From f7f9aae21f342240c80a6e71a81c895289da71fe Mon Sep 17 00:00:00 2001 From: Paarth Neekhara Date: Wed, 15 Apr 2026 13:54:38 -0400 Subject: [PATCH 001/109] wip Signed-off-by: Paarth Neekhara Signed-off-by: Edresson Casanova --- ...nference_multiturn_turn_based_as_magpie.py | 2247 +++++++++++++++++ 1 file changed, 2247 insertions(+) create mode 100644 examples/tts/easy_magpietts_inference_multiturn_turn_based_as_magpie.py diff --git a/examples/tts/easy_magpietts_inference_multiturn_turn_based_as_magpie.py b/examples/tts/easy_magpietts_inference_multiturn_turn_based_as_magpie.py new file mode 100644 index 000000000000..deb2244f2376 --- /dev/null +++ b/examples/tts/easy_magpietts_inference_multiturn_turn_based_as_magpie.py @@ -0,0 +1,2247 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +""" +Multi-GPU EasyMagpieTTS / NemotronTTS multiturn inference evaluation. + +Key behavior: + - Uses torchrun env vars RANK, LOCAL_RANK, WORLD_SIZE for sharding/GPU assignment. + - Does NOT initialize torch.distributed. This avoids NeMo ASR doing distributed + collectives during metric computation. + - Generation runs first for all assigned samples. + - ASR and SECS are loaded only after generation is done and the TTS/codec model + has been deleted from GPU memory. + - ASR and SECS are loaded sequentially: ASR first, then released; SECS second. + - For --profile_multiturn_inference, metrics are computed turn-by-turn. + Final filewise outputs are grouped back to one row per original sample, with + lists for asr_hyp/reference_text/cer_turns/wer_turns/secs_turns. + - Uses DistributedSampler with explicit rank/world_size. A few repeated samples + may appear when len(dataset) is not divisible by world_size. Filewise final + metrics deduplicate sampler-padding repeats by (run_id, dataset_index, + turn_id), then group turns into one row per sample with metric lists, while + preserving --num_eval_runs repetitions. + - --sort_by_text_token_count sorts samples by total text-token count before + sharding to improve GPU load balance. + - Saves audio in out_dir/audios/. + - Saves metrics in out_dir/. + +Recommended single-node torchrun: + CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ + torchrun --standalone --nproc_per_node=8 easy_magpietts_inference_multiturn_multigpu_postgen_metrics.py ... + +Recommended single-node srun wrapper: + srun --nodes=1 --ntasks=1 --ntasks-per-node=1 --container-image=... \ + bash -lc 'torchrun --standalone --nproc_per_node=8 easy_magpietts_inference_multiturn_multigpu_postgen_metrics.py ...' +""" + +import argparse +import csv +import json +import os +import socket +import time +from copy import deepcopy +from functools import partial +from typing import Any, Dict, Iterable, List, Tuple + +import librosa +import soundfile as sf +import torch +from omegaconf import open_dict +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import DataLoader, Dataset, DistributedSampler, SequentialSampler + +from nemo.collections.audio.parts.utils.transforms import resample +from nemo.collections.asr.metrics.wer import word_error_rate +from nemo.collections.speechlm2.parts.metrics.asr_cer_wer import Intelligibility +from nemo.collections.speechlm2.parts.metrics.secs import SECS +from nemo.collections.speechlm2.parts.precision import fp32_precision +from nemo.collections.tts.models import AudioCodecModel +from nemo.collections.tts.models.easy_magpietts_inference import EasyMagpieTTSInferenceModel +from nemo.collections.tts.modules.audio_codec_modules import VectorQuantizerIndexConverter +from nemo.collections.tts.modules.magpietts_modules import CodecHelper +from nemo.collections.tts.parts.utils.tts_dataset_utils import normalize_volume +from nemo.utils import logging +from whisper_normalizer.english import EnglishTextNormalizer + + +torch.set_float32_matmul_precision("medium") +torch.backends.cudnn.allow_tf32 = True +torch.backends.cuda.matmul.allow_tf32 = True + + +# ----------------------------- +# Rank / file helpers +# ----------------------------- + + +def get_rank_info() -> Tuple[bool, int, int, int]: + world_size = int(os.environ.get("WORLD_SIZE", "1")) + rank = int(os.environ.get("RANK", os.environ.get("SLURM_PROCID", "0"))) + local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("SLURM_LOCALID", "0"))) + distributed = world_size > 1 + return distributed, rank, local_rank, world_size + + +def get_visible_device_index(local_rank: int) -> int: + if not torch.cuda.is_available(): + return -1 + ndev = torch.cuda.device_count() + if ndev <= 0: + return -1 + return local_rank % ndev + + +def setup_distributed(): + """ + Do not initialize torch.distributed. + + We only need RANK/LOCAL_RANK/WORLD_SIZE for rank assignment and dataset + sharding. Initializing a process group can cause NeMo ASR to run distributed + collectives during transcribe(), which may hang when ranks have different + audio lengths or workloads. + """ + distributed, rank, local_rank, world_size = get_rank_info() + device_index = get_visible_device_index(local_rank) + + if torch.cuda.is_available() and device_index >= 0: + torch.cuda.set_device(device_index) + + return distributed, rank, local_rank, world_size, device_index + + +def cleanup_distributed(): + return + + +def all_rank_print(rank: int, msg: str): + print(f"[rank={rank}] {msg}", flush=True) + + +def rank0_print(rank: int, msg: str): + if rank == 0: + print(msg, flush=True) + + +def get_audio_out_dir(args) -> str: + return os.path.join(args.out_dir, "audios") + + +def get_generated_turn_audio_dir(args) -> str: + return os.path.join(get_audio_out_dir(args), "metric_turns") + + +def get_context_metric_audio_dir(args) -> str: + return os.path.join(get_audio_out_dir(args), "metric_context") + + +def write_json(path: str, obj: Dict[str, Any]): + os.makedirs(os.path.dirname(path), exist_ok=True) + tmp_path = path + ".tmp" + with open(tmp_path, "w", encoding="utf-8") as f: + json.dump(obj, f, indent=2, sort_keys=True, ensure_ascii=False) + os.replace(tmp_path, path) + + +def write_text_atomic(path: str, text: str): + os.makedirs(os.path.dirname(path), exist_ok=True) + tmp_path = path + ".tmp" + with open(tmp_path, "w", encoding="utf-8") as f: + f.write(text) + os.replace(tmp_path, path) + + +def write_jsonl(path: str, rows: List[Dict[str, Any]]): + os.makedirs(os.path.dirname(path), exist_ok=True) + tmp_path = path + ".tmp" + with open(tmp_path, "w", encoding="utf-8") as f: + for row in rows: + f.write(json.dumps(row, sort_keys=True, ensure_ascii=False) + "\n") + os.replace(tmp_path, path) + + +def wait_for_files(paths: List[str], timeout_sec: float = 7200.0, poll_sec: float = 5.0): + start = time.time() + while True: + missing = [p for p in paths if not os.path.exists(p)] + if not missing: + return + if time.time() - start > timeout_sec: + raise TimeoutError("Timed out waiting for files:\n" + "\n".join(missing)) + time.sleep(poll_sec) + + +def wait_for_rank_metric_files(args, world_size: int): + paths = [os.path.join(args.out_dir, f"metrics_rank{r:04d}.json") for r in range(world_size)] + wait_for_files(paths) + + +def wait_for_rank_filewise_metric_files(args, world_size: int): + paths = [os.path.join(args.out_dir, f"filewise_metrics_rank{r:04d}.jsonl") for r in range(world_size)] + wait_for_files(paths) + + +def scalarize_metric_value(v: Any): + if torch.is_tensor(v): + if v.numel() == 1: + return float(v.detach().cpu().item()) + return v.detach().cpu().tolist() + try: + import numpy as np + + if isinstance(v, np.generic): + return float(v.item()) + except Exception: + pass + if isinstance(v, (int, float, str, bool)) or v is None: + return v + return str(v) + + +def metric_dict_to_jsonable(d: Dict[str, Any]) -> Dict[str, Any]: + return {str(k): scalarize_metric_value(v) for k, v in d.items()} + + +def safe_metric_scalar(metric_dict: Dict[str, Any], preferred_keys: List[str]): + for key in preferred_keys: + if key in metric_dict: + value = metric_dict[key] + if torch.is_tensor(value): + return float(value.detach().cpu().item()) + return float(value) + return None + + +def get_first_metric(metrics: Dict[str, Any], names: List[str], default=None): + for name in names: + if name in metrics: + return metrics[name] + return default + + +def format_final_metric_text(final_metrics: Dict[str, Any]) -> str: + intelligibility = final_metrics.get("intelligibility", {}) + secs = final_metrics.get("secs", {}) + + cer = get_first_metric(intelligibility, ["cer", "cer_dataset"]) + wer = get_first_metric(intelligibility, ["wer", "wer_dataset"]) + secs_value = get_first_metric(secs, ["secs", "secs_dataset"]) + + def fmt(x): + if x is None: + return "nan" + try: + return f"{float(x):.10f}" + except Exception: + return str(x) + + return f"Average CER: {fmt(cer)}\nAverage WER: {fmt(wer)}\nSECS: {fmt(secs_value)}\n" + + +def format_filewise_final_metric_text(filewise_summary: Dict[str, Any]) -> str: + def fmt(x): + if x is None: + return "nan" + try: + return f"{float(x):.10f}" + except Exception: + return str(x) + + return ( + f"Average CER: {fmt(filewise_summary.get('cer'))}\n" + f"Average WER: {fmt(filewise_summary.get('wer'))}\n" + f"SECS: {fmt(filewise_summary.get('secs'))}\n" + ) + + +def write_filewise_csv(path: str, rows: List[Dict[str, Any]]): + """Write sample-level filewise metrics. + + Several fields are lists (turn_ids, reference_text, asr_hyp, cer_turns, + etc.), so they are JSON-encoded inside CSV cells. + """ + os.makedirs(os.path.dirname(path), exist_ok=True) + tmp_path = path + ".tmp" + + fieldnames = [ + "run_id", + "dataset_index", + "rank", + "num_turns", + "cer", + "wer", + "secs", + "turn_ids", + "cer_turns", + "wer_turns", + "secs_turns", + "pred_audio_seconds_turns", + "target_audio_path", + "context_audio_path", + "pred_audio_paths", + "reference_text", + "asr_hyp", + ] + + def csv_value(v): + if isinstance(v, (list, dict)): + return json.dumps(v, ensure_ascii=False) + return v + + with open(tmp_path, "w", encoding="utf-8", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + for row in rows: + writer.writerow({k: csv_value(row.get(k, None)) for k in fieldnames}) + + os.replace(tmp_path, path) + +# ----------------------------- +# Dataset helpers +# ----------------------------- + + +def _combined_audio_name(first_audio_filepath: str, paths: List[str]) -> str: + base_names = [os.path.splitext(os.path.basename(p))[0] for p in paths if p] + ext = os.path.splitext(paths[-1])[1] if paths and paths[-1] else "" + combined_name = "_".join(base_names) + ext + dir_name = os.path.dirname(first_audio_filepath) + return os.path.join(dir_name, combined_name) if dir_name else combined_name + + +class EvalJSONLDataset(Dataset): + def __init__(self, file_path: str, num_turns: int = 1): + self.samples = [] + raw_samples = [] + + with open(file_path, "r", encoding="utf-8") as f: + for line_idx, line in enumerate(f, 1): + line = line.strip() + if not line: + continue + try: + sample = json.loads(line) + sample["__dataset_index__"] = len(raw_samples) + raw_samples.append(sample) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON on line {line_idx}: {e}") + + if num_turns <= 1: + self.samples = raw_samples + return + + single_turn_by_speaker = {} + for sample in raw_samples: + if isinstance(sample["text"], list): + self.samples.append(sample) + else: + speaker = sample.get("speaker", "unknown") + single_turn_by_speaker.setdefault(speaker, []).append(sample) + + synthetic_index = len(raw_samples) + for _, speaker_samples in single_turn_by_speaker.items(): + buffer_texts, buffer_paths = [], [] + first_sample_meta = None + + for sample in speaker_samples: + if not buffer_texts: + first_sample_meta = dict(sample) + + buffer_texts.append(sample["text"]) + buffer_paths.append(sample.get("audio_filepath", "")) + + if len(buffer_texts) == num_turns: + first_sample_meta["text"] = buffer_texts + first_sample_meta["audio_filepath"] = _combined_audio_name( + first_sample_meta.get("audio_filepath", ""), + buffer_paths, + ) + first_sample_meta["__dataset_index__"] = synthetic_index + synthetic_index += 1 + self.samples.append(first_sample_meta) + buffer_texts, buffer_paths, first_sample_meta = [], [], None + + if buffer_texts and first_sample_meta is not None: + first_sample_meta["text"] = buffer_texts + first_sample_meta["audio_filepath"] = _combined_audio_name( + first_sample_meta.get("audio_filepath", ""), + buffer_paths, + ) + first_sample_meta["__dataset_index__"] = synthetic_index + synthetic_index += 1 + self.samples.append(first_sample_meta) + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + return self.samples[idx] + + +def _sample_text_segments_for_count(sample: Dict[str, Any], max_eval_turns=None) -> List[str]: + text_data = sample.get("text", "") + if isinstance(text_data, list): + segments = text_data + if max_eval_turns is not None: + segments = segments[: int(max_eval_turns)] + return [str(x) for x in segments] + return [str(text_data)] + + +def estimate_text_token_count(sample: Dict[str, Any], model, max_eval_turns=None) -> int: + main_tokenizer_name = list(model.cfg.text_tokenizers.keys())[0] + total = 0 + for segment in _sample_text_segments_for_count(sample, max_eval_turns=max_eval_turns): + total += len(model.tokenizer.encode(segment, tokenizer_name=main_tokenizer_name)) + 1 + return int(total) + + +class SortedByTextTokenCountDataset(Dataset): + def __init__(self, dataset: Dataset, model, max_eval_turns=None, descending: bool = True): + self.dataset = dataset + scored = [] + for i in range(len(dataset)): + sample = dict(dataset[i]) + token_count = estimate_text_token_count(sample, model=model, max_eval_turns=max_eval_turns) + sample["__text_token_count__"] = int(token_count) + scored.append((token_count, i, sample)) + + scored.sort(key=lambda x: (x[0], -x[1]), reverse=bool(descending)) + self.indices = [i for _, i, _ in scored] + self.token_counts = {i: int(tok) for tok, i, _ in scored} + + def __len__(self): + return len(self.indices) + + def __getitem__(self, local_idx): + original_idx = self.indices[local_idx] + sample = dict(self.dataset[original_idx]) + sample["__text_token_count__"] = self.token_counts[original_idx] + return sample + + +# ----------------------------- +# Audio / collate helpers +# ----------------------------- + + +def _resolve_audio_path(path, root_path): + if path is None: + return None + if root_path is not None and not os.path.isabs(path): + return os.path.join(root_path, path) + return path + + +def _load_audio(path, sample_rate, normalize=True, use_librosa=False): + if path is None or not os.path.exists(path): + return torch.zeros(1, dtype=torch.float32) + + if use_librosa: + wav, sr = librosa.load(path, sr=sample_rate, mono=True) + if normalize: + wav = normalize_volume(wav) + return torch.as_tensor(wav, dtype=torch.float32) + + wav, sr = sf.read(path, dtype="float32") + if wav.ndim > 1: + wav = wav.mean(axis=1) + + if normalize: + wav = normalize_volume(wav) + + wav = torch.as_tensor(wav, dtype=torch.float32).unsqueeze(0) + return resample(wav, sr, sample_rate).squeeze(0) + + +def collate_and_tokenize_custom( + batch, + model, + extra_duration_thrshould=1.3, + sample_rate=22050, + root_path=None, + emulate_duplex_inference=False, + add_interruption_token=False, + pad_factor_text_speech=10, + force_interruption=False, + normalize_audio_volume=True, + use_librosa=False, + profile_multiturn_inference=False, + max_eval_turns=None, +): + main_tokenizer_name = list(model.cfg.text_tokenizers.keys())[0] + + if max_eval_turns is not None: + max_eval_turns = int(max_eval_turns) + if max_eval_turns <= 0: + raise ValueError("--max_eval_turns must be > 0 when provided.") + + truncated_batch = [] + for s in batch: + s = dict(s) + if isinstance(s["text"], list): + s["text"] = s["text"][:max_eval_turns] + if isinstance(s.get("user_audio_file_path"), list): + s["user_audio_file_path"] = s["user_audio_file_path"][:max_eval_turns] + truncated_batch.append(s) + batch = truncated_batch + + is_profile = profile_multiturn_inference + is_duplex = emulate_duplex_inference and not is_profile + + out_dict = { + "duplex_multiturn": is_duplex, + "regular_multiturn": (not is_duplex) and (not is_profile), + "profile_multiturn": is_profile, + "dataset_indices": [int(s.get("__dataset_index__", -1)) for s in batch], + "text_token_counts": [int(s.get("__text_token_count__", -1)) for s in batch], + } + + tokenized_list = [] + batched_turns = [] + batched_turn_lens = [] + valid_turn_masks = [] + + if is_duplex: + for s in batch: + text_data = s["text"] + + if isinstance(text_data, list): + full_ids = [] + for segment in text_data: + seg_ids = model.tokenizer.encode(segment, tokenizer_name=main_tokenizer_name) + [model.eos_id] + pad_ids = [model.pad_id] * (len(seg_ids) * pad_factor_text_speech) + + if force_interruption: + fname = s["audio_filepath"] + no_ext = fname.split(".")[0] + sample_id = int(no_ext.split("_")[-1]) + case = sample_id % 3 + + if case == 0: + if len(seg_ids) >= 2: + seg_ids[-2] = model.interruption_token_id + seg_ids[-1] = model.pad_id + else: + pad_ids[0] = model.interruption_token_id + elif case == 1: + eos_idx = min(6, len(pad_ids) - 1) + pad_ids[eos_idx] = model.interruption_token_id + else: + pad_ids[0] = model.interruption_token_id + + elif add_interruption_token: + eos_idx = int(len(pad_ids) * 0.7) + pad_ids[eos_idx] = model.interruption_token_id + + full_ids.extend(seg_ids) + full_ids.extend(pad_ids) + + tokenized_list.append(torch.as_tensor(full_ids, dtype=torch.long)) + else: + tokenized_list.append( + torch.as_tensor( + model.tokenizer.encode(text_data, tokenizer_name=main_tokenizer_name) + [model.eos_id], + dtype=torch.long, + ) + ) + + prefix = torch.full((25,), model.pad_id, dtype=torch.long) + tokenized_list = [torch.cat([prefix, x]) for x in tokenized_list] + out_dict["input_lengths"] = torch.tensor([len(x) for x in tokenized_list], dtype=torch.long) + out_dict["input_ids"] = pad_sequence(tokenized_list, batch_first=True, padding_value=model.pad_id) + + else: + max_turns = 1 + for s in batch: + if isinstance(s["text"], list): + max_turns = max(max_turns, len(s["text"])) + + for t in range(max_turns): + turn_t_tokens, turn_t_lens, turn_t_valid = [], [], [] + + for s in batch: + text_data = s["text"] + + if isinstance(text_data, list): + if t < len(text_data): + seg_ids = model.tokenizer.encode(text_data[t], tokenizer_name=main_tokenizer_name) + [ + model.eos_id + ] + turn_t_tokens.append(torch.as_tensor(seg_ids, dtype=torch.long)) + turn_t_lens.append(len(seg_ids)) + turn_t_valid.append(True) + else: + turn_t_tokens.append(torch.as_tensor([model.pad_id], dtype=torch.long)) + turn_t_lens.append(1) + turn_t_valid.append(False) + else: + if t == 0: + seg_ids = model.tokenizer.encode(text_data, tokenizer_name=main_tokenizer_name) + [ + model.eos_id + ] + turn_t_tokens.append(torch.as_tensor(seg_ids, dtype=torch.long)) + turn_t_lens.append(len(seg_ids)) + turn_t_valid.append(True) + else: + turn_t_tokens.append(torch.as_tensor([model.pad_id], dtype=torch.long)) + turn_t_lens.append(1) + turn_t_valid.append(False) + + batched_turns.append(pad_sequence(turn_t_tokens, batch_first=True, padding_value=model.pad_id)) + batched_turn_lens.append(torch.tensor(turn_t_lens, dtype=torch.long)) + valid_turn_masks.append(torch.tensor(turn_t_valid, dtype=torch.bool)) + + out_dict["batched_turns"] = batched_turns + out_dict["batched_turn_lens"] = batched_turn_lens + out_dict["valid_turn_masks"] = valid_turn_masks + + audio_list, audio_lengths, target_num_frames = [], [], [] + context_audio_paths = [] + max_turns_for_user_audio = len(batched_turns) if not is_duplex else 0 + + if is_profile and max_turns_for_user_audio > 0: + user_audio_by_turn = [[] for _ in range(max_turns_for_user_audio)] + user_audio_lens_by_turn = [[] for _ in range(max_turns_for_user_audio)] + else: + user_audio_by_turn, user_audio_lens_by_turn = [], [] + + for i, s in enumerate(batch): + audio_path = _resolve_audio_path(s.get("context_audio_filepath"), root_path) + context_audio_paths.append(audio_path) + wav = _load_audio(audio_path, sample_rate, normalize=normalize_audio_volume, use_librosa=use_librosa) + audio_list.append(wav) + audio_lengths.append(len(wav)) + + if is_profile and max_turns_for_user_audio > 0: + user_audio_paths = s.get("user_audio_file_path", None) + + for t in range(max_turns_for_user_audio): + has_valid_text_turn = (isinstance(s["text"], list) and t < len(s["text"])) or ( + not isinstance(s["text"], list) and t == 0 + ) + + if ( + isinstance(user_audio_paths, list) + and t < len(user_audio_paths) + and user_audio_paths[t] + and has_valid_text_turn + ): + ua_path = _resolve_audio_path(user_audio_paths[t], root_path) + ua_wav = _load_audio( + ua_path, + sample_rate=sample_rate, + normalize=normalize_audio_volume, + use_librosa=use_librosa, + ) + else: + ua_wav = torch.zeros(int(2 * sample_rate), dtype=torch.float32) + + user_audio_by_turn[t].append(ua_wav) + user_audio_lens_by_turn[t].append(len(ua_wav)) + + tdur_audio_path = _resolve_audio_path(s["audio_filepath"], root_path) + + if tdur_audio_path and os.path.exists(tdur_audio_path): + wav_dur = _load_audio( + tdur_audio_path, + sample_rate, + normalize=normalize_audio_volume, + use_librosa=use_librosa, + ) + tdur = wav_dur.shape[0] // model.input_samples_per_frame + target_num_frames.append(tdur * extra_duration_thrshould) + else: + if is_duplex: + current_text_len = len(tokenized_list[i]) + target_num_frames.append(current_text_len if isinstance(s["text"], list) else current_text_len * 5) + else: + target_num_frames.append(sum([l[i].item() for l in batched_turn_lens]) * 5) + + max_audio_len = max(audio_lengths) + B = len(audio_lengths) + padded_audio = torch.zeros((B, max_audio_len), dtype=torch.float32) + + for i, wav in enumerate(audio_list): + padded_audio[i, : len(wav)] = wav + + if is_profile and max_turns_for_user_audio > 0: + padded_user_audio_turns, padded_user_audio_turns_lens = [], [] + + for t in range(max_turns_for_user_audio): + turn_lens = user_audio_lens_by_turn[t] + max_turn_audio_len = max(turn_lens) + padded_turn_audio = torch.zeros((B, max_turn_audio_len), dtype=torch.float32) + + for i, wav in enumerate(user_audio_by_turn[t]): + padded_turn_audio[i, : len(wav)] = wav + + padded_user_audio_turns.append(padded_turn_audio) + padded_user_audio_turns_lens.append(torch.tensor(turn_lens, dtype=torch.long)) + + out_dict["user_audio_turns"] = padded_user_audio_turns + out_dict["user_audio_turns_lens"] = padded_user_audio_turns_lens + + raw_turn_texts = [] + for s in batch: + if isinstance(s["text"], list): + raw_turn_texts.append([str(x) for x in s["text"]]) + else: + raw_turn_texts.append([str(s["text"])]) + + out_dict["context_audio"] = padded_audio + out_dict["context_audio_lengths"] = torch.tensor(audio_lengths, dtype=torch.long) + out_dict["context_audio_paths"] = context_audio_paths + out_dict["target_audio_paths"] = [s["audio_filepath"] for s in batch] + out_dict["target_num_frames"] = target_num_frames + out_dict["raw_turn_texts"] = raw_turn_texts + out_dict["raw_text"] = [" ".join(x) for x in raw_turn_texts] + + return out_dict + + +# ----------------------------- +# Model / generation +# ----------------------------- + + +def attach_dtype_counter(model): + handles = [] + stats = {} + examples = {} + + def is_leaf(module): + return len(list(module.children())) == 0 + + def get_dtype(x): + if torch.is_tensor(x): + return str(x.dtype) + if isinstance(x, (list, tuple)): + for t in x: + if torch.is_tensor(t): + return str(t.dtype) + return "other" + + def get_module_group(name): + return name.split(".")[0] if "." in name else name + + def hook_fn(name): + def fn(module, inputs, outputs): + dtype = get_dtype(outputs) + if dtype not in ["torch.float16", "torch.bfloat16", "torch.float32"]: + dtype = "other" + group = get_module_group(name) + if group not in stats: + stats[group] = { + "torch.float16": 0, + "torch.bfloat16": 0, + "torch.float32": 0, + "other": 0, + } + examples[group] = { + "torch.float16": [], + "torch.bfloat16": [], + "torch.float32": [], + "other": [], + } + stats[group][dtype] += 1 + if len(examples[group][dtype]) < 3: + examples[group][dtype].append(module.__class__.__name__) + + return fn + + for name, module in model.named_modules(): + if is_leaf(module): + handles.append(module.register_forward_hook(hook_fn(name))) + return handles, stats, examples + + +def report_dtype_stats(handles, stats, examples, rank=0): + for h in handles: + h.remove() + logging.info(f"[rank={rank}] === DTYPE USAGE PER MODULE ===") + for group, group_stats in stats.items(): + total = sum(group_stats.values()) + if total == 0: + continue + logging.info(f"[rank={rank}] --- {group} ---") + for dtype, count in group_stats.items(): + if count > 0: + logging.info(f"[rank={rank}] {dtype}: {count} ({100 * count / total:.2f}%)") + logging.info(f"[rank={rank}] === DTYPE EXAMPLES ===") + for group, group_examples in examples.items(): + for dtype, mods in group_examples.items(): + if mods: + logging.info(f"[rank={rank}] {group} {dtype}: {mods}") + + +def build_model_and_codec(args, target_device, target_dtype): + model_cfg = EasyMagpieTTSInferenceModel.restore_from(args.checkpoint_path, return_config=True) + + with open_dict(model_cfg): + model_cfg.target = "nemo.collections.tts.models.easy_magpietts_inference.EasyMagpieTTSInferenceModel" + model_cfg.codecmodel_path = args.codec_model_path + model_cfg.train_ds = None + model_cfg.validation_ds = None + model_cfg.run_val_inference = False + model_cfg.use_utmos = False + model_cfg.use_meta_init_for_decoder = True + + if args.phoneme_tokenizer_path and getattr(model_cfg, "phoneme_tokenizer", None) is not None: + model_cfg.phoneme_tokenizer.tokenizer_path = args.phoneme_tokenizer_path + + model = EasyMagpieTTSInferenceModel.restore_from( + args.checkpoint_path, + override_config_path=model_cfg, + map_location=torch.device("cpu"), + ) + model.use_kv_cache_for_inference = True + model.to(dtype=target_dtype) + model.eval().to(target_device) + + model.input_samples_per_frame = int(model.codec_model_samples_per_frame * model.frame_stacking_factor) + model.target_samples_per_frame = model.input_samples_per_frame / (model.sample_rate / model.output_sample_rate) + + codec_model = AudioCodecModel.restore_from(args.codec_model_path, strict=False, map_location=torch.device("cpu")) + if hasattr(codec_model, "discriminator"): + del codec_model.discriminator + codec_model.freeze() + codec_model = codec_model.to(target_device).eval() + + codec_converter = None + if getattr(model, "_codec_converter", None) is not None: + vq_new = deepcopy(model._codec_converter.vector_quantizer_new).to(target_device).eval() + codec_converter = VectorQuantizerIndexConverter( + vector_quantizer_original=codec_model.vector_quantizer, + vector_quantizer_new=vq_new, + ).to(target_device).eval() + + model._codec_helper = CodecHelper(codec_model=codec_model, codec_converter=codec_converter) + model._generate_codec_silence_buffer() + + return model + + +def prepare_inputs_for_device(inputs, model, args, target_dtype, speaker_wav=None): + B = inputs["context_audio"].size(0) + device = model.device + + inputs["context_audio"] = inputs["context_audio"].to(device, dtype=target_dtype) + inputs["context_audio_lengths"] = inputs["context_audio_lengths"].to(device) + + if args.user_custom_speaker_reference and speaker_wav is not None: + inputs["context_audio"] = speaker_wav.repeat(B, 1).detach() + inputs["context_audio_lengths"] = torch.full((B,), speaker_wav.size(-1), dtype=torch.long, device=device) + + if "user_audio_turns" in inputs: + inputs["user_audio_turns"] = [x.to(device, dtype=target_dtype) for x in inputs["user_audio_turns"]] + inputs["user_audio_turns_lens"] = [x.to(device) for x in inputs["user_audio_turns_lens"]] + + return inputs + + +def run_generation(model, inputs, args, codec_sil_codes): + B = inputs["context_audio"].size(0) + device = model.device + profile_turn_frame_ranges = [] + profile_decode_start_frame = 0 + + with torch.inference_mode(): + wav = inputs["context_audio"] + wav_len = inputs["context_audio_lengths"] + codes, codes_lens = model._codec_helper.audio_to_codes(wav, wav_len) + + use_lang = bool(getattr(model, "add_language_to_context_text", False)) + ctx_text = f"[{args.language.upper()}]" if use_lang else "[NO TEXT CONTEXT]" + ctx_text_ids = model.tokenizer.encode(ctx_text, tokenizer_name=model.text_conditioning_tokenizer_name) + + ctx_toks = torch.tensor([ctx_text_ids], dtype=torch.long, device=device).expand(B, -1) + ctx_toks_lens = torch.tensor([len(ctx_text_ids)] * B, dtype=torch.long, device=device) + + state = model.streaming_init( + context_audio_codes=codes, + context_audio_codes_lens=codes_lens, + context_text_tokens=ctx_toks, + context_text_tokens_lens=ctx_toks_lens, + use_cfg=args.use_cfg, + cfg_scale=args.cfg_scale, + use_local_transformer=True, + temperature=args.temperature, + topk=args.topk, + phoneme_input_type="pred", + phoneme_sampling_method="argmax", + use_inference_mode=True, + ) + + if inputs["duplex_multiturn"]: + text = inputs["input_ids"].to(device) + text_lens = inputs["input_lengths"].to(device) + + in_initial_silence = torch.ones(B, dtype=torch.bool, device=device) + in_post_speech_silence = torch.zeros(B, dtype=torch.bool, device=device) + text_exhausted = state.text_tokens_seen >= text_lens + + while not text_exhausted.all(): + state.finished = state.finished & text_exhausted + state.text_finished = state.text_finished & text_exhausted + + if hasattr(state, "phoneme_stream_ended"): + state.phoneme_stream_ended = state.phoneme_stream_ended & text_exhausted + + positions = state.text_tokens_seen.clamp(max=text.size(1) - 1) + current_tokens = text[torch.arange(B, device=device), positions] + current_tokens = torch.where( + text_exhausted, + torch.full_like(current_tokens, model.eos_id), + current_tokens, + ) + + is_pad_or_eos = (current_tokens == model.pad_id) | (current_tokens == model.eos_id) + in_initial_silence = in_initial_silence & is_pad_or_eos + in_post_speech_silence = in_post_speech_silence & is_pad_or_eos + + state, audio_codes, _ = model.streaming_step( + state=state, + text_tokens=current_tokens, + use_inference_mode=True, + ) + + if audio_codes is not None and args.force_speech_sil_codes: + force_silence_mask = in_initial_silence | in_post_speech_silence + if force_silence_mask.any(): + expanded_sil = codec_sil_codes.view(1, -1, 1).expand_as(audio_codes) + mask_3d = force_silence_mask.view(B, 1, 1) + state.all_predictions[-1] = torch.where(mask_3d, expanded_sil, audio_codes) + + in_post_speech_silence = in_post_speech_silence | state.finished + text_exhausted = state.text_tokens_seen >= text_lens + + elif inputs["regular_multiturn"]: + batched_turns = inputs["batched_turns"] + batched_turn_lens = inputs["batched_turn_lens"] + valid_turn_masks = inputs["valid_turn_masks"] + turn_offsets = torch.zeros(B, dtype=torch.long, device=device) + + for t in range(len(batched_turns)): + turn_text = batched_turns[t].to(device) + turn_lens = batched_turn_lens[t].to(device) + valid_mask = valid_turn_masks[t].to(device) + + state.finished = state.finished & (~valid_mask) + state.text_finished = state.text_finished & (~valid_mask) + + if hasattr(state, "phoneme_stream_ended"): + state.phoneme_stream_ended = state.phoneme_stream_ended & (~valid_mask) + + if state.finished.all(): + continue + + turn_offsets = torch.where(valid_mask, state.text_tokens_seen, turn_offsets) + turn_steps = 0 + + while not state.finished.all() and turn_steps < args.max_tts_steps: + turn_steps += 1 + relative_positions = state.text_tokens_seen - turn_offsets + positions = relative_positions.clamp(min=0, max=turn_text.size(1) - 1) + current_tokens = turn_text[torch.arange(B, device=device), positions] + + exhausted = relative_positions >= turn_lens + current_tokens = torch.where( + exhausted, + torch.full_like(current_tokens, model.eos_id), + current_tokens, + ) + + state, _, _ = model.streaming_step( + state=state, + text_tokens=current_tokens, + use_inference_mode=True, + ) + + elif inputs["profile_multiturn"]: + if B != 1: + raise RuntimeError("--profile_multiturn_inference requires --batch_size=1 per process.") + + batched_turns = inputs["batched_turns"] + batched_turn_lens = inputs["batched_turn_lens"] + valid_turn_masks = inputs["valid_turn_masks"] + + for t in range(len(batched_turns)): + turn_text = batched_turns[t].to(device) + turn_lens = batched_turn_lens[t].to(device) + valid_mask = valid_turn_masks[t].to(device) + + if not bool(valid_mask[0].item()): + continue + + state.finished.zero_() + state.text_finished.zero_() + state.audio_prediction_end_idx.fill_(-1) + + if hasattr(state, "turn_text_tokens_seen"): + state.turn_text_tokens_seen.zero_() + if hasattr(state, "phoneme_steps"): + state.phoneme_steps.zero_() + if hasattr(state, "phoneme_stream_ended"): + state.phoneme_stream_ended.zero_() + if hasattr(state, "phoneme_eos_detected"): + state.phoneme_eos_detected.zero_() + state.last_phoneme_tokens = None + + if not model.cfg.get("condition_on_user_speech", False): + if "user_audio_turns" in inputs: + profile_T = int(round(inputs["user_audio_turns"][t].size(-1) / model.input_samples_per_frame)) + profile_seconds = profile_T * model.input_samples_per_frame / model.sample_rate + else: + profile_seconds = args.profile_pad_min_sec + torch.rand((), device=device).item() * ( + args.profile_pad_max_sec - args.profile_pad_min_sec + ) + profile_T = max( + 1, + int(round(profile_seconds * model.sample_rate / model.input_samples_per_frame)), + ) + + profile_tokens = torch.full((1, profile_T), model.pad_id, dtype=torch.long, device=device) + user_audio_channel_embedding = None + + else: + if "user_audio_turns" in inputs: + user_audio = inputs["user_audio_turns"][t] + user_audio_lens = inputs["user_audio_turns_lens"][t] + else: + user_audio = inputs["context_audio"] + user_audio_lens = inputs["context_audio_lengths"] + + user_audio_codes, user_audio_codes_lens = model._codec_helper.audio_to_codes( + user_audio, + user_audio_lens, + ) + + if model._codec_converter is not None: + user_audio_codes = model._codec_converter.convert_original_to_new( + audio_tokens=user_audio_codes, + audio_lens=user_audio_codes_lens, + ).long() + + user_audio_codes, user_audio_codes_lens = model.stack_codes( + user_audio_codes, + user_audio_codes_lens, + model.audio_bos_id, + model.audio_eos_id, + model.frame_stacking_factor, + model.num_audio_codebooks, + ) + + user_audio_embedded = model.embed_audio_tokens(user_audio_codes) + + boundary_trim = model.cfg.get("user_audio_boundary_trim", 4) + boundary_trim = 0 if boundary_trim is None else int(boundary_trim) + + if boundary_trim == 0: + real_start = 0 + real_end = int(user_audio_codes_lens[0].item()) + else: + turn_len_with_special = int(user_audio_codes_lens[0].item()) + real_start = 1 + real_end = max(real_start, turn_len_with_special - 1) + + user_audio_embedded = user_audio_embedded[:, real_start:real_end] + copy_len = user_audio_embedded.size(1) + + if boundary_trim > 0: + trim = min(boundary_trim, copy_len // 2) + if trim > 0: + user_audio_embedded[:, :trim] = 0.0 + user_audio_embedded[:, copy_len - trim :] = 0.0 + + bos_user_pad = torch.zeros( + user_audio_embedded.size(0), + 1, + user_audio_embedded.size(2), + device=user_audio_embedded.device, + dtype=user_audio_embedded.dtype, + ) + user_audio_embedded = torch.cat([bos_user_pad, user_audio_embedded], dim=1) + + profile_T = user_audio_embedded.size(1) + profile_tokens = torch.full((B, profile_T), model.pad_id, dtype=torch.long, device=device) + user_audio_channel_embedding = user_audio_embedded + profile_seconds = profile_T * model.input_samples_per_frame / model.sample_rate + + delay_tokens = int(state.config.training_mode.streaming_speech_delay) + delay_tokens = min(delay_tokens, int(turn_lens[0].item()), profile_T) + + warmup_tokens = turn_text[:, :delay_tokens] + turn_text = turn_text[:, delay_tokens:] + turn_lens = torch.clamp(turn_lens - delay_tokens, min=0) + + if user_audio_channel_embedding is not None and delay_tokens > 0: + warmup_user_audio = user_audio_channel_embedding[:, -delay_tokens:] + user_audio_channel_embedding = user_audio_channel_embedding[:, :-delay_tokens] + profile_tokens = profile_tokens[:, :-delay_tokens] + else: + warmup_user_audio = None + + if profile_tokens.size(1) > 0: + state = model.streaming_prefill_profile( + state=state, + text_tokens=profile_tokens, + use_inference_mode=True, + user_audio_channel_embedding=user_audio_channel_embedding, + ) + + for i in range(delay_tokens): + user_step_emb = warmup_user_audio[:, i] if warmup_user_audio is not None else None + + state.finished.zero_() + state, _, _ = model.streaming_step( + state=state, + text_tokens=warmup_tokens[:, i], + user_audio_channel_embedding=user_step_emb, + prefill_like_step=not bool(model.cfg.get("agent_mask_include_transition_prefix", False)), + prefill_like_is_last_step=(i == delay_tokens - 1), + use_inference_mode=True, + ) + + logging.info(f"[profile_multiturn] turn={t} prefilled {profile_T} steps ({profile_seconds:.2f}s)") + + turn_start_frame = sum(p.size(-1) for p in state.all_predictions) + if t == 0: + state.audio_prediction_start_idx.fill_(turn_start_frame) + profile_decode_start_frame = turn_start_frame + + turn_offset = state.text_tokens_seen.clone() + turn_steps = 0 + saw_audio = False + turn_ended_with_audio_eos = False + + while turn_steps < args.max_tts_steps: + turn_steps += 1 + state.finished.zero_() + + relative_position = state.text_tokens_seen - turn_offset + text_exhausted = relative_position >= turn_lens + + if turn_text.size(1) == 0: + current_tokens = torch.full((B,), model.eos_id, dtype=torch.long, device=device) + else: + position = relative_position.clamp(min=0, max=turn_text.size(1) - 1) + current_tokens = turn_text[torch.arange(B, device=device), position] + current_tokens = torch.where( + text_exhausted, + torch.full_like(current_tokens, model.eos_id), + current_tokens, + ) + + state, audio_codes, _ = model.streaming_step( + state=state, + text_tokens=current_tokens, + use_inference_mode=True, + ) + + if audio_codes is not None and not saw_audio: + saw_audio = True + + if bool(text_exhausted[0].item()) and bool(state.finished[0].item()): + turn_ended_with_audio_eos = True + break + + state.audio_prediction_end_idx.fill_(-1) + state.finished.zero_() + + turn_end_frame = sum(p.size(-1) for p in state.all_predictions) + profile_turn_frame_ranges.append((t, turn_start_frame, turn_end_frame)) + + logging.info( + f"[profile_multiturn] turn={t} steps={turn_steps} " + f"saw_audio={saw_audio} ended_with_audio_eos={turn_ended_with_audio_eos}" + ) + + bos_id = getattr(model, "audio_bos_id", -1) + eos_id = getattr(model, "audio_eos_id", -1) + speaking_id = getattr(model, "audio_user_speaking_id", -1) + speaking_end_id = getattr(model, "audio_user_speaking_end_id", -1) + sil_injection = codec_sil_codes.view(1, -1, 1) + + for step_idx in range(len(state.all_predictions)): + pred = state.all_predictions[step_idx] + mask = (pred == bos_id) | (pred == eos_id) | (pred == speaking_id) | (pred == speaking_end_id) + frame_mask = mask.any(dim=1, keepdim=True) + if frame_mask.any(): + state.all_predictions[step_idx] = torch.where(frame_mask, sil_injection.expand_as(pred), pred) + + if inputs["duplex_multiturn"] or inputs["profile_multiturn"]: + state.audio_prediction_end_idx.fill_(-1) + + finalize_output = model.streaming_finalize(state, use_inference_mode=True) + + return finalize_output, profile_turn_frame_ranges, profile_decode_start_frame + + +def load_speaker_wav_if_needed(args, model, target_dtype): + if args.user_custom_speaker_reference and args.inference_speaker_reference: + return _load_audio( + args.inference_speaker_reference, + model.sample_rate, + normalize=args.normalize_volume, + use_librosa=args.use_librosa, + ).unsqueeze(0).to(model.device, dtype=target_dtype) + + return None + + +# ----------------------------- +# Save generation outputs and metric manifests +# ----------------------------- + + +def write_audio_1d(path: str, wav: torch.Tensor, sr: int): + os.makedirs(os.path.dirname(path), exist_ok=True) + wav_np = wav.detach().cpu().float().numpy() + sf.write(path, wav_np, samplerate=sr) + + +def build_metric_item( + run_id: int, + rank: int, + dataset_index: int, + turn_id: int, + target_audio_path: str, + reference_text: str, + pred_audio_path: str, + context_audio_path: str, + pred_audio_samples: int, + context_audio_samples: int, + output_sample_rate: int, + context_sample_rate: int, +): + return { + "run_id": int(run_id), + "rank": int(rank), + "dataset_index": int(dataset_index), + "turn_id": int(turn_id), + "target_audio_path": target_audio_path, + "reference_text": reference_text, + "pred_audio_path": pred_audio_path, + "context_audio_path": context_audio_path, + "pred_audio_samples": int(pred_audio_samples), + "context_audio_samples": int(context_audio_samples), + "pred_audio_seconds": float(pred_audio_samples / output_sample_rate), + "context_audio_seconds": float(context_audio_samples / context_sample_rate), + "output_sample_rate": int(output_sample_rate), + "context_sample_rate": int(context_sample_rate), + } + + +def save_generation_outputs_and_build_metric_items( + model, + inputs, + finalize_output, + profile_turn_frame_ranges, + profile_decode_start_frame, + args, + rank: int, + run_id: int, +): + device = model.device + B = inputs["context_audio"].size(0) + + with fp32_precision(): + audio_f32 = finalize_output.audio.float() + audio_len = finalize_output.audio_len.int() + + expected_audio_lens = ( + torch.tensor(inputs["target_num_frames"], device=device) * model.target_samples_per_frame + ).int() + + if inputs["duplex_multiturn"]: + text_lens = inputs["input_lengths"].to(device) + audio_len = (text_lens * model.target_samples_per_frame).int() + audio_len = torch.min(audio_len, torch.tensor(audio_f32.size(1), device=device)) + elif inputs["profile_multiturn"]: + audio_len = finalize_output.audio_len.int() + else: + audio_len = torch.min(audio_len, expected_audio_lens) + + audio_out_dir = get_audio_out_dir(args) + metric_turn_dir = get_generated_turn_audio_dir(args) + metric_context_dir = get_context_metric_audio_dir(args) + os.makedirs(audio_out_dir, exist_ok=True) + os.makedirs(metric_turn_dir, exist_ok=True) + os.makedirs(metric_context_dir, exist_ok=True) + + audio_f32_cpu = audio_f32.detach().cpu() + audio_len_cpu = audio_len.detach().cpu() + metric_items = [] + + for i in range(B): + target_path = inputs["target_audio_paths"][i] + base_name = os.path.basename(target_path) + stem, ext = os.path.splitext(base_name) + if not ext: + ext = ".wav" + + dataset_idx = int(inputs.get("dataset_indices", [-1] * B)[i]) + safe_stem = ( + f"run{run_id:02d}_idx{dataset_idx:08d}_{stem}" + if dataset_idx >= 0 + else f"run{run_id:02d}_rank{rank}_{stem}" + ) + + context_len = int(inputs["context_audio_lengths"][i].detach().cpu().item()) + context_wav = inputs["context_audio"][i, :context_len].detach().cpu().float() + context_metric_path = os.path.join(metric_context_dir, f"{safe_stem}_context.wav") + write_audio_1d(context_metric_path, context_wav, model.sample_rate) + + if inputs["profile_multiturn"]: + full_len = int(audio_len_cpu[i].item()) + full_wav_t = audio_f32_cpu[i, :full_len].float() + + samples_per_prediction_frame = model.codec_model_samples_per_frame / ( + model.sample_rate / model.output_sample_rate + ) + + aligned_agent = torch.zeros_like(full_wav_t) + raw_turn_texts = inputs.get("raw_turn_texts", [[] for _ in range(B)]) + + for turn_id, start_frame, end_frame in profile_turn_frame_ranges: + rel_start_frame = start_frame - profile_decode_start_frame + rel_end_frame = end_frame - profile_decode_start_frame + + start_sample = int(round(rel_start_frame * samples_per_prediction_frame)) + end_sample = int(round(rel_end_frame * samples_per_prediction_frame)) + + start_sample = max(0, min(start_sample, full_len)) + end_sample = max(start_sample, min(end_sample, full_len)) + + aligned_agent[start_sample:end_sample] = full_wav_t[start_sample:end_sample] + + turn_wav = aligned_agent[start_sample:end_sample].float() + turn_out_path = os.path.join(audio_out_dir, f"{safe_stem}_turn_{turn_id}{ext}") + write_audio_1d(turn_out_path, turn_wav, model.output_sample_rate) + + metric_turn_path = os.path.join(metric_turn_dir, f"{safe_stem}_turn_{turn_id}.wav") + write_audio_1d(metric_turn_path, turn_wav, model.output_sample_rate) + + if turn_id < len(raw_turn_texts[i]): + metric_items.append( + build_metric_item( + run_id=run_id, + rank=rank, + dataset_index=dataset_idx, + turn_id=turn_id, + target_audio_path=target_path, + reference_text=str(raw_turn_texts[i][turn_id]), + pred_audio_path=metric_turn_path, + context_audio_path=context_metric_path, + pred_audio_samples=int(turn_wav.numel()), + context_audio_samples=int(context_wav.numel()), + output_sample_rate=model.output_sample_rate, + context_sample_rate=model.sample_rate, + ) + ) + + full_out_path = os.path.join(audio_out_dir, f"{safe_stem}{ext}") + write_audio_1d(full_out_path, aligned_agent, model.output_sample_rate) + + if "user_audio_turns" in inputs: + user_segments = [] + + first_user_len_in = int(inputs["user_audio_turns_lens"][0][i].item()) + first_user_delay_out = int(round(first_user_len_in * model.output_sample_rate / model.sample_rate)) + + for turn_id, start_frame, _ in profile_turn_frame_ranges: + if turn_id >= len(inputs["user_audio_turns"]): + continue + + turn_audio = inputs["user_audio_turns"][turn_id][i].detach().cpu().float() + turn_audio_len = int(inputs["user_audio_turns_lens"][turn_id][i].item()) + turn_audio = turn_audio[:turn_audio_len] + + turn_audio_out = resample( + turn_audio.unsqueeze(0), + model.sample_rate, + model.output_sample_rate, + ).squeeze(0) + + if turn_id == 0: + user_start_sample = 0 + else: + prev_turn_end_frame = profile_turn_frame_ranges[turn_id - 1][2] + rel_prev_end_frame = prev_turn_end_frame - profile_decode_start_frame + user_start_sample = first_user_delay_out + int( + round(rel_prev_end_frame * samples_per_prediction_frame) + ) + + user_segments.append((user_start_sample, turn_audio_out.detach().cpu().float())) + + total_user_len = 0 + for s, wav_seg in user_segments: + total_user_len = max(total_user_len, s + wav_seg.numel()) + + user_ch = torch.zeros(total_user_len) + for s, wav_seg in user_segments: + e = s + wav_seg.numel() + user_ch[s:e] += wav_seg + + agent_ch = torch.cat([torch.zeros(first_user_delay_out, dtype=aligned_agent.dtype), aligned_agent]) + + stereo_len = max(user_ch.numel(), agent_ch.numel()) + user_pad = torch.zeros(stereo_len) + agent_pad = torch.zeros(stereo_len) + + user_pad[: user_ch.numel()] = user_ch + agent_pad[: agent_ch.numel()] = agent_ch + + stereo = torch.stack([user_pad, agent_pad], dim=1).numpy() + aligned_path = os.path.join(audio_out_dir, f"{safe_stem}_user_agent_aligned{ext}") + sf.write(aligned_path, stereo, samplerate=model.output_sample_rate) + + else: + full_len = int(audio_len_cpu[i].item()) + wav = audio_f32_cpu[i, :full_len].float() + out_path = os.path.join(audio_out_dir, f"{safe_stem}{ext}") + write_audio_1d(out_path, wav, model.output_sample_rate) + + metric_turn_path = os.path.join(metric_turn_dir, f"{safe_stem}_turn_0.wav") + write_audio_1d(metric_turn_path, wav, model.output_sample_rate) + + metric_items.append( + build_metric_item( + run_id=run_id, + rank=rank, + dataset_index=dataset_idx, + turn_id=0, + target_audio_path=target_path, + reference_text=str(inputs["raw_text"][i]), + pred_audio_path=metric_turn_path, + context_audio_path=context_metric_path, + pred_audio_samples=int(wav.numel()), + context_audio_samples=int(context_wav.numel()), + output_sample_rate=model.output_sample_rate, + context_sample_rate=model.sample_rate, + ) + ) + + return metric_items + + +# ----------------------------- +# Metrics after generation +# ----------------------------- + + +def torch_rms_norm(wav: torch.Tensor, db_level: float = -27.0) -> torch.Tensor: + denom = torch.sum(wav**2) + if denom <= 0: + return wav + r = 10 ** (db_level / 20) + a = torch.sqrt((wav.size(-1) * (r**2)) / denom) + return wav * a + + +def _load_audio_for_metric(path: str, sample_rate: int): + wav = _load_audio(path, sample_rate=sample_rate, normalize=False, use_librosa=False) + if wav.numel() == 0: + wav = torch.zeros(1, dtype=torch.float32) + return wav.float() + + +def _pad_audio_1d_list(wavs: List[torch.Tensor], device, dtype=torch.float32): + if len(wavs) == 0: + return torch.zeros((0, 1), device=device, dtype=dtype), torch.zeros((0,), device=device, dtype=torch.long) + + lens = torch.tensor([max(1, int(w.numel())) for w in wavs], device=device, dtype=torch.long) + max_len = int(lens.max().item()) + out = torch.zeros((len(wavs), max_len), device=device, dtype=dtype) + + for i, w in enumerate(wavs): + w = w.to(device=device, dtype=dtype).flatten() + if w.numel() == 0: + continue + out[i, : w.numel()] = w + + return out, lens + + +def chunk_list(xs: List[Any], chunk_size: int) -> Iterable[List[Any]]: + chunk_size = max(1, int(chunk_size)) + for start in range(0, len(xs), chunk_size): + yield xs[start : start + chunk_size] + + +def _metric_device(): + return "cuda" if torch.cuda.is_available() else "cpu" + + +def _load_metric_batch_audio(batch_items: List[Dict[str, Any]], args): + pred_wavs = [] + context_wavs = [] + + for item in batch_items: + pred = _load_audio_for_metric(item["pred_audio_path"], sample_rate=int(item["output_sample_rate"])) + context = _load_audio_for_metric(item["context_audio_path"], sample_rate=int(item["context_sample_rate"])) + + if args.max_metric_audio_sec is not None: + max_pred_len = int(float(args.max_metric_audio_sec) * int(item["output_sample_rate"])) + pred = pred[: max(1, max_pred_len)] + + pred_wavs.append(pred) + context_wavs.append(context) + + device = _metric_device() + pred_audio, pred_lens = _pad_audio_1d_list(pred_wavs, device=device) + context_audio, context_lens = _pad_audio_1d_list(context_wavs, device=device) + output_sample_rate = int(batch_items[0]["output_sample_rate"]) + context_sample_rate = int(batch_items[0]["context_sample_rate"]) + + return pred_audio, pred_lens, context_audio, context_lens, output_sample_rate, context_sample_rate + + +def compute_metrics_after_generation(args, rank: int, world_size: int, metric_items: List[Dict[str, Any]]): + """ + Load metric models only after generation is complete. + + Order: + 1. Load ASR, compute turn-level CER/WER and ASR hyps, then free ASR. + 2. Load SECS speaker encoder and compute turn-level SECS. + 3. Save rank-level aggregate metrics from the same turn-level rows. + + SECS is always computed turn-by-turn, like CER/WER. The grouped filewise + output stores secs_turns and sample-level secs, and metrics_final.* receives + the turn-level aggregate SECS. + """ + metric_start = time.time() + + if len(metric_items) == 0: + return { + "rank": int(rank), + "world_size": int(world_size), + "num_processed": 0, + "num_metric_items": 0, + "metric_elapsed_sec": 0.0, + "intelligibility": {}, + "secs": {}, + }, [] + + normalizer = EnglishTextNormalizer() + normalizer.ignore_patterns = r"$^" + filewise_rows = [] + + # ASR pass. + all_rank_print(rank, f"loading ASR after generation: {args.asr_model_name}") + with fp32_precision(): + intelligibility = Intelligibility(args.asr_model_name, reuse_asr_hyps=False).reset() + + for batch_items in chunk_list(metric_items, args.metric_batch_size): + refs = [x["reference_text"] for x in batch_items] + pred_audio, pred_lens, _, _, output_sr, _ = _load_metric_batch_audio(batch_items, args) + + with fp32_precision(): + pred_16k = resample(pred_audio, output_sr, 16000) + pred_16k_lens = (pred_lens / output_sr * 16000).to(torch.long) + pred_16k = torch_rms_norm(pred_16k) + + asr_hyps = intelligibility.update( + name="dataset", + refs=refs, + pred_audio=pred_16k, + pred_audio_lens=pred_16k_lens, + asr_hyps=None, + ) + + for item, hyp in zip(batch_items, asr_hyps): + ref_norm = normalizer(str(item["reference_text"])).strip() + hyp_norm = normalizer(str(hyp)).strip() + if ref_norm == "": + cer = None + wer = None + else: + cer = float(word_error_rate([hyp_norm], [ref_norm], use_cer=True)) + wer = float(word_error_rate([hyp_norm], [ref_norm], use_cer=False)) + + row = dict(item) + row["asr_hyp"] = hyp + row["cer"] = cer + row["wer"] = wer + row["secs"] = None + filewise_rows.append(row) + + with fp32_precision(): + cer_wer = metric_dict_to_jsonable(intelligibility.compute()) + del intelligibility + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # SECS pass. This is intentionally turn-level, matching CER/WER. + # We keep one aggregate SECS metric for metrics_final.* and also compute + # one SECS value per filewise turn row so grouped outputs have secs_turns. + all_rank_print(rank, f"loading speaker encoder after ASR is released: {args.secs_model_name}") + with fp32_precision(): + secs_metric = SECS(args.secs_model_name).reset() + + # Aggregate turn-level SECS for metrics_final.json / metrics_final.txt. + for batch_items in chunk_list(metric_items, args.metric_batch_size): + pred_audio, pred_lens, context_audio, context_lens, output_sr, context_sr = _load_metric_batch_audio( + batch_items, args + ) + + with fp32_precision(): + pred_16k = resample(pred_audio, output_sr, 16000) + pred_16k_lens = (pred_lens / output_sr * 16000).to(torch.long) + context_16k = resample(context_audio, context_sr, 16000) + context_16k_lens = (context_lens / context_sr * 16000).to(torch.long) + + pred_16k = torch_rms_norm(pred_16k) + context_16k = torch_rms_norm(context_16k) + + secs_metric.update( + name="dataset", + target_audio=context_16k, + target_audio_lens=context_16k_lens, + pred_audio=pred_16k, + pred_audio_lens=pred_16k_lens, + ) + + with fp32_precision(): + secs_scores = metric_dict_to_jsonable(secs_metric.compute()) + del secs_metric + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Per-turn SECS for filewise/grouped outputs. This is always computed so + # secs_turns and sample-level secs are never null in final filewise metrics. + # It is slower than aggregate-only SECS, but it matches the turn-level + # semantics requested for CER/WER/SECS. + all_rank_print(rank, "computing per-turn SECS rows") + for row in filewise_rows: + pred_audio, pred_lens, context_audio, context_lens, output_sr, context_sr = _load_metric_batch_audio([row], args) + + with fp32_precision(): + one_secs = SECS(args.secs_model_name).reset() + pred_16k = resample(pred_audio, output_sr, 16000) + pred_16k_lens = (pred_lens / output_sr * 16000).to(torch.long) + context_16k = resample(context_audio, context_sr, 16000) + context_16k_lens = (context_lens / context_sr * 16000).to(torch.long) + + pred_16k = torch_rms_norm(pred_16k) + context_16k = torch_rms_norm(context_16k) + + one_secs.update( + name="dataset", + target_audio=context_16k, + target_audio_lens=context_16k_lens, + pred_audio=pred_16k, + pred_audio_lens=pred_16k_lens, + ) + one_secs_metrics = metric_dict_to_jsonable(one_secs.compute()) + + row["secs"] = safe_metric_scalar(one_secs_metrics, ["secs", "secs_dataset"]) + row["secs_metrics"] = one_secs_metrics + del one_secs + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + metric_elapsed = time.time() - metric_start + + rank_metrics = { + "rank": int(rank), + "world_size": int(world_size), + "num_processed": len({(x["run_id"], x["dataset_index"]) for x in metric_items}), + "num_metric_items": int(len(metric_items)), + "metric_elapsed_sec": float(metric_elapsed), + "intelligibility": cer_wer, + "secs": secs_scores, + } + + return rank_metrics, filewise_rows + + +# ----------------------------- +# Merge helpers +# ----------------------------- + + +def compute_and_save_rank_metrics_file(args, rank_metrics: Dict[str, Any], rank: int): + rank_path = os.path.join(args.out_dir, f"metrics_rank{rank:04d}.json") + write_json(rank_path, rank_metrics) + return rank_metrics + + +def merge_metrics_on_rank0(args, rank, world_size): + if rank != 0: + return None + + rank_metric_files = [os.path.join(args.out_dir, f"metrics_rank{r:04d}.json") for r in range(world_size)] + + rank_metrics = [] + for path in rank_metric_files: + if not os.path.exists(path): + logging.warning(f"Missing rank metric file: {path}") + continue + with open(path, "r", encoding="utf-8") as f: + rank_metrics.append(json.load(f)) + + total_n = sum(int(m.get("num_metric_items", m.get("num_processed", 0))) for m in rank_metrics) + + def weighted_average(section: str): + keys = set() + for m in rank_metrics: + keys.update(m.get(section, {}).keys()) + + out = {} + for k in sorted(keys): + numerator = 0.0 + denominator = 0 + + for m in rank_metrics: + n = int(m.get("num_metric_items", m.get("num_processed", 0))) + if n <= 0: + continue + + value = m.get(section, {}).get(k, None) + if value is None or isinstance(value, str): + continue + + try: + value = float(value) + except Exception: + continue + + numerator += value * n + denominator += n + + if denominator > 0: + out[k] = numerator / denominator + + return out + + final_metrics = { + "world_size": int(world_size), + "num_metric_items": int(total_n), + "aggregation": "sum(rank_metric * rank_num_metric_items) / total_num_metric_items", + "intelligibility": weighted_average("intelligibility"), + "secs": weighted_average("secs"), + "ranks": rank_metrics, + } + + final_json_path = os.path.join(args.out_dir, "metrics_final.json") + final_txt_path = os.path.join(args.out_dir, "metrics_final.txt") + + write_json(final_json_path, final_metrics) + + final_text = format_final_metric_text(final_metrics) + write_text_atomic(final_txt_path, final_text) + + print("\n--- Final Evaluation Metrics ---", flush=True) + print(final_text, flush=True) + + logging.info(f"Final metrics JSON saved to: {final_json_path}") + logging.info(f"Final metrics TXT saved to: {final_txt_path}") + logging.info(json.dumps(final_metrics, indent=2, sort_keys=True)) + + return final_metrics + + +def merge_filewise_metrics_on_rank0(args, rank: int, world_size: int): + """Merge per-turn rank metric rows into one row per original sample. + + Rank files still contain one row per turn because metrics are computed + turn-by-turn. The final filewise outputs group those turn rows by + (run_id, dataset_index), producing one JSONL/CSV row per original sample + with list fields: + reference_text, asr_hyp, cer_turns, wer_turns, secs_turns. + + DistributedSampler padding repeats are deduplicated by + (run_id, dataset_index, turn_id), but repetitions from --num_eval_runs are + preserved because run_id is part of the key. + """ + if rank != 0 or not args.save_filewise_metrics: + return [] + + turn_rows = [] + + for r in range(world_size): + path = os.path.join(args.out_dir, f"filewise_metrics_rank{r:04d}.jsonl") + if not os.path.exists(path): + logging.warning(f"Missing filewise metrics file: {path}") + continue + + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + turn_rows.append(json.loads(line)) + + # Deduplicate DistributedSampler padding repeats, but preserve --num_eval_runs. + deduped_turns = {} + for row in turn_rows: + run_id = int(row.get("run_id", 0)) + idx = int(row.get("dataset_index", -1)) + turn_id = int(row.get("turn_id", 0)) + key = (run_id, idx, turn_id) + if key not in deduped_turns: + deduped_turns[key] = row + + turn_rows = list(deduped_turns.values()) + + # Group turn rows into one row per original file/sample. + grouped = {} + for row in turn_rows: + run_id = int(row.get("run_id", 0)) + idx = int(row.get("dataset_index", -1)) + key = (run_id, idx) + + if key not in grouped: + grouped[key] = { + "run_id": run_id, + "dataset_index": idx, + "rank": int(row.get("rank", -1)), + "target_audio_path": row.get("target_audio_path", ""), + "context_audio_path": row.get("context_audio_path", ""), + "turn_rows": [], + } + + grouped[key]["turn_rows"].append(row) + + def avg(vals): + vals = [float(x) for x in vals if x is not None and math.isfinite(float(x))] + return None if not vals else sum(vals) / len(vals) + + sample_rows = [] + for _, group in grouped.items(): + turns = sorted(group["turn_rows"], key=lambda x: int(x.get("turn_id", 0))) + + cer_turns = [r.get("cer") for r in turns] + wer_turns = [r.get("wer") for r in turns] + secs_turns = [r.get("secs") for r in turns] + + sample_row = { + "run_id": group["run_id"], + "dataset_index": group["dataset_index"], + "rank": group["rank"], + "num_turns": len(turns), + "turn_ids": [int(r.get("turn_id", 0)) for r in turns], + "target_audio_path": group["target_audio_path"], + "context_audio_path": group["context_audio_path"], + "pred_audio_paths": [r.get("pred_audio_path", "") for r in turns], + "pred_audio_seconds_turns": [r.get("pred_audio_seconds") for r in turns], + "reference_text": [r.get("reference_text", "") for r in turns], + "asr_hyp": [r.get("asr_hyp", "") for r in turns], + "cer_turns": cer_turns, + "wer_turns": wer_turns, + "secs_turns": secs_turns, + "cer": avg(cer_turns), + "wer": avg(wer_turns), + "secs": avg(secs_turns), + } + + sample_rows.append(sample_row) + + # Sort samples by average CER descending for failure analysis. + sample_rows.sort( + key=lambda x: ( + x.get("cer") is not None, + float(x.get("cer")) if x.get("cer") is not None else -1.0, + ), + reverse=True, + ) + + jsonl_path = os.path.join(args.out_dir, "filewise_metrics_sorted_by_cer.jsonl") + csv_path = os.path.join(args.out_dir, "filewise_metrics_sorted_by_cer.csv") + + write_jsonl(jsonl_path, sample_rows) + write_filewise_csv(csv_path, sample_rows) + + logging.info(f"Saved sample-level filewise metrics JSONL to: {jsonl_path}") + logging.info(f"Saved sample-level filewise metrics CSV to: {csv_path}") + + topk = min(int(args.filewise_metrics_topk_log), len(sample_rows)) + if topk > 0: + logging.info(f"Top {topk} worst CER samples:") + for row in sample_rows[:topk]: + logging.info( + "run_id=%s dataset_index=%s num_turns=%s cer=%s wer=%s secs=%s path=%s" + % ( + row.get("run_id"), + row.get("dataset_index"), + row.get("num_turns"), + row.get("cer"), + row.get("wer"), + row.get("secs"), + row.get("target_audio_path"), + ) + ) + + return sample_rows + +def compute_aggregates_from_filewise_rows(rows: List[Dict[str, Any]]): + """Aggregate over sample-level rows. + + Each row may internally contain multiple turn metrics in cer_turns/wer_turns, + but the final filewise average is over original samples/files. + """ + if len(rows) == 0: + return { + "cer": None, + "wer": None, + "secs": None, + "num_samples": 0, + } + + def avg_key(key): + vals = [float(r[key]) for r in rows if r.get(key) is not None] + if len(vals) == 0: + return None + return sum(vals) / len(vals) + + return { + "cer": avg_key("cer"), + "wer": avg_key("wer"), + "secs": avg_key("secs"), + "num_samples": len(rows), + } + +def save_filewise_final_summary(args, filewise_rows: List[Dict[str, Any]]): + filewise_summary = compute_aggregates_from_filewise_rows(filewise_rows) + + obj = { + "aggregation": "mean_over_sample_metrics_each_sample_contains_turn_metric_lists", + **filewise_summary, + } + + path = os.path.join(args.out_dir, "metrics_final_filewise_average.json") + write_json(path, obj) + + sample_metrics_final_path = os.path.join(args.out_dir, "metrics_final_sample_average.json") + write_json(sample_metrics_final_path, obj) + + final_txt_path = os.path.join(args.out_dir, "metrics_final.txt") + final_text = format_filewise_final_metric_text(filewise_summary) + write_text_atomic(final_txt_path, final_text) + + print("\n--- Final Sample-Averaged Evaluation Metrics ---", flush=True) + print(final_text, flush=True) + + logging.info(f"Filewise averaged final metrics saved to: {path}") + logging.info(f"Sample averaged metrics_final JSON saved to: {sample_metrics_final_path}") + logging.info(f"Final metrics TXT saved to: {final_txt_path}") + + return obj + + +# ----------------------------- +# Args / main +# ----------------------------- + + +def parse_args(): + parser = argparse.ArgumentParser(description="EasyMagpieTTS Multi-GPU Inference Evaluation") + + parser.add_argument("--checkpoint_path", type=str, required=True) + parser.add_argument("--codec_model_path", type=str, required=True) + parser.add_argument("--datasets_json_path", type=str, required=True) + parser.add_argument("--out_dir", type=str, required=True) + + parser.add_argument("--phoneme_tokenizer_path", type=str, default=None) + parser.add_argument("--audio_dir", type=str, default=None) + parser.add_argument("--inference_dtype", type=str, default="float32", choices=["float32", "float16", "bfloat16"]) + parser.add_argument("--debug_dtype", action="store_true") + parser.add_argument("--debug_gpu_assignment", action="store_true") + parser.add_argument("--use_librosa", action="store_true") + + parser.add_argument("--batch_size", type=int, default=6) + parser.add_argument("--num_workers", type=int, default=4) + parser.add_argument("--num_turns", type=int, default=1) + parser.add_argument("--pad_factor_text_speech", type=int, default=10) + + parser.add_argument("--emulate_duplex_inference", action="store_true") + parser.add_argument("--add_interruption_token", action="store_true") + parser.add_argument("--force_interruption", action="store_true") + parser.add_argument("--profile_multiturn_inference", action="store_true") + parser.add_argument("--profile_pad_min_sec", type=float, default=2.0) + parser.add_argument("--profile_pad_max_sec", type=float, default=2.0) + parser.add_argument("--max_eval_turns", type=int, default=6) + + parser.add_argument("--user_custom_speaker_reference", action="store_true") + parser.add_argument("--inference_speaker_reference", type=str, default=None) + parser.add_argument("--language", type=str, default="en") + + parser.add_argument("--use_cfg", action="store_true") + parser.add_argument("--cfg_scale", type=float, default=2.5) + parser.add_argument("--temperature", type=float, default=0.7) + parser.add_argument("--topk", type=int, default=80) + parser.add_argument("--max_tts_steps", type=int, default=2000) + parser.add_argument("--force_speech_sil_codes", action="store_true") + parser.add_argument("--normalize_volume", type=lambda x: str(x).lower() in ["true", "1", "yes"], default=True) + + parser.add_argument( + "--save_filewise_metrics", + action="store_true", + help="Save per-turn/file CER/WER metrics sorted by CER descending.", + ) + parser.add_argument( + "--compute_filewise_secs", + action="store_true", + help="Also compute per-turn/file SECS. Slower because it runs SECS per row.", + ) + parser.add_argument( + "--filewise_metrics_topk_log", + type=int, + default=20, + help="Number of worst CER samples to print on rank 0.", + ) + parser.add_argument( + "--num_eval_runs", + type=int, + default=1, + help="Repeat the full eval set N times. Repetitions are preserved in final filewise average.", + ) + parser.add_argument( + "--sort_by_text_token_count", + action="store_true", + help="Sort eval samples by total text token count before distributed sharding for better load balancing.", + ) + parser.add_argument( + "--metric_batch_size", + type=int, + default=8, + help="Batch size used for post-generation ASR/SECS metric computation.", + ) + parser.add_argument( + "--max_metric_audio_sec", + type=float, + default=120.0, + help="Clamp generated audio length used for ASR/SECS metrics to avoid metric OOM/hangs.", + ) + parser.add_argument( + "--asr_model_name", + type=str, + default="stt_en_fastconformer_transducer_large", + help="Pretrained NeMo ASR model used for CER/WER.", + ) + parser.add_argument( + "--secs_model_name", + type=str, + default="titanet_large", + help="Pretrained speaker encoder model used for SECS.", + ) + + return parser.parse_args() + + +def main(): + args = parse_args() + os.makedirs(args.out_dir, exist_ok=True) + os.makedirs(get_audio_out_dir(args), exist_ok=True) + os.makedirs(get_generated_turn_audio_dir(args), exist_ok=True) + os.makedirs(get_context_metric_audio_dir(args), exist_ok=True) + + distributed, rank, local_rank, world_size, device_index = setup_distributed() + + if args.profile_multiturn_inference and args.batch_size != 1: + raise RuntimeError( + "--profile_multiturn_inference requires --batch_size=1 per process. " + "Use multiple GPUs/processes for parallelism instead of increasing batch_size." + ) + + if args.profile_pad_max_sec < args.profile_pad_min_sec: + raise RuntimeError("--profile_pad_max_sec must be >= --profile_pad_min_sec.") + + if args.num_eval_runs <= 0: + raise RuntimeError("--num_eval_runs must be >= 1.") + + target_device = torch.device(f"cuda:{device_index}" if torch.cuda.is_available() and device_index >= 0 else "cpu") + target_dtype = getattr(torch, args.inference_dtype) + torch.set_default_dtype(target_dtype) + + hostname = socket.gethostname() + cuda_name = torch.cuda.get_device_name(target_device) if torch.cuda.is_available() and device_index >= 0 else "cpu" + + all_rank_print( + rank, + f"host={hostname} local_rank={local_rank} world_size={world_size} " + f"device={target_device} device_name={cuda_name}", + ) + + model = build_model_and_codec(args, target_device, target_dtype) + codec_sil_codes = model.codec_sil_codes + + if args.debug_dtype: + handles, stats, examples = attach_dtype_counter(model) + else: + handles = stats = examples = None + + full_eval_dataset = EvalJSONLDataset(args.datasets_json_path, num_turns=args.num_turns) + # debug + # full_eval_dataset.samples = full_eval_dataset.samples[:7] + + if args.sort_by_text_token_count: + full_eval_dataset = SortedByTextTokenCountDataset( + full_eval_dataset, + model=model, + max_eval_turns=args.max_eval_turns, + descending=True, + ) + + collate_fn = partial( + collate_and_tokenize_custom, + model=model, + extra_duration_thrshould=1.5, + sample_rate=model.sample_rate, + root_path=args.audio_dir, + emulate_duplex_inference=args.emulate_duplex_inference, + add_interruption_token=args.add_interruption_token, + pad_factor_text_speech=args.pad_factor_text_speech, + force_interruption=args.force_interruption, + normalize_audio_volume=args.normalize_volume, + use_librosa=args.use_librosa, + profile_multiturn_inference=args.profile_multiturn_inference, + max_eval_turns=args.max_eval_turns, + ) + + speaker_wav = load_speaker_wav_if_needed(args, model, target_dtype) + + generation_start = time.time() + all_metric_items = [] + total_batches = 0 + total_generated_samples = 0 + + for run_id in range(args.num_eval_runs): + if distributed: + sampler = DistributedSampler( + full_eval_dataset, + num_replicas=world_size, + rank=rank, + shuffle=False, + drop_last=False, + ) + sampler.set_epoch(run_id) + else: + sampler = SequentialSampler(full_eval_dataset) + + if args.debug_gpu_assignment: + try: + assigned_indices = list(iter(sampler)) + assigned_dataset_indices = [ + int(full_eval_dataset[i].get("__dataset_index__", -1)) for i in assigned_indices + ] + all_rank_print( + rank, + f"run_id={run_id} assigned {len(assigned_dataset_indices)} / {len(full_eval_dataset)} " + f"samples to gpu={local_rank}: dataset_indices={assigned_dataset_indices}", + ) + except Exception as e: + all_rank_print(rank, f"Could not print assigned indices: {repr(e)}") + + dataloader = DataLoader( + dataset=full_eval_dataset, + batch_size=args.batch_size, + sampler=sampler, + collate_fn=collate_fn, + num_workers=args.num_workers, + pin_memory=True, + drop_last=False, + ) + + for batch_id, inputs in enumerate(dataloader): + total_batches += 1 + batch_indices = inputs.get("dataset_indices", []) + total_generated_samples += len(batch_indices) + + if args.debug_gpu_assignment: + all_rank_print( + rank, + f"run_id={run_id} gpu={local_rank} batch_id={batch_id} " + f"dataset_indices={batch_indices} text_token_counts={inputs.get('text_token_counts', [])} " + f"target_paths={inputs.get('target_audio_paths', [])}", + ) + + inputs = prepare_inputs_for_device(inputs, model, args, target_dtype, speaker_wav=speaker_wav) + + finalize_output, profile_turn_frame_ranges, profile_decode_start_frame = run_generation( + model=model, + inputs=inputs, + args=args, + codec_sil_codes=codec_sil_codes, + ) + + metric_items = save_generation_outputs_and_build_metric_items( + model=model, + inputs=inputs, + finalize_output=finalize_output, + profile_turn_frame_ranges=profile_turn_frame_ranges, + profile_decode_start_frame=profile_decode_start_frame, + args=args, + rank=rank, + run_id=run_id, + ) + all_metric_items.extend(metric_items) + + if args.debug_dtype and batch_id == 0 and run_id == 0: + report_dtype_stats(handles, stats, examples, rank=rank) + + generation_elapsed = time.time() - generation_start + + # Save pre-metric manifest for debugging and restartability. + metric_manifest_path = os.path.join(args.out_dir, f"metric_items_rank{rank:04d}.jsonl") + write_jsonl(metric_manifest_path, all_metric_items) + + # Free TTS/codec model memory before loading ASR and speaker encoder metrics. + del model + if speaker_wav is not None: + del speaker_wav + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + all_rank_print( + rank, + f"generation done: batches={total_batches} generated_samples_with_sampler_padding={total_generated_samples} " + f"metric_items={len(all_metric_items)} elapsed_sec={generation_elapsed:.2f}. " + "Loading ASR/SECS metrics now.", + ) + + rank_metrics, rank_filewise_rows = compute_metrics_after_generation( + args=args, + rank=rank, + world_size=world_size, + metric_items=all_metric_items, + ) + rank_metrics["generation_elapsed_sec"] = float(generation_elapsed) + rank_metrics["num_generated_samples_with_sampler_padding"] = int(total_generated_samples) + + rank_metrics = compute_and_save_rank_metrics_file(args, rank_metrics, rank) + all_rank_print(rank, f"saved rank metrics: {json.dumps(rank_metrics, sort_keys=True)}") + + if args.save_filewise_metrics: + rank_filewise_rows.sort( + key=lambda x: ( + x.get("cer") is not None, + float(x.get("cer")) if x.get("cer") is not None else -1.0, + ), + reverse=True, + ) + + rank_filewise_path = os.path.join(args.out_dir, f"filewise_metrics_rank{rank:04d}.jsonl") + write_jsonl(rank_filewise_path, rank_filewise_rows) + all_rank_print(rank, f"saved filewise metrics: {rank_filewise_path}") + + if rank == 0: + wait_for_rank_metric_files(args, world_size) + + merge_metrics_on_rank0(args, rank, world_size) + + if args.save_filewise_metrics: + if rank == 0: + wait_for_rank_filewise_metric_files(args, world_size) + + filewise_rows = merge_filewise_metrics_on_rank0(args, rank, world_size) + + if rank == 0: + save_filewise_final_summary(args, filewise_rows) + + cleanup_distributed() + + +if __name__ == "__main__": + main() From 6cf3e02dbcf926a95ac89bf88d6e6f0b7081a182 Mon Sep 17 00:00:00 2001 From: Paarth Neekhara Date: Thu, 16 Apr 2026 17:15:06 -0400 Subject: [PATCH 002/109] WIP Signed-off-by: Paarth Neekhara Signed-off-by: Edresson Casanova --- nemo/collections/tts/models/easy_magpietts_inference.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index c0ef45d9a7a9..58aaac61f492 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -13,6 +13,7 @@ # limitations under the License. import random import time +import random from dataclasses import dataclass, fields from functools import partial from typing import Any, Dict, List, Optional, Sequence, Tuple From ed01ceb2090e63b997a88fb65233d4d11615f41c Mon Sep 17 00:00:00 2001 From: Paarth Neekhara Date: Sun, 19 Apr 2026 16:24:39 -0700 Subject: [PATCH 003/109] speaker encoder optional Signed-off-by: Paarth Neekhara Signed-off-by: Edresson Casanova --- nemo/collections/tts/models/easy_magpietts_inference.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index 58aaac61f492..40b18288764c 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -953,7 +953,6 @@ def prepare_context_tensors( ) context_audio_embedded = self.embed_audio_tokens(context_audio_codes) # (B, T', E) batch_size = context_audio_embedded.size(0) - if self.use_speaker_encoder: if ( self.training From 74cd9cdaa349401e16eb5639f444bea3160121cd Mon Sep 17 00:00:00 2001 From: paarthneekhara Date: Wed, 22 Apr 2026 23:37:02 +0000 Subject: [PATCH 004/109] Apply isort and black reformatting Signed-off-by: paarthneekhara Signed-off-by: Edresson Casanova --- nemo/collections/tts/models/easy_magpietts_inference.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index 40b18288764c..039a2e183101 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -11,9 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +<<<<<<< HEAD import random import time +======= +>>>>>>> 3ea2b31298 (Apply isort and black reformatting) import random +import time from dataclasses import dataclass, fields from functools import partial from typing import Any, Dict, List, Optional, Sequence, Tuple From 382fb952a2dc0a9d3757a958662acab244416883 Mon Sep 17 00:00:00 2001 From: Paarth Neekhara Date: Sat, 2 May 2026 15:11:41 -0700 Subject: [PATCH 005/109] add option to remove text embedding and lm head Signed-off-by: Paarth Neekhara Signed-off-by: Edresson Casanova --- nemo/collections/tts/models/easy_magpietts_inference.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index 039a2e183101..fa6d949ef871 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -11,11 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -<<<<<<< HEAD import random import time -======= ->>>>>>> 3ea2b31298 (Apply isort and black reformatting) import random import time from dataclasses import dataclass, fields From 3899c7e0b00c6c4d7ba2bb8a382601fe1c01598d Mon Sep 17 00:00:00 2001 From: Paarth Neekhara Date: Sat, 2 May 2026 15:35:10 -0700 Subject: [PATCH 006/109] cas encoder layers Signed-off-by: Paarth Neekhara Signed-off-by: Edresson Casanova --- examples/tts/conf/magpietts/easy_magpietts.yaml | 1 + .../tts/conf/magpietts/easy_magpietts_lhotse.yaml | 1 + .../tts/models/easy_magpietts_inference.py | 1 + nemo/collections/tts/modules/magpietts_modules.py | 12 ++++++++++-- 4 files changed, 13 insertions(+), 2 deletions(-) diff --git a/examples/tts/conf/magpietts/easy_magpietts.yaml b/examples/tts/conf/magpietts/easy_magpietts.yaml index 2d0c274eb7e0..f1a58d0ae71c 100644 --- a/examples/tts/conf/magpietts/easy_magpietts.yaml +++ b/examples/tts/conf/magpietts/easy_magpietts.yaml @@ -17,6 +17,7 @@ model: disable_lm_text_head: false disable_subword_embedding: false use_bpe_char_tokenizer: true + cas_encoder_n_layers: 1 # HuggingFace backend config (used when decoder_type: "huggingface") transformer_hf_backend: "Qwen/Qwen2.5-1.5B" diff --git a/examples/tts/conf/magpietts/easy_magpietts_lhotse.yaml b/examples/tts/conf/magpietts/easy_magpietts_lhotse.yaml index d6b02adc4b97..01d3666ba958 100644 --- a/examples/tts/conf/magpietts/easy_magpietts_lhotse.yaml +++ b/examples/tts/conf/magpietts/easy_magpietts_lhotse.yaml @@ -15,6 +15,7 @@ model: disable_lm_text_head: false disable_subword_embedding: false use_bpe_char_tokenizer: true + cas_encoder_n_layers: 1 # HuggingFace backend config (used when decoder_type: "huggingface") transformer_hf_backend: "Qwen/Qwen2.5-1.5B" diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index fa6d949ef871..296d93a09315 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -522,6 +522,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): llm_tokenizer_vocab=subword_vocab, subword_padding_idx=self.tokenizer.pad, special_vocab=special_vocab, + n_layers=cfg.get('cas_encoder_n_layers', 1), ) if self.disable_subword_embedding and not hasattr(self, 'cas_encoder'): diff --git a/nemo/collections/tts/modules/magpietts_modules.py b/nemo/collections/tts/modules/magpietts_modules.py index 483bae678243..33666b02ac86 100644 --- a/nemo/collections/tts/modules/magpietts_modules.py +++ b/nemo/collections/tts/modules/magpietts_modules.py @@ -181,7 +181,14 @@ class CharAwareSubwordEncoder(NeuralModule): The output is a tensor of shape (batch_size, max_subword_length, d_embed). """ - def __init__(self, d_embed: int, llm_tokenizer_vocab: dict, subword_padding_idx: int, special_vocab: dict = None): + def __init__( + self, + d_embed: int, + llm_tokenizer_vocab: dict, + subword_padding_idx: int, + special_vocab: dict = None, + n_layers: int = 1, + ): """ Args: d_embed (int): The dimension of the embedding. @@ -191,6 +198,7 @@ def __init__(self, d_embed: int, llm_tokenizer_vocab: dict, subword_padding_idx: subword_padding_idx (int): The padding index for the subword vocabulary. special_vocab (dict): items of special token dictionary (usually BOS, EOS) eg. special_vocab = {'': 30001, '': 30002} + n_layers (int): Number of transformer layers used in the char-aware encoder. """ super().__init__() self.subword_id_to_char_ids, self.char_vocab = build_vocabs( @@ -198,7 +206,7 @@ def __init__(self, d_embed: int, llm_tokenizer_vocab: dict, subword_padding_idx: ) self.embed_tokens = torch.nn.Embedding(self.vocab_size + 1, d_embed, padding_idx=self.vocab_size) self.encoder = transformer_2501.Transformer( - n_layers=1, + n_layers=n_layers, d_model=d_embed, d_ffn=d_embed * 4, sa_n_heads=8, From 3711b3a75b53e055c10065fa3381f9db14690263 Mon Sep 17 00:00:00 2001 From: Paarth Neekhara Date: Sun, 3 May 2026 11:35:32 -0700 Subject: [PATCH 007/109] use IPA as text prob added during training Signed-off-by: Paarth Neekhara Signed-off-by: Edresson Casanova --- .../tts/data/text_to_speech_dataset_lhotse.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py index c9baaaa96324..2905de072f96 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py @@ -463,6 +463,17 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) text_str = cut.supervisions[0].normalized_text else: text_str = cut.supervisions[0].text + + should_use_ipa_as_text = ( + self.dataset_type == 'train' + and self.ipa_as_text_prob > 0.0 + and random.random() < self.ipa_as_text_prob + and cut.supervisions[0].has_custom("ipa") + and language not in self.ignore_phoneme_languages + ) + if should_use_ipa_as_text: + text_str = cut.supervisions[0].ipa + raw_text_list.append(text_str) if cut.has_custom("tokenizer_names"): # Pick a random tokenizer from the list of tokenizers From a5e00b355a54c5b04043ea29832738d3ff1536fb Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 13 Apr 2026 14:35:31 -0700 Subject: [PATCH 008/109] Add multiturn dataloader Signed-off-by: Edresson Casanova --- nemo/collections/common/data/lhotse/cutset.py | 9 + ...to_speech_dataset_lhotse_multiturn copy.py | 953 ++++++++++++++++++ ...text_to_speech_dataset_lhotse_multiturn.py | 627 ++++++++++++ nemo/collections/tts/models/easy_magpietts.py | 73 +- 4 files changed, 1640 insertions(+), 22 deletions(-) create mode 100644 nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn copy.py create mode 100644 nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index 8e613091b7e1..02db79a0a153 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -24,6 +24,8 @@ from pathlib import Path from typing import KeysView, List, Mapping, Sequence, Tuple, Union +from copy import deepcopy + import numpy as np import omegaconf import soundfile as sf @@ -976,6 +978,7 @@ def convert_cut_fn(cut: Cut) -> Cut: zero_audio = np.zeros((1, int(total_duration * sample_rate)), dtype=np.float32) source_recording = create_recording_from_array(zero_audio, sample_rate, recording_id=f"{cut.id}_source") + # Create source cut, preserving cut.custom safely cut_source = MonoCut( id=f"{cut.id}_source", start=0.0, @@ -983,6 +986,7 @@ def convert_cut_fn(cut: Cut) -> Cut: channel=0, recording=source_recording, supervisions=[], + custom=deepcopy(cut.custom) if cut.custom is not None else None, ) # Save to memory @@ -993,6 +997,11 @@ def convert_cut_fn(cut: Cut) -> Cut: user_sup = fastcopy(orig_agent_sup, start=0.0, duration=0.08, speaker="user", text="dummy text") agent_sup = fastcopy(orig_agent_sup, start=0.0, duration=target_audio_orig_dur - 0.08, speaker="agent") + # Safely wipe IPA from dummy user text if it exists + if user_sup.custom is not None and "ipa" in user_sup.custom: + user_sup.custom = deepcopy(user_sup.custom) + user_sup.custom["ipa"] = "" + # Optionally add extra silence if add_extra_end_sil: sil_duration = random.uniform(*extra_end_silence_range) diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn copy.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn copy.py new file mode 100644 index 000000000000..413abddfe264 --- /dev/null +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn copy.py @@ -0,0 +1,953 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import random +import re +from copy import deepcopy + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.data +from lhotse import CutSet, Seconds, compute_num_frames +from lhotse.cut import Cut +from lhotse.dataset.collation import collate_audio, collate_vectors, collate_matrices +from lhotse.utils import ifnone + +from nemo.collections.common.tokenizers import TokenizerSpec +from nemo.collections.speechlm2.data.utils import get_pad_id +from nemo.collections.speechlm2.parts.precision import fp32_precision +from nemo.collections.tts.parts.utils.helpers import get_mask_from_lengths +from nemo.utils import logging + +from hydra.utils import instantiate +from omegaconf import DictConfig +from nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers import IPABPETokenizer + +class MagpieTTSLhotseMultiturnDataset(torch.utils.data.Dataset): + """ + A dataset for duplex speech-to-speech models that handles bidirectional conversations. + + This dataset processes Lhotse CutSet objects containing recordings with supervision segments + from different speakers (roles). It creates aligned representations of audio and text for + both source (input) and target (output) channels, preserving temporal alignment between + audio frames and text tokens. + + Args: + tokenizer (TokenizerSpec): + Tokenizer for converting text to token IDs and vice versa. Must support BOS and EOS tokens. + It's expected to support PAD token as well, otherwise we will use 0 as the pad token + and emit a warning. + + frame_length (Seconds): + Duration of a single frame in seconds. Used to calculate frame positions for token alignment. + + source_sample_rate (int): + Sample rate for source audio (e.g., 16000 Hz). + + target_sample_rate (int): + Sample rate for target audio (e.g., 22050 Hz). + + input_roles (list[str], optional): + List of speaker roles (cut.supervisions[:].speaker) to consider as inputs. Defaults to ["user"]. + + output_roles (list[str], optional): + List of speaker roles (cut.supervisions[:].speaker) to consider as outputs. Defaults to ["agent"]. + + p_drop_description (float, optional): + Probability of dropping text descriptions. Default: `0.0`. + + add_text_bos_and_eos_in_each_turn (bool, optional): + If True, each conversational turn from any speaker is explicitly delimited + with BOS and EOS tokens in the text stream. + Default: `True`. + + Returns: + A dictionary with the following keys: + - sample_id: List of sample IDs for each cut in the batch [B] + + - non_prompt_mask: Bool tensor [B, T] marking positions that are not part of the prompt + - prompt_lens: Tensor of description + audio prompt lengths [B] + + - aligned_attention_mask: Bool tensor [B, T] used by alignment-aware transformer models + - aligned_position_ids: Tensor of position indices aligned to audio frames [B, T] + + - source_audio: Tensor of source waveform samples [B, T] + - source_audio_lens: Tensor of source audio lengths [B] + + - target_audio: Tensor of target waveform samples [B, T] + - target_audio_lens: Tensor of target audio lengths [B] + + - target_text_tokens: Tensor of frame-aligned input text tokens [B, T], + including BOS/EOS/PAD when enabled + - target_token_lens: Tensor of target token sequence lengths [B] + + - source_tokens: Tensor of frame-aligned source text tokens [B, T], + including BOS/EOS/PAD + - source_token_lens: Tensor of source token sequence lengths [B] + + - target_texts: List of full target texts joined from output_roles supervisions [B] + + - audio_prompt: Tensor of optional speaker reference waveform samples [B, T] + - audio_prompt_lens: Tensor of speaker reference audio lengths [B] + + - task: List indicating the task to use for each cut (default "s2s_duplex") [B] + + Notes: + - The dataset ensures frame-level alignment between audio and text by inserting tokens at + specific frame positions based on the timing of supervision segments. + - PAD tokens (typically 0) are used to fill gaps where there's no text. + - BOS tokens mark the beginning of each speech segment. + - EOS tokens mark the end of each speech segment. + - Text tokens from each speaker are placed at frame positions corresponding to their + timestamp in the original recording, preserving the temporal relationship. + This is a segment-level alignment only, not word-level alignment. + """ + + def __init__( + self, + tokenizer, + frame_length: Seconds, + source_sample_rate: int, + target_sample_rate: int, + input_roles: list[str] = None, + output_roles: list[str] = None, + p_drop_description: float = 0.0, + add_text_bos_and_eos_in_each_turn: bool = False, + add_audio_prompt: bool = False, + audio_prompt_duration: float = 3.0, + num_delay_speech_tokens: int = 0, + add_system_prompt: bool = False, + ignore_data_system_prompt: bool = True, + phoneme_tokenizer_config: DictConfig = None, + ignore_phoneme_languages: list[str] = None, + load_cached_codes_if_available: bool = False, + ): + self.tokenizer = tokenizer + self.frame_length = frame_length + self.source_sample_rate = source_sample_rate + self.target_sample_rate = target_sample_rate + self.input_roles = set(ifnone(input_roles, ["user"])) + self.output_roles = set(ifnone(output_roles, ["agent"])) + self.p_drop_description = p_drop_description + self.add_text_bos_and_eos_in_each_turn = add_text_bos_and_eos_in_each_turn + self.add_audio_prompt = add_audio_prompt + self.audio_prompt_duration = audio_prompt_duration + self.num_delay_speech_tokens = num_delay_speech_tokens + self.add_system_prompt = add_system_prompt + self.ignore_data_system_prompt = ignore_data_system_prompt + + self.phoneme_tokenizer_config = phoneme_tokenizer_config + self.ignore_phoneme_languages = ignore_phoneme_languages or [] + self.phoneme_tokenizer = None + self.load_cached_codes_if_available = load_cached_codes_if_available + + self.source_samples_per_frame = int(self.source_sample_rate * self.frame_length) + self.target_samples_per_frame = int(self.target_sample_rate * self.frame_length) + + assert tokenizer.bos is not None, "BOS support in the tokenizer is required for S2S models." + assert tokenizer.eos is not None, "EOS support in the tokenizer is required for S2S models." + + def __getitem__(self, cuts: CutSet) -> dict: + if self.phoneme_tokenizer is None and getattr(self, "phoneme_tokenizer_config", None) is not None: + self.phoneme_tokenizer = instantiate(self.phoneme_tokenizer_config) + + cuts = cuts.transform_text(_strip_timestamps) + + batch_tokenizer_names = [] + for cut in cuts: + if cut.has_custom("tokenizer_names"): + batch_tokenizer_names.append(random.choice(cut.tokenizer_names)) + else: + batch_tokenizer_names.append("english_phoneme") + + target_codes_list = [] + source_codes_list = [] + if self.load_cached_codes_if_available: + for cut in cuts: + if cut.has_custom("target_codes"): + codes_array = cut.target_codes.load().astype(np.int32) + target_codes_list.append(torch.from_numpy(codes_array).T) + if cut.has_custom("source_codes"): + codes_array = cut.source_codes.load().astype(np.int32) + source_codes_list.append(torch.from_numpy(codes_array).T) + + if target_codes_list: + target_codes = collate_matrices(target_codes_list, padding_value=0).transpose(1, 2) + target_codes_lens = torch.tensor([c.shape[0] for c in target_codes_list], dtype=torch.int32) + else: + target_codes, target_codes_lens = None, None + + if source_codes_list: + source_codes = collate_matrices(source_codes_list, padding_value=0).transpose(1, 2) + source_codes_lens = torch.tensor([c.shape[0] for c in source_codes_list], dtype=torch.int32) + else: + source_codes, source_codes_lens = None, None + + with fp32_precision(): + source_audio, source_audio_lens = collate_audio(cuts.resample(self.source_sample_rate)) + target_audio, target_audio_lens = collate_audio( + cuts.resample(self.target_sample_rate, recording_field="target_audio"), recording_field="target_audio" + ) + + target_text_tokens, target_token_lens = collate_token_channel( + cuts, + self.tokenizer, + self.frame_length, + roles=self.output_roles, + add_text_bos_and_eos_in_each_turn=self.add_text_bos_and_eos_in_each_turn, + tokenizer_names=batch_tokenizer_names, + ) + source_tokens, source_token_lens = collate_token_channel( + cuts, + self.tokenizer, + self.frame_length, + roles=self.input_roles, + add_text_bos_and_eos_in_each_turn=self.add_text_bos_and_eos_in_each_turn, + tokenizer_names=batch_tokenizer_names, + ) + + if self.phoneme_tokenizer is not None: + target_phoneme_tokens, target_phoneme_lens = collate_phoneme_channel( + cuts, + self.phoneme_tokenizer, + self.frame_length, + roles=self.output_roles, + ignore_phoneme_languages=self.ignore_phoneme_languages, + add_text_bos_and_eos_in_each_turn=self.add_text_bos_and_eos_in_each_turn, + ) + else: + target_phoneme_tokens, target_phoneme_lens = None, None + + with fp32_precision(): + audio_prompt, audio_prompt_lens = get_audio_prompt( + cuts, self.target_sample_rate, roles=self.output_roles, recording_field="target_audio" + ) + + if self.num_delay_speech_tokens: + ( + source_audio, + source_audio_lens, + target_audio, + target_audio_lens, + source_codes, + source_codes_lens, + target_codes, + target_codes_lens + ) = add_speech_delay( + source_audio, + source_audio_lens, + target_audio, + target_audio_lens, + self.num_delay_speech_tokens, + self.target_samples_per_frame, + self.source_samples_per_frame, + source_codes=source_codes, + source_codes_lens=source_codes_lens, + target_codes=target_codes, + target_codes_lens=target_codes_lens, + ) + + if self.add_system_prompt: + with fp32_precision(): + system_prompts, system_prompts_lens, system_prompts_raw = collate_system_prompt( + cuts, + self.tokenizer, + ignore_data_system_prompt=self.ignore_data_system_prompt, + tokenizer_names=batch_tokenizer_names, + ) + else: + system_prompts = None + system_prompts_lens = None + system_prompts_raw = None + + dataset_type = [getattr(c, "type", "") for c in cuts] + + ( + target_text_tokens, + target_token_lens, + source_tokens, + source_token_lens, + source_audio, + source_audio_lens, + target_audio, + target_audio_lens, + prompt_lens, + target_phoneme_tokens, + target_phoneme_lens, + source_codes, + source_codes_lens, + target_codes, + target_codes_lens, + ) = self.maybe_add_audio_prompt( + target_text_tokens, target_token_lens, source_tokens, source_token_lens, + target_audio, target_audio_lens, source_audio, source_audio_lens, + audio_prompt, audio_prompt_lens, system_prompts, system_prompts_lens, + target_phoneme_tokens=target_phoneme_tokens, target_phoneme_lens=target_phoneme_lens, + source_codes=source_codes, source_codes_lens=source_codes_lens, + target_codes=target_codes, target_codes_lens=target_codes_lens, + ) + + non_prompt_mask = get_mask_from_lengths(target_token_lens) + for i, frame in enumerate(prompt_lens): + non_prompt_mask[i, : frame - 1] = 0.0 + + max_len = max(target_token_lens) + aligned_segment_ids = torch.stack( + [torch.nn.functional.pad(torch.full((seq_len,), i), (0, max_len - seq_len), value=-1) for i, seq_len in enumerate(target_token_lens)], dim=0, + ) + aligned_attention_mask = (aligned_segment_ids.unsqueeze(-2) == aligned_segment_ids.unsqueeze(-1)) & ( + torch.arange(max_len).unsqueeze(0).unsqueeze(1) <= torch.arange(max_len).unsqueeze(0).unsqueeze(-1) + ) + aligned_attention_mask = aligned_attention_mask.unsqueeze(1) + aligned_position_ids = torch.stack( + [torch.nn.functional.pad(torch.arange(seq_len), (0, max(target_token_lens) - seq_len), value=0) for seq_len in target_token_lens], dim=0, + ) + + batch_dict = { + "sample_id": [str(cut.id) for cut in cuts], + "non_prompt_mask": non_prompt_mask.bool(), + "prompt_lens": prompt_lens, + "aligned_attention_mask": aligned_attention_mask.bool(), + "aligned_position_ids": aligned_position_ids, + "source_audio": source_audio, + "source_audio_lens": source_audio_lens, + "target_audio": target_audio, + "target_audio_lens": target_audio_lens, + "target_text_tokens": target_text_tokens, + "target_token_lens": target_token_lens, + "source_tokens": source_tokens, + "source_token_lens": source_token_lens, + "target_texts": [ + " ".join(s.text for s in cut.supervisions if s.speaker in self.output_roles) for cut in cuts + ], + "audio_prompt": audio_prompt, + "audio_prompt_lens": audio_prompt_lens, + "system_prompts_raw": system_prompts_raw, + "dataset_type": dataset_type, + "phoneme_tokens": target_phoneme_tokens, + "phoneme_tokens_lens": target_phoneme_lens, + "task": [getattr(cut, "task", "s2s_duplex") for cut in cuts], + } + + if target_codes is not None: + batch_dict["target_codes"] = target_codes + batch_dict["target_codes_lens"] = target_codes_lens + if source_codes is not None: + batch_dict["source_codes"] = source_codes + batch_dict["source_codes_lens"] = source_codes_lens + + return batch_dict + + def maybe_add_audio_prompt( + self, + target_text_tokens: torch.Tensor, + target_token_lens: torch.Tensor, + source_tokens: torch.Tensor, + source_token_lens: torch.Tensor, + target_audio: torch.Tensor, + target_audio_lens: torch.Tensor, + source_audio: torch.Tensor, + source_audio_lens: torch.Tensor, + audio_prompt: torch.Tensor, + audio_prompt_lens: torch.Tensor, + system_prompts: torch.Tensor = None, + system_prompts_lens: torch.Tensor = None, + target_phoneme_tokens: torch.Tensor = None, + target_phoneme_lens: torch.Tensor = None, + source_codes: torch.Tensor = None, + source_codes_lens: torch.Tensor = None, + target_codes: torch.Tensor = None, + target_codes_lens: torch.Tensor = None, + ): + text_pad_id = get_pad_id(self.tokenizer) + + target_text_tokens_ = [] + source_tokens_ = [] + source_audio_ = [] + target_audio_ = [] + prompt_lens = [] + + target_phoneme_tokens_ = [] + phoneme_pad_id = self.phoneme_tokenizer.pad if self.phoneme_tokenizer else -1 + + source_codes_ = [] + target_codes_ = [] + + for i in range(target_text_tokens.size(0)): + if system_prompts is not None: + text_prompt = system_prompts[i][: system_prompts_lens[i]] + else: + text_prompt = torch.tensor( + [self.tokenizer.eos], + dtype=torch.long, + device=target_text_tokens.device, + ) + + if self.add_audio_prompt: + prompt_audio_size = int( + ((self.audio_prompt_duration * self.target_sample_rate) // self.target_samples_per_frame) + * self.target_samples_per_frame + ) + + prompt_audio = sample_audio_segments_repeat( + audio_prompt, audio_prompt_lens, prompt_audio_size, sample=True + ) + + prompt_audio[:, -int(self.target_samples_per_frame * 2) :] = 0 + + prompt_audio_text_pad_size = prompt_audio_size // self.target_samples_per_frame + prompt_audio_text_pad = ( + torch.ones(prompt_audio_text_pad_size, device=target_text_tokens.device, dtype=target_text_tokens.dtype) + * text_pad_id + ) + prompt_audio_text_pad[-1] = self.tokenizer.eos + + new_target_text_tokens = torch.cat( + [text_prompt.to(target_text_tokens.dtype), prompt_audio_text_pad, target_text_tokens[i]] + ) + target_text_tokens_.append(new_target_text_tokens) + target_token_lens[i] += len(text_prompt) + prompt_audio_text_pad_size + + new_source_tokens = torch.cat([text_prompt, prompt_audio_text_pad, source_tokens[i]]) + source_tokens_.append(new_source_tokens) + source_token_lens[i] += len(text_prompt) + prompt_audio_text_pad_size + + if target_phoneme_tokens is not None: + phoneme_pad_size = len(text_prompt) + prompt_audio_text_pad_size + phoneme_pad = torch.full((phoneme_pad_size,), phoneme_pad_id, device=target_phoneme_tokens.device, dtype=target_phoneme_tokens.dtype) + target_phoneme_tokens_.append(torch.cat([phoneme_pad, target_phoneme_tokens[i]])) + target_phoneme_lens[i] += phoneme_pad_size + + code_pad_size = len(text_prompt) + prompt_audio_text_pad_size + if target_codes is not None: + pad_codes = torch.zeros((target_codes.size(1), code_pad_size), device=target_codes.device, dtype=target_codes.dtype) + target_codes_.append(torch.cat([pad_codes, target_codes[i]], dim=1)) + target_codes_lens[i] += code_pad_size + + if source_codes is not None: + pad_codes = torch.zeros((source_codes.size(1), code_pad_size), device=source_codes.device, dtype=source_codes.dtype) + source_codes_.append(torch.cat([pad_codes, source_codes[i]], dim=1)) + source_codes_lens[i] += code_pad_size + + pad_size_src = (len(text_prompt) * self.source_samples_per_frame) + prompt_audio.size(1) + pad_audio_src = torch.zeros(pad_size_src, device=source_audio.device, dtype=source_audio.dtype) + source_audio_.append(torch.cat([pad_audio_src, source_audio[i]])) + source_audio_lens[i] += pad_size_src + + pad_size_tgt = len(text_prompt) * self.target_samples_per_frame + pad_audio_tgt = torch.zeros(pad_size_tgt, device=target_audio.device, dtype=target_audio.dtype) + target_audio_.append(torch.cat([pad_audio_tgt, prompt_audio[i], target_audio[i]])) + target_audio_lens[i] += pad_size_tgt + prompt_audio.size(1) + + prompt_lens.append(len(text_prompt) + prompt_audio_text_pad_size - 1) + + else: + target_text_tokens_.append(torch.cat([text_prompt, target_text_tokens[i]])) + target_token_lens[i] += len(text_prompt) + + source_tokens_.append(torch.cat([text_prompt, source_tokens[i]])) + source_token_lens[i] += len(text_prompt) + + if target_phoneme_tokens is not None: + phoneme_pad_size = len(text_prompt) + phoneme_pad = torch.full((phoneme_pad_size,), phoneme_pad_id, device=target_phoneme_tokens.device, dtype=target_phoneme_tokens.dtype) + target_phoneme_tokens_.append(torch.cat([phoneme_pad, target_phoneme_tokens[i]])) + target_phoneme_lens[i] += phoneme_pad_size + + code_pad_size = len(text_prompt) + if target_codes is not None: + pad_codes = torch.zeros((target_codes.size(1), code_pad_size), device=target_codes.device, dtype=target_codes.dtype) + target_codes_.append(torch.cat([pad_codes, target_codes[i]], dim=1)) + target_codes_lens[i] += code_pad_size + + if source_codes is not None: + pad_codes = torch.zeros((source_codes.size(1), code_pad_size), device=source_codes.device, dtype=source_codes.dtype) + source_codes_.append(torch.cat([pad_codes, source_codes[i]], dim=1)) + source_codes_lens[i] += code_pad_size + + pad_size_src = len(text_prompt) * self.source_samples_per_frame + pad_audio_src = torch.zeros(pad_size_src, device=source_audio.device, dtype=source_audio.dtype) + source_audio_.append(torch.cat([pad_audio_src, source_audio[i]])) + source_audio_lens[i] += pad_size_src + + pad_size_tgt = len(text_prompt) * self.target_samples_per_frame + pad_audio_tgt = torch.zeros(pad_size_tgt, device=target_audio.device, dtype=target_audio.dtype) + target_audio_.append(torch.cat([pad_audio_tgt, target_audio[i]])) + target_audio_lens[i] += pad_size_tgt + + prompt_lens.append(len(text_prompt)) + + target_text_tokens = collate_vectors(target_text_tokens_, padding_value=text_pad_id) + source_tokens = collate_vectors(source_tokens_, padding_value=text_pad_id) + source_audio = collate_vectors(source_audio_, padding_value=0) + target_audio = collate_vectors(target_audio_, padding_value=0) + + if target_phoneme_tokens is not None: + target_phoneme_tokens = collate_vectors(target_phoneme_tokens_, padding_value=phoneme_pad_id) + + if target_codes is not None: + max_len = max([c.size(1) for c in target_codes_]) + target_codes = torch.stack([F.pad(c, (0, max_len - c.size(1))) for c in target_codes_]) + if source_codes is not None: + max_len = max([c.size(1) for c in source_codes_]) + source_codes = torch.stack([F.pad(c, (0, max_len - c.size(1))) for c in source_codes_]) + + return ( + target_text_tokens, + target_token_lens, + source_tokens, + source_token_lens, + source_audio, + source_audio_lens, + target_audio, + target_audio_lens, + prompt_lens, + target_phoneme_tokens, + target_phoneme_lens, + source_codes, + source_codes_lens, + target_codes, + target_codes_lens, + ) + + +def build_phoneme_channel( + cut: Cut, + phoneme_tokenizer, + frame_length: Seconds, + roles: set[str], + ignore_phoneme_languages: list[str], + pad_id: int = -1, + add_text_bos_and_eos_in_each_turn: bool = True, +) -> torch.Tensor: + """ + Build a frame-aligned phoneme sequence for a single cut, mirroring text token logic. + """ + diagnostic = f"Extra info: {cut.id=}" + if getattr(cut, "shard_origin", None) is not None: + diagnostic = f"{diagnostic} {cut.shard_origin=}" + + total = compute_num_frames(cut.duration, frame_length, cut.sampling_rate) + tokens = torch.ones(total, dtype=torch.long) * pad_id + + if cut.has_custom("lang"): + language = cut.lang + else: + language = cut.supervisions[0].language if cut.supervisions[0].has_custom("language") else "en" + + for supervision in cut.supervisions: + if supervision.speaker in roles: + if isinstance(phoneme_tokenizer, IPABPETokenizer): + if not supervision.has_custom("ipa"): + logging.warning(f"'ipa' field not found in cut {cut.id}. Using empty string.") + ipa_text = "" + else: + ipa_text = supervision.ipa + + if language in ignore_phoneme_languages: + ipa_text = "" + else: + ipa_text = supervision.text + + phoneme_ids = phoneme_tokenizer.encode(ipa_text) + if add_text_bos_and_eos_in_each_turn: + phoneme_ids = [phoneme_tokenizer.bos_token_id] + phoneme_ids + + phoneme_ids = torch.as_tensor(phoneme_ids, dtype=torch.long) + + pos = compute_num_frames(supervision.start, frame_length, cut.sampling_rate) + if pos > len(tokens): + logging.warning(f"Supervision offset {pos} larger than {len(tokens)}. {diagnostic}") + continue + + endpos = pos + len(phoneme_ids) + if endpos > len(tokens): + trunc_len = len(tokens) - pos + logging.warning(f"Truncating phoneme_ids by {trunc_len}. {diagnostic}") + phoneme_ids = phoneme_ids[:trunc_len] + + try: + tokens[pos:endpos] = phoneme_ids + except Exception as e: + raise RuntimeError(f"{tokens.shape=} {pos=} {endpos=} {phoneme_ids.shape=} {diagnostic}") from e + + if add_text_bos_and_eos_in_each_turn: + eospos = compute_num_frames(supervision.end, frame_length, cut.sampling_rate) + if eospos < len(tokens): + tokens[eospos] = phoneme_tokenizer.eos_token_id + + return tokens + + +def collate_phoneme_channel( + cuts: CutSet, + phoneme_tokenizer, + frame_length: Seconds, + roles: set[str], + ignore_phoneme_languages: list[str], + add_text_bos_and_eos_in_each_turn: bool = True, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Collate frame-aligned phoneme channels. + """ + pad_id = phoneme_tokenizer.pad + tokens = [ + build_phoneme_channel( + c, + phoneme_tokenizer=phoneme_tokenizer, + frame_length=frame_length, + roles=roles, + ignore_phoneme_languages=ignore_phoneme_languages, + pad_id=pad_id, + add_text_bos_and_eos_in_each_turn=add_text_bos_and_eos_in_each_turn, + ) + for c in cuts + ] + token_lens = torch.tensor([len(tt) for tt in tokens]) + tokens = collate_vectors(tokens, padding_value=pad_id) + return tokens, token_lens + + +def add_speech_delay( + source_audio: torch.Tensor, + source_audio_lens: torch.Tensor, + target_audio: torch.Tensor, + target_audio_lens: torch.Tensor, + num_delay_speech_tokens: int, + target_samples_per_frame: int, + source_samples_per_frame: int, + source_codes: torch.Tensor = None, + source_codes_lens: torch.Tensor = None, + target_codes: torch.Tensor = None, + target_codes_lens: torch.Tensor = None, +): + """ + Apply a speech delay by padding audio waveforms based on the number of delay speech tokens. + """ + extra_target_samples = int(num_delay_speech_tokens * target_samples_per_frame) + target_audio = F.pad(target_audio, (extra_target_samples, 0)) + target_audio_lens = target_audio_lens + extra_target_samples + + extra_source_samples = int(num_delay_speech_tokens * source_samples_per_frame) + source_audio = F.pad(source_audio, (0, extra_source_samples)) + source_audio_lens = source_audio_lens + extra_source_samples + + if target_codes is not None: + target_codes = F.pad(target_codes, (num_delay_speech_tokens, 0)) + target_codes_lens = target_codes_lens + num_delay_speech_tokens + + if source_codes is not None: + source_codes = F.pad(source_codes, (0, num_delay_speech_tokens)) + source_codes_lens = source_codes_lens + num_delay_speech_tokens + + return ( + source_audio, source_audio_lens, target_audio, target_audio_lens, + source_codes, source_codes_lens, target_codes, target_codes_lens + ) + + +def collate_system_prompt( + cuts: CutSet, + tokenizer: TokenizerSpec, + ignore_data_system_prompt: bool = False, + tokenizer_names: list[str] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Collate system prompts from cuts. + System prompts should be stored in cut.custom['system_prompt']. + """ + pad_id = get_pad_id(tokenizer) + tokens = [] + system_prompts_raw = [] + + for i, c in enumerate(cuts): + tok_name = tokenizer_names[i] if tokenizer_names else "english_phoneme" + + def _encode(txt): + if hasattr(tokenizer, "encode"): + try: + return tokenizer.encode(text=txt, tokenizer_name=tok_name) + except TypeError: + return tokenizer.encode(text=txt) + return tokenizer.text_to_ids(txt) + + if c.custom and c.custom.get("system_prompt", None) and not ignore_data_system_prompt: + prompt_text = c.custom["system_prompt"] + tokens.append( + torch.as_tensor( + [tokenizer.bos] + _encode(prompt_text) + [tokenizer.eos], dtype=torch.long + ) + ) + system_prompts_raw.append(prompt_text) + else: + if getattr(c, "type", None): + prompt_text = c.type + tokens.append( + torch.as_tensor( + [tokenizer.bos] + _encode(prompt_text) + [tokenizer.eos], dtype=torch.long + ) + ) + system_prompts_raw.append(prompt_text) + else: + logging.warning( + "No system prompt or dataset type defined on the config! Using a eos token as system prompt!" + ) + tokens.append(torch.as_tensor([tokenizer.eos], dtype=torch.long)) + system_prompts_raw.append("") + + token_lens = torch.tensor([len(tt) for tt in tokens]) + tokens = collate_vectors(tokens, padding_value=pad_id) + return tokens, token_lens, system_prompts_raw + + +def get_audio_prompt( + cuts: CutSet, + target_sample_rate: int, + roles: set[str], + recording_field: str = "target_audio", +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Retrieve an audio prompt for speaker conditioning. + """ + if hasattr(cuts[0], "context_audio"): + audio_prompt = [] + audio_prompt_lens = [] + + for cut in cuts: + ref_audio = cut.context_audio.resample(target_sample_rate).load_audio() + ref_audio = torch.tensor(ref_audio).float() + ref_audio_len = ref_audio.shape[1] + + audio_prompt.append(ref_audio.squeeze(0)) + audio_prompt_lens.append(ref_audio_len) + + audio_prompt = collate_vectors(audio_prompt, padding_value=0).float() + audio_prompt_lens = torch.tensor(audio_prompt_lens).long() + + else: + cuts = sanitize_cuts(cuts) + audio_prompt, audio_prompt_lens = collate_random_turn_audio( + cuts.resample(target_sample_rate, recording_field=recording_field), + roles=roles, + recording_field=recording_field, + ) + + return audio_prompt, audio_prompt_lens + + +def sanitize_cuts(cuts: CutSet) -> CutSet: + """ + Adjusts supervisions to fit within the cut's truncated duration. + """ + sanitized_list = [] + + for cut in cuts: + valid_supervisions = [] + for sup in cut.supervisions: + if sup.start >= cut.duration: + continue + + if sup.end > cut.duration: + new_duration = cut.duration - sup.start + + if new_duration <= 0: + continue + + new_sup = deepcopy(sup) + new_sup.duration = new_duration + valid_supervisions.append(new_sup) + + else: + valid_supervisions.append(sup) + + cut.supervisions = valid_supervisions + sanitized_list.append(cut) + + return cuts.from_cuts(sanitized_list) + + +def collate_random_turn_audio( + cuts: CutSet, + roles: set[str], + recording_field: str = "target_audio", +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Sample and collate reference audio from random speaker turns. + """ + selected_turn_audios = [] + selected_turn_audios_lens = [] + for cut in cuts: + matching_supervisions = [s for s in cut.supervisions if s.speaker in roles] + if len(matching_supervisions) == 0: + target_duration = 5.0 + num_samples = int(target_duration * cut.sampling_rate) + + silence_tensor = torch.zeros(num_samples, dtype=torch.float32) + selected_turn_audios.append(silence_tensor) + selected_turn_audios_lens.append(num_samples) + logging.warning( + "There is no target speaker supervision available on this sample! Using a silence audio as audio prompt!" + ) + else: + selected_supervision = random.choice(matching_supervisions) + truncated_audio = cut.truncate( + offset=max(0, selected_supervision.start), duration=selected_supervision.duration + ).load_custom(recording_field) + + selected_turn_audios.append(truncated_audio.squeeze(0)) + selected_turn_audios_lens.append(truncated_audio.shape[-1]) + + return collate_vectors(selected_turn_audios, padding_value=0), torch.tensor(selected_turn_audios_lens) + + +def collate_token_channel( + cuts: CutSet, + tokenizer: TokenizerSpec, + frame_length: Seconds, + roles: set[str], + add_text_bos_and_eos_in_each_turn: bool = True, + tokenizer_names: list[str] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Build and collate token channels aligned to the audio frame grid. + """ + pad_id = get_pad_id(tokenizer) + tokens = [] + + for i, c in enumerate(cuts): + tok_name = tokenizer_names[i] if tokenizer_names else "english_phoneme" + tokens.append( + build_token_channel( + c, + tokenizer=tokenizer, + frame_length=frame_length, + roles=roles, + pad_id=pad_id, + add_text_bos_and_eos_in_each_turn=add_text_bos_and_eos_in_each_turn, + tokenizer_name=tok_name, + ) + ) + token_lens = torch.tensor([len(tt) for tt in tokens]) + tokens = collate_vectors(tokens, padding_value=pad_id) + return tokens, token_lens + + +def build_token_channel( + cut: Cut, + tokenizer: TokenizerSpec, + frame_length: Seconds, + roles: set[str], + pad_id: int = -1, + add_text_bos_and_eos_in_each_turn: bool = True, + tokenizer_name: str = "english_phoneme", +) -> torch.Tensor: + """ + Build a frame-aligned token sequence for a single cut. + """ + diagnostic = f"Extra info: {cut.id=}" + if getattr(cut, "shard_origin", None) is not None: + diagnostic = f"{diagnostic} {cut.shard_origin=}" + + total = compute_num_frames(cut.duration, frame_length, cut.sampling_rate) + tokens = torch.ones(total, dtype=torch.long) * pad_id + for supervision in cut.supervisions: + if supervision.speaker in roles: + text = supervision.text + + if hasattr(tokenizer, "encode"): + try: + raw_ids = tokenizer.encode(text=text, tokenizer_name=tokenizer_name) + except TypeError: + raw_ids = tokenizer.encode(text) + else: + raw_ids = tokenizer.text_to_ids(text) + + if add_text_bos_and_eos_in_each_turn: + text_ids = torch.as_tensor([tokenizer.bos] + raw_ids) + else: + text_ids = torch.as_tensor(raw_ids) + + pos = compute_num_frames(supervision.start, frame_length, cut.sampling_rate) + if pos > len(tokens): + logging.warning( + f"Ill-constructed example: the beginning offset of a supervision {pos} is larger than the example's length {len(tokens)}. {diagnostic}" + ) + continue + + endpos = pos + len(text_ids) + if endpos > len(tokens): + trunc_len = len(tokens) - pos + logging.warning( + f"Truncating training example's text_ids of length {len(text_ids)} by {trunc_len} because {endpos=} > {len(tokens)=}. {diagnostic}" + ) + text_ids = text_ids[:trunc_len] + try: + tokens[pos:endpos] = text_ids + except Exception as e: + raise RuntimeError(f"{tokens.shape=} {pos=} {endpos=} {text_ids.shape=} {diagnostic}") from e + + if add_text_bos_and_eos_in_each_turn: + eospos = compute_num_frames(supervision.end, frame_length, cut.sampling_rate) + if eospos < len(tokens): + tokens[eospos] = tokenizer.eos + + return tokens + + +def _strip_timestamps( + text: str, _TIMESTAMP_PATTERN=re.compile(r"<\|\d+\|>"), _SPACE_PATTERN=re.compile(r"\s+") +) -> str: + """ + Strips timestamp tokens from text. + """ + text = _TIMESTAMP_PATTERN.sub("", text) + return _SPACE_PATTERN.sub(" ", text).strip() + + +def sample_audio_segments_repeat( + prompt_audio: torch.Tensor, + prompt_audio_lens: torch.Tensor, + n_sample: int, + sample: bool = True, +) -> torch.Tensor: + """ + Extract audio segments of length n_sample. + """ + B, T = prompt_audio.shape + device = prompt_audio.device + out = torch.zeros(B, n_sample, device=device, dtype=prompt_audio.dtype) + + for b in range(B): + length = min(prompt_audio_lens[b].item(), T) + + if length <= 0: + continue + + if length >= n_sample: + if sample: + max_start = max(1, length - n_sample + 1) + start = torch.randint(0, max_start, (1,), device=device).item() + else: + start = 0 + out[b] = prompt_audio[b, start : start + n_sample] + + else: + start = 0 + segment = prompt_audio[b, start:length] + + repeat_times = (n_sample + (length - start) - 1) // (length - start) + repeated = segment.repeat(repeat_times)[:n_sample] + out[b] = repeated + + return out diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py new file mode 100644 index 000000000000..e86779ce15f5 --- /dev/null +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py @@ -0,0 +1,627 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import random +import re +from typing import Dict, List, Union +from copy import deepcopy + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.data +from hydra.utils import instantiate +from lhotse import CutSet, Seconds, compute_num_frames +from lhotse.cut import Cut +from lhotse.dataset.collation import collate_matrices, collate_vectors, collate_audio +from lhotse.utils import ifnone +from omegaconf import DictConfig +from transformers import AutoTokenizer, T5Tokenizer + +from nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers import AggregatedTTSTokenizer, IPABPETokenizer +from nemo.collections.speechlm2.data.utils import get_pad_id +from nemo.collections.speechlm2.parts.precision import fp32_precision +from nemo.collections.tts.parts.utils.tts_dataset_utils import ( + beta_binomial_prior_distribution, + normalize_volume, + stack_tensors, +) +from nemo.utils import logging + + +def setup_tokenizers(all_tokenizers_config, mode='train'): + tokenizers = [] + tokenizer_names = [] + for tokenizer_name in all_tokenizers_config: + tokenizer_config = all_tokenizers_config[tokenizer_name] + if tokenizer_config._target_ == 'AutoTokenizer': + tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.pretrained_model, trust_remote_code=True) + elif tokenizer_config._target_ == 'T5Tokenizer': + tokenizer = T5Tokenizer.from_pretrained(tokenizer_config.pretrained_model) + else: + text_tokenizer_kwargs = {} + if "g2p" in tokenizer_config: + text_tokenizer_kwargs["g2p"] = instantiate(tokenizer_config.g2p) + tokenizer = instantiate(tokenizer_config, **text_tokenizer_kwargs) + if mode == 'test' and hasattr(tokenizer, "set_phone_prob"): + tokenizer.set_phone_prob(1.0) + tokenizers.append(tokenizer) + tokenizer_names.append(tokenizer_name) + + aggregated_tokenizer = AggregatedTTSTokenizer(tokenizers, tokenizer_names) + return aggregated_tokenizer + + +def check_speaker_format(item: str): + pattern = r"\| Language:\w+ Dataset:[\w\d\W]+ Speaker:[\w\d\W]+ \|" + return bool(re.match(pattern, item)) + + +def _strip_timestamps( + text: str, _TIMESTAMP_PATTERN=re.compile(r"<\|\d+\|>"), _SPACE_PATTERN=re.compile(r"\s+") +) -> str: + if text is None: + return "" + text = _TIMESTAMP_PATTERN.sub("", text) + return _SPACE_PATTERN.sub(" ", text).strip() + + +class MagpieTTSLhotseMultiturnDataset(torch.utils.data.Dataset): + """ + A PyTorch Dataset for loading and processing Text-to-Speech data for + MagpieTTS models using Lhotse CutSets, specifically designed for datasets + with text or audio context. But either context can be optional. + + This dataset expects Lhotse Cut objects where each cut represents a + target utterance along with its preceding context. Context can be + audio (preferred) or text. It handles loading either pre-computed audio + codes or raw audio waveforms, applying volume normalization, and tokenizing + text transcripts. Context audio/codes are sliced or repeated to fit within + a specified duration range. Optionally, it loads 16kHz audio suitable for + speaker verification models and calculates alignment priors. + + Tokenizers (for target text and optional context text) are initialized lazily + within each dataloader worker process upon first access. + + Args: + sample_rate (int): Target sample rate for loading audio. Audio will be + resampled if necessary. + volume_norm (bool): If True, applies peak volume normalization to audio + waveforms. Defaults to True. + codec_model_samples_per_frame (int): The total downsampling factor of the + audio codec model used to generate codes. Used for padding audio + and calculating number of codec frames. + num_audio_codebooks (int): Number of codebooks used by the audio codec model. + Needed for creating dummy context codes if necessary. + prior_scaling_factor (Optional[float]): Scaling factor for the beta-binomial + alignment prior calculation. If None, priors are not computed. Defaults to None. + load_cached_codes_if_available (bool): If True, attempts to load pre-computed + audio codes from custom fields in the Lhotse Cut (e.g., 'codes_21fpsCausalDecoder', + 'context_codes_21fpsCausalDecoder'). Falls back to loading audio if codes + are not found. Defaults to True. + dataset_type (str): Specifies the mode ('train' or 'test'), mainly affecting + tokenizer settings like phoneme probability. Defaults to 'train'. + load_16khz_audio (bool): If True, loads 16kHz audio suitable for speaker + verification models. It prioritizes context audio ('context_audio' field) + if available, otherwise uses the target audio ('target_audio' field). + Defaults to True. + pad_context_text_to_max_duration (bool): If True and `use_text_conditioning_tokenizer` + is True, pads the tokenized context text to a length derived from + `context_duration_max`. Defaults to False. + context_duration_min (float): Minimum duration (in seconds) for the context + audio/codes. Context shorter than this will be repeated. Defaults to 3.0. + context_duration_max (float): Maximum duration (in seconds) for the context + audio/codes. Context longer than this will be sliced randomly. Defaults to 10.0. + use_text_conditioning_tokenizer (bool): If True, enables processing of context + text using a separate tokenizer (currently T5Tokenizer). Expects context text + in `cut.supervisions[0].custom['context_text']`. Defaults to False. + tokenizer_config (Optional[DictConfig]): Configuration for the text tokenizers. + Used for lazy initialization within workers. Must be provided if tokenizers + are not set externally. Defaults to None. + text_context_remapping: Dict defining mapping of multiple text contexts to a single text context. + text_context_remapping_prob: Probability of remapping the original text context to a remapped text context. + """ + + def __init__( + self, + sample_rate: int, + volume_norm: bool = True, + codec_model_samples_per_frame: int = None, + num_audio_codebooks: int = None, + prior_scaling_factor: float = None, + load_cached_codes_if_available: bool = True, + dataset_type: str = 'train', + load_16khz_audio: bool = False, + pad_context_text_to_max_duration: bool = False, + context_duration_min: float = 3.0, + context_duration_max: float = 10.0, + use_text_conditioning_tokenizer: bool = False, + text_conditioning_tokenizer_name: str = None, + tokenizer_config: DictConfig = None, + text_context_remapping: Dict[str, str] = None, + text_context_remapping_prob: float = 0.0, + phoneme_tokenizer_config: DictConfig = None, + ignore_phoneme_languages: List[str] = None, + add_language_to_context_text: bool = False, + source_sample_rate: int = 16000, + input_roles: List[str] = ["user", "User"], + output_roles: List[str] = ["assistant", "Assistant", "agent", "Agent"], + add_text_bos_and_eos_in_each_turn: bool = False, + ): + super().__init__() + self.sample_rate = sample_rate + self.volume_norm = volume_norm + + self.codec_model_samples_per_frame = codec_model_samples_per_frame + self.num_audio_codebooks = num_audio_codebooks + + self.include_align_prior = prior_scaling_factor is not None + self.prior_scaling_factor = prior_scaling_factor + self.load_cached_codes_if_available = load_cached_codes_if_available + self.dataset_type = dataset_type + self.load_16khz_audio = load_16khz_audio + self.use_text_conditioning_tokenizer = use_text_conditioning_tokenizer + self.text_conditioning_tokenizer_name = text_conditioning_tokenizer_name + self.pad_context_text_to_max_duration = pad_context_text_to_max_duration + self.context_duration_min = context_duration_min + self.context_duration_max = context_duration_max + self.tokenizer_config = tokenizer_config + self.text_tokenizer = None + self.phoneme_tokenizer = None + self.text_context_remapping = text_context_remapping + self.text_context_remapping_prob = text_context_remapping_prob + self.phoneme_tokenizer_config = phoneme_tokenizer_config + self.ignore_phoneme_languages = ignore_phoneme_languages or [] + self.add_language_to_context_text = add_language_to_context_text + + self.source_sample_rate = source_sample_rate + self.input_roles = set(ifnone(input_roles, ["user"])) + self.output_roles = set(ifnone(output_roles, ["agent"])) + self.add_text_bos_and_eos_in_each_turn = add_text_bos_and_eos_in_each_turn + + self.frame_length = self.codec_model_samples_per_frame / self.sample_rate + + def get_num_audio_samples_to_slice(self, duration, sample_rate): + num_codec_frames = int(duration * sample_rate / self.codec_model_samples_per_frame) + num_audio_samples = num_codec_frames * self.codec_model_samples_per_frame + return num_audio_samples + + def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: + if self.text_tokenizer is None: + worker_info = torch.utils.data.get_worker_info() + worker_id = worker_info.id if worker_info is not None else 0 + logging.info(f"Worker {worker_id} initializing tokenizers...") + self.text_tokenizer = setup_tokenizers( + all_tokenizers_config=self.tokenizer_config, + mode=self.dataset_type, + ) + self.bos_id = len(self.text_tokenizer.tokens) + self.eos_id = self.bos_id + 1 + self.pad_id = self.text_tokenizer.pad + + if self.phoneme_tokenizer is None and self.phoneme_tokenizer_config is not None: + self.phoneme_tokenizer = instantiate(self.phoneme_tokenizer_config) + + cuts = cuts.transform_text(_strip_timestamps) + + batch_tokenizer_names = [] + for cut in cuts: + if cut.has_custom("tokenizer_names"): + batch_tokenizer_names.append(random.choice(cut.tokenizer_names)) + else: + batch_tokenizer_names.append("english_phoneme") + + with fp32_precision(): + source_audio, source_audio_lens = collate_audio(cuts.resample(self.source_sample_rate)) + target_audio, target_audio_lens = collate_audio( + cuts.resample(self.sample_rate, recording_field="target_audio"), recording_field="target_audio" + ) + + # Apply volume norm if requested + if self.volume_norm: + source_audio = torch.stack([torch.from_numpy(normalize_volume(a.numpy())) for a in source_audio]) + target_audio = torch.stack([torch.from_numpy(normalize_volume(a.numpy())) for a in target_audio]) + + target_text_tokens, target_token_lens = collate_token_channel( + cuts, self.text_tokenizer, self.frame_length, roles=self.output_roles, + add_text_bos_and_eos_in_each_turn=self.add_text_bos_and_eos_in_each_turn, + tokenizer_names=batch_tokenizer_names, + ) + source_tokens, source_token_lens = collate_token_channel( + cuts, self.text_tokenizer, self.frame_length, roles=self.input_roles, + add_text_bos_and_eos_in_each_turn=self.add_text_bos_and_eos_in_each_turn, + tokenizer_names=batch_tokenizer_names, + ) + + if self.phoneme_tokenizer is not None: + target_phoneme_tokens, target_phoneme_lens = collate_phoneme_channel( + cuts, self.phoneme_tokenizer, self.frame_length, roles=self.output_roles, + ignore_phoneme_languages=self.ignore_phoneme_languages, + add_text_bos_and_eos_in_each_turn=False, + ) + else: + target_phoneme_tokens, target_phoneme_lens = None, None + + + dataset_name_list = [] + audio_list_16khz = [] + audio_len_list_16khz = [] + prior_list = [] + + target_codes_list = [] + source_codes_list = [] + + context_audio_list = [] + context_audio_len_list = [] + context_audio_codes_list = [] + context_audio_codes_len_list = [] + context_text_tokens_list = [] + context_text_tokens_len_list = [] + context_has_text_context_list = [] + reward_list = [] + language_list = [] + + def _sample_context_duration_with_available_limit(available_duration_sec: float) -> float: + effective_duration_max = min(self.context_duration_max, available_duration_sec) + effective_duration_max = max(self.context_duration_min, effective_duration_max) + return random.uniform(self.context_duration_min, effective_duration_max) + + for cut in cuts: + speaker_found = False + for sup in reversed(cut.supervisions): + if check_speaker_format(sup.speaker): + dataset_name = sup.speaker.strip().split()[2].split(":")[-1] + speaker_found = True + break + + if not speaker_found: + dataset_name = "unknown" + dataset_name_list.append(dataset_name) + + language = cut.lang if cut.has_custom("lang") else next((sup.language for sup in reversed(cut.supervisions) if sup.has_custom("language")), "en") + language_list.append(language) + + # Target and Source Codes + if self.load_cached_codes_if_available: + if cut.has_custom("target_codes"): + target_codes_list.append(torch.from_numpy(cut.target_codes.load().astype(np.int32)).T) + if cut.has_custom("source_codes"): + source_codes_list.append(torch.from_numpy(cut.source_codes.load().astype(np.int32)).T) + + # Context Audio or Context Codes + if self.load_cached_codes_if_available and cut.has_custom("context_codes"): + context_audio_codes_array = cut.context_codes.load().astype(np.int32) + context_audio_codes = torch.from_numpy(context_audio_codes_array) + _available_context_duration = (context_audio_codes.shape[1] * self.codec_model_samples_per_frame / self.sample_rate) + _context_duration_to_slice = _sample_context_duration_with_available_limit(_available_context_duration) + _num_frames_to_slice = int(_context_duration_to_slice * self.sample_rate / self.codec_model_samples_per_frame) + + if _num_frames_to_slice < context_audio_codes.shape[1]: + start_idx = random.randint(0, context_audio_codes.shape[1] - _num_frames_to_slice) + context_audio_codes = context_audio_codes[:, start_idx : start_idx + _num_frames_to_slice] + else: + _num_repeats = int(np.ceil(_num_frames_to_slice / context_audio_codes.shape[1])) + context_audio_codes = context_audio_codes.repeat(1, _num_repeats)[:, :_num_frames_to_slice] + + context_audio_codes_list.append(context_audio_codes.T) + context_audio_codes_len_list.append(context_audio_codes.T.shape[0]) + + elif cut.has_custom("context_audio"): + with fp32_precision(): + context_audio_array = cut.context_audio.resample(self.sample_rate).load_audio().squeeze(0) + if self.volume_norm: + context_audio_array = normalize_volume(context_audio_array) + + _available_context_duration = len(context_audio_array) / self.sample_rate + _context_duration_to_slice = _sample_context_duration_with_available_limit(_available_context_duration) + _num_samples_to_slice = self.get_num_audio_samples_to_slice(_context_duration_to_slice, self.sample_rate) + + if _num_samples_to_slice < len(context_audio_array): + start_idx = random.randint(0, len(context_audio_array) - _num_samples_to_slice) + context_audio_array = context_audio_array[start_idx : start_idx + _num_samples_to_slice] + else: + _num_repeats = int(np.ceil(_num_samples_to_slice / len(context_audio_array))) + context_audio_array = np.tile(context_audio_array, _num_repeats)[:_num_samples_to_slice] + + context_audio = torch.from_numpy(context_audio_array) + context_audio_list.append(context_audio) + context_audio_len_list.append(context_audio.shape[0]) + + else: + matching_supervisions = [s for s in cut.supervisions if s.speaker in self.output_roles] + + if self.load_cached_codes_if_available: + if len(matching_supervisions) > 0 and cut.has_custom("target_codes"): + sup = random.choice(matching_supervisions) + codes_array = cut.target_codes.load().astype(np.int32) + start_frame = int(max(0, sup.start) * self.sample_rate / self.codec_model_samples_per_frame) + num_frames = int(sup.duration * self.sample_rate / self.codec_model_samples_per_frame) + context_audio_codes = torch.from_numpy(codes_array)[:, start_frame : start_frame + num_frames].T + else: + context_audio_codes = torch.zeros([0, self.num_audio_codebooks], dtype=torch.int32) + context_audio_codes_list.append(context_audio_codes) + context_audio_codes_len_list.append(context_audio_codes.shape[0]) + else: + if len(matching_supervisions) > 0: + sup = random.choice(matching_supervisions) + with fp32_precision(): + turn_cut = cut.resample(self.sample_rate, recording_field="target_audio").truncate(offset=max(0, sup.start), duration=sup.duration) + context_audio_array = turn_cut.load_custom("target_audio").squeeze(0) + if self.volume_norm: + context_audio_array = normalize_volume(context_audio_array) + context_audio = torch.from_numpy(context_audio_array) + else: + context_audio = torch.zeros(self.codec_model_samples_per_frame, dtype=torch.float32) + + context_audio_list.append(context_audio) + context_audio_len_list.append(context_audio.shape[0]) + + # 16khz audio for SV + if self.load_16khz_audio: + with fp32_precision(): + if cut.has_custom("context_audio"): + audio_array_16khz = cut.context_audio.resample(16_000).load_audio().squeeze(0) + if self.volume_norm: + audio_array_16khz = normalize_volume(audio_array_16khz) + + _available_context_duration = len(audio_array_16khz) / 16_000 + _context_duration_to_slice = _sample_context_duration_with_available_limit(_available_context_duration) + _num_samples_to_slice = int(_context_duration_to_slice * 16_000) + if _num_samples_to_slice < len(audio_array_16khz): + start_idx = random.randint(0, len(audio_array_16khz) - _num_samples_to_slice) + audio_array_16khz = audio_array_16khz[start_idx : start_idx + _num_samples_to_slice] + else: + matching_supervisions = [s for s in cut.supervisions if s.speaker in self.output_roles] + if len(matching_supervisions) > 0: + sup = random.choice(matching_supervisions) + turn_cut = cut.resample(16_000, recording_field="target_audio").truncate(offset=max(0, sup.start), duration=sup.duration) + audio_array_16khz = turn_cut.load_custom("target_audio").squeeze(0) + else: + audio_array_16khz = np.zeros(16000, dtype=np.float32) + + if self.volume_norm: + audio_array_16khz = normalize_volume(audio_array_16khz) + + audio_16khz = torch.from_numpy(audio_array_16khz) + audio_list_16khz.append(audio_16khz) + audio_len_list_16khz.append(audio_16khz.shape[0]) + + # Context Text + if self.use_text_conditioning_tokenizer: + context_text = next((sup.context_text for sup in cut.supervisions if sup.has_custom("context_text")), None) + if context_text is not None: + if self.text_context_remapping is not None and context_text in self.text_context_remapping: + if self.dataset_type == 'train' and random.random() < self.text_context_remapping_prob: + context_text = self.text_context_remapping[context_text] + context_text_tokens = self.text_tokenizer.encode(context_text, tokenizer_name=self.text_conditioning_tokenizer_name) + has_text_context = True + else: + context_text = f"[{language.upper()}]" if self.add_language_to_context_text else "[NO TEXT CONTEXT]" + context_text_tokens = self.text_tokenizer.encode(context_text, tokenizer_name=self.text_conditioning_tokenizer_name) + has_text_context = False + + if self.pad_context_text_to_max_duration: + _required_len = int(self.context_duration_max * self.sample_rate / self.codec_model_samples_per_frame) + 2 + if len(context_text_tokens) < _required_len: + _pad_id = self.text_tokenizer.tokenizer_pad_ids[self.text_conditioning_tokenizer_name] + context_text_tokens += [_pad_id] * (_required_len - len(context_text_tokens)) + else: + context_text_tokens = context_text_tokens[:_required_len] + + context_text_tokens = torch.tensor(context_text_tokens, dtype=torch.int32) + context_text_tokens_list.append(context_text_tokens) + context_text_tokens_len_list.append(context_text_tokens.shape[0]) + context_has_text_context_list.append(has_text_context) + + # Align Prior (Note: Using full target length to preserve shape compatibility) + if self.include_align_prior: + full_text_len = sum([len(self.text_tokenizer.encode(sup.text)) for sup in cut.supervisions if sup.speaker in self.output_roles]) + spec_len = int(cut.duration * self.sample_rate / self.codec_model_samples_per_frame) + 1 + align_prior = beta_binomial_prior_distribution(phoneme_count=full_text_len, mel_count=spec_len, scaling_factor=self.prior_scaling_factor) + prior_list.append(torch.tensor(align_prior, dtype=torch.float32)) + + reward = next((sup.reward for sup in reversed(cut.supervisions) if sup.has_custom("reward")), None) + if reward is not None: + reward_list.append(reward) + + # --- ASSEMBLE FINAL BATCH DICTIONARY --- + batch_dict = { + "sample_id": [str(cut.id) for cut in cuts], + "dataset_names": dataset_name_list, + "languages": language_list, + "source_audio": source_audio, + "source_audio_lens": source_audio_lens, + "target_audio": target_audio, + "target_audio_lens": target_audio_lens, + "source_tokens": source_tokens, + "source_token_lens": source_token_lens, + "target_text_tokens": target_text_tokens, + "target_token_lens": target_token_lens, + "target_texts": [" ".join(s.text for s in cut.supervisions if s.speaker in self.output_roles) for cut in cuts], + "dataset_type": [getattr(c, "type", "") for c in cuts], + } + + if target_codes_list: + batch_dict["target_codes"] = collate_matrices(target_codes_list, padding_value=0).transpose(1, 2) + batch_dict["target_codes_lens"] = torch.IntTensor([c.shape[0] for c in target_codes_list]) + + if source_codes_list: + batch_dict["source_codes"] = collate_matrices(source_codes_list, padding_value=0).transpose(1, 2) + batch_dict["source_codes_lens"] = torch.IntTensor([c.shape[0] for c in source_codes_list]) + + if self.phoneme_tokenizer is not None: + batch_dict["phoneme_tokens"] = target_phoneme_tokens + batch_dict["phoneme_tokens_lens"] = target_phoneme_lens + + if len(audio_list_16khz) > 0: + batch_dict["audio_16khz"] = collate_vectors(audio_list_16khz, padding_value=0.0) + batch_dict["audio_lens_16khz"] = torch.IntTensor(audio_len_list_16khz) + + if len(context_audio_list) > 0: + batch_dict["context_audio"] = collate_vectors(context_audio_list, padding_value=0.0) + batch_dict["context_audio_lens"] = torch.IntTensor(context_audio_len_list) + + if len(context_audio_codes_list) > 0: + batch_dict["context_audio_codes"] = collate_matrices(context_audio_codes_list, padding_value=0).transpose(1, 2) + batch_dict["context_audio_codes_lens"] = torch.IntTensor(context_audio_codes_len_list) + + if self.use_text_conditioning_tokenizer: + batch_dict['context_text_tokens'] = collate_vectors( + tensors=context_text_tokens_list, + padding_value=self.text_tokenizer.tokenizer_pad_ids[self.text_conditioning_tokenizer_name], + ) + batch_dict['context_text_tokens_lens'] = torch.IntTensor(context_text_tokens_len_list) + batch_dict['has_text_context'] = torch.BoolTensor(context_has_text_context_list) + + if self.include_align_prior: + spec_max_len = max([prior.shape[0] for prior in prior_list]) + text_max_len = max([prior.shape[1] for prior in prior_list]) + batch_dict["align_prior_matrix"] = stack_tensors(prior_list, max_lens=[text_max_len, spec_max_len]) + + if len(reward_list) > 0: + batch_dict['rewards'] = torch.FloatTensor(reward_list) + + return batch_dict + + +def collate_token_channel( + cuts: CutSet, + tokenizer, + frame_length: Seconds, + roles: set[str], + add_text_bos_and_eos_in_each_turn: bool = True, + tokenizer_names: list[str] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Build and collate token channels aligned to the audio frame grid.""" + pad_id = getattr(tokenizer, 'pad', -1) + tokens = [] + + for i, c in enumerate(cuts): + tok_name = tokenizer_names[i] if tokenizer_names else "english_phoneme" + tokens.append( + build_token_channel( + c, tokenizer, frame_length, roles, pad_id, + add_text_bos_and_eos_in_each_turn, tok_name + ) + ) + token_lens = torch.tensor([len(tt) for tt in tokens]) + return collate_vectors(tokens, padding_value=pad_id), token_lens + + +def build_token_channel( + cut: Cut, + tokenizer, + frame_length: Seconds, + roles: set[str], + pad_id: int = -1, + add_text_bos_and_eos_in_each_turn: bool = True, + tokenizer_name: str = "english_phoneme", +) -> torch.Tensor: + total = compute_num_frames(cut.duration, frame_length, cut.sampling_rate) + tokens = torch.ones(total, dtype=torch.long) * pad_id + bos_id = getattr(tokenizer, 'bos', 0) + eos_id = getattr(tokenizer, 'eos', 1) + + for supervision in cut.supervisions: + if supervision.speaker in roles: + text = supervision.text + + if hasattr(tokenizer, "encode"): + try: + raw_ids = tokenizer.encode(text=text, tokenizer_name=tokenizer_name) + except TypeError: + raw_ids = tokenizer.encode(text) + else: + raw_ids = tokenizer.text_to_ids(text) + + text_ids = torch.as_tensor([bos_id] + raw_ids) if add_text_bos_and_eos_in_each_turn else torch.as_tensor(raw_ids) + + pos = compute_num_frames(supervision.start, frame_length, cut.sampling_rate) + if pos >= len(tokens): + continue + + endpos = pos + len(text_ids) + if endpos > len(tokens): + text_ids = text_ids[:len(tokens) - pos] + tokens[pos:pos+len(text_ids)] = text_ids + + if add_text_bos_and_eos_in_each_turn: + eospos = compute_num_frames(supervision.end, frame_length, cut.sampling_rate) + if eospos < len(tokens): + tokens[eospos] = eos_id + + return tokens + + +def collate_phoneme_channel( + cuts: CutSet, + phoneme_tokenizer, + frame_length: Seconds, + roles: set[str], + ignore_phoneme_languages: list[str], + add_text_bos_and_eos_in_each_turn: bool = True, +) -> tuple[torch.Tensor, torch.Tensor]: + pad_id = phoneme_tokenizer.pad + tokens = [ + build_phoneme_channel( + c, phoneme_tokenizer, frame_length, roles, + ignore_phoneme_languages, pad_id, add_text_bos_and_eos_in_each_turn + ) for c in cuts + ] + token_lens = torch.tensor([len(tt) for tt in tokens]) + return collate_vectors(tokens, padding_value=pad_id), token_lens + + +def build_phoneme_channel( + cut: Cut, + phoneme_tokenizer, + frame_length: Seconds, + roles: set[str], + ignore_phoneme_languages: list[str], + pad_id: int = -1, + add_text_bos_and_eos_in_each_turn: bool = True, +) -> torch.Tensor: + total = compute_num_frames(cut.duration, frame_length, cut.sampling_rate) + tokens = torch.ones(total, dtype=torch.long) * pad_id + + language = cut.lang if cut.has_custom("lang") else next((sup.language for sup in cut.supervisions if sup.has_custom("language")), "en") + + for supervision in cut.supervisions: + if supervision.speaker in roles: + if isinstance(phoneme_tokenizer, IPABPETokenizer): + ipa_text = supervision.ipa if supervision.has_custom("ipa") else "" + if language in ignore_phoneme_languages: + ipa_text = "" + else: + ipa_text = supervision.text + + phoneme_ids = phoneme_tokenizer.encode(ipa_text) + if add_text_bos_and_eos_in_each_turn: + phoneme_ids = [phoneme_tokenizer.bos_token_id] + phoneme_ids + + phoneme_ids = torch.as_tensor(phoneme_ids, dtype=torch.long) + pos = compute_num_frames(supervision.start, frame_length, cut.sampling_rate) + if pos >= len(tokens): + continue + + endpos = pos + len(phoneme_ids) + if endpos > len(tokens): + phoneme_ids = phoneme_ids[:len(tokens) - pos] + tokens[pos:pos+len(phoneme_ids)] = phoneme_ids + + if add_text_bos_and_eos_in_each_turn: + eospos = compute_num_frames(supervision.end, frame_length, cut.sampling_rate) + if eospos < len(tokens): + tokens[eospos] = phoneme_tokenizer.eos_token_id + + return tokens diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index ffe4a0cfd055..9111dcde9b33 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -32,6 +32,8 @@ from nemo.collections.asr.parts.mixins.transcription import TranscribeConfig from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config from nemo.collections.tts.data.text_to_speech_dataset_lhotse import MagpieTTSLhotseDataset, setup_tokenizers +from nemo.collections.tts.data.text_to_speech_dataset_lhotse_multiturn import MagpieTTSLhotseMultiturnDataset + from nemo.collections.tts.models.easy_magpietts_inference import EasyMagpieTTSInferenceModel, TrainingMode from nemo.collections.tts.modules.magpietts_modules import ( LocalTransformerType, @@ -1393,27 +1395,53 @@ def get_dataset(self, dataset_cfg, dataset_type): def get_lhotse_dataloader(self, dataset_cfg, mode='train') -> torch.utils.data.DataLoader: # TODO @xueyang: better to distinguish cfg. self.cfg is the model cfg, while cfg here is train_ds cfg. Also # cfg is a classifier-free guidance. - dataset = MagpieTTSLhotseDataset( - sample_rate=self.sample_rate, - volume_norm=dataset_cfg.volume_norm, - codec_model_samples_per_frame=self.codec_model_samples_per_frame, - num_audio_codebooks=self.data_num_audio_codebooks, - prior_scaling_factor=0.0, - load_cached_codes_if_available=self.cfg.load_cached_codes_if_available, - dataset_type=mode, # train or test used for setting phone prob to 1.0 in test dataset (worker_init_fn) - load_16khz_audio=False, - pad_context_text_to_max_duration=self.pad_context_text_to_max_duration, - context_duration_min=self.cfg.context_duration_min, - context_duration_max=self.cfg.context_duration_max, - use_text_conditioning_tokenizer=True, - text_conditioning_tokenizer_name=self.text_conditioning_tokenizer_name, - tokenizer_config=self.cfg.text_tokenizers, - phoneme_tokenizer_config=self.cfg.get("phoneme_tokenizer", None), - ignore_phoneme_languages=self.cfg.get("ignore_phoneme_languages", []), - phoneme_as_text_prob=self.phoneme_as_text_prob if mode == 'train' else 0.0, - pronunciation_control_g2p=self.cfg.get("pronunciation_control_g2p", None), - add_language_to_context_text=self.add_language_to_context_text, - ) + if self.cfg.get("use_multiturn_dataset", False): + dataset = MagpieTTSLhotseMultiturnDataset( + sample_rate=self.sample_rate, + volume_norm=dataset_cfg.volume_norm, + codec_model_samples_per_frame=self.codec_model_samples_per_frame, + num_audio_codebooks=self.data_num_audio_codebooks, + prior_scaling_factor=0.0, + load_cached_codes_if_available=self.cfg.load_cached_codes_if_available, + dataset_type=mode, # train or test used for setting phone prob to 1.0 in test dataset (worker_init_fn) + load_16khz_audio=False, + pad_context_text_to_max_duration=self.pad_context_text_to_max_duration, + context_duration_min=self.cfg.context_duration_min, + context_duration_max=self.cfg.context_duration_max, + use_text_conditioning_tokenizer=True, + text_conditioning_tokenizer_name=self.text_conditioning_tokenizer_name, + tokenizer_config=self.cfg.text_tokenizers, + phoneme_tokenizer_config=self.cfg.get("phoneme_tokenizer", None), + ignore_phoneme_languages=self.cfg.get("ignore_phoneme_languages", []), + add_language_to_context_text=self.add_language_to_context_text, + source_sample_rate=self.sample_rate, + input_roles=["user", "User"], + output_roles=["assistant", "Assistant", "agent", "Agent"], + add_text_bos_and_eos_in_each_turn=self.cfg.get("add_text_bos_and_eos_in_each_turn", False), + # pronunciation_control_g2p=self.cfg.get("pronunciation_control_g2p", None), + ) + else: + dataset = MagpieTTSLhotseDataset( + sample_rate=self.sample_rate, + volume_norm=dataset_cfg.volume_norm, + codec_model_samples_per_frame=self.codec_model_samples_per_frame, + num_audio_codebooks=self.data_num_audio_codebooks, + prior_scaling_factor=0.0, + load_cached_codes_if_available=self.cfg.load_cached_codes_if_available, + dataset_type=mode, # train or test used for setting phone prob to 1.0 in test dataset (worker_init_fn) + load_16khz_audio=False, + pad_context_text_to_max_duration=self.pad_context_text_to_max_duration, + context_duration_min=self.cfg.context_duration_min, + context_duration_max=self.cfg.context_duration_max, + use_text_conditioning_tokenizer=True, + text_conditioning_tokenizer_name=self.text_conditioning_tokenizer_name, + tokenizer_config=self.cfg.text_tokenizers, + phoneme_tokenizer_config=self.cfg.get("phoneme_tokenizer", None), + ignore_phoneme_languages=self.cfg.get("ignore_phoneme_languages", []), + phoneme_as_text_prob=self.phoneme_as_text_prob if mode == 'train' else 0.0, + pronunciation_control_g2p=self.cfg.get("pronunciation_control_g2p", None), + add_language_to_context_text=self.add_language_to_context_text, + ) data_loader = get_lhotse_dataloader_from_config( config=dataset_cfg.dataset, @@ -1421,6 +1449,7 @@ def get_lhotse_dataloader(self, dataset_cfg, mode='train') -> torch.utils.data.D world_size=self.world_size, dataset=dataset, ) + return data_loader def setup_training_data(self, dataset_cfg): @@ -1452,7 +1481,7 @@ def setup_training_data(self, dataset_cfg): ) def _setup_test_dataloader(self, dataset_cfg) -> torch.utils.data.DataLoader: - if dataset_cfg.get("use_lhotse", False): + if self.cfg.get("use_lhotse", False): data_loader = self.get_lhotse_dataloader(dataset_cfg, mode='test') else: dataset = self.get_dataset(dataset_cfg, dataset_type='test') From ca4ca991ec4bbb9572cd144a3155b5c09bbb7d5b Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 14 Apr 2026 09:30:52 -0700 Subject: [PATCH 009/109] Update multiturn dataloader Signed-off-by: Edresson Casanova --- nemo/collections/common/data/lhotse/cutset.py | 323 +++--- ...to_speech_dataset_lhotse_multiturn copy.py | 953 ------------------ ...text_to_speech_dataset_lhotse_multiturn.py | 21 +- nemo/collections/tts/models/easy_magpietts.py | 4 +- 4 files changed, 179 insertions(+), 1122 deletions(-) delete mode 100644 nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn copy.py diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index 02db79a0a153..5ea4065d5fc1 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -812,131 +812,92 @@ def cut_to_conversation( ) -@data_type_parser(["s2s_duplex_overlap_as_s2s_duplex"]) -def read_s2s_duplex_overlap_as_s2s_duplex(config) -> Tuple[CutSet, bool]: - """ - Convert a CutSet with overlapping agent/user segments into a standard S2S duplex format. - - Use Case: - This parser is designed for conversational data where agent and user speech can overlap - in time (e.g., natural turn-taking with interruptions or backchanneling). The input - format stores agent and user segments separately as `agent_segments` and `user_segments` - attributes on each cut. This function converts them into a unified timeline of sequential - SupervisionSegments, which is the standard format expected by DuplexS2S models. - - Expected Input Data Format: - Each cut should have: - - cut.agent_segments: List[Dict] with keys: - - "start" (float): Start time in seconds - - "end" (float): End time in seconds - - "text" (str): Agent's transcription - - cut.user_segments: List[Dict] with keys: - - "start" (float): Start time in seconds - - "end" (float): End time in seconds - - "text" (str): User's transcription +def _filter_cer_fn(cut: Cut, max_cer: float) -> bool: + return ( + len(cut.supervisions) == 0 + or not cut.supervisions[0].has_custom("cer") + or cut.supervisions[0].cer <= max_cer + ) - Example: - Input cut with overlapping segments: - cut.agent_segments = [ - {"start": 0.5, "end": 2.0, "text": "Hello, how can I help?"}, - {"start": 3.0, "end": 4.5, "text": "Sure, I can do that."} - ] - cut.user_segments = [ - {"start": 1.8, "end": 3.2, "text": "I need assistance"}, - {"start": 4.0, "end": 5.5, "text": "Thank you"} - ] - - Output cut.supervisions (sorted by start time): - [ - SupervisionSegment(start=0.5, duration=1.5, text="Hello, how can I help?", speaker="agent"), - SupervisionSegment(start=1.8, duration=1.4, text="I need assistance", speaker="user"), - SupervisionSegment(start=3.0, duration=1.5, text="Sure, I can do that.", speaker="agent"), - SupervisionSegment(start=4.0, duration=1.5, "Thank you", speaker="user") - ] +def _filter_val_flag_fn(cut: Cut, keep_flag: str) -> bool: + return not cut.has_custom("validation_status") or cut.validation_status == keep_flag - Args: - config: Dictionary containing parser options: - - move_agent_text_back_by (float): Time offset to shift agent text back (default: 0). - Useful for aligning agent text with earlier audio timing. - - filter_samples_starting_with_agent (bool): Whether to remove samples starting with agent (default: False). - When True, only keeps samples where the first speaker is a user. - - agent_roles (List[str]): Roles considered as agent (default: ["agent", "Assistant", "assistant"]). +def _filter_secs_fn(cut: Cut, min_sim: float) -> bool: + return ( + len(cut.supervisions) == 0 + or not cut.supervisions[0].has_custom("context_speaker_similarity") + or cut.supervisions[0].context_speaker_similarity >= min_sim + ) - Returns: - Tuple[CutSet, bool]: Converted cuts with unified supervisions, and a flag indicating if the data was tarred. - """ - move_agent_text_back_by = config.get("move_agent_text_back_by", 0) - filter_samples_starting_with_agent = config.get("filter_samples_starting_with_agent", False) - agent_roles = config.get("agent_roles", ["agent", "Assistant", "assistant"]) +def _filter_target_speaker_fn(cut: Cut, target_speaker: str) -> bool: + return len(cut.supervisions) == 0 or target_speaker is None or target_speaker in cut.supervisions[0].speaker - cuts, is_tarred = read_cutset_from_config(config) +def _create_recording_from_array(samples: np.ndarray, sampling_rate: int, recording_id: str) -> Recording: + with io.BytesIO() as buffer: + sf.write(buffer, samples.T, samplerate=sampling_rate, format='WAV') + buffer.seek(0) + return Recording.from_bytes(buffer.read(), recording_id=recording_id) - def filter_cuts_starting_with_agent_fn(cuts: CutSet, agent_roles: Tuple[str, ...]) -> CutSet: - """Remove cuts where the first supervision belongs to an agent role.""" - - def _filter_fn(cut: Cut) -> bool: - if not cut.supervisions: - return False - cut.supervisions = sorted(cut.supervisions, key=lambda s: s.start) - return cut.supervisions[0].speaker not in agent_roles - - return cuts.filter(_filter_fn) - - def convert_overlap_cut_fn(cut: Cut) -> Cut: - """Convert agent/user overlapping segments into sequential SupervisionSegments.""" - agent_segments = [ - SupervisionSegment( - id=cut.id, - recording_id=cut.id, - start=seg["start"] - move_agent_text_back_by, - duration=seg["end"] - seg["start"] + move_agent_text_back_by, - text=seg["text"], - speaker="agent", - ) - for seg in cut.agent_segments - ] - - user_segments = [ - SupervisionSegment( - id=cut.id, - recording_id=cut.id, - start=seg["start"], - duration=seg["end"] - seg["start"], - text=seg["text"], - speaker="user", - ) - for seg in cut.user_segments - ] +def _convert_cut_fn(cut: Cut, sample_rate: int, add_extra_end_sil: bool, extra_end_silence_range: list) -> Cut: + orig_agent_sup = fastcopy(cut.supervisions[0]) + target_audio_orig_dur = cut.target_audio.duration - cut.supervisions = sorted(agent_segments + user_segments, key=lambda s: s.start) - cut.task = "s2s_duplex_overlap_as_s2s_duplex" - return cut + cut.target_audio = cut.target_audio.resample(sample_rate) + cut.context_audio = cut.context_audio.resample(sample_rate) + total_duration = cut.target_audio.duration - cuts = cuts.map(convert_overlap_cut_fn) - if filter_samples_starting_with_agent: - cuts = filter_cuts_starting_with_agent_fn(cuts, tuple(agent_roles)) + cut_target = MonoCut( + id=f"{cut.id}_target", + start=0.0, + duration=total_duration, + channel=0, + recording=cut.target_audio, + supervisions=[], + ) - return cuts, is_tarred + zero_audio = np.zeros((1, int(total_duration * sample_rate)), dtype=np.float32) + source_recording = _create_recording_from_array(zero_audio, sample_rate, recording_id=f"{cut.id}_source") + + cut_source = MonoCut( + id=f"{cut.id}_source", + start=0.0, + duration=total_duration, + channel=0, + recording=source_recording, + supervisions=[], + custom=deepcopy(cut.custom) if cut.custom is not None else None, + ) + cut_source = cut_source.move_to_memory(audio_format='wav') + cut_target = cut_target.move_to_memory(audio_format='wav') -@data_type_parser(["lhotse_magpietts_data_as_continuation"]) -def read_lhotse_magpietts_data_as_s2s_duplex(config) -> Tuple[CutSet, bool]: - """ - Convert MagpieTTS dataset cuts into the Duplex S2S format, with optional - `context_audio` that can be used as a speaker reference. + user_sup = fastcopy(orig_agent_sup, start=0.0, duration=0.08, speaker="user", text="dummy text") + agent_sup = fastcopy(orig_agent_sup, start=0.0, duration=target_audio_orig_dur - 0.08, speaker="agent") - Args: - config: Dictionary containing parser options: - - add_extra_end_silence (bool): Whether to add extra silence at the end. - - extra_end_silence_range (List[float]): Range of extra silence duration. - - max_cer (float): Maximum allowed character error rate. - - min_context_speaker_similarity (float): Minimum similarity score. - - target_speaker (str, optional): Target speaker filter. - - sample_rate (int): Audio sample rate for resampling. + if user_sup.custom is not None and "ipa" in user_sup.custom: + user_sup.custom = deepcopy(user_sup.custom) + user_sup.custom["ipa"] = "" - Returns: - Tuple[CutSet, bool]: Converted cuts and a flag indicating if data was tarred. - """ + if add_extra_end_sil: + sil_duration = random.uniform(*extra_end_silence_range) + cut_target = cut_target.pad(duration=total_duration + sil_duration, direction="right") + cut_source = cut_source.pad(duration=total_duration + sil_duration, direction="right") + cut_source = cut_source.to_mono().move_to_memory(audio_format='wav') + cut_target = cut_target.to_mono().move_to_memory(audio_format='wav') + agent_sup.duration += sil_duration + 1.0 + user_sup.duration += sil_duration + + cut_source.supervisions = [user_sup, agent_sup] + cut_source.target_audio = cut_target.recording + cut_source.duration = cut_target.duration + cut_source.context_audio = cut.context_audio + cut_source.task = "lhotse_magpietts_data_as_continuation" + + return cut_source + + +@data_type_parser(["lhotse_magpietts_data_as_continuation"]) +def read_lhotse_magpietts_data_as_s2s_duplex(config) -> Tuple[CutSet, bool]: cuts, is_tarred = read_cutset_from_config(config) add_extra_end_sil = config.get("add_extra_end_silence", False) @@ -948,24 +909,78 @@ def read_lhotse_magpietts_data_as_s2s_duplex(config) -> Tuple[CutSet, bool]: target_speaker = config.get("target_speaker", None) keep_flag = "pass" - def create_recording_from_array(samples: np.ndarray, sampling_rate: int, recording_id: str) -> Recording: - """Convert a numpy array into a Lhotse Recording object.""" - with io.BytesIO() as buffer: - sf.write(buffer, samples.T, samplerate=sampling_rate, format='WAV') - buffer.seek(0) - return Recording.from_bytes(buffer.read(), recording_id=recording_id) + cuts = ( + cuts.filter(partial(_filter_cer_fn, max_cer=max_cer)) + .filter(partial(_filter_val_flag_fn, keep_flag=keep_flag)) + .filter(partial(_filter_secs_fn, min_sim=min_context_speaker_similarity)) + .filter(partial(_filter_target_speaker_fn, target_speaker=target_speaker)) + ) + + cuts = cuts.map(partial(_convert_cut_fn, sample_rate=sample_rate, add_extra_end_sil=add_extra_end_sil, extra_end_silence_range=extra_end_silence_range)) - def convert_cut_fn(cut: Cut) -> Cut: - """Convert a single cut into the continuation format.""" - orig_agent_sup = deepcopy(cut.supervisions[0]) + return cuts, is_tarred + + +class FilterCER: + def __init__(self, max_cer: float): + self.max_cer = max_cer + + def __call__(self, cut: Cut) -> bool: + return ( + len(cut.supervisions) == 0 + or not cut.supervisions[0].has_custom("cer") + or cut.supervisions[0].cer <= self.max_cer + ) + +class FilterValFlag: + def __init__(self, keep_flag: str): + self.keep_flag = keep_flag + + def __call__(self, cut: Cut) -> bool: + return not cut.has_custom("validation_status") or cut.validation_status == self.keep_flag + +class FilterSecs: + def __init__(self, min_sim: float): + self.min_sim = min_sim + + def __call__(self, cut: Cut) -> bool: + return ( + len(cut.supervisions) == 0 + or not cut.supervisions[0].has_custom("context_speaker_similarity") + or cut.supervisions[0].context_speaker_similarity >= self.min_sim + ) + +class FilterTargetSpeaker: + def __init__(self, target_speaker: str): + self.target_speaker = target_speaker + + def __call__(self, cut: Cut) -> bool: + return len(cut.supervisions) == 0 or self.target_speaker is None or self.target_speaker in cut.supervisions[0].speaker + +def _create_recording_from_array(samples: np.ndarray, sampling_rate: int, recording_id: str) -> Recording: + with io.BytesIO() as buffer: + sf.write(buffer, samples.T, samplerate=sampling_rate, format='WAV') + buffer.seek(0) + return Recording.from_bytes(buffer.read(), recording_id=recording_id) + +class ConvertCutFn: + def __init__(self, sample_rate: int, add_extra_end_sil: bool, extra_end_silence_range: list): + self.sample_rate = sample_rate + self.add_extra_end_sil = add_extra_end_sil + self.extra_end_silence_range = extra_end_silence_range + + def __call__(self, cut: Cut) -> Cut: + orig_agent_sup = fastcopy(cut.supervisions[0]) target_audio_orig_dur = cut.target_audio.duration - # Resample audios - cut.target_audio = cut.target_audio.resample(sample_rate) - cut.context_audio = cut.context_audio.resample(sample_rate) + cut.target_audio = cut.target_audio.resample(self.sample_rate) + + # --- SAFELY CHECK FOR CONTEXT AUDIO --- + if cut.has_custom("context_audio"): + cut.context_audio = cut.context_audio.resample(self.sample_rate) + total_duration = cut.target_audio.duration - # Prepare MonoCuts cut_target = MonoCut( id=f"{cut.id}_target", start=0.0, @@ -975,10 +990,9 @@ def convert_cut_fn(cut: Cut) -> Cut: supervisions=[], ) - zero_audio = np.zeros((1, int(total_duration * sample_rate)), dtype=np.float32) - source_recording = create_recording_from_array(zero_audio, sample_rate, recording_id=f"{cut.id}_source") + zero_audio = np.zeros((1, int(total_duration * self.sample_rate)), dtype=np.float32) + source_recording = _create_recording_from_array(zero_audio, self.sample_rate, recording_id=f"{cut.id}_source") - # Create source cut, preserving cut.custom safely cut_source = MonoCut( id=f"{cut.id}_source", start=0.0, @@ -989,22 +1003,18 @@ def convert_cut_fn(cut: Cut) -> Cut: custom=deepcopy(cut.custom) if cut.custom is not None else None, ) - # Save to memory cut_source = cut_source.move_to_memory(audio_format='wav') cut_target = cut_target.move_to_memory(audio_format='wav') - # Create user and agent supervisions user_sup = fastcopy(orig_agent_sup, start=0.0, duration=0.08, speaker="user", text="dummy text") agent_sup = fastcopy(orig_agent_sup, start=0.0, duration=target_audio_orig_dur - 0.08, speaker="agent") - # Safely wipe IPA from dummy user text if it exists if user_sup.custom is not None and "ipa" in user_sup.custom: user_sup.custom = deepcopy(user_sup.custom) user_sup.custom["ipa"] = "" - # Optionally add extra silence - if add_extra_end_sil: - sil_duration = random.uniform(*extra_end_silence_range) + if self.add_extra_end_sil: + sil_duration = random.uniform(*self.extra_end_silence_range) cut_target = cut_target.pad(duration=total_duration + sil_duration, direction="right") cut_source = cut_source.pad(duration=total_duration + sil_duration, direction="right") cut_source = cut_source.to_mono().move_to_memory(audio_format='wav') @@ -1012,43 +1022,40 @@ def convert_cut_fn(cut: Cut) -> Cut: agent_sup.duration += sil_duration + 1.0 user_sup.duration += sil_duration - # Assemble final cut cut_source.supervisions = [user_sup, agent_sup] cut_source.target_audio = cut_target.recording cut_source.duration = cut_target.duration - cut_source.context_audio = cut.context_audio + + if cut.has_custom("context_audio"): + cut_source.context_audio = cut.context_audio + cut_source.task = "lhotse_magpietts_data_as_continuation" return cut_source - # Filters - def filter_cer_fn(cut: Cut) -> bool: - return ( - len(cut.supervisions) == 0 - or not cut.supervisions[0].has_custom("cer") - or cut.supervisions[0].cer <= max_cer - ) - def filter_val_flag_fn(cut: Cut) -> bool: - return not cut.has_custom("validation_status") or cut.validation_status == keep_flag +@data_type_parser(["lhotse_magpietts_data_as_continuation"]) +def read_lhotse_magpietts_data_as_s2s_duplex(config) -> Tuple[CutSet, bool]: + cuts, is_tarred = read_cutset_from_config(config) - def filter_secs_fn(cut: Cut) -> bool: - return ( - len(cut.supervisions) == 0 - or not cut.supervisions[0].has_custom("context_speaker_similarity") - or cut.supervisions[0].context_speaker_similarity >= min_context_speaker_similarity - ) + add_extra_end_sil = config.get("add_extra_end_silence", False) + extra_end_silence_range = config.get("extra_end_silence_range", [0.5, 6.0]) + sample_rate = config.get("sample_rate", 22050) - def filter_target_speaker_fn(cut: Cut) -> bool: - return len(cut.supervisions) == 0 or target_speaker is None or target_speaker in cut.supervisions[0].speaker + max_cer = config.get("max_cer", 0.03) + min_context_speaker_similarity = config.get("min_context_speaker_similarity", 0.6) + target_speaker = config.get("target_speaker", None) + keep_flag = "pass" - # Apply filters + # Use the globally defined classes cuts = ( - cuts.filter(filter_cer_fn).filter(filter_val_flag_fn).filter(filter_secs_fn).filter(filter_target_speaker_fn) + cuts.filter(FilterCER(max_cer)) + .filter(FilterValFlag(keep_flag)) + .filter(FilterSecs(min_context_speaker_similarity)) + .filter(FilterTargetSpeaker(target_speaker)) ) - # Convert cuts - cuts = cuts.map(convert_cut_fn) + cuts = cuts.map(ConvertCutFn(sample_rate, add_extra_end_sil, extra_end_silence_range)) return cuts, is_tarred diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn copy.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn copy.py deleted file mode 100644 index 413abddfe264..000000000000 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn copy.py +++ /dev/null @@ -1,953 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import random -import re -from copy import deepcopy - -import numpy as np -import torch -import torch.nn.functional as F -import torch.utils.data -from lhotse import CutSet, Seconds, compute_num_frames -from lhotse.cut import Cut -from lhotse.dataset.collation import collate_audio, collate_vectors, collate_matrices -from lhotse.utils import ifnone - -from nemo.collections.common.tokenizers import TokenizerSpec -from nemo.collections.speechlm2.data.utils import get_pad_id -from nemo.collections.speechlm2.parts.precision import fp32_precision -from nemo.collections.tts.parts.utils.helpers import get_mask_from_lengths -from nemo.utils import logging - -from hydra.utils import instantiate -from omegaconf import DictConfig -from nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers import IPABPETokenizer - -class MagpieTTSLhotseMultiturnDataset(torch.utils.data.Dataset): - """ - A dataset for duplex speech-to-speech models that handles bidirectional conversations. - - This dataset processes Lhotse CutSet objects containing recordings with supervision segments - from different speakers (roles). It creates aligned representations of audio and text for - both source (input) and target (output) channels, preserving temporal alignment between - audio frames and text tokens. - - Args: - tokenizer (TokenizerSpec): - Tokenizer for converting text to token IDs and vice versa. Must support BOS and EOS tokens. - It's expected to support PAD token as well, otherwise we will use 0 as the pad token - and emit a warning. - - frame_length (Seconds): - Duration of a single frame in seconds. Used to calculate frame positions for token alignment. - - source_sample_rate (int): - Sample rate for source audio (e.g., 16000 Hz). - - target_sample_rate (int): - Sample rate for target audio (e.g., 22050 Hz). - - input_roles (list[str], optional): - List of speaker roles (cut.supervisions[:].speaker) to consider as inputs. Defaults to ["user"]. - - output_roles (list[str], optional): - List of speaker roles (cut.supervisions[:].speaker) to consider as outputs. Defaults to ["agent"]. - - p_drop_description (float, optional): - Probability of dropping text descriptions. Default: `0.0`. - - add_text_bos_and_eos_in_each_turn (bool, optional): - If True, each conversational turn from any speaker is explicitly delimited - with BOS and EOS tokens in the text stream. - Default: `True`. - - Returns: - A dictionary with the following keys: - - sample_id: List of sample IDs for each cut in the batch [B] - - - non_prompt_mask: Bool tensor [B, T] marking positions that are not part of the prompt - - prompt_lens: Tensor of description + audio prompt lengths [B] - - - aligned_attention_mask: Bool tensor [B, T] used by alignment-aware transformer models - - aligned_position_ids: Tensor of position indices aligned to audio frames [B, T] - - - source_audio: Tensor of source waveform samples [B, T] - - source_audio_lens: Tensor of source audio lengths [B] - - - target_audio: Tensor of target waveform samples [B, T] - - target_audio_lens: Tensor of target audio lengths [B] - - - target_text_tokens: Tensor of frame-aligned input text tokens [B, T], - including BOS/EOS/PAD when enabled - - target_token_lens: Tensor of target token sequence lengths [B] - - - source_tokens: Tensor of frame-aligned source text tokens [B, T], - including BOS/EOS/PAD - - source_token_lens: Tensor of source token sequence lengths [B] - - - target_texts: List of full target texts joined from output_roles supervisions [B] - - - audio_prompt: Tensor of optional speaker reference waveform samples [B, T] - - audio_prompt_lens: Tensor of speaker reference audio lengths [B] - - - task: List indicating the task to use for each cut (default "s2s_duplex") [B] - - Notes: - - The dataset ensures frame-level alignment between audio and text by inserting tokens at - specific frame positions based on the timing of supervision segments. - - PAD tokens (typically 0) are used to fill gaps where there's no text. - - BOS tokens mark the beginning of each speech segment. - - EOS tokens mark the end of each speech segment. - - Text tokens from each speaker are placed at frame positions corresponding to their - timestamp in the original recording, preserving the temporal relationship. - This is a segment-level alignment only, not word-level alignment. - """ - - def __init__( - self, - tokenizer, - frame_length: Seconds, - source_sample_rate: int, - target_sample_rate: int, - input_roles: list[str] = None, - output_roles: list[str] = None, - p_drop_description: float = 0.0, - add_text_bos_and_eos_in_each_turn: bool = False, - add_audio_prompt: bool = False, - audio_prompt_duration: float = 3.0, - num_delay_speech_tokens: int = 0, - add_system_prompt: bool = False, - ignore_data_system_prompt: bool = True, - phoneme_tokenizer_config: DictConfig = None, - ignore_phoneme_languages: list[str] = None, - load_cached_codes_if_available: bool = False, - ): - self.tokenizer = tokenizer - self.frame_length = frame_length - self.source_sample_rate = source_sample_rate - self.target_sample_rate = target_sample_rate - self.input_roles = set(ifnone(input_roles, ["user"])) - self.output_roles = set(ifnone(output_roles, ["agent"])) - self.p_drop_description = p_drop_description - self.add_text_bos_and_eos_in_each_turn = add_text_bos_and_eos_in_each_turn - self.add_audio_prompt = add_audio_prompt - self.audio_prompt_duration = audio_prompt_duration - self.num_delay_speech_tokens = num_delay_speech_tokens - self.add_system_prompt = add_system_prompt - self.ignore_data_system_prompt = ignore_data_system_prompt - - self.phoneme_tokenizer_config = phoneme_tokenizer_config - self.ignore_phoneme_languages = ignore_phoneme_languages or [] - self.phoneme_tokenizer = None - self.load_cached_codes_if_available = load_cached_codes_if_available - - self.source_samples_per_frame = int(self.source_sample_rate * self.frame_length) - self.target_samples_per_frame = int(self.target_sample_rate * self.frame_length) - - assert tokenizer.bos is not None, "BOS support in the tokenizer is required for S2S models." - assert tokenizer.eos is not None, "EOS support in the tokenizer is required for S2S models." - - def __getitem__(self, cuts: CutSet) -> dict: - if self.phoneme_tokenizer is None and getattr(self, "phoneme_tokenizer_config", None) is not None: - self.phoneme_tokenizer = instantiate(self.phoneme_tokenizer_config) - - cuts = cuts.transform_text(_strip_timestamps) - - batch_tokenizer_names = [] - for cut in cuts: - if cut.has_custom("tokenizer_names"): - batch_tokenizer_names.append(random.choice(cut.tokenizer_names)) - else: - batch_tokenizer_names.append("english_phoneme") - - target_codes_list = [] - source_codes_list = [] - if self.load_cached_codes_if_available: - for cut in cuts: - if cut.has_custom("target_codes"): - codes_array = cut.target_codes.load().astype(np.int32) - target_codes_list.append(torch.from_numpy(codes_array).T) - if cut.has_custom("source_codes"): - codes_array = cut.source_codes.load().astype(np.int32) - source_codes_list.append(torch.from_numpy(codes_array).T) - - if target_codes_list: - target_codes = collate_matrices(target_codes_list, padding_value=0).transpose(1, 2) - target_codes_lens = torch.tensor([c.shape[0] for c in target_codes_list], dtype=torch.int32) - else: - target_codes, target_codes_lens = None, None - - if source_codes_list: - source_codes = collate_matrices(source_codes_list, padding_value=0).transpose(1, 2) - source_codes_lens = torch.tensor([c.shape[0] for c in source_codes_list], dtype=torch.int32) - else: - source_codes, source_codes_lens = None, None - - with fp32_precision(): - source_audio, source_audio_lens = collate_audio(cuts.resample(self.source_sample_rate)) - target_audio, target_audio_lens = collate_audio( - cuts.resample(self.target_sample_rate, recording_field="target_audio"), recording_field="target_audio" - ) - - target_text_tokens, target_token_lens = collate_token_channel( - cuts, - self.tokenizer, - self.frame_length, - roles=self.output_roles, - add_text_bos_and_eos_in_each_turn=self.add_text_bos_and_eos_in_each_turn, - tokenizer_names=batch_tokenizer_names, - ) - source_tokens, source_token_lens = collate_token_channel( - cuts, - self.tokenizer, - self.frame_length, - roles=self.input_roles, - add_text_bos_and_eos_in_each_turn=self.add_text_bos_and_eos_in_each_turn, - tokenizer_names=batch_tokenizer_names, - ) - - if self.phoneme_tokenizer is not None: - target_phoneme_tokens, target_phoneme_lens = collate_phoneme_channel( - cuts, - self.phoneme_tokenizer, - self.frame_length, - roles=self.output_roles, - ignore_phoneme_languages=self.ignore_phoneme_languages, - add_text_bos_and_eos_in_each_turn=self.add_text_bos_and_eos_in_each_turn, - ) - else: - target_phoneme_tokens, target_phoneme_lens = None, None - - with fp32_precision(): - audio_prompt, audio_prompt_lens = get_audio_prompt( - cuts, self.target_sample_rate, roles=self.output_roles, recording_field="target_audio" - ) - - if self.num_delay_speech_tokens: - ( - source_audio, - source_audio_lens, - target_audio, - target_audio_lens, - source_codes, - source_codes_lens, - target_codes, - target_codes_lens - ) = add_speech_delay( - source_audio, - source_audio_lens, - target_audio, - target_audio_lens, - self.num_delay_speech_tokens, - self.target_samples_per_frame, - self.source_samples_per_frame, - source_codes=source_codes, - source_codes_lens=source_codes_lens, - target_codes=target_codes, - target_codes_lens=target_codes_lens, - ) - - if self.add_system_prompt: - with fp32_precision(): - system_prompts, system_prompts_lens, system_prompts_raw = collate_system_prompt( - cuts, - self.tokenizer, - ignore_data_system_prompt=self.ignore_data_system_prompt, - tokenizer_names=batch_tokenizer_names, - ) - else: - system_prompts = None - system_prompts_lens = None - system_prompts_raw = None - - dataset_type = [getattr(c, "type", "") for c in cuts] - - ( - target_text_tokens, - target_token_lens, - source_tokens, - source_token_lens, - source_audio, - source_audio_lens, - target_audio, - target_audio_lens, - prompt_lens, - target_phoneme_tokens, - target_phoneme_lens, - source_codes, - source_codes_lens, - target_codes, - target_codes_lens, - ) = self.maybe_add_audio_prompt( - target_text_tokens, target_token_lens, source_tokens, source_token_lens, - target_audio, target_audio_lens, source_audio, source_audio_lens, - audio_prompt, audio_prompt_lens, system_prompts, system_prompts_lens, - target_phoneme_tokens=target_phoneme_tokens, target_phoneme_lens=target_phoneme_lens, - source_codes=source_codes, source_codes_lens=source_codes_lens, - target_codes=target_codes, target_codes_lens=target_codes_lens, - ) - - non_prompt_mask = get_mask_from_lengths(target_token_lens) - for i, frame in enumerate(prompt_lens): - non_prompt_mask[i, : frame - 1] = 0.0 - - max_len = max(target_token_lens) - aligned_segment_ids = torch.stack( - [torch.nn.functional.pad(torch.full((seq_len,), i), (0, max_len - seq_len), value=-1) for i, seq_len in enumerate(target_token_lens)], dim=0, - ) - aligned_attention_mask = (aligned_segment_ids.unsqueeze(-2) == aligned_segment_ids.unsqueeze(-1)) & ( - torch.arange(max_len).unsqueeze(0).unsqueeze(1) <= torch.arange(max_len).unsqueeze(0).unsqueeze(-1) - ) - aligned_attention_mask = aligned_attention_mask.unsqueeze(1) - aligned_position_ids = torch.stack( - [torch.nn.functional.pad(torch.arange(seq_len), (0, max(target_token_lens) - seq_len), value=0) for seq_len in target_token_lens], dim=0, - ) - - batch_dict = { - "sample_id": [str(cut.id) for cut in cuts], - "non_prompt_mask": non_prompt_mask.bool(), - "prompt_lens": prompt_lens, - "aligned_attention_mask": aligned_attention_mask.bool(), - "aligned_position_ids": aligned_position_ids, - "source_audio": source_audio, - "source_audio_lens": source_audio_lens, - "target_audio": target_audio, - "target_audio_lens": target_audio_lens, - "target_text_tokens": target_text_tokens, - "target_token_lens": target_token_lens, - "source_tokens": source_tokens, - "source_token_lens": source_token_lens, - "target_texts": [ - " ".join(s.text for s in cut.supervisions if s.speaker in self.output_roles) for cut in cuts - ], - "audio_prompt": audio_prompt, - "audio_prompt_lens": audio_prompt_lens, - "system_prompts_raw": system_prompts_raw, - "dataset_type": dataset_type, - "phoneme_tokens": target_phoneme_tokens, - "phoneme_tokens_lens": target_phoneme_lens, - "task": [getattr(cut, "task", "s2s_duplex") for cut in cuts], - } - - if target_codes is not None: - batch_dict["target_codes"] = target_codes - batch_dict["target_codes_lens"] = target_codes_lens - if source_codes is not None: - batch_dict["source_codes"] = source_codes - batch_dict["source_codes_lens"] = source_codes_lens - - return batch_dict - - def maybe_add_audio_prompt( - self, - target_text_tokens: torch.Tensor, - target_token_lens: torch.Tensor, - source_tokens: torch.Tensor, - source_token_lens: torch.Tensor, - target_audio: torch.Tensor, - target_audio_lens: torch.Tensor, - source_audio: torch.Tensor, - source_audio_lens: torch.Tensor, - audio_prompt: torch.Tensor, - audio_prompt_lens: torch.Tensor, - system_prompts: torch.Tensor = None, - system_prompts_lens: torch.Tensor = None, - target_phoneme_tokens: torch.Tensor = None, - target_phoneme_lens: torch.Tensor = None, - source_codes: torch.Tensor = None, - source_codes_lens: torch.Tensor = None, - target_codes: torch.Tensor = None, - target_codes_lens: torch.Tensor = None, - ): - text_pad_id = get_pad_id(self.tokenizer) - - target_text_tokens_ = [] - source_tokens_ = [] - source_audio_ = [] - target_audio_ = [] - prompt_lens = [] - - target_phoneme_tokens_ = [] - phoneme_pad_id = self.phoneme_tokenizer.pad if self.phoneme_tokenizer else -1 - - source_codes_ = [] - target_codes_ = [] - - for i in range(target_text_tokens.size(0)): - if system_prompts is not None: - text_prompt = system_prompts[i][: system_prompts_lens[i]] - else: - text_prompt = torch.tensor( - [self.tokenizer.eos], - dtype=torch.long, - device=target_text_tokens.device, - ) - - if self.add_audio_prompt: - prompt_audio_size = int( - ((self.audio_prompt_duration * self.target_sample_rate) // self.target_samples_per_frame) - * self.target_samples_per_frame - ) - - prompt_audio = sample_audio_segments_repeat( - audio_prompt, audio_prompt_lens, prompt_audio_size, sample=True - ) - - prompt_audio[:, -int(self.target_samples_per_frame * 2) :] = 0 - - prompt_audio_text_pad_size = prompt_audio_size // self.target_samples_per_frame - prompt_audio_text_pad = ( - torch.ones(prompt_audio_text_pad_size, device=target_text_tokens.device, dtype=target_text_tokens.dtype) - * text_pad_id - ) - prompt_audio_text_pad[-1] = self.tokenizer.eos - - new_target_text_tokens = torch.cat( - [text_prompt.to(target_text_tokens.dtype), prompt_audio_text_pad, target_text_tokens[i]] - ) - target_text_tokens_.append(new_target_text_tokens) - target_token_lens[i] += len(text_prompt) + prompt_audio_text_pad_size - - new_source_tokens = torch.cat([text_prompt, prompt_audio_text_pad, source_tokens[i]]) - source_tokens_.append(new_source_tokens) - source_token_lens[i] += len(text_prompt) + prompt_audio_text_pad_size - - if target_phoneme_tokens is not None: - phoneme_pad_size = len(text_prompt) + prompt_audio_text_pad_size - phoneme_pad = torch.full((phoneme_pad_size,), phoneme_pad_id, device=target_phoneme_tokens.device, dtype=target_phoneme_tokens.dtype) - target_phoneme_tokens_.append(torch.cat([phoneme_pad, target_phoneme_tokens[i]])) - target_phoneme_lens[i] += phoneme_pad_size - - code_pad_size = len(text_prompt) + prompt_audio_text_pad_size - if target_codes is not None: - pad_codes = torch.zeros((target_codes.size(1), code_pad_size), device=target_codes.device, dtype=target_codes.dtype) - target_codes_.append(torch.cat([pad_codes, target_codes[i]], dim=1)) - target_codes_lens[i] += code_pad_size - - if source_codes is not None: - pad_codes = torch.zeros((source_codes.size(1), code_pad_size), device=source_codes.device, dtype=source_codes.dtype) - source_codes_.append(torch.cat([pad_codes, source_codes[i]], dim=1)) - source_codes_lens[i] += code_pad_size - - pad_size_src = (len(text_prompt) * self.source_samples_per_frame) + prompt_audio.size(1) - pad_audio_src = torch.zeros(pad_size_src, device=source_audio.device, dtype=source_audio.dtype) - source_audio_.append(torch.cat([pad_audio_src, source_audio[i]])) - source_audio_lens[i] += pad_size_src - - pad_size_tgt = len(text_prompt) * self.target_samples_per_frame - pad_audio_tgt = torch.zeros(pad_size_tgt, device=target_audio.device, dtype=target_audio.dtype) - target_audio_.append(torch.cat([pad_audio_tgt, prompt_audio[i], target_audio[i]])) - target_audio_lens[i] += pad_size_tgt + prompt_audio.size(1) - - prompt_lens.append(len(text_prompt) + prompt_audio_text_pad_size - 1) - - else: - target_text_tokens_.append(torch.cat([text_prompt, target_text_tokens[i]])) - target_token_lens[i] += len(text_prompt) - - source_tokens_.append(torch.cat([text_prompt, source_tokens[i]])) - source_token_lens[i] += len(text_prompt) - - if target_phoneme_tokens is not None: - phoneme_pad_size = len(text_prompt) - phoneme_pad = torch.full((phoneme_pad_size,), phoneme_pad_id, device=target_phoneme_tokens.device, dtype=target_phoneme_tokens.dtype) - target_phoneme_tokens_.append(torch.cat([phoneme_pad, target_phoneme_tokens[i]])) - target_phoneme_lens[i] += phoneme_pad_size - - code_pad_size = len(text_prompt) - if target_codes is not None: - pad_codes = torch.zeros((target_codes.size(1), code_pad_size), device=target_codes.device, dtype=target_codes.dtype) - target_codes_.append(torch.cat([pad_codes, target_codes[i]], dim=1)) - target_codes_lens[i] += code_pad_size - - if source_codes is not None: - pad_codes = torch.zeros((source_codes.size(1), code_pad_size), device=source_codes.device, dtype=source_codes.dtype) - source_codes_.append(torch.cat([pad_codes, source_codes[i]], dim=1)) - source_codes_lens[i] += code_pad_size - - pad_size_src = len(text_prompt) * self.source_samples_per_frame - pad_audio_src = torch.zeros(pad_size_src, device=source_audio.device, dtype=source_audio.dtype) - source_audio_.append(torch.cat([pad_audio_src, source_audio[i]])) - source_audio_lens[i] += pad_size_src - - pad_size_tgt = len(text_prompt) * self.target_samples_per_frame - pad_audio_tgt = torch.zeros(pad_size_tgt, device=target_audio.device, dtype=target_audio.dtype) - target_audio_.append(torch.cat([pad_audio_tgt, target_audio[i]])) - target_audio_lens[i] += pad_size_tgt - - prompt_lens.append(len(text_prompt)) - - target_text_tokens = collate_vectors(target_text_tokens_, padding_value=text_pad_id) - source_tokens = collate_vectors(source_tokens_, padding_value=text_pad_id) - source_audio = collate_vectors(source_audio_, padding_value=0) - target_audio = collate_vectors(target_audio_, padding_value=0) - - if target_phoneme_tokens is not None: - target_phoneme_tokens = collate_vectors(target_phoneme_tokens_, padding_value=phoneme_pad_id) - - if target_codes is not None: - max_len = max([c.size(1) for c in target_codes_]) - target_codes = torch.stack([F.pad(c, (0, max_len - c.size(1))) for c in target_codes_]) - if source_codes is not None: - max_len = max([c.size(1) for c in source_codes_]) - source_codes = torch.stack([F.pad(c, (0, max_len - c.size(1))) for c in source_codes_]) - - return ( - target_text_tokens, - target_token_lens, - source_tokens, - source_token_lens, - source_audio, - source_audio_lens, - target_audio, - target_audio_lens, - prompt_lens, - target_phoneme_tokens, - target_phoneme_lens, - source_codes, - source_codes_lens, - target_codes, - target_codes_lens, - ) - - -def build_phoneme_channel( - cut: Cut, - phoneme_tokenizer, - frame_length: Seconds, - roles: set[str], - ignore_phoneme_languages: list[str], - pad_id: int = -1, - add_text_bos_and_eos_in_each_turn: bool = True, -) -> torch.Tensor: - """ - Build a frame-aligned phoneme sequence for a single cut, mirroring text token logic. - """ - diagnostic = f"Extra info: {cut.id=}" - if getattr(cut, "shard_origin", None) is not None: - diagnostic = f"{diagnostic} {cut.shard_origin=}" - - total = compute_num_frames(cut.duration, frame_length, cut.sampling_rate) - tokens = torch.ones(total, dtype=torch.long) * pad_id - - if cut.has_custom("lang"): - language = cut.lang - else: - language = cut.supervisions[0].language if cut.supervisions[0].has_custom("language") else "en" - - for supervision in cut.supervisions: - if supervision.speaker in roles: - if isinstance(phoneme_tokenizer, IPABPETokenizer): - if not supervision.has_custom("ipa"): - logging.warning(f"'ipa' field not found in cut {cut.id}. Using empty string.") - ipa_text = "" - else: - ipa_text = supervision.ipa - - if language in ignore_phoneme_languages: - ipa_text = "" - else: - ipa_text = supervision.text - - phoneme_ids = phoneme_tokenizer.encode(ipa_text) - if add_text_bos_and_eos_in_each_turn: - phoneme_ids = [phoneme_tokenizer.bos_token_id] + phoneme_ids - - phoneme_ids = torch.as_tensor(phoneme_ids, dtype=torch.long) - - pos = compute_num_frames(supervision.start, frame_length, cut.sampling_rate) - if pos > len(tokens): - logging.warning(f"Supervision offset {pos} larger than {len(tokens)}. {diagnostic}") - continue - - endpos = pos + len(phoneme_ids) - if endpos > len(tokens): - trunc_len = len(tokens) - pos - logging.warning(f"Truncating phoneme_ids by {trunc_len}. {diagnostic}") - phoneme_ids = phoneme_ids[:trunc_len] - - try: - tokens[pos:endpos] = phoneme_ids - except Exception as e: - raise RuntimeError(f"{tokens.shape=} {pos=} {endpos=} {phoneme_ids.shape=} {diagnostic}") from e - - if add_text_bos_and_eos_in_each_turn: - eospos = compute_num_frames(supervision.end, frame_length, cut.sampling_rate) - if eospos < len(tokens): - tokens[eospos] = phoneme_tokenizer.eos_token_id - - return tokens - - -def collate_phoneme_channel( - cuts: CutSet, - phoneme_tokenizer, - frame_length: Seconds, - roles: set[str], - ignore_phoneme_languages: list[str], - add_text_bos_and_eos_in_each_turn: bool = True, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Collate frame-aligned phoneme channels. - """ - pad_id = phoneme_tokenizer.pad - tokens = [ - build_phoneme_channel( - c, - phoneme_tokenizer=phoneme_tokenizer, - frame_length=frame_length, - roles=roles, - ignore_phoneme_languages=ignore_phoneme_languages, - pad_id=pad_id, - add_text_bos_and_eos_in_each_turn=add_text_bos_and_eos_in_each_turn, - ) - for c in cuts - ] - token_lens = torch.tensor([len(tt) for tt in tokens]) - tokens = collate_vectors(tokens, padding_value=pad_id) - return tokens, token_lens - - -def add_speech_delay( - source_audio: torch.Tensor, - source_audio_lens: torch.Tensor, - target_audio: torch.Tensor, - target_audio_lens: torch.Tensor, - num_delay_speech_tokens: int, - target_samples_per_frame: int, - source_samples_per_frame: int, - source_codes: torch.Tensor = None, - source_codes_lens: torch.Tensor = None, - target_codes: torch.Tensor = None, - target_codes_lens: torch.Tensor = None, -): - """ - Apply a speech delay by padding audio waveforms based on the number of delay speech tokens. - """ - extra_target_samples = int(num_delay_speech_tokens * target_samples_per_frame) - target_audio = F.pad(target_audio, (extra_target_samples, 0)) - target_audio_lens = target_audio_lens + extra_target_samples - - extra_source_samples = int(num_delay_speech_tokens * source_samples_per_frame) - source_audio = F.pad(source_audio, (0, extra_source_samples)) - source_audio_lens = source_audio_lens + extra_source_samples - - if target_codes is not None: - target_codes = F.pad(target_codes, (num_delay_speech_tokens, 0)) - target_codes_lens = target_codes_lens + num_delay_speech_tokens - - if source_codes is not None: - source_codes = F.pad(source_codes, (0, num_delay_speech_tokens)) - source_codes_lens = source_codes_lens + num_delay_speech_tokens - - return ( - source_audio, source_audio_lens, target_audio, target_audio_lens, - source_codes, source_codes_lens, target_codes, target_codes_lens - ) - - -def collate_system_prompt( - cuts: CutSet, - tokenizer: TokenizerSpec, - ignore_data_system_prompt: bool = False, - tokenizer_names: list[str] = None, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Collate system prompts from cuts. - System prompts should be stored in cut.custom['system_prompt']. - """ - pad_id = get_pad_id(tokenizer) - tokens = [] - system_prompts_raw = [] - - for i, c in enumerate(cuts): - tok_name = tokenizer_names[i] if tokenizer_names else "english_phoneme" - - def _encode(txt): - if hasattr(tokenizer, "encode"): - try: - return tokenizer.encode(text=txt, tokenizer_name=tok_name) - except TypeError: - return tokenizer.encode(text=txt) - return tokenizer.text_to_ids(txt) - - if c.custom and c.custom.get("system_prompt", None) and not ignore_data_system_prompt: - prompt_text = c.custom["system_prompt"] - tokens.append( - torch.as_tensor( - [tokenizer.bos] + _encode(prompt_text) + [tokenizer.eos], dtype=torch.long - ) - ) - system_prompts_raw.append(prompt_text) - else: - if getattr(c, "type", None): - prompt_text = c.type - tokens.append( - torch.as_tensor( - [tokenizer.bos] + _encode(prompt_text) + [tokenizer.eos], dtype=torch.long - ) - ) - system_prompts_raw.append(prompt_text) - else: - logging.warning( - "No system prompt or dataset type defined on the config! Using a eos token as system prompt!" - ) - tokens.append(torch.as_tensor([tokenizer.eos], dtype=torch.long)) - system_prompts_raw.append("") - - token_lens = torch.tensor([len(tt) for tt in tokens]) - tokens = collate_vectors(tokens, padding_value=pad_id) - return tokens, token_lens, system_prompts_raw - - -def get_audio_prompt( - cuts: CutSet, - target_sample_rate: int, - roles: set[str], - recording_field: str = "target_audio", -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Retrieve an audio prompt for speaker conditioning. - """ - if hasattr(cuts[0], "context_audio"): - audio_prompt = [] - audio_prompt_lens = [] - - for cut in cuts: - ref_audio = cut.context_audio.resample(target_sample_rate).load_audio() - ref_audio = torch.tensor(ref_audio).float() - ref_audio_len = ref_audio.shape[1] - - audio_prompt.append(ref_audio.squeeze(0)) - audio_prompt_lens.append(ref_audio_len) - - audio_prompt = collate_vectors(audio_prompt, padding_value=0).float() - audio_prompt_lens = torch.tensor(audio_prompt_lens).long() - - else: - cuts = sanitize_cuts(cuts) - audio_prompt, audio_prompt_lens = collate_random_turn_audio( - cuts.resample(target_sample_rate, recording_field=recording_field), - roles=roles, - recording_field=recording_field, - ) - - return audio_prompt, audio_prompt_lens - - -def sanitize_cuts(cuts: CutSet) -> CutSet: - """ - Adjusts supervisions to fit within the cut's truncated duration. - """ - sanitized_list = [] - - for cut in cuts: - valid_supervisions = [] - for sup in cut.supervisions: - if sup.start >= cut.duration: - continue - - if sup.end > cut.duration: - new_duration = cut.duration - sup.start - - if new_duration <= 0: - continue - - new_sup = deepcopy(sup) - new_sup.duration = new_duration - valid_supervisions.append(new_sup) - - else: - valid_supervisions.append(sup) - - cut.supervisions = valid_supervisions - sanitized_list.append(cut) - - return cuts.from_cuts(sanitized_list) - - -def collate_random_turn_audio( - cuts: CutSet, - roles: set[str], - recording_field: str = "target_audio", -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Sample and collate reference audio from random speaker turns. - """ - selected_turn_audios = [] - selected_turn_audios_lens = [] - for cut in cuts: - matching_supervisions = [s for s in cut.supervisions if s.speaker in roles] - if len(matching_supervisions) == 0: - target_duration = 5.0 - num_samples = int(target_duration * cut.sampling_rate) - - silence_tensor = torch.zeros(num_samples, dtype=torch.float32) - selected_turn_audios.append(silence_tensor) - selected_turn_audios_lens.append(num_samples) - logging.warning( - "There is no target speaker supervision available on this sample! Using a silence audio as audio prompt!" - ) - else: - selected_supervision = random.choice(matching_supervisions) - truncated_audio = cut.truncate( - offset=max(0, selected_supervision.start), duration=selected_supervision.duration - ).load_custom(recording_field) - - selected_turn_audios.append(truncated_audio.squeeze(0)) - selected_turn_audios_lens.append(truncated_audio.shape[-1]) - - return collate_vectors(selected_turn_audios, padding_value=0), torch.tensor(selected_turn_audios_lens) - - -def collate_token_channel( - cuts: CutSet, - tokenizer: TokenizerSpec, - frame_length: Seconds, - roles: set[str], - add_text_bos_and_eos_in_each_turn: bool = True, - tokenizer_names: list[str] = None, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Build and collate token channels aligned to the audio frame grid. - """ - pad_id = get_pad_id(tokenizer) - tokens = [] - - for i, c in enumerate(cuts): - tok_name = tokenizer_names[i] if tokenizer_names else "english_phoneme" - tokens.append( - build_token_channel( - c, - tokenizer=tokenizer, - frame_length=frame_length, - roles=roles, - pad_id=pad_id, - add_text_bos_and_eos_in_each_turn=add_text_bos_and_eos_in_each_turn, - tokenizer_name=tok_name, - ) - ) - token_lens = torch.tensor([len(tt) for tt in tokens]) - tokens = collate_vectors(tokens, padding_value=pad_id) - return tokens, token_lens - - -def build_token_channel( - cut: Cut, - tokenizer: TokenizerSpec, - frame_length: Seconds, - roles: set[str], - pad_id: int = -1, - add_text_bos_and_eos_in_each_turn: bool = True, - tokenizer_name: str = "english_phoneme", -) -> torch.Tensor: - """ - Build a frame-aligned token sequence for a single cut. - """ - diagnostic = f"Extra info: {cut.id=}" - if getattr(cut, "shard_origin", None) is not None: - diagnostic = f"{diagnostic} {cut.shard_origin=}" - - total = compute_num_frames(cut.duration, frame_length, cut.sampling_rate) - tokens = torch.ones(total, dtype=torch.long) * pad_id - for supervision in cut.supervisions: - if supervision.speaker in roles: - text = supervision.text - - if hasattr(tokenizer, "encode"): - try: - raw_ids = tokenizer.encode(text=text, tokenizer_name=tokenizer_name) - except TypeError: - raw_ids = tokenizer.encode(text) - else: - raw_ids = tokenizer.text_to_ids(text) - - if add_text_bos_and_eos_in_each_turn: - text_ids = torch.as_tensor([tokenizer.bos] + raw_ids) - else: - text_ids = torch.as_tensor(raw_ids) - - pos = compute_num_frames(supervision.start, frame_length, cut.sampling_rate) - if pos > len(tokens): - logging.warning( - f"Ill-constructed example: the beginning offset of a supervision {pos} is larger than the example's length {len(tokens)}. {diagnostic}" - ) - continue - - endpos = pos + len(text_ids) - if endpos > len(tokens): - trunc_len = len(tokens) - pos - logging.warning( - f"Truncating training example's text_ids of length {len(text_ids)} by {trunc_len} because {endpos=} > {len(tokens)=}. {diagnostic}" - ) - text_ids = text_ids[:trunc_len] - try: - tokens[pos:endpos] = text_ids - except Exception as e: - raise RuntimeError(f"{tokens.shape=} {pos=} {endpos=} {text_ids.shape=} {diagnostic}") from e - - if add_text_bos_and_eos_in_each_turn: - eospos = compute_num_frames(supervision.end, frame_length, cut.sampling_rate) - if eospos < len(tokens): - tokens[eospos] = tokenizer.eos - - return tokens - - -def _strip_timestamps( - text: str, _TIMESTAMP_PATTERN=re.compile(r"<\|\d+\|>"), _SPACE_PATTERN=re.compile(r"\s+") -) -> str: - """ - Strips timestamp tokens from text. - """ - text = _TIMESTAMP_PATTERN.sub("", text) - return _SPACE_PATTERN.sub(" ", text).strip() - - -def sample_audio_segments_repeat( - prompt_audio: torch.Tensor, - prompt_audio_lens: torch.Tensor, - n_sample: int, - sample: bool = True, -) -> torch.Tensor: - """ - Extract audio segments of length n_sample. - """ - B, T = prompt_audio.shape - device = prompt_audio.device - out = torch.zeros(B, n_sample, device=device, dtype=prompt_audio.dtype) - - for b in range(B): - length = min(prompt_audio_lens[b].item(), T) - - if length <= 0: - continue - - if length >= n_sample: - if sample: - max_start = max(1, length - n_sample + 1) - start = torch.randint(0, max_start, (1,), device=device).item() - else: - start = 0 - out[b] = prompt_audio[b, start : start + n_sample] - - else: - start = 0 - segment = prompt_audio[b, start:length] - - repeat_times = (n_sample + (length - start) - 1) // (length - start) - repeated = segment.repeat(repeat_times)[:n_sample] - out[b] = repeated - - return out diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py index e86779ce15f5..47985db76b94 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py @@ -222,11 +222,11 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: batch_tokenizer_names.append("english_phoneme") with fp32_precision(): - source_audio, source_audio_lens = collate_audio(cuts.resample(self.source_sample_rate)) target_audio, target_audio_lens = collate_audio( cuts.resample(self.sample_rate, recording_field="target_audio"), recording_field="target_audio" ) - + source_audio, source_audio_lens = collate_audio(cuts.resample(self.source_sample_rate)) + # Apply volume norm if requested if self.volume_norm: source_audio = torch.stack([torch.from_numpy(normalize_volume(a.numpy())) for a in source_audio]) @@ -276,7 +276,7 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) effective_duration_max = max(self.context_duration_min, effective_duration_max) return random.uniform(self.context_duration_min, effective_duration_max) - for cut in cuts: + for i, cut in enumerate(cuts): speaker_found = False for sup in reversed(cut.supervisions): if check_speaker_format(sup.speaker): @@ -287,6 +287,7 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) if not speaker_found: dataset_name = "unknown" dataset_name_list.append(dataset_name) + print("Language is available?", cut.has_custom("lang"), " Has codes?", cut.has_custom("target_codes"), "Has context audio?", cut.has_custom("context_audio"), "Has context codes?", cut.has_custom("context_codes")) language = cut.lang if cut.has_custom("lang") else next((sup.language for sup in reversed(cut.supervisions) if sup.has_custom("language")), "en") language_list.append(language) @@ -425,7 +426,8 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) # Align Prior (Note: Using full target length to preserve shape compatibility) if self.include_align_prior: - full_text_len = sum([len(self.text_tokenizer.encode(sup.text)) for sup in cut.supervisions if sup.speaker in self.output_roles]) + tok_name = batch_tokenizer_names[i] + full_text_len = sum([len(self.text_tokenizer.encode(sup.text, tokenizer_name=tok_name)) for sup in cut.supervisions if sup.speaker in self.output_roles]) spec_len = int(cut.duration * self.sample_rate / self.codec_model_samples_per_frame) + 1 align_prior = beta_binomial_prior_distribution(phoneme_count=full_text_len, mel_count=spec_len, scaling_factor=self.prior_scaling_factor) prior_list.append(torch.tensor(align_prior, dtype=torch.float32)) @@ -434,20 +436,19 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) if reward is not None: reward_list.append(reward) - # --- ASSEMBLE FINAL BATCH DICTIONARY --- batch_dict = { "sample_id": [str(cut.id) for cut in cuts], "dataset_names": dataset_name_list, "languages": language_list, "source_audio": source_audio, "source_audio_lens": source_audio_lens, - "target_audio": target_audio, - "target_audio_lens": target_audio_lens, + "audio": target_audio, + "audio_lens": target_audio_lens, "source_tokens": source_tokens, "source_token_lens": source_token_lens, - "target_text_tokens": target_text_tokens, - "target_token_lens": target_token_lens, - "target_texts": [" ".join(s.text for s in cut.supervisions if s.speaker in self.output_roles) for cut in cuts], + "text": target_text_tokens, + "text_lens": target_token_lens, + "raw_texts": [" ".join(s.text for s in cut.supervisions if s.speaker in self.output_roles) for cut in cuts], "dataset_type": [getattr(c, "type", "") for c in cuts], } diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index 9111dcde9b33..4c6a18419d5e 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -31,6 +31,7 @@ from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.asr.parts.mixins.transcription import TranscribeConfig from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config +from nemo.collections.common.data.fallback import FallbackDataset from nemo.collections.tts.data.text_to_speech_dataset_lhotse import MagpieTTSLhotseDataset, setup_tokenizers from nemo.collections.tts.data.text_to_speech_dataset_lhotse_multiturn import MagpieTTSLhotseMultiturnDataset @@ -1420,6 +1421,7 @@ def get_lhotse_dataloader(self, dataset_cfg, mode='train') -> torch.utils.data.D add_text_bos_and_eos_in_each_turn=self.cfg.get("add_text_bos_and_eos_in_each_turn", False), # pronunciation_control_g2p=self.cfg.get("pronunciation_control_g2p", None), ) + dataset = FallbackDataset(dataset) else: dataset = MagpieTTSLhotseDataset( sample_rate=self.sample_rate, @@ -1481,7 +1483,7 @@ def setup_training_data(self, dataset_cfg): ) def _setup_test_dataloader(self, dataset_cfg) -> torch.utils.data.DataLoader: - if self.cfg.get("use_lhotse", False): + if dataset_cfg.get("use_lhotse", False): data_loader = self.get_lhotse_dataloader(dataset_cfg, mode='test') else: dataset = self.get_dataset(dataset_cfg, dataset_type='test') From 3b447763dce8dab84f44f7c9e0ff9c84bccba1e3 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 14 Apr 2026 13:52:47 -0700 Subject: [PATCH 010/109] Add multiturn config Signed-off-by: Edresson Casanova --- .../easy_magpietts_lhotse_multiturn.yaml | 228 ++++++++++++++++++ ...text_to_speech_dataset_lhotse_multiturn.py | 16 +- 2 files changed, 241 insertions(+), 3 deletions(-) create mode 100644 examples/tts/conf/magpietts/easy_magpietts_lhotse_multiturn.yaml diff --git a/examples/tts/conf/magpietts/easy_magpietts_lhotse_multiturn.yaml b/examples/tts/conf/magpietts/easy_magpietts_lhotse_multiturn.yaml new file mode 100644 index 000000000000..84adce615bd2 --- /dev/null +++ b/examples/tts/conf/magpietts/easy_magpietts_lhotse_multiturn.yaml @@ -0,0 +1,228 @@ +name: Magpie-TTS-DecoderOnly-EN + +quadratic_duration: 20 + +# Adjust batch size based on GPU memory +# When doing weighted sampling with multiple manifests, this defines how many training steps are in an epoch. +# If null, then weighted sampling is disabled. + +model: + use_lhotse: true + + # Decoder backend selection + # Options: "huggingface" (default), "nemotron_h" + decoder_type: "huggingface" + + # HuggingFace backend config (used when decoder_type: "huggingface") + transformer_hf_backend: "Qwen/Qwen2.5-1.5B" + + # NemotronH config (used when decoder_type: "nemotron_h") + # Hybrid Mamba2/MoE/Attention model (~3B total, ~600-800M active). Layer types via hybrid_override_pattern: + # 'M' = Mamba2 layer, '*' = Attention layer, '-' = MLP layer, 'E' = MoE layer + nemotron_h_config: + hidden_size: 1536 # Should match embedding_dim + num_hidden_layers: 48 + vocab_size: 131072 + # Attention config + num_attention_heads: 12 + num_key_value_heads: 4 + attention_dropout: 0.0 + attention_bias: false + max_position_embeddings: 8192 + # Mamba config + mamba_num_heads: 64 + mamba_head_dim: 24 + ssm_state_size: 128 + conv_kernel: 4 + n_groups: 8 + chunk_size: 256 + mamba_hidden_act: "silu" + use_conv_bias: true + use_bias: false + # MLP config + intermediate_size: 4096 + mlp_hidden_act: "silu" + mlp_bias: false + # MoE config (scaled from Nemotron-3-Nano-30B-A3B) + n_routed_experts: 48 + num_experts_per_tok: 6 + moe_intermediate_size: 1024 + moe_shared_expert_intermediate_size: 2048 + n_group: 1 + topk_group: 1 + routed_scaling_factor: 2.5 + norm_topk_prob: true + # Layer pattern: (M E M E M *) x 8 => 16 Mamba, 16 MoE, 8 Attention + hybrid_override_pattern: "MEMEM*MEMEM*MEMEM*MEMEM*MEMEM*MEMEM*MEMEM*MEMEM*" + # Normalization + layer_norm_epsilon: 1e-5 + residual_in_fp32: true + + use_text_conditioning_encoder: true # If true, distilbert will be used to encode context_text if provided. + context_duration_min: 5.0 + context_duration_max: 5.0 + load_cached_codes_if_available: true + + embedding_dim: 1536 + hidden_dim: 1536 + audio_embedding_dim: 1536 # Can set a smaller dimension for audio embeddings to reduce parameters. Set equal to hidden_dim for no projection. + codecmodel_path: ??? + + # Local transformer parameters for autoregressive codebook prediction within a frame + local_transformer_type: "autoregressive" # "none", "autoregressive" + # Below args are only relevant if use_local_transformer is autoregressive + local_transformer_loss_scale: 1.0 + phoneme_loss_weight: 1.0 + local_transformer_n_layers: 3 + local_transformer_n_heads: 12 + local_transformer_hidden_dim: 1536 + + cfg_unconditional_prob: 0.05 + + # Multi-mode training configuration + training_modes: + - text_input_mode: "streaming" # Options: "full", "streaming" + streaming_phonemes_delay: 0 + streaming_speech_delay: 1 + + frame_stacking_factor: 2 + phoneme_stacking_factor: 1 + phoneme_confidence_unk_threshold: 0.0 # If max phoneme probability is below this threshold at inference-time, replace the predicted timestep with UNK to reduce error propagation. + dropout_text_input_prob: 0.1 + phoneme_corruption_batch_prob: 0.1 + phoneme_corruption_timestep_ratio: 0.15 + phoneme_corruption_unk_mode_prob: 0.5 + phoneme_corruption_type: "repeat_skip_unk" # "repeat_skip_unk" or "complete_channel" + + phoneme_tokenizer: + _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPABPETokenizer + tokenizer_path: ??? + + text_tokenizers: + nemotron_nano_30b: + _target_: AutoTokenizer + pretrained_model: "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16" + + train_ds: + use_lhotse: ${model.use_lhotse} + volume_norm: true + dataset: + multi_config: true + shuffle: true + seed: 42 + shard_seed: "trng" + + sampler_fusion: randomized_round_robin + sampler_weights: + tts_data: 0.5 + duplex_data: 0.5 + tts_data: + min_duration: 0.2 + min_context_speaker_similarity: 0.6 + max_cer: 0.03 + batch_duration : ??? # in seconds. Adjust based on your GPU memory. + quadratic_duration: ${quadratic_duration} + use_bucketing: true + num_buckets: 20 + bucket_buffer_size: 15_000 + shuffle_buffer_size: 15_000 + num_cuts_for_bins_estimate: 15_000 + shard_seed: "trng" + drop_last: true + shuffle: true + num_workers: 6 + pin_memory: true + + input_cfg: + - type: lhotse_shar + shar_path: ??? + weight: 1.0 + tags: + tokenizer_names: ["english_phoneme"] + + duplex_data: + input_cfg: /lustre/fsw/convai_convaird_nemo-speech/data/duplex/multispeaker_syn_duplex.yaml + use_bucketing: true + num_buckets: 20 + bucket_buffer_size: 5_000 + shuffle_buffer_size: 5_000 + num_cuts_for_bins_estimate: 5_000 + max_duration: 300 # 5 mi max duration + bucket_duration_bins: [4.0, 8.9, 10.2, 11.6, 13.2, 15.0, 17.0, 19.3, 25.0, 31.5, 38.5, 46.0, 55.5, 66.5, 79.5, 93.3, 110.0, 130.0, 156.8, 203.3] + bucket_batch_size: [75, 33, 29, 25, 23, 20, 18, 15, 12, 10, 8, 7, 5, 4, 3, 3, 2, 2, 1, 1] + + + validation_ds: + use_lhotse: ${model.use_lhotse} + volume_norm: true + + dataset: + min_duration: 0.2 + min_context_speaker_similarity: 0.6 + max_cer: 0.03 + batch_duration: ??? # recommend to use smaller batch_duration for validation dataset than training dataset. + quadratic_duration: ${quadratic_duration} + use_bucketing: false + force_finite: true + force_map_dataset: true + drop_last: false + shuffle: false + num_workers: 2 + pin_memory: true + seed: 42 + shard_seed: "randomized" + + input_cfg: + - type: lhotse_shar + shar_path: ??? + weight: 1.0 + tags: + tokenizer_names: ["english_phoneme"] + + optim: + _target_: torch.optim.AdamW + lr: 1e-4 + + sched: + name: ExponentialLR + gamma: 0.998 + +trainer: + num_nodes: 1 + devices: -1 + accelerator: gpu + strategy: ddp_find_unused_parameters_true + precision: bf16-mixed + max_steps: ??? + accumulate_grad_batches: 1 + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + log_every_n_steps: 100 + limit_train_batches: 1_000 + val_check_interval: 1_000 + num_sanity_val_steps: 0 + benchmark: false + use_distributed_sampler: false # required because Lhotse has its own handling + gradient_clip_val: 2.5 + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_wandb_logger: false + wandb_logger_kwargs: + entity: null + name: ${name} + project: null + group: null + resume: true + create_checkpoint_callback: true + checkpoint_callback_params: + monitor: val_loss + mode: min + save_top_k: 5 + save_best_model: true + always_save_nemo: true + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.4f}-{step}-{epoch}' + resume_if_exists: true + resume_ignore_no_checkpoint: true \ No newline at end of file diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py index 47985db76b94..a864ceb3f988 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py @@ -221,6 +221,14 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: else: batch_tokenizer_names.append("english_phoneme") + def _align_codebooks(t): + C = t.shape[1] + if C < self.num_audio_codebooks: + return F.pad(t, (0, self.num_audio_codebooks - C)) + elif C > self.num_audio_codebooks: + return t[:, :self.num_audio_codebooks] + return t + with fp32_precision(): target_audio, target_audio_lens = collate_audio( cuts.resample(self.sample_rate, recording_field="target_audio"), recording_field="target_audio" @@ -287,7 +295,7 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) if not speaker_found: dataset_name = "unknown" dataset_name_list.append(dataset_name) - print("Language is available?", cut.has_custom("lang"), " Has codes?", cut.has_custom("target_codes"), "Has context audio?", cut.has_custom("context_audio"), "Has context codes?", cut.has_custom("context_codes")) + # print("Language is available?", cut.has_custom("lang"), " Has codes?", cut.has_custom("target_codes"), "Has context audio?", cut.has_custom("context_audio"), "Has context codes?", cut.has_custom("context_codes")) language = cut.lang if cut.has_custom("lang") else next((sup.language for sup in reversed(cut.supervisions) if sup.has_custom("language")), "en") language_list.append(language) @@ -314,8 +322,9 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) _num_repeats = int(np.ceil(_num_frames_to_slice / context_audio_codes.shape[1])) context_audio_codes = context_audio_codes.repeat(1, _num_repeats)[:, :_num_frames_to_slice] - context_audio_codes_list.append(context_audio_codes.T) - context_audio_codes_len_list.append(context_audio_codes.T.shape[0]) + context_audio_codes = _align_codebooks(context_audio_codes.T) + context_audio_codes_list.append(context_audio_codes) + context_audio_codes_len_list.append(context_audio_codes.shape[0]) elif cut.has_custom("context_audio"): with fp32_precision(): @@ -348,6 +357,7 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) start_frame = int(max(0, sup.start) * self.sample_rate / self.codec_model_samples_per_frame) num_frames = int(sup.duration * self.sample_rate / self.codec_model_samples_per_frame) context_audio_codes = torch.from_numpy(codes_array)[:, start_frame : start_frame + num_frames].T + context_audio_codes = _align_codebooks(context_audio_codes) else: context_audio_codes = torch.zeros([0, self.num_audio_codebooks], dtype=torch.int32) context_audio_codes_list.append(context_audio_codes) From 7048d3ce28702ebbd073dabe2e025da92231ab99 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 14 Apr 2026 19:29:34 -0700 Subject: [PATCH 011/109] Add a intermediary fix for prior Signed-off-by: Edresson Casanova --- .../data/text_to_speech_dataset_lhotse_multiturn.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py index a864ceb3f988..c145ff2e5dc0 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py @@ -438,7 +438,17 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) if self.include_align_prior: tok_name = batch_tokenizer_names[i] full_text_len = sum([len(self.text_tokenizer.encode(sup.text, tokenizer_name=tok_name)) for sup in cut.supervisions if sup.speaker in self.output_roles]) - spec_len = int(cut.duration * self.sample_rate / self.codec_model_samples_per_frame) + 1 + + if self.add_text_bos_and_eos_in_each_turn: + full_text_len += 2 * sum([1 for sup in cut.supervisions if sup.speaker in self.output_roles]) + + full_text_len = max(1, full_text_len) + + if self.load_cached_codes_if_available and cut.has_custom("target_codes"): + spec_len = int(target_codes_list[-1].shape[0]) + 1 + else: + spec_len = int(target_audio_lens[i] / self.codec_model_samples_per_frame) + 2 # +1 extra in case it was truncated + align_prior = beta_binomial_prior_distribution(phoneme_count=full_text_len, mel_count=spec_len, scaling_factor=self.prior_scaling_factor) prior_list.append(torch.tensor(align_prior, dtype=torch.float32)) From 1a56d9e475e723dd15f5b64ce36cdbdde76f48a8 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 14 Apr 2026 19:49:25 -0700 Subject: [PATCH 012/109] Fix audio tokens name Signed-off-by: Edresson Casanova --- .../tts/data/text_to_speech_dataset_lhotse_multiturn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py index c145ff2e5dc0..378ab5714b01 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py @@ -473,8 +473,8 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) } if target_codes_list: - batch_dict["target_codes"] = collate_matrices(target_codes_list, padding_value=0).transpose(1, 2) - batch_dict["target_codes_lens"] = torch.IntTensor([c.shape[0] for c in target_codes_list]) + batch_dict["audio_codes"] = collate_matrices(target_codes_list, padding_value=0).transpose(1, 2) + batch_dict["audio_codes_lens"] = torch.IntTensor([c.shape[0] for c in target_codes_list]) if source_codes_list: batch_dict["source_codes"] = collate_matrices(source_codes_list, padding_value=0).transpose(1, 2) From 7bcaa8e9b629e7a1fdb9e6fab4691182033eca8a Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 15 Apr 2026 15:23:29 -0700 Subject: [PATCH 013/109] Add formatter to support json dataset on lhotse inference Signed-off-by: Edresson Casanova --- .../easy_magpietts_lhotse_multiturn.yaml | 2 +- nemo/collections/common/data/lhotse/cutset.py | 115 +----------------- 2 files changed, 7 insertions(+), 110 deletions(-) diff --git a/examples/tts/conf/magpietts/easy_magpietts_lhotse_multiturn.yaml b/examples/tts/conf/magpietts/easy_magpietts_lhotse_multiturn.yaml index 84adce615bd2..1f4c10b071c3 100644 --- a/examples/tts/conf/magpietts/easy_magpietts_lhotse_multiturn.yaml +++ b/examples/tts/conf/magpietts/easy_magpietts_lhotse_multiturn.yaml @@ -178,7 +178,7 @@ model: weight: 1.0 tags: tokenizer_names: ["english_phoneme"] - + optim: _target_: torch.optim.AdamW lr: 1e-4 diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index 5ea4065d5fc1..141b91b65c63 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -14,6 +14,8 @@ """Lhotse CutSet utilities and Parquet manifest support for NeMo.""" import io +import os +import json import logging import random import re @@ -57,6 +59,9 @@ ) from nemo.collections.common.parts.preprocessing.manifest import get_full_path +from lhotse import Recording, AudioSource, SupervisionSegment, MonoCut, CutSet + +from pydub.utils import mediainfo def temperature_reweighting(weights: List[Union[float, int]], temperature: float = 1.0) -> List[float]: """ @@ -812,115 +817,6 @@ def cut_to_conversation( ) -def _filter_cer_fn(cut: Cut, max_cer: float) -> bool: - return ( - len(cut.supervisions) == 0 - or not cut.supervisions[0].has_custom("cer") - or cut.supervisions[0].cer <= max_cer - ) - -def _filter_val_flag_fn(cut: Cut, keep_flag: str) -> bool: - return not cut.has_custom("validation_status") or cut.validation_status == keep_flag - -def _filter_secs_fn(cut: Cut, min_sim: float) -> bool: - return ( - len(cut.supervisions) == 0 - or not cut.supervisions[0].has_custom("context_speaker_similarity") - or cut.supervisions[0].context_speaker_similarity >= min_sim - ) - -def _filter_target_speaker_fn(cut: Cut, target_speaker: str) -> bool: - return len(cut.supervisions) == 0 or target_speaker is None or target_speaker in cut.supervisions[0].speaker - -def _create_recording_from_array(samples: np.ndarray, sampling_rate: int, recording_id: str) -> Recording: - with io.BytesIO() as buffer: - sf.write(buffer, samples.T, samplerate=sampling_rate, format='WAV') - buffer.seek(0) - return Recording.from_bytes(buffer.read(), recording_id=recording_id) - -def _convert_cut_fn(cut: Cut, sample_rate: int, add_extra_end_sil: bool, extra_end_silence_range: list) -> Cut: - orig_agent_sup = fastcopy(cut.supervisions[0]) - target_audio_orig_dur = cut.target_audio.duration - - cut.target_audio = cut.target_audio.resample(sample_rate) - cut.context_audio = cut.context_audio.resample(sample_rate) - total_duration = cut.target_audio.duration - - cut_target = MonoCut( - id=f"{cut.id}_target", - start=0.0, - duration=total_duration, - channel=0, - recording=cut.target_audio, - supervisions=[], - ) - - zero_audio = np.zeros((1, int(total_duration * sample_rate)), dtype=np.float32) - source_recording = _create_recording_from_array(zero_audio, sample_rate, recording_id=f"{cut.id}_source") - - cut_source = MonoCut( - id=f"{cut.id}_source", - start=0.0, - duration=total_duration, - channel=0, - recording=source_recording, - supervisions=[], - custom=deepcopy(cut.custom) if cut.custom is not None else None, - ) - - cut_source = cut_source.move_to_memory(audio_format='wav') - cut_target = cut_target.move_to_memory(audio_format='wav') - - user_sup = fastcopy(orig_agent_sup, start=0.0, duration=0.08, speaker="user", text="dummy text") - agent_sup = fastcopy(orig_agent_sup, start=0.0, duration=target_audio_orig_dur - 0.08, speaker="agent") - - if user_sup.custom is not None and "ipa" in user_sup.custom: - user_sup.custom = deepcopy(user_sup.custom) - user_sup.custom["ipa"] = "" - - if add_extra_end_sil: - sil_duration = random.uniform(*extra_end_silence_range) - cut_target = cut_target.pad(duration=total_duration + sil_duration, direction="right") - cut_source = cut_source.pad(duration=total_duration + sil_duration, direction="right") - cut_source = cut_source.to_mono().move_to_memory(audio_format='wav') - cut_target = cut_target.to_mono().move_to_memory(audio_format='wav') - agent_sup.duration += sil_duration + 1.0 - user_sup.duration += sil_duration - - cut_source.supervisions = [user_sup, agent_sup] - cut_source.target_audio = cut_target.recording - cut_source.duration = cut_target.duration - cut_source.context_audio = cut.context_audio - cut_source.task = "lhotse_magpietts_data_as_continuation" - - return cut_source - - -@data_type_parser(["lhotse_magpietts_data_as_continuation"]) -def read_lhotse_magpietts_data_as_s2s_duplex(config) -> Tuple[CutSet, bool]: - cuts, is_tarred = read_cutset_from_config(config) - - add_extra_end_sil = config.get("add_extra_end_silence", False) - extra_end_silence_range = config.get("extra_end_silence_range", [0.5, 6.0]) - sample_rate = config.get("sample_rate", 22050) - - max_cer = config.get("max_cer", 0.03) - min_context_speaker_similarity = config.get("min_context_speaker_similarity", 0.6) - target_speaker = config.get("target_speaker", None) - keep_flag = "pass" - - cuts = ( - cuts.filter(partial(_filter_cer_fn, max_cer=max_cer)) - .filter(partial(_filter_val_flag_fn, keep_flag=keep_flag)) - .filter(partial(_filter_secs_fn, min_sim=min_context_speaker_similarity)) - .filter(partial(_filter_target_speaker_fn, target_speaker=target_speaker)) - ) - - cuts = cuts.map(partial(_convert_cut_fn, sample_rate=sample_rate, add_extra_end_sil=add_extra_end_sil, extra_end_silence_range=extra_end_silence_range)) - - return cuts, is_tarred - - class FilterCER: def __init__(self, max_cer: float): self.max_cer = max_cer @@ -1034,6 +930,7 @@ def __call__(self, cut: Cut) -> Cut: return cut_source + @data_type_parser(["lhotse_magpietts_data_as_continuation"]) def read_lhotse_magpietts_data_as_s2s_duplex(config) -> Tuple[CutSet, bool]: cuts, is_tarred = read_cutset_from_config(config) From a741d4413779304760fd3ba7015dfb1a9e925427 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 16 Apr 2026 13:47:21 -0700 Subject: [PATCH 014/109] Update inference to support multiturn dataloader Signed-off-by: Edresson Casanova --- ...text_to_speech_dataset_lhotse_multiturn.py | 39 +++++++++---------- nemo/collections/tts/models/easy_magpietts.py | 28 +++++++++++-- .../tts/models/easy_magpietts_inference.py | 19 ++++++++- 3 files changed, 61 insertions(+), 25 deletions(-) diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py index 378ab5714b01..15cd545f2cab 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py @@ -243,24 +243,22 @@ def _align_codebooks(t): target_text_tokens, target_token_lens = collate_token_channel( cuts, self.text_tokenizer, self.frame_length, roles=self.output_roles, add_text_bos_and_eos_in_each_turn=self.add_text_bos_and_eos_in_each_turn, - tokenizer_names=batch_tokenizer_names, + tokenizer_names=batch_tokenizer_names, pad_id=self.pad_id, eos_id=self.eos_id, bos_id=self.bos_id, ) source_tokens, source_token_lens = collate_token_channel( cuts, self.text_tokenizer, self.frame_length, roles=self.input_roles, add_text_bos_and_eos_in_each_turn=self.add_text_bos_and_eos_in_each_turn, - tokenizer_names=batch_tokenizer_names, + tokenizer_names=batch_tokenizer_names, pad_id=self.pad_id, eos_id=self.eos_id, bos_id=self.bos_id, ) if self.phoneme_tokenizer is not None: target_phoneme_tokens, target_phoneme_lens = collate_phoneme_channel( cuts, self.phoneme_tokenizer, self.frame_length, roles=self.output_roles, - ignore_phoneme_languages=self.ignore_phoneme_languages, - add_text_bos_and_eos_in_each_turn=False, + ignore_phoneme_languages=self.ignore_phoneme_languages, pad_id=self.phoneme_tokenizer.pad, eos_id=self.phoneme_tokenizer.eos_token_id, bos_id=self.phoneme_tokenizer.bos_token_id, ) else: target_phoneme_tokens, target_phoneme_lens = None, None - dataset_name_list = [] audio_list_16khz = [] audio_len_list_16khz = [] @@ -522,16 +520,18 @@ def collate_token_channel( roles: set[str], add_text_bos_and_eos_in_each_turn: bool = True, tokenizer_names: list[str] = None, + pad_id: int = None, + eos_id: int = None, + bos_id: int = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Build and collate token channels aligned to the audio frame grid.""" - pad_id = getattr(tokenizer, 'pad', -1) tokens = [] - + for i, c in enumerate(cuts): tok_name = tokenizer_names[i] if tokenizer_names else "english_phoneme" tokens.append( build_token_channel( - c, tokenizer, frame_length, roles, pad_id, + c, tokenizer, frame_length, roles, pad_id, eos_id, bos_id, add_text_bos_and_eos_in_each_turn, tok_name ) ) @@ -545,13 +545,13 @@ def build_token_channel( frame_length: Seconds, roles: set[str], pad_id: int = -1, + eos_id: int = -2, + bos_id: int = -3, add_text_bos_and_eos_in_each_turn: bool = True, tokenizer_name: str = "english_phoneme", ) -> torch.Tensor: total = compute_num_frames(cut.duration, frame_length, cut.sampling_rate) tokens = torch.ones(total, dtype=torch.long) * pad_id - bos_id = getattr(tokenizer, 'bos', 0) - eos_id = getattr(tokenizer, 'eos', 1) for supervision in cut.supervisions: if supervision.speaker in roles: @@ -565,7 +565,7 @@ def build_token_channel( else: raw_ids = tokenizer.text_to_ids(text) - text_ids = torch.as_tensor([bos_id] + raw_ids) if add_text_bos_and_eos_in_each_turn else torch.as_tensor(raw_ids) + text_ids = torch.as_tensor(raw_ids + [eos_id]) pos = compute_num_frames(supervision.start, frame_length, cut.sampling_rate) if pos >= len(tokens): @@ -590,13 +590,15 @@ def collate_phoneme_channel( frame_length: Seconds, roles: set[str], ignore_phoneme_languages: list[str], - add_text_bos_and_eos_in_each_turn: bool = True, + pad_id: int = -1, + eos_id: int = -2, + bos_id: int = -3, ) -> tuple[torch.Tensor, torch.Tensor]: pad_id = phoneme_tokenizer.pad tokens = [ build_phoneme_channel( c, phoneme_tokenizer, frame_length, roles, - ignore_phoneme_languages, pad_id, add_text_bos_and_eos_in_each_turn + ignore_phoneme_languages, pad_id, eos_id, bos_id ) for c in cuts ] token_lens = torch.tensor([len(tt) for tt in tokens]) @@ -610,7 +612,8 @@ def build_phoneme_channel( roles: set[str], ignore_phoneme_languages: list[str], pad_id: int = -1, - add_text_bos_and_eos_in_each_turn: bool = True, + eos_id: int = -2, + bos_id: int = -3, ) -> torch.Tensor: total = compute_num_frames(cut.duration, frame_length, cut.sampling_rate) tokens = torch.ones(total, dtype=torch.long) * pad_id @@ -627,8 +630,7 @@ def build_phoneme_channel( ipa_text = supervision.text phoneme_ids = phoneme_tokenizer.encode(ipa_text) - if add_text_bos_and_eos_in_each_turn: - phoneme_ids = [phoneme_tokenizer.bos_token_id] + phoneme_ids + phoneme_ids = [bos_id] + phoneme_ids + [eos_id] phoneme_ids = torch.as_tensor(phoneme_ids, dtype=torch.long) pos = compute_num_frames(supervision.start, frame_length, cut.sampling_rate) @@ -640,9 +642,4 @@ def build_phoneme_channel( phoneme_ids = phoneme_ids[:len(tokens) - pos] tokens[pos:pos+len(phoneme_ids)] = phoneme_ids - if add_text_bos_and_eos_in_each_turn: - eospos = compute_num_frames(supervision.end, frame_length, cut.sampling_rate) - if eospos < len(tokens): - tokens[eospos] = phoneme_tokenizer.eos_token_id - return tokens diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index 4c6a18419d5e..e9a580bc6d35 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -329,6 +329,8 @@ def prepare_text_channel_embeddings( text_lens: torch.Tensor, delay: torch.Tensor, dropout_text_input: bool = False, + is_multiturn: bool = False, + text_pad_id: int = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Prepare text embeddings as a channel input with delay handling. @@ -353,12 +355,16 @@ def prepare_text_channel_embeddings( device = text.device # Embed text tokens (CAS-only when disable_subword_embedding=True). - text_embedded = self.embed_text_tokens(text, text_lens=text_lens) # (B, L, E) + text_embedded = self.embed_text_tokens(text, text_lens=text_lens, is_multiturn=is_multiturn) # (B, L, E) # Handle text dropout - zero out the embeddings if dropout_text_input: text_embedded = text_embedded * 0.0 + # multiturn dataset returns a special pad text tokens until it matches the audio len, to keep compatible with regular dataset zero-out those values + if is_multiturn: + text_embedded[text == text_pad_id] = 0.0 + # Create zero tensor for delay padding max_delay = delay.max().item() zero_delay_tensor = torch.zeros(batch_size, max_delay, self.cfg.embedding_dim, device=device) @@ -429,8 +435,15 @@ def prepare_phoneme_channel_embeddings( phoneme_embedded = self.embed_phoneme_tokens(phoneme_tokens_stacked) # (B, T', E) # Apply mask to zero out padding - phoneme_mask = get_mask_from_lengths(phoneme_tokens_lens_stacked) - phoneme_embedded = phoneme_embedded * phoneme_mask.unsqueeze(2) # (B, T', E) + if self.cfg.get("use_multiturn_dataset", False): + phoneme_pad_id = getattr(self.phoneme_tokenizer, "pad", -1) + phoneme_mask = (phoneme_tokens_stacked[:, 0, :] != phoneme_pad_id) # Check the first layer of the stack + # Apply mask to zero out padding + phoneme_embedded = phoneme_embedded * phoneme_mask.unsqueeze(2) # (B, T', E) + else: + phoneme_mask = get_mask_from_lengths(phoneme_tokens_lens_stacked) + phoneme_embedded = phoneme_embedded * phoneme_mask.unsqueeze(2) # (B, T', E) + # Handle phoneme dropout - zero out the embeddings if dropout_complete_phoneme_channel: @@ -724,6 +737,8 @@ def process_batch( text_lens=text_lens, delay=text_delay, dropout_text_input=dropout_text_input or dropout_conditional_input, + is_multiturn=self.cfg.get("use_multiturn_dataset", False), + text_pad_id=self.tokenizer.pad, ) # 4. Prepare phoneme channel embeddings (if phoneme tokenizer is configured) @@ -1307,6 +1322,10 @@ def validation_step(self, batch, batch_idx): return val_output + def on_validation_epoch_start(self) -> None: + if torch.distributed.is_initialized(): + self.trainer.strategy.model.require_backward_grad_sync = False + def on_validation_epoch_end(self): collect = lambda key: torch.stack([x[key] for x in self.validation_step_outputs]).mean() val_loss = collect("val_loss") @@ -1363,6 +1382,9 @@ def collect_if_exists(key): self.validation_step_outputs.clear() # free memory + if torch.distributed.is_initialized(): + self.trainer.strategy.model.require_backward_grad_sync = True + def get_dataset(self, dataset_cfg, dataset_type): dataset = safe_instantiate( dataset_cfg.dataset, diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index 296d93a09315..f7c872d735af 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -675,6 +675,7 @@ def embed_text_tokens( text_tokens: torch.Tensor, text_lens: Optional[torch.Tensor] = None, disable_cas_embedding: bool = False, + is_multiturn: bool = False, ) -> torch.Tensor: """Embed text tokens using decoder embedding + optional CAS, or CAS-only when configured. @@ -683,13 +684,18 @@ def embed_text_tokens( text_lens: Optional valid token lengths for constructing the CAS mask. Defaults to the full sequence length. disable_cas_embedding: When True, skip adding CAS embeddings even if the model uses the BPE char tokenizer. This is needed for legacy models where context text was trained without CAS embeddings. + is_multiturn: When True creates the text_mask based on non text pad ids positions, so that it can support multiturn. """ if text_lens is None: text_lens = torch.full( (text_tokens.size(0),), text_tokens.size(1), dtype=torch.long, device=text_tokens.device ) - text_mask = get_mask_from_lengths(text_lens) + if is_multiturn: + text_mask = (text_tokens != self.tokenizer.pad) + else: + text_mask = get_mask_from_lengths(text_lens) + if self.disable_subword_embedding: if disable_cas_embedding: raise ValueError("Cannot disable CAS embedding when `disable_subword_embedding=True`.") @@ -1293,6 +1299,11 @@ def streaming_init( ) gt_phoneme_embeddings = self.embed_phoneme_tokens(gt_phoneme_stacked) # (B, T', E) + if self.cfg.get("use_multiturn_dataset", False): + phoneme_pad_id = getattr(self.phoneme_tokenizer, "pad", -1) + phoneme_mask = (gt_phoneme_stacked[:, 0, :] != phoneme_pad_id) + gt_phoneme_embeddings = gt_phoneme_embeddings * phoneme_mask.unsqueeze(2) + # Process GT audio codes if provided (for teacher forcing) gt_audio_embeddings = None gt_audio_lens_state = None @@ -1459,11 +1470,17 @@ def _prepare_streaming_input( text_embedded = self.embed_text_tokens( text_tokens_2d, text_lens=torch.ones(batch_size, dtype=torch.long, device=device), + is_multiturn=self.cfg.get("use_multiturn_dataset", False), ) # (B, 1, E) if force_dropout_text: text_embedded = text_embedded * 0 + # Zero out padding tokens exactly like in process_batch + if self.cfg.get("use_multiturn_dataset", False): + is_pad = (text_tokens_2d == self.tokenizer.pad) + text_embedded[is_pad] = 0.0 + is_eos_token = (text_tokens == self.eos_id) & needs_text # (B,) bool text_add_mask = needs_text.view(batch_size, 1, 1).float() next_input = next_input + text_embedded * text_add_mask From 4e2e236b39e5699c33012cb0da3e84a75667fc41 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 17 Apr 2026 09:37:57 -0700 Subject: [PATCH 015/109] Bug fix in dataloder Signed-off-by: Edresson Casanova --- ...text_to_speech_dataset_lhotse_multiturn.py | 51 ++++++++++++------- nemo/collections/tts/models/easy_magpietts.py | 34 ++++++++++++- .../tts/models/easy_magpietts_inference.py | 24 +++++++-- 3 files changed, 83 insertions(+), 26 deletions(-) diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py index 15cd545f2cab..d11dd3e7fc2f 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py @@ -137,6 +137,8 @@ def __init__( sample_rate: int, volume_norm: bool = True, codec_model_samples_per_frame: int = None, + codec_model_input_sample_rate: int = None, + frame_stacking_factor: int = None, num_audio_codebooks: int = None, prior_scaling_factor: float = None, load_cached_codes_if_available: bool = True, @@ -156,7 +158,7 @@ def __init__( source_sample_rate: int = 16000, input_roles: List[str] = ["user", "User"], output_roles: List[str] = ["assistant", "Assistant", "agent", "Agent"], - add_text_bos_and_eos_in_each_turn: bool = False, + add_text_bos: bool = False, ): super().__init__() self.sample_rate = sample_rate @@ -187,9 +189,9 @@ def __init__( self.source_sample_rate = source_sample_rate self.input_roles = set(ifnone(input_roles, ["user"])) self.output_roles = set(ifnone(output_roles, ["agent"])) - self.add_text_bos_and_eos_in_each_turn = add_text_bos_and_eos_in_each_turn + self.add_text_bos = add_text_bos - self.frame_length = self.codec_model_samples_per_frame / self.sample_rate + self.frame_length = (self.codec_model_samples_per_frame / codec_model_input_sample_rate) * frame_stacking_factor def get_num_audio_samples_to_slice(self, duration, sample_rate): num_codec_frames = int(duration * sample_rate / self.codec_model_samples_per_frame) @@ -205,8 +207,11 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: all_tokenizers_config=self.tokenizer_config, mode=self.dataset_type, ) - self.bos_id = len(self.text_tokenizer.tokens) - self.eos_id = self.bos_id + 1 + num_tokens = len(self.text_tokenizer.tokens) + self.bos_id = num_tokens + self.eos_id = num_tokens + 1 + self.cfg_unk_token_id = num_tokens + 2 + self.interruption_token_id = num_tokens + 3 self.pad_id = self.text_tokenizer.pad if self.phoneme_tokenizer is None and self.phoneme_tokenizer_config is not None: @@ -242,13 +247,13 @@ def _align_codebooks(t): target_text_tokens, target_token_lens = collate_token_channel( cuts, self.text_tokenizer, self.frame_length, roles=self.output_roles, - add_text_bos_and_eos_in_each_turn=self.add_text_bos_and_eos_in_each_turn, - tokenizer_names=batch_tokenizer_names, pad_id=self.pad_id, eos_id=self.eos_id, bos_id=self.bos_id, + add_text_bos=self.add_text_bos, tokenizer_names=batch_tokenizer_names, + pad_id=self.pad_id, eos_id=self.eos_id, bos_id=self.bos_id, interruption_token_id=self.interruption_token_id, ) source_tokens, source_token_lens = collate_token_channel( cuts, self.text_tokenizer, self.frame_length, roles=self.input_roles, - add_text_bos_and_eos_in_each_turn=self.add_text_bos_and_eos_in_each_turn, - tokenizer_names=batch_tokenizer_names, pad_id=self.pad_id, eos_id=self.eos_id, bos_id=self.bos_id, + add_text_bos=self.add_text_bos, tokenizer_names=batch_tokenizer_names, + pad_id=self.pad_id, eos_id=self.eos_id, bos_id=self.bos_id, interruption_token_id=self.interruption_token_id, ) if self.phoneme_tokenizer is not None: @@ -437,8 +442,11 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) tok_name = batch_tokenizer_names[i] full_text_len = sum([len(self.text_tokenizer.encode(sup.text, tokenizer_name=tok_name)) for sup in cut.supervisions if sup.speaker in self.output_roles]) - if self.add_text_bos_and_eos_in_each_turn: + if self.add_text_bos: full_text_len += 2 * sum([1 for sup in cut.supervisions if sup.speaker in self.output_roles]) + else: + # cont eos token + full_text_len += sum([1 for sup in cut.supervisions if sup.speaker in self.output_roles]) full_text_len = max(1, full_text_len) @@ -518,11 +526,12 @@ def collate_token_channel( tokenizer, frame_length: Seconds, roles: set[str], - add_text_bos_and_eos_in_each_turn: bool = True, + add_text_bos: bool = True, tokenizer_names: list[str] = None, pad_id: int = None, eos_id: int = None, bos_id: int = None, + interruption_token_id: int = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Build and collate token channels aligned to the audio frame grid.""" tokens = [] @@ -531,8 +540,8 @@ def collate_token_channel( tok_name = tokenizer_names[i] if tokenizer_names else "english_phoneme" tokens.append( build_token_channel( - c, tokenizer, frame_length, roles, pad_id, eos_id, bos_id, - add_text_bos_and_eos_in_each_turn, tok_name + c, tokenizer, frame_length, roles, pad_id, eos_id, bos_id, interruption_token_id, + add_text_bos, tok_name ) ) token_lens = torch.tensor([len(tt) for tt in tokens]) @@ -547,7 +556,8 @@ def build_token_channel( pad_id: int = -1, eos_id: int = -2, bos_id: int = -3, - add_text_bos_and_eos_in_each_turn: bool = True, + interruption_token_id: int = -4, + add_text_bos: bool = True, tokenizer_name: str = "english_phoneme", ) -> torch.Tensor: total = compute_num_frames(cut.duration, frame_length, cut.sampling_rate) @@ -565,7 +575,10 @@ def build_token_channel( else: raw_ids = tokenizer.text_to_ids(text) - text_ids = torch.as_tensor(raw_ids + [eos_id]) + if add_text_bos: + text_ids = torch.as_tensor([bos_id] + raw_ids + [eos_id]) + else: + text_ids = torch.as_tensor(raw_ids + [eos_id]) pos = compute_num_frames(supervision.start, frame_length, cut.sampling_rate) if pos >= len(tokens): @@ -576,10 +589,10 @@ def build_token_channel( text_ids = text_ids[:len(tokens) - pos] tokens[pos:pos+len(text_ids)] = text_ids - if add_text_bos_and_eos_in_each_turn: - eospos = compute_num_frames(supervision.end, frame_length, cut.sampling_rate) - if eospos < len(tokens): - tokens[eospos] = eos_id + # add interruption token, used for add speech eos and interrupt the model + interruption_pos = compute_num_frames(supervision.end, frame_length, cut.sampling_rate) + if interruption_pos < len(tokens): + tokens[interruption_pos] = interruption_token_id return tokens diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index e9a580bc6d35..db01bdf14824 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -539,6 +539,7 @@ def prepare_audio_channel_embeddings( audio_codes: torch.Tensor, audio_codes_lens: torch.Tensor, delay: torch.Tensor, + speech_eos_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Prepare audio embeddings as a channel input with delay handling. @@ -588,6 +589,25 @@ def prepare_audio_channel_embeddings( self.num_audio_codebooks, ) + if speech_eos_mask is not None: + # 1. Shift the mask +1 to the right to account for the token and +1 for the EOS + # prepended by add_special_tokens. + B_mask, T_mask = speech_eos_mask.shape + shifted_mask = torch.zeros((B_mask, T_mask + 2), dtype=torch.bool, device=device) + shifted_mask[:, 2:] = speech_eos_mask + + # 2. Find the minimum overlapping time dimension + t_mask = shifted_mask.size(1) + t_audio = audio_codes.size(2) + min_t = min(t_mask, t_audio) + + # 3. Slice both to the valid overlap and broadcast the C dimension + valid_mask = shifted_mask[:, :min_t] + expanded_mask = valid_mask.unsqueeze(1).expand(-1, audio_codes.size(1), -1) + + # Inject the EOS token only into the overlapping region + audio_codes[:, :, :min_t][expanded_mask] = self.audio_eos_id + # Prepare input and target for autoregressive training # Input: all tokens except the last (teacher forcing) # Target: all tokens except the first (shifted by one) @@ -731,6 +751,13 @@ def process_batch( # Streaming mode: context_lens + speech_delay audio_delay = context_lens + current_streaming_speech_delay + speech_eos_mask = None + if self.cfg.get("use_multiturn_dataset", False): + speech_eos_mask = (text == self.interruption_token_id) # (B, T) + # ToDo: do not remove from interruption data + text[speech_eos_mask] = self.tokenizer.pad # Clean up the text channel + + # 3. Prepare text channel embeddings text_channel_embedding, text_channel_lens = self.prepare_text_channel_embeddings( text=text, @@ -738,7 +765,7 @@ def process_batch( delay=text_delay, dropout_text_input=dropout_text_input or dropout_conditional_input, is_multiturn=self.cfg.get("use_multiturn_dataset", False), - text_pad_id=self.tokenizer.pad, + text_pad_id=self.pad_id, ) # 4. Prepare phoneme channel embeddings (if phoneme tokenizer is configured) @@ -788,6 +815,7 @@ def process_batch( audio_codes=audio_codes, audio_codes_lens=audio_codes_lens, delay=audio_delay, + speech_eos_mask=speech_eos_mask, ) # 6. Sum the channel embeddings element-wise @@ -1423,6 +1451,8 @@ def get_lhotse_dataloader(self, dataset_cfg, mode='train') -> torch.utils.data.D sample_rate=self.sample_rate, volume_norm=dataset_cfg.volume_norm, codec_model_samples_per_frame=self.codec_model_samples_per_frame, + codec_model_input_sample_rate=self.codec_model_input_sample_rate, + frame_stacking_factor=self.frame_stacking_factor, num_audio_codebooks=self.data_num_audio_codebooks, prior_scaling_factor=0.0, load_cached_codes_if_available=self.cfg.load_cached_codes_if_available, @@ -1440,7 +1470,7 @@ def get_lhotse_dataloader(self, dataset_cfg, mode='train') -> torch.utils.data.D source_sample_rate=self.sample_rate, input_roles=["user", "User"], output_roles=["assistant", "Assistant", "agent", "Agent"], - add_text_bos_and_eos_in_each_turn=self.cfg.get("add_text_bos_and_eos_in_each_turn", False), + add_text_bos=self.cfg.get("add_text_bos", False), # pronunciation_control_g2p=self.cfg.get("pronunciation_control_g2p", None), ) dataset = FallbackDataset(dataset) diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index f7c872d735af..d869714c049f 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -257,6 +257,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.codebook_size = codebook_size self.codec_model_samples_per_frame = codec_model.samples_per_frame + self.codec_model_input_sample_rate = codec_model.sample_rate # Our codebooks start with actual audio codec tokens, followed by special tokens. # The `forced_*` options are for backward compatibility for models trained with older code. get_token_index = partial(SpecialAudioToken.get_index, base_codebook_size=self.codebook_size) @@ -335,11 +336,24 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): mode='train', ) - num_tokens_tokenizer = len(self.tokenizer.tokens) - num_tokens = num_tokens_tokenizer + 3 # +3 for BOS, EOS, CFG_UNK - self.bos_id = num_tokens - 3 - self.eos_id = num_tokens - 2 - self.cfg_unk_token_id = num_tokens - 1 + base_num_tokens = len(self.tokenizer.tokens) + + # Assign standard special tokens sequentially + self.bos_id = base_num_tokens + self.eos_id = base_num_tokens + 1 + self.cfg_unk_token_id = base_num_tokens + 2 + special_tokens_added = 3 + + # Conditionally add the interruption token + if cfg.get("use_multiturn_dataset", False): + self.interruption_token_id = base_num_tokens + special_tokens_added + special_tokens_added += 1 + + # Calculate the final total vocabulary size + num_tokens = base_num_tokens + special_tokens_added + + self.pad_id = self.tokenizer.pad + self.phoneme_tokenizer = None if cfg.get('phoneme_tokenizer', None) is not None: self.phoneme_tokenizer = safe_instantiate(cfg.phoneme_tokenizer) From 9cf7584ae067e09db60f41f17c47f7f5612fbc6f Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 17 Apr 2026 12:37:07 -0700 Subject: [PATCH 016/109] Add parameter to remove user turns Signed-off-by: Edresson Casanova --- ...text_to_speech_dataset_lhotse_multiturn.py | 184 ++++++++++++++++-- nemo/collections/tts/models/easy_magpietts.py | 2 +- 2 files changed, 168 insertions(+), 18 deletions(-) diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py index d11dd3e7fc2f..cabd1fc20849 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py @@ -159,6 +159,7 @@ def __init__( input_roles: List[str] = ["user", "User"], output_roles: List[str] = ["assistant", "Assistant", "agent", "Agent"], add_text_bos: bool = False, + remove_user_turns_prob: float = None, ): super().__init__() self.sample_rate = sample_rate @@ -166,6 +167,7 @@ def __init__( self.codec_model_samples_per_frame = codec_model_samples_per_frame self.num_audio_codebooks = num_audio_codebooks + self.remove_user_turns_prob = remove_user_turns_prob self.include_align_prior = prior_scaling_factor is not None self.prior_scaling_factor = prior_scaling_factor @@ -220,12 +222,25 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: cuts = cuts.transform_text(_strip_timestamps) batch_tokenizer_names = [] + remove_user_turn_flags = [] for cut in cuts: if cut.has_custom("tokenizer_names"): batch_tokenizer_names.append(random.choice(cut.tokenizer_names)) else: batch_tokenizer_names.append("english_phoneme") + # Get all agent supervisions in this cut + agent_sups = [sup for sup in cut.supervisions if sup.speaker in self.output_roles] + + # It is a multiturn if there's more than 1 agent turn + is_multiturn = not (len(agent_sups) == 1) + + # Apply augmentation only if it's multiturn AND passes the probability check + if is_multiturn and self.remove_user_turns_prob and random.random() < self.remove_user_turns_prob: + remove_user_turn_flags.append(True) + else: + remove_user_turn_flags.append(False) + def _align_codebooks(t): C = t.shape[1] if C < self.num_audio_codebooks: @@ -239,27 +254,66 @@ def _align_codebooks(t): cuts.resample(self.sample_rate, recording_field="target_audio"), recording_field="target_audio" ) source_audio, source_audio_lens = collate_audio(cuts.resample(self.source_sample_rate)) + target_audio_list = [] + source_audio_list = [] + # normalize volume and apply audio the removal of user turn if needed - # Apply volume norm if requested - if self.volume_norm: - source_audio = torch.stack([torch.from_numpy(normalize_volume(a.numpy())) for a in source_audio]) - target_audio = torch.stack([torch.from_numpy(normalize_volume(a.numpy())) for a in target_audio]) + for i, cut in enumerate(cuts): + remove_user_turn_this_cut = remove_user_turn_flags[i] + + # Extract the raw, unpadded 1D numpy array for this specific cut + t_audio = target_audio[i, :target_audio_lens[i]].numpy() + s_audio = source_audio[i, :source_audio_lens[i]].numpy() + if remove_user_turn_this_cut: + collapsed_t, collapsed_s = [], [] + for sup in cut.supervisions: + if sup.speaker in self.output_roles: + start_t = int(round(max(0, sup.start) * self.sample_rate)) + end_t = int(round(sup.end * self.sample_rate)) + start_s = int(round(max(0, sup.start) * self.source_sample_rate)) + end_s = int(round(sup.end * self.source_sample_rate)) + + # Clamp safely inside the array + start_t, end_t = min(start_t, len(t_audio)), min(end_t, len(t_audio)) + start_s, end_s = min(start_s, len(s_audio)), min(end_s, len(s_audio)) + + if end_t > start_t: collapsed_t.append(t_audio[start_t:end_t]) + if end_s > start_s: collapsed_s.append(s_audio[start_s:end_s]) + + t_audio = np.concatenate(collapsed_t) if collapsed_t else np.zeros(1, dtype=np.float32) + s_audio = np.concatenate(collapsed_s) if collapsed_s else np.zeros(1, dtype=np.float32) + + # Apply volume norm locally (so we only normalize the stitched audio, saving math ops) + if self.volume_norm: + t_audio = normalize_volume(t_audio) + s_audio = normalize_volume(s_audio) + target_audio_list.append(torch.from_numpy(t_audio)) + source_audio_list.append(torch.from_numpy(s_audio)) + + # 3. Re-pad the newly stitched arrays to the batch's new maximum length + target_audio = collate_vectors(target_audio_list, padding_value=0.0) + target_audio_lens = torch.tensor([len(a) for a in target_audio_list], dtype=torch.long) + + source_audio = collate_vectors(source_audio_list, padding_value=0.0) + source_audio_lens = torch.tensor([len(a) for a in source_audio_list], dtype=torch.long) + target_text_tokens, target_token_lens = collate_token_channel( cuts, self.text_tokenizer, self.frame_length, roles=self.output_roles, add_text_bos=self.add_text_bos, tokenizer_names=batch_tokenizer_names, - pad_id=self.pad_id, eos_id=self.eos_id, bos_id=self.bos_id, interruption_token_id=self.interruption_token_id, + pad_id=self.pad_id, eos_id=self.eos_id, bos_id=self.bos_id, interruption_token_id=self.interruption_token_id, remove_user_turn_flags=remove_user_turn_flags ) source_tokens, source_token_lens = collate_token_channel( cuts, self.text_tokenizer, self.frame_length, roles=self.input_roles, add_text_bos=self.add_text_bos, tokenizer_names=batch_tokenizer_names, - pad_id=self.pad_id, eos_id=self.eos_id, bos_id=self.bos_id, interruption_token_id=self.interruption_token_id, + pad_id=self.pad_id, eos_id=self.eos_id, bos_id=self.bos_id, interruption_token_id=self.interruption_token_id, remove_user_turn_flags=remove_user_turn_flags ) if self.phoneme_tokenizer is not None: target_phoneme_tokens, target_phoneme_lens = collate_phoneme_channel( cuts, self.phoneme_tokenizer, self.frame_length, roles=self.output_roles, ignore_phoneme_languages=self.ignore_phoneme_languages, pad_id=self.phoneme_tokenizer.pad, eos_id=self.phoneme_tokenizer.eos_token_id, bos_id=self.phoneme_tokenizer.bos_token_id, + remove_user_turn_flags=remove_user_turn_flags, ) else: target_phoneme_tokens, target_phoneme_lens = None, None @@ -288,6 +342,7 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) return random.uniform(self.context_duration_min, effective_duration_max) for i, cut in enumerate(cuts): + remove_user_turn_this_cut = remove_user_turn_flags[i] speaker_found = False for sup in reversed(cut.supervisions): if check_speaker_format(sup.speaker): @@ -306,9 +361,15 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) # Target and Source Codes if self.load_cached_codes_if_available: if cut.has_custom("target_codes"): - target_codes_list.append(torch.from_numpy(cut.target_codes.load().astype(np.int32)).T) + codes_array = cut.target_codes.load().astype(np.int32) + if remove_user_turn_this_cut: + raise RuntimeError("Remove user turn augmentation is not implemented for cached codes!") + target_codes_list.append(torch.from_numpy(codes_array).T) + if cut.has_custom("source_codes"): source_codes_list.append(torch.from_numpy(cut.source_codes.load().astype(np.int32)).T) + if remove_user_turn_this_cut: + raise RuntimeError("Remove user turn augmentation is not implemented for cached codes!") # Context Audio or Context Codes if self.load_cached_codes_if_available and cut.has_custom("context_codes"): @@ -532,16 +593,18 @@ def collate_token_channel( eos_id: int = None, bos_id: int = None, interruption_token_id: int = None, + remove_user_turn_flags: list[bool] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Build and collate token channels aligned to the audio frame grid.""" tokens = [] for i, c in enumerate(cuts): tok_name = tokenizer_names[i] if tokenizer_names else "english_phoneme" + flag = remove_user_turn_flags[i] if remove_user_turn_flags else False tokens.append( build_token_channel( c, tokenizer, frame_length, roles, pad_id, eos_id, bos_id, interruption_token_id, - add_text_bos, tok_name + add_text_bos, tok_name, remove_user_turns=flag, ) ) token_lens = torch.tensor([len(tt) for tt in tokens]) @@ -559,7 +622,54 @@ def build_token_channel( interruption_token_id: int = -4, add_text_bos: bool = True, tokenizer_name: str = "english_phoneme", + remove_user_turns: bool = False, ) -> torch.Tensor: + if remove_user_turns: + turn_chunks = [] + for supervision in cut.supervisions: + if supervision.speaker in roles: + # 1. Get exact frame length of THIS turn + start_f = compute_num_frames(supervision.start, frame_length, cut.sampling_rate) + end_f = compute_num_frames(supervision.end, frame_length, cut.sampling_rate) + turn_frames = max(0, end_f - start_f) + + turn_tokens = torch.ones(turn_frames, dtype=torch.long) * pad_id + + if turn_frames == 0: + continue + + # 2. Encode text + text = supervision.text + if hasattr(tokenizer, "encode"): + try: + raw_ids = tokenizer.encode(text=text, tokenizer_name=tokenizer_name) + except TypeError: + raw_ids = tokenizer.encode(text) + else: + raw_ids = tokenizer.text_to_ids(text) + + if add_text_bos: + text_ids = [bos_id] + raw_ids + [eos_id] + else: + text_ids = raw_ids + [eos_id] + + # 3. Place text at the start, keeping the rest as pad_id + text_len = len(text_ids) + if text_len > turn_frames: + text_ids = text_ids[:turn_frames] + text_len = turn_frames + + turn_tokens[0:text_len] = torch.as_tensor(text_ids, dtype=torch.long) + + # 4. Place interruption token at the exact end of the turn + turn_tokens[-1] = interruption_token_id + turn_chunks.append(turn_tokens) + + if turn_chunks: + return torch.cat(turn_chunks, dim=0) + else: + return torch.tensor([pad_id], dtype=torch.long) + total = compute_num_frames(cut.duration, frame_length, cut.sampling_rate) tokens = torch.ones(total, dtype=torch.long) * pad_id @@ -606,14 +716,17 @@ def collate_phoneme_channel( pad_id: int = -1, eos_id: int = -2, bos_id: int = -3, + remove_user_turn_flags: list[bool] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - pad_id = phoneme_tokenizer.pad - tokens = [ - build_phoneme_channel( - c, phoneme_tokenizer, frame_length, roles, - ignore_phoneme_languages, pad_id, eos_id, bos_id - ) for c in cuts - ] + tokens = [] + for i, c in enumerate(cuts): + flag = remove_user_turn_flags[i] if remove_user_turn_flags else False + tokens.append( + build_phoneme_channel( + c, phoneme_tokenizer, frame_length, roles, + ignore_phoneme_languages, pad_id, eos_id, bos_id, remove_user_turns=flag + ) + ) token_lens = torch.tensor([len(tt) for tt in tokens]) return collate_vectors(tokens, padding_value=pad_id), token_lens @@ -627,12 +740,49 @@ def build_phoneme_channel( pad_id: int = -1, eos_id: int = -2, bos_id: int = -3, + remove_user_turns: bool = False, ) -> torch.Tensor: + language = cut.lang if cut.has_custom("lang") else next((sup.language for sup in cut.supervisions if sup.has_custom("language")), "en") + + if remove_user_turns: + turn_chunks = [] + for supervision in cut.supervisions: + if supervision.speaker in roles: + start_f = compute_num_frames(supervision.start, frame_length, cut.sampling_rate) + end_f = compute_num_frames(supervision.end, frame_length, cut.sampling_rate) + turn_frames = max(0, end_f - start_f) + + turn_tokens = torch.ones(turn_frames, dtype=torch.long) * pad_id + + if turn_frames == 0: + continue + + if isinstance(phoneme_tokenizer, IPABPETokenizer): + ipa_text = supervision.ipa if supervision.has_custom("ipa") else "" + if language in ignore_phoneme_languages: + ipa_text = "" + else: + ipa_text = supervision.text + + phoneme_ids = phoneme_tokenizer.encode(ipa_text) + phoneme_ids = [bos_id] + phoneme_ids + [eos_id] + + text_len = len(phoneme_ids) + if text_len > turn_frames: + phoneme_ids = phoneme_ids[:turn_frames] + text_len = turn_frames + + turn_tokens[0:text_len] = torch.as_tensor(phoneme_ids, dtype=torch.long) + turn_chunks.append(turn_tokens) + + if turn_chunks: + return torch.cat(turn_chunks, dim=0) + else: + return torch.tensor([pad_id], dtype=torch.long) + total = compute_num_frames(cut.duration, frame_length, cut.sampling_rate) tokens = torch.ones(total, dtype=torch.long) * pad_id - language = cut.lang if cut.has_custom("lang") else next((sup.language for sup in cut.supervisions if sup.has_custom("language")), "en") - for supervision in cut.supervisions: if supervision.speaker in roles: if isinstance(phoneme_tokenizer, IPABPETokenizer): diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index db01bdf14824..b519d93a2dcf 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -1471,7 +1471,7 @@ def get_lhotse_dataloader(self, dataset_cfg, mode='train') -> torch.utils.data.D input_roles=["user", "User"], output_roles=["assistant", "Assistant", "agent", "Agent"], add_text_bos=self.cfg.get("add_text_bos", False), - # pronunciation_control_g2p=self.cfg.get("pronunciation_control_g2p", None), + remove_user_turns_prob=self.cfg.get("remove_user_turns_prob", None), ) dataset = FallbackDataset(dataset) else: From 7ad497ddbb77530d1a664a5b82858f21e687bb35 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 20 Apr 2026 17:03:56 -0700 Subject: [PATCH 017/109] Add multiturn inference script Signed-off-by: Edresson Casanova --- .../tts/easy_magpietts_inference_multiturn.py | 745 ++++++++++++++++++ .../easy_magpietts_inference_multiturn_old.py | 560 +++++++++++++ .../tts/modules/magpietts_modules.py | 5 + 3 files changed, 1310 insertions(+) create mode 100644 examples/tts/easy_magpietts_inference_multiturn.py create mode 100644 examples/tts/easy_magpietts_inference_multiturn_old.py diff --git a/examples/tts/easy_magpietts_inference_multiturn.py b/examples/tts/easy_magpietts_inference_multiturn.py new file mode 100644 index 000000000000..7e34485d16a3 --- /dev/null +++ b/examples/tts/easy_magpietts_inference_multiturn.py @@ -0,0 +1,745 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +""" +Evaluation script for custom EasyMagpieTTS models. +Features explicit Duplex (10x Padding) and Regular (Turn-by-turn) multi-turn modes. + +Usage: + python easy_magpietts_eval.py \ + --checkpoint_path=/path/to/magpie/model.ckpt \ + --codec_model_path=/path/to/codec/model.ckpt \ + --datasets_json_path=/path/to/evalset_config.jsonl \ + --out_dir=/path/to/out/audio \ + --batch_size=6 \ + --use_cfg +""" + +import argparse +import json +import os +from copy import deepcopy +from functools import partial + +import librosa +import soundfile as sf +import torch +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import DataLoader, Dataset +from omegaconf import OmegaConf, open_dict + +from nemo.collections.audio.parts.utils.transforms import resample +from nemo.collections.speechlm2.parts.metrics.asr_cer_wer import Intelligibility +from nemo.collections.speechlm2.parts.metrics.secs import SECS +from nemo.collections.speechlm2.parts.precision import fp32_precision +from nemo.utils import logging + +# --- EasyMagpieTTS Imports --- +from nemo.collections.tts.models import AudioCodecModel +from nemo.collections.tts.modules.audio_codec_modules import VectorQuantizerIndexConverter +from nemo.collections.tts.modules.magpietts_modules import CodecHelper +from nemo.collections.tts.models.easy_magpietts_inference import EasyMagpieTTSInferenceModel + +torch.set_float32_matmul_precision("medium") +torch.backends.cudnn.allow_tf32 = True +torch.backends.cuda.matmul.allow_tf32 = True + +if torch.cuda.is_available(): + torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0))) + + +def attach_dtype_counter(model): + handles = [] + stats = {} + examples = {} + + def is_leaf(module): + return len(list(module.children())) == 0 + + def get_dtype(x): + if torch.is_tensor(x): + return str(x.dtype) + elif isinstance(x, (list, tuple)): + for t in x: + if torch.is_tensor(t): + return str(t.dtype) + return "other" + + def get_module_group(name): + return name.split(".")[0] if "." in name else name + + def hook_fn(name): + def fn(module, inputs, outputs): + dtype = get_dtype(outputs) + if dtype not in ["torch.float16", "torch.bfloat16", "torch.float32"]: + dtype = "other" + group = get_module_group(name) + if group not in stats: + stats[group] = { + "torch.float16": 0, "torch.bfloat16": 0, "torch.float32": 0, "other": 0, + } + examples[group] = { + "torch.float16": [], "torch.bfloat16": [], "torch.float32": [], "other": [], + } + stats[group][dtype] += 1 + if len(examples[group][dtype]) < 3: + examples[group][dtype].append(module.__class__.__name__) + return fn + + for name, module in model.named_modules(): + if is_leaf(module): + handles.append(module.register_forward_hook(hook_fn(name))) + return handles, stats, examples + + +def report_dtype_stats(handles, stats, examples): + for h in handles: + h.remove() + logging.info("\n=== DTYPE USAGE PER MODULE ===") + for group, group_stats in stats.items(): + total = sum(group_stats.values()) + if total == 0: continue + logging.info(f"\n--- {group} ---") + for dtype, count in group_stats.items(): + if count > 0: + logging.info(f"{dtype}: {count} ({100*count/total:.2f}%)") + logging.info("\n=== EXAMPLES ===") + for group, group_examples in examples.items(): + logging.info(f"\n--- {group} ---") + for dtype, mods in group_examples.items(): + if mods: + logging.info(f"{dtype}: {mods}") + + +class EvalJSONLDataset(Dataset): + def __init__(self, file_path, num_turns=1): + self.samples = [] + raw_samples = [] + with open(file_path, "r", encoding="utf-8") as f: + for line_idx, line in enumerate(f, 1): + line = line.strip() + if not line: continue + try: + raw_samples.append(json.loads(line)) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON on line {line_idx}: {e}") + + if num_turns <= 1: + self.samples = raw_samples + return + + single_turn_by_speaker = {} + for sample in raw_samples: + if isinstance(sample["text"], list): + self.samples.append(sample) + else: + speaker = sample.get("speaker", "unknown") + if speaker not in single_turn_by_speaker: + single_turn_by_speaker[speaker] = [] + single_turn_by_speaker[speaker].append(sample) + + for speaker, speaker_samples in single_turn_by_speaker.items(): + buffer_texts, buffer_paths = [], [] + first_sample_meta = None + + for sample in speaker_samples: + if not buffer_texts: + first_sample_meta = dict(sample) + buffer_texts.append(sample["text"]) + buffer_paths.append(sample.get("audio_filepath", "")) + + if len(buffer_texts) == num_turns: + first_sample_meta["text"] = buffer_texts + base_names = [os.path.splitext(os.path.basename(p))[0] for p in buffer_paths if p] + ext = os.path.splitext(buffer_paths[-1])[1] if buffer_paths[-1] else "" + combined_name = "_".join(base_names) + ext + dir_name = os.path.dirname(first_sample_meta.get("audio_filepath", "")) + if dir_name: + first_sample_meta["audio_filepath"] = os.path.join(dir_name, combined_name) + else: + first_sample_meta["audio_filepath"] = combined_name + + self.samples.append(first_sample_meta) + buffer_texts, buffer_paths, first_sample_meta = [], [], None + + if buffer_texts: + first_sample_meta["text"] = buffer_texts + base_names = [os.path.splitext(os.path.basename(p))[0] for p in buffer_paths if p] + ext = os.path.splitext(buffer_paths[-1])[1] if buffer_paths[-1] else "" + combined_name = "_".join(base_names) + ext + dir_name = os.path.dirname(first_sample_meta.get("audio_filepath", "")) + if dir_name: + first_sample_meta["audio_filepath"] = os.path.join(dir_name, combined_name) + else: + first_sample_meta["audio_filepath"] = combined_name + self.samples.append(first_sample_meta) + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + return self.samples[idx] + + +def collate_and_tokenize_custom( + batch, + model, + extra_duration_thrshould=1.3, + sample_rate=22050, + root_path=None, + emulate_duplex_inference=False, + add_interruption_token=False, + pad_factor_text_speech=10, + force_interruption=False, +): + main_tokenizer_name = list(model.cfg.text_tokenizers.keys())[0] + + # --- MULTI-TURN MODE DECISION --- + is_duplex = emulate_duplex_inference + + out_dict = { + "duplex_multiturn": is_duplex, + "regular_multiturn": not is_duplex, + } + + tokenized_list = [] + batched_turns = [] + batched_turn_lens = [] + valid_turn_masks = [] + + if is_duplex: + # ------------------------------------------------------------- + # DUPLEX MODE (Continuous sequence with 10x pad injection) + # ------------------------------------------------------------- + for s in batch: + text_data = s["text"] + + if isinstance(text_data, list): + full_ids = [] + for segment in text_data: + seg_ids = model.tokenizer.encode(segment, tokenizer_name=main_tokenizer_name) + [model.eos_id] + seg_len = len(seg_ids) + pad_len = seg_len * pad_factor_text_speech + pad_ids = [model.pad_id] * pad_len + + if force_interruption: + fname = s["audio_filepath"] + no_ext = fname.split(".")[0] + sample_id = int(no_ext.split("_")[-1]) + case = sample_id % 3 + + if case == 0: + if len(seg_ids) >= 2: + seg_ids[-2] = model.interruption_token_id + seg_ids[-1] = model.pad_id + else: + pad_ids[0] = model.interruption_token_id + elif case == 1: + eos_idx = min(6, len(pad_ids) - 1) + pad_ids[eos_idx] = model.interruption_token_id + else: + eos_idx = 0 + pad_ids[eos_idx] = model.interruption_token_id + else: + if add_interruption_token: + eos_idx = int(len(pad_ids) * 0.7) + pad_ids[eos_idx] = model.interruption_token_id + + full_ids.extend(seg_ids) + full_ids.extend(pad_ids) + + tokenized_list.append(torch.as_tensor(full_ids, dtype=torch.long)) + else: + tokenized_list.append( + torch.as_tensor(model.tokenizer.encode(text_data, tokenizer_name=main_tokenizer_name) + [model.eos_id], dtype=torch.long) + ) + + pad_len = 25 + prefix = torch.full((pad_len,), model.pad_id, dtype=torch.long) + for i in range(len(tokenized_list)): + tokenized_list[i] = torch.cat([prefix, tokenized_list[i]]) + input_lengths = torch.tensor([len(x) for x in tokenized_list], dtype=torch.long) + + input_ids = pad_sequence(tokenized_list, batch_first=True, padding_value=model.pad_id) + + out_dict["input_ids"] = input_ids + out_dict["input_lengths"] = input_lengths + + else: + # ------------------------------------------------------------- + # REGULAR MODE (Turn-by-turn discrete packaging) + # ------------------------------------------------------------- + max_turns = 1 + for s in batch: + if isinstance(s["text"], list): + max_turns = max(max_turns, len(s["text"])) + + for t in range(max_turns): + turn_t_tokens = [] + turn_t_lens = [] + turn_t_valid = [] + + for s in batch: + text_data = s["text"] + if isinstance(text_data, list): + if t < len(text_data): + seg_ids = model.tokenizer.encode(text_data[t], tokenizer_name=main_tokenizer_name) + [model.eos_id] + turn_t_tokens.append(torch.as_tensor(seg_ids, dtype=torch.long)) + turn_t_lens.append(len(seg_ids)) + turn_t_valid.append(True) + else: + # Dummy pad to keep shapes consistent for items with fewer turns + turn_t_tokens.append(torch.as_tensor([model.pad_id], dtype=torch.long)) + turn_t_lens.append(1) + turn_t_valid.append(False) + else: + if t == 0: + seg_ids = model.tokenizer.encode(text_data, tokenizer_name=main_tokenizer_name) + [model.eos_id] + turn_t_tokens.append(torch.as_tensor(seg_ids, dtype=torch.long)) + turn_t_lens.append(len(seg_ids)) + turn_t_valid.append(True) + else: + turn_t_tokens.append(torch.as_tensor([model.pad_id], dtype=torch.long)) + turn_t_lens.append(1) + turn_t_valid.append(False) + + padded_turn_t = pad_sequence(turn_t_tokens, batch_first=True, padding_value=model.pad_id) + batched_turns.append(padded_turn_t) + batched_turn_lens.append(torch.tensor(turn_t_lens, dtype=torch.long)) + valid_turn_masks.append(torch.tensor(turn_t_valid, dtype=torch.bool)) + + out_dict["batched_turns"] = batched_turns + out_dict["batched_turn_lens"] = batched_turn_lens + out_dict["valid_turn_masks"] = valid_turn_masks + + # --- AUDIO LOADING --- + audio_list = [] + audio_lengths = [] + target_num_frames = [] + + for i, s in enumerate(batch): + audio_path = s["context_audio_filepath"] + if root_path is not None: + audio_path = os.path.join(root_path, audio_path) + + if os.path.exists(audio_path): + wav, sr = librosa.load(audio_path, sr=sample_rate, mono=True) + wav = torch.as_tensor(wav, dtype=torch.float32) + else: + wav = torch.zeros(1, dtype=torch.float32) + + audio_list.append(wav) + audio_lengths.append(len(wav)) + + tdur_audio_path = s["audio_filepath"] + if root_path is not None: + tdur_audio_path = os.path.join(root_path, tdur_audio_path) + + if tdur_audio_path and os.path.exists(tdur_audio_path): + wav_dur, sr_ = librosa.load(tdur_audio_path, sr=sample_rate, mono=True) + tdur = wav_dur.shape[0] // model.target_samples_per_frame + target_num_frames.append(tdur * extra_duration_thrshould) + else: + # Fallback estimation + if is_duplex: + current_text_len = len(tokenized_list[i]) + if isinstance(s["text"], list): + target_num_frames.append(current_text_len) + else: + target_num_frames.append(current_text_len * 5) + else: + target_num_frames.append(sum([l[i].item() for l in batched_turn_lens]) * 5) + + max_audio_len = max(audio_lengths) + B = len(audio_lengths) + padded_audio = torch.zeros((B, max_audio_len), dtype=torch.float32) + + for i, wav in enumerate(audio_list): + padded_audio[i, : len(wav)] = wav + + out_dict["context_audio"] = padded_audio + out_dict["context_audio_lengths"] = torch.tensor(audio_lengths, dtype=torch.long) + out_dict["target_audio_paths"] = [s["audio_filepath"] for s in batch] + out_dict["target_num_frames"] = target_num_frames + + out_dict["raw_text"] = [" ".join(s["text"]) if isinstance(s["text"], list) else s["text"] for s in batch] + + return out_dict + + +def main(): + parser = argparse.ArgumentParser(description="EasyMagpieTTS Inference Evaluation") + + # Required Paths + parser.add_argument("--checkpoint_path", type=str, required=True, help="Path to the EasyMagpie model") + parser.add_argument("--codec_model_path", type=str, required=True, help="Path to the audio codec") + parser.add_argument("--datasets_json_path", type=str, required=True, help="Path to JSONL data") + parser.add_argument("--out_dir", type=str, required=True, help="Directory to save audio outputs") + + # Optional Paths & General + parser.add_argument("--phoneme_tokenizer_path", type=str, default=None) + parser.add_argument("--audio_dir", type=str, default=None, help="Root dir for audio paths in JSONL") + parser.add_argument("--inference_dtype", type=str, default="float16") + parser.add_argument("--debug_dtype", action="store_true") + + # Dataloader & Batching + parser.add_argument("--batch_size", type=int, default=6) + parser.add_argument("--num_workers", type=int, default=4) + parser.add_argument("--num_turns", type=int, default=1) + parser.add_argument("--pad_factor_text_speech", type=int, default=10) + + # Text Processing Boolean Flags + parser.add_argument("--emulate_duplex_inference", action="store_true") + parser.add_argument("--add_interruption_token", action="store_true") + parser.add_argument("--force_interruption", action="store_true") + + # Speaker & Prompt Configurations + parser.add_argument("--user_custom_speaker_reference", action="store_true") + parser.add_argument("--inference_speaker_reference", type=str, default=None) + parser.add_argument("--language", type=str, default="en") + + # Generation Kwargs + parser.add_argument("--use_cfg", action="store_true") + parser.add_argument("--cfg_scale", type=float, default=2.5) + parser.add_argument("--temperature", type=float, default=0.7) + parser.add_argument("--topk", type=int, default=80) + parser.add_argument("--max_tts_steps", type=int, default=2000) + parser.add_argument("--force_speech_sil_codes", action="store_true") + + args = parser.parse_args() + + distributed = int(os.environ.get("WORLD_SIZE", "1")) > 1 + if distributed and not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend="nccl") + + target_device = torch.device(f"cuda:{int(os.environ.get('LOCAL_RANK', 0))}" if torch.cuda.is_available() else "cpu") + target_dtype = getattr(torch, args.inference_dtype) + torch.set_default_dtype(target_dtype) + + model_cfg = EasyMagpieTTSInferenceModel.restore_from(args.checkpoint_path, return_config=True) + with open_dict(model_cfg): + model_cfg.target = 'nemo.collections.tts.models.easy_magpietts_inference.EasyMagpieTTSInferenceModel' + model_cfg.codecmodel_path = args.codec_model_path + model_cfg.train_ds = None + model_cfg.validation_ds = None + model_cfg.run_val_inference = False + model_cfg.use_utmos = False + model_cfg.use_meta_init_for_decoder = True + + # --- MISSING FIX: Guarantees silence for pad tokens --- + model_cfg.use_multiturn_dataset = True + + if args.phoneme_tokenizer_path and getattr(model_cfg, "phoneme_tokenizer", None) is not None: + model_cfg.phoneme_tokenizer.tokenizer_path = args.phoneme_tokenizer_path + + # --- MISSING FIX: Load to CPU first to prevent OOM --- + model = EasyMagpieTTSInferenceModel.restore_from( + args.checkpoint_path, override_config_path=model_cfg, map_location=torch.device("cpu") + ) + model.use_kv_cache_for_inference = True + model.to(dtype=target_dtype) + model.eval().to(target_device) + + # --- DATALOADER COMPATIBILITY PATCHES --- + model.target_samples_per_frame = int(model.codec_model_samples_per_frame * model.frame_stacking_factor) + model.target_sample_rate = model.sample_rate + + # --- MISSING FIX: Load to CPU first to prevent OOM --- + codec_model = AudioCodecModel.restore_from(args.codec_model_path, strict=False, map_location=torch.device("cpu")) + if hasattr(codec_model, "discriminator"): + del codec_model.discriminator + codec_model.freeze() + codec_model = codec_model.to(target_device).eval() + + + codec_converter = None + if getattr(model, "_codec_converter", None) is not None: + vq_new = deepcopy(model._codec_converter.vector_quantizer_new).to(target_device).eval() + codec_converter = VectorQuantizerIndexConverter( + vector_quantizer_original=codec_model.vector_quantizer, + vector_quantizer_new=vq_new, + ).to(target_device).eval() + + if not hasattr(model, "_codec_helper") or model._codec_helper is None: + model._codec_helper = CodecHelper(codec_model=codec_model, codec_converter=codec_converter) + + from collections import Counter + def get_codec_silence_frame(model, device, target_sample_rate): + # Generate long zero waveform (silence) + audio = torch.zeros(1, 10 * target_sample_rate).float().to(device) + audio_len = torch.tensor([audio.size(-1)]).long().to(device) + + sil_codes, sil_codes_lens = model._codec_helper.audio_to_codes( + audio, audio_len + ) + + if model._codec_converter is not None: + sil_codes = model._codec_converter.convert_original_to_new( + audio_tokens=sil_codes, audio_lens=sil_codes_lens + ).long() + + # sil_codes is shape [1, C, T]. + # Extract batch index 0 and transpose to shape [T, C] + frames = sil_codes[0].transpose(0, 1) + + # Convert each time frame (C integers) into a tuple of integers + combos = [tuple(frame.tolist()) for frame in frames] + + # Count frequencies + counter = Counter(combos) + + # Pick the most common combination + most_common_combo, freq = counter.most_common(1)[0] + + # Return as tensor [C] + return torch.tensor(most_common_combo, device=device, dtype=torch.long) + + codec_sil_codes = get_codec_silence_frame(model, target_device, model.sample_rate) + + if args.debug_dtype: + handles, stats, examples = attach_dtype_counter(model) + + with fp32_precision(): + intelligibility = Intelligibility("stt_en_fastconformer_transducer_large", reuse_asr_hyps=False).reset() + secs_metric = SECS("titanet_large").reset() + + eval_dataset = EvalJSONLDataset(args.datasets_json_path, num_turns=args.num_turns) + + collate_fn = partial( + collate_and_tokenize_custom, + model=model, + extra_duration_thrshould=1.5, + sample_rate=model.target_sample_rate, + root_path=args.audio_dir, + emulate_duplex_inference=args.emulate_duplex_inference, + add_interruption_token=args.add_interruption_token, + pad_factor_text_speech=args.pad_factor_text_speech, + force_interruption=args.force_interruption, + ) + + dataloader = DataLoader( + dataset=eval_dataset, batch_size=args.batch_size, collate_fn=collate_fn, + num_workers=args.num_workers, pin_memory=True, shuffle=False, drop_last=False, + ) + + if args.user_custom_speaker_reference and args.inference_speaker_reference: + wav, sr = librosa.load(args.inference_speaker_reference, sr=model.target_sample_rate, mono=True) + speaker_wav = torch.as_tensor(wav, dtype=target_dtype).unsqueeze(0).to(model.device) + + for batch_id, inputs in enumerate(dataloader): + B = inputs["context_audio"].size(0) + device = model.device + + inputs["context_audio"] = inputs["context_audio"].to(device, dtype=target_dtype) + inputs["context_audio_lengths"] = inputs["context_audio_lengths"].to(device) + + if args.user_custom_speaker_reference and args.inference_speaker_reference: + inputs["context_audio"] = speaker_wav.expand(B, *speaker_wav.shape[1:]) + inputs["context_audio_lengths"][:] = speaker_wav.size(-1) + + with torch.inference_mode(): + # 1. Base Initialization (Shared between modes) + wav = inputs["context_audio"] + wav_len = inputs["context_audio_lengths"] + codes, codes_lens = model._codec_helper.audio_to_codes(wav, wav_len) + + use_lang = bool(getattr(model, "add_language_to_context_text", False)) + ctx_text = f"[{args.language.upper()}]" if use_lang else "[NO TEXT CONTEXT]" + ctx_text_ids = model.tokenizer.encode(ctx_text, tokenizer_name=model.text_conditioning_tokenizer_name) + + ctx_toks = torch.tensor([ctx_text_ids], dtype=torch.long, device=device).expand(B, -1) + ctx_toks_lens = torch.tensor([len(ctx_text_ids)] * B, dtype=torch.long, device=device) + + state = model.streaming_init( + context_audio_codes=codes, + context_audio_codes_lens=codes_lens, + context_text_tokens=ctx_toks, + context_text_tokens_lens=ctx_toks_lens, + use_cfg=args.use_cfg, + cfg_scale=args.cfg_scale, + use_local_transformer=True, + temperature=args.temperature, + topk=args.topk, + phoneme_input_type="pred", + phoneme_sampling_method="argmax", + use_inference_mode=True, + ) + + # --------------------------------------------------------- + # MODE 1: DUPLEX (Continuous Padding Token Stream) + # --------------------------------------------------------- + if inputs["duplex_multiturn"]: + text = inputs["input_ids"].to(device) + text_lens = inputs["input_lengths"].to(device) + + # Fetch the true silence frame codebook combo once + codec_sil_codes = get_codec_silence_frame(model, device, model.sample_rate) + + # Trackers for our two forced-silence zones + in_initial_silence = torch.ones(B, dtype=torch.bool, device=device) + in_post_speech_silence = torch.zeros(B, dtype=torch.bool, device=device) + + text_exhausted = state.text_tokens_seen >= text_lens + + while not text_exhausted.all() and len(state.all_predictions) < args.max_tts_steps: + + # 1. WAKE UP OVERRIDE: Keep the text pipeline awake to read pads! + # Note: This forces state.finished to False at the start of the loop + state.finished = state.finished & text_exhausted + state.text_finished = state.text_finished & text_exhausted + if hasattr(state, "phoneme_stream_ended"): + state.phoneme_stream_ended = state.phoneme_stream_ended & text_exhausted + + # 2. Safely index text using the model's internal pointer + positions = state.text_tokens_seen.clamp(max=text.size(1) - 1) + current_tokens = text[torch.arange(B, device=device), positions] + + current_tokens = torch.where( + text_exhausted, torch.full_like(current_tokens, model.eos_id), current_tokens + ) + + # 3. Update our trackers BEFORE the step + is_pad_or_eos = (current_tokens == model.pad_id) | (current_tokens == model.eos_id) + + # Initial silence turns off forever once a real word is seen + in_initial_silence = in_initial_silence & is_pad_or_eos + + # Post-speech silence turns off when a real word for the NEXT turn is seen + in_post_speech_silence = in_post_speech_silence & is_pad_or_eos + + # 4. Step the model + state, audio_codes, _ = model.streaming_step(state=state, text_tokens=current_tokens, use_inference_mode=True) + + # 5. TRIGGER POST-SPEECH SILENCE: + # If the audio decoder naturally predicted a speech EOS, state.finished becomes True here! + in_post_speech_silence = in_post_speech_silence | state.finished + + # 6. SILENCE FORCING INJECTION + if audio_codes is not None and args.force_speech_sil_codes: + # We force silence if we are in the initial prefix OR if the model has finished its sentence + force_silence_mask = in_initial_silence | in_post_speech_silence + + if force_silence_mask.any(): + # Expand silence codes [C] -> [1, C, 1] to match audio_codes [B, C, 1] + expanded_sil = codec_sil_codes.view(1, -1, 1).expand_as(audio_codes) + + # Expand mask [B] -> [B, 1, 1] for broadcasting + mask_3d = force_silence_mask.view(B, 1, 1) + + # Overwrite the prediction with silence codes where the mask is True. + overwritten_codes = torch.where(mask_3d, expanded_sil, audio_codes) + + # Inject back into the model's KV cache history + state.all_predictions[-1] = overwritten_codes + + # Update exhaustion tracker for the next iteration + text_exhausted = state.text_tokens_seen >= text_lens + + # --------------------------------------------------------- + # MODE 2: REGULAR (Turn-by-Turn Re-wakes) + # --------------------------------------------------------- + elif inputs["regular_multiturn"]: + batched_turns = inputs["batched_turns"] + batched_turn_lens = inputs["batched_turn_lens"] + valid_turn_masks = inputs["valid_turn_masks"] + + max_turns = len(batched_turns) + + # Tracking offset ensures sync regardless of context audio length + turn_offsets = torch.zeros(B, dtype=torch.long, device=device) + + for t in range(max_turns): + turn_text = batched_turns[t].to(device) + turn_lens = batched_turn_lens[t].to(device) + valid_mask = valid_turn_masks[t].to(device) + + # Reset ALL finished flags for items participating in this new turn + state.finished = state.finished & (~valid_mask) + state.text_finished = state.text_finished & (~valid_mask) + + if hasattr(state, "phoneme_stream_ended"): + state.phoneme_stream_ended = state.phoneme_stream_ended & (~valid_mask) + + if state.finished.all(): + continue + + # Record internal token count at start of turn + turn_offsets = torch.where(valid_mask, state.text_tokens_seen, turn_offsets) + turn_steps = 0 + + while not state.finished.all() and turn_steps < args.max_tts_steps: + turn_steps += 1 + + # Fetch token synced relative to the model's progress + relative_positions = state.text_tokens_seen - turn_offsets + positions = relative_positions.clamp(min=0, max=turn_text.size(1) - 1) + current_tokens = turn_text[torch.arange(B, device=device), positions] + + # Once the text for this turn is fully fed, feed EOS so audio can finish + exhausted = relative_positions >= turn_lens + current_tokens = torch.where(exhausted, torch.full_like(current_tokens, model.eos_id), current_tokens) + + state, _, _ = model.streaming_step(state=state, text_tokens=current_tokens, use_inference_mode=True) + + # Finalize decodes the collected Codec states globally regardless of which loop was run + finalize_output = model.streaming_finalize(state, use_inference_mode=True) + + if args.debug_dtype and batch_id == 0: + report_dtype_stats(handles, stats, examples) + + with fp32_precision(): + audio_f32 = finalize_output.audio.float() + audio_len = finalize_output.audio_len.int() + + expected_audio_lens = (torch.tensor(inputs["target_num_frames"], device=device) * model.target_samples_per_frame).int() + + if inputs["duplex_multiturn"]: + # Cap the expected length so it physically cannot exceed the actual generated tensor size + expected_audio_lens = expected_audio_lens.clamp(max=audio_f32.size(1)) + audio_len = torch.max(audio_len, expected_audio_lens) + else: + audio_len = torch.min(audio_len, expected_audio_lens) + + metric_audio_pred = resample(audio_f32, getattr(model, "output_sample_rate", 24000), 16000) + metric_audio_pred_lens = (audio_len / getattr(model, "output_sample_rate", 24000) * 16000).to(torch.long) + + intelligibility.update( + name="dataset", + refs=inputs["raw_text"], + pred_audio=metric_audio_pred, + pred_audio_lens=metric_audio_pred_lens, + asr_hyps=None, + ) + + secs_metric.update( + name="dataset", + target_audio=resample(inputs["context_audio"].float(), model.target_sample_rate, 16000), + target_audio_lens=(inputs["context_audio_lengths"] / model.target_sample_rate * 16000).to(torch.long), + pred_audio=metric_audio_pred, + pred_audio_lens=metric_audio_pred_lens, + ) + + os.makedirs(args.out_dir, exist_ok=True) + audio_f32 = audio_f32.detach().cpu() + audio_len = audio_len.cpu() + + for i in range(B): + wav = audio_f32[i, : audio_len[i]].numpy() + target_path = inputs["target_audio_paths"][i] + base_name = os.path.basename(target_path) + out_path = os.path.join(args.out_dir, base_name) + sf.write(out_path, wav, samplerate=getattr(model, "output_sample_rate", 24000)) + logging.info(f"Saved: {out_path}") + + with fp32_precision(): + logging.info("\n--- Evaluation Metrics ---") + cer_wer = intelligibility.compute() + for k, m in cer_wer.items(): + logging.info(f"Intelligibility - {k}: {m}") + + secs_scores = secs_metric.compute() + for k, m in secs_scores.items(): + logging.info(f"SECS - {k}: {m}") + + +if __name__ == "__main__": + main() diff --git a/examples/tts/easy_magpietts_inference_multiturn_old.py b/examples/tts/easy_magpietts_inference_multiturn_old.py new file mode 100644 index 000000000000..502071d52d81 --- /dev/null +++ b/examples/tts/easy_magpietts_inference_multiturn_old.py @@ -0,0 +1,560 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +""" +Evaluation script for custom EasyMagpieTTS models trained on 10x padded inputs. +Stripped of Hydra config requirements. Uses standard argparse. +Features an explicitly exposed, fully batched autoregressive loop for easy multi-turn modding. + +Usage: + python easy_magpietts_eval.py \ + --checkpoint_path=/path/to/magpie/model.ckpt \ + --codec_model_path=/path/to/codec/model.ckpt \ + --datasets_json_path=/path/to/evalset_config.jsonl \ + --out_dir=/path/to/out/audio \ + --batch_size=6 \ + --add_interruption_token \ + --add_beginning_pad_tokens \ + --use_cfg +""" + +import argparse +import json +import os +import time +from copy import deepcopy +from functools import partial + +import librosa +import soundfile as sf +import torch +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import DataLoader, Dataset +from omegaconf import OmegaConf, open_dict + +from nemo.collections.audio.parts.utils.transforms import resample +from nemo.collections.speechlm2.parts.metrics.asr_cer_wer import Intelligibility +from nemo.collections.speechlm2.parts.metrics.secs import SECS +from nemo.collections.speechlm2.parts.precision import fp32_precision +from nemo.utils import logging + +# --- EasyMagpieTTS Imports --- +from nemo.collections.tts.models import AudioCodecModel +from nemo.collections.tts.modules.audio_codec_modules import VectorQuantizerIndexConverter +from nemo.collections.tts.modules.magpietts_modules import CodecHelper +from nemo.collections.tts.models.easy_magpietts_inference import EasyMagpieTTSInferenceModel + +torch.set_float32_matmul_precision("medium") +torch.backends.cudnn.allow_tf32 = True +torch.backends.cuda.matmul.allow_tf32 = True + +if torch.cuda.is_available(): + torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0))) + + +def attach_dtype_counter(model): + handles = [] + stats = {} + examples = {} + + def is_leaf(module): + return len(list(module.children())) == 0 + + def get_dtype(x): + if torch.is_tensor(x): + return str(x.dtype) + elif isinstance(x, (list, tuple)): + for t in x: + if torch.is_tensor(t): + return str(t.dtype) + return "other" + + def get_module_group(name): + return name.split(".")[0] if "." in name else name + + def hook_fn(name): + def fn(module, inputs, outputs): + dtype = get_dtype(outputs) + if dtype not in ["torch.float16", "torch.bfloat16", "torch.float32"]: + dtype = "other" + group = get_module_group(name) + if group not in stats: + stats[group] = { + "torch.float16": 0, "torch.bfloat16": 0, "torch.float32": 0, "other": 0, + } + examples[group] = { + "torch.float16": [], "torch.bfloat16": [], "torch.float32": [], "other": [], + } + stats[group][dtype] += 1 + if len(examples[group][dtype]) < 3: + examples[group][dtype].append(module.__class__.__name__) + return fn + + for name, module in model.named_modules(): + if is_leaf(module): + handles.append(module.register_forward_hook(hook_fn(name))) + return handles, stats, examples + + +def report_dtype_stats(handles, stats, examples): + for h in handles: + h.remove() + logging.info("\n=== DTYPE USAGE PER MODULE ===") + for group, group_stats in stats.items(): + total = sum(group_stats.values()) + if total == 0: continue + logging.info(f"\n--- {group} ---") + for dtype, count in group_stats.items(): + if count > 0: + logging.info(f"{dtype}: {count} ({100*count/total:.2f}%)") + logging.info("\n=== EXAMPLES ===") + for group, group_examples in examples.items(): + logging.info(f"\n--- {group} ---") + for dtype, mods in group_examples.items(): + if mods: + logging.info(f"{dtype}: {mods}") + + +class EvalJSONLDataset(Dataset): + def __init__(self, file_path, num_turns=1): + self.samples = [] + raw_samples = [] + with open(file_path, "r", encoding="utf-8") as f: + for line_idx, line in enumerate(f, 1): + line = line.strip() + if not line: continue + try: + raw_samples.append(json.loads(line)) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON on line {line_idx}: {e}") + + if num_turns <= 1: + self.samples = raw_samples + return + + single_turn_by_speaker = {} + for sample in raw_samples: + if isinstance(sample["text"], list): + self.samples.append(sample) + else: + speaker = sample.get("speaker", "unknown") + if speaker not in single_turn_by_speaker: + single_turn_by_speaker[speaker] = [] + single_turn_by_speaker[speaker].append(sample) + + for speaker, speaker_samples in single_turn_by_speaker.items(): + buffer_texts, buffer_paths = [], [] + first_sample_meta = None + + for sample in speaker_samples: + if not buffer_texts: + first_sample_meta = dict(sample) + buffer_texts.append(sample["text"]) + buffer_paths.append(sample.get("audio_filepath", "")) + + if len(buffer_texts) == num_turns: + first_sample_meta["text"] = buffer_texts + base_names = [os.path.splitext(os.path.basename(p))[0] for p in buffer_paths if p] + ext = os.path.splitext(buffer_paths[-1])[1] if buffer_paths[-1] else "" + combined_name = "_".join(base_names) + ext + dir_name = os.path.dirname(first_sample_meta.get("audio_filepath", "")) + if dir_name: + first_sample_meta["audio_filepath"] = os.path.join(dir_name, combined_name) + else: + first_sample_meta["audio_filepath"] = combined_name + + self.samples.append(first_sample_meta) + buffer_texts, buffer_paths, first_sample_meta = [], [], None + + if buffer_texts: + first_sample_meta["text"] = buffer_texts + base_names = [os.path.splitext(os.path.basename(p))[0] for p in buffer_paths if p] + ext = os.path.splitext(buffer_paths[-1])[1] if buffer_paths[-1] else "" + combined_name = "_".join(base_names) + ext + dir_name = os.path.dirname(first_sample_meta.get("audio_filepath", "")) + if dir_name: + first_sample_meta["audio_filepath"] = os.path.join(dir_name, combined_name) + else: + first_sample_meta["audio_filepath"] = combined_name + self.samples.append(first_sample_meta) + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + return self.samples[idx] + + +def collate_and_tokenize_custom( + batch, + model, + extra_duration_thrshould=1.3, + sample_rate=22050, + root_path=None, + add_beginning_pad_tokens=False, + add_interruption_token=False, + pad_factor_text_speech=10, + force_interruption=False, +): + tokenized_list = [] + main_tokenizer_name = list(model.cfg.text_tokenizers.keys())[0] + for s in batch: + text_data = s["text"] + + if isinstance(text_data, list): + full_ids = [] + for segment in text_data: + seg_ids = model.tokenizer.encode(segment, tokenizer_name=main_tokenizer_name) + [model.eos_id] + seg_len = len(seg_ids) + pad_len = seg_len * pad_factor_text_speech + pad_ids = [model.pad_id] * pad_len + + if force_interruption: + fname = s["audio_filepath"] + no_ext = fname.split(".")[0] + sample_id = int(no_ext.split("_")[-1]) + case = sample_id % 3 + + if case == 0: + if len(seg_ids) >= 2: + seg_ids[-2] = model.interruption_token_id + seg_ids[-1] = model.pad_id + else: + pad_ids[0] = model.interruption_token_id + elif case == 1: + eos_idx = min(6, len(pad_ids) - 1) + pad_ids[eos_idx] = model.interruption_token_id + else: + eos_idx = 0 + pad_ids[eos_idx] = model.interruption_token_id + else: + if add_interruption_token: + eos_idx = int(len(pad_ids) * 0.7) + pad_ids[eos_idx] = model.interruption_token_id + + full_ids.extend(seg_ids) + full_ids.extend(pad_ids) + + tokenized_list.append(torch.as_tensor(full_ids, dtype=torch.long)) + else: + tokenized_list.append( + torch.as_tensor(model.tokenizer.encode(text_data, tokenizer_name=main_tokenizer_name) + [model.eos_id], dtype=torch.long) + ) + + if add_beginning_pad_tokens: + pad_len = 25 + prefix = torch.full((pad_len,), model.pad_id, dtype=torch.long) + for i in range(len(tokenized_list)): + tokenized_list[i] = torch.cat([prefix, tokenized_list[i]]) + + # Capture the true sequence length before pad_sequence applies batch alignment padding + input_lengths = torch.tensor([len(x) for x in tokenized_list], dtype=torch.long) + input_ids = pad_sequence(tokenized_list, batch_first=True, padding_value=model.pad_id) + + audio_list = [] + audio_lengths = [] + target_num_frames = [] + + for i, s in enumerate(batch): + audio_path = s["context_audio_filepath"] + if root_path is not None: + audio_path = os.path.join(root_path, audio_path) + + if os.path.exists(audio_path): + wav, sr = librosa.load(audio_path, sr=sample_rate, mono=True) + wav = torch.as_tensor(wav, dtype=torch.float32) + else: + wav = torch.zeros(1, dtype=torch.float32) + + audio_list.append(wav) + audio_lengths.append(len(wav)) + + tdur_audio_path = s["audio_filepath"] + if root_path is not None: + tdur_audio_path = os.path.join(root_path, tdur_audio_path) + + if tdur_audio_path and os.path.exists(tdur_audio_path): + wav_dur, sr_ = librosa.load(tdur_audio_path, sr=sample_rate, mono=True) + tdur = wav_dur.shape[0] // model.target_samples_per_frame + target_num_frames.append(tdur * extra_duration_thrshould) + else: + current_text_len = len(tokenized_list[i]) + if isinstance(s["text"], list): + target_num_frames.append(current_text_len) + else: + target_num_frames.append(current_text_len * 5) + + max_audio_len = max(audio_lengths) + B = len(audio_lengths) + padded_audio = torch.zeros((B, max_audio_len), dtype=torch.float32) + + for i, wav in enumerate(audio_list): + padded_audio[i, : len(wav)] = wav + + audio_lengths = torch.tensor(audio_lengths, dtype=torch.long) + B, L = input_ids.shape + target_len = int(max(target_num_frames)) + target_len = max(target_len, L) + + padded_input_ids = torch.full((B, target_len), fill_value=model.pad_id, dtype=input_ids.dtype) + padded_input_ids[:, :L] = input_ids + + collapsed_raw_text = [" ".join(s["text"]) if isinstance(s["text"], list) else s["text"] for s in batch] + + return { + "input_ids": padded_input_ids, + "input_lengths": input_lengths, + "raw_text": collapsed_raw_text, + "context_audio": padded_audio, + "context_audio_lengths": audio_lengths, + "target_audio_paths": [s["audio_filepath"] for s in batch], + "target_num_frames": target_num_frames, + } + + +def main(): + parser = argparse.ArgumentParser(description="EasyMagpieTTS Inference Evaluation") + + # Required Paths + parser.add_argument("--checkpoint_path", type=str, required=True, help="Path to the EasyMagpie model") + parser.add_argument("--codec_model_path", type=str, required=True, help="Path to the audio codec") + parser.add_argument("--datasets_json_path", type=str, required=True, help="Path to JSONL data") + parser.add_argument("--out_dir", type=str, required=True, help="Directory to save audio outputs") + + # Optional Paths & General + parser.add_argument("--phoneme_tokenizer_path", type=str, default=None) + parser.add_argument("--audio_dir", type=str, default=None, help="Root dir for audio paths in JSONL") + parser.add_argument("--inference_dtype", type=str, default="float16") + parser.add_argument("--debug_dtype", action="store_true") + + # Dataloader & Batching + parser.add_argument("--batch_size", type=int, default=6) + parser.add_argument("--num_workers", type=int, default=4) + parser.add_argument("--num_turns", type=int, default=1) + parser.add_argument("--pad_factor_text_speech", type=int, default=10) + + # Text Processing Boolean Flags + parser.add_argument("--add_beginning_pad_tokens", action="store_true") + parser.add_argument("--add_interruption_token", action="store_true") + parser.add_argument("--force_interruption", action="store_true") + + # Speaker & Prompt Configurations + parser.add_argument("--user_custom_speaker_reference", action="store_true") + parser.add_argument("--inference_speaker_reference", type=str, default=None) + parser.add_argument("--language", type=str, default="en") + + # Generation Kwargs + parser.add_argument("--use_cfg", action="store_true") + parser.add_argument("--cfg_scale", type=float, default=2.5) + parser.add_argument("--temperature", type=float, default=0.7) + parser.add_argument("--topk", type=int, default=80) + parser.add_argument("--max_tts_steps", type=int, default=1000) + + args = parser.parse_args() + + distributed = int(os.environ.get("WORLD_SIZE", "1")) > 1 + if distributed and not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend="nccl") + + target_device = torch.device(f"cuda:{int(os.environ.get('LOCAL_RANK', 0))}" if torch.cuda.is_available() else "cpu") + target_dtype = getattr(torch, args.inference_dtype) + torch.set_default_dtype(target_dtype) + + model_cfg = EasyMagpieTTSInferenceModel.restore_from(args.checkpoint_path, return_config=True) + with open_dict(model_cfg): + model_cfg.target = 'nemo.collections.tts.models.easy_magpietts_inference.EasyMagpieTTSInferenceModel' + model_cfg.codecmodel_path = args.codec_model_path + model_cfg.train_ds = None + model_cfg.validation_ds = None + model_cfg.run_val_inference = False + model_cfg.use_utmos = False + model_cfg.use_meta_init_for_decoder = True + if args.phoneme_tokenizer_path and getattr(model_cfg, "phoneme_tokenizer", None) is not None: + model_cfg.phoneme_tokenizer.tokenizer_path = args.phoneme_tokenizer_path + + model = EasyMagpieTTSInferenceModel.restore_from( + args.checkpoint_path, override_config_path=model_cfg, map_location=target_device + ) + model.use_kv_cache_for_inference = True + model.eval().to(target_device) + model.to(dtype=target_dtype) + + # --- DATALOADER COMPATIBILITY PATCHES --- + model.target_samples_per_frame = getattr(model, "codec_model_samples_per_frame", 320) + model.target_sample_rate = getattr(model, "sample_rate", 22050) + model.pad_id = getattr(model.tokenizer, "pad_id", 0) + model.text_eos_id = model.eos_id + + codec_model = AudioCodecModel.restore_from(args.codec_model_path, strict=False, map_location=target_device) + if hasattr(codec_model, "discriminator"): + del codec_model.discriminator + codec_model.freeze() + codec_model = codec_model.to(target_device).eval() + + codec_converter = None + if getattr(model, "_codec_converter", None) is not None: + vq_new = deepcopy(model._codec_converter.vector_quantizer_new).to(target_device).eval() + codec_converter = VectorQuantizerIndexConverter( + vector_quantizer_original=codec_model.vector_quantizer, + vector_quantizer_new=vq_new, + ).to(target_device).eval() + + if not hasattr(model, "_codec_helper") or model._codec_helper is None: + model._codec_helper = CodecHelper(codec_model=codec_model, codec_converter=codec_converter) + + if args.debug_dtype: + handles, stats, examples = attach_dtype_counter(model) + + with fp32_precision(): + intelligibility = Intelligibility("stt_en_fastconformer_transducer_large", reuse_asr_hyps=False).reset() + secs_metric = SECS("titanet_large").reset() + + eval_dataset = EvalJSONLDataset(args.datasets_json_path, num_turns=args.num_turns) + + collate_fn = partial( + collate_and_tokenize_custom, + model=model, + extra_duration_thrshould=1.5, + sample_rate=model.target_sample_rate, + root_path=args.audio_dir, + add_beginning_pad_tokens=args.add_beginning_pad_tokens, + add_interruption_token=args.add_interruption_token, + pad_factor_text_speech=args.pad_factor_text_speech, + force_interruption=args.force_interruption, + ) + + dataloader = DataLoader( + dataset=eval_dataset, batch_size=args.batch_size, collate_fn=collate_fn, + num_workers=args.num_workers, pin_memory=True, shuffle=False, drop_last=False, + ) + + if args.user_custom_speaker_reference and args.inference_speaker_reference: + wav, sr = librosa.load(args.inference_speaker_reference, sr=model.target_sample_rate, mono=True) + speaker_wav = torch.as_tensor(wav, dtype=target_dtype).unsqueeze(0).to(model.device) + + for batch_id, inputs in enumerate(dataloader): + B = inputs["input_ids"].size(0) + device = model.device + + inputs["input_ids"] = inputs["input_ids"].to(device) + inputs["input_lengths"] = inputs["input_lengths"].to(device) + inputs["context_audio"] = inputs["context_audio"].to(device, dtype=target_dtype) + inputs["context_audio_lengths"] = inputs["context_audio_lengths"].to(device) + + if args.user_custom_speaker_reference and args.inference_speaker_reference: + inputs["context_audio"] = speaker_wav.expand(B, *speaker_wav.shape[1:]) + inputs["context_audio_lengths"][:] = speaker_wav.size(-1) + + # 1. Prepare Context & Initialize + with torch.inference_mode(): + wav = inputs["context_audio"] + wav_len = inputs["context_audio_lengths"] + codes, codes_lens = model._codec_helper.audio_to_codes(wav, wav_len) + + use_lang = bool(getattr(model, "add_language_to_context_text", False)) + ctx_text = f"[{args.language.upper()}]" if use_lang else "[NO TEXT CONTEXT]" + ctx_text_ids = model.tokenizer.encode(ctx_text, tokenizer_name=model.text_conditioning_tokenizer_name) + + ctx_toks = torch.tensor([ctx_text_ids], dtype=torch.long, device=device).expand(B, -1) + ctx_toks_lens = torch.tensor([len(ctx_text_ids)] * B, dtype=torch.long, device=device) + + state = model.streaming_init( + context_audio_codes=codes, + context_audio_codes_lens=codes_lens, + context_text_tokens=ctx_toks, + context_text_tokens_lens=ctx_toks_lens, + use_cfg=args.use_cfg, + cfg_scale=args.cfg_scale, + use_local_transformer=True, + temperature=args.temperature, + topk=args.topk, + phoneme_input_type="pred", + phoneme_sampling_method="argmax", + use_inference_mode=True, + ) + + # --------------------------------------------------------- + # EXPOSED BATCHED GENERATION LOOP (Ready for multi-turn edits!) + # --------------------------------------------------------- + text = inputs["input_ids"] + text_lens = inputs["input_lengths"] + + gen_step = 0 + while not state.finished.all() and len(state.all_predictions) < args.max_tts_steps: + gen_step += 1 + + # Fetch current token dynamically based on state.text_tokens_seen + positions = state.text_tokens_seen.clamp(max=text.size(1) - 1) + current_tokens = text[torch.arange(B, device=device), positions] + + # Mask out sequences that have finished their true length + text_exhausted = state.text_tokens_seen >= text_lens + current_tokens = torch.where( + text_exhausted, torch.full_like(current_tokens, model.eos_id), current_tokens + ) + + # Feed tokens to the model step-by-step + state, audio_codes, _ = model.streaming_step( + state=state, + text_tokens=current_tokens, + use_inference_mode=True, + ) + + # Bulk Decode using the exposed state + finalize_output = model.streaming_finalize(state, use_inference_mode=True) + # --------------------------------------------------------- + + if args.debug_dtype and batch_id == 0: + report_dtype_stats(handles, stats, examples) + + with fp32_precision(): + # Grab output directly from streaming_finalize + audio_f32 = finalize_output.audio.float() + audio_len = finalize_output.audio_len.int() + + expected_audio_lens = (torch.tensor(inputs["target_num_frames"], device=device) * model.target_samples_per_frame).int() + audio_len = torch.min(audio_len, expected_audio_lens) + + metric_audio_pred = resample(audio_f32, getattr(model, "output_sample_rate", 24000), 16000) + metric_audio_pred_lens = (audio_len / getattr(model, "output_sample_rate", 24000) * 16000).to(torch.long) + + intelligibility.update( + name="dataset", + refs=inputs["raw_text"], + pred_audio=metric_audio_pred, + pred_audio_lens=metric_audio_pred_lens, + asr_hyps=None, + ) + + secs_metric.update( + name="dataset", + target_audio=resample(inputs["context_audio"].float(), model.target_sample_rate, 16000), + target_audio_lens=(inputs["context_audio_lengths"] / model.target_sample_rate * 16000).to(torch.long), + pred_audio=metric_audio_pred, + pred_audio_lens=metric_audio_pred_lens, + ) + + os.makedirs(args.out_dir, exist_ok=True) + audio_f32 = audio_f32.detach().cpu() + audio_len = audio_len.cpu() + + for i in range(B): + wav = audio_f32[i, : audio_len[i]].numpy() + target_path = inputs["target_audio_paths"][i] + base_name = os.path.basename(target_path) + out_path = os.path.join(args.out_dir, base_name) + sf.write(out_path, wav, samplerate=getattr(model, "output_sample_rate", 24000)) + logging.info(f"Saved: {out_path}") + + with fp32_precision(): + logging.info("\n--- Evaluation Metrics ---") + cer_wer = intelligibility.compute() + for k, m in cer_wer.items(): + logging.info(f"Intelligibility - {k}: {m}") + + secs_scores = secs_metric.compute() + for k, m in secs_scores.items(): + logging.info(f"SECS - {k}: {m}") + + +if __name__ == "__main__": + main() diff --git a/nemo/collections/tts/modules/magpietts_modules.py b/nemo/collections/tts/modules/magpietts_modules.py index 33666b02ac86..6edcc9c56e26 100644 --- a/nemo/collections/tts/modules/magpietts_modules.py +++ b/nemo/collections/tts/modules/magpietts_modules.py @@ -252,6 +252,11 @@ def forward(self, subword_ids: Tensor, subword_mask: Tensor | None = None) -> Te if subword_mask.ndim == 3: subword_mask = subword_mask.squeeze(-1) + if not subword_mask.any(): + B, T = subword_ids.shape + D = self.embed_tokens.embedding_dim + return torch.zeros((B, T, D), dtype=self.embed_tokens.weight.dtype, device=device) + char_ids, char_lengths = self.prepare_inputs(subword_ids, subword_mask) char_mask = get_mask_from_lengths(char_lengths) char_emb = self.embed_tokens(char_ids) From 61e2a3eff251bc4dc86ddc3b4c904cdb38f8b549 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 21 Apr 2026 09:06:07 -0700 Subject: [PATCH 018/109] Update inference script Signed-off-by: Edresson Casanova --- ...easy_magpietts_inference_multiturn copy.py | 751 ++++++++++++++++++ .../tts/easy_magpietts_inference_multiturn.py | 106 ++- 2 files changed, 800 insertions(+), 57 deletions(-) create mode 100644 examples/tts/easy_magpietts_inference_multiturn copy.py diff --git a/examples/tts/easy_magpietts_inference_multiturn copy.py b/examples/tts/easy_magpietts_inference_multiturn copy.py new file mode 100644 index 000000000000..a85c9bbb0791 --- /dev/null +++ b/examples/tts/easy_magpietts_inference_multiturn copy.py @@ -0,0 +1,751 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +""" +Evaluation script for custom EasyMagpieTTS models. +Features explicit Duplex (10x Padding) and Regular (Turn-by-turn) multi-turn modes. + +Usage: + python easy_magpietts_eval.py \ + --checkpoint_path=/path/to/magpie/model.ckpt \ + --codec_model_path=/path/to/codec/model.ckpt \ + --datasets_json_path=/path/to/evalset_config.jsonl \ + --out_dir=/path/to/out/audio \ + --batch_size=6 \ + --use_cfg +""" + +import argparse +import json +import os +from copy import deepcopy +from functools import partial + +import librosa +import soundfile as sf +import torch +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import DataLoader, Dataset +from omegaconf import OmegaConf, open_dict + +from nemo.collections.audio.parts.utils.transforms import resample +from nemo.collections.speechlm2.parts.metrics.asr_cer_wer import Intelligibility +from nemo.collections.speechlm2.parts.metrics.secs import SECS +from nemo.collections.speechlm2.parts.precision import fp32_precision +from nemo.utils import logging + +# --- EasyMagpieTTS Imports --- +from nemo.collections.tts.models import AudioCodecModel +from nemo.collections.tts.modules.audio_codec_modules import VectorQuantizerIndexConverter +from nemo.collections.tts.modules.magpietts_modules import CodecHelper +from nemo.collections.tts.models.easy_magpietts_inference import EasyMagpieTTSInferenceModel + +torch.set_float32_matmul_precision("medium") +torch.backends.cudnn.allow_tf32 = True +torch.backends.cuda.matmul.allow_tf32 = True + +if torch.cuda.is_available(): + torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0))) + + +def attach_dtype_counter(model): + handles = [] + stats = {} + examples = {} + + def is_leaf(module): + return len(list(module.children())) == 0 + + def get_dtype(x): + if torch.is_tensor(x): + return str(x.dtype) + elif isinstance(x, (list, tuple)): + for t in x: + if torch.is_tensor(t): + return str(t.dtype) + return "other" + + def get_module_group(name): + return name.split(".")[0] if "." in name else name + + def hook_fn(name): + def fn(module, inputs, outputs): + dtype = get_dtype(outputs) + if dtype not in ["torch.float16", "torch.bfloat16", "torch.float32"]: + dtype = "other" + group = get_module_group(name) + if group not in stats: + stats[group] = { + "torch.float16": 0, "torch.bfloat16": 0, "torch.float32": 0, "other": 0, + } + examples[group] = { + "torch.float16": [], "torch.bfloat16": [], "torch.float32": [], "other": [], + } + stats[group][dtype] += 1 + if len(examples[group][dtype]) < 3: + examples[group][dtype].append(module.__class__.__name__) + return fn + + for name, module in model.named_modules(): + if is_leaf(module): + handles.append(module.register_forward_hook(hook_fn(name))) + return handles, stats, examples + + +def report_dtype_stats(handles, stats, examples): + for h in handles: + h.remove() + logging.info("\n=== DTYPE USAGE PER MODULE ===") + for group, group_stats in stats.items(): + total = sum(group_stats.values()) + if total == 0: continue + logging.info(f"\n--- {group} ---") + for dtype, count in group_stats.items(): + if count > 0: + logging.info(f"{dtype}: {count} ({100*count/total:.2f}%)") + logging.info("\n=== EXAMPLES ===") + for group, group_examples in examples.items(): + logging.info(f"\n--- {group} ---") + for dtype, mods in group_examples.items(): + if mods: + logging.info(f"{dtype}: {mods}") + + +class EvalJSONLDataset(Dataset): + def __init__(self, file_path, num_turns=1): + self.samples = [] + raw_samples = [] + with open(file_path, "r", encoding="utf-8") as f: + for line_idx, line in enumerate(f, 1): + line = line.strip() + if not line: continue + try: + raw_samples.append(json.loads(line)) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON on line {line_idx}: {e}") + + if num_turns <= 1: + self.samples = raw_samples + return + + single_turn_by_speaker = {} + for sample in raw_samples: + if isinstance(sample["text"], list): + self.samples.append(sample) + else: + speaker = sample.get("speaker", "unknown") + if speaker not in single_turn_by_speaker: + single_turn_by_speaker[speaker] = [] + single_turn_by_speaker[speaker].append(sample) + + for speaker, speaker_samples in single_turn_by_speaker.items(): + buffer_texts, buffer_paths = [], [] + first_sample_meta = None + + for sample in speaker_samples: + if not buffer_texts: + first_sample_meta = dict(sample) + buffer_texts.append(sample["text"]) + buffer_paths.append(sample.get("audio_filepath", "")) + + if len(buffer_texts) == num_turns: + first_sample_meta["text"] = buffer_texts + base_names = [os.path.splitext(os.path.basename(p))[0] for p in buffer_paths if p] + ext = os.path.splitext(buffer_paths[-1])[1] if buffer_paths[-1] else "" + combined_name = "_".join(base_names) + ext + dir_name = os.path.dirname(first_sample_meta.get("audio_filepath", "")) + if dir_name: + first_sample_meta["audio_filepath"] = os.path.join(dir_name, combined_name) + else: + first_sample_meta["audio_filepath"] = combined_name + + self.samples.append(first_sample_meta) + buffer_texts, buffer_paths, first_sample_meta = [], [], None + + if buffer_texts: + first_sample_meta["text"] = buffer_texts + base_names = [os.path.splitext(os.path.basename(p))[0] for p in buffer_paths if p] + ext = os.path.splitext(buffer_paths[-1])[1] if buffer_paths[-1] else "" + combined_name = "_".join(base_names) + ext + dir_name = os.path.dirname(first_sample_meta.get("audio_filepath", "")) + if dir_name: + first_sample_meta["audio_filepath"] = os.path.join(dir_name, combined_name) + else: + first_sample_meta["audio_filepath"] = combined_name + self.samples.append(first_sample_meta) + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + return self.samples[idx] + + +def collate_and_tokenize_custom( + batch, + model, + extra_duration_thrshould=1.3, + sample_rate=22050, + root_path=None, + emulate_duplex_inference=False, + add_interruption_token=False, + pad_factor_text_speech=10, + force_interruption=False, +): + main_tokenizer_name = list(model.cfg.text_tokenizers.keys())[0] + + # --- MULTI-TURN MODE DECISION --- + is_duplex = emulate_duplex_inference + + out_dict = { + "duplex_multiturn": is_duplex, + "regular_multiturn": not is_duplex, + } + + tokenized_list = [] + batched_turns = [] + batched_turn_lens = [] + valid_turn_masks = [] + + if is_duplex: + # ------------------------------------------------------------- + # DUPLEX MODE (Continuous sequence with 10x pad injection) + # ------------------------------------------------------------- + for s in batch: + text_data = s["text"] + + if isinstance(text_data, list): + full_ids = [] + for segment in text_data: + seg_ids = model.tokenizer.encode(segment, tokenizer_name=main_tokenizer_name) + [model.eos_id] + seg_len = len(seg_ids) + pad_len = seg_len * pad_factor_text_speech + pad_ids = [model.pad_id] * pad_len + + if force_interruption: + fname = s["audio_filepath"] + no_ext = fname.split(".")[0] + sample_id = int(no_ext.split("_")[-1]) + case = sample_id % 3 + + if case == 0: + if len(seg_ids) >= 2: + seg_ids[-2] = model.interruption_token_id + seg_ids[-1] = model.pad_id + else: + pad_ids[0] = model.interruption_token_id + elif case == 1: + eos_idx = min(6, len(pad_ids) - 1) + pad_ids[eos_idx] = model.interruption_token_id + else: + eos_idx = 0 + pad_ids[eos_idx] = model.interruption_token_id + else: + if add_interruption_token: + eos_idx = int(len(pad_ids) * 0.7) + pad_ids[eos_idx] = model.interruption_token_id + + full_ids.extend(seg_ids) + full_ids.extend(pad_ids) + + tokenized_list.append(torch.as_tensor(full_ids, dtype=torch.long)) + else: + tokenized_list.append( + torch.as_tensor(model.tokenizer.encode(text_data, tokenizer_name=main_tokenizer_name) + [model.eos_id], dtype=torch.long) + ) + + pad_len = 25 + prefix = torch.full((pad_len,), model.pad_id, dtype=torch.long) + for i in range(len(tokenized_list)): + tokenized_list[i] = torch.cat([prefix, tokenized_list[i]]) + input_lengths = torch.tensor([len(x) for x in tokenized_list], dtype=torch.long) + + input_ids = pad_sequence(tokenized_list, batch_first=True, padding_value=model.pad_id) + + out_dict["input_ids"] = input_ids + out_dict["input_lengths"] = input_lengths + + else: + # ------------------------------------------------------------- + # REGULAR MODE (Turn-by-turn discrete packaging) + # ------------------------------------------------------------- + max_turns = 1 + for s in batch: + if isinstance(s["text"], list): + max_turns = max(max_turns, len(s["text"])) + + for t in range(max_turns): + turn_t_tokens = [] + turn_t_lens = [] + turn_t_valid = [] + + for s in batch: + text_data = s["text"] + if isinstance(text_data, list): + if t < len(text_data): + seg_ids = model.tokenizer.encode(text_data[t], tokenizer_name=main_tokenizer_name) + [model.eos_id] + turn_t_tokens.append(torch.as_tensor(seg_ids, dtype=torch.long)) + turn_t_lens.append(len(seg_ids)) + turn_t_valid.append(True) + else: + # Dummy pad to keep shapes consistent for items with fewer turns + turn_t_tokens.append(torch.as_tensor([model.pad_id], dtype=torch.long)) + turn_t_lens.append(1) + turn_t_valid.append(False) + else: + if t == 0: + seg_ids = model.tokenizer.encode(text_data, tokenizer_name=main_tokenizer_name) + [model.eos_id] + turn_t_tokens.append(torch.as_tensor(seg_ids, dtype=torch.long)) + turn_t_lens.append(len(seg_ids)) + turn_t_valid.append(True) + else: + turn_t_tokens.append(torch.as_tensor([model.pad_id], dtype=torch.long)) + turn_t_lens.append(1) + turn_t_valid.append(False) + + padded_turn_t = pad_sequence(turn_t_tokens, batch_first=True, padding_value=model.pad_id) + batched_turns.append(padded_turn_t) + batched_turn_lens.append(torch.tensor(turn_t_lens, dtype=torch.long)) + valid_turn_masks.append(torch.tensor(turn_t_valid, dtype=torch.bool)) + + out_dict["batched_turns"] = batched_turns + out_dict["batched_turn_lens"] = batched_turn_lens + out_dict["valid_turn_masks"] = valid_turn_masks + + # --- AUDIO LOADING --- + audio_list = [] + audio_lengths = [] + target_num_frames = [] + + for i, s in enumerate(batch): + audio_path = s["context_audio_filepath"] + if root_path is not None: + audio_path = os.path.join(root_path, audio_path) + + if os.path.exists(audio_path): + wav, sr = librosa.load(audio_path, sr=sample_rate, mono=True) + wav = torch.as_tensor(wav, dtype=torch.float32) + else: + wav = torch.zeros(1, dtype=torch.float32) + + audio_list.append(wav) + audio_lengths.append(len(wav)) + + tdur_audio_path = s["audio_filepath"] + if root_path is not None: + tdur_audio_path = os.path.join(root_path, tdur_audio_path) + + if tdur_audio_path and os.path.exists(tdur_audio_path): + wav_dur, sr_ = librosa.load(tdur_audio_path, sr=sample_rate, mono=True) + tdur = wav_dur.shape[0] // model.input_samples_per_frame + target_num_frames.append(tdur * extra_duration_thrshould) + else: + # Fallback estimation + if is_duplex: + current_text_len = len(tokenized_list[i]) + if isinstance(s["text"], list): + target_num_frames.append(current_text_len) + else: + target_num_frames.append(current_text_len * 5) + else: + target_num_frames.append(sum([l[i].item() for l in batched_turn_lens]) * 5) + + max_audio_len = max(audio_lengths) + B = len(audio_lengths) + padded_audio = torch.zeros((B, max_audio_len), dtype=torch.float32) + + for i, wav in enumerate(audio_list): + padded_audio[i, : len(wav)] = wav + + out_dict["context_audio"] = padded_audio + out_dict["context_audio_lengths"] = torch.tensor(audio_lengths, dtype=torch.long) + out_dict["target_audio_paths"] = [s["audio_filepath"] for s in batch] + out_dict["target_num_frames"] = target_num_frames + + out_dict["raw_text"] = [" ".join(s["text"]) if isinstance(s["text"], list) else s["text"] for s in batch] + + return out_dict + + +def main(): + parser = argparse.ArgumentParser(description="EasyMagpieTTS Inference Evaluation") + + # Required Paths + parser.add_argument("--checkpoint_path", type=str, required=True, help="Path to the EasyMagpie model") + parser.add_argument("--codec_model_path", type=str, required=True, help="Path to the audio codec") + parser.add_argument("--datasets_json_path", type=str, required=True, help="Path to JSONL data") + parser.add_argument("--out_dir", type=str, required=True, help="Directory to save audio outputs") + + # Optional Paths & General + parser.add_argument("--phoneme_tokenizer_path", type=str, default=None) + parser.add_argument("--audio_dir", type=str, default=None, help="Root dir for audio paths in JSONL") + parser.add_argument("--inference_dtype", type=str, default="float16") + parser.add_argument("--debug_dtype", action="store_true") + + # Dataloader & Batching + parser.add_argument("--batch_size", type=int, default=6) + parser.add_argument("--num_workers", type=int, default=4) + parser.add_argument("--num_turns", type=int, default=1) + parser.add_argument("--pad_factor_text_speech", type=int, default=10) + + # Text Processing Boolean Flags + parser.add_argument("--emulate_duplex_inference", action="store_true") + parser.add_argument("--add_interruption_token", action="store_true") + parser.add_argument("--force_interruption", action="store_true") + + # Speaker & Prompt Configurations + parser.add_argument("--user_custom_speaker_reference", action="store_true") + parser.add_argument("--inference_speaker_reference", type=str, default=None) + parser.add_argument("--language", type=str, default="en") + + # Generation Kwargs + parser.add_argument("--use_cfg", action="store_true") + parser.add_argument("--cfg_scale", type=float, default=2.5) + parser.add_argument("--temperature", type=float, default=0.7) + parser.add_argument("--topk", type=int, default=80) + parser.add_argument("--max_tts_steps", type=int, default=2000) + parser.add_argument("--force_speech_sil_codes", action="store_true") + + args = parser.parse_args() + + distributed = int(os.environ.get("WORLD_SIZE", "1")) > 1 + if distributed and not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend="nccl") + + target_device = torch.device(f"cuda:{int(os.environ.get('LOCAL_RANK', 0))}" if torch.cuda.is_available() else "cpu") + target_dtype = getattr(torch, args.inference_dtype) + torch.set_default_dtype(target_dtype) + + model_cfg = EasyMagpieTTSInferenceModel.restore_from(args.checkpoint_path, return_config=True) + with open_dict(model_cfg): + model_cfg.target = 'nemo.collections.tts.models.easy_magpietts_inference.EasyMagpieTTSInferenceModel' + model_cfg.codecmodel_path = args.codec_model_path + model_cfg.train_ds = None + model_cfg.validation_ds = None + model_cfg.run_val_inference = False + model_cfg.use_utmos = False + model_cfg.use_meta_init_for_decoder = True + + # --- MISSING FIX: Guarantees silence for pad tokens --- + model_cfg.use_multiturn_dataset = True + + if args.phoneme_tokenizer_path and getattr(model_cfg, "phoneme_tokenizer", None) is not None: + model_cfg.phoneme_tokenizer.tokenizer_path = args.phoneme_tokenizer_path + + # --- MISSING FIX: Load to CPU first to prevent OOM --- + model = EasyMagpieTTSInferenceModel.restore_from( + args.checkpoint_path, override_config_path=model_cfg, map_location=torch.device("cpu") + ) + model.use_kv_cache_for_inference = True + model.to(dtype=target_dtype) + model.eval().to(target_device) + + # --- DATALOADER COMPATIBILITY PATCHES --- + model.input_samples_per_frame = int(model.codec_model_samples_per_frame * model.frame_stacking_factor) + model.target_samples_per_frame = model.input_samples_per_frame / (model.sample_rate / model.output_sample_rate) + + # --- MISSING FIX: Load to CPU first to prevent OOM --- + codec_model = AudioCodecModel.restore_from(args.codec_model_path, strict=False, map_location=torch.device("cpu")) + if hasattr(codec_model, "discriminator"): + del codec_model.discriminator + codec_model.freeze() + codec_model = codec_model.to(target_device).eval() + + + codec_converter = None + if getattr(model, "_codec_converter", None) is not None: + vq_new = deepcopy(model._codec_converter.vector_quantizer_new).to(target_device).eval() + codec_converter = VectorQuantizerIndexConverter( + vector_quantizer_original=codec_model.vector_quantizer, + vector_quantizer_new=vq_new, + ).to(target_device).eval() + + if not hasattr(model, "_codec_helper") or model._codec_helper is None: + model._codec_helper = CodecHelper(codec_model=codec_model, codec_converter=codec_converter) + + from collections import Counter + def get_codec_silence_frame(model, device, target_sample_rate): + # Generate long zero waveform (silence) + audio = torch.zeros(1, 10 * target_sample_rate).float().to(device) + audio_len = torch.tensor([audio.size(-1)]).long().to(device) + + sil_codes, sil_codes_lens = model._codec_helper.audio_to_codes( + audio, audio_len + ) + + if model._codec_converter is not None: + sil_codes = model._codec_converter.convert_original_to_new( + audio_tokens=sil_codes, audio_lens=sil_codes_lens + ).long() + + # sil_codes is shape [1, C, T]. + # Extract batch index 0 and transpose to shape [T, C] + frames = sil_codes[0].transpose(0, 1) + + # Convert each time frame (C integers) into a tuple of integers + combos = [tuple(frame.tolist()) for frame in frames] + + # Count frequencies + counter = Counter(combos) + + # Pick the most common combination + most_common_combo, freq = counter.most_common(1)[0] + + # Return as tensor [C] + return torch.tensor(most_common_combo, device=device, dtype=torch.long) + + codec_sil_codes = get_codec_silence_frame(model, target_device, model.sample_rate) + + if args.debug_dtype: + handles, stats, examples = attach_dtype_counter(model) + + with fp32_precision(): + intelligibility = Intelligibility("stt_en_fastconformer_transducer_large", reuse_asr_hyps=False).reset() + secs_metric = SECS("titanet_large").reset() + + eval_dataset = EvalJSONLDataset(args.datasets_json_path, num_turns=args.num_turns) + + collate_fn = partial( + collate_and_tokenize_custom, + model=model, + extra_duration_thrshould=1.5, + sample_rate=model.sample_rate, + root_path=args.audio_dir, + emulate_duplex_inference=args.emulate_duplex_inference, + add_interruption_token=args.add_interruption_token, + pad_factor_text_speech=args.pad_factor_text_speech, + force_interruption=args.force_interruption, + ) + + dataloader = DataLoader( + dataset=eval_dataset, batch_size=args.batch_size, collate_fn=collate_fn, + num_workers=args.num_workers, pin_memory=True, shuffle=False, drop_last=False, + ) + + if args.user_custom_speaker_reference and args.inference_speaker_reference: + wav, sr = librosa.load(args.inference_speaker_reference, sr=model.sample_rate, mono=True) + speaker_wav = torch.as_tensor(wav, dtype=target_dtype).unsqueeze(0).to(model.device) + + for batch_id, inputs in enumerate(dataloader): + B = inputs["context_audio"].size(0) + device = model.device + + inputs["context_audio"] = inputs["context_audio"].to(device, dtype=target_dtype) + inputs["context_audio_lengths"] = inputs["context_audio_lengths"].to(device) + + if args.user_custom_speaker_reference and args.inference_speaker_reference: + inputs["context_audio"] = speaker_wav.expand(B, *speaker_wav.shape[1:]) + inputs["context_audio_lengths"][:] = speaker_wav.size(-1) + + with torch.inference_mode(): + # 1. Base Initialization (Shared between modes) + wav = inputs["context_audio"] + wav_len = inputs["context_audio_lengths"] + codes, codes_lens = model._codec_helper.audio_to_codes(wav, wav_len) + + use_lang = bool(getattr(model, "add_language_to_context_text", False)) + ctx_text = f"[{args.language.upper()}]" if use_lang else "[NO TEXT CONTEXT]" + ctx_text_ids = model.tokenizer.encode(ctx_text, tokenizer_name=model.text_conditioning_tokenizer_name) + + ctx_toks = torch.tensor([ctx_text_ids], dtype=torch.long, device=device).expand(B, -1) + ctx_toks_lens = torch.tensor([len(ctx_text_ids)] * B, dtype=torch.long, device=device) + + state = model.streaming_init( + context_audio_codes=codes, + context_audio_codes_lens=codes_lens, + context_text_tokens=ctx_toks, + context_text_tokens_lens=ctx_toks_lens, + use_cfg=args.use_cfg, + cfg_scale=args.cfg_scale, + use_local_transformer=True, + temperature=args.temperature, + topk=args.topk, + phoneme_input_type="pred", + phoneme_sampling_method="argmax", + use_inference_mode=True, + ) + + # --------------------------------------------------------- + # MODE 1: DUPLEX (Continuous Padding Token Stream) + # --------------------------------------------------------- + if inputs["duplex_multiturn"]: + text = inputs["input_ids"].to(device) + text_lens = inputs["input_lengths"].to(device) + # print("Text lens:", text.shape, text_lens) + # Fetch the true silence frame codebook combo once + codec_sil_codes = get_codec_silence_frame(model, device, model.sample_rate) + + # Trackers for our two forced-silence zones + in_initial_silence = torch.ones(B, dtype=torch.bool, device=device) + in_post_speech_silence = torch.zeros(B, dtype=torch.bool, device=device) + + text_exhausted = state.text_tokens_seen >= text_lens + while not text_exhausted.all(): + # 1. WAKE UP OVERRIDE: Keep the text pipeline awake to read pads! + state.finished = state.finished & text_exhausted + state.text_finished = state.text_finished & text_exhausted + if hasattr(state, "phoneme_stream_ended"): + state.phoneme_stream_ended = state.phoneme_stream_ended & text_exhausted + + # 2. Safely index text using the model's internal pointer + positions = state.text_tokens_seen.clamp(max=text.size(1) - 1) + current_tokens = text[torch.arange(B, device=device), positions] + + current_tokens = torch.where( + text_exhausted, torch.full_like(current_tokens, model.eos_id), current_tokens + ) + + # 3. Update our trackers BEFORE the step + is_pad_or_eos = (current_tokens == model.pad_id) | (current_tokens == model.eos_id) + + # Initial silence turns off forever once a real word is seen + in_initial_silence = in_initial_silence & is_pad_or_eos + + # Post-speech silence turns off when a real word for the NEXT turn is seen + in_post_speech_silence = in_post_speech_silence & is_pad_or_eos + + # 4. Step the model + state, audio_codes, _ = model.streaming_step(state=state, text_tokens=current_tokens, use_inference_mode=True) + + # 5. SILENCE FORCING INJECTION (Moved ABOVE the trigger!) + if audio_codes is not None and args.force_speech_sil_codes: + force_silence_mask = in_initial_silence | in_post_speech_silence + + if force_silence_mask.any(): + # Expand silence codes [C] -> [1, C, 1] to match audio_codes [B, C, 1] + expanded_sil = codec_sil_codes.view(1, -1, 1).expand_as(audio_codes) + + # Expand mask [B] -> [B, 1, 1] for broadcasting + mask_3d = force_silence_mask.view(B, 1, 1) + + # Overwrite the prediction with silence codes where the mask is True. + overwritten_codes = torch.where(mask_3d, expanded_sil, audio_codes) + + # Inject back into the model's KV cache history + state.all_predictions[-1] = overwritten_codes + + # 6. TRIGGER POST-SPEECH SILENCE FOR THE *NEXT* FRAME + # If the audio decoder naturally predicted a speech EOS, state.finished becomes True here. + # We update the tracker AFTER injection so we don't accidentally overwrite the EOS token! + in_post_speech_silence = in_post_speech_silence | state.finished + + # Update exhaustion tracker for the next iteration + text_exhausted = state.text_tokens_seen >= text_lens + + + # --------------------------------------------------------- + # MODE 2: REGULAR (Turn-by-Turn Re-wakes) + # --------------------------------------------------------- + elif inputs["regular_multiturn"]: + batched_turns = inputs["batched_turns"] + batched_turn_lens = inputs["batched_turn_lens"] + valid_turn_masks = inputs["valid_turn_masks"] + + max_turns = len(batched_turns) + + # Tracking offset ensures sync regardless of context audio length + turn_offsets = torch.zeros(B, dtype=torch.long, device=device) + + for t in range(max_turns): + turn_text = batched_turns[t].to(device) + turn_lens = batched_turn_lens[t].to(device) + valid_mask = valid_turn_masks[t].to(device) + + # Reset ALL finished flags for items participating in this new turn + state.finished = state.finished & (~valid_mask) + state.text_finished = state.text_finished & (~valid_mask) + + if hasattr(state, "phoneme_stream_ended"): + state.phoneme_stream_ended = state.phoneme_stream_ended & (~valid_mask) + + if state.finished.all(): + continue + + # Record internal token count at start of turn + turn_offsets = torch.where(valid_mask, state.text_tokens_seen, turn_offsets) + turn_steps = 0 + + while not state.finished.all() and turn_steps < args.max_tts_steps: + turn_steps += 1 + + # Fetch token synced relative to the model's progress + relative_positions = state.text_tokens_seen - turn_offsets + positions = relative_positions.clamp(min=0, max=turn_text.size(1) - 1) + current_tokens = turn_text[torch.arange(B, device=device), positions] + + # Once the text for this turn is fully fed, feed EOS so audio can finish + exhausted = relative_positions >= turn_lens + current_tokens = torch.where(exhausted, torch.full_like(current_tokens, model.eos_id), current_tokens) + + state, _, _ = model.streaming_step(state=state, text_tokens=current_tokens, use_inference_mode=True) + + if inputs["duplex_multiturn"]: + # Erase the internal memory of Turn 1's EOS token so `streaming_finalize` + # decodes the entire physical sequence! + state.audio_prediction_end_idx.fill_(-1) + + # Finalize decodes the collected Codec states globally regardless of which loop was run + finalize_output = model.streaming_finalize(state, use_inference_mode=True) + + if args.debug_dtype and batch_id == 0: + report_dtype_stats(handles, stats, examples) + + with fp32_precision(): + audio_f32 = finalize_output.audio.float() + audio_len = finalize_output.audio_len.int() + + expected_audio_lens = (torch.tensor(inputs["target_num_frames"], device=device) * model.target_samples_per_frame).int() + + if inputs["duplex_multiturn"]: + # Cap the expected length so it physically cannot exceed the actual generated tensor size + # expected_audio_lens = expected_audio_lens.clamp(max=audio_f32.size(1)) + # audio_len = torch.max(audio_len, expected_audio_lens) + # audio_len = torch.full_like(audio_len, audio_f32.size(1)) + print(text_lens, text_lens * model.target_samples_per_frame) + audio_len = (text_lens * model.target_samples_per_frame).int() + else: + audio_len = torch.min(audio_len, expected_audio_lens) + + metric_audio_pred = resample(audio_f32, model.output_sample_rate, 16000) + metric_audio_pred_lens = (audio_len / model.output_sample_rate * 16000).to(torch.long) + + intelligibility.update( + name="dataset", + refs=inputs["raw_text"], + pred_audio=metric_audio_pred, + pred_audio_lens=metric_audio_pred_lens, + asr_hyps=None, + ) + + secs_metric.update( + name="dataset", + target_audio=resample(inputs["context_audio"].float(), model.sample_rate, 16000), + target_audio_lens=(inputs["context_audio_lengths"] / model.sample_rate * 16000).to(torch.long), + pred_audio=metric_audio_pred, + pred_audio_lens=metric_audio_pred_lens, + ) + + os.makedirs(args.out_dir, exist_ok=True) + audio_f32 = audio_f32.detach().cpu() + audio_len = audio_len.cpu() + + for i in range(B): + wav = audio_f32[i, : audio_len[i]].numpy() + target_path = inputs["target_audio_paths"][i] + base_name = os.path.basename(target_path) + out_path = os.path.join(args.out_dir, base_name) + sf.write(out_path, wav, samplerate=model.output_sample_rate) + logging.info(f"Saved: {out_path}") + + with fp32_precision(): + logging.info("\n--- Evaluation Metrics ---") + cer_wer = intelligibility.compute() + for k, m in cer_wer.items(): + logging.info(f"Intelligibility - {k}: {m}") + + secs_scores = secs_metric.compute() + for k, m in secs_scores.items(): + logging.info(f"SECS - {k}: {m}") + + +if __name__ == "__main__": + main() diff --git a/examples/tts/easy_magpietts_inference_multiturn.py b/examples/tts/easy_magpietts_inference_multiturn.py index 7e34485d16a3..a7e22bd32458 100644 --- a/examples/tts/easy_magpietts_inference_multiturn.py +++ b/examples/tts/easy_magpietts_inference_multiturn.py @@ -287,7 +287,6 @@ def collate_and_tokenize_custom( turn_t_lens.append(len(seg_ids)) turn_t_valid.append(True) else: - # Dummy pad to keep shapes consistent for items with fewer turns turn_t_tokens.append(torch.as_tensor([model.pad_id], dtype=torch.long)) turn_t_lens.append(1) turn_t_valid.append(False) @@ -336,7 +335,7 @@ def collate_and_tokenize_custom( if tdur_audio_path and os.path.exists(tdur_audio_path): wav_dur, sr_ = librosa.load(tdur_audio_path, sr=sample_rate, mono=True) - tdur = wav_dur.shape[0] // model.target_samples_per_frame + tdur = wav_dur.shape[0] // model.input_samples_per_frame target_num_frames.append(tdur * extra_duration_thrshould) else: # Fallback estimation @@ -360,7 +359,7 @@ def collate_and_tokenize_custom( out_dict["context_audio_lengths"] = torch.tensor(audio_lengths, dtype=torch.long) out_dict["target_audio_paths"] = [s["audio_filepath"] for s in batch] out_dict["target_num_frames"] = target_num_frames - + out_dict["raw_text"] = [" ".join(s["text"]) if isinstance(s["text"], list) else s["text"] for s in batch] return out_dict @@ -424,14 +423,14 @@ def main(): model_cfg.run_val_inference = False model_cfg.use_utmos = False model_cfg.use_meta_init_for_decoder = True - - # --- MISSING FIX: Guarantees silence for pad tokens --- + + # Guarantees silence for pad tokens model_cfg.use_multiturn_dataset = True if args.phoneme_tokenizer_path and getattr(model_cfg, "phoneme_tokenizer", None) is not None: model_cfg.phoneme_tokenizer.tokenizer_path = args.phoneme_tokenizer_path - # --- MISSING FIX: Load to CPU first to prevent OOM --- + # Load to CPU first to prevent OOM model = EasyMagpieTTSInferenceModel.restore_from( args.checkpoint_path, override_config_path=model_cfg, map_location=torch.device("cpu") ) @@ -440,17 +439,16 @@ def main(): model.eval().to(target_device) # --- DATALOADER COMPATIBILITY PATCHES --- - model.target_samples_per_frame = int(model.codec_model_samples_per_frame * model.frame_stacking_factor) - model.target_sample_rate = model.sample_rate + model.input_samples_per_frame = int(model.codec_model_samples_per_frame * model.frame_stacking_factor) + model.target_samples_per_frame = model.input_samples_per_frame / (model.sample_rate / model.output_sample_rate) - # --- MISSING FIX: Load to CPU first to prevent OOM --- + # Load to CPU first to prevent OOM codec_model = AudioCodecModel.restore_from(args.codec_model_path, strict=False, map_location=torch.device("cpu")) if hasattr(codec_model, "discriminator"): del codec_model.discriminator codec_model.freeze() codec_model = codec_model.to(target_device).eval() - codec_converter = None if getattr(model, "_codec_converter", None) is not None: vq_new = deepcopy(model._codec_converter.vector_quantizer_new).to(target_device).eval() @@ -464,7 +462,6 @@ def main(): from collections import Counter def get_codec_silence_frame(model, device, target_sample_rate): - # Generate long zero waveform (silence) audio = torch.zeros(1, 10 * target_sample_rate).float().to(device) audio_len = torch.tensor([audio.size(-1)]).long().to(device) @@ -477,20 +474,10 @@ def get_codec_silence_frame(model, device, target_sample_rate): audio_tokens=sil_codes, audio_lens=sil_codes_lens ).long() - # sil_codes is shape [1, C, T]. - # Extract batch index 0 and transpose to shape [T, C] frames = sil_codes[0].transpose(0, 1) - - # Convert each time frame (C integers) into a tuple of integers combos = [tuple(frame.tolist()) for frame in frames] - - # Count frequencies counter = Counter(combos) - - # Pick the most common combination most_common_combo, freq = counter.most_common(1)[0] - - # Return as tensor [C] return torch.tensor(most_common_combo, device=device, dtype=torch.long) codec_sil_codes = get_codec_silence_frame(model, target_device, model.sample_rate) @@ -508,7 +495,7 @@ def get_codec_silence_frame(model, device, target_sample_rate): collate_and_tokenize_custom, model=model, extra_duration_thrshould=1.5, - sample_rate=model.target_sample_rate, + sample_rate=model.sample_rate, root_path=args.audio_dir, emulate_duplex_inference=args.emulate_duplex_inference, add_interruption_token=args.add_interruption_token, @@ -522,7 +509,7 @@ def get_codec_silence_frame(model, device, target_sample_rate): ) if args.user_custom_speaker_reference and args.inference_speaker_reference: - wav, sr = librosa.load(args.inference_speaker_reference, sr=model.target_sample_rate, mono=True) + wav, sr = librosa.load(args.inference_speaker_reference, sr=model.sample_rate, mono=True) speaker_wav = torch.as_tensor(wav, dtype=target_dtype).unsqueeze(0).to(model.device) for batch_id, inputs in enumerate(dataloader): @@ -537,7 +524,6 @@ def get_codec_silence_frame(model, device, target_sample_rate): inputs["context_audio_lengths"][:] = speaker_wav.size(-1) with torch.inference_mode(): - # 1. Base Initialization (Shared between modes) wav = inputs["context_audio"] wav_len = inputs["context_audio_lengths"] codes, codes_lens = model._codec_helper.audio_to_codes(wav, wav_len) @@ -570,20 +556,14 @@ def get_codec_silence_frame(model, device, target_sample_rate): if inputs["duplex_multiturn"]: text = inputs["input_ids"].to(device) text_lens = inputs["input_lengths"].to(device) - - # Fetch the true silence frame codebook combo once - codec_sil_codes = get_codec_silence_frame(model, device, model.sample_rate) # Trackers for our two forced-silence zones in_initial_silence = torch.ones(B, dtype=torch.bool, device=device) in_post_speech_silence = torch.zeros(B, dtype=torch.bool, device=device) text_exhausted = state.text_tokens_seen >= text_lens - - while not text_exhausted.all() and len(state.all_predictions) < args.max_tts_steps: - + while not text_exhausted.all(): # 1. WAKE UP OVERRIDE: Keep the text pipeline awake to read pads! - # Note: This forces state.finished to False at the start of the loop state.finished = state.finished & text_exhausted state.text_finished = state.text_finished & text_exhausted if hasattr(state, "phoneme_stream_ended"): @@ -599,38 +579,29 @@ def get_codec_silence_frame(model, device, target_sample_rate): # 3. Update our trackers BEFORE the step is_pad_or_eos = (current_tokens == model.pad_id) | (current_tokens == model.eos_id) - - # Initial silence turns off forever once a real word is seen in_initial_silence = in_initial_silence & is_pad_or_eos - - # Post-speech silence turns off when a real word for the NEXT turn is seen in_post_speech_silence = in_post_speech_silence & is_pad_or_eos # 4. Step the model state, audio_codes, _ = model.streaming_step(state=state, text_tokens=current_tokens, use_inference_mode=True) - # 5. TRIGGER POST-SPEECH SILENCE: - # If the audio decoder naturally predicted a speech EOS, state.finished becomes True here! - in_post_speech_silence = in_post_speech_silence | state.finished - - # 6. SILENCE FORCING INJECTION + # 5. SILENCE FORCING INJECTION if audio_codes is not None and args.force_speech_sil_codes: - # We force silence if we are in the initial prefix OR if the model has finished its sentence force_silence_mask = in_initial_silence | in_post_speech_silence if force_silence_mask.any(): # Expand silence codes [C] -> [1, C, 1] to match audio_codes [B, C, 1] expanded_sil = codec_sil_codes.view(1, -1, 1).expand_as(audio_codes) - # Expand mask [B] -> [B, 1, 1] for broadcasting mask_3d = force_silence_mask.view(B, 1, 1) - # Overwrite the prediction with silence codes where the mask is True. overwritten_codes = torch.where(mask_3d, expanded_sil, audio_codes) - # Inject back into the model's KV cache history state.all_predictions[-1] = overwritten_codes + # 6. TRIGGER POST-SPEECH SILENCE FOR THE *NEXT* FRAME + in_post_speech_silence = in_post_speech_silence | state.finished + # Update exhaustion tracker for the next iteration text_exhausted = state.text_tokens_seen >= text_lens @@ -643,8 +614,6 @@ def get_codec_silence_frame(model, device, target_sample_rate): valid_turn_masks = inputs["valid_turn_masks"] max_turns = len(batched_turns) - - # Tracking offset ensures sync regardless of context audio length turn_offsets = torch.zeros(B, dtype=torch.long, device=device) for t in range(max_turns): @@ -652,7 +621,6 @@ def get_codec_silence_frame(model, device, target_sample_rate): turn_lens = batched_turn_lens[t].to(device) valid_mask = valid_turn_masks[t].to(device) - # Reset ALL finished flags for items participating in this new turn state.finished = state.finished & (~valid_mask) state.text_finished = state.text_finished & (~valid_mask) @@ -662,24 +630,46 @@ def get_codec_silence_frame(model, device, target_sample_rate): if state.finished.all(): continue - # Record internal token count at start of turn turn_offsets = torch.where(valid_mask, state.text_tokens_seen, turn_offsets) turn_steps = 0 while not state.finished.all() and turn_steps < args.max_tts_steps: turn_steps += 1 - # Fetch token synced relative to the model's progress relative_positions = state.text_tokens_seen - turn_offsets positions = relative_positions.clamp(min=0, max=turn_text.size(1) - 1) current_tokens = turn_text[torch.arange(B, device=device), positions] - # Once the text for this turn is fully fed, feed EOS so audio can finish exhausted = relative_positions >= turn_lens current_tokens = torch.where(exhausted, torch.full_like(current_tokens, model.eos_id), current_tokens) state, _, _ = model.streaming_step(state=state, text_tokens=current_tokens, use_inference_mode=True) + # Scrub Special Tokens (BOS/EOS) from Audio Codes --- + # Because we force-decode the entire uncropped sequence, any BOS or EOS + # tokens left in the array will produce loud artifacts in the codec. + bos_id = model.audio_bos_id + eos_id = model.audio_eos_id + sil_injection = codec_sil_codes.view(1, -1, 1) + + for step_idx in range(len(state.all_predictions)): + pred = state.all_predictions[step_idx] + # Check if any codebook in the frame has a BOS or EOS token + mask = (pred == bos_id) | (pred == eos_id) + frame_mask = mask.any(dim=1, keepdim=True) + + if frame_mask.any(): + state.all_predictions[step_idx] = torch.where( + frame_mask, + sil_injection.expand_as(pred), + pred + ) + + if inputs["duplex_multiturn"]: + # Erase the internal memory of Turn 1's EOS token so `streaming_finalize` + # decodes the entire physical sequence! + state.audio_prediction_end_idx.fill_(-1) + # Finalize decodes the collected Codec states globally regardless of which loop was run finalize_output = model.streaming_finalize(state, use_inference_mode=True) @@ -693,14 +683,16 @@ def get_codec_silence_frame(model, device, target_sample_rate): expected_audio_lens = (torch.tensor(inputs["target_num_frames"], device=device) * model.target_samples_per_frame).int() if inputs["duplex_multiturn"]: + # Use exact math based on the output samples multiplier! + audio_len = (text_lens * model.target_samples_per_frame).int() + # Cap the expected length so it physically cannot exceed the actual generated tensor size - expected_audio_lens = expected_audio_lens.clamp(max=audio_f32.size(1)) - audio_len = torch.max(audio_len, expected_audio_lens) + audio_len = torch.min(audio_len, torch.tensor(audio_f32.size(1), device=device)) else: audio_len = torch.min(audio_len, expected_audio_lens) - metric_audio_pred = resample(audio_f32, getattr(model, "output_sample_rate", 24000), 16000) - metric_audio_pred_lens = (audio_len / getattr(model, "output_sample_rate", 24000) * 16000).to(torch.long) + metric_audio_pred = resample(audio_f32, model.output_sample_rate, 16000) + metric_audio_pred_lens = (audio_len / model.output_sample_rate * 16000).to(torch.long) intelligibility.update( name="dataset", @@ -712,8 +704,8 @@ def get_codec_silence_frame(model, device, target_sample_rate): secs_metric.update( name="dataset", - target_audio=resample(inputs["context_audio"].float(), model.target_sample_rate, 16000), - target_audio_lens=(inputs["context_audio_lengths"] / model.target_sample_rate * 16000).to(torch.long), + target_audio=resample(inputs["context_audio"].float(), model.sample_rate, 16000), + target_audio_lens=(inputs["context_audio_lengths"] / model.sample_rate * 16000).to(torch.long), pred_audio=metric_audio_pred, pred_audio_lens=metric_audio_pred_lens, ) @@ -727,7 +719,7 @@ def get_codec_silence_frame(model, device, target_sample_rate): target_path = inputs["target_audio_paths"][i] base_name = os.path.basename(target_path) out_path = os.path.join(args.out_dir, base_name) - sf.write(out_path, wav, samplerate=getattr(model, "output_sample_rate", 24000)) + sf.write(out_path, wav, samplerate=model.output_sample_rate) logging.info(f"Saved: {out_path}") with fp32_precision(): From 56edfbb58e67048333eb03d26fad8374a6cb1c27 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 21 Apr 2026 09:06:37 -0700 Subject: [PATCH 019/109] Remove unused codes Signed-off-by: Edresson Casanova --- ...easy_magpietts_inference_multiturn copy.py | 751 ------------------ .../easy_magpietts_inference_multiturn_old.py | 560 ------------- 2 files changed, 1311 deletions(-) delete mode 100644 examples/tts/easy_magpietts_inference_multiturn copy.py delete mode 100644 examples/tts/easy_magpietts_inference_multiturn_old.py diff --git a/examples/tts/easy_magpietts_inference_multiturn copy.py b/examples/tts/easy_magpietts_inference_multiturn copy.py deleted file mode 100644 index a85c9bbb0791..000000000000 --- a/examples/tts/easy_magpietts_inference_multiturn copy.py +++ /dev/null @@ -1,751 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -""" -Evaluation script for custom EasyMagpieTTS models. -Features explicit Duplex (10x Padding) and Regular (Turn-by-turn) multi-turn modes. - -Usage: - python easy_magpietts_eval.py \ - --checkpoint_path=/path/to/magpie/model.ckpt \ - --codec_model_path=/path/to/codec/model.ckpt \ - --datasets_json_path=/path/to/evalset_config.jsonl \ - --out_dir=/path/to/out/audio \ - --batch_size=6 \ - --use_cfg -""" - -import argparse -import json -import os -from copy import deepcopy -from functools import partial - -import librosa -import soundfile as sf -import torch -from torch.nn.utils.rnn import pad_sequence -from torch.utils.data import DataLoader, Dataset -from omegaconf import OmegaConf, open_dict - -from nemo.collections.audio.parts.utils.transforms import resample -from nemo.collections.speechlm2.parts.metrics.asr_cer_wer import Intelligibility -from nemo.collections.speechlm2.parts.metrics.secs import SECS -from nemo.collections.speechlm2.parts.precision import fp32_precision -from nemo.utils import logging - -# --- EasyMagpieTTS Imports --- -from nemo.collections.tts.models import AudioCodecModel -from nemo.collections.tts.modules.audio_codec_modules import VectorQuantizerIndexConverter -from nemo.collections.tts.modules.magpietts_modules import CodecHelper -from nemo.collections.tts.models.easy_magpietts_inference import EasyMagpieTTSInferenceModel - -torch.set_float32_matmul_precision("medium") -torch.backends.cudnn.allow_tf32 = True -torch.backends.cuda.matmul.allow_tf32 = True - -if torch.cuda.is_available(): - torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0))) - - -def attach_dtype_counter(model): - handles = [] - stats = {} - examples = {} - - def is_leaf(module): - return len(list(module.children())) == 0 - - def get_dtype(x): - if torch.is_tensor(x): - return str(x.dtype) - elif isinstance(x, (list, tuple)): - for t in x: - if torch.is_tensor(t): - return str(t.dtype) - return "other" - - def get_module_group(name): - return name.split(".")[0] if "." in name else name - - def hook_fn(name): - def fn(module, inputs, outputs): - dtype = get_dtype(outputs) - if dtype not in ["torch.float16", "torch.bfloat16", "torch.float32"]: - dtype = "other" - group = get_module_group(name) - if group not in stats: - stats[group] = { - "torch.float16": 0, "torch.bfloat16": 0, "torch.float32": 0, "other": 0, - } - examples[group] = { - "torch.float16": [], "torch.bfloat16": [], "torch.float32": [], "other": [], - } - stats[group][dtype] += 1 - if len(examples[group][dtype]) < 3: - examples[group][dtype].append(module.__class__.__name__) - return fn - - for name, module in model.named_modules(): - if is_leaf(module): - handles.append(module.register_forward_hook(hook_fn(name))) - return handles, stats, examples - - -def report_dtype_stats(handles, stats, examples): - for h in handles: - h.remove() - logging.info("\n=== DTYPE USAGE PER MODULE ===") - for group, group_stats in stats.items(): - total = sum(group_stats.values()) - if total == 0: continue - logging.info(f"\n--- {group} ---") - for dtype, count in group_stats.items(): - if count > 0: - logging.info(f"{dtype}: {count} ({100*count/total:.2f}%)") - logging.info("\n=== EXAMPLES ===") - for group, group_examples in examples.items(): - logging.info(f"\n--- {group} ---") - for dtype, mods in group_examples.items(): - if mods: - logging.info(f"{dtype}: {mods}") - - -class EvalJSONLDataset(Dataset): - def __init__(self, file_path, num_turns=1): - self.samples = [] - raw_samples = [] - with open(file_path, "r", encoding="utf-8") as f: - for line_idx, line in enumerate(f, 1): - line = line.strip() - if not line: continue - try: - raw_samples.append(json.loads(line)) - except json.JSONDecodeError as e: - raise ValueError(f"Invalid JSON on line {line_idx}: {e}") - - if num_turns <= 1: - self.samples = raw_samples - return - - single_turn_by_speaker = {} - for sample in raw_samples: - if isinstance(sample["text"], list): - self.samples.append(sample) - else: - speaker = sample.get("speaker", "unknown") - if speaker not in single_turn_by_speaker: - single_turn_by_speaker[speaker] = [] - single_turn_by_speaker[speaker].append(sample) - - for speaker, speaker_samples in single_turn_by_speaker.items(): - buffer_texts, buffer_paths = [], [] - first_sample_meta = None - - for sample in speaker_samples: - if not buffer_texts: - first_sample_meta = dict(sample) - buffer_texts.append(sample["text"]) - buffer_paths.append(sample.get("audio_filepath", "")) - - if len(buffer_texts) == num_turns: - first_sample_meta["text"] = buffer_texts - base_names = [os.path.splitext(os.path.basename(p))[0] for p in buffer_paths if p] - ext = os.path.splitext(buffer_paths[-1])[1] if buffer_paths[-1] else "" - combined_name = "_".join(base_names) + ext - dir_name = os.path.dirname(first_sample_meta.get("audio_filepath", "")) - if dir_name: - first_sample_meta["audio_filepath"] = os.path.join(dir_name, combined_name) - else: - first_sample_meta["audio_filepath"] = combined_name - - self.samples.append(first_sample_meta) - buffer_texts, buffer_paths, first_sample_meta = [], [], None - - if buffer_texts: - first_sample_meta["text"] = buffer_texts - base_names = [os.path.splitext(os.path.basename(p))[0] for p in buffer_paths if p] - ext = os.path.splitext(buffer_paths[-1])[1] if buffer_paths[-1] else "" - combined_name = "_".join(base_names) + ext - dir_name = os.path.dirname(first_sample_meta.get("audio_filepath", "")) - if dir_name: - first_sample_meta["audio_filepath"] = os.path.join(dir_name, combined_name) - else: - first_sample_meta["audio_filepath"] = combined_name - self.samples.append(first_sample_meta) - - def __len__(self): - return len(self.samples) - - def __getitem__(self, idx): - return self.samples[idx] - - -def collate_and_tokenize_custom( - batch, - model, - extra_duration_thrshould=1.3, - sample_rate=22050, - root_path=None, - emulate_duplex_inference=False, - add_interruption_token=False, - pad_factor_text_speech=10, - force_interruption=False, -): - main_tokenizer_name = list(model.cfg.text_tokenizers.keys())[0] - - # --- MULTI-TURN MODE DECISION --- - is_duplex = emulate_duplex_inference - - out_dict = { - "duplex_multiturn": is_duplex, - "regular_multiturn": not is_duplex, - } - - tokenized_list = [] - batched_turns = [] - batched_turn_lens = [] - valid_turn_masks = [] - - if is_duplex: - # ------------------------------------------------------------- - # DUPLEX MODE (Continuous sequence with 10x pad injection) - # ------------------------------------------------------------- - for s in batch: - text_data = s["text"] - - if isinstance(text_data, list): - full_ids = [] - for segment in text_data: - seg_ids = model.tokenizer.encode(segment, tokenizer_name=main_tokenizer_name) + [model.eos_id] - seg_len = len(seg_ids) - pad_len = seg_len * pad_factor_text_speech - pad_ids = [model.pad_id] * pad_len - - if force_interruption: - fname = s["audio_filepath"] - no_ext = fname.split(".")[0] - sample_id = int(no_ext.split("_")[-1]) - case = sample_id % 3 - - if case == 0: - if len(seg_ids) >= 2: - seg_ids[-2] = model.interruption_token_id - seg_ids[-1] = model.pad_id - else: - pad_ids[0] = model.interruption_token_id - elif case == 1: - eos_idx = min(6, len(pad_ids) - 1) - pad_ids[eos_idx] = model.interruption_token_id - else: - eos_idx = 0 - pad_ids[eos_idx] = model.interruption_token_id - else: - if add_interruption_token: - eos_idx = int(len(pad_ids) * 0.7) - pad_ids[eos_idx] = model.interruption_token_id - - full_ids.extend(seg_ids) - full_ids.extend(pad_ids) - - tokenized_list.append(torch.as_tensor(full_ids, dtype=torch.long)) - else: - tokenized_list.append( - torch.as_tensor(model.tokenizer.encode(text_data, tokenizer_name=main_tokenizer_name) + [model.eos_id], dtype=torch.long) - ) - - pad_len = 25 - prefix = torch.full((pad_len,), model.pad_id, dtype=torch.long) - for i in range(len(tokenized_list)): - tokenized_list[i] = torch.cat([prefix, tokenized_list[i]]) - input_lengths = torch.tensor([len(x) for x in tokenized_list], dtype=torch.long) - - input_ids = pad_sequence(tokenized_list, batch_first=True, padding_value=model.pad_id) - - out_dict["input_ids"] = input_ids - out_dict["input_lengths"] = input_lengths - - else: - # ------------------------------------------------------------- - # REGULAR MODE (Turn-by-turn discrete packaging) - # ------------------------------------------------------------- - max_turns = 1 - for s in batch: - if isinstance(s["text"], list): - max_turns = max(max_turns, len(s["text"])) - - for t in range(max_turns): - turn_t_tokens = [] - turn_t_lens = [] - turn_t_valid = [] - - for s in batch: - text_data = s["text"] - if isinstance(text_data, list): - if t < len(text_data): - seg_ids = model.tokenizer.encode(text_data[t], tokenizer_name=main_tokenizer_name) + [model.eos_id] - turn_t_tokens.append(torch.as_tensor(seg_ids, dtype=torch.long)) - turn_t_lens.append(len(seg_ids)) - turn_t_valid.append(True) - else: - # Dummy pad to keep shapes consistent for items with fewer turns - turn_t_tokens.append(torch.as_tensor([model.pad_id], dtype=torch.long)) - turn_t_lens.append(1) - turn_t_valid.append(False) - else: - if t == 0: - seg_ids = model.tokenizer.encode(text_data, tokenizer_name=main_tokenizer_name) + [model.eos_id] - turn_t_tokens.append(torch.as_tensor(seg_ids, dtype=torch.long)) - turn_t_lens.append(len(seg_ids)) - turn_t_valid.append(True) - else: - turn_t_tokens.append(torch.as_tensor([model.pad_id], dtype=torch.long)) - turn_t_lens.append(1) - turn_t_valid.append(False) - - padded_turn_t = pad_sequence(turn_t_tokens, batch_first=True, padding_value=model.pad_id) - batched_turns.append(padded_turn_t) - batched_turn_lens.append(torch.tensor(turn_t_lens, dtype=torch.long)) - valid_turn_masks.append(torch.tensor(turn_t_valid, dtype=torch.bool)) - - out_dict["batched_turns"] = batched_turns - out_dict["batched_turn_lens"] = batched_turn_lens - out_dict["valid_turn_masks"] = valid_turn_masks - - # --- AUDIO LOADING --- - audio_list = [] - audio_lengths = [] - target_num_frames = [] - - for i, s in enumerate(batch): - audio_path = s["context_audio_filepath"] - if root_path is not None: - audio_path = os.path.join(root_path, audio_path) - - if os.path.exists(audio_path): - wav, sr = librosa.load(audio_path, sr=sample_rate, mono=True) - wav = torch.as_tensor(wav, dtype=torch.float32) - else: - wav = torch.zeros(1, dtype=torch.float32) - - audio_list.append(wav) - audio_lengths.append(len(wav)) - - tdur_audio_path = s["audio_filepath"] - if root_path is not None: - tdur_audio_path = os.path.join(root_path, tdur_audio_path) - - if tdur_audio_path and os.path.exists(tdur_audio_path): - wav_dur, sr_ = librosa.load(tdur_audio_path, sr=sample_rate, mono=True) - tdur = wav_dur.shape[0] // model.input_samples_per_frame - target_num_frames.append(tdur * extra_duration_thrshould) - else: - # Fallback estimation - if is_duplex: - current_text_len = len(tokenized_list[i]) - if isinstance(s["text"], list): - target_num_frames.append(current_text_len) - else: - target_num_frames.append(current_text_len * 5) - else: - target_num_frames.append(sum([l[i].item() for l in batched_turn_lens]) * 5) - - max_audio_len = max(audio_lengths) - B = len(audio_lengths) - padded_audio = torch.zeros((B, max_audio_len), dtype=torch.float32) - - for i, wav in enumerate(audio_list): - padded_audio[i, : len(wav)] = wav - - out_dict["context_audio"] = padded_audio - out_dict["context_audio_lengths"] = torch.tensor(audio_lengths, dtype=torch.long) - out_dict["target_audio_paths"] = [s["audio_filepath"] for s in batch] - out_dict["target_num_frames"] = target_num_frames - - out_dict["raw_text"] = [" ".join(s["text"]) if isinstance(s["text"], list) else s["text"] for s in batch] - - return out_dict - - -def main(): - parser = argparse.ArgumentParser(description="EasyMagpieTTS Inference Evaluation") - - # Required Paths - parser.add_argument("--checkpoint_path", type=str, required=True, help="Path to the EasyMagpie model") - parser.add_argument("--codec_model_path", type=str, required=True, help="Path to the audio codec") - parser.add_argument("--datasets_json_path", type=str, required=True, help="Path to JSONL data") - parser.add_argument("--out_dir", type=str, required=True, help="Directory to save audio outputs") - - # Optional Paths & General - parser.add_argument("--phoneme_tokenizer_path", type=str, default=None) - parser.add_argument("--audio_dir", type=str, default=None, help="Root dir for audio paths in JSONL") - parser.add_argument("--inference_dtype", type=str, default="float16") - parser.add_argument("--debug_dtype", action="store_true") - - # Dataloader & Batching - parser.add_argument("--batch_size", type=int, default=6) - parser.add_argument("--num_workers", type=int, default=4) - parser.add_argument("--num_turns", type=int, default=1) - parser.add_argument("--pad_factor_text_speech", type=int, default=10) - - # Text Processing Boolean Flags - parser.add_argument("--emulate_duplex_inference", action="store_true") - parser.add_argument("--add_interruption_token", action="store_true") - parser.add_argument("--force_interruption", action="store_true") - - # Speaker & Prompt Configurations - parser.add_argument("--user_custom_speaker_reference", action="store_true") - parser.add_argument("--inference_speaker_reference", type=str, default=None) - parser.add_argument("--language", type=str, default="en") - - # Generation Kwargs - parser.add_argument("--use_cfg", action="store_true") - parser.add_argument("--cfg_scale", type=float, default=2.5) - parser.add_argument("--temperature", type=float, default=0.7) - parser.add_argument("--topk", type=int, default=80) - parser.add_argument("--max_tts_steps", type=int, default=2000) - parser.add_argument("--force_speech_sil_codes", action="store_true") - - args = parser.parse_args() - - distributed = int(os.environ.get("WORLD_SIZE", "1")) > 1 - if distributed and not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="nccl") - - target_device = torch.device(f"cuda:{int(os.environ.get('LOCAL_RANK', 0))}" if torch.cuda.is_available() else "cpu") - target_dtype = getattr(torch, args.inference_dtype) - torch.set_default_dtype(target_dtype) - - model_cfg = EasyMagpieTTSInferenceModel.restore_from(args.checkpoint_path, return_config=True) - with open_dict(model_cfg): - model_cfg.target = 'nemo.collections.tts.models.easy_magpietts_inference.EasyMagpieTTSInferenceModel' - model_cfg.codecmodel_path = args.codec_model_path - model_cfg.train_ds = None - model_cfg.validation_ds = None - model_cfg.run_val_inference = False - model_cfg.use_utmos = False - model_cfg.use_meta_init_for_decoder = True - - # --- MISSING FIX: Guarantees silence for pad tokens --- - model_cfg.use_multiturn_dataset = True - - if args.phoneme_tokenizer_path and getattr(model_cfg, "phoneme_tokenizer", None) is not None: - model_cfg.phoneme_tokenizer.tokenizer_path = args.phoneme_tokenizer_path - - # --- MISSING FIX: Load to CPU first to prevent OOM --- - model = EasyMagpieTTSInferenceModel.restore_from( - args.checkpoint_path, override_config_path=model_cfg, map_location=torch.device("cpu") - ) - model.use_kv_cache_for_inference = True - model.to(dtype=target_dtype) - model.eval().to(target_device) - - # --- DATALOADER COMPATIBILITY PATCHES --- - model.input_samples_per_frame = int(model.codec_model_samples_per_frame * model.frame_stacking_factor) - model.target_samples_per_frame = model.input_samples_per_frame / (model.sample_rate / model.output_sample_rate) - - # --- MISSING FIX: Load to CPU first to prevent OOM --- - codec_model = AudioCodecModel.restore_from(args.codec_model_path, strict=False, map_location=torch.device("cpu")) - if hasattr(codec_model, "discriminator"): - del codec_model.discriminator - codec_model.freeze() - codec_model = codec_model.to(target_device).eval() - - - codec_converter = None - if getattr(model, "_codec_converter", None) is not None: - vq_new = deepcopy(model._codec_converter.vector_quantizer_new).to(target_device).eval() - codec_converter = VectorQuantizerIndexConverter( - vector_quantizer_original=codec_model.vector_quantizer, - vector_quantizer_new=vq_new, - ).to(target_device).eval() - - if not hasattr(model, "_codec_helper") or model._codec_helper is None: - model._codec_helper = CodecHelper(codec_model=codec_model, codec_converter=codec_converter) - - from collections import Counter - def get_codec_silence_frame(model, device, target_sample_rate): - # Generate long zero waveform (silence) - audio = torch.zeros(1, 10 * target_sample_rate).float().to(device) - audio_len = torch.tensor([audio.size(-1)]).long().to(device) - - sil_codes, sil_codes_lens = model._codec_helper.audio_to_codes( - audio, audio_len - ) - - if model._codec_converter is not None: - sil_codes = model._codec_converter.convert_original_to_new( - audio_tokens=sil_codes, audio_lens=sil_codes_lens - ).long() - - # sil_codes is shape [1, C, T]. - # Extract batch index 0 and transpose to shape [T, C] - frames = sil_codes[0].transpose(0, 1) - - # Convert each time frame (C integers) into a tuple of integers - combos = [tuple(frame.tolist()) for frame in frames] - - # Count frequencies - counter = Counter(combos) - - # Pick the most common combination - most_common_combo, freq = counter.most_common(1)[0] - - # Return as tensor [C] - return torch.tensor(most_common_combo, device=device, dtype=torch.long) - - codec_sil_codes = get_codec_silence_frame(model, target_device, model.sample_rate) - - if args.debug_dtype: - handles, stats, examples = attach_dtype_counter(model) - - with fp32_precision(): - intelligibility = Intelligibility("stt_en_fastconformer_transducer_large", reuse_asr_hyps=False).reset() - secs_metric = SECS("titanet_large").reset() - - eval_dataset = EvalJSONLDataset(args.datasets_json_path, num_turns=args.num_turns) - - collate_fn = partial( - collate_and_tokenize_custom, - model=model, - extra_duration_thrshould=1.5, - sample_rate=model.sample_rate, - root_path=args.audio_dir, - emulate_duplex_inference=args.emulate_duplex_inference, - add_interruption_token=args.add_interruption_token, - pad_factor_text_speech=args.pad_factor_text_speech, - force_interruption=args.force_interruption, - ) - - dataloader = DataLoader( - dataset=eval_dataset, batch_size=args.batch_size, collate_fn=collate_fn, - num_workers=args.num_workers, pin_memory=True, shuffle=False, drop_last=False, - ) - - if args.user_custom_speaker_reference and args.inference_speaker_reference: - wav, sr = librosa.load(args.inference_speaker_reference, sr=model.sample_rate, mono=True) - speaker_wav = torch.as_tensor(wav, dtype=target_dtype).unsqueeze(0).to(model.device) - - for batch_id, inputs in enumerate(dataloader): - B = inputs["context_audio"].size(0) - device = model.device - - inputs["context_audio"] = inputs["context_audio"].to(device, dtype=target_dtype) - inputs["context_audio_lengths"] = inputs["context_audio_lengths"].to(device) - - if args.user_custom_speaker_reference and args.inference_speaker_reference: - inputs["context_audio"] = speaker_wav.expand(B, *speaker_wav.shape[1:]) - inputs["context_audio_lengths"][:] = speaker_wav.size(-1) - - with torch.inference_mode(): - # 1. Base Initialization (Shared between modes) - wav = inputs["context_audio"] - wav_len = inputs["context_audio_lengths"] - codes, codes_lens = model._codec_helper.audio_to_codes(wav, wav_len) - - use_lang = bool(getattr(model, "add_language_to_context_text", False)) - ctx_text = f"[{args.language.upper()}]" if use_lang else "[NO TEXT CONTEXT]" - ctx_text_ids = model.tokenizer.encode(ctx_text, tokenizer_name=model.text_conditioning_tokenizer_name) - - ctx_toks = torch.tensor([ctx_text_ids], dtype=torch.long, device=device).expand(B, -1) - ctx_toks_lens = torch.tensor([len(ctx_text_ids)] * B, dtype=torch.long, device=device) - - state = model.streaming_init( - context_audio_codes=codes, - context_audio_codes_lens=codes_lens, - context_text_tokens=ctx_toks, - context_text_tokens_lens=ctx_toks_lens, - use_cfg=args.use_cfg, - cfg_scale=args.cfg_scale, - use_local_transformer=True, - temperature=args.temperature, - topk=args.topk, - phoneme_input_type="pred", - phoneme_sampling_method="argmax", - use_inference_mode=True, - ) - - # --------------------------------------------------------- - # MODE 1: DUPLEX (Continuous Padding Token Stream) - # --------------------------------------------------------- - if inputs["duplex_multiturn"]: - text = inputs["input_ids"].to(device) - text_lens = inputs["input_lengths"].to(device) - # print("Text lens:", text.shape, text_lens) - # Fetch the true silence frame codebook combo once - codec_sil_codes = get_codec_silence_frame(model, device, model.sample_rate) - - # Trackers for our two forced-silence zones - in_initial_silence = torch.ones(B, dtype=torch.bool, device=device) - in_post_speech_silence = torch.zeros(B, dtype=torch.bool, device=device) - - text_exhausted = state.text_tokens_seen >= text_lens - while not text_exhausted.all(): - # 1. WAKE UP OVERRIDE: Keep the text pipeline awake to read pads! - state.finished = state.finished & text_exhausted - state.text_finished = state.text_finished & text_exhausted - if hasattr(state, "phoneme_stream_ended"): - state.phoneme_stream_ended = state.phoneme_stream_ended & text_exhausted - - # 2. Safely index text using the model's internal pointer - positions = state.text_tokens_seen.clamp(max=text.size(1) - 1) - current_tokens = text[torch.arange(B, device=device), positions] - - current_tokens = torch.where( - text_exhausted, torch.full_like(current_tokens, model.eos_id), current_tokens - ) - - # 3. Update our trackers BEFORE the step - is_pad_or_eos = (current_tokens == model.pad_id) | (current_tokens == model.eos_id) - - # Initial silence turns off forever once a real word is seen - in_initial_silence = in_initial_silence & is_pad_or_eos - - # Post-speech silence turns off when a real word for the NEXT turn is seen - in_post_speech_silence = in_post_speech_silence & is_pad_or_eos - - # 4. Step the model - state, audio_codes, _ = model.streaming_step(state=state, text_tokens=current_tokens, use_inference_mode=True) - - # 5. SILENCE FORCING INJECTION (Moved ABOVE the trigger!) - if audio_codes is not None and args.force_speech_sil_codes: - force_silence_mask = in_initial_silence | in_post_speech_silence - - if force_silence_mask.any(): - # Expand silence codes [C] -> [1, C, 1] to match audio_codes [B, C, 1] - expanded_sil = codec_sil_codes.view(1, -1, 1).expand_as(audio_codes) - - # Expand mask [B] -> [B, 1, 1] for broadcasting - mask_3d = force_silence_mask.view(B, 1, 1) - - # Overwrite the prediction with silence codes where the mask is True. - overwritten_codes = torch.where(mask_3d, expanded_sil, audio_codes) - - # Inject back into the model's KV cache history - state.all_predictions[-1] = overwritten_codes - - # 6. TRIGGER POST-SPEECH SILENCE FOR THE *NEXT* FRAME - # If the audio decoder naturally predicted a speech EOS, state.finished becomes True here. - # We update the tracker AFTER injection so we don't accidentally overwrite the EOS token! - in_post_speech_silence = in_post_speech_silence | state.finished - - # Update exhaustion tracker for the next iteration - text_exhausted = state.text_tokens_seen >= text_lens - - - # --------------------------------------------------------- - # MODE 2: REGULAR (Turn-by-Turn Re-wakes) - # --------------------------------------------------------- - elif inputs["regular_multiturn"]: - batched_turns = inputs["batched_turns"] - batched_turn_lens = inputs["batched_turn_lens"] - valid_turn_masks = inputs["valid_turn_masks"] - - max_turns = len(batched_turns) - - # Tracking offset ensures sync regardless of context audio length - turn_offsets = torch.zeros(B, dtype=torch.long, device=device) - - for t in range(max_turns): - turn_text = batched_turns[t].to(device) - turn_lens = batched_turn_lens[t].to(device) - valid_mask = valid_turn_masks[t].to(device) - - # Reset ALL finished flags for items participating in this new turn - state.finished = state.finished & (~valid_mask) - state.text_finished = state.text_finished & (~valid_mask) - - if hasattr(state, "phoneme_stream_ended"): - state.phoneme_stream_ended = state.phoneme_stream_ended & (~valid_mask) - - if state.finished.all(): - continue - - # Record internal token count at start of turn - turn_offsets = torch.where(valid_mask, state.text_tokens_seen, turn_offsets) - turn_steps = 0 - - while not state.finished.all() and turn_steps < args.max_tts_steps: - turn_steps += 1 - - # Fetch token synced relative to the model's progress - relative_positions = state.text_tokens_seen - turn_offsets - positions = relative_positions.clamp(min=0, max=turn_text.size(1) - 1) - current_tokens = turn_text[torch.arange(B, device=device), positions] - - # Once the text for this turn is fully fed, feed EOS so audio can finish - exhausted = relative_positions >= turn_lens - current_tokens = torch.where(exhausted, torch.full_like(current_tokens, model.eos_id), current_tokens) - - state, _, _ = model.streaming_step(state=state, text_tokens=current_tokens, use_inference_mode=True) - - if inputs["duplex_multiturn"]: - # Erase the internal memory of Turn 1's EOS token so `streaming_finalize` - # decodes the entire physical sequence! - state.audio_prediction_end_idx.fill_(-1) - - # Finalize decodes the collected Codec states globally regardless of which loop was run - finalize_output = model.streaming_finalize(state, use_inference_mode=True) - - if args.debug_dtype and batch_id == 0: - report_dtype_stats(handles, stats, examples) - - with fp32_precision(): - audio_f32 = finalize_output.audio.float() - audio_len = finalize_output.audio_len.int() - - expected_audio_lens = (torch.tensor(inputs["target_num_frames"], device=device) * model.target_samples_per_frame).int() - - if inputs["duplex_multiturn"]: - # Cap the expected length so it physically cannot exceed the actual generated tensor size - # expected_audio_lens = expected_audio_lens.clamp(max=audio_f32.size(1)) - # audio_len = torch.max(audio_len, expected_audio_lens) - # audio_len = torch.full_like(audio_len, audio_f32.size(1)) - print(text_lens, text_lens * model.target_samples_per_frame) - audio_len = (text_lens * model.target_samples_per_frame).int() - else: - audio_len = torch.min(audio_len, expected_audio_lens) - - metric_audio_pred = resample(audio_f32, model.output_sample_rate, 16000) - metric_audio_pred_lens = (audio_len / model.output_sample_rate * 16000).to(torch.long) - - intelligibility.update( - name="dataset", - refs=inputs["raw_text"], - pred_audio=metric_audio_pred, - pred_audio_lens=metric_audio_pred_lens, - asr_hyps=None, - ) - - secs_metric.update( - name="dataset", - target_audio=resample(inputs["context_audio"].float(), model.sample_rate, 16000), - target_audio_lens=(inputs["context_audio_lengths"] / model.sample_rate * 16000).to(torch.long), - pred_audio=metric_audio_pred, - pred_audio_lens=metric_audio_pred_lens, - ) - - os.makedirs(args.out_dir, exist_ok=True) - audio_f32 = audio_f32.detach().cpu() - audio_len = audio_len.cpu() - - for i in range(B): - wav = audio_f32[i, : audio_len[i]].numpy() - target_path = inputs["target_audio_paths"][i] - base_name = os.path.basename(target_path) - out_path = os.path.join(args.out_dir, base_name) - sf.write(out_path, wav, samplerate=model.output_sample_rate) - logging.info(f"Saved: {out_path}") - - with fp32_precision(): - logging.info("\n--- Evaluation Metrics ---") - cer_wer = intelligibility.compute() - for k, m in cer_wer.items(): - logging.info(f"Intelligibility - {k}: {m}") - - secs_scores = secs_metric.compute() - for k, m in secs_scores.items(): - logging.info(f"SECS - {k}: {m}") - - -if __name__ == "__main__": - main() diff --git a/examples/tts/easy_magpietts_inference_multiturn_old.py b/examples/tts/easy_magpietts_inference_multiturn_old.py deleted file mode 100644 index 502071d52d81..000000000000 --- a/examples/tts/easy_magpietts_inference_multiturn_old.py +++ /dev/null @@ -1,560 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -""" -Evaluation script for custom EasyMagpieTTS models trained on 10x padded inputs. -Stripped of Hydra config requirements. Uses standard argparse. -Features an explicitly exposed, fully batched autoregressive loop for easy multi-turn modding. - -Usage: - python easy_magpietts_eval.py \ - --checkpoint_path=/path/to/magpie/model.ckpt \ - --codec_model_path=/path/to/codec/model.ckpt \ - --datasets_json_path=/path/to/evalset_config.jsonl \ - --out_dir=/path/to/out/audio \ - --batch_size=6 \ - --add_interruption_token \ - --add_beginning_pad_tokens \ - --use_cfg -""" - -import argparse -import json -import os -import time -from copy import deepcopy -from functools import partial - -import librosa -import soundfile as sf -import torch -from torch.nn.utils.rnn import pad_sequence -from torch.utils.data import DataLoader, Dataset -from omegaconf import OmegaConf, open_dict - -from nemo.collections.audio.parts.utils.transforms import resample -from nemo.collections.speechlm2.parts.metrics.asr_cer_wer import Intelligibility -from nemo.collections.speechlm2.parts.metrics.secs import SECS -from nemo.collections.speechlm2.parts.precision import fp32_precision -from nemo.utils import logging - -# --- EasyMagpieTTS Imports --- -from nemo.collections.tts.models import AudioCodecModel -from nemo.collections.tts.modules.audio_codec_modules import VectorQuantizerIndexConverter -from nemo.collections.tts.modules.magpietts_modules import CodecHelper -from nemo.collections.tts.models.easy_magpietts_inference import EasyMagpieTTSInferenceModel - -torch.set_float32_matmul_precision("medium") -torch.backends.cudnn.allow_tf32 = True -torch.backends.cuda.matmul.allow_tf32 = True - -if torch.cuda.is_available(): - torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0))) - - -def attach_dtype_counter(model): - handles = [] - stats = {} - examples = {} - - def is_leaf(module): - return len(list(module.children())) == 0 - - def get_dtype(x): - if torch.is_tensor(x): - return str(x.dtype) - elif isinstance(x, (list, tuple)): - for t in x: - if torch.is_tensor(t): - return str(t.dtype) - return "other" - - def get_module_group(name): - return name.split(".")[0] if "." in name else name - - def hook_fn(name): - def fn(module, inputs, outputs): - dtype = get_dtype(outputs) - if dtype not in ["torch.float16", "torch.bfloat16", "torch.float32"]: - dtype = "other" - group = get_module_group(name) - if group not in stats: - stats[group] = { - "torch.float16": 0, "torch.bfloat16": 0, "torch.float32": 0, "other": 0, - } - examples[group] = { - "torch.float16": [], "torch.bfloat16": [], "torch.float32": [], "other": [], - } - stats[group][dtype] += 1 - if len(examples[group][dtype]) < 3: - examples[group][dtype].append(module.__class__.__name__) - return fn - - for name, module in model.named_modules(): - if is_leaf(module): - handles.append(module.register_forward_hook(hook_fn(name))) - return handles, stats, examples - - -def report_dtype_stats(handles, stats, examples): - for h in handles: - h.remove() - logging.info("\n=== DTYPE USAGE PER MODULE ===") - for group, group_stats in stats.items(): - total = sum(group_stats.values()) - if total == 0: continue - logging.info(f"\n--- {group} ---") - for dtype, count in group_stats.items(): - if count > 0: - logging.info(f"{dtype}: {count} ({100*count/total:.2f}%)") - logging.info("\n=== EXAMPLES ===") - for group, group_examples in examples.items(): - logging.info(f"\n--- {group} ---") - for dtype, mods in group_examples.items(): - if mods: - logging.info(f"{dtype}: {mods}") - - -class EvalJSONLDataset(Dataset): - def __init__(self, file_path, num_turns=1): - self.samples = [] - raw_samples = [] - with open(file_path, "r", encoding="utf-8") as f: - for line_idx, line in enumerate(f, 1): - line = line.strip() - if not line: continue - try: - raw_samples.append(json.loads(line)) - except json.JSONDecodeError as e: - raise ValueError(f"Invalid JSON on line {line_idx}: {e}") - - if num_turns <= 1: - self.samples = raw_samples - return - - single_turn_by_speaker = {} - for sample in raw_samples: - if isinstance(sample["text"], list): - self.samples.append(sample) - else: - speaker = sample.get("speaker", "unknown") - if speaker not in single_turn_by_speaker: - single_turn_by_speaker[speaker] = [] - single_turn_by_speaker[speaker].append(sample) - - for speaker, speaker_samples in single_turn_by_speaker.items(): - buffer_texts, buffer_paths = [], [] - first_sample_meta = None - - for sample in speaker_samples: - if not buffer_texts: - first_sample_meta = dict(sample) - buffer_texts.append(sample["text"]) - buffer_paths.append(sample.get("audio_filepath", "")) - - if len(buffer_texts) == num_turns: - first_sample_meta["text"] = buffer_texts - base_names = [os.path.splitext(os.path.basename(p))[0] for p in buffer_paths if p] - ext = os.path.splitext(buffer_paths[-1])[1] if buffer_paths[-1] else "" - combined_name = "_".join(base_names) + ext - dir_name = os.path.dirname(first_sample_meta.get("audio_filepath", "")) - if dir_name: - first_sample_meta["audio_filepath"] = os.path.join(dir_name, combined_name) - else: - first_sample_meta["audio_filepath"] = combined_name - - self.samples.append(first_sample_meta) - buffer_texts, buffer_paths, first_sample_meta = [], [], None - - if buffer_texts: - first_sample_meta["text"] = buffer_texts - base_names = [os.path.splitext(os.path.basename(p))[0] for p in buffer_paths if p] - ext = os.path.splitext(buffer_paths[-1])[1] if buffer_paths[-1] else "" - combined_name = "_".join(base_names) + ext - dir_name = os.path.dirname(first_sample_meta.get("audio_filepath", "")) - if dir_name: - first_sample_meta["audio_filepath"] = os.path.join(dir_name, combined_name) - else: - first_sample_meta["audio_filepath"] = combined_name - self.samples.append(first_sample_meta) - - def __len__(self): - return len(self.samples) - - def __getitem__(self, idx): - return self.samples[idx] - - -def collate_and_tokenize_custom( - batch, - model, - extra_duration_thrshould=1.3, - sample_rate=22050, - root_path=None, - add_beginning_pad_tokens=False, - add_interruption_token=False, - pad_factor_text_speech=10, - force_interruption=False, -): - tokenized_list = [] - main_tokenizer_name = list(model.cfg.text_tokenizers.keys())[0] - for s in batch: - text_data = s["text"] - - if isinstance(text_data, list): - full_ids = [] - for segment in text_data: - seg_ids = model.tokenizer.encode(segment, tokenizer_name=main_tokenizer_name) + [model.eos_id] - seg_len = len(seg_ids) - pad_len = seg_len * pad_factor_text_speech - pad_ids = [model.pad_id] * pad_len - - if force_interruption: - fname = s["audio_filepath"] - no_ext = fname.split(".")[0] - sample_id = int(no_ext.split("_")[-1]) - case = sample_id % 3 - - if case == 0: - if len(seg_ids) >= 2: - seg_ids[-2] = model.interruption_token_id - seg_ids[-1] = model.pad_id - else: - pad_ids[0] = model.interruption_token_id - elif case == 1: - eos_idx = min(6, len(pad_ids) - 1) - pad_ids[eos_idx] = model.interruption_token_id - else: - eos_idx = 0 - pad_ids[eos_idx] = model.interruption_token_id - else: - if add_interruption_token: - eos_idx = int(len(pad_ids) * 0.7) - pad_ids[eos_idx] = model.interruption_token_id - - full_ids.extend(seg_ids) - full_ids.extend(pad_ids) - - tokenized_list.append(torch.as_tensor(full_ids, dtype=torch.long)) - else: - tokenized_list.append( - torch.as_tensor(model.tokenizer.encode(text_data, tokenizer_name=main_tokenizer_name) + [model.eos_id], dtype=torch.long) - ) - - if add_beginning_pad_tokens: - pad_len = 25 - prefix = torch.full((pad_len,), model.pad_id, dtype=torch.long) - for i in range(len(tokenized_list)): - tokenized_list[i] = torch.cat([prefix, tokenized_list[i]]) - - # Capture the true sequence length before pad_sequence applies batch alignment padding - input_lengths = torch.tensor([len(x) for x in tokenized_list], dtype=torch.long) - input_ids = pad_sequence(tokenized_list, batch_first=True, padding_value=model.pad_id) - - audio_list = [] - audio_lengths = [] - target_num_frames = [] - - for i, s in enumerate(batch): - audio_path = s["context_audio_filepath"] - if root_path is not None: - audio_path = os.path.join(root_path, audio_path) - - if os.path.exists(audio_path): - wav, sr = librosa.load(audio_path, sr=sample_rate, mono=True) - wav = torch.as_tensor(wav, dtype=torch.float32) - else: - wav = torch.zeros(1, dtype=torch.float32) - - audio_list.append(wav) - audio_lengths.append(len(wav)) - - tdur_audio_path = s["audio_filepath"] - if root_path is not None: - tdur_audio_path = os.path.join(root_path, tdur_audio_path) - - if tdur_audio_path and os.path.exists(tdur_audio_path): - wav_dur, sr_ = librosa.load(tdur_audio_path, sr=sample_rate, mono=True) - tdur = wav_dur.shape[0] // model.target_samples_per_frame - target_num_frames.append(tdur * extra_duration_thrshould) - else: - current_text_len = len(tokenized_list[i]) - if isinstance(s["text"], list): - target_num_frames.append(current_text_len) - else: - target_num_frames.append(current_text_len * 5) - - max_audio_len = max(audio_lengths) - B = len(audio_lengths) - padded_audio = torch.zeros((B, max_audio_len), dtype=torch.float32) - - for i, wav in enumerate(audio_list): - padded_audio[i, : len(wav)] = wav - - audio_lengths = torch.tensor(audio_lengths, dtype=torch.long) - B, L = input_ids.shape - target_len = int(max(target_num_frames)) - target_len = max(target_len, L) - - padded_input_ids = torch.full((B, target_len), fill_value=model.pad_id, dtype=input_ids.dtype) - padded_input_ids[:, :L] = input_ids - - collapsed_raw_text = [" ".join(s["text"]) if isinstance(s["text"], list) else s["text"] for s in batch] - - return { - "input_ids": padded_input_ids, - "input_lengths": input_lengths, - "raw_text": collapsed_raw_text, - "context_audio": padded_audio, - "context_audio_lengths": audio_lengths, - "target_audio_paths": [s["audio_filepath"] for s in batch], - "target_num_frames": target_num_frames, - } - - -def main(): - parser = argparse.ArgumentParser(description="EasyMagpieTTS Inference Evaluation") - - # Required Paths - parser.add_argument("--checkpoint_path", type=str, required=True, help="Path to the EasyMagpie model") - parser.add_argument("--codec_model_path", type=str, required=True, help="Path to the audio codec") - parser.add_argument("--datasets_json_path", type=str, required=True, help="Path to JSONL data") - parser.add_argument("--out_dir", type=str, required=True, help="Directory to save audio outputs") - - # Optional Paths & General - parser.add_argument("--phoneme_tokenizer_path", type=str, default=None) - parser.add_argument("--audio_dir", type=str, default=None, help="Root dir for audio paths in JSONL") - parser.add_argument("--inference_dtype", type=str, default="float16") - parser.add_argument("--debug_dtype", action="store_true") - - # Dataloader & Batching - parser.add_argument("--batch_size", type=int, default=6) - parser.add_argument("--num_workers", type=int, default=4) - parser.add_argument("--num_turns", type=int, default=1) - parser.add_argument("--pad_factor_text_speech", type=int, default=10) - - # Text Processing Boolean Flags - parser.add_argument("--add_beginning_pad_tokens", action="store_true") - parser.add_argument("--add_interruption_token", action="store_true") - parser.add_argument("--force_interruption", action="store_true") - - # Speaker & Prompt Configurations - parser.add_argument("--user_custom_speaker_reference", action="store_true") - parser.add_argument("--inference_speaker_reference", type=str, default=None) - parser.add_argument("--language", type=str, default="en") - - # Generation Kwargs - parser.add_argument("--use_cfg", action="store_true") - parser.add_argument("--cfg_scale", type=float, default=2.5) - parser.add_argument("--temperature", type=float, default=0.7) - parser.add_argument("--topk", type=int, default=80) - parser.add_argument("--max_tts_steps", type=int, default=1000) - - args = parser.parse_args() - - distributed = int(os.environ.get("WORLD_SIZE", "1")) > 1 - if distributed and not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="nccl") - - target_device = torch.device(f"cuda:{int(os.environ.get('LOCAL_RANK', 0))}" if torch.cuda.is_available() else "cpu") - target_dtype = getattr(torch, args.inference_dtype) - torch.set_default_dtype(target_dtype) - - model_cfg = EasyMagpieTTSInferenceModel.restore_from(args.checkpoint_path, return_config=True) - with open_dict(model_cfg): - model_cfg.target = 'nemo.collections.tts.models.easy_magpietts_inference.EasyMagpieTTSInferenceModel' - model_cfg.codecmodel_path = args.codec_model_path - model_cfg.train_ds = None - model_cfg.validation_ds = None - model_cfg.run_val_inference = False - model_cfg.use_utmos = False - model_cfg.use_meta_init_for_decoder = True - if args.phoneme_tokenizer_path and getattr(model_cfg, "phoneme_tokenizer", None) is not None: - model_cfg.phoneme_tokenizer.tokenizer_path = args.phoneme_tokenizer_path - - model = EasyMagpieTTSInferenceModel.restore_from( - args.checkpoint_path, override_config_path=model_cfg, map_location=target_device - ) - model.use_kv_cache_for_inference = True - model.eval().to(target_device) - model.to(dtype=target_dtype) - - # --- DATALOADER COMPATIBILITY PATCHES --- - model.target_samples_per_frame = getattr(model, "codec_model_samples_per_frame", 320) - model.target_sample_rate = getattr(model, "sample_rate", 22050) - model.pad_id = getattr(model.tokenizer, "pad_id", 0) - model.text_eos_id = model.eos_id - - codec_model = AudioCodecModel.restore_from(args.codec_model_path, strict=False, map_location=target_device) - if hasattr(codec_model, "discriminator"): - del codec_model.discriminator - codec_model.freeze() - codec_model = codec_model.to(target_device).eval() - - codec_converter = None - if getattr(model, "_codec_converter", None) is not None: - vq_new = deepcopy(model._codec_converter.vector_quantizer_new).to(target_device).eval() - codec_converter = VectorQuantizerIndexConverter( - vector_quantizer_original=codec_model.vector_quantizer, - vector_quantizer_new=vq_new, - ).to(target_device).eval() - - if not hasattr(model, "_codec_helper") or model._codec_helper is None: - model._codec_helper = CodecHelper(codec_model=codec_model, codec_converter=codec_converter) - - if args.debug_dtype: - handles, stats, examples = attach_dtype_counter(model) - - with fp32_precision(): - intelligibility = Intelligibility("stt_en_fastconformer_transducer_large", reuse_asr_hyps=False).reset() - secs_metric = SECS("titanet_large").reset() - - eval_dataset = EvalJSONLDataset(args.datasets_json_path, num_turns=args.num_turns) - - collate_fn = partial( - collate_and_tokenize_custom, - model=model, - extra_duration_thrshould=1.5, - sample_rate=model.target_sample_rate, - root_path=args.audio_dir, - add_beginning_pad_tokens=args.add_beginning_pad_tokens, - add_interruption_token=args.add_interruption_token, - pad_factor_text_speech=args.pad_factor_text_speech, - force_interruption=args.force_interruption, - ) - - dataloader = DataLoader( - dataset=eval_dataset, batch_size=args.batch_size, collate_fn=collate_fn, - num_workers=args.num_workers, pin_memory=True, shuffle=False, drop_last=False, - ) - - if args.user_custom_speaker_reference and args.inference_speaker_reference: - wav, sr = librosa.load(args.inference_speaker_reference, sr=model.target_sample_rate, mono=True) - speaker_wav = torch.as_tensor(wav, dtype=target_dtype).unsqueeze(0).to(model.device) - - for batch_id, inputs in enumerate(dataloader): - B = inputs["input_ids"].size(0) - device = model.device - - inputs["input_ids"] = inputs["input_ids"].to(device) - inputs["input_lengths"] = inputs["input_lengths"].to(device) - inputs["context_audio"] = inputs["context_audio"].to(device, dtype=target_dtype) - inputs["context_audio_lengths"] = inputs["context_audio_lengths"].to(device) - - if args.user_custom_speaker_reference and args.inference_speaker_reference: - inputs["context_audio"] = speaker_wav.expand(B, *speaker_wav.shape[1:]) - inputs["context_audio_lengths"][:] = speaker_wav.size(-1) - - # 1. Prepare Context & Initialize - with torch.inference_mode(): - wav = inputs["context_audio"] - wav_len = inputs["context_audio_lengths"] - codes, codes_lens = model._codec_helper.audio_to_codes(wav, wav_len) - - use_lang = bool(getattr(model, "add_language_to_context_text", False)) - ctx_text = f"[{args.language.upper()}]" if use_lang else "[NO TEXT CONTEXT]" - ctx_text_ids = model.tokenizer.encode(ctx_text, tokenizer_name=model.text_conditioning_tokenizer_name) - - ctx_toks = torch.tensor([ctx_text_ids], dtype=torch.long, device=device).expand(B, -1) - ctx_toks_lens = torch.tensor([len(ctx_text_ids)] * B, dtype=torch.long, device=device) - - state = model.streaming_init( - context_audio_codes=codes, - context_audio_codes_lens=codes_lens, - context_text_tokens=ctx_toks, - context_text_tokens_lens=ctx_toks_lens, - use_cfg=args.use_cfg, - cfg_scale=args.cfg_scale, - use_local_transformer=True, - temperature=args.temperature, - topk=args.topk, - phoneme_input_type="pred", - phoneme_sampling_method="argmax", - use_inference_mode=True, - ) - - # --------------------------------------------------------- - # EXPOSED BATCHED GENERATION LOOP (Ready for multi-turn edits!) - # --------------------------------------------------------- - text = inputs["input_ids"] - text_lens = inputs["input_lengths"] - - gen_step = 0 - while not state.finished.all() and len(state.all_predictions) < args.max_tts_steps: - gen_step += 1 - - # Fetch current token dynamically based on state.text_tokens_seen - positions = state.text_tokens_seen.clamp(max=text.size(1) - 1) - current_tokens = text[torch.arange(B, device=device), positions] - - # Mask out sequences that have finished their true length - text_exhausted = state.text_tokens_seen >= text_lens - current_tokens = torch.where( - text_exhausted, torch.full_like(current_tokens, model.eos_id), current_tokens - ) - - # Feed tokens to the model step-by-step - state, audio_codes, _ = model.streaming_step( - state=state, - text_tokens=current_tokens, - use_inference_mode=True, - ) - - # Bulk Decode using the exposed state - finalize_output = model.streaming_finalize(state, use_inference_mode=True) - # --------------------------------------------------------- - - if args.debug_dtype and batch_id == 0: - report_dtype_stats(handles, stats, examples) - - with fp32_precision(): - # Grab output directly from streaming_finalize - audio_f32 = finalize_output.audio.float() - audio_len = finalize_output.audio_len.int() - - expected_audio_lens = (torch.tensor(inputs["target_num_frames"], device=device) * model.target_samples_per_frame).int() - audio_len = torch.min(audio_len, expected_audio_lens) - - metric_audio_pred = resample(audio_f32, getattr(model, "output_sample_rate", 24000), 16000) - metric_audio_pred_lens = (audio_len / getattr(model, "output_sample_rate", 24000) * 16000).to(torch.long) - - intelligibility.update( - name="dataset", - refs=inputs["raw_text"], - pred_audio=metric_audio_pred, - pred_audio_lens=metric_audio_pred_lens, - asr_hyps=None, - ) - - secs_metric.update( - name="dataset", - target_audio=resample(inputs["context_audio"].float(), model.target_sample_rate, 16000), - target_audio_lens=(inputs["context_audio_lengths"] / model.target_sample_rate * 16000).to(torch.long), - pred_audio=metric_audio_pred, - pred_audio_lens=metric_audio_pred_lens, - ) - - os.makedirs(args.out_dir, exist_ok=True) - audio_f32 = audio_f32.detach().cpu() - audio_len = audio_len.cpu() - - for i in range(B): - wav = audio_f32[i, : audio_len[i]].numpy() - target_path = inputs["target_audio_paths"][i] - base_name = os.path.basename(target_path) - out_path = os.path.join(args.out_dir, base_name) - sf.write(out_path, wav, samplerate=getattr(model, "output_sample_rate", 24000)) - logging.info(f"Saved: {out_path}") - - with fp32_precision(): - logging.info("\n--- Evaluation Metrics ---") - cer_wer = intelligibility.compute() - for k, m in cer_wer.items(): - logging.info(f"Intelligibility - {k}: {m}") - - secs_scores = secs_metric.compute() - for k, m in secs_scores.items(): - logging.info(f"SECS - {k}: {m}") - - -if __name__ == "__main__": - main() From 4f1a87e8e2ec26f550b01de1dcd937ab507c3c9b Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 22 Apr 2026 13:25:47 -0700 Subject: [PATCH 020/109] Update inference recipe Signed-off-by: Edresson Casanova --- examples/tts/easy_magpietts_inference_multiturn.py | 14 ++++++++++++-- .../text_to_speech_dataset_lhotse_multiturn.py | 2 +- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/examples/tts/easy_magpietts_inference_multiturn.py b/examples/tts/easy_magpietts_inference_multiturn.py index a7e22bd32458..c8f22b456086 100644 --- a/examples/tts/easy_magpietts_inference_multiturn.py +++ b/examples/tts/easy_magpietts_inference_multiturn.py @@ -39,6 +39,8 @@ from nemo.collections.tts.modules.magpietts_modules import CodecHelper from nemo.collections.tts.models.easy_magpietts_inference import EasyMagpieTTSInferenceModel +from nemo.collections.tts.parts.utils.tts_dataset_utils import normalize_volume + torch.set_float32_matmul_precision("medium") torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True @@ -190,6 +192,7 @@ def collate_and_tokenize_custom( add_interruption_token=False, pad_factor_text_speech=10, force_interruption=False, + normalize_context_audio_volume=True, ): main_tokenizer_name = list(model.cfg.text_tokenizers.keys())[0] @@ -322,6 +325,8 @@ def collate_and_tokenize_custom( if os.path.exists(audio_path): wav, sr = librosa.load(audio_path, sr=sample_rate, mono=True) + if normalize_context_audio_volume: + wav = normalize_volume(wav) wav = torch.as_tensor(wav, dtype=torch.float32) else: wav = torch.zeros(1, dtype=torch.float32) @@ -403,6 +408,7 @@ def main(): parser.add_argument("--topk", type=int, default=80) parser.add_argument("--max_tts_steps", type=int, default=2000) parser.add_argument("--force_speech_sil_codes", action="store_true") + parser.add_argument("--normalize_volume", type=bool, default=False) args = parser.parse_args() @@ -501,6 +507,8 @@ def get_codec_silence_frame(model, device, target_sample_rate): add_interruption_token=args.add_interruption_token, pad_factor_text_speech=args.pad_factor_text_speech, force_interruption=args.force_interruption, + normalize_context_audio_volume=args.normalize_volume, + ) dataloader = DataLoader( @@ -510,6 +518,8 @@ def get_codec_silence_frame(model, device, target_sample_rate): if args.user_custom_speaker_reference and args.inference_speaker_reference: wav, sr = librosa.load(args.inference_speaker_reference, sr=model.sample_rate, mono=True) + if args.normalize_volume: + wav = normalize_volume(wav) speaker_wav = torch.as_tensor(wav, dtype=target_dtype).unsqueeze(0).to(model.device) for batch_id, inputs in enumerate(dataloader): @@ -520,8 +530,8 @@ def get_codec_silence_frame(model, device, target_sample_rate): inputs["context_audio_lengths"] = inputs["context_audio_lengths"].to(device) if args.user_custom_speaker_reference and args.inference_speaker_reference: - inputs["context_audio"] = speaker_wav.expand(B, *speaker_wav.shape[1:]) - inputs["context_audio_lengths"][:] = speaker_wav.size(-1) + inputs["context_audio"] = speaker_wav.repeat(B, 1) + inputs["context_audio_lengths"] = torch.full((B,), speaker_wav.size(-1), dtype=torch.long, device=device) with torch.inference_mode(): wav = inputs["context_audio"] diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py index cabd1fc20849..bf3055eee63f 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py @@ -395,7 +395,7 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) context_audio_array = cut.context_audio.resample(self.sample_rate).load_audio().squeeze(0) if self.volume_norm: context_audio_array = normalize_volume(context_audio_array) - + _available_context_duration = len(context_audio_array) / self.sample_rate _context_duration_to_slice = _sample_context_duration_with_available_limit(_available_context_duration) _num_samples_to_slice = self.get_num_audio_samples_to_slice(_context_duration_to_slice, self.sample_rate) From 03e0e47ab2e3eb6a2020adbce74cd11a87f61699 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 22 Apr 2026 15:35:27 -0700 Subject: [PATCH 021/109] Remove librosa resample Signed-off-by: Edresson Casanova --- .../tts/easy_magpietts_inference_multiturn.py | 59 ++++++++++++++----- ...text_to_speech_dataset_lhotse_multiturn.py | 2 +- .../tts/models/easy_magpietts_inference.py | 9 ++- 3 files changed, 51 insertions(+), 19 deletions(-) diff --git a/examples/tts/easy_magpietts_inference_multiturn.py b/examples/tts/easy_magpietts_inference_multiturn.py index c8f22b456086..d87c85194309 100644 --- a/examples/tts/easy_magpietts_inference_multiturn.py +++ b/examples/tts/easy_magpietts_inference_multiturn.py @@ -11,7 +11,8 @@ --datasets_json_path=/path/to/evalset_config.jsonl \ --out_dir=/path/to/out/audio \ --batch_size=6 \ - --use_cfg + --use_cfg \ + --use_librosa """ import argparse @@ -193,6 +194,7 @@ def collate_and_tokenize_custom( pad_factor_text_speech=10, force_interruption=False, normalize_context_audio_volume=True, + use_librosa=False, ): main_tokenizer_name = list(model.cfg.text_tokenizers.keys())[0] @@ -290,6 +292,7 @@ def collate_and_tokenize_custom( turn_t_lens.append(len(seg_ids)) turn_t_valid.append(True) else: + # Dummy pad to keep shapes consistent for items with fewer turns turn_t_tokens.append(torch.as_tensor([model.pad_id], dtype=torch.long)) turn_t_lens.append(1) turn_t_valid.append(False) @@ -324,10 +327,23 @@ def collate_and_tokenize_custom( audio_path = os.path.join(root_path, audio_path) if os.path.exists(audio_path): - wav, sr = librosa.load(audio_path, sr=sample_rate, mono=True) - if normalize_context_audio_volume: - wav = normalize_volume(wav) - wav = torch.as_tensor(wav, dtype=torch.float32) + if use_librosa: + wav, sr = librosa.load(audio_path, sr=sample_rate, mono=True) + if normalize_context_audio_volume: + wav = normalize_volume(wav) + wav = torch.as_tensor(wav, dtype=torch.float32) + else: + wav, sr = sf.read(audio_path, dtype='float32') + # Force Mono + if wav.ndim > 1: + wav = wav.mean(axis=1) + + if normalize_context_audio_volume: + wav = normalize_volume(wav) + + # Convert to tensor, add batch dim for resampler, then remove it + wav_tensor = torch.as_tensor(wav, dtype=torch.float32).unsqueeze(0) + wav = resample(wav_tensor, sr, sample_rate).squeeze(0) else: wav = torch.zeros(1, dtype=torch.float32) @@ -384,6 +400,7 @@ def main(): parser.add_argument("--audio_dir", type=str, default=None, help="Root dir for audio paths in JSONL") parser.add_argument("--inference_dtype", type=str, default="float16") parser.add_argument("--debug_dtype", action="store_true") + parser.add_argument("--use_librosa", action="store_true", help="Use librosa instead of soundfile+torch for audio load") # Dataloader & Batching parser.add_argument("--batch_size", type=int, default=6) @@ -408,7 +425,7 @@ def main(): parser.add_argument("--topk", type=int, default=80) parser.add_argument("--max_tts_steps", type=int, default=2000) parser.add_argument("--force_speech_sil_codes", action="store_true") - parser.add_argument("--normalize_volume", type=bool, default=False) + parser.add_argument("--normalize_volume", type=lambda x: (str(x).lower() in ['true', '1', 'yes']), default=False) args = parser.parse_args() @@ -508,7 +525,7 @@ def get_codec_silence_frame(model, device, target_sample_rate): pad_factor_text_speech=args.pad_factor_text_speech, force_interruption=args.force_interruption, normalize_context_audio_volume=args.normalize_volume, - + use_librosa=args.use_librosa, ) dataloader = DataLoader( @@ -517,18 +534,30 @@ def get_codec_silence_frame(model, device, target_sample_rate): ) if args.user_custom_speaker_reference and args.inference_speaker_reference: - wav, sr = librosa.load(args.inference_speaker_reference, sr=model.sample_rate, mono=True) - if args.normalize_volume: - wav = normalize_volume(wav) - speaker_wav = torch.as_tensor(wav, dtype=target_dtype).unsqueeze(0).to(model.device) + if args.use_librosa: + wav, sr = librosa.load(args.inference_speaker_reference, sr=model.sample_rate, mono=True) + if args.normalize_volume: + wav = normalize_volume(wav) + speaker_wav = torch.as_tensor(wav, dtype=target_dtype).unsqueeze(0).to(model.device) + else: + wav, sr = sf.read(args.inference_speaker_reference, dtype='float32') + # Force Mono + if wav.ndim > 1: + wav = wav.mean(axis=1) + + if args.normalize_volume: + wav = normalize_volume(wav) + + speaker_wav = torch.as_tensor(wav).unsqueeze(0) + speaker_wav = resample(speaker_wav.float(), sr, model.sample_rate).to(target_dtype).to(model.device) for batch_id, inputs in enumerate(dataloader): B = inputs["context_audio"].size(0) device = model.device - + inputs["context_audio"] = inputs["context_audio"].to(device, dtype=target_dtype) inputs["context_audio_lengths"] = inputs["context_audio_lengths"].to(device) - + if args.user_custom_speaker_reference and args.inference_speaker_reference: inputs["context_audio"] = speaker_wav.repeat(B, 1) inputs["context_audio_lengths"] = torch.full((B,), speaker_wav.size(-1), dtype=torch.long, device=device) @@ -658,8 +687,8 @@ def get_codec_silence_frame(model, device, target_sample_rate): # Scrub Special Tokens (BOS/EOS) from Audio Codes --- # Because we force-decode the entire uncropped sequence, any BOS or EOS # tokens left in the array will produce loud artifacts in the codec. - bos_id = model.audio_bos_id - eos_id = model.audio_eos_id + bos_id = getattr(model, "audio_bos_id", -1) + eos_id = getattr(model, "audio_eos_id", -1) sil_injection = codec_sil_codes.view(1, -1, 1) for step_idx in range(len(state.all_predictions)): diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py index bf3055eee63f..1e1d4be117a2 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py @@ -399,7 +399,7 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) _available_context_duration = len(context_audio_array) / self.sample_rate _context_duration_to_slice = _sample_context_duration_with_available_limit(_available_context_duration) _num_samples_to_slice = self.get_num_audio_samples_to_slice(_context_duration_to_slice, self.sample_rate) - + if _num_samples_to_slice < len(context_audio_array): start_idx = random.randint(0, len(context_audio_array) - _num_samples_to_slice) context_audio_array = context_audio_array[start_idx : start_idx + _num_samples_to_slice] diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index d869714c049f..af99666ee0be 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -30,6 +30,7 @@ from nemo.collections.tts.data.text_to_speech_dataset_lhotse import setup_tokenizers from nemo.collections.tts.models import AudioCodecModel from nemo.collections.tts.modules import transformer_2501 +from nemo.collections.audio.parts.utils.transforms import resample from nemo.collections.tts.modules.audio_codec_modules import VectorQuantizerIndexConverter from nemo.collections.tts.modules.magpietts_modules import ( CharAwareSubwordEncoder, @@ -2132,11 +2133,13 @@ def _load_audio_for_inference(audio_path: str, target_sample_rate: int) -> torch audio, sr = sf.read(audio_path, dtype='float32') if len(audio.shape) > 1: audio = audio.mean(axis=1) + + audio = torch.from_numpy(audio).unsqueeze(0) + if sr != target_sample_rate: - import librosa + audio = resample(audio.float(), sr, target_sample_rate) - audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sample_rate) - return torch.from_numpy(audio).unsqueeze(0) + return audio @staticmethod def _adjust_audio_to_duration_for_inference( From 867e1630f830d7d9aa8adac2ba5ec28db6abee28 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 22 Apr 2026 16:20:32 -0700 Subject: [PATCH 022/109] Add silence aug Signed-off-by: Edresson Casanova --- nemo/collections/common/data/lhotse/cutset.py | 67 +++++++++++++++++-- 1 file changed, 63 insertions(+), 4 deletions(-) diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index 141b91b65c63..6f9109e451d0 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -859,11 +859,46 @@ def _create_recording_from_array(samples: np.ndarray, sampling_rate: int, record buffer.seek(0) return Recording.from_bytes(buffer.read(), recording_id=recording_id) +def _prepend_silence_monocut( + cut: MonoCut, + sil_duration: float, + sample_rate: int, + recording_id: str, + cut_id: str, +) -> MonoCut: + """Helper to pad silence at the beginning of a monocut.""" + audio = cut.load_audio() # (C, N) + n_pad = int(round(sil_duration * sample_rate)) + if n_pad <= 0: + return cut + + pad = np.zeros((audio.shape[0], n_pad), dtype=audio.dtype) + audio2 = np.concatenate([pad, audio], axis=1) + + rec = _create_recording_from_array(audio2, sample_rate, recording_id=recording_id) + return MonoCut( + id=cut_id, + start=0.0, + duration=audio2.shape[1] / sample_rate, + channel=0, + recording=rec, + supervisions=[], + ).move_to_memory(audio_format="wav") + class ConvertCutFn: - def __init__(self, sample_rate: int, add_extra_end_sil: bool, extra_end_silence_range: list): + def __init__( + self, + sample_rate: int, + add_extra_end_sil: bool, + extra_end_silence_range: list, + add_extra_begin_sil: bool, + extra_begin_silence_range: list + ): self.sample_rate = sample_rate self.add_extra_end_sil = add_extra_end_sil self.extra_end_silence_range = extra_end_silence_range + self.add_extra_begin_sil = add_extra_begin_sil + self.extra_begin_silence_range = extra_begin_silence_range def __call__(self, cut: Cut) -> Cut: orig_agent_sup = fastcopy(cut.supervisions[0]) @@ -909,6 +944,7 @@ def __call__(self, cut: Cut) -> Cut: user_sup.custom = deepcopy(user_sup.custom) user_sup.custom["ipa"] = "" + # Optionally add extra silence on the end if self.add_extra_end_sil: sil_duration = random.uniform(*self.extra_end_silence_range) cut_target = cut_target.pad(duration=total_duration + sil_duration, direction="right") @@ -918,6 +954,22 @@ def __call__(self, cut: Cut) -> Cut: agent_sup.duration += sil_duration + 1.0 user_sup.duration += sil_duration + # Optionally add extra silence on the start + if self.add_extra_begin_sil: + sil_duration = random.uniform(*self.extra_begin_silence_range) + cut_target = _prepend_silence_monocut( + cut_target, sil_duration, self.sample_rate, + recording_id=f"{cut.id}_target_pre", cut_id=f"{cut.id}_target" + ) + cut_source = _prepend_silence_monocut( + cut_source, sil_duration, self.sample_rate, + recording_id=f"{cut.id}_source_pre", cut_id=f"{cut.id}_source" + ) + + # Shift supervision start times forward, because audio got longer at the beginning + user_sup.start += sil_duration + agent_sup.start += sil_duration + cut_source.supervisions = [user_sup, agent_sup] cut_source.target_audio = cut_target.recording cut_source.duration = cut_target.duration @@ -929,14 +981,14 @@ def __call__(self, cut: Cut) -> Cut: return cut_source - - @data_type_parser(["lhotse_magpietts_data_as_continuation"]) def read_lhotse_magpietts_data_as_s2s_duplex(config) -> Tuple[CutSet, bool]: cuts, is_tarred = read_cutset_from_config(config) add_extra_end_sil = config.get("add_extra_end_silence", False) extra_end_silence_range = config.get("extra_end_silence_range", [0.5, 6.0]) + add_extra_begin_sil = config.get("add_extra_begin_sil", False) + extra_begin_silence_range = config.get("extra_begin_silence_range", [0.5, 6.0]) sample_rate = config.get("sample_rate", 22050) max_cer = config.get("max_cer", 0.03) @@ -952,7 +1004,14 @@ def read_lhotse_magpietts_data_as_s2s_duplex(config) -> Tuple[CutSet, bool]: .filter(FilterTargetSpeaker(target_speaker)) ) - cuts = cuts.map(ConvertCutFn(sample_rate, add_extra_end_sil, extra_end_silence_range)) + # Pass the beginning silence configs to the updated ConvertCutFn + cuts = cuts.map(ConvertCutFn( + sample_rate=sample_rate, + add_extra_end_sil=add_extra_end_sil, + extra_end_silence_range=extra_end_silence_range, + add_extra_begin_sil=add_extra_begin_sil, + extra_begin_silence_range=extra_begin_silence_range + )) return cuts, is_tarred From 7f5d81e6b603424c9e06af0d8cc1cbfb1b8a9bd7 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 23 Apr 2026 06:34:06 -0700 Subject: [PATCH 023/109] Add silence tts data augmentation Signed-off-by: Edresson Casanova --- .../tts/easy_magpietts_inference_multiturn.py | 22 +------- ...text_to_speech_dataset_lhotse_multiturn.py | 2 +- nemo/collections/tts/models/easy_magpietts.py | 53 +++++++++++++++++++ .../tts/models/easy_magpietts_inference.py | 51 ++++++++++++++++++ 4 files changed, 106 insertions(+), 22 deletions(-) diff --git a/examples/tts/easy_magpietts_inference_multiturn.py b/examples/tts/easy_magpietts_inference_multiturn.py index d87c85194309..04a7ae6775d1 100644 --- a/examples/tts/easy_magpietts_inference_multiturn.py +++ b/examples/tts/easy_magpietts_inference_multiturn.py @@ -483,27 +483,7 @@ def main(): if not hasattr(model, "_codec_helper") or model._codec_helper is None: model._codec_helper = CodecHelper(codec_model=codec_model, codec_converter=codec_converter) - from collections import Counter - def get_codec_silence_frame(model, device, target_sample_rate): - audio = torch.zeros(1, 10 * target_sample_rate).float().to(device) - audio_len = torch.tensor([audio.size(-1)]).long().to(device) - - sil_codes, sil_codes_lens = model._codec_helper.audio_to_codes( - audio, audio_len - ) - - if model._codec_converter is not None: - sil_codes = model._codec_converter.convert_original_to_new( - audio_tokens=sil_codes, audio_lens=sil_codes_lens - ).long() - - frames = sil_codes[0].transpose(0, 1) - combos = [tuple(frame.tolist()) for frame in frames] - counter = Counter(combos) - most_common_combo, freq = counter.most_common(1)[0] - return torch.tensor(most_common_combo, device=device, dtype=torch.long) - - codec_sil_codes = get_codec_silence_frame(model, target_device, model.sample_rate) + codec_sil_codes = model.codec_sil_codes if args.debug_dtype: handles, stats, examples = attach_dtype_counter(model) diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py index 1e1d4be117a2..78e9f14d641c 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py @@ -536,7 +536,7 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) "text": target_text_tokens, "text_lens": target_token_lens, "raw_texts": [" ".join(s.text for s in cut.supervisions if s.speaker in self.output_roles) for cut in cuts], - "dataset_type": [getattr(c, "type", "") for c in cuts], + "task": [getattr(cut, "task", "tts") for cut in cuts], } if target_codes_list: diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index b519d93a2dcf..4826608cc997 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -986,6 +986,59 @@ def training_step(self, batch, batch_idx): audio_lens = batch['audio_lens'] audio_codes, audio_codes_lens = self._codec_helper.audio_to_codes(audio, audio_lens) + + if self.cfg.get("use_multiturn_dataset", False) and "tts" in batch['task']: + prob = self.cfg.get("add_tts_sil_begining_prob", 0.0) + if prob > 0 and torch.rand(1).item() < prob: + audio_codes_lens_max = audio_codes_lens.max() + + # 1. Calculate the raw shift (with the -1 safety buffer) + raw_pad_lens = torch.clamp(audio_codes_lens_max - audio_codes_lens - 1, min=0) + + # 2. Round DOWN to the nearest multiple of the stacking factor + pad_lens = (raw_pad_lens // self.frame_stacking_factor) * self.frame_stacking_factor + + # 3. Calculate perfectly aligned text padding + text_pad_lens = pad_lens // self.frame_stacking_factor + + if pad_lens.max() > 0: + device = audio_codes.device + B, C, T_audio = audio_codes.shape + + # --- Vectorized Audio Shift --- + idx_a = torch.arange(T_audio, device=device).unsqueeze(0) + src_idx_a = idx_a - pad_lens.unsqueeze(1) + + valid_mask_a = (src_idx_a >= 0) & (src_idx_a < audio_codes_lens.unsqueeze(1)) + safe_src_idx_a = src_idx_a.clamp(min=0, max=T_audio - 1) + + safe_src_idx_a_exp = safe_src_idx_a.unsqueeze(1).expand(-1, C, -1) + valid_mask_a_exp = valid_mask_a.unsqueeze(1).expand(-1, C, -1) + + gathered_audio = torch.gather(audio_codes, 2, safe_src_idx_a_exp) + silence_pad = self.codec_sil_codes_unconverted.view(1, C, 1).expand(B, C, T_audio) + + audio_codes = torch.where(valid_mask_a_exp, gathered_audio, silence_pad) + audio_codes_lens = audio_codes_lens + pad_lens + + # --- Vectorized Text Shift (USING SCALED LENS) --- + old_text = batch['text'] + text_lens = batch['text_lens'] + + new_text_lens = text_lens + text_pad_lens + new_T_text = max(new_text_lens.max().item(), old_text.size(1)) + + idx_t = torch.arange(new_T_text, device=device).unsqueeze(0) + src_idx_t = idx_t - text_pad_lens.unsqueeze(1) + + valid_mask_t = (src_idx_t >= 0) & (src_idx_t < text_lens.unsqueeze(1)) + safe_src_idx_t = src_idx_t.clamp(min=0, max=old_text.size(1) - 1) + + gathered_text = torch.gather(old_text, 1, safe_src_idx_t) + + batch['text'] = torch.where(valid_mask_t, gathered_text, self.pad_id) + batch['text_lens'] = new_text_lens + batch_output = self.process_batch( text=batch['text'], text_lens=batch['text_lens'], diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index af99666ee0be..e9d18218f365 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -18,6 +18,7 @@ from dataclasses import dataclass, fields from functools import partial from typing import Any, Dict, List, Optional, Sequence, Tuple +from collections import Counter import numpy as np import soundfile as sf @@ -605,6 +606,56 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): codebook_size=self.codebook_size, ) + @property + def codec_sil_codes(self): + """Returns the CONVERTED silence codes (used by the model's predictions)""" + if not hasattr(self, "_codec_sil_codes_buffer"): + self._generate_codec_silence_buffer() + return self._codec_sil_codes_buffer + + @property + def codec_sil_codes_unconverted(self): + """Returns the RAW, UNCONVERTED silence codes (used for training batch labels)""" + if not hasattr(self, "_codec_sil_codes_buffer_unconverted"): + self._generate_codec_silence_buffer() + return self._codec_sil_codes_buffer_unconverted + + def _generate_codec_silence_buffer(self): + device = self.device if hasattr(self, 'device') else next(self.parameters()).device + + # Generate 5 seconds of silence + audio = torch.zeros(1, 5 * self.sample_rate, dtype=torch.float32, device=device) + audio_len = torch.tensor([audio.size(-1)], dtype=torch.long, device=device) + + with torch.no_grad(): + # 1. Get the RAW codes directly from the helper + sil_codes_raw, sil_codes_lens = self._codec_helper.audio_to_codes(audio, audio_len) + + # Find most common frame for UNCONVERTED + frames_raw = sil_codes_raw[0].transpose(0, 1) + combos_raw = [tuple(frame.tolist()) for frame in frames_raw] + most_common_raw, _ = Counter(combos_raw).most_common(1)[0] + sil_tensor_unconverted = torch.tensor(most_common_raw, device=device, dtype=torch.long) + + # 2. Get the CONVERTED codes (if a converter exists) + if getattr(self, '_codec_converter', None) is not None: + sil_codes_conv = self._codec_converter.convert_original_to_new( + audio_tokens=sil_codes_raw, audio_lens=sil_codes_lens + ).long() + + # Find most common frame for CONVERTED + frames_conv = sil_codes_conv[0].transpose(0, 1) + combos_conv = [tuple(frame.tolist()) for frame in frames_conv] + most_common_conv, _ = Counter(combos_conv).most_common(1)[0] + sil_tensor_converted = torch.tensor(most_common_conv, device=device, dtype=torch.long) + else: + # If no converter exists, they are identical + sil_tensor_converted = sil_tensor_unconverted.clone() + + # 3. Register BOTH as independent buffers + self.register_buffer("_codec_sil_codes_buffer", sil_tensor_converted, persistent=False) + self.register_buffer("_codec_sil_codes_buffer_unconverted", sil_tensor_unconverted, persistent=False) + def _get_state_dict_keys_to_exclude(self) -> List[str]: return [ '_codec_model', From 6aa93ea03d14a1abd5a86365535ed6429cf70ca5 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 27 Apr 2026 13:53:42 -0700 Subject: [PATCH 024/109] Add parameter to remove subword text conditioning Signed-off-by: Edresson Casanova --- .../easy_magpietts_lhotse_multiturn.yaml | 12 +-- .../tts/easy_magpietts_inference_multiturn.py | 1 + ...text_to_speech_dataset_lhotse_multiturn.py | 2 +- nemo/collections/tts/models/easy_magpietts.py | 5 ++ .../tts/models/easy_magpietts_inference.py | 90 ++++++++++++++----- 5 files changed, 80 insertions(+), 30 deletions(-) diff --git a/examples/tts/conf/magpietts/easy_magpietts_lhotse_multiturn.yaml b/examples/tts/conf/magpietts/easy_magpietts_lhotse_multiturn.yaml index 1f4c10b071c3..0f729a41b61e 100644 --- a/examples/tts/conf/magpietts/easy_magpietts_lhotse_multiturn.yaml +++ b/examples/tts/conf/magpietts/easy_magpietts_lhotse_multiturn.yaml @@ -124,9 +124,9 @@ model: quadratic_duration: ${quadratic_duration} use_bucketing: true num_buckets: 20 - bucket_buffer_size: 15_000 - shuffle_buffer_size: 15_000 - num_cuts_for_bins_estimate: 15_000 + bucket_buffer_size: 10_000 + shuffle_buffer_size: 10_000 + num_cuts_for_bins_estimate: 10_000 shard_seed: "trng" drop_last: true shuffle: true @@ -144,9 +144,9 @@ model: input_cfg: /lustre/fsw/convai_convaird_nemo-speech/data/duplex/multispeaker_syn_duplex.yaml use_bucketing: true num_buckets: 20 - bucket_buffer_size: 5_000 - shuffle_buffer_size: 5_000 - num_cuts_for_bins_estimate: 5_000 + bucket_buffer_size: 1_000 + shuffle_buffer_size: 1_000 + num_cuts_for_bins_estimate: 1_000 max_duration: 300 # 5 mi max duration bucket_duration_bins: [4.0, 8.9, 10.2, 11.6, 13.2, 15.0, 17.0, 19.3, 25.0, 31.5, 38.5, 46.0, 55.5, 66.5, 79.5, 93.3, 110.0, 130.0, 156.8, 203.3] bucket_batch_size: [75, 33, 29, 25, 23, 20, 18, 15, 12, 10, 8, 7, 5, 4, 3, 3, 2, 2, 1, 1] diff --git a/examples/tts/easy_magpietts_inference_multiturn.py b/examples/tts/easy_magpietts_inference_multiturn.py index 04a7ae6775d1..3e04ec07b361 100644 --- a/examples/tts/easy_magpietts_inference_multiturn.py +++ b/examples/tts/easy_magpietts_inference_multiturn.py @@ -483,6 +483,7 @@ def main(): if not hasattr(model, "_codec_helper") or model._codec_helper is None: model._codec_helper = CodecHelper(codec_model=codec_model, codec_converter=codec_converter) + model._generate_codec_silence_buffer() codec_sil_codes = model.codec_sil_codes if args.debug_dtype: diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py index 78e9f14d641c..07e6e2d55d47 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py @@ -410,7 +410,7 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) context_audio = torch.from_numpy(context_audio_array) context_audio_list.append(context_audio) context_audio_len_list.append(context_audio.shape[0]) - + else: matching_supervisions = [s for s in cut.supervisions if s.speaker in self.output_roles] diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index 4826608cc997..29ea67c1c48b 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -1403,6 +1403,11 @@ def validation_step(self, batch, batch_idx): return val_output + def on_fit_start(self): + super().on_fit_start() + if not hasattr(self, "_codec_sil_codes_buffer"): + self._generate_codec_silence_buffer() + def on_validation_epoch_start(self) -> None: if torch.distributed.is_initialized(): self.trainer.strategy.model.require_backward_grad_sync = False diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index e9d18218f365..7b72f374e357 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -338,6 +338,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): mode='train', ) + base_num_tokens = len(self.tokenizer.tokens) # Assign standard special tokens sequentially @@ -608,53 +609,43 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): @property def codec_sil_codes(self): - """Returns the CONVERTED silence codes (used by the model's predictions)""" - if not hasattr(self, "_codec_sil_codes_buffer"): - self._generate_codec_silence_buffer() return self._codec_sil_codes_buffer @property def codec_sil_codes_unconverted(self): - """Returns the RAW, UNCONVERTED silence codes (used for training batch labels)""" - if not hasattr(self, "_codec_sil_codes_buffer_unconverted"): - self._generate_codec_silence_buffer() return self._codec_sil_codes_buffer_unconverted def _generate_codec_silence_buffer(self): - device = self.device if hasattr(self, 'device') else next(self.parameters()).device + codec_device = next(self._codec_model.parameters()).device - # Generate 5 seconds of silence - audio = torch.zeros(1, 5 * self.sample_rate, dtype=torch.float32, device=device) - audio_len = torch.tensor([audio.size(-1)], dtype=torch.long, device=device) + audio = torch.zeros(1, 5 * self.sample_rate, dtype=torch.float32, device=codec_device) + audio_len = torch.tensor([audio.size(-1)], dtype=torch.long, device=codec_device) with torch.no_grad(): - # 1. Get the RAW codes directly from the helper sil_codes_raw, sil_codes_lens = self._codec_helper.audio_to_codes(audio, audio_len) - - # Find most common frame for UNCONVERTED + frames_raw = sil_codes_raw[0].transpose(0, 1) combos_raw = [tuple(frame.tolist()) for frame in frames_raw] most_common_raw, _ = Counter(combos_raw).most_common(1)[0] - sil_tensor_unconverted = torch.tensor(most_common_raw, device=device, dtype=torch.long) + sil_tensor_unconverted = torch.tensor(most_common_raw, device=codec_device, dtype=torch.long) - # 2. Get the CONVERTED codes (if a converter exists) - if getattr(self, '_codec_converter', None) is not None: + if self._codec_converter is not None: sil_codes_conv = self._codec_converter.convert_original_to_new( audio_tokens=sil_codes_raw, audio_lens=sil_codes_lens ).long() - - # Find most common frame for CONVERTED frames_conv = sil_codes_conv[0].transpose(0, 1) combos_conv = [tuple(frame.tolist()) for frame in frames_conv] most_common_conv, _ = Counter(combos_conv).most_common(1)[0] - sil_tensor_converted = torch.tensor(most_common_conv, device=device, dtype=torch.long) + sil_tensor_converted = torch.tensor(most_common_conv, device=codec_device, dtype=torch.long) else: - # If no converter exists, they are identical sil_tensor_converted = sil_tensor_unconverted.clone() - # 3. Register BOTH as independent buffers - self.register_buffer("_codec_sil_codes_buffer", sil_tensor_converted, persistent=False) - self.register_buffer("_codec_sil_codes_buffer_unconverted", sil_tensor_unconverted, persistent=False) + if not hasattr(self, "_codec_sil_codes_buffer"): + self.register_buffer("_codec_sil_codes_buffer", sil_tensor_converted, persistent=False) + self.register_buffer("_codec_sil_codes_buffer_unconverted", sil_tensor_unconverted, persistent=False) + else: + self._codec_sil_codes_buffer.copy_(sil_tensor_converted) + self._codec_sil_codes_buffer_unconverted.copy_(sil_tensor_unconverted) def _get_state_dict_keys_to_exclude(self) -> List[str]: return [ @@ -1084,6 +1075,59 @@ def prepare_context_tensors( return context_embedding, context_lens, context_audio_codes, context_audio_codes_lens + def _embed_context_text_tokens(self, context_text_tokens: torch.Tensor) -> torch.Tensor: + """ + Embed context text tokens. + + Default behavior is preserved: + - disable_subword_embedding_on_context=False: decoder text embedding only + + New behavior: + - disable_subword_embedding_on_context=True: CAS encoder replaces decoder text embedding + """ + if self.cfg.get("disable_subword_embedding_on_context", False): + if not self.cfg.get("use_bpe_char_tokenizer", False): + raise ValueError( + "`disable_subword_embedding_on_context=True` requires " + "`use_bpe_char_tokenizer=True`, because CAS must replace text_embedding." + ) + + if self.cfg.get("use_multiturn_dataset", False): + context_text_mask = context_text_tokens != self.pad_id + else: + context_text_mask = torch.ones_like(context_text_tokens, dtype=torch.bool) + + return self.cas_encoder( + context_text_tokens, + subword_mask=context_text_mask, + ) + + return self.decoder.get_input_embeddings()(context_text_tokens) + + def _get_context_cfg_embedding(self, batch_size: int, device: torch.device) -> torch.Tensor: + """ + Returns the unconditional context embedding used for CFG dropout. + Shape: (B, 1, E) + """ + cfg_token = torch.full( + (batch_size, 1), + self.cfg_unk_token_id, + dtype=torch.long, + device=device, + ) + + if self.cfg.get("disable_subword_embedding_on_context", False): + if not self.cfg.get("use_bpe_char_tokenizer", False): + raise ValueError( + "`disable_subword_embedding_on_context=True` requires " + "`use_bpe_char_tokenizer=True` for CFG context embedding." + ) + + cfg_mask = torch.ones_like(cfg_token, dtype=torch.bool) + return self.cas_encoder(cfg_token, subword_mask=cfg_mask) + + return self.decoder.get_input_embeddings()(cfg_token) + def stack_codes(self, codes, codes_lens, bos_id, eos_id, stacking_factor, num_codebooks): """ Stack multiple time steps into the channel dimension to reduce sequence length. From 60dbe33c52174d3b2e73f67dbaa7f2f66d1b779c Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 27 Apr 2026 16:38:44 -0700 Subject: [PATCH 025/109] Add support for extra duplex dataloaders Signed-off-by: Edresson Casanova --- nemo/collections/common/data/lhotse/cutset.py | 108 ++++++++++++++++++ nemo/collections/tts/models/easy_magpietts.py | 11 +- .../tts/models/easy_magpietts_inference.py | 3 + 3 files changed, 118 insertions(+), 4 deletions(-) diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index 6f9109e451d0..8e384e3a8117 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -981,6 +981,114 @@ def __call__(self, cut: Cut) -> Cut: return cut_source + +@data_type_parser(["s2s_duplex_overlap_as_s2s_duplex"]) +def read_s2s_duplex_overlap_as_s2s_duplex(config) -> Tuple[CutSet, bool]: + """ + Convert a CutSet with overlapping agent/user segments into a standard S2S duplex format. + + Use Case: + This parser is designed for conversational data where agent and user speech can overlap + in time (e.g., natural turn-taking with interruptions or backchanneling). The input + format stores agent and user segments separately as `agent_segments` and `user_segments` + attributes on each cut. This function converts them into a unified timeline of sequential + SupervisionSegments, which is the standard format expected by DuplexS2S models. + + Expected Input Data Format: + Each cut should have: + - cut.agent_segments: List[Dict] with keys: + - "start" (float): Start time in seconds + - "end" (float): End time in seconds + - "text" (str): Agent's transcription + - cut.user_segments: List[Dict] with keys: + - "start" (float): Start time in seconds + - "end" (float): End time in seconds + - "text" (str): User's transcription + + Example: + Input cut with overlapping segments: + cut.agent_segments = [ + {"start": 0.5, "end": 2.0, "text": "Hello, how can I help?"}, + {"start": 3.0, "end": 4.5, "text": "Sure, I can do that."} + ] + cut.user_segments = [ + {"start": 1.8, "end": 3.2, "text": "I need assistance"}, + {"start": 4.0, "end": 5.5, "text": "Thank you"} + ] + + Output cut.supervisions (sorted by start time): + [ + SupervisionSegment(start=0.5, duration=1.5, text="Hello, how can I help?", speaker="agent"), + SupervisionSegment(start=1.8, duration=1.4, text="I need assistance", speaker="user"), + SupervisionSegment(start=3.0, duration=1.5, text="Sure, I can do that.", speaker="agent"), + SupervisionSegment(start=4.0, duration=1.5, "Thank you", speaker="user") + ] + + Args: + config: Dictionary containing parser options: + - move_agent_text_back_by (float): Time offset to shift agent text back (default: 0). + Useful for aligning agent text with earlier audio timing. + - filter_samples_starting_with_agent (bool): Whether to remove samples starting with agent (default: False). + When True, only keeps samples where the first speaker is a user. + - agent_roles (List[str]): Roles considered as agent (default: ["agent", "Assistant", "assistant"]). + + Returns: + Tuple[CutSet, bool]: Converted cuts with unified supervisions, and a flag indicating if the data was tarred. + """ + move_agent_text_back_by = config.get("move_agent_text_back_by", 0) + filter_samples_starting_with_agent = config.get("filter_samples_starting_with_agent", False) + agent_roles = config.get("agent_roles", ["agent", "Assistant", "assistant"]) + + cuts, is_tarred = read_cutset_from_config(config) + + def filter_cuts_starting_with_agent_fn(cuts: CutSet, agent_roles: Tuple[str, ...]) -> CutSet: + """Remove cuts where the first supervision belongs to an agent role.""" + + def _filter_fn(cut: Cut) -> bool: + if not cut.supervisions: + return False + cut.supervisions = sorted(cut.supervisions, key=lambda s: s.start) + return cut.supervisions[0].speaker not in agent_roles + + return cuts.filter(_filter_fn) + + def convert_overlap_cut_fn(cut: Cut) -> Cut: + """Convert agent/user overlapping segments into sequential SupervisionSegments.""" + agent_segments = [ + SupervisionSegment( + id=cut.id, + recording_id=cut.id, + start=seg["start"] - move_agent_text_back_by, + duration=seg["end"] - seg["start"] + move_agent_text_back_by, + text=seg["text"], + speaker="agent", + ) + for seg in cut.agent_segments + ] + + user_segments = [ + SupervisionSegment( + id=cut.id, + recording_id=cut.id, + start=seg["start"], + duration=seg["end"] - seg["start"], + text=seg["text"], + speaker="user", + ) + for seg in cut.user_segments + ] + + cut.supervisions = sorted(agent_segments + user_segments, key=lambda s: s.start) + cut.task = "s2s_duplex_overlap_as_s2s_duplex" + return cut + + cuts = cuts.map(convert_overlap_cut_fn) + if filter_samples_starting_with_agent: + cuts = filter_cuts_starting_with_agent_fn(cuts, tuple(agent_roles)) + + return cuts, is_tarred + + @data_type_parser(["lhotse_magpietts_data_as_continuation"]) def read_lhotse_magpietts_data_as_s2s_duplex(config) -> Tuple[CutSet, bool]: cuts, is_tarred = read_cutset_from_config(config) diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index 29ea67c1c48b..d4d25794ce56 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -61,7 +61,7 @@ HAVE_UTMOSV2 = False from transformers import WhisperForConditionalGeneration, WhisperProcessor - +from typing import List @dataclass class ProcessBatchOutput: @@ -671,6 +671,7 @@ def process_batch( phoneme_tokens_lens: Optional[torch.Tensor] = None, mode: str = "train", training_mode: Optional[TrainingMode] = None, + task: Optional[List[str]] = None, ) -> ProcessBatchOutput: """ Simplified batch processing using channel-based embedding architecture. @@ -754,9 +755,9 @@ def process_batch( speech_eos_mask = None if self.cfg.get("use_multiturn_dataset", False): speech_eos_mask = (text == self.interruption_token_id) # (B, T) - # ToDo: do not remove from interruption data - text[speech_eos_mask] = self.tokenizer.pad # Clean up the text channel - + # remove the interruption token for all task, expect for interruption + if "interruption" not in task[0]: + text[speech_eos_mask] = self.tokenizer.pad # Clean up the text channel # 3. Prepare text channel embeddings text_channel_embedding, text_channel_lens = self.prepare_text_channel_embeddings( @@ -1051,6 +1052,7 @@ def training_step(self, batch, batch_idx): phoneme_tokens=batch.get('phoneme_tokens'), phoneme_tokens_lens=batch.get('phoneme_tokens_lens'), mode="train", + task=batch["task"], ) loss = batch_output.loss codebook_loss = batch_output.codebook_loss @@ -1148,6 +1150,7 @@ def validation_step(self, batch, batch_idx): phoneme_tokens=batch.get('phoneme_tokens'), phoneme_tokens_lens=batch.get('phoneme_tokens_lens'), mode="val", + task=batch["task"], ) # Access ProcessBatchOutput dataclass attributes # logits come from the parallel prediction head diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index 7b72f374e357..e8c4c10a829b 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -534,6 +534,9 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): '': self.eos_id, '': self.cfg_unk_token_id, } + if cfg.get("use_multiturn_dataset", False): + special_vocab[""] = self.interruption_token_id + self.cas_encoder = CharAwareSubwordEncoder( d_embed=cfg.embedding_dim, llm_tokenizer_vocab=subword_vocab, From 481543e355a28ad2d36774a3e40db84fdbfe797c Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 27 Apr 2026 17:14:27 -0700 Subject: [PATCH 026/109] Add partial loading Signed-off-by: Edresson Casanova --- nemo/collections/tts/models/easy_magpietts_inference.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index e8c4c10a829b..fd7ecc358322 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -47,6 +47,8 @@ from nemo.utils import logging from nemo.utils.exceptions import NeMoBaseException +from nemo.collections.speechlm2.parts.pretrained import set_model_dict_for_partial_init + @dataclass class TrainingMode: @@ -666,6 +668,7 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): return state_dict def load_state_dict(self, state_dict, strict=True): + state_dict = set_model_dict_for_partial_init(state_dict, self.state_dict()) if not strict: super().load_state_dict(state_dict, strict=False) modules_to_skip = self._get_state_dict_keys_to_exclude() From 3a38620b2b8c43ccdbf56f3ca1e79349d631220e Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 28 Apr 2026 04:29:39 -0700 Subject: [PATCH 027/109] Fix interruption handling for validation dataset Signed-off-by: Edresson Casanova --- nemo/collections/tts/models/easy_magpietts.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index d4d25794ce56..d1917423f9ce 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -756,7 +756,7 @@ def process_batch( if self.cfg.get("use_multiturn_dataset", False): speech_eos_mask = (text == self.interruption_token_id) # (B, T) # remove the interruption token for all task, expect for interruption - if "interruption" not in task[0]: + if not task or "interruption" not in str(task[0]): text[speech_eos_mask] = self.tokenizer.pad # Clean up the text channel # 3. Prepare text channel embeddings @@ -1052,7 +1052,7 @@ def training_step(self, batch, batch_idx): phoneme_tokens=batch.get('phoneme_tokens'), phoneme_tokens_lens=batch.get('phoneme_tokens_lens'), mode="train", - task=batch["task"], + task=batch["task"] if self.cfg.get("use_multiturn_dataset", False) else None, ) loss = batch_output.loss codebook_loss = batch_output.codebook_loss @@ -1150,7 +1150,7 @@ def validation_step(self, batch, batch_idx): phoneme_tokens=batch.get('phoneme_tokens'), phoneme_tokens_lens=batch.get('phoneme_tokens_lens'), mode="val", - task=batch["task"], + task=batch["task"] if "task" in batch else None, ) # Access ProcessBatchOutput dataclass attributes # logits come from the parallel prediction head From 05874118aa1705de7c578021d36ab9f65a89d5e3 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 30 Apr 2026 05:17:06 -0700 Subject: [PATCH 028/109] Add user silence mask Signed-off-by: Edresson Casanova --- ...text_to_speech_dataset_lhotse_multiturn.py | 38 ++++ nemo/collections/tts/models/easy_magpietts.py | 170 ++++++++++++++++-- 2 files changed, 191 insertions(+), 17 deletions(-) diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py index 07e6e2d55d47..b37134625a52 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py @@ -579,6 +579,14 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) if len(reward_list) > 0: batch_dict['rewards'] = torch.FloatTensor(reward_list) + agent_mask, agent_mask_lens = collate_speaker_mask_channel( + cuts, + self.frame_length, + self.output_roles, + ) + + batch_dict["agent_mask"] = agent_mask + batch_dict["agent_mask_lens"] = agent_mask return batch_dict @@ -611,6 +619,36 @@ def collate_token_channel( return collate_vectors(tokens, padding_value=pad_id), token_lens +def build_speaker_mask_channel( + cut: Cut, + frame_length: Seconds, + output_roles: set[str], +) -> torch.Tensor: + total = compute_num_frames(cut.duration, frame_length, cut.sampling_rate) + mask = torch.zeros(total, dtype=torch.float32) + + for supervision in cut.supervisions: + start = compute_num_frames(supervision.start, frame_length, cut.sampling_rate) + end = compute_num_frames(supervision.end, frame_length, cut.sampling_rate) + + if supervision.speaker in output_roles: + mask[start:end] = 1.0 + + return mask + +def collate_speaker_mask_channel( + cuts: CutSet, + frame_length: Seconds, + output_roles: set[str], +): + masks = [ + build_speaker_mask_channel(cut, frame_length, output_roles) + for cut in cuts + ] + mask_lens = torch.tensor([len(m) for m in masks]) + return collate_vectors(masks, padding_value=0.0), mask_lens + + def build_token_channel( cut: Cut, tokenizer, diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index d1917423f9ce..e46832432a36 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -171,7 +171,14 @@ def _get_state_dict_keys_to_exclude(self): '_utmos_calculator', ] - def compute_loss(self, logits, audio_codes, audio_codes_lens): + + def compute_loss( + self, + logits, + audio_codes, + audio_codes_lens, + agent_mask_target=None, + ): """ Computes the audio codebook loss. Used by (1) The main Magpie-TTS transformer @@ -183,21 +190,26 @@ def compute_loss(self, logits, audio_codes, audio_codes_lens): """ loss_mask = get_mask_from_lengths(audio_codes_lens) loss_mask = loss_mask.unsqueeze(1).repeat(1, audio_codes.size(1), 1) + + if agent_mask_target is not None: + agent_mask_target = agent_mask_target.to(device=audio_codes.device, dtype=loss_mask.dtype) + total_codebook_loss = None for codebook in range(audio_codes.size(1)): si = codebook * self.num_all_tokens_per_codebook ei = si + self.num_all_tokens_per_codebook - codebook_logits = logits[:, :, si:ei] # (B, T', num_tokens_per_codebook) - codebook_targets = audio_codes[:, codebook] # (B, T') - codebook_loss = self.cross_entropy_loss( - codebook_logits.permute(0, 2, 1), codebook_targets.long() # (B, num_tokens_per_codebook, T') + codebook_logits = logits[:, :, si:ei] + codebook_targets = audio_codes[:, codebook] + raw_loss = self.cross_entropy_loss( + codebook_logits.permute(0, 2, 1), + codebook_targets.long(), ) # (B, T') - codebook_loss = codebook_loss * loss_mask[:, codebook, :] - codebook_loss = codebook_loss.sum() / loss_mask[:, codebook, :].sum() - if total_codebook_loss is None: - total_codebook_loss = codebook_loss - else: - total_codebook_loss = total_codebook_loss + codebook_loss + effective_mask = loss_mask[:, codebook, :] + if agent_mask_target is not None: + effective_mask = effective_mask * agent_mask_target + codebook_loss = raw_loss * effective_mask + codebook_loss = codebook_loss.sum() / effective_mask.sum().clamp_min(1.0) + total_codebook_loss = codebook_loss if total_codebook_loss is None else total_codebook_loss + codebook_loss total_codebook_loss = total_codebook_loss / audio_codes.size(1) return total_codebook_loss, loss_mask @@ -577,6 +589,7 @@ def prepare_audio_channel_embeddings( codes_len=audio_codes_lens, bos_id=self.audio_bos_id, eos_id=self.audio_eos_id, + num_eos_tokens=1 if speech_eos_mask is None else 0, ) # Stack audio codes across codebooks @@ -590,11 +603,10 @@ def prepare_audio_channel_embeddings( ) if speech_eos_mask is not None: - # 1. Shift the mask +1 to the right to account for the token and +1 for the EOS - # prepended by add_special_tokens. + # 1. Shift the mask +1 to the right to account for the token. B_mask, T_mask = speech_eos_mask.shape - shifted_mask = torch.zeros((B_mask, T_mask + 2), dtype=torch.bool, device=device) - shifted_mask[:, 2:] = speech_eos_mask + shifted_mask = torch.zeros((B_mask, T_mask + 1), dtype=torch.bool, device=device) + shifted_mask[:, 1:] = speech_eos_mask # 2. Find the minimum overlapping time dimension t_mask = shifted_mask.size(1) @@ -672,6 +684,7 @@ def process_batch( mode: str = "train", training_mode: Optional[TrainingMode] = None, task: Optional[List[str]] = None, + agent_mask: Optional[torch.Tensor] = None, ) -> ProcessBatchOutput: """ Simplified batch processing using channel-based embedding architecture. @@ -907,10 +920,130 @@ def process_batch( pred_embeddings_audio = self.audio_out_projection(pred_embeddings) logits = self.final_proj(pred_embeddings_audio) + if agent_mask is not None: + # pad agent mask + target_T = audio_codes_target.size(2) + if agent_mask.size(1) < target_T: + pad = torch.zeros( + agent_mask.size(0), + target_T - agent_mask.size(1), + device=agent_mask.device, + dtype=agent_mask.dtype, + ) + agent_mask = torch.cat([agent_mask, pad], dim=1) + else: + agent_mask = agent_mask[:, :target_T] + agent_mask = agent_mask.to(audio_codes_target.device).bool() + + # force include EOS + valid = get_mask_from_lengths(audio_codes_lens_target).bool().to(audio_codes_target.device) + eos_any = (audio_codes_target == self.audio_eos_id).any(dim=1) & valid + agent_mask = agent_mask | eos_any + + # include previous token to avoid align issues + eos_prev1 = torch.zeros_like(eos_any) + eos_prev1[:, :-1] = eos_any[:, 1:] + agent_mask = agent_mask | eos_any | eos_prev1 + + """ + # ========================= + # HARD ASSERTS (ALIGNMENT) + # ========================= + agent_mask = agent_mask.bool() + debug_window = 5 + + valid = get_mask_from_lengths(audio_codes_lens_target).bool().to(audio_codes_target.device) + eos_any = (audio_codes_target == self.audio_eos_id).any(dim=1) & valid # (B, T) + + eos_prev1 = torch.zeros_like(eos_any) + eos_prev1[:, :-1] = eos_any[:, 1:] + + eos_prev2 = torch.zeros_like(eos_any) + eos_prev2[:, :-2] = eos_any[:, 2:] + + eos_required = eos_any | eos_prev1 | eos_prev2 + missing_eos = eos_required & (~agent_mask) + + if missing_eos.any(): + b, t = missing_eos.nonzero(as_tuple=False)[0].tolist() + s = max(0, t - debug_window) + e = min(audio_codes_target.size(2), t + debug_window + 1) + + eos_positions_b = eos_any[b].nonzero(as_tuple=False).flatten().detach().cpu().tolist() + required_positions_b = eos_required[b].nonzero(as_tuple=False).flatten().detach().cpu().tolist() + + raise RuntimeError( + "[MASK ERROR] EOS coverage violation\n" + f"first_bad=(b={b}, t={t})\n" + f"audio_len_target={int(audio_codes_lens_target[b])}\n" + f"all_eos_positions_b={eos_positions_b}\n" + f"required_positions_b={required_positions_b}\n" + f"is_eos_at_bad={bool(eos_any[b, t])}\n" + f"is_eos_prev1_at_bad={bool(eos_prev1[b, t])}\n" + f"is_eos_prev2_at_bad={bool(eos_prev2[b, t])}\n" + f"valid_window={valid[b, s:e].detach().cpu().tolist()}\n" + f"eos_window={eos_any[b, s:e].detach().cpu().tolist()}\n" + f"eos_required_window={eos_required[b, s:e].detach().cpu().tolist()}\n" + f"agent_mask_window={agent_mask[b, s:e].detach().cpu().tolist()}\n" + f"codes_window={audio_codes_target[b, :, s:e].detach().cpu().tolist()}\n" + ) + + # ---- All non-pad text must be covered ---- + if text is not None: + text = text.to(agent_mask.device) + valid_text = text != self.pad_id + + if text_lens is not None: + valid_text = valid_text & get_mask_from_lengths(text_lens).bool().to(text.device) + + T_overlap = min(text.size(1), agent_mask.size(1)) + valid_text = valid_text[:, :T_overlap] + agent_mask_text = agent_mask[:, :T_overlap] + + missing_text = valid_text & (~agent_mask_text) + + if missing_text.any(): + b, t = missing_text.nonzero(as_tuple=False)[0].tolist() + s = max(0, t - debug_window) + e = min(T_overlap, t + debug_window + 1) + + first_nonpad = valid_text[b].nonzero(as_tuple=False) + first_nonpad_t = int(first_nonpad[0].item()) if first_nonpad.numel() > 0 else None + + bad_token = int(text[b, t].detach().cpu()) + + token_str = None + try: + token_str = self.tokenizer.ids_to_text([bad_token]) + except Exception: + try: + token_str = self.tokenizer.decode([bad_token]) + except Exception: + try: + token_str = self.tokenizer.id_to_token(bad_token) + except Exception: + token_str = "" + + raise RuntimeError( + "[MASK ERROR] Text coverage violation\n" + f"first_bad=(b={b}, t={t})\n" + f"bad_token_id={bad_token}\n" + f"bad_token_decoded={token_str}\n" + f"text_lens={int(text_lens[b]) if text_lens is not None else None}\n" + f"audio_len_target={int(audio_codes_lens_target[b])}\n" + f"T_text={text.size(1)} T_agent_mask={agent_mask.size(1)} T_overlap={T_overlap}\n" + f"first_nonpad_text_t={first_nonpad_t}\n" + f"text_window={text[b, s:e].detach().cpu().tolist()}\n" + f"valid_text_window={valid_text[b, s:e].detach().cpu().tolist()}\n" + f"agent_mask_window={agent_mask_text[b, s:e].detach().cpu().tolist()}\n" + ) + """ + # Compute codebook loss - codebook_loss, _ = self.compute_loss(logits, audio_codes_target, audio_codes_lens_target) + codebook_loss, _ = self.compute_loss(logits, audio_codes_target, audio_codes_lens_target, agent_mask_target=agent_mask) loss = self.parallel_codebook_loss_scale * codebook_loss + # Compute local transformer loss if applicable local_transformer_loss = None local_transformer_logits = None @@ -920,8 +1053,9 @@ def process_batch( pred_embeddings, audio_codes_target, targets_offset_by_one=False ) local_transformer_loss, _ = self.compute_loss( - local_transformer_logits, audio_codes_target, audio_codes_lens_target + local_transformer_logits, audio_codes_target, audio_codes_lens_target, agent_mask_target=agent_mask ) + loss = loss + self.local_transformer_loss_scale * local_transformer_loss # Compute phoneme loss if applicable @@ -1053,6 +1187,7 @@ def training_step(self, batch, batch_idx): phoneme_tokens_lens=batch.get('phoneme_tokens_lens'), mode="train", task=batch["task"] if self.cfg.get("use_multiturn_dataset", False) else None, + agent_mask=batch["agent_mask"] if self.cfg.get("use_multiturn_dataset", False) and self.cfg.get("mask_user_on_loss", False) else None, ) loss = batch_output.loss codebook_loss = batch_output.codebook_loss @@ -1151,6 +1286,7 @@ def validation_step(self, batch, batch_idx): phoneme_tokens_lens=batch.get('phoneme_tokens_lens'), mode="val", task=batch["task"] if "task" in batch else None, + agent_mask=batch["agent_mask"] if "agent_mask" in batch and self.cfg.get("mask_user_on_loss", False) else None, ) # Access ProcessBatchOutput dataclass attributes # logits come from the parallel prediction head From d4d2c69e71aab98fb59729d1407a6fc5dae63401 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 1 May 2026 07:19:22 -0700 Subject: [PATCH 029/109] Add use_user_speaking_token Signed-off-by: Edresson Casanova --- .../tts/easy_magpietts_inference_multiturn.py | 149 ++++++++++- nemo/collections/tts/models/easy_magpietts.py | 239 ++++++++++-------- .../tts/models/easy_magpietts_inference.py | 115 +++++++++ .../tts/modules/magpietts_modules.py | 2 +- 4 files changed, 399 insertions(+), 106 deletions(-) diff --git a/examples/tts/easy_magpietts_inference_multiturn.py b/examples/tts/easy_magpietts_inference_multiturn.py index 3e04ec07b361..3eab73a83e5c 100644 --- a/examples/tts/easy_magpietts_inference_multiturn.py +++ b/examples/tts/easy_magpietts_inference_multiturn.py @@ -195,15 +195,18 @@ def collate_and_tokenize_custom( force_interruption=False, normalize_context_audio_volume=True, use_librosa=False, + profile_multiturn_inference=False, ): main_tokenizer_name = list(model.cfg.text_tokenizers.keys())[0] # --- MULTI-TURN MODE DECISION --- - is_duplex = emulate_duplex_inference + is_profile = profile_multiturn_inference + is_duplex = emulate_duplex_inference and not is_profile out_dict = { "duplex_multiturn": is_duplex, - "regular_multiturn": not is_duplex, + "regular_multiturn": (not is_duplex) and (not is_profile), + "profile_multiturn": is_profile, } tokenized_list = [] @@ -412,7 +415,11 @@ def main(): parser.add_argument("--emulate_duplex_inference", action="store_true") parser.add_argument("--add_interruption_token", action="store_true") parser.add_argument("--force_interruption", action="store_true") - + parser.add_argument("--profile_multiturn_inference", action="store_true") + parser.add_argument("--profile_pad_min_sec", type=float, default=2.0) + parser.add_argument("--profile_pad_max_sec", type=float, default=2.0) + + # Speaker & Prompt Configurations parser.add_argument("--user_custom_speaker_reference", action="store_true") parser.add_argument("--inference_speaker_reference", type=str, default=None) @@ -429,6 +436,12 @@ def main(): args = parser.parse_args() + if args.profile_multiturn_inference and args.batch_size != 1: + raise RuntimeError("--profile_multiturn_inference currently requires --batch_size=1.") + + if args.profile_pad_max_sec < args.profile_pad_min_sec: + raise RuntimeError("--profile_pad_max_sec must be >= --profile_pad_min_sec.") + distributed = int(os.environ.get("WORLD_SIZE", "1")) > 1 if distributed and not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend="nccl") @@ -507,6 +520,7 @@ def main(): force_interruption=args.force_interruption, normalize_context_audio_volume=args.normalize_volume, use_librosa=args.use_librosa, + profile_multiturn_inference=args.profile_multiturn_inference ) dataloader = DataLoader( @@ -569,7 +583,6 @@ def main(): phoneme_sampling_method="argmax", use_inference_mode=True, ) - # --------------------------------------------------------- # MODE 1: DUPLEX (Continuous Padding Token Stream) # --------------------------------------------------------- @@ -665,6 +678,129 @@ def main(): state, _, _ = model.streaming_step(state=state, text_tokens=current_tokens, use_inference_mode=True) + # --------------------------------------------------------- + # MODE 3: PROFILE MULTI-TURN + # --------------------------------------------------------- + elif inputs["profile_multiturn"]: + profile_started = False + if B != 1: + raise RuntimeError( + "--profile_multiturn_inference currently supports only batch_size=1. " + "Use --batch_size=1 for this mode." + ) + + batched_turns = inputs["batched_turns"] + batched_turn_lens = inputs["batched_turn_lens"] + valid_turn_masks = inputs["valid_turn_masks"] + + max_turns = len(batched_turns) + prev_turn_immediate_eos = True # force prefill before first turn + prev_turn_ended_with_audio_eos = True # profile before turn 0 + for t in range(max_turns): + turn_text = batched_turns[t].to(device) + turn_lens = batched_turn_lens[t].to(device) + valid_mask = valid_turn_masks[t].to(device) + + if not bool(valid_mask[0].item()): + continue + + # Re-open stream for this turn. + state.finished.zero_() + state.text_finished.zero_() + state.audio_prediction_end_idx.fill_(-1) + + if hasattr(state, "phoneme_stream_ended"): + state.phoneme_stream_ended.zero_() + if hasattr(state, "phoneme_eos_detected"): + state.phoneme_eos_detected.zero_() + + # Prefill before first turn and only after turns that ended immediately. + if t == 0 or prev_turn_ended_with_audio_eos: + profile_seconds = ( + args.profile_pad_min_sec + + torch.rand((), device=device).item() + * (args.profile_pad_max_sec - args.profile_pad_min_sec) + ) + + profile_T = max( + 1, + int(round(profile_seconds * model.sample_rate / model.input_samples_per_frame)), + ) + + profile_tokens = torch.full( + (1, profile_T), + model.pad_id, + dtype=torch.long, + device=device, + ) + + state = model.streaming_prefill_profile( + state=state, + text_tokens=profile_tokens, + use_inference_mode=True, + ) + + logging.info( + f"[profile_multiturn] turn={t} prefilled {profile_T} steps " + f"({profile_seconds:.2f}s)" + ) + if not profile_started: + start_frame = sum(p.size(-1) for p in state.all_predictions) + state.audio_prediction_start_idx.fill_(start_frame) + profile_started = True + + turn_offset = state.text_tokens_seen.clone() + turn_steps = 0 + saw_audio = False + first_audio_step_finished = False + + turn_text_done = False + + while turn_steps < args.max_tts_steps: + turn_steps += 1 + + state.finished.zero_() + + relative_position = state.text_tokens_seen - turn_offset + text_exhausted = relative_position >= turn_lens + + position = relative_position.clamp(min=0, max=turn_text.size(1) - 1) + current_tokens = turn_text[torch.arange(B, device=device), position] + current_tokens = torch.where( + text_exhausted, + torch.full_like(current_tokens, model.eos_id), + current_tokens, + ) + + state, audio_codes, _ = model.streaming_step( + state=state, + text_tokens=current_tokens, + use_inference_mode=True, + ) + + if audio_codes is not None and not saw_audio: + saw_audio = True + first_audio_step_finished = bool(state.finished[0].item()) + + if bool(text_exhausted[0].item()) and bool(state.finished[0].item()): + turn_ended_with_audio_eos = True + break + + prev_turn_immediate_eos = saw_audio and first_audio_step_finished + # prev_turn_immediate_eos = saw_audio and first_audio_step_finished + prev_turn_ended_with_audio_eos = turn_ended_with_audio_eos + + # keep generated codes, but don't let this turn's EOS crop finalize output + state.audio_prediction_end_idx.fill_(-1) + state.finished.zero_() + + logging.info( + f"[profile_multiturn] turn={t} steps={turn_steps} " + f"saw_audio={saw_audio} immediate_eos={prev_turn_immediate_eos}" + ) + # if state.audio_prediction_end_idx[0].item() >= 0: + # last_audio_prediction_end_idx.copy_(state.audio_prediction_end_idx) + # Scrub Special Tokens (BOS/EOS) from Audio Codes --- # Because we force-decode the entire uncropped sequence, any BOS or EOS # tokens left in the array will produce loud artifacts in the codec. @@ -689,6 +825,9 @@ def main(): # Erase the internal memory of Turn 1's EOS token so `streaming_finalize` # decodes the entire physical sequence! state.audio_prediction_end_idx.fill_(-1) + + if inputs["profile_multiturn"]: + state.audio_prediction_end_idx.fill_(-1) # Finalize decodes the collected Codec states globally regardless of which loop was run finalize_output = model.streaming_finalize(state, use_inference_mode=True) @@ -708,6 +847,8 @@ def main(): # Cap the expected length so it physically cannot exceed the actual generated tensor size audio_len = torch.min(audio_len, torch.tensor(audio_f32.size(1), device=device)) + elif inputs["profile_multiturn"]: + audio_len = finalize_output.audio_len.int() else: audio_len = torch.min(audio_len, expected_audio_lens) diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index e46832432a36..57f98a58ca2c 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -928,119 +928,47 @@ def process_batch( agent_mask.size(0), target_T - agent_mask.size(1), device=agent_mask.device, - dtype=agent_mask.dtype, + dtype=torch.bool, ) - agent_mask = torch.cat([agent_mask, pad], dim=1) + agent_mask = torch.cat([agent_mask.bool(), pad], dim=1) else: - agent_mask = agent_mask[:, :target_T] - agent_mask = agent_mask.to(audio_codes_target.device).bool() + agent_mask = agent_mask[:, :target_T].bool() + + agent_mask = agent_mask.to(audio_codes_target.device) - # force include EOS valid = get_mask_from_lengths(audio_codes_lens_target).bool().to(audio_codes_target.device) + agent_mask = agent_mask & valid + eos_any = (audio_codes_target == self.audio_eos_id).any(dim=1) & valid - agent_mask = agent_mask | eos_any - # include previous token to avoid align issues eos_prev1 = torch.zeros_like(eos_any) eos_prev1[:, :-1] = eos_any[:, 1:] - agent_mask = agent_mask | eos_any | eos_prev1 - - """ - # ========================= - # HARD ASSERTS (ALIGNMENT) - # ========================= - agent_mask = agent_mask.bool() - debug_window = 5 - valid = get_mask_from_lengths(audio_codes_lens_target).bool().to(audio_codes_target.device) - eos_any = (audio_codes_target == self.audio_eos_id).any(dim=1) & valid # (B, T) - - eos_prev1 = torch.zeros_like(eos_any) - eos_prev1[:, :-1] = eos_any[:, 1:] + # Keep EOS/boundary supervised even if the dataloader mask is slightly off. + agent_mask = agent_mask | eos_prev1 | eos_any - eos_prev2 = torch.zeros_like(eos_any) - eos_prev2[:, :-2] = eos_any[:, 2:] - - eos_required = eos_any | eos_prev1 | eos_prev2 - missing_eos = eos_required & (~agent_mask) - - if missing_eos.any(): - b, t = missing_eos.nonzero(as_tuple=False)[0].tolist() - s = max(0, t - debug_window) - e = min(audio_codes_target.size(2), t + debug_window + 1) - - eos_positions_b = eos_any[b].nonzero(as_tuple=False).flatten().detach().cpu().tolist() - required_positions_b = eos_required[b].nonzero(as_tuple=False).flatten().detach().cpu().tolist() - - raise RuntimeError( - "[MASK ERROR] EOS coverage violation\n" - f"first_bad=(b={b}, t={t})\n" - f"audio_len_target={int(audio_codes_lens_target[b])}\n" - f"all_eos_positions_b={eos_positions_b}\n" - f"required_positions_b={required_positions_b}\n" - f"is_eos_at_bad={bool(eos_any[b, t])}\n" - f"is_eos_prev1_at_bad={bool(eos_prev1[b, t])}\n" - f"is_eos_prev2_at_bad={bool(eos_prev2[b, t])}\n" - f"valid_window={valid[b, s:e].detach().cpu().tolist()}\n" - f"eos_window={eos_any[b, s:e].detach().cpu().tolist()}\n" - f"eos_required_window={eos_required[b, s:e].detach().cpu().tolist()}\n" - f"agent_mask_window={agent_mask[b, s:e].detach().cpu().tolist()}\n" - f"codes_window={audio_codes_target[b, :, s:e].detach().cpu().tolist()}\n" + if self.cfg.get("debug_decode_agent_mask", False) and mode == "train" and self.global_step < 5: + self.debug_decode_mask_regions( + audio_codes_target=audio_codes_target, + audio_codes_lens_target=audio_codes_lens_target, + agent_mask=agent_mask, + out_dir=os.path.join(self.trainer.log_dir, "mask_debug", f"step_{self.global_step}"), + prefix=f"batch_{self.global_rank}_{self.global_step}", ) - # ---- All non-pad text must be covered ---- - if text is not None: - text = text.to(agent_mask.device) - valid_text = text != self.pad_id - - if text_lens is not None: - valid_text = valid_text & get_mask_from_lengths(text_lens).bool().to(text.device) - - T_overlap = min(text.size(1), agent_mask.size(1)) - valid_text = valid_text[:, :T_overlap] - agent_mask_text = agent_mask[:, :T_overlap] - - missing_text = valid_text & (~agent_mask_text) - - if missing_text.any(): - b, t = missing_text.nonzero(as_tuple=False)[0].tolist() - s = max(0, t - debug_window) - e = min(T_overlap, t + debug_window + 1) - - first_nonpad = valid_text[b].nonzero(as_tuple=False) - first_nonpad_t = int(first_nonpad[0].item()) if first_nonpad.numel() > 0 else None + # replace all user tokens by a single token + if self.cfg.get("use_user_speaking_token", False) and agent_mask is not None: + non_agent_mask = (~agent_mask) & get_mask_from_lengths(audio_codes_lens_target).bool().to(agent_mask.device) - bad_token = int(text[b, t].detach().cpu()) - - token_str = None - try: - token_str = self.tokenizer.ids_to_text([bad_token]) - except Exception: - try: - token_str = self.tokenizer.decode([bad_token]) - except Exception: - try: - token_str = self.tokenizer.id_to_token(bad_token) - except Exception: - token_str = "" - - raise RuntimeError( - "[MASK ERROR] Text coverage violation\n" - f"first_bad=(b={b}, t={t})\n" - f"bad_token_id={bad_token}\n" - f"bad_token_decoded={token_str}\n" - f"text_lens={int(text_lens[b]) if text_lens is not None else None}\n" - f"audio_len_target={int(audio_codes_lens_target[b])}\n" - f"T_text={text.size(1)} T_agent_mask={agent_mask.size(1)} T_overlap={T_overlap}\n" - f"first_nonpad_text_t={first_nonpad_t}\n" - f"text_window={text[b, s:e].detach().cpu().tolist()}\n" - f"valid_text_window={valid_text[b, s:e].detach().cpu().tolist()}\n" - f"agent_mask_window={agent_mask_text[b, s:e].detach().cpu().tolist()}\n" - ) - """ + user_tok = torch.full_like(audio_codes_target, self.audio_user_speaking_id) + audio_codes_target = torch.where( + non_agent_mask.unsqueeze(1), + user_tok, + audio_codes_target, + ) # Compute codebook loss - codebook_loss, _ = self.compute_loss(logits, audio_codes_target, audio_codes_lens_target, agent_mask_target=agent_mask) + codebook_loss, _ = self.compute_loss(logits, audio_codes_target, audio_codes_lens_target, agent_mask_target=agent_mask if self.cfg.get("mask_user_on_loss", False) else None) loss = self.parallel_codebook_loss_scale * codebook_loss @@ -1053,7 +981,7 @@ def process_batch( pred_embeddings, audio_codes_target, targets_offset_by_one=False ) local_transformer_loss, _ = self.compute_loss( - local_transformer_logits, audio_codes_target, audio_codes_lens_target, agent_mask_target=agent_mask + local_transformer_logits, audio_codes_target, audio_codes_lens_target, agent_mask_target=agent_mask if self.cfg.get("mask_user_on_loss", False) else None ) loss = loss + self.local_transformer_loss_scale * local_transformer_loss @@ -1187,7 +1115,7 @@ def training_step(self, batch, batch_idx): phoneme_tokens_lens=batch.get('phoneme_tokens_lens'), mode="train", task=batch["task"] if self.cfg.get("use_multiturn_dataset", False) else None, - agent_mask=batch["agent_mask"] if self.cfg.get("use_multiturn_dataset", False) and self.cfg.get("mask_user_on_loss", False) else None, + agent_mask=batch["agent_mask"] if self.cfg.get("use_multiturn_dataset", False) else None, ) loss = batch_output.loss codebook_loss = batch_output.codebook_loss @@ -1286,7 +1214,7 @@ def validation_step(self, batch, batch_idx): phoneme_tokens_lens=batch.get('phoneme_tokens_lens'), mode="val", task=batch["task"] if "task" in batch else None, - agent_mask=batch["agent_mask"] if "agent_mask" in batch and self.cfg.get("mask_user_on_loss", False) else None, + agent_mask=batch["agent_mask"] if "agent_mask" in batch else None, ) # Access ProcessBatchOutput dataclass attributes # logits come from the parallel prediction head @@ -1822,3 +1750,112 @@ def val_dataloader(self): self._val_dl_wrapped_with_dist_sampler = True return self._validation_dl + + def debug_decode_mask_regions( + self, + audio_codes_target, + audio_codes_lens_target, + agent_mask, + out_dir, + prefix="debug_mask", + ): + os.makedirs(out_dir, exist_ok=True) + + device = audio_codes_target.device + B, C, T = audio_codes_target.shape + + agent_mask = agent_mask.to(device).bool() + + if agent_mask.size(1) < T: + pad = torch.zeros(B, T - agent_mask.size(1), device=device, dtype=torch.bool) + agent_mask = torch.cat([agent_mask, pad], dim=1) + else: + agent_mask = agent_mask[:, :T] + + valid = get_mask_from_lengths(audio_codes_lens_target).bool().to(device) + agent_mask = agent_mask & valid + + C_base = self.num_audio_codebooks + S = self.frame_stacking_factor + C_target = audio_codes_target.size(1) + + sil = self.codec_sil_codes.to(device=device, dtype=audio_codes_target.dtype) + + if C_target == C_base: + sil = sil.view(1, C_base, 1).expand(B, C_base, T) + + elif C_target == C_base * S: + sil_unstacked = sil.view(1, C_base, 1).expand(B, C_base, T * S).contiguous() + sil_stacked, _ = self.stack_codes( + sil_unstacked, + torch.full((B,), T * S, dtype=torch.long, device=device), + bos_id=self.audio_bos_id, + eos_id=self.audio_eos_id, + stacking_factor=S, + num_codebooks=C_base, + ) + sil = sil_stacked[:, :, :T] + else: + raise RuntimeError( + f"Unexpected codebook dim: target C={C_target}, " + f"base C={C_base}, stacking_factor={S}" + ) + + def decode_and_save(codes, lens, name): + codes = codes.clone() + codes, lens = self._prepare_codes_for_decode(codes, lens) + audio, audio_len, _ = self._codec_helper.codes_to_audio(codes, lens) + + for b in range(B): + wav = audio[b, : audio_len[b]].float().detach().cpu().numpy() + sf.write( + os.path.join(out_dir, f"{prefix}_b{b}_{name}.wav"), + wav, + self.output_sample_rate, + ) + + # 1. full target + decode_and_save(audio_codes_target, audio_codes_lens_target, "full_target") + + # 2. only agent region, silence elsewhere + agent_codes = torch.where(agent_mask[:, None, :], audio_codes_target, sil) + decode_and_save(agent_codes, audio_codes_lens_target, "agent_only_sil_elsewhere") + + # 3. only masked-out region, silence elsewhere + non_agent_codes = torch.where((~agent_mask & valid)[:, None, :], audio_codes_target, sil) + decode_and_save(non_agent_codes, audio_codes_lens_target, "non_agent_only_sil_elsewhere") + + # 4. each contiguous agent segment independently + for b in range(B): + mask_b = agent_mask[b] + idx = mask_b.nonzero(as_tuple=False).flatten() + + if idx.numel() == 0: + continue + + # contiguous runs + breaks = torch.where(idx[1:] != idx[:-1] + 1)[0] + 1 + chunks = torch.tensor_split(idx, breaks.cpu().tolist()) + + for seg_i, seg_idx in enumerate(chunks): + start = int(seg_idx[0]) + end = int(seg_idx[-1]) + 1 + + seg_codes = audio_codes_target[b : b + 1, :, start:end].clone() + seg_lens = torch.tensor([end - start], device=device, dtype=torch.long) + + seg_codes, seg_lens = self._prepare_codes_for_decode(seg_codes, seg_lens) + audio, audio_len, _ = self._codec_helper.codes_to_audio(seg_codes, seg_lens) + + wav = audio[0, : audio_len[0]].float().detach().cpu().numpy() + sf.write( + os.path.join(out_dir, f"{prefix}_b{b}_agent_segment{seg_i}_frames{start}-{end}.wav"), + wav, + self.output_sample_rate, + ) + + logging.info( + f"[mask_debug] saved mask decode files to {out_dir}; " + f"agent coverage frames={agent_mask.sum(dim=1).detach().cpu().tolist()} / " + f"{audio_codes_lens_target.detach().cpu().tolist()}" + ) \ No newline at end of file diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index fd7ecc358322..e5dc238155ca 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -270,6 +270,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.context_audio_bos_id = get_token_index(SpecialAudioToken.AUDIO_CONTEXT_BOS) self.context_audio_eos_id = get_token_index(SpecialAudioToken.AUDIO_CONTEXT_EOS) self.mask_token_id = get_token_index(SpecialAudioToken.MASK_TOKEN) + self.audio_user_speaking_id = get_token_index(SpecialAudioToken.USER_SPEAKING) self.num_all_tokens_per_codebook = self.codebook_size + len(SpecialAudioToken) self.use_bpe_char_tokenizer = cfg.get('use_bpe_char_tokenizer', False) # If True, text tokens are embedded only with the char-aware subword (CAS) encoder, and the decoder token @@ -651,6 +652,120 @@ def _generate_codec_silence_buffer(self): else: self._codec_sil_codes_buffer.copy_(sil_tensor_converted) self._codec_sil_codes_buffer_unconverted.copy_(sil_tensor_unconverted) + + def streaming_prefill_profile( + self, + state: StreamingState, + text_tokens: torch.Tensor, # (B, T) or (B,) + use_inference_mode: bool = True, + # ToDo: implement audio direct support instead of use silence tokens + ) -> StreamingState: + grad_ctx = torch.inference_mode if use_inference_mode else torch.no_grad + with grad_ctx(): + if text_tokens.dim() == 1: + text_tokens = text_tokens[:, None] + + B, T = text_tokens.shape + device = state.config.device + text_tokens = text_tokens.to(device) + + # ----------------------- + # TEXT CHANNEL + # ----------------------- + if self.cfg.get("disable_subword_embedding", False): + text_emb = torch.zeros( + B, + T, + self.cfg.embedding_dim, + dtype=next(self.parameters()).dtype, + device=device, + ) + else: + text_emb = self.decoder.get_input_embeddings()(text_tokens) + + if self.use_bpe_char_tokenizer: + if self.cfg.get("use_multiturn_dataset", False): + text_mask = text_tokens != self.pad_id + else: + text_mask = torch.ones_like(text_tokens, dtype=torch.bool) + + text_emb = text_emb + self.cas_encoder(text_tokens, subword_mask=text_mask) + + if self.cfg.get("use_multiturn_dataset", False): + text_emb[text_tokens == self.pad_id] = 0.0 + + # ----------------------- + # AUDIO CHANNEL: silence previous-token input + # ----------------------- + C = self.num_audio_codebooks + S = self.frame_stacking_factor + + sil_codes = self.codec_sil_codes.to(device=device, dtype=torch.long) # (C,) + + sil_codes_unstacked = sil_codes.view(1, C, 1).expand(B, C, T * S).contiguous() + sil_codes_stacked, _ = self.stack_codes( + sil_codes_unstacked, + torch.full((B,), T * S, dtype=torch.long, device=device), + bos_id=self.audio_bos_id, + eos_id=self.audio_eos_id, + stacking_factor=S, + num_codebooks=C, + ) # (B, C*S, T) + + audio_emb = self.embed_audio_tokens(sil_codes_stacked) + + # Match training channel sum: text + audio silence. + combined_emb = text_emb + audio_emb + + # ----------------------- + # CFG handling + # ----------------------- + if state.config.use_cfg: + # Match regular streaming inference: + # conditional branch = text + audio + # unconditional branch = audio only + inputs_embeds = torch.cat([combined_emb, audio_emb], dim=0) + else: + inputs_embeds = combined_emb + + # ----------------------- + # KV CACHE EXTENSION + # ----------------------- + cache_position = torch.arange( + state.cache_seq_len, + state.cache_seq_len + T, + device=device, + ) + + out = self.forward( + inputs_embeds=inputs_embeds, + attention_mask=None, + use_cache=True, + past_key_values=state.past_key_values, + cache_position=cache_position, + ) + + state.past_key_values = out.past_key_values + state.cache_seq_len += T + state.last_hidden = out.last_hidden_state + + # Advance logical streams consumed by this profile prefill. + state.text_tokens_seen += T + state.audio_steps += T + + # Make the next normal streaming_step continue from silence, not AUDIO_BOS. + state.all_predictions.append(sil_codes_unstacked) # keep silence so that in the target audio user will be silence + if self.cfg.get("use_user_speaking_token", False): + state.last_audio_codes = torch.full( + (B, C * S), + self.audio_user_speaking_id, + dtype=torch.long, + device=device, + ) + else: + state.last_audio_codes = sil_codes_stacked[:, :, -1].contiguous() + + return state def _get_state_dict_keys_to_exclude(self) -> List[str]: return [ diff --git a/nemo/collections/tts/modules/magpietts_modules.py b/nemo/collections/tts/modules/magpietts_modules.py index 6edcc9c56e26..592f9b7b9f5f 100644 --- a/nemo/collections/tts/modules/magpietts_modules.py +++ b/nemo/collections/tts/modules/magpietts_modules.py @@ -96,7 +96,7 @@ class SpecialAudioToken(Enum): AUDIO_CONTEXT_EOS = 3 MASK_TOKEN = 4 # Reserve these values so that if we need to add more special tokens in the future the codebook size will remain the same - RESERVED_1 = 5 + USER_SPEAKING = 5 RESERVED_2 = 6 RESERVED_3 = 7 From d4c29aec46cd82d0c892a00273266c064af3462f Mon Sep 17 00:00:00 2001 From: Shehzeen Hussain Date: Fri, 1 May 2026 17:12:08 -0700 Subject: [PATCH 030/109] IPA handling in multiturn data Signed-off-by: Shehzeen Hussain Signed-off-by: Edresson Casanova --- ...text_to_speech_dataset_lhotse_multiturn.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py index b37134625a52..f88b0685cc7b 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py @@ -76,6 +76,21 @@ def _strip_timestamps( return _SPACE_PATTERN.sub(" ", text).strip() +def _get_supervision_ipa_text(supervision) -> str: + """Return IPA for a supervision, preferring top-level field over custom.""" + ipa_text = getattr(supervision, "ipa", None) + if isinstance(ipa_text, str) and ipa_text.strip(): + return ipa_text + + custom = getattr(supervision, "custom", None) + if isinstance(custom, dict): + custom_ipa = custom.get("ipa") + if isinstance(custom_ipa, str): + return custom_ipa + + return "" + + class MagpieTTSLhotseMultiturnDataset(torch.utils.data.Dataset): """ A PyTorch Dataset for loading and processing Text-to-Speech data for @@ -796,7 +811,7 @@ def build_phoneme_channel( continue if isinstance(phoneme_tokenizer, IPABPETokenizer): - ipa_text = supervision.ipa if supervision.has_custom("ipa") else "" + ipa_text = _get_supervision_ipa_text(supervision) if language in ignore_phoneme_languages: ipa_text = "" else: @@ -824,7 +839,7 @@ def build_phoneme_channel( for supervision in cut.supervisions: if supervision.speaker in roles: if isinstance(phoneme_tokenizer, IPABPETokenizer): - ipa_text = supervision.ipa if supervision.has_custom("ipa") else "" + ipa_text = _get_supervision_ipa_text(supervision) if language in ignore_phoneme_languages: ipa_text = "" else: From 1062f8b0aa9f6d22966c68b03b6a51c61b4a44db Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 4 May 2026 11:08:33 -0700 Subject: [PATCH 031/109] Fix inference script Signed-off-by: Edresson Casanova --- .../tts/easy_magpietts_inference_multiturn.py | 127 ++++++++++------ nemo/collections/tts/models/easy_magpietts.py | 142 ++++++++++++------ .../tts/models/easy_magpietts_inference.py | 50 ++++-- .../tts/modules/magpietts_modules.py | 2 +- 4 files changed, 214 insertions(+), 107 deletions(-) diff --git a/examples/tts/easy_magpietts_inference_multiturn.py b/examples/tts/easy_magpietts_inference_multiturn.py index 3eab73a83e5c..eaa34a8b5009 100644 --- a/examples/tts/easy_magpietts_inference_multiturn.py +++ b/examples/tts/easy_magpietts_inference_multiturn.py @@ -401,7 +401,7 @@ def main(): # Optional Paths & General parser.add_argument("--phoneme_tokenizer_path", type=str, default=None) parser.add_argument("--audio_dir", type=str, default=None, help="Root dir for audio paths in JSONL") - parser.add_argument("--inference_dtype", type=str, default="float16") + parser.add_argument("--inference_dtype", type=str, default="float32") parser.add_argument("--debug_dtype", action="store_true") parser.add_argument("--use_librosa", action="store_true", help="Use librosa instead of soundfile+torch for audio load") @@ -557,6 +557,7 @@ def main(): inputs["context_audio"] = speaker_wav.repeat(B, 1) inputs["context_audio_lengths"] = torch.full((B,), speaker_wav.size(-1), dtype=torch.long, device=device) + profile_turn_frame_ranges = [] with torch.inference_mode(): wav = inputs["context_audio"] wav_len = inputs["context_audio_lengths"] @@ -682,7 +683,6 @@ def main(): # MODE 3: PROFILE MULTI-TURN # --------------------------------------------------------- elif inputs["profile_multiturn"]: - profile_started = False if B != 1: raise RuntimeError( "--profile_multiturn_inference currently supports only batch_size=1. " @@ -694,9 +694,9 @@ def main(): valid_turn_masks = inputs["valid_turn_masks"] max_turns = len(batched_turns) - prev_turn_immediate_eos = True # force prefill before first turn prev_turn_ended_with_audio_eos = True # profile before turn 0 for t in range(max_turns): + turn_ended_with_audio_eos = False turn_text = batched_turns[t].to(device) turn_lens = batched_turn_lens[t].to(device) valid_mask = valid_turn_masks[t].to(device) @@ -714,40 +714,46 @@ def main(): if hasattr(state, "phoneme_eos_detected"): state.phoneme_eos_detected.zero_() - # Prefill before first turn and only after turns that ended immediately. - if t == 0 or prev_turn_ended_with_audio_eos: - profile_seconds = ( - args.profile_pad_min_sec - + torch.rand((), device=device).item() - * (args.profile_pad_max_sec - args.profile_pad_min_sec) - ) + # Prefill on the begining of each turn + profile_seconds = ( + args.profile_pad_min_sec + + torch.rand((), device=device).item() + * (args.profile_pad_max_sec - args.profile_pad_min_sec) + ) - profile_T = max( - 1, - int(round(profile_seconds * model.sample_rate / model.input_samples_per_frame)), - ) + profile_T = max( + 1, + int(round(profile_seconds * model.sample_rate / model.input_samples_per_frame)), + ) - profile_tokens = torch.full( - (1, profile_T), - model.pad_id, - dtype=torch.long, - device=device, - ) + profile_tokens = torch.full( + (1, profile_T), + model.pad_id, + dtype=torch.long, + device=device, + ) + # add text tokens needed for profilling + delay_tokens = int(state.config.training_mode.streaming_speech_delay) + delay_tokens = min(delay_tokens, int(turn_lens[0].item()), profile_T) + profile_tokens[:, -delay_tokens:] = turn_text[:, :delay_tokens] + turn_text = turn_text[:, delay_tokens:] + turn_lens = torch.clamp(turn_lens - delay_tokens, min=0) + + state = model.streaming_prefill_profile( + state=state, + text_tokens=profile_tokens, + use_inference_mode=True, + ) - state = model.streaming_prefill_profile( - state=state, - text_tokens=profile_tokens, - use_inference_mode=True, - ) + logging.info( + f"[profile_multiturn] turn={t} prefilled {profile_T} steps " + f"({profile_seconds:.2f}s)" + ) - logging.info( - f"[profile_multiturn] turn={t} prefilled {profile_T} steps " - f"({profile_seconds:.2f}s)" - ) - if not profile_started: - start_frame = sum(p.size(-1) for p in state.all_predictions) - state.audio_prediction_start_idx.fill_(start_frame) - profile_started = True + turn_start_frame = sum(p.size(-1) for p in state.all_predictions) + if t == 0: + state.audio_prediction_start_idx.fill_(turn_start_frame) + profile_decode_start_frame = turn_start_frame turn_offset = state.text_tokens_seen.clone() turn_steps = 0 @@ -786,8 +792,6 @@ def main(): turn_ended_with_audio_eos = True break - prev_turn_immediate_eos = saw_audio and first_audio_step_finished - # prev_turn_immediate_eos = saw_audio and first_audio_step_finished prev_turn_ended_with_audio_eos = turn_ended_with_audio_eos # keep generated codes, but don't let this turn's EOS crop finalize output @@ -796,8 +800,11 @@ def main(): logging.info( f"[profile_multiturn] turn={t} steps={turn_steps} " - f"saw_audio={saw_audio} immediate_eos={prev_turn_immediate_eos}" + f"saw_audio={saw_audio} immediate_eos={prev_turn_ended_with_audio_eos}" ) + turn_end_frame = sum(p.size(-1) for p in state.all_predictions) + profile_turn_frame_ranges.append((t, turn_start_frame, turn_end_frame)) + # if state.audio_prediction_end_idx[0].item() >= 0: # last_audio_prediction_end_idx.copy_(state.audio_prediction_end_idx) @@ -806,14 +813,17 @@ def main(): # tokens left in the array will produce loud artifacts in the codec. bos_id = getattr(model, "audio_bos_id", -1) eos_id = getattr(model, "audio_eos_id", -1) + speaking_id = getattr(model, "audio_user_speaking_id", -1) + speaking_end_id = getattr(model, "audio_user_speaking_end_id", -1) + sil_injection = codec_sil_codes.view(1, -1, 1) for step_idx in range(len(state.all_predictions)): pred = state.all_predictions[step_idx] - # Check if any codebook in the frame has a BOS or EOS token - mask = (pred == bos_id) | (pred == eos_id) + # Check if any codebook in the frame has any special token + mask = (pred == bos_id) | (pred == eos_id) | (pred == speaking_id) | (pred == speaking_end_id) frame_mask = mask.any(dim=1, keepdim=True) - + if frame_mask.any(): state.all_predictions[step_idx] = torch.where( frame_mask, @@ -876,12 +886,43 @@ def main(): audio_len = audio_len.cpu() for i in range(B): - wav = audio_f32[i, : audio_len[i]].numpy() target_path = inputs["target_audio_paths"][i] base_name = os.path.basename(target_path) - out_path = os.path.join(args.out_dir, base_name) - sf.write(out_path, wav, samplerate=model.output_sample_rate) - logging.info(f"Saved: {out_path}") + stem, ext = os.path.splitext(base_name) + if not ext: + ext = ".wav" + + if inputs["profile_multiturn"]: + wav = audio_f32[i, : audio_len[i]].numpy() + out_path = os.path.join(args.out_dir, base_name) + sf.write(out_path, wav, samplerate=model.output_sample_rate) + logging.info(f"Full Audio Saved: {out_path}") + + full_wav = audio_f32[i].numpy() + full_len = int(audio_len[i].item()) + print(profile_turn_frame_ranges) + for turn_id, start_frame, end_frame in profile_turn_frame_ranges: + samples_per_prediction_frame = ( + model.codec_model_samples_per_frame / (model.sample_rate / model.output_sample_rate) + ) + rel_start_frame = start_frame - profile_decode_start_frame + rel_end_frame = end_frame - profile_decode_start_frame + + start_sample = int(round(rel_start_frame * samples_per_prediction_frame)) + end_sample = int(round(rel_end_frame * samples_per_prediction_frame)) + start_sample = max(0, min(start_sample, full_len)) + end_sample = max(start_sample, min(end_sample, full_len)) + print("Turn:", turn_id, "Start:", start_sample, "End:", end_sample, "Start S:", start_sample/model.output_sample_rate, "End S:", end_sample/model.output_sample_rate, ) + turn_wav = full_wav[start_sample:end_sample] + + out_path = os.path.join(args.out_dir, f"{stem}_turn_{turn_id}{ext}") + sf.write(out_path, turn_wav, samplerate=model.output_sample_rate) + logging.info(f"Saved: {out_path}") + else: + wav = audio_f32[i, : audio_len[i]].numpy() + out_path = os.path.join(args.out_dir, base_name) + sf.write(out_path, wav, samplerate=model.output_sample_rate) + logging.info(f"Saved: {out_path}") with fp32_precision(): logging.info("\n--- Evaluation Metrics ---") diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index 57f98a58ca2c..6eafaf31095d 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -552,6 +552,7 @@ def prepare_audio_channel_embeddings( audio_codes_lens: torch.Tensor, delay: torch.Tensor, speech_eos_mask: Optional[torch.Tensor] = None, + agent_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Prepare audio embeddings as a channel input with delay handling. @@ -627,6 +628,96 @@ def prepare_audio_channel_embeddings( audio_codes_target = audio_codes[:, :, 1:] # (B, C, T'-1) audio_codes_input = audio_codes[:, :, :-1] # (B, C, T'-1) + # deal with agent mask + if agent_mask is not None: + target_T = audio_codes_target.size(2) + + # Align dataloader agent_mask to audio_codes_target time. + if agent_mask.size(1) < target_T: + pad = torch.zeros( + agent_mask.size(0), + target_T - agent_mask.size(1), + device=agent_mask.device, + dtype=torch.bool, + ) + agent_mask = torch.cat([agent_mask.bool(), pad], dim=1) + else: + agent_mask = agent_mask[:, :target_T].bool() + + agent_mask = agent_mask.to(audio_codes_target.device) + + valid = get_mask_from_lengths(audio_codes_lens_target).bool().to(audio_codes_target.device) + agent_mask = agent_mask & valid + + # Keep EOS and the frame before EOS supervised. + eos_any = (audio_codes_target == self.audio_eos_id).any(dim=1) & valid + + eos_prev1 = torch.zeros_like(eos_any) + eos_prev1[:, :-1] = eos_any[:, 1:] + + agent_mask = agent_mask | eos_prev1 | eos_any + target_agent_mask = agent_mask & valid + + if self.cfg.get("debug_decode_agent_mask", False) and self.training and self.global_step < 5: + self.debug_decode_mask_regions( + audio_codes_target=audio_codes_target, + audio_codes_lens_target=audio_codes_lens_target, + agent_mask=agent_mask, + out_dir=os.path.join(self.trainer.log_dir, "mask_debug", f"step_{self.global_step}"), + prefix=f"batch_{self.global_rank}_{self.global_step}", + ) + + + # Replace user/non-agent regions with a learned token. + # Important: audio_codes_input predicts audio_codes_target, so input mask must be shifted. + if self.cfg.get("use_user_speaking_token", False): + target_non_agent = (~target_agent_mask) & valid + + # audio_codes_input[:, :, t] is the previous token used to predict target t. + input_agent_mask = torch.zeros_like(target_agent_mask) + input_agent_mask[:, 1:] = target_agent_mask[:, :-1] + input_agent_mask[:, 0] = True # Keep first/BOS input untouched. + + input_valid = torch.zeros_like(valid) + input_valid[:, 1:] = valid[:, :-1] + input_valid[:, 0] = valid[:, 0] + + input_non_agent = (~input_agent_mask) & input_valid + + user_tok_input = torch.full_like(audio_codes_input, self.audio_user_speaking_id) + audio_codes_input = torch.where( + input_non_agent.unsqueeze(1), + user_tok_input, + audio_codes_input, + ) + + user_tok_target = torch.full_like(audio_codes_target, self.audio_user_speaking_id) + audio_codes_target = torch.where( + target_non_agent.unsqueeze(1), + user_tok_target, + audio_codes_target, + ) + + # Put audio_user_speaking_end_id in the input slot that predicts the first agent frame + # after a non-agent region. + if self.cfg.get("use_user_speaking_end_token", False): + user_to_agent = torch.zeros_like(target_agent_mask) + + user_to_agent[:, 1:] = ( + target_agent_mask[:, 1:] + & (~target_agent_mask[:, :-1]) + & valid[:, 1:] + & valid[:, :-1] + ) + + end_tok_input = torch.full_like(audio_codes_input, self.audio_user_speaking_end_id) + + audio_codes_input = torch.where( + user_to_agent.unsqueeze(1), + end_tok_input, + audio_codes_input, + ) + # Embed audio tokens audio_embedded = self.embed_audio_tokens(audio_codes_input) # (B, T'-1, E) @@ -640,7 +731,7 @@ def prepare_audio_channel_embeddings( lengths=[delay, audio_codes_lens_target], ) - return audio_channel_embedding, audio_channel_lens, audio_codes_target, audio_codes_lens_target + return audio_channel_embedding, audio_channel_lens, audio_codes_target, audio_codes_lens_target, agent_mask def slice_sequence_embeddings(self, sequence_embeddings, context_lens, target_lens): """ @@ -825,11 +916,13 @@ def process_batch( audio_channel_lens, audio_codes_target, audio_codes_lens_target, + agent_mask, ) = self.prepare_audio_channel_embeddings( audio_codes=audio_codes, audio_codes_lens=audio_codes_lens, delay=audio_delay, speech_eos_mask=speech_eos_mask, + agent_mask=agent_mask, ) # 6. Sum the channel embeddings element-wise @@ -920,53 +1013,6 @@ def process_batch( pred_embeddings_audio = self.audio_out_projection(pred_embeddings) logits = self.final_proj(pred_embeddings_audio) - if agent_mask is not None: - # pad agent mask - target_T = audio_codes_target.size(2) - if agent_mask.size(1) < target_T: - pad = torch.zeros( - agent_mask.size(0), - target_T - agent_mask.size(1), - device=agent_mask.device, - dtype=torch.bool, - ) - agent_mask = torch.cat([agent_mask.bool(), pad], dim=1) - else: - agent_mask = agent_mask[:, :target_T].bool() - - agent_mask = agent_mask.to(audio_codes_target.device) - - valid = get_mask_from_lengths(audio_codes_lens_target).bool().to(audio_codes_target.device) - agent_mask = agent_mask & valid - - eos_any = (audio_codes_target == self.audio_eos_id).any(dim=1) & valid - - eos_prev1 = torch.zeros_like(eos_any) - eos_prev1[:, :-1] = eos_any[:, 1:] - - # Keep EOS/boundary supervised even if the dataloader mask is slightly off. - agent_mask = agent_mask | eos_prev1 | eos_any - - if self.cfg.get("debug_decode_agent_mask", False) and mode == "train" and self.global_step < 5: - self.debug_decode_mask_regions( - audio_codes_target=audio_codes_target, - audio_codes_lens_target=audio_codes_lens_target, - agent_mask=agent_mask, - out_dir=os.path.join(self.trainer.log_dir, "mask_debug", f"step_{self.global_step}"), - prefix=f"batch_{self.global_rank}_{self.global_step}", - ) - - # replace all user tokens by a single token - if self.cfg.get("use_user_speaking_token", False) and agent_mask is not None: - non_agent_mask = (~agent_mask) & get_mask_from_lengths(audio_codes_lens_target).bool().to(agent_mask.device) - - user_tok = torch.full_like(audio_codes_target, self.audio_user_speaking_id) - audio_codes_target = torch.where( - non_agent_mask.unsqueeze(1), - user_tok, - audio_codes_target, - ) - # Compute codebook loss codebook_loss, _ = self.compute_loss(logits, audio_codes_target, audio_codes_lens_target, agent_mask_target=agent_mask if self.cfg.get("mask_user_on_loss", False) else None) loss = self.parallel_codebook_loss_scale * codebook_loss diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index e5dc238155ca..dbe46d665c1d 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -271,6 +271,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.context_audio_eos_id = get_token_index(SpecialAudioToken.AUDIO_CONTEXT_EOS) self.mask_token_id = get_token_index(SpecialAudioToken.MASK_TOKEN) self.audio_user_speaking_id = get_token_index(SpecialAudioToken.USER_SPEAKING) + self.audio_user_speaking_end_id = get_token_index(SpecialAudioToken.USER_SPEAKING_END) self.num_all_tokens_per_codebook = self.codebook_size + len(SpecialAudioToken) self.use_bpe_char_tokenizer = cfg.get('use_bpe_char_tokenizer', False) # If True, text tokens are embedded only with the char-aware subword (CAS) encoder, and the decoder token @@ -695,26 +696,38 @@ def streaming_prefill_profile( text_emb[text_tokens == self.pad_id] = 0.0 # ----------------------- - # AUDIO CHANNEL: silence previous-token input + # AUDIO CHANNEL: previous-token input during profile # ----------------------- C = self.num_audio_codebooks S = self.frame_stacking_factor sil_codes = self.codec_sil_codes.to(device=device, dtype=torch.long) # (C,) + # Keep all_predictions as real silence so decoded waveform has silence during profile. sil_codes_unstacked = sil_codes.view(1, C, 1).expand(B, C, T * S).contiguous() - sil_codes_stacked, _ = self.stack_codes( - sil_codes_unstacked, - torch.full((B,), T * S, dtype=torch.long, device=device), - bos_id=self.audio_bos_id, - eos_id=self.audio_eos_id, - stacking_factor=S, - num_codebooks=C, - ) # (B, C*S, T) - - audio_emb = self.embed_audio_tokens(sil_codes_stacked) - - # Match training channel sum: text + audio silence. + + if self.cfg.get("use_user_speaking_token", False): + # Match training: during non-agent/user-speaking regions, the audio INPUT token + # is audio_user_speaking_id. + profile_audio_stacked = torch.full( + (B, C * S, T), + self.audio_user_speaking_id, + dtype=torch.long, + device=device, + ) + else: + profile_audio_stacked, _ = self.stack_codes( + sil_codes_unstacked, + torch.full((B,), T * S, dtype=torch.long, device=device), + bos_id=self.audio_bos_id, + eos_id=self.audio_eos_id, + stacking_factor=S, + num_codebooks=C, + ) # (B, C*S, T) + + audio_emb = self.embed_audio_tokens(profile_audio_stacked) + + # Match training channel sum: text + audio profile-token/silence input. combined_emb = text_emb + audio_emb # ----------------------- @@ -755,7 +768,14 @@ def streaming_prefill_profile( # Make the next normal streaming_step continue from silence, not AUDIO_BOS. state.all_predictions.append(sil_codes_unstacked) # keep silence so that in the target audio user will be silence - if self.cfg.get("use_user_speaking_token", False): + if self.cfg.get("use_user_speaking_end_token", False): + state.last_audio_codes = torch.full( + (B, C * S), + self.audio_user_speaking_end_id, + dtype=torch.long, + device=device, + ) + elif self.cfg.get("use_user_speaking_token", False): state.last_audio_codes = torch.full( (B, C * S), self.audio_user_speaking_id, @@ -763,7 +783,7 @@ def streaming_prefill_profile( device=device, ) else: - state.last_audio_codes = sil_codes_stacked[:, :, -1].contiguous() + state.last_audio_codes = profile_audio_stacked[:, :, -1].contiguous() return state diff --git a/nemo/collections/tts/modules/magpietts_modules.py b/nemo/collections/tts/modules/magpietts_modules.py index 592f9b7b9f5f..fcffeb21e96e 100644 --- a/nemo/collections/tts/modules/magpietts_modules.py +++ b/nemo/collections/tts/modules/magpietts_modules.py @@ -97,7 +97,7 @@ class SpecialAudioToken(Enum): MASK_TOKEN = 4 # Reserve these values so that if we need to add more special tokens in the future the codebook size will remain the same USER_SPEAKING = 5 - RESERVED_2 = 6 + USER_SPEAKING_END = 6 RESERVED_3 = 7 @staticmethod From 4ab752bb0b1daf6566eb39ea4f6793708ef0020d Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 4 May 2026 14:21:52 -0700 Subject: [PATCH 032/109] Add transition tokens on loss Signed-off-by: Edresson Casanova --- .../tts/easy_magpietts_inference_multiturn.py | 13 +++++----- nemo/collections/tts/models/easy_magpietts.py | 24 ++++++++++++++++++- .../tts/models/easy_magpietts_inference.py | 2 +- 3 files changed, 31 insertions(+), 8 deletions(-) diff --git a/examples/tts/easy_magpietts_inference_multiturn.py b/examples/tts/easy_magpietts_inference_multiturn.py index eaa34a8b5009..2d8f2bc892ea 100644 --- a/examples/tts/easy_magpietts_inference_multiturn.py +++ b/examples/tts/easy_magpietts_inference_multiturn.py @@ -554,7 +554,7 @@ def main(): inputs["context_audio_lengths"] = inputs["context_audio_lengths"].to(device) if args.user_custom_speaker_reference and args.inference_speaker_reference: - inputs["context_audio"] = speaker_wav.repeat(B, 1) + inputs["context_audio"] = speaker_wav.repeat(B, 1).detach() inputs["context_audio_lengths"] = torch.full((B,), speaker_wav.size(-1), dtype=torch.long, device=device) profile_turn_frame_ranges = [] @@ -733,11 +733,12 @@ def main(): device=device, ) # add text tokens needed for profilling - delay_tokens = int(state.config.training_mode.streaming_speech_delay) - delay_tokens = min(delay_tokens, int(turn_lens[0].item()), profile_T) - profile_tokens[:, -delay_tokens:] = turn_text[:, :delay_tokens] - turn_text = turn_text[:, delay_tokens:] - turn_lens = torch.clamp(turn_lens - delay_tokens, min=0) + if not model.cfg.get("agent_mask_include_transition_prefix", False): + delay_tokens = int(state.config.training_mode.streaming_speech_delay) + delay_tokens = min(delay_tokens, int(turn_lens[0].item()), profile_T) + profile_tokens[:, -delay_tokens:] = turn_text[:, :delay_tokens] + turn_text = turn_text[:, delay_tokens:] + turn_lens = torch.clamp(turn_lens - delay_tokens, min=0) state = model.streaming_prefill_profile( state=state, diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index 6eafaf31095d..6c14556e4375 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -553,6 +553,7 @@ def prepare_audio_channel_embeddings( delay: torch.Tensor, speech_eos_mask: Optional[torch.Tensor] = None, agent_mask: Optional[torch.Tensor] = None, + current_streaming_speech_delay: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Prepare audio embeddings as a channel input with delay handling. @@ -629,6 +630,7 @@ def prepare_audio_channel_embeddings( audio_codes_input = audio_codes[:, :, :-1] # (B, C, T'-1) # deal with agent mask + loss_agent_mask = None if agent_mask is not None: target_T = audio_codes_target.size(2) @@ -657,6 +659,7 @@ def prepare_audio_channel_embeddings( agent_mask = agent_mask | eos_prev1 | eos_any target_agent_mask = agent_mask & valid + loss_agent_mask = target_agent_mask if self.cfg.get("debug_decode_agent_mask", False) and self.training and self.global_step < 5: self.debug_decode_mask_regions( @@ -718,6 +721,24 @@ def prepare_audio_channel_embeddings( audio_codes_input, ) + # Note that consider the current_streaming_speech_delay tokens/user speaking tokens on the loss, + # allowing to predict them in autoregressive way + transition_prefix = int(current_streaming_speech_delay or 0) + if self.cfg.get("agent_mask_include_transition_prefix", False) and transition_prefix > 0: + agent_i = target_agent_mask.float().unsqueeze(1) + + agent_i = torch.nn.functional.pad(agent_i, (0, transition_prefix)) + loss_agent_mask = ( + torch.nn.functional.max_pool1d( + agent_i, + kernel_size=transition_prefix + 1, + stride=1, + ) + .squeeze(1) + .bool() + & valid + ) + # Embed audio tokens audio_embedded = self.embed_audio_tokens(audio_codes_input) # (B, T'-1, E) @@ -731,7 +752,7 @@ def prepare_audio_channel_embeddings( lengths=[delay, audio_codes_lens_target], ) - return audio_channel_embedding, audio_channel_lens, audio_codes_target, audio_codes_lens_target, agent_mask + return audio_channel_embedding, audio_channel_lens, audio_codes_target, audio_codes_lens_target, loss_agent_mask def slice_sequence_embeddings(self, sequence_embeddings, context_lens, target_lens): """ @@ -923,6 +944,7 @@ def process_batch( delay=audio_delay, speech_eos_mask=speech_eos_mask, agent_mask=agent_mask, + current_streaming_speech_delay=current_streaming_speech_delay ) # 6. Sum the channel embeddings element-wise diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index dbe46d665c1d..98aca9f69e5e 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -768,7 +768,7 @@ def streaming_prefill_profile( # Make the next normal streaming_step continue from silence, not AUDIO_BOS. state.all_predictions.append(sil_codes_unstacked) # keep silence so that in the target audio user will be silence - if self.cfg.get("use_user_speaking_end_token", False): + if self.cfg.get("use_user_speaking_end_token", False) and not self.cfg.get("agent_mask_include_transition_prefix", False): state.last_audio_codes = torch.full( (B, C * S), self.audio_user_speaking_end_id, From 5ab5340a6b9d2cdeee5d0d8342074c3142fc72b5 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 4 May 2026 14:33:20 -0700 Subject: [PATCH 033/109] Add extra parameter type Signed-off-by: Edresson Casanova --- nemo/collections/tts/models/easy_magpietts.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index 6c14556e4375..a77c726d51a1 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -554,7 +554,7 @@ def prepare_audio_channel_embeddings( speech_eos_mask: Optional[torch.Tensor] = None, agent_mask: Optional[torch.Tensor] = None, current_streaming_speech_delay: Optional[int] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ Prepare audio embeddings as a channel input with delay handling. @@ -575,6 +575,7 @@ def prepare_audio_channel_embeddings( - audio_channel_lens: Total length of audio channel for each batch item (B,) - audio_codes_target: Target audio codes for loss computation (B, C, T'-1) - audio_codes_lens_target: Length of target audio codes (B,) + - loss_agent_mask: Optional mask used for loss masking; None when no agent_mask is provided. """ batch_size = audio_codes.size(0) device = audio_codes.device From 807e3a4d2d6b80f664867e176975fa581c2c6d44 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 5 May 2026 05:52:04 -0700 Subject: [PATCH 034/109] Fix augmentation Signed-off-by: Edresson Casanova --- nemo/collections/tts/models/easy_magpietts.py | 74 +++++++++++++++---- 1 file changed, 60 insertions(+), 14 deletions(-) diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index a77c726d51a1..26168d455a1c 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -606,10 +606,10 @@ def prepare_audio_channel_embeddings( ) if speech_eos_mask is not None: - # 1. Shift the mask +1 to the right to account for the token. + # Shift +1 for BOS alignment and +1 more so EOS is injected after the marked frame. B_mask, T_mask = speech_eos_mask.shape - shifted_mask = torch.zeros((B_mask, T_mask + 1), dtype=torch.bool, device=device) - shifted_mask[:, 1:] = speech_eos_mask + shifted_mask = torch.zeros((B_mask, T_mask + 2), dtype=torch.bool, device=device) + shifted_mask[:, 2:] = speech_eos_mask # 2. Find the minimum overlapping time dimension t_mask = shifted_mask.size(1) @@ -884,6 +884,7 @@ def process_batch( # remove the interruption token for all task, expect for interruption if not task or "interruption" not in str(task[0]): text[speech_eos_mask] = self.tokenizer.pad # Clean up the text channel + # else: # ToDo: move self.interruption_token_id forward by audio_delay so that soon it saw the interruption token it is forced to stop instead of await audio_delay tokens # 3. Prepare text channel embeddings text_channel_embedding, text_channel_lens = self.prepare_text_channel_embeddings( @@ -1040,7 +1041,6 @@ def process_batch( codebook_loss, _ = self.compute_loss(logits, audio_codes_target, audio_codes_lens_target, agent_mask_target=agent_mask if self.cfg.get("mask_user_on_loss", False) else None) loss = self.parallel_codebook_loss_scale * codebook_loss - # Compute local transformer loss if applicable local_transformer_loss = None local_transformer_logits = None @@ -1118,7 +1118,7 @@ def training_step(self, batch, batch_idx): audio_lens = batch['audio_lens'] audio_codes, audio_codes_lens = self._codec_helper.audio_to_codes(audio, audio_lens) - + # augment tts data to looks more like multiturn data by adding pad on the begining and emulating user speaking. if self.cfg.get("use_multiturn_dataset", False) and "tts" in batch['task']: prob = self.cfg.get("add_tts_sil_begining_prob", 0.0) if prob > 0 and torch.rand(1).item() < prob: @@ -1151,26 +1151,72 @@ def training_step(self, batch, batch_idx): silence_pad = self.codec_sil_codes_unconverted.view(1, C, 1).expand(B, C, T_audio) audio_codes = torch.where(valid_mask_a_exp, gathered_audio, silence_pad) - audio_codes_lens = audio_codes_lens + pad_lens - - # --- Vectorized Text Shift (USING SCALED LENS) --- + audio_codes_lens = torch.clamp(audio_codes_lens + pad_lens, max=T_audio) + + # Vectorized Text Shift old_text = batch['text'] text_lens = batch['text_lens'] - - new_text_lens = text_lens + text_pad_lens - new_T_text = max(new_text_lens.max().item(), old_text.size(1)) - + + new_T_text = old_text.size(1) + new_text_lens = torch.clamp(text_lens + text_pad_lens, max=new_T_text) idx_t = torch.arange(new_T_text, device=device).unsqueeze(0) src_idx_t = idx_t - text_pad_lens.unsqueeze(1) - + valid_mask_t = (src_idx_t >= 0) & (src_idx_t < text_lens.unsqueeze(1)) safe_src_idx_t = src_idx_t.clamp(min=0, max=old_text.size(1) - 1) - gathered_text = torch.gather(old_text, 1, safe_src_idx_t) batch['text'] = torch.where(valid_mask_t, gathered_text, self.pad_id) batch['text_lens'] = new_text_lens + # Vectorized Phoneme Shift + if ( + self.phoneme_tokenizer is not None + and batch.get("phoneme_tokens") is not None + and batch.get("phoneme_tokens_lens") is not None + ): + old_phonemes = batch["phoneme_tokens"] + phoneme_lens = batch["phoneme_tokens_lens"] + + new_T_phoneme = old_phonemes.size(1) + new_phoneme_lens = torch.clamp(phoneme_lens + text_pad_lens, max=new_T_phoneme) + + idx_p = torch.arange(new_T_phoneme, device=device).unsqueeze(0) + src_idx_p = idx_p - text_pad_lens.unsqueeze(1) + + valid_mask_p = (src_idx_p >= 0) & (src_idx_p < phoneme_lens.unsqueeze(1)) + safe_src_idx_p = src_idx_p.clamp(min=0, max=old_phonemes.size(1) - 1) + + gathered_phonemes = torch.gather(old_phonemes, 1, safe_src_idx_p) + + phoneme_pad_id = getattr(self.phoneme_tokenizer, "pad", -1) + batch["phoneme_tokens"] = torch.where( + valid_mask_p, + gathered_phonemes, + torch.full_like(gathered_phonemes, phoneme_pad_id), + ) + batch["phoneme_tokens_lens"] = new_phoneme_lens + + # change batch["agent_mask"] to consider this augmentation (in practice adding zeros/False where we are adding silence ) + if self.cfg.get("use_multiturn_dataset", False) and "agent_mask" in batch: + old_agent_mask = batch["agent_mask"].bool() + T_mask = old_agent_mask.size(1) + + idx_m = torch.arange(T_mask, device=device).unsqueeze(0) + src_idx_m = idx_m - text_pad_lens.unsqueeze(1) + + valid_mask_m = (src_idx_m >= 0) & (src_idx_m < old_agent_mask.size(1)) + safe_src_idx_m = src_idx_m.clamp(min=0, max=old_agent_mask.size(1) - 1) + + gathered_agent_mask = torch.gather(old_agent_mask, 1, safe_src_idx_m) + + # New prepended silence/user region should be non-agent. + batch["agent_mask"] = torch.where( + valid_mask_m, + gathered_agent_mask, + torch.zeros_like(gathered_agent_mask), + ) + batch_output = self.process_batch( text=batch['text'], text_lens=batch['text_lens'], From 0b11e36e9d977c254e2fb332fa5fb2ad0a2611af Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 5 May 2026 12:39:39 -0700 Subject: [PATCH 035/109] Add min_number_of_turns and max_gap_duration_collapse_turns Signed-off-by: Edresson Casanova --- nemo/collections/common/data/lhotse/cutset.py | 57 ++++++++++++++++++- nemo/collections/tts/models/easy_magpietts.py | 27 ++++++++- 2 files changed, 82 insertions(+), 2 deletions(-) diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index 8e384e3a8117..ae039d3f2d3d 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -1038,7 +1038,8 @@ def read_s2s_duplex_overlap_as_s2s_duplex(config) -> Tuple[CutSet, bool]: move_agent_text_back_by = config.get("move_agent_text_back_by", 0) filter_samples_starting_with_agent = config.get("filter_samples_starting_with_agent", False) agent_roles = config.get("agent_roles", ["agent", "Assistant", "assistant"]) - + min_number_of_turns = int(config.get("min_number_of_turns", 0)) + max_gap_duration_collapse_turns = config.get("max_gap_duration_collapse_turns", None) cuts, is_tarred = read_cutset_from_config(config) def filter_cuts_starting_with_agent_fn(cuts: CutSet, agent_roles: Tuple[str, ...]) -> CutSet: @@ -1052,6 +1053,49 @@ def _filter_fn(cut: Cut) -> bool: return cuts.filter(_filter_fn) + def filter_min_agent_turns_fn(cuts: CutSet, min_number_of_turns: int, agent_roles: Tuple[str, ...]) -> CutSet: + """Keep cuts with at least `min_number_of_turns` agent turns.""" + if min_number_of_turns <= 0: + return cuts + + def _filter_fn(cut: Cut) -> bool: + num_agent_turns = sum(s.speaker in agent_roles for s in cut.supervisions) + if not num_agent_turns >= min_number_of_turns: + logging.info( + f"[Parser] Filtering cut={cut.id}: " + f"agent_turns={num_agent_turns} < min_number_of_turns={min_number_of_turns}" + ) + return num_agent_turns >= min_number_of_turns + + return cuts.filter(_filter_fn) + + def collapse_adjacent_same_speaker(supervisions, max_gap): + if max_gap is None or max_gap <= 0: + return supervisions + + supervisions = sorted(supervisions, key=lambda s: (s.start, s.end)) + collapsed = [] + + for s in supervisions: + if not collapsed: + collapsed.append(s) + continue + + prev = collapsed[-1] + gap = s.start - (prev.start + prev.duration) + + if s.speaker == prev.speaker: + if gap < max_gap: + # MERGE + new_end = max(prev.start + prev.duration, s.start + s.duration) + prev.duration = new_end - prev.start + prev.text = f"{prev.text} {s.text}".strip() + else: + collapsed.append(s) + else: + collapsed.append(s) + return collapsed + def convert_overlap_cut_fn(cut: Cut) -> Cut: """Convert agent/user overlapping segments into sequential SupervisionSegments.""" agent_segments = [ @@ -1080,11 +1124,22 @@ def convert_overlap_cut_fn(cut: Cut) -> Cut: cut.supervisions = sorted(agent_segments + user_segments, key=lambda s: s.start) cut.task = "s2s_duplex_overlap_as_s2s_duplex" + if max_gap_duration_collapse_turns is not None: + cut.supervisions = collapse_adjacent_same_speaker( + cut.supervisions, + max_gap_duration_collapse_turns, + ) + return cut cuts = cuts.map(convert_overlap_cut_fn) + + # Force materialization for accurate counting if filter_samples_starting_with_agent: cuts = filter_cuts_starting_with_agent_fn(cuts, tuple(agent_roles)) + + if min_number_of_turns > 0: + cuts = filter_min_agent_turns_fn(cuts, min_number_of_turns, tuple(agent_roles)) return cuts, is_tarred diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index 26168d455a1c..c7be0e1236a7 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -606,6 +606,7 @@ def prepare_audio_channel_embeddings( ) if speech_eos_mask is not None: + audio_codes_before_speech_eos = audio_codes.clone() # Shift +1 for BOS alignment and +1 more so EOS is injected after the marked frame. B_mask, T_mask = speech_eos_mask.shape shifted_mask = torch.zeros((B_mask, T_mask + 2), dtype=torch.bool, device=device) @@ -630,6 +631,31 @@ def prepare_audio_channel_embeddings( audio_codes_target = audio_codes[:, :, 1:] # (B, C, T'-1) audio_codes_input = audio_codes[:, :, :-1] # (B, C, T'-1) + # Drop some EOS frames from audio input so the model learns recovery when inference misses EOS. + # sample_prob keeps some samples untouched, so the model still learns the normal EOS-input behavior. + if speech_eos_mask is not None and self.training: + drop_eos_sample_prob = float(self.cfg.get("drop_eos_from_audio_input_sample_prob", 0.0)) + drop_eos_frame_prob = float(self.cfg.get("drop_eos_from_audio_input_frame_prob", 0.5)) + + if drop_eos_sample_prob > 0.0 and drop_eos_frame_prob > 0.0: + eos_frame_mask = (audio_codes_input == self.audio_eos_id).any(dim=1) # [B, T] + + sample_drop_mask = torch.rand(batch_size, device=device) < drop_eos_sample_prob # [B] + + frame_drop_mask = ( + eos_frame_mask + & sample_drop_mask.unsqueeze(1) + & (torch.rand_like(eos_frame_mask.float()) < drop_eos_frame_prob) + ) # [B, T] + + audio_codes_input_backup = audio_codes_before_speech_eos[:, :, :-1] + + audio_codes_input = torch.where( + frame_drop_mask.unsqueeze(1), + audio_codes_input_backup, + audio_codes_input, + ) + # deal with agent mask loss_agent_mask = None if agent_mask is not None: @@ -671,7 +697,6 @@ def prepare_audio_channel_embeddings( prefix=f"batch_{self.global_rank}_{self.global_step}", ) - # Replace user/non-agent regions with a learned token. # Important: audio_codes_input predicts audio_codes_target, so input mask must be shifted. if self.cfg.get("use_user_speaking_token", False): From 5bd03d1f688f6ee50637264a57d253aac212477c Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 7 May 2026 04:22:51 -0700 Subject: [PATCH 036/109] Add phoneme multiturn inference support and update silence augmentation Signed-off-by: Edresson Casanova --- .../tts/easy_magpietts_inference_multiturn.py | 79 +++++++++++++++++-- nemo/collections/tts/models/easy_magpietts.py | 12 +-- .../tts/models/easy_magpietts_inference.py | 2 +- 3 files changed, 80 insertions(+), 13 deletions(-) diff --git a/examples/tts/easy_magpietts_inference_multiturn.py b/examples/tts/easy_magpietts_inference_multiturn.py index 2d8f2bc892ea..12439dbd6e83 100644 --- a/examples/tts/easy_magpietts_inference_multiturn.py +++ b/examples/tts/easy_magpietts_inference_multiturn.py @@ -734,7 +734,10 @@ def main(): ) # add text tokens needed for profilling if not model.cfg.get("agent_mask_include_transition_prefix", False): - delay_tokens = int(state.config.training_mode.streaming_speech_delay) + if model.phoneme_tokenizer is not None: + delay_tokens = int(state.config.training_mode.streaming_phonemes_delay) + else: + delay_tokens = int(state.config.training_mode.streaming_speech_delay) delay_tokens = min(delay_tokens, int(turn_lens[0].item()), profile_T) profile_tokens[:, -delay_tokens:] = turn_text[:, :delay_tokens] turn_text = turn_text[:, delay_tokens:] @@ -750,8 +753,13 @@ def main(): f"[profile_multiturn] turn={t} prefilled {profile_T} steps " f"({profile_seconds:.2f}s)" ) - - turn_start_frame = sum(p.size(-1) for p in state.all_predictions) + # turn_start_frame = sum(p.size(-1) for p in state.all_predictions) + # count the two extra delays to remove hallucinations on inference time + if model.phoneme_tokenizer: + turn_start_frame = sum(p.size(-1) for p in state.all_predictions) + (int(state.config.training_mode.streaming_speech_delay)-int(state.config.training_mode.streaming_phonemes_delay)) + else: + turn_start_frame = sum(p.size(-1) for p in state.all_predictions) + if t == 0: state.audio_prediction_start_idx.fill_(turn_start_frame) profile_decode_start_frame = turn_start_frame @@ -778,12 +786,68 @@ def main(): torch.full_like(current_tokens, model.eos_id), current_tokens, ) - + # continue the profilling step until all the delays tokens are consumed if phoneme channel is used + if model.phoneme_tokenizer is not None and not model.cfg.get("agent_mask_include_transition_prefix", False) and turn_steps <= (int(state.config.training_mode.streaming_speech_delay)-int(state.config.training_mode.streaming_phonemes_delay)): + C = model.num_audio_codebooks + S = model.frame_stacking_factor + T = 1 + sil_codes = model.codec_sil_codes.to(device=device, dtype=torch.long) # (C,) + sil_codes_unstacked = sil_codes.view(1, C, 1).expand(B, C, T * S).contiguous() + profile_audio_stacked, _ = model.stack_codes( + sil_codes_unstacked, + torch.full((B,), T * S, dtype=torch.long, device=device), + bos_id=model.audio_bos_id, + eos_id=model.audio_eos_id, + stacking_factor=S, + num_codebooks=C, + ) # (B, C*S, T) + + # Feed the pad on speech channel as part of the prediction while profilling until we finish the whole delay + if turn_steps < (int(state.config.training_mode.streaming_speech_delay)-int(state.config.training_mode.streaming_phonemes_delay)): + if model.cfg.get("use_user_speaking_end_token", False) and not model.cfg.get("agent_mask_include_transition_prefix", False): + state.last_audio_codes = torch.full( + state.last_audio_codes.shape, + model.audio_user_speaking_end_id, + dtype=torch.long, + device=device, + ) + elif model.cfg.get("use_user_speaking_token", False): + state.last_audio_codes = torch.full( + state.last_audio_codes.shape, + model.audio_user_speaking_id, + dtype=torch.long, + device=device, + ) + else: + state.last_audio_codes = profile_audio_stacked[:, :, -1].contiguous() + + elif turn_steps == (int(state.config.training_mode.streaming_speech_delay)-int(state.config.training_mode.streaming_phonemes_delay)): + if model.cfg.get("use_user_speaking_end_token", False) and not model.cfg.get("agent_mask_include_transition_prefix", False): + state.last_audio_codes = torch.full( + state.last_audio_codes.shape, + model.audio_user_speaking_end_id, + dtype=torch.long, + device=device, + ) + elif model.cfg.get("use_user_speaking_token", False): + state.last_audio_codes = torch.full( + state.last_audio_codes.shape, + model.audio_user_speaking_id, + dtype=torch.long, + device=device, + ) + else: + state.last_audio_codes = profile_audio_stacked[:, :, -1].contiguous() + + # ToDo: we need to feed the user audio embedding to streaming step as well state, audio_codes, _ = model.streaming_step( state=state, text_tokens=current_tokens, use_inference_mode=True, ) + # Replace predicted delay tokens with silence + if model.phoneme_tokenizer is not None not model.cfg.get("agent_mask_include_transition_prefix", False) and turn_steps <= (int(state.config.training_mode.streaming_speech_delay)-int(state.config.training_mode.streaming_phonemes_delay)): + state.all_predictions[-1] = codec_sil_codes.view(1, -1, 1).expand_as(state.all_predictions[-1]) if audio_codes is not None and not saw_audio: saw_audio = True @@ -817,12 +881,15 @@ def main(): speaking_id = getattr(model, "audio_user_speaking_id", -1) speaking_end_id = getattr(model, "audio_user_speaking_end_id", -1) - sil_injection = codec_sil_codes.view(1, -1, 1) + context_audio_bos_id = getattr(model, "context_audio_bos_id", -1) + context_audio_eos_id = getattr(model, "context_audio_eos_id", -1) + mask_token_id = getattr(model, "mask_token_id", -1) + sil_injection = codec_sil_codes.view(1, -1, 1) for step_idx in range(len(state.all_predictions)): pred = state.all_predictions[step_idx] # Check if any codebook in the frame has any special token - mask = (pred == bos_id) | (pred == eos_id) | (pred == speaking_id) | (pred == speaking_end_id) + mask = (pred == bos_id) | (pred == eos_id) | (pred == speaking_id) | (pred == speaking_end_id) | (pred == context_audio_bos_id) | (pred == context_audio_eos_id) | (pred == mask_token_id) frame_mask = mask.any(dim=1, keepdim=True) if frame_mask.any(): diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index c7be0e1236a7..5037fd72ba58 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -1150,8 +1150,8 @@ def training_step(self, batch, batch_idx): audio_codes_lens_max = audio_codes_lens.max() # 1. Calculate the raw shift (with the -1 safety buffer) - raw_pad_lens = torch.clamp(audio_codes_lens_max - audio_codes_lens - 1, min=0) - + raw_pad_lens = torch.clamp(audio_codes_lens_max - audio_codes_lens - 4, min=0) + # 2. Round DOWN to the nearest multiple of the stacking factor pad_lens = (raw_pad_lens // self.frame_stacking_factor) * self.frame_stacking_factor @@ -1165,13 +1165,13 @@ def training_step(self, batch, batch_idx): # --- Vectorized Audio Shift --- idx_a = torch.arange(T_audio, device=device).unsqueeze(0) src_idx_a = idx_a - pad_lens.unsqueeze(1) - + valid_mask_a = (src_idx_a >= 0) & (src_idx_a < audio_codes_lens.unsqueeze(1)) safe_src_idx_a = src_idx_a.clamp(min=0, max=T_audio - 1) - + safe_src_idx_a_exp = safe_src_idx_a.unsqueeze(1).expand(-1, C, -1) valid_mask_a_exp = valid_mask_a.unsqueeze(1).expand(-1, C, -1) - + gathered_audio = torch.gather(audio_codes, 2, safe_src_idx_a_exp) silence_pad = self.codec_sil_codes_unconverted.view(1, C, 1).expand(B, C, T_audio) @@ -1190,7 +1190,7 @@ def training_step(self, batch, batch_idx): valid_mask_t = (src_idx_t >= 0) & (src_idx_t < text_lens.unsqueeze(1)) safe_src_idx_t = src_idx_t.clamp(min=0, max=old_text.size(1) - 1) gathered_text = torch.gather(old_text, 1, safe_src_idx_t) - + batch['text'] = torch.where(valid_mask_t, gathered_text, self.pad_id) batch['text_lens'] = new_text_lens diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index 98aca9f69e5e..55d79f40e32a 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -768,7 +768,7 @@ def streaming_prefill_profile( # Make the next normal streaming_step continue from silence, not AUDIO_BOS. state.all_predictions.append(sil_codes_unstacked) # keep silence so that in the target audio user will be silence - if self.cfg.get("use_user_speaking_end_token", False) and not self.cfg.get("agent_mask_include_transition_prefix", False): + if self.cfg.get("use_user_speaking_end_token", False) and not self.cfg.get("agent_mask_include_transition_prefix", False) and self.phoneme_tokenizer is None: state.last_audio_codes = torch.full( (B, C * S), self.audio_user_speaking_end_id, From 70cf55947f54b6ff822c9f2a83413ec16f283f4c Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 8 May 2026 09:05:30 -0700 Subject: [PATCH 037/109] Fix sil augmentation on formatter Signed-off-by: Edresson Casanova --- nemo/collections/common/data/lhotse/cutset.py | 414 ++++++------------ ...text_to_speech_dataset_lhotse_multiturn.py | 2 +- nemo/collections/tts/models/easy_magpietts.py | 1 + 3 files changed, 125 insertions(+), 292 deletions(-) diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index ae039d3f2d3d..751cc87fee7f 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -816,102 +816,93 @@ def cut_to_conversation( custom=cut.custom, ) +@data_type_parser(["lhotse_magpietts_data_as_continuation"]) +def read_lhotse_magpietts_data_as_s2s_duplex(config) -> Tuple[CutSet, bool]: + """ + Convert MagpieTTS dataset cuts into the Duplex S2S format, with optional + `context_audio` that can be used as a speaker reference. -class FilterCER: - def __init__(self, max_cer: float): - self.max_cer = max_cer - - def __call__(self, cut: Cut) -> bool: - return ( - len(cut.supervisions) == 0 - or not cut.supervisions[0].has_custom("cer") - or cut.supervisions[0].cer <= self.max_cer - ) - -class FilterValFlag: - def __init__(self, keep_flag: str): - self.keep_flag = keep_flag - - def __call__(self, cut: Cut) -> bool: - return not cut.has_custom("validation_status") or cut.validation_status == self.keep_flag - -class FilterSecs: - def __init__(self, min_sim: float): - self.min_sim = min_sim - - def __call__(self, cut: Cut) -> bool: - return ( - len(cut.supervisions) == 0 - or not cut.supervisions[0].has_custom("context_speaker_similarity") - or cut.supervisions[0].context_speaker_similarity >= self.min_sim - ) - -class FilterTargetSpeaker: - def __init__(self, target_speaker: str): - self.target_speaker = target_speaker - - def __call__(self, cut: Cut) -> bool: - return len(cut.supervisions) == 0 or self.target_speaker is None or self.target_speaker in cut.supervisions[0].speaker + Args: + config: Dictionary containing parser options: + - add_extra_end_silence (bool): Whether to add extra silence at the end. + - extra_end_silence_range (List[float]): Range of extra silence duration. + - max_cer (float): Maximum allowed character error rate. + - min_context_speaker_similarity (float): Minimum similarity score. + - target_speaker (str, optional): Target speaker filter. + - sample_rate (int): Audio sample rate for resampling. -def _create_recording_from_array(samples: np.ndarray, sampling_rate: int, recording_id: str) -> Recording: - with io.BytesIO() as buffer: - sf.write(buffer, samples.T, samplerate=sampling_rate, format='WAV') - buffer.seek(0) - return Recording.from_bytes(buffer.read(), recording_id=recording_id) + Returns: + Tuple[CutSet, bool]: Converted cuts and a flag indicating if data was tarred. + """ + cuts, is_tarred = read_cutset_from_config(config) -def _prepend_silence_monocut( - cut: MonoCut, - sil_duration: float, - sample_rate: int, - recording_id: str, - cut_id: str, -) -> MonoCut: - """Helper to pad silence at the beginning of a monocut.""" - audio = cut.load_audio() # (C, N) - n_pad = int(round(sil_duration * sample_rate)) - if n_pad <= 0: - return cut + add_extra_end_sil = config.get("add_extra_end_silence", False) + extra_end_silence_range = config.get("extra_end_silence_range", [0.5, 6.0]) + add_extra_begin_sil = config.get("add_extra_begin_sil", False) + extra_begin_silence_range = config.get("extra_begin_silence_range", [0.5, 6.0]) + sample_rate = config.get("sample_rate", 22050) - pad = np.zeros((audio.shape[0], n_pad), dtype=audio.dtype) - audio2 = np.concatenate([pad, audio], axis=1) + max_cer = config.get("max_cer", 0.03) + min_context_speaker_similarity = config.get("min_context_speaker_similarity", 0.6) + target_speaker = config.get("target_speaker", None) + keep_flag = "pass" - rec = _create_recording_from_array(audio2, sample_rate, recording_id=recording_id) - return MonoCut( - id=cut_id, - start=0.0, - duration=audio2.shape[1] / sample_rate, - channel=0, - recording=rec, - supervisions=[], - ).move_to_memory(audio_format="wav") - -class ConvertCutFn: - def __init__( - self, - sample_rate: int, - add_extra_end_sil: bool, - extra_end_silence_range: list, - add_extra_begin_sil: bool, - extra_begin_silence_range: list - ): - self.sample_rate = sample_rate - self.add_extra_end_sil = add_extra_end_sil - self.extra_end_silence_range = extra_end_silence_range - self.add_extra_begin_sil = add_extra_begin_sil - self.extra_begin_silence_range = extra_begin_silence_range - - def __call__(self, cut: Cut) -> Cut: + def create_recording_from_array(samples: np.ndarray, sampling_rate: int, recording_id: str) -> Recording: + """Convert a numpy array into a Lhotse Recording object.""" + with io.BytesIO() as buffer: + sf.write(buffer, samples.T, samplerate=sampling_rate, format='WAV') + buffer.seek(0) + return Recording.from_bytes(buffer.read(), recording_id=recording_id) + + def materialize_to_monocut(cut_like: Cut, cut_id: str, sample_rate: int) -> MonoCut: + audio = cut_like.load_audio() # renders mix -> (C, N) + rec = create_recording_from_array(audio, sample_rate, recording_id=f"{cut_id}_rec") + return MonoCut( + id=cut_id, + start=0.0, + duration=cut_like.duration, + channel=0, + recording=rec, + supervisions=[], + ).move_to_memory(audio_format="wav") + + def prepend_silence_monocut( + cut: MonoCut, + sil_duration: float, + sample_rate: int, + recording_id: str, + cut_id: str, + ) -> MonoCut: + audio = cut.load_audio() # (C, N) + n_pad = int(round(sil_duration * sample_rate)) + if n_pad <= 0: + return cut + + pad = np.zeros((audio.shape[0], n_pad), dtype=audio.dtype) + audio2 = np.concatenate([pad, audio], axis=1) + + rec = create_recording_from_array(audio2, sample_rate, recording_id=recording_id) + return MonoCut( + id=cut_id, + start=0.0, + duration=audio2.shape[1] / sample_rate, + channel=0, + recording=rec, + supervisions=[], + ).move_to_memory(audio_format="wav") + + def convert_cut_fn(cut: Cut) -> Cut: + """Convert a single cut into the continuation format.""" orig_agent_sup = fastcopy(cut.supervisions[0]) target_audio_orig_dur = cut.target_audio.duration - cut.target_audio = cut.target_audio.resample(self.sample_rate) - - # --- SAFELY CHECK FOR CONTEXT AUDIO --- + # Resample audios + cut.target_audio = cut.target_audio.resample(sample_rate) if cut.has_custom("context_audio"): - cut.context_audio = cut.context_audio.resample(self.sample_rate) - + cut.context_audio = cut.context_audio.resample(sample_rate) total_duration = cut.target_audio.duration + # Prepare MonoCuts cut_target = MonoCut( id=f"{cut.id}_target", start=0.0, @@ -921,8 +912,8 @@ def __call__(self, cut: Cut) -> Cut: supervisions=[], ) - zero_audio = np.zeros((1, int(total_duration * self.sample_rate)), dtype=np.float32) - source_recording = _create_recording_from_array(zero_audio, self.sample_rate, recording_id=f"{cut.id}_source") + zero_audio = np.zeros((1, int(total_duration * sample_rate)), dtype=np.float32) + source_recording = create_recording_from_array(zero_audio, sample_rate, recording_id=f"{cut.id}_source") cut_source = MonoCut( id=f"{cut.id}_source", @@ -931,22 +922,19 @@ def __call__(self, cut: Cut) -> Cut: channel=0, recording=source_recording, supervisions=[], - custom=deepcopy(cut.custom) if cut.custom is not None else None, ) + # Save to memory cut_source = cut_source.move_to_memory(audio_format='wav') cut_target = cut_target.move_to_memory(audio_format='wav') + # Create user and agent supervisions user_sup = fastcopy(orig_agent_sup, start=0.0, duration=0.08, speaker="user", text="dummy text") agent_sup = fastcopy(orig_agent_sup, start=0.0, duration=target_audio_orig_dur - 0.08, speaker="agent") - if user_sup.custom is not None and "ipa" in user_sup.custom: - user_sup.custom = deepcopy(user_sup.custom) - user_sup.custom["ipa"] = "" - # Optionally add extra silence on the end - if self.add_extra_end_sil: - sil_duration = random.uniform(*self.extra_end_silence_range) + if add_extra_end_sil: + sil_duration = random.uniform(*extra_end_silence_range) cut_target = cut_target.pad(duration=total_duration + sil_duration, direction="right") cut_source = cut_source.pad(duration=total_duration + sil_duration, direction="right") cut_source = cut_source.to_mono().move_to_memory(audio_format='wav') @@ -955,14 +943,18 @@ def __call__(self, cut: Cut) -> Cut: user_sup.duration += sil_duration # Optionally add extra silence on the start - if self.add_extra_begin_sil: - sil_duration = random.uniform(*self.extra_begin_silence_range) - cut_target = _prepend_silence_monocut( - cut_target, sil_duration, self.sample_rate, + if add_extra_begin_sil: + sil_duration = random.uniform(*extra_begin_silence_range) + # Pad both streams on the left (adds zeros at the start) + prev_target_dur = cut_target.duration + prev_source_dur = cut_source.duration + # prepend zeros explicitly + cut_target = prepend_silence_monocut( + cut_target, sil_duration, sample_rate, recording_id=f"{cut.id}_target_pre", cut_id=f"{cut.id}_target" ) - cut_source = _prepend_silence_monocut( - cut_source, sil_duration, self.sample_rate, + cut_source = prepend_silence_monocut( + cut_source, sil_duration, sample_rate, recording_id=f"{cut.id}_source_pre", cut_id=f"{cut.id}_source" ) @@ -970,211 +962,52 @@ def __call__(self, cut: Cut) -> Cut: user_sup.start += sil_duration agent_sup.start += sil_duration + # Assemble final cut cut_source.supervisions = [user_sup, agent_sup] cut_source.target_audio = cut_target.recording cut_source.duration = cut_target.duration - if cut.has_custom("context_audio"): cut_source.context_audio = cut.context_audio - - cut_source.task = "lhotse_magpietts_data_as_continuation" + if cut.has_custom("context_codes"): + cut_source.context_codes = cut.context_codes + if cut.has_custom("target_codes"): + cut_source.target_codes = cut.target_codes + if cut.has_custom("lang"): + cut_source.lang = cut_source.lang + if cut.has_custom("ipa"): + cut_source.ipa = cut_source.ipa + cut_source.formatter = "lhotse_magpietts_data_as_continuation" return cut_source + # Filters + def filter_cer_fn(cut: Cut) -> bool: + return ( + len(cut.supervisions) == 0 + or not cut.supervisions[0].has_custom("cer") + or cut.supervisions[0].cer <= max_cer + ) -@data_type_parser(["s2s_duplex_overlap_as_s2s_duplex"]) -def read_s2s_duplex_overlap_as_s2s_duplex(config) -> Tuple[CutSet, bool]: - """ - Convert a CutSet with overlapping agent/user segments into a standard S2S duplex format. - - Use Case: - This parser is designed for conversational data where agent and user speech can overlap - in time (e.g., natural turn-taking with interruptions or backchanneling). The input - format stores agent and user segments separately as `agent_segments` and `user_segments` - attributes on each cut. This function converts them into a unified timeline of sequential - SupervisionSegments, which is the standard format expected by DuplexS2S models. - - Expected Input Data Format: - Each cut should have: - - cut.agent_segments: List[Dict] with keys: - - "start" (float): Start time in seconds - - "end" (float): End time in seconds - - "text" (str): Agent's transcription - - cut.user_segments: List[Dict] with keys: - - "start" (float): Start time in seconds - - "end" (float): End time in seconds - - "text" (str): User's transcription - - Example: - Input cut with overlapping segments: - cut.agent_segments = [ - {"start": 0.5, "end": 2.0, "text": "Hello, how can I help?"}, - {"start": 3.0, "end": 4.5, "text": "Sure, I can do that."} - ] - cut.user_segments = [ - {"start": 1.8, "end": 3.2, "text": "I need assistance"}, - {"start": 4.0, "end": 5.5, "text": "Thank you"} - ] - - Output cut.supervisions (sorted by start time): - [ - SupervisionSegment(start=0.5, duration=1.5, text="Hello, how can I help?", speaker="agent"), - SupervisionSegment(start=1.8, duration=1.4, text="I need assistance", speaker="user"), - SupervisionSegment(start=3.0, duration=1.5, text="Sure, I can do that.", speaker="agent"), - SupervisionSegment(start=4.0, duration=1.5, "Thank you", speaker="user") - ] - - Args: - config: Dictionary containing parser options: - - move_agent_text_back_by (float): Time offset to shift agent text back (default: 0). - Useful for aligning agent text with earlier audio timing. - - filter_samples_starting_with_agent (bool): Whether to remove samples starting with agent (default: False). - When True, only keeps samples where the first speaker is a user. - - agent_roles (List[str]): Roles considered as agent (default: ["agent", "Assistant", "assistant"]). - - Returns: - Tuple[CutSet, bool]: Converted cuts with unified supervisions, and a flag indicating if the data was tarred. - """ - move_agent_text_back_by = config.get("move_agent_text_back_by", 0) - filter_samples_starting_with_agent = config.get("filter_samples_starting_with_agent", False) - agent_roles = config.get("agent_roles", ["agent", "Assistant", "assistant"]) - min_number_of_turns = int(config.get("min_number_of_turns", 0)) - max_gap_duration_collapse_turns = config.get("max_gap_duration_collapse_turns", None) - cuts, is_tarred = read_cutset_from_config(config) - - def filter_cuts_starting_with_agent_fn(cuts: CutSet, agent_roles: Tuple[str, ...]) -> CutSet: - """Remove cuts where the first supervision belongs to an agent role.""" - - def _filter_fn(cut: Cut) -> bool: - if not cut.supervisions: - return False - cut.supervisions = sorted(cut.supervisions, key=lambda s: s.start) - return cut.supervisions[0].speaker not in agent_roles - - return cuts.filter(_filter_fn) - - def filter_min_agent_turns_fn(cuts: CutSet, min_number_of_turns: int, agent_roles: Tuple[str, ...]) -> CutSet: - """Keep cuts with at least `min_number_of_turns` agent turns.""" - if min_number_of_turns <= 0: - return cuts - - def _filter_fn(cut: Cut) -> bool: - num_agent_turns = sum(s.speaker in agent_roles for s in cut.supervisions) - if not num_agent_turns >= min_number_of_turns: - logging.info( - f"[Parser] Filtering cut={cut.id}: " - f"agent_turns={num_agent_turns} < min_number_of_turns={min_number_of_turns}" - ) - return num_agent_turns >= min_number_of_turns - - return cuts.filter(_filter_fn) - - def collapse_adjacent_same_speaker(supervisions, max_gap): - if max_gap is None or max_gap <= 0: - return supervisions - - supervisions = sorted(supervisions, key=lambda s: (s.start, s.end)) - collapsed = [] - - for s in supervisions: - if not collapsed: - collapsed.append(s) - continue - - prev = collapsed[-1] - gap = s.start - (prev.start + prev.duration) - - if s.speaker == prev.speaker: - if gap < max_gap: - # MERGE - new_end = max(prev.start + prev.duration, s.start + s.duration) - prev.duration = new_end - prev.start - prev.text = f"{prev.text} {s.text}".strip() - else: - collapsed.append(s) - else: - collapsed.append(s) - return collapsed - - def convert_overlap_cut_fn(cut: Cut) -> Cut: - """Convert agent/user overlapping segments into sequential SupervisionSegments.""" - agent_segments = [ - SupervisionSegment( - id=cut.id, - recording_id=cut.id, - start=seg["start"] - move_agent_text_back_by, - duration=seg["end"] - seg["start"] + move_agent_text_back_by, - text=seg["text"], - speaker="agent", - ) - for seg in cut.agent_segments - ] - - user_segments = [ - SupervisionSegment( - id=cut.id, - recording_id=cut.id, - start=seg["start"], - duration=seg["end"] - seg["start"], - text=seg["text"], - speaker="user", - ) - for seg in cut.user_segments - ] - - cut.supervisions = sorted(agent_segments + user_segments, key=lambda s: s.start) - cut.task = "s2s_duplex_overlap_as_s2s_duplex" - if max_gap_duration_collapse_turns is not None: - cut.supervisions = collapse_adjacent_same_speaker( - cut.supervisions, - max_gap_duration_collapse_turns, - ) - - return cut - - cuts = cuts.map(convert_overlap_cut_fn) - - # Force materialization for accurate counting - if filter_samples_starting_with_agent: - cuts = filter_cuts_starting_with_agent_fn(cuts, tuple(agent_roles)) - - if min_number_of_turns > 0: - cuts = filter_min_agent_turns_fn(cuts, min_number_of_turns, tuple(agent_roles)) - - return cuts, is_tarred - - -@data_type_parser(["lhotse_magpietts_data_as_continuation"]) -def read_lhotse_magpietts_data_as_s2s_duplex(config) -> Tuple[CutSet, bool]: - cuts, is_tarred = read_cutset_from_config(config) + def filter_val_flag_fn(cut: Cut) -> bool: + return not cut.has_custom("validation_status") or cut.validation_status == keep_flag - add_extra_end_sil = config.get("add_extra_end_silence", False) - extra_end_silence_range = config.get("extra_end_silence_range", [0.5, 6.0]) - add_extra_begin_sil = config.get("add_extra_begin_sil", False) - extra_begin_silence_range = config.get("extra_begin_silence_range", [0.5, 6.0]) - sample_rate = config.get("sample_rate", 22050) + def filter_secs_fn(cut: Cut) -> bool: + return ( + len(cut.supervisions) == 0 + or not cut.supervisions[0].has_custom("context_speaker_similarity") + or cut.supervisions[0].context_speaker_similarity >= min_context_speaker_similarity + ) - max_cer = config.get("max_cer", 0.03) - min_context_speaker_similarity = config.get("min_context_speaker_similarity", 0.6) - target_speaker = config.get("target_speaker", None) - keep_flag = "pass" + def filter_target_speaker_fn(cut: Cut) -> bool: + return len(cut.supervisions) == 0 or target_speaker is None or target_speaker in cut.supervisions[0].speaker - # Use the globally defined classes + # Apply filters cuts = ( - cuts.filter(FilterCER(max_cer)) - .filter(FilterValFlag(keep_flag)) - .filter(FilterSecs(min_context_speaker_similarity)) - .filter(FilterTargetSpeaker(target_speaker)) + cuts.filter(filter_cer_fn).filter(filter_val_flag_fn).filter(filter_secs_fn).filter(filter_target_speaker_fn) ) - # Pass the beginning silence configs to the updated ConvertCutFn - cuts = cuts.map(ConvertCutFn( - sample_rate=sample_rate, - add_extra_end_sil=add_extra_end_sil, - extra_end_silence_range=extra_end_silence_range, - add_extra_begin_sil=add_extra_begin_sil, - extra_begin_silence_range=extra_begin_silence_range - )) + # Convert cuts + cuts = cuts.map(convert_cut_fn) return cuts, is_tarred @@ -1183,7 +1016,6 @@ def read_lhotse_magpietts_data_as_s2s_duplex(config) -> Tuple[CutSet, bool]: def read_s2s_duplex_reverse_role(config) -> Tuple[CutSet, bool]: """ Reverse the speaker roles and swap the source/target audio streams in a Duplex S2S CutSet. - This parser takes an existing conversational dataset and inverts the perspective by swapping the "user" and "agent" supervision labels. It also swaps the primary `recording` (usually source audio) with the `target_audio` to fully simulate the diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py index f88b0685cc7b..9d66f086fb5c 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py @@ -176,7 +176,7 @@ def __init__( add_text_bos: bool = False, remove_user_turns_prob: float = None, ): - super().__init__() + # super().__init__() self.sample_rate = sample_rate self.volume_norm = volume_norm diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index 5037fd72ba58..7b1aee9ab818 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -1143,6 +1143,7 @@ def training_step(self, batch, batch_idx): audio_lens = batch['audio_lens'] audio_codes, audio_codes_lens = self._codec_helper.audio_to_codes(audio, audio_lens) + # augment tts data to looks more like multiturn data by adding pad on the begining and emulating user speaking. if self.cfg.get("use_multiturn_dataset", False) and "tts" in batch['task']: prob = self.cfg.get("add_tts_sil_begining_prob", 0.0) From 4bbfd695647495d50ece5e2fb4f920761eb19368 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 8 May 2026 12:26:17 -0700 Subject: [PATCH 038/109] Fix inference Signed-off-by: Edresson Casanova --- examples/tts/easy_magpietts_inference_multiturn.py | 2 +- nemo/collections/tts/models/easy_magpietts.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/tts/easy_magpietts_inference_multiturn.py b/examples/tts/easy_magpietts_inference_multiturn.py index 12439dbd6e83..985cf1fc66f3 100644 --- a/examples/tts/easy_magpietts_inference_multiturn.py +++ b/examples/tts/easy_magpietts_inference_multiturn.py @@ -846,7 +846,7 @@ def main(): use_inference_mode=True, ) # Replace predicted delay tokens with silence - if model.phoneme_tokenizer is not None not model.cfg.get("agent_mask_include_transition_prefix", False) and turn_steps <= (int(state.config.training_mode.streaming_speech_delay)-int(state.config.training_mode.streaming_phonemes_delay)): + if model.phoneme_tokenizer is not None and not model.cfg.get("agent_mask_include_transition_prefix", False) and turn_steps <= (int(state.config.training_mode.streaming_speech_delay)-int(state.config.training_mode.streaming_phonemes_delay)): state.all_predictions[-1] = codec_sil_codes.view(1, -1, 1).expand_as(state.all_predictions[-1]) if audio_codes is not None and not saw_audio: diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index 7b1aee9ab818..3753cf901d65 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -1142,7 +1142,6 @@ def training_step(self, batch, batch_idx): audio = batch['audio'] audio_lens = batch['audio_lens'] audio_codes, audio_codes_lens = self._codec_helper.audio_to_codes(audio, audio_lens) - # augment tts data to looks more like multiturn data by adding pad on the begining and emulating user speaking. if self.cfg.get("use_multiturn_dataset", False) and "tts" in batch['task']: From cc38c611db0b500b1ace11b73f1d66221fdaddcc Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 11 May 2026 12:43:45 -0700 Subject: [PATCH 039/109] Add new inference script and fix data formatter Signed-off-by: Edresson Casanova --- .../tts/easy_magpietts_inference_multiturn.py | 79 +- ...gpietts_inference_multiturn_new_prefill.py | 1186 +++++++++++++++++ .../tts/models/easy_magpietts_inference.py | 626 ++++++++- 3 files changed, 1817 insertions(+), 74 deletions(-) create mode 100644 examples/tts/easy_magpietts_inference_multiturn_new_prefill.py diff --git a/examples/tts/easy_magpietts_inference_multiturn.py b/examples/tts/easy_magpietts_inference_multiturn.py index 985cf1fc66f3..2d8f2bc892ea 100644 --- a/examples/tts/easy_magpietts_inference_multiturn.py +++ b/examples/tts/easy_magpietts_inference_multiturn.py @@ -734,10 +734,7 @@ def main(): ) # add text tokens needed for profilling if not model.cfg.get("agent_mask_include_transition_prefix", False): - if model.phoneme_tokenizer is not None: - delay_tokens = int(state.config.training_mode.streaming_phonemes_delay) - else: - delay_tokens = int(state.config.training_mode.streaming_speech_delay) + delay_tokens = int(state.config.training_mode.streaming_speech_delay) delay_tokens = min(delay_tokens, int(turn_lens[0].item()), profile_T) profile_tokens[:, -delay_tokens:] = turn_text[:, :delay_tokens] turn_text = turn_text[:, delay_tokens:] @@ -753,13 +750,8 @@ def main(): f"[profile_multiturn] turn={t} prefilled {profile_T} steps " f"({profile_seconds:.2f}s)" ) - # turn_start_frame = sum(p.size(-1) for p in state.all_predictions) - # count the two extra delays to remove hallucinations on inference time - if model.phoneme_tokenizer: - turn_start_frame = sum(p.size(-1) for p in state.all_predictions) + (int(state.config.training_mode.streaming_speech_delay)-int(state.config.training_mode.streaming_phonemes_delay)) - else: - turn_start_frame = sum(p.size(-1) for p in state.all_predictions) - + + turn_start_frame = sum(p.size(-1) for p in state.all_predictions) if t == 0: state.audio_prediction_start_idx.fill_(turn_start_frame) profile_decode_start_frame = turn_start_frame @@ -786,68 +778,12 @@ def main(): torch.full_like(current_tokens, model.eos_id), current_tokens, ) - # continue the profilling step until all the delays tokens are consumed if phoneme channel is used - if model.phoneme_tokenizer is not None and not model.cfg.get("agent_mask_include_transition_prefix", False) and turn_steps <= (int(state.config.training_mode.streaming_speech_delay)-int(state.config.training_mode.streaming_phonemes_delay)): - C = model.num_audio_codebooks - S = model.frame_stacking_factor - T = 1 - sil_codes = model.codec_sil_codes.to(device=device, dtype=torch.long) # (C,) - sil_codes_unstacked = sil_codes.view(1, C, 1).expand(B, C, T * S).contiguous() - profile_audio_stacked, _ = model.stack_codes( - sil_codes_unstacked, - torch.full((B,), T * S, dtype=torch.long, device=device), - bos_id=model.audio_bos_id, - eos_id=model.audio_eos_id, - stacking_factor=S, - num_codebooks=C, - ) # (B, C*S, T) - - # Feed the pad on speech channel as part of the prediction while profilling until we finish the whole delay - if turn_steps < (int(state.config.training_mode.streaming_speech_delay)-int(state.config.training_mode.streaming_phonemes_delay)): - if model.cfg.get("use_user_speaking_end_token", False) and not model.cfg.get("agent_mask_include_transition_prefix", False): - state.last_audio_codes = torch.full( - state.last_audio_codes.shape, - model.audio_user_speaking_end_id, - dtype=torch.long, - device=device, - ) - elif model.cfg.get("use_user_speaking_token", False): - state.last_audio_codes = torch.full( - state.last_audio_codes.shape, - model.audio_user_speaking_id, - dtype=torch.long, - device=device, - ) - else: - state.last_audio_codes = profile_audio_stacked[:, :, -1].contiguous() - - elif turn_steps == (int(state.config.training_mode.streaming_speech_delay)-int(state.config.training_mode.streaming_phonemes_delay)): - if model.cfg.get("use_user_speaking_end_token", False) and not model.cfg.get("agent_mask_include_transition_prefix", False): - state.last_audio_codes = torch.full( - state.last_audio_codes.shape, - model.audio_user_speaking_end_id, - dtype=torch.long, - device=device, - ) - elif model.cfg.get("use_user_speaking_token", False): - state.last_audio_codes = torch.full( - state.last_audio_codes.shape, - model.audio_user_speaking_id, - dtype=torch.long, - device=device, - ) - else: - state.last_audio_codes = profile_audio_stacked[:, :, -1].contiguous() - - # ToDo: we need to feed the user audio embedding to streaming step as well + state, audio_codes, _ = model.streaming_step( state=state, text_tokens=current_tokens, use_inference_mode=True, ) - # Replace predicted delay tokens with silence - if model.phoneme_tokenizer is not None and not model.cfg.get("agent_mask_include_transition_prefix", False) and turn_steps <= (int(state.config.training_mode.streaming_speech_delay)-int(state.config.training_mode.streaming_phonemes_delay)): - state.all_predictions[-1] = codec_sil_codes.view(1, -1, 1).expand_as(state.all_predictions[-1]) if audio_codes is not None and not saw_audio: saw_audio = True @@ -881,15 +817,12 @@ def main(): speaking_id = getattr(model, "audio_user_speaking_id", -1) speaking_end_id = getattr(model, "audio_user_speaking_end_id", -1) - context_audio_bos_id = getattr(model, "context_audio_bos_id", -1) - context_audio_eos_id = getattr(model, "context_audio_eos_id", -1) - mask_token_id = getattr(model, "mask_token_id", -1) - sil_injection = codec_sil_codes.view(1, -1, 1) + for step_idx in range(len(state.all_predictions)): pred = state.all_predictions[step_idx] # Check if any codebook in the frame has any special token - mask = (pred == bos_id) | (pred == eos_id) | (pred == speaking_id) | (pred == speaking_end_id) | (pred == context_audio_bos_id) | (pred == context_audio_eos_id) | (pred == mask_token_id) + mask = (pred == bos_id) | (pred == eos_id) | (pred == speaking_id) | (pred == speaking_end_id) frame_mask = mask.any(dim=1, keepdim=True) if frame_mask.any(): diff --git a/examples/tts/easy_magpietts_inference_multiturn_new_prefill.py b/examples/tts/easy_magpietts_inference_multiturn_new_prefill.py new file mode 100644 index 000000000000..30aedefb89ae --- /dev/null +++ b/examples/tts/easy_magpietts_inference_multiturn_new_prefill.py @@ -0,0 +1,1186 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +""" +Evaluation script for custom EasyMagpieTTS models. +Features explicit Duplex (10x Padding) and Regular (Turn-by-turn) multi-turn modes. + +Usage: + python easy_magpietts_eval.py \ + --checkpoint_path=/path/to/magpie/model.ckpt \ + --codec_model_path=/path/to/codec/model.ckpt \ + --datasets_json_path=/path/to/evalset_config.jsonl \ + --out_dir=/path/to/out/audio \ + --batch_size=6 \ + --use_cfg \ + --use_librosa +""" + +import argparse +import json +import os +from copy import deepcopy +from functools import partial + +import librosa +import soundfile as sf +import torch +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import DataLoader, Dataset +from omegaconf import OmegaConf, open_dict + +from nemo.collections.audio.parts.utils.transforms import resample +from nemo.collections.speechlm2.parts.metrics.asr_cer_wer import Intelligibility +from nemo.collections.speechlm2.parts.metrics.secs import SECS +from nemo.collections.speechlm2.parts.precision import fp32_precision +from nemo.utils import logging + +# --- EasyMagpieTTS Imports --- +from nemo.collections.tts.models import AudioCodecModel +from nemo.collections.tts.modules.audio_codec_modules import VectorQuantizerIndexConverter +from nemo.collections.tts.modules.magpietts_modules import CodecHelper +from nemo.collections.tts.models.easy_magpietts_inference import EasyMagpieTTSInferenceModel + +from nemo.collections.tts.parts.utils.tts_dataset_utils import normalize_volume + +torch.set_float32_matmul_precision("medium") +torch.backends.cudnn.allow_tf32 = True +torch.backends.cuda.matmul.allow_tf32 = True + +if torch.cuda.is_available(): + torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0))) + + +def attach_dtype_counter(model): + handles = [] + stats = {} + examples = {} + + def is_leaf(module): + return len(list(module.children())) == 0 + + def get_dtype(x): + if torch.is_tensor(x): + return str(x.dtype) + elif isinstance(x, (list, tuple)): + for t in x: + if torch.is_tensor(t): + return str(t.dtype) + return "other" + + def get_module_group(name): + return name.split(".")[0] if "." in name else name + + def hook_fn(name): + def fn(module, inputs, outputs): + dtype = get_dtype(outputs) + if dtype not in ["torch.float16", "torch.bfloat16", "torch.float32"]: + dtype = "other" + group = get_module_group(name) + if group not in stats: + stats[group] = { + "torch.float16": 0, "torch.bfloat16": 0, "torch.float32": 0, "other": 0, + } + examples[group] = { + "torch.float16": [], "torch.bfloat16": [], "torch.float32": [], "other": [], + } + stats[group][dtype] += 1 + if len(examples[group][dtype]) < 3: + examples[group][dtype].append(module.__class__.__name__) + return fn + + for name, module in model.named_modules(): + if is_leaf(module): + handles.append(module.register_forward_hook(hook_fn(name))) + return handles, stats, examples + + +def report_dtype_stats(handles, stats, examples): + for h in handles: + h.remove() + logging.info("\n=== DTYPE USAGE PER MODULE ===") + for group, group_stats in stats.items(): + total = sum(group_stats.values()) + if total == 0: continue + logging.info(f"\n--- {group} ---") + for dtype, count in group_stats.items(): + if count > 0: + logging.info(f"{dtype}: {count} ({100*count/total:.2f}%)") + logging.info("\n=== EXAMPLES ===") + for group, group_examples in examples.items(): + logging.info(f"\n--- {group} ---") + for dtype, mods in group_examples.items(): + if mods: + logging.info(f"{dtype}: {mods}") + + +class EvalJSONLDataset(Dataset): + def __init__(self, file_path, num_turns=1): + self.samples = [] + raw_samples = [] + with open(file_path, "r", encoding="utf-8") as f: + for line_idx, line in enumerate(f, 1): + line = line.strip() + if not line: continue + try: + raw_samples.append(json.loads(line)) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON on line {line_idx}: {e}") + + if num_turns <= 1: + self.samples = raw_samples + return + + single_turn_by_speaker = {} + for sample in raw_samples: + if isinstance(sample["text"], list): + self.samples.append(sample) + else: + speaker = sample.get("speaker", "unknown") + if speaker not in single_turn_by_speaker: + single_turn_by_speaker[speaker] = [] + single_turn_by_speaker[speaker].append(sample) + + for speaker, speaker_samples in single_turn_by_speaker.items(): + buffer_texts, buffer_paths = [], [] + first_sample_meta = None + + for sample in speaker_samples: + if not buffer_texts: + first_sample_meta = dict(sample) + buffer_texts.append(sample["text"]) + buffer_paths.append(sample.get("audio_filepath", "")) + + if len(buffer_texts) == num_turns: + first_sample_meta["text"] = buffer_texts + base_names = [os.path.splitext(os.path.basename(p))[0] for p in buffer_paths if p] + ext = os.path.splitext(buffer_paths[-1])[1] if buffer_paths[-1] else "" + combined_name = "_".join(base_names) + ext + dir_name = os.path.dirname(first_sample_meta.get("audio_filepath", "")) + if dir_name: + first_sample_meta["audio_filepath"] = os.path.join(dir_name, combined_name) + else: + first_sample_meta["audio_filepath"] = combined_name + + self.samples.append(first_sample_meta) + buffer_texts, buffer_paths, first_sample_meta = [], [], None + + if buffer_texts: + first_sample_meta["text"] = buffer_texts + base_names = [os.path.splitext(os.path.basename(p))[0] for p in buffer_paths if p] + ext = os.path.splitext(buffer_paths[-1])[1] if buffer_paths[-1] else "" + combined_name = "_".join(base_names) + ext + dir_name = os.path.dirname(first_sample_meta.get("audio_filepath", "")) + if dir_name: + first_sample_meta["audio_filepath"] = os.path.join(dir_name, combined_name) + else: + first_sample_meta["audio_filepath"] = combined_name + self.samples.append(first_sample_meta) + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + return self.samples[idx] + + +def collate_and_tokenize_custom( + batch, + model, + extra_duration_thrshould=1.3, + sample_rate=22050, + root_path=None, + emulate_duplex_inference=False, + add_interruption_token=False, + pad_factor_text_speech=10, + force_interruption=False, + normalize_context_audio_volume=True, + use_librosa=False, + profile_multiturn_inference=False, +): + main_tokenizer_name = list(model.cfg.text_tokenizers.keys())[0] + + # --- MULTI-TURN MODE DECISION --- + is_profile = profile_multiturn_inference + is_duplex = emulate_duplex_inference and not is_profile + + out_dict = { + "duplex_multiturn": is_duplex, + "regular_multiturn": (not is_duplex) and (not is_profile), + "profile_multiturn": is_profile, + } + + tokenized_list = [] + batched_turns = [] + batched_turn_lens = [] + valid_turn_masks = [] + + if is_duplex: + # ------------------------------------------------------------- + # DUPLEX MODE (Continuous sequence with 10x pad injection) + # ------------------------------------------------------------- + for s in batch: + text_data = s["text"] + + if isinstance(text_data, list): + full_ids = [] + for segment in text_data: + seg_ids = model.tokenizer.encode(segment, tokenizer_name=main_tokenizer_name) + [model.eos_id] + seg_len = len(seg_ids) + pad_len = seg_len * pad_factor_text_speech + pad_ids = [model.pad_id] * pad_len + + if force_interruption: + fname = s["audio_filepath"] + no_ext = fname.split(".")[0] + sample_id = int(no_ext.split("_")[-1]) + case = sample_id % 3 + + if case == 0: + if len(seg_ids) >= 2: + seg_ids[-2] = model.interruption_token_id + seg_ids[-1] = model.pad_id + else: + pad_ids[0] = model.interruption_token_id + elif case == 1: + eos_idx = min(6, len(pad_ids) - 1) + pad_ids[eos_idx] = model.interruption_token_id + else: + eos_idx = 0 + pad_ids[eos_idx] = model.interruption_token_id + else: + if add_interruption_token: + eos_idx = int(len(pad_ids) * 0.7) + pad_ids[eos_idx] = model.interruption_token_id + + full_ids.extend(seg_ids) + full_ids.extend(pad_ids) + + tokenized_list.append(torch.as_tensor(full_ids, dtype=torch.long)) + else: + tokenized_list.append( + torch.as_tensor(model.tokenizer.encode(text_data, tokenizer_name=main_tokenizer_name) + [model.eos_id], dtype=torch.long) + ) + + pad_len = 25 + prefix = torch.full((pad_len,), model.pad_id, dtype=torch.long) + for i in range(len(tokenized_list)): + tokenized_list[i] = torch.cat([prefix, tokenized_list[i]]) + input_lengths = torch.tensor([len(x) for x in tokenized_list], dtype=torch.long) + + input_ids = pad_sequence(tokenized_list, batch_first=True, padding_value=model.pad_id) + + out_dict["input_ids"] = input_ids + out_dict["input_lengths"] = input_lengths + + else: + # ------------------------------------------------------------- + # REGULAR MODE (Turn-by-turn discrete packaging) + # ------------------------------------------------------------- + max_turns = 1 + for s in batch: + if isinstance(s["text"], list): + max_turns = max(max_turns, len(s["text"])) + + for t in range(max_turns): + turn_t_tokens = [] + turn_t_lens = [] + turn_t_valid = [] + + for s in batch: + text_data = s["text"] + if isinstance(text_data, list): + if t < len(text_data): + seg_ids = model.tokenizer.encode(text_data[t], tokenizer_name=main_tokenizer_name) + [model.eos_id] + turn_t_tokens.append(torch.as_tensor(seg_ids, dtype=torch.long)) + turn_t_lens.append(len(seg_ids)) + turn_t_valid.append(True) + else: + # Dummy pad to keep shapes consistent for items with fewer turns + turn_t_tokens.append(torch.as_tensor([model.pad_id], dtype=torch.long)) + turn_t_lens.append(1) + turn_t_valid.append(False) + else: + if t == 0: + seg_ids = model.tokenizer.encode(text_data, tokenizer_name=main_tokenizer_name) + [model.eos_id] + turn_t_tokens.append(torch.as_tensor(seg_ids, dtype=torch.long)) + turn_t_lens.append(len(seg_ids)) + turn_t_valid.append(True) + else: + turn_t_tokens.append(torch.as_tensor([model.pad_id], dtype=torch.long)) + turn_t_lens.append(1) + turn_t_valid.append(False) + + padded_turn_t = pad_sequence(turn_t_tokens, batch_first=True, padding_value=model.pad_id) + batched_turns.append(padded_turn_t) + batched_turn_lens.append(torch.tensor(turn_t_lens, dtype=torch.long)) + valid_turn_masks.append(torch.tensor(turn_t_valid, dtype=torch.bool)) + + out_dict["batched_turns"] = batched_turns + out_dict["batched_turn_lens"] = batched_turn_lens + out_dict["valid_turn_masks"] = valid_turn_masks + + # --- AUDIO LOADING --- + audio_list = [] + audio_lengths = [] + target_num_frames = [] + + for i, s in enumerate(batch): + audio_path = s["context_audio_filepath"] + if root_path is not None: + audio_path = os.path.join(root_path, audio_path) + + if os.path.exists(audio_path): + if use_librosa: + wav, sr = librosa.load(audio_path, sr=sample_rate, mono=True) + if normalize_context_audio_volume: + wav = normalize_volume(wav) + wav = torch.as_tensor(wav, dtype=torch.float32) + else: + wav, sr = sf.read(audio_path, dtype='float32') + # Force Mono + if wav.ndim > 1: + wav = wav.mean(axis=1) + + if normalize_context_audio_volume: + wav = normalize_volume(wav) + + # Convert to tensor, add batch dim for resampler, then remove it + wav_tensor = torch.as_tensor(wav, dtype=torch.float32).unsqueeze(0) + wav = resample(wav_tensor, sr, sample_rate).squeeze(0) + else: + wav = torch.zeros(1, dtype=torch.float32) + + audio_list.append(wav) + audio_lengths.append(len(wav)) + + tdur_audio_path = s["audio_filepath"] + if root_path is not None: + tdur_audio_path = os.path.join(root_path, tdur_audio_path) + + if tdur_audio_path and os.path.exists(tdur_audio_path): + wav_dur, sr_ = librosa.load(tdur_audio_path, sr=sample_rate, mono=True) + tdur = wav_dur.shape[0] // model.input_samples_per_frame + target_num_frames.append(tdur * extra_duration_thrshould) + else: + # Fallback estimation + if is_duplex: + current_text_len = len(tokenized_list[i]) + if isinstance(s["text"], list): + target_num_frames.append(current_text_len) + else: + target_num_frames.append(current_text_len * 5) + else: + target_num_frames.append(sum([l[i].item() for l in batched_turn_lens]) * 5) + + max_audio_len = max(audio_lengths) + B = len(audio_lengths) + padded_audio = torch.zeros((B, max_audio_len), dtype=torch.float32) + + for i, wav in enumerate(audio_list): + padded_audio[i, : len(wav)] = wav + + out_dict["context_audio"] = padded_audio + out_dict["context_audio_lengths"] = torch.tensor(audio_lengths, dtype=torch.long) + out_dict["target_audio_paths"] = [s["audio_filepath"] for s in batch] + out_dict["target_num_frames"] = target_num_frames + + out_dict["raw_text"] = [" ".join(s["text"]) if isinstance(s["text"], list) else s["text"] for s in batch] + + return out_dict + + +def main(): + parser = argparse.ArgumentParser(description="EasyMagpieTTS Inference Evaluation") + + # Required Paths + parser.add_argument("--checkpoint_path", type=str, required=True, help="Path to the EasyMagpie model") + parser.add_argument("--codec_model_path", type=str, required=True, help="Path to the audio codec") + parser.add_argument("--datasets_json_path", type=str, required=True, help="Path to JSONL data") + parser.add_argument("--out_dir", type=str, required=True, help="Directory to save audio outputs") + + # Optional Paths & General + parser.add_argument("--phoneme_tokenizer_path", type=str, default=None) + parser.add_argument("--audio_dir", type=str, default=None, help="Root dir for audio paths in JSONL") + parser.add_argument("--inference_dtype", type=str, default="float32") + parser.add_argument("--debug_dtype", action="store_true") + parser.add_argument("--use_librosa", action="store_true", help="Use librosa instead of soundfile+torch for audio load") + + # Dataloader & Batching + parser.add_argument("--batch_size", type=int, default=6) + parser.add_argument("--num_workers", type=int, default=4) + parser.add_argument("--num_turns", type=int, default=1) + parser.add_argument("--pad_factor_text_speech", type=int, default=10) + + # Text Processing Boolean Flags + parser.add_argument("--emulate_duplex_inference", action="store_true") + parser.add_argument("--add_interruption_token", action="store_true") + parser.add_argument("--force_interruption", action="store_true") + parser.add_argument("--profile_multiturn_inference", action="store_true") + parser.add_argument("--profile_pad_min_sec", type=float, default=2.0) + parser.add_argument("--profile_pad_max_sec", type=float, default=2.0) + + + # Speaker & Prompt Configurations + parser.add_argument("--user_custom_speaker_reference", action="store_true") + parser.add_argument("--inference_speaker_reference", type=str, default=None) + parser.add_argument("--language", type=str, default="en") + + # Generation Kwargs + parser.add_argument("--use_cfg", action="store_true") + parser.add_argument("--cfg_scale", type=float, default=2.5) + parser.add_argument("--temperature", type=float, default=0.7) + parser.add_argument("--topk", type=int, default=80) + parser.add_argument("--max_tts_steps", type=int, default=2000) + parser.add_argument("--force_speech_sil_codes", action="store_true") + parser.add_argument("--normalize_volume", type=lambda x: (str(x).lower() in ['true', '1', 'yes']), default=False) + + args = parser.parse_args() + + if args.profile_pad_max_sec < args.profile_pad_min_sec: + raise RuntimeError("--profile_pad_max_sec must be >= --profile_pad_min_sec.") + + distributed = int(os.environ.get("WORLD_SIZE", "1")) > 1 + if distributed and not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend="nccl") + + target_device = torch.device(f"cuda:{int(os.environ.get('LOCAL_RANK', 0))}" if torch.cuda.is_available() else "cpu") + target_dtype = getattr(torch, args.inference_dtype) + torch.set_default_dtype(target_dtype) + + model_cfg = EasyMagpieTTSInferenceModel.restore_from(args.checkpoint_path, return_config=True) + with open_dict(model_cfg): + model_cfg.target = 'nemo.collections.tts.models.easy_magpietts_inference.EasyMagpieTTSInferenceModel' + model_cfg.codecmodel_path = args.codec_model_path + model_cfg.train_ds = None + model_cfg.validation_ds = None + model_cfg.run_val_inference = False + model_cfg.use_utmos = False + model_cfg.use_meta_init_for_decoder = True + + # Guarantees silence for pad tokens + model_cfg.use_multiturn_dataset = True + + if args.phoneme_tokenizer_path and getattr(model_cfg, "phoneme_tokenizer", None) is not None: + model_cfg.phoneme_tokenizer.tokenizer_path = args.phoneme_tokenizer_path + + # Load to CPU first to prevent OOM + model = EasyMagpieTTSInferenceModel.restore_from( + args.checkpoint_path, override_config_path=model_cfg, map_location=torch.device("cpu") + ) + model.use_kv_cache_for_inference = True + model.to(dtype=target_dtype) + model.eval().to(target_device) + + # --- DATALOADER COMPATIBILITY PATCHES --- + model.input_samples_per_frame = int(model.codec_model_samples_per_frame * model.frame_stacking_factor) + model.target_samples_per_frame = model.input_samples_per_frame / (model.sample_rate / model.output_sample_rate) + + # Load to CPU first to prevent OOM + codec_model = AudioCodecModel.restore_from(args.codec_model_path, strict=False, map_location=torch.device("cpu")) + if hasattr(codec_model, "discriminator"): + del codec_model.discriminator + codec_model.freeze() + codec_model = codec_model.to(target_device).eval() + + codec_converter = None + if getattr(model, "_codec_converter", None) is not None: + vq_new = deepcopy(model._codec_converter.vector_quantizer_new).to(target_device).eval() + codec_converter = VectorQuantizerIndexConverter( + vector_quantizer_original=codec_model.vector_quantizer, + vector_quantizer_new=vq_new, + ).to(target_device).eval() + + if not hasattr(model, "_codec_helper") or model._codec_helper is None: + model._codec_helper = CodecHelper(codec_model=codec_model, codec_converter=codec_converter) + + model._generate_codec_silence_buffer() + codec_sil_codes = model.codec_sil_codes + + if args.debug_dtype: + handles, stats, examples = attach_dtype_counter(model) + + with fp32_precision(): + intelligibility = Intelligibility("stt_en_fastconformer_transducer_large", reuse_asr_hyps=False).reset() + secs_metric = SECS("titanet_large").reset() + + eval_dataset = EvalJSONLDataset(args.datasets_json_path, num_turns=args.num_turns) + + collate_fn = partial( + collate_and_tokenize_custom, + model=model, + extra_duration_thrshould=1.5, + sample_rate=model.sample_rate, + root_path=args.audio_dir, + emulate_duplex_inference=args.emulate_duplex_inference, + add_interruption_token=args.add_interruption_token, + pad_factor_text_speech=args.pad_factor_text_speech, + force_interruption=args.force_interruption, + normalize_context_audio_volume=args.normalize_volume, + use_librosa=args.use_librosa, + profile_multiturn_inference=args.profile_multiturn_inference + ) + + dataloader = DataLoader( + dataset=eval_dataset, batch_size=args.batch_size, collate_fn=collate_fn, + num_workers=args.num_workers, pin_memory=True, shuffle=False, drop_last=False, + ) + + if args.user_custom_speaker_reference and args.inference_speaker_reference: + if args.use_librosa: + wav, sr = librosa.load(args.inference_speaker_reference, sr=model.sample_rate, mono=True) + if args.normalize_volume: + wav = normalize_volume(wav) + speaker_wav = torch.as_tensor(wav, dtype=target_dtype).unsqueeze(0).to(model.device) + else: + wav, sr = sf.read(args.inference_speaker_reference, dtype='float32') + # Force Mono + if wav.ndim > 1: + wav = wav.mean(axis=1) + + if args.normalize_volume: + wav = normalize_volume(wav) + + speaker_wav = torch.as_tensor(wav).unsqueeze(0) + speaker_wav = resample(speaker_wav.float(), sr, model.sample_rate).to(target_dtype).to(model.device) + + for batch_id, inputs in enumerate(dataloader): + B = inputs["context_audio"].size(0) + device = model.device + + inputs["context_audio"] = inputs["context_audio"].to(device, dtype=target_dtype) + inputs["context_audio_lengths"] = inputs["context_audio_lengths"].to(device) + + if args.user_custom_speaker_reference and args.inference_speaker_reference: + inputs["context_audio"] = speaker_wav.repeat(B, 1).detach() + inputs["context_audio_lengths"] = torch.full((B,), speaker_wav.size(-1), dtype=torch.long, device=device) + + B = inputs["context_audio"].size(0) + profile_turn_frame_ranges = [[] for _ in range(B)] + profile_decode_start_frame = None + with torch.inference_mode(): + wav = inputs["context_audio"] + wav_len = inputs["context_audio_lengths"] + codes, codes_lens = model._codec_helper.audio_to_codes(wav, wav_len) + + use_lang = bool(getattr(model, "add_language_to_context_text", False)) + ctx_text = f"[{args.language.upper()}]" if use_lang else "[NO TEXT CONTEXT]" + ctx_text_ids = model.tokenizer.encode(ctx_text, tokenizer_name=model.text_conditioning_tokenizer_name) + + ctx_toks = torch.tensor([ctx_text_ids], dtype=torch.long, device=device).expand(B, -1) + ctx_toks_lens = torch.tensor([len(ctx_text_ids)] * B, dtype=torch.long, device=device) + + state = model.streaming_init( + context_audio_codes=codes, + context_audio_codes_lens=codes_lens, + context_text_tokens=ctx_toks, + context_text_tokens_lens=ctx_toks_lens, + use_cfg=args.use_cfg, + cfg_scale=args.cfg_scale, + use_local_transformer=True, + temperature=args.temperature, + topk=args.topk, + phoneme_input_type="pred", + phoneme_sampling_method="argmax", + use_inference_mode=True, + ) + # --------------------------------------------------------- + # MODE 1: DUPLEX (Continuous Padding Token Stream) + # --------------------------------------------------------- + if inputs["duplex_multiturn"]: + text = inputs["input_ids"].to(device) + text_lens = inputs["input_lengths"].to(device) + + # Trackers for our two forced-silence zones + in_initial_silence = torch.ones(B, dtype=torch.bool, device=device) + in_post_speech_silence = torch.zeros(B, dtype=torch.bool, device=device) + + text_exhausted = state.text_tokens_seen >= text_lens + while not text_exhausted.all(): + # 1. WAKE UP OVERRIDE: Keep the text pipeline awake to read pads! + state.finished = state.finished & text_exhausted + state.text_finished = state.text_finished & text_exhausted + if hasattr(state, "phoneme_stream_ended"): + state.phoneme_stream_ended = state.phoneme_stream_ended & text_exhausted + + # 2. Safely index text using the model's internal pointer + positions = state.text_tokens_seen.clamp(max=text.size(1) - 1) + current_tokens = text[torch.arange(B, device=device), positions] + + current_tokens = torch.where( + text_exhausted, torch.full_like(current_tokens, model.eos_id), current_tokens + ) + + # 3. Update our trackers BEFORE the step + is_pad_or_eos = (current_tokens == model.pad_id) | (current_tokens == model.eos_id) + in_initial_silence = in_initial_silence & is_pad_or_eos + in_post_speech_silence = in_post_speech_silence & is_pad_or_eos + + # 4. Step the model + state, audio_codes, _ = model.streaming_step(state=state, text_tokens=current_tokens, use_inference_mode=True) + + # 5. SILENCE FORCING INJECTION + if audio_codes is not None and args.force_speech_sil_codes: + force_silence_mask = in_initial_silence | in_post_speech_silence + + if force_silence_mask.any(): + # Expand silence codes [C] -> [1, C, 1] to match audio_codes [B, C, 1] + expanded_sil = codec_sil_codes.view(1, -1, 1).expand_as(audio_codes) + # Expand mask [B] -> [B, 1, 1] for broadcasting + mask_3d = force_silence_mask.view(B, 1, 1) + # Overwrite the prediction with silence codes where the mask is True. + overwritten_codes = torch.where(mask_3d, expanded_sil, audio_codes) + # Inject back into the model's KV cache history + state.all_predictions[-1] = overwritten_codes + + # 6. TRIGGER POST-SPEECH SILENCE FOR THE *NEXT* FRAME + in_post_speech_silence = in_post_speech_silence | state.finished + + # Update exhaustion tracker for the next iteration + text_exhausted = state.text_tokens_seen >= text_lens + + # --------------------------------------------------------- + # MODE 2: REGULAR (Turn-by-Turn Re-wakes) + # --------------------------------------------------------- + elif inputs["regular_multiturn"]: + batched_turns = inputs["batched_turns"] + batched_turn_lens = inputs["batched_turn_lens"] + valid_turn_masks = inputs["valid_turn_masks"] + + max_turns = len(batched_turns) + turn_offsets = torch.zeros(B, dtype=torch.long, device=device) + + for t in range(max_turns): + turn_text = batched_turns[t].to(device) + turn_lens = batched_turn_lens[t].to(device) + valid_mask = valid_turn_masks[t].to(device) + + state.finished = state.finished & (~valid_mask) + state.text_finished = state.text_finished & (~valid_mask) + + if hasattr(state, "phoneme_stream_ended"): + state.phoneme_stream_ended = state.phoneme_stream_ended & (~valid_mask) + + if state.finished.all(): + continue + + turn_offsets = torch.where(valid_mask, state.text_tokens_seen, turn_offsets) + turn_steps = 0 + + while not state.finished.all() and turn_steps < args.max_tts_steps: + turn_steps += 1 + + relative_positions = state.text_tokens_seen - turn_offsets + positions = relative_positions.clamp(min=0, max=turn_text.size(1) - 1) + current_tokens = turn_text[torch.arange(B, device=device), positions] + + exhausted = relative_positions >= turn_lens + current_tokens = torch.where(exhausted, torch.full_like(current_tokens, model.eos_id), current_tokens) + + state, _, _ = model.streaming_step(state=state, text_tokens=current_tokens, use_inference_mode=True) + + # --------------------------------------------------------- + # MODE 3: PROFILE MULTI-TURN, BATCHED + # --------------------------------------------------------- + elif inputs["profile_multiturn"]: + batched_turns = inputs["batched_turns"] + batched_turn_lens = inputs["batched_turn_lens"] + valid_turn_masks = inputs["valid_turn_masks"] + + max_turns = len(batched_turns) + + # Per-sample debug/save metadata. + # profile_turn_frame_ranges[b] = [(turn_id, start_frame, end_frame), ...] + profile_turn_frame_ranges = [[] for _ in range(B)] + + # First decoded frame per sample. Used by streaming_finalize and by per-turn saving. + profile_decode_start_frame = torch.full( + (B,), + -1, + dtype=torch.long, + device=device, + ) + + # Last meaningful decoded frame per sample. This avoids keeping trailing batch-idle silence. + profile_final_end_frame = torch.full( + (B,), + -1, + dtype=torch.long, + device=device, + ) + + arange_B = torch.arange(B, device=device) + + for t in range(max_turns): + turn_text = batched_turns[t].to(device) # (B, T_text_t) + turn_lens = batched_turn_lens[t].to(device) # (B,) + valid_mask = valid_turn_masks[t].to(device) # (B,) + + if not valid_mask.any(): + continue + + # Re-open only rows that have this turn. + state.finished = state.finished & (~valid_mask) + state.text_finished = state.text_finished & (~valid_mask) + + # Let this turn detect its own EOS, but keep audio_prediction_start_idx + # as the first start of the full generated conversation. + state.audio_prediction_end_idx = torch.where( + valid_mask, + torch.full_like(state.audio_prediction_end_idx, -1), + state.audio_prediction_end_idx, + ) + + if hasattr(state, "phoneme_stream_ended"): + state.phoneme_stream_ended = state.phoneme_stream_ended & (~valid_mask) + if hasattr(state, "phoneme_eos_detected"): + state.phoneme_eos_detected = state.phoneme_eos_detected & (~valid_mask) + + # Optional but usually cleaner for turn-level phoneme generation. + if hasattr(state, "phoneme_steps"): + state.phoneme_steps = torch.where( + valid_mask, + torch.zeros_like(state.phoneme_steps), + state.phoneme_steps, + ) + + # Optional but cleaner for turn-local audio step accounting. + # Profile steps below will immediately set last_audio_codes and increment audio_steps. + state.audio_steps = torch.where( + valid_mask, + torch.zeros_like(state.audio_steps), + state.audio_steps, + ) + + # ----------------------------- + # 1. Batched per-row profiling + # ----------------------------- + profile_seconds = ( + args.profile_pad_min_sec + + torch.rand((B,), device=device) + * (args.profile_pad_max_sec - args.profile_pad_min_sec) + ) + + profile_T = torch.round( + profile_seconds * model.sample_rate / model.input_samples_per_frame + ).to(torch.long) + profile_T = torch.clamp(profile_T, min=1) + + profile_T = torch.where( + valid_mask, + profile_T, + torch.zeros_like(profile_T), + ) + + profile_remaining = profile_T.clone() + profile_step_idx = torch.zeros(B, dtype=torch.long, device=device) + + # Old behavior: put the first streaming_speech_delay BPE tokens into + # the last profile positions, then do normal generation after that. + # + # This keeps your current working behavior, but now per batch row. + if model.cfg.get("agent_mask_include_transition_prefix", False): + delay_tokens = torch.zeros(B, dtype=torch.long, device=device) + else: + delay_tokens = torch.full( + (B,), + int(state.config.training_mode.streaming_speech_delay), + dtype=torch.long, + device=device, + ) + + # Safer version: keep at least one token, normally EOS, outside profile. + # If you want exact old behavior, replace `turn_lens - 1` with `turn_lens`. + max_consumable_text = torch.clamp(turn_lens - 1, min=0) + + delay_tokens = torch.minimum(delay_tokens, max_consumable_text) + delay_tokens = torch.minimum(delay_tokens, profile_T) + delay_tokens = torch.where(valid_mask, delay_tokens, torch.zeros_like(delay_tokens)) + + profile_text_consumed = torch.zeros(B, dtype=torch.long, device=device) + + while profile_remaining.max().item() > 0: + profile_mask = valid_mask & (profile_remaining > 0) + profile_end_mask = profile_mask & (profile_remaining == 1) + + # Default profile text is PAD. Rows in the last delay_tokens profile + # positions receive real BPE text tokens. + profile_text_tokens = torch.full( + (B,), + model.pad_id, + dtype=torch.long, + device=device, + ) + + if (delay_tokens > 0).any(): + # step_in_profile: 0, 1, ..., profile_T[b]-1 + step_in_profile = profile_step_idx + + # Emit text only in the final delay_tokens[b] profile steps. + emit_profile_text = ( + profile_mask + & (delay_tokens > 0) + & (step_in_profile >= (profile_T - delay_tokens)) + & (profile_text_consumed < delay_tokens) + ) + + if emit_profile_text.any(): + text_pos = profile_text_consumed.clamp( + min=0, + max=turn_text.size(1) - 1, + ) + + gathered_profile_text = turn_text[arange_B, text_pos] + profile_text_tokens = torch.where( + emit_profile_text, + gathered_profile_text, + profile_text_tokens, + ) + + profile_text_consumed = profile_text_consumed + emit_profile_text.long() + + # Only rows with profile_mask=True are active in this step. + # Other rows receive silence in the rectangular all_predictions tensor, + # but their logical counters do not advance. + state, _, _ = model.streaming_step_profiled( + state=state, + text_tokens=torch.full( + (B,), + model.eos_id, + dtype=torch.long, + device=device, + ), + profile_mask=profile_mask, + profile_text_tokens=profile_text_tokens, + profile_end_mask=profile_end_mask, + active_mask=profile_mask, + use_inference_mode=True, + ) + + profile_remaining = torch.where( + profile_mask, + profile_remaining - 1, + profile_remaining, + ) + profile_step_idx = torch.where( + profile_mask, + profile_step_idx + 1, + profile_step_idx, + ) + + logging.info( + f"[profile_multiturn] turn={t} profile_steps=" + f"{profile_T.detach().cpu().tolist()} " + f"profile_seconds={profile_seconds.detach().cpu().tolist()}" + ) + + # We start the turn after all rows have finished the profile phase. + # This excludes profile/user-speaking silence from each turn segment. + turn_start_frame_global = sum(p.size(-1) for p in state.all_predictions) + turn_start_frames = torch.full( + (B,), + turn_start_frame_global, + dtype=torch.long, + device=device, + ) + + first_profile_turn = valid_mask & (profile_decode_start_frame < 0) + profile_decode_start_frame = torch.where( + first_profile_turn, + turn_start_frames, + profile_decode_start_frame, + ) + + # Make streaming_finalize start from the first generated turn frame. + state.audio_prediction_start_idx = torch.where( + first_profile_turn, + turn_start_frames, + state.audio_prediction_start_idx, + ) + + # The profile phase already consumed `delay_tokens` text tokens for each row. + # Do not slice turn_text because delay_tokens differs per row; use a per-row base offset. + turn_text_base = delay_tokens + turn_remaining_lens = torch.clamp(turn_lens - turn_text_base, min=0) + + turn_offset = state.text_tokens_seen.clone() + turn_done = ~valid_mask + turn_steps = 0 + + saw_audio = torch.zeros(B, dtype=torch.bool, device=device) + first_audio_step_finished = torch.zeros(B, dtype=torch.bool, device=device) + + # ----------------------------- + # 2. Batched turn generation + # ----------------------------- + while (not turn_done.all()) and turn_steps < args.max_tts_steps: + turn_steps += 1 + + gen_mask = valid_mask & (~turn_done) + + # Old single-sample behavior cleared state.finished every generation step. + # Do it only for active rows. This prevents an early audio EOS from ending + # the turn before text is exhausted. + state.finished = state.finished & (~gen_mask) + + relative_position = state.text_tokens_seen - turn_offset + text_exhausted = relative_position >= turn_remaining_lens + + # Current token index in original turn_text, accounting for tokens consumed + # during profile. + position = turn_text_base + relative_position + position = position.clamp(min=0, max=turn_text.size(1) - 1) + + current_tokens = turn_text[arange_B, position] + current_tokens = torch.where( + text_exhausted, + torch.full_like(current_tokens, model.eos_id), + current_tokens, + ) + + state, audio_codes, _ = model.streaming_step_profiled( + state=state, + text_tokens=current_tokens, + profile_mask=torch.zeros(B, dtype=torch.bool, device=device), + profile_text_tokens=torch.full( + (B,), + model.pad_id, + dtype=torch.long, + device=device, + ), + profile_end_mask=torch.zeros(B, dtype=torch.bool, device=device), + active_mask=gen_mask, + use_inference_mode=True, + ) + + if audio_codes is not None: + newly_saw_audio = gen_mask & (~saw_audio) + saw_audio = saw_audio | gen_mask + first_audio_step_finished = torch.where( + newly_saw_audio, + state.finished, + first_audio_step_finished, + ) + + # Match old stopping condition: + # a turn is done only when its text is exhausted AND audio EOS is detected. + done_now = gen_mask & text_exhausted & state.finished + + if done_now.any(): + current_end_frame_global = sum(p.size(-1) for p in state.all_predictions) + + # Prefer precise EOS frame if _process_predictions_profiled populated it. + eos_end_frames = torch.where( + state.audio_prediction_end_idx >= 0, + state.audio_prediction_end_idx, + torch.full_like( + state.audio_prediction_end_idx, + current_end_frame_global, + ), + ) + + profile_final_end_frame = torch.where( + done_now, + eos_end_frames, + profile_final_end_frame, + ) + + for b in done_now.nonzero(as_tuple=False).flatten().detach().cpu().tolist(): + profile_turn_frame_ranges[b].append( + ( + int(t), + int(turn_start_frames[b].detach().cpu().item()), + int(eos_end_frames[b].detach().cpu().item()), + ) + ) + + turn_done = turn_done | done_now + + # Max-step fallback for rows that did not finish by EOS. + still_running = valid_mask & (~turn_done) + if still_running.any(): + current_end_frame_global = sum(p.size(-1) for p in state.all_predictions) + fallback_end_frames = torch.full( + (B,), + current_end_frame_global, + dtype=torch.long, + device=device, + ) + + profile_final_end_frame = torch.where( + still_running, + fallback_end_frames, + profile_final_end_frame, + ) + + for b in still_running.nonzero(as_tuple=False).flatten().detach().cpu().tolist(): + profile_turn_frame_ranges[b].append( + ( + int(t), + int(turn_start_frames[b].detach().cpu().item()), + int(fallback_end_frames[b].detach().cpu().item()), + ) + ) + + # Do not let this turn's EOS crop the full conversation before later turns. + state.audio_prediction_end_idx = torch.where( + valid_mask, + torch.full_like(state.audio_prediction_end_idx, -1), + state.audio_prediction_end_idx, + ) + state.finished = state.finished & (~valid_mask) + + logging.info( + f"[profile_multiturn] turn={t} steps={turn_steps} " + f"saw_audio={saw_audio.detach().cpu().tolist()} " + f"first_audio_step_finished={first_audio_step_finished.detach().cpu().tolist()}" + ) + + # After all turns, crop each row at its own last meaningful frame. + # This prevents rows that finished early from keeping trailing batch-idle silence. + total_frames = sum(p.size(-1) for p in state.all_predictions) + profile_final_end_frame = torch.where( + profile_final_end_frame >= 0, + profile_final_end_frame, + torch.full_like(profile_final_end_frame, total_frames), + ) + + state.audio_prediction_end_idx.copy_(profile_final_end_frame) + + # if state.audio_prediction_end_idx[0].item() >= 0: + # last_audio_prediction_end_idx.copy_(state.audio_prediction_end_idx) + + # Scrub Special Tokens (BOS/EOS) from Audio Codes --- + # Because we force-decode the entire uncropped sequence, any BOS or EOS + # tokens left in the array will produce loud artifacts in the codec. + bos_id = getattr(model, "audio_bos_id", -1) + eos_id = getattr(model, "audio_eos_id", -1) + speaking_id = getattr(model, "audio_user_speaking_id", -1) + speaking_end_id = getattr(model, "audio_user_speaking_end_id", -1) + + sil_injection = codec_sil_codes.view(1, -1, 1) + + for step_idx in range(len(state.all_predictions)): + pred = state.all_predictions[step_idx] + # Check if any codebook in the frame has any special token + mask = (pred == bos_id) | (pred == eos_id) | (pred == speaking_id) | (pred == speaking_end_id) + frame_mask = mask.any(dim=1, keepdim=True) + + if frame_mask.any(): + state.all_predictions[step_idx] = torch.where( + frame_mask, + sil_injection.expand_as(pred), + pred + ) + + if inputs["duplex_multiturn"]: + # Erase the internal memory of Turn 1's EOS token so `streaming_finalize` + # decodes the entire physical sequence! + state.audio_prediction_end_idx.fill_(-1) + + + # Finalize decodes the collected Codec states globally regardless of which loop was run + finalize_output = model.streaming_finalize(state, use_inference_mode=True) + + if args.debug_dtype and batch_id == 0: + report_dtype_stats(handles, stats, examples) + + with fp32_precision(): + audio_f32 = finalize_output.audio.float() + audio_len = finalize_output.audio_len.int() + + expected_audio_lens = (torch.tensor(inputs["target_num_frames"], device=device) * model.target_samples_per_frame).int() + + if inputs["duplex_multiturn"]: + # Use exact math based on the output samples multiplier! + audio_len = (text_lens * model.target_samples_per_frame).int() + + # Cap the expected length so it physically cannot exceed the actual generated tensor size + audio_len = torch.min(audio_len, torch.tensor(audio_f32.size(1), device=device)) + elif inputs["profile_multiturn"]: + audio_len = finalize_output.audio_len.int() + else: + audio_len = torch.min(audio_len, expected_audio_lens) + + metric_audio_pred = resample(audio_f32, model.output_sample_rate, 16000) + metric_audio_pred_lens = (audio_len / model.output_sample_rate * 16000).to(torch.long) + + intelligibility.update( + name="dataset", + refs=inputs["raw_text"], + pred_audio=metric_audio_pred, + pred_audio_lens=metric_audio_pred_lens, + asr_hyps=None, + ) + + secs_metric.update( + name="dataset", + target_audio=resample(inputs["context_audio"].float(), model.sample_rate, 16000), + target_audio_lens=(inputs["context_audio_lengths"] / model.sample_rate * 16000).to(torch.long), + pred_audio=metric_audio_pred, + pred_audio_lens=metric_audio_pred_lens, + ) + + os.makedirs(args.out_dir, exist_ok=True) + audio_f32 = audio_f32.detach().cpu() + audio_len = audio_len.cpu() + + for i in range(B): + target_path = inputs["target_audio_paths"][i] + base_name = os.path.basename(target_path) + stem, ext = os.path.splitext(base_name) + if not ext: + ext = ".wav" + + if inputs["profile_multiturn"]: + wav = audio_f32[i, : audio_len[i]].numpy() + out_path = os.path.join(args.out_dir, base_name) + sf.write(out_path, wav, samplerate=model.output_sample_rate) + logging.info(f"Full Audio Saved: {out_path}") + + full_wav = audio_f32[i].numpy() + full_len = int(audio_len[i].item()) + + samples_per_prediction_frame = ( + model.codec_model_samples_per_frame + / (model.sample_rate / model.output_sample_rate) + ) + + decode_start_i = int(profile_decode_start_frame[i].detach().cpu().item()) + if decode_start_i < 0: + decode_start_i = 0 + + for turn_id, start_frame, end_frame in profile_turn_frame_ranges[i]: + rel_start_frame = start_frame - decode_start_i + rel_end_frame = end_frame - decode_start_i + + start_sample = int(round(rel_start_frame * samples_per_prediction_frame)) + end_sample = int(round(rel_end_frame * samples_per_prediction_frame)) + + start_sample = max(0, min(start_sample, full_len)) + end_sample = max(start_sample, min(end_sample, full_len)) + + turn_wav = full_wav[start_sample:end_sample] + + out_path = os.path.join(args.out_dir, f"{stem}_turn_{turn_id}{ext}") + sf.write(out_path, turn_wav, samplerate=model.output_sample_rate) + logging.info(f"Saved: {out_path}") + + else: + wav = audio_f32[i, : audio_len[i]].numpy() + out_path = os.path.join(args.out_dir, base_name) + sf.write(out_path, wav, samplerate=model.output_sample_rate) + logging.info(f"Saved: {out_path}") + + with fp32_precision(): + logging.info("\n--- Evaluation Metrics ---") + cer_wer = intelligibility.compute() + for k, m in cer_wer.items(): + logging.info(f"Intelligibility - {k}: {m}") + + secs_scores = secs_metric.compute() + for k, m in secs_scores.items(): + logging.info(f"SECS - {k}: {m}") + + +if __name__ == "__main__": + main() diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index 55d79f40e32a..10deeb48035e 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -1498,7 +1498,7 @@ def streaming_init( dropout_conditional_input=False, ) ) - + # Store full context embedding and lens before any CFG manipulation full_context_embedding = context_embedding.clone() # (B, T_max, E) full_context_lens = context_lens.clone() # (B,) @@ -1833,6 +1833,630 @@ def _prepare_streaming_input( return next_input, needs_context, needs_phoneme, needs_audio + def streaming_step_profiled( + self, + state: StreamingState, + text_tokens: Optional[torch.Tensor] = None, # (B,) + profile_mask: Optional[torch.Tensor] = None, # (B,) + profile_text_tokens: Optional[torch.Tensor] = None, # (B,) + profile_end_mask: Optional[torch.Tensor] = None, # (B,) + active_mask: Optional[torch.Tensor] = None, # (B,) + force_dropout_text: bool = False, + use_inference_mode: bool = True, + ) -> Tuple[StreamingState, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + One streaming step with per-row profiling support. + + profile_mask[b] == True: + row b runs a profiling step instead of a normal generation step. + The model receives profile_text_tokens[b] on text channel and + user-speaking/silence on audio channel. + The decoded output frame for that row is forced silence. + + active_mask[b] == False: + row b is idle this step. Its counters are not advanced. + """ + grad_ctx = torch.inference_mode if use_inference_mode else torch.no_grad + + with grad_ctx(): + device = state.config.device + B = state.config.batch_size + + if profile_mask is None: + profile_mask = torch.zeros(B, dtype=torch.bool, device=device) + else: + profile_mask = profile_mask.to(device=device, dtype=torch.bool) + + if profile_end_mask is None: + profile_end_mask = torch.zeros(B, dtype=torch.bool, device=device) + else: + profile_end_mask = profile_end_mask.to(device=device, dtype=torch.bool) + + if active_mask is None: + # Default: normal active rows are unfinished; profile rows are active. + active_mask = (~state.finished) | profile_mask + else: + active_mask = active_mask.to(device=device, dtype=torch.bool) + + profile_mask = profile_mask & active_mask + + if text_tokens is None: + text_tokens = torch.full( + (B,), + self.eos_id, + dtype=torch.long, + device=device, + ) + else: + text_tokens = text_tokens.to(device=device, dtype=torch.long) + + if profile_text_tokens is None: + profile_text_tokens = torch.full( + (B,), + self.pad_id, + dtype=torch.long, + device=device, + ) + else: + profile_text_tokens = profile_text_tokens.to(device=device, dtype=torch.long) + + ( + next_input, + needs_context, + needs_phoneme, + needs_audio, + effective_profile_mask, + profile_silence_unstacked, + profile_last_audio_codes, + ) = self._prepare_streaming_input_profiled( + state=state, + text_tokens=text_tokens, + profile_mask=profile_mask, + profile_text_tokens=profile_text_tokens, + profile_end_mask=profile_end_mask, + active_mask=active_mask, + force_dropout_text=force_dropout_text, + ) + + cache_position = torch.tensor([state.cache_seq_len], device=device) + + transformer_out = self.forward( + inputs_embeds=next_input, + attention_mask=None, + use_cache=True, + past_key_values=state.past_key_values, + cache_position=cache_position, + ) + + state.last_hidden = transformer_out.last_hidden_state + state.past_key_values = transformer_out.past_key_values + state.cache_seq_len += 1 + + audio_codes_next, pred_phoneme_tokens = self._process_predictions_profiled( + state=state, + needs_context=needs_context, + needs_phoneme=needs_phoneme, + needs_audio=needs_audio, + profile_mask=effective_profile_mask, + active_mask=active_mask, + profile_silence_unstacked=profile_silence_unstacked, + profile_last_audio_codes=profile_last_audio_codes, + ) + + return state, audio_codes_next, pred_phoneme_tokens + + def _embed_one_text_step( + self, + tokens: torch.Tensor, # (B,) + force_dropout_text: bool = False, + ) -> torch.Tensor: + """ + Embed one text step. Returns (B, 1, E). + """ + device = tokens.device + tokens_2d = tokens.unsqueeze(1) + + if self.cfg.get("disable_subword_embedding", False): + text_embedded = torch.zeros( + tokens_2d.size(0), + 1, + self.cfg.embedding_dim, + dtype=next(self.parameters()).dtype, + device=device, + ) + else: + text_embedded = self.decoder.get_input_embeddings()(tokens_2d) + + is_pad = tokens_2d == self.pad_id + + if self.use_bpe_char_tokenizer: + if self.cfg.get("use_multiturn_dataset", False): + text_mask = ~is_pad + else: + text_mask = torch.ones_like(tokens_2d, dtype=torch.bool) + + text_embedded = text_embedded + self.cas_encoder( + tokens_2d, + subword_mask=text_mask, + ) + + if force_dropout_text: + text_embedded = text_embedded * 0.0 + + if self.cfg.get("use_multiturn_dataset", False): + text_embedded[is_pad] = 0.0 + + return text_embedded + + def _prepare_streaming_input_profiled( + self, + state: StreamingState, + text_tokens: torch.Tensor, # (B,) + profile_mask: torch.Tensor, # (B,) + profile_text_tokens: torch.Tensor, # (B,) + profile_end_mask: torch.Tensor, # (B,) + active_mask: torch.Tensor, # (B,) + force_dropout_text: bool, + ): + device = state.config.device + B = state.config.batch_size + dtype = next(self.parameters()).dtype + + streaming_speech_delay = state.config.training_mode.streaming_speech_delay + streaming_phonemes_delay = state.config.training_mode.streaming_phonemes_delay + + needs_context = active_mask & (state.context_position < state.full_context_lens) + + # Profiling only applies after context is consumed. + effective_profile_mask = profile_mask & (~needs_context) + + normal_active = active_mask & (~effective_profile_mask) + + needs_text = ( + normal_active + & (~needs_context) + & (~state.text_finished) + ) + + needs_phoneme = ( + normal_active + & (~needs_context) + & (state.text_tokens_seen >= streaming_phonemes_delay) + & (~state.phoneme_stream_ended) + ) + + needs_audio = ( + normal_active + & (~needs_context) + & (state.text_tokens_seen >= streaming_speech_delay) + & (~state.finished) + ) + + next_input = torch.zeros( + B, + 1, + self.cfg.embedding_dim, + dtype=dtype, + device=device, + ) + + # ----------------------- + # Context rows + # ----------------------- + if needs_context.any(): + ctx_positions = state.context_position.clamp( + max=state.full_context_embedding.size(1) - 1 + ) + + ctx_emb = state.full_context_embedding[ + torch.arange(B, device=device), + ctx_positions, + :, + ].unsqueeze(1) + + next_input = next_input + ctx_emb * needs_context.view(B, 1, 1).to(dtype) + + # ----------------------- + # Normal text rows + # ----------------------- + if needs_text.any(): + text_emb = self._embed_one_text_step( + text_tokens, + force_dropout_text=force_dropout_text, + ) + + next_input = next_input + text_emb * needs_text.view(B, 1, 1).to(dtype) + + is_eos_token = (text_tokens == self.eos_id) & needs_text + state.text_finished = state.text_finished | is_eos_token + + # ----------------------- + # Profile rows: text channel + # ----------------------- + profile_text_emb = None + if effective_profile_mask.any(): + profile_text_emb = self._embed_one_text_step( + profile_text_tokens, + force_dropout_text=force_dropout_text, + ) + + # ----------------------- + # Phoneme rows, normal only + # ----------------------- + if self.phoneme_tokenizer is not None and needs_phoneme.any(): + phoneme_emb = torch.zeros( + B, + 1, + self.cfg.embedding_dim, + dtype=dtype, + device=device, + ) + + if state.config.phoneme_input_type == "gt" and state.gt_phoneme_embeddings is not None: + within_gt_len = state.phoneme_steps < state.gt_phoneme_lens + positions = state.phoneme_steps.clamp( + max=state.gt_phoneme_embeddings.size(1) - 1 + ) + + gt_emb = state.gt_phoneme_embeddings[ + torch.arange(B, device=device), + positions, + :, + ].unsqueeze(1) + + phoneme_mask = (needs_phoneme & within_gt_len).view(B, 1, 1).to(dtype) + phoneme_emb = phoneme_emb + gt_emb * phoneme_mask + + else: + first_phoneme_step = needs_phoneme & (state.phoneme_steps == 0) + has_last_phoneme = ( + needs_phoneme + & (~first_phoneme_step) + & (state.last_phoneme_tokens is not None) + ) + + if first_phoneme_step.any(): + phoneme_bos = torch.full( + (B, self.phoneme_stacking_factor, 1), + self.phoneme_tokenizer.bos_token_id, + device=device, + dtype=torch.long, + ) + phoneme_bos_emb = self.embed_phoneme_tokens(phoneme_bos) + phoneme_emb = phoneme_emb + phoneme_bos_emb * first_phoneme_step.view(B, 1, 1).to(dtype) + + if has_last_phoneme.any() and state.last_phoneme_tokens is not None: + last_phoneme_emb = self.embed_phoneme_tokens( + state.last_phoneme_tokens.unsqueeze(2) + ) + phoneme_emb = phoneme_emb + last_phoneme_emb * has_last_phoneme.view(B, 1, 1).to(dtype) + + state.phoneme_stream_ended = ( + state.phoneme_stream_ended | state.phoneme_eos_detected + ) + + next_input = next_input + phoneme_emb + + # ----------------------- + # Normal audio rows + # ----------------------- + audio_emb = torch.zeros( + B, + 1, + self.cfg.embedding_dim, + dtype=dtype, + device=device, + ) + + if needs_audio.any(): + if state.gt_audio_embeddings is not None: + within_gt_len = state.audio_steps < state.gt_audio_lens + positions = state.audio_steps.clamp( + max=state.gt_audio_embeddings.size(1) - 1 + ) + + gt_emb = state.gt_audio_embeddings[ + torch.arange(B, device=device), + positions, + :, + ].unsqueeze(1) + + audio_mask = (needs_audio & within_gt_len).view(B, 1, 1).to(dtype) + audio_emb = audio_emb + gt_emb * audio_mask + + else: + first_audio_step = needs_audio & (state.audio_steps == 0) + has_last_audio = ( + needs_audio + & (~first_audio_step) + & (state.last_audio_codes is not None) + ) + + if first_audio_step.any(): + audio_bos = torch.full( + (B, self.num_audio_codebooks * self.frame_stacking_factor, 1), + self.audio_bos_id, + device=device, + dtype=torch.long, + ) + audio_bos_emb = self.embed_audio_tokens(audio_bos) + audio_emb = audio_emb + audio_bos_emb * first_audio_step.view(B, 1, 1).to(dtype) + + if has_last_audio.any() and state.last_audio_codes is not None: + last_audio_emb = self.embed_audio_tokens( + state.last_audio_codes.unsqueeze(2) + ) + audio_emb = audio_emb + last_audio_emb * has_last_audio.view(B, 1, 1).to(dtype) + + next_input = next_input + audio_emb + + # ----------------------- + # Profile audio input + # ----------------------- + C = self.num_audio_codebooks + S = self.frame_stacking_factor + + sil_codes = self.codec_sil_codes.to(device=device, dtype=torch.long) + + profile_silence_unstacked = ( + sil_codes.view(1, C, 1) + .expand(B, C, S) + .contiguous() + ) + + if self.cfg.get("use_user_speaking_token", False): + profile_audio_stacked = torch.full( + (B, C * S, 1), + self.audio_user_speaking_id, + dtype=torch.long, + device=device, + ) + else: + profile_audio_stacked, _ = self.stack_codes( + profile_silence_unstacked, + torch.full((B,), S, dtype=torch.long, device=device), + bos_id=self.audio_bos_id, + eos_id=self.audio_eos_id, + stacking_factor=S, + num_codebooks=C, + ) + + profile_last_audio_codes = profile_audio_stacked[:, :, -1].contiguous() + + if self.cfg.get("use_user_speaking_end_token", False): + profile_end_codes = torch.full( + (B, C * S), + self.audio_user_speaking_end_id, + dtype=torch.long, + device=device, + ) + profile_last_audio_codes = torch.where( + profile_end_mask.view(B, 1), + profile_end_codes, + profile_last_audio_codes, + ) + + profile_audio_emb = self.embed_audio_tokens(profile_audio_stacked) + + if effective_profile_mask.any(): + profile_emb = profile_audio_emb + + if profile_text_emb is not None: + profile_emb = profile_emb + profile_text_emb + + next_input = ( + next_input * (~effective_profile_mask).view(B, 1, 1).to(dtype) + + profile_emb * effective_profile_mask.view(B, 1, 1).to(dtype) + ) + + # ----------------------- + # CFG branch + # ----------------------- + if state.config.use_cfg: + next_input_uncond = torch.zeros_like(next_input) + + if needs_context.any(): + ctx_uncond = state.config.dummy_context_embedding_unconditional.expand( + B, + 1, + -1, + ) + next_input_uncond = next_input_uncond + ctx_uncond * needs_context.view(B, 1, 1).to(dtype) + + if needs_audio.any(): + next_input_uncond = next_input_uncond + audio_emb * needs_audio.view(B, 1, 1).to(dtype) + + if effective_profile_mask.any(): + # Match your streaming_prefill_profile behavior: + # conditional = profile text + profile audio + # unconditional = profile audio only + next_input_uncond = ( + next_input_uncond * (~effective_profile_mask).view(B, 1, 1).to(dtype) + + profile_audio_emb * effective_profile_mask.view(B, 1, 1).to(dtype) + ) + + next_input = torch.cat([next_input, next_input_uncond], dim=0) + + return ( + next_input, + needs_context, + needs_phoneme, + needs_audio, + effective_profile_mask, + profile_silence_unstacked, + profile_last_audio_codes, + ) + + def _process_predictions_profiled( + self, + state: StreamingState, + needs_context: torch.Tensor, + needs_phoneme: torch.Tensor, + needs_audio: torch.Tensor, + profile_mask: torch.Tensor, + active_mask: torch.Tensor, + profile_silence_unstacked: torch.Tensor, # (B, C, S) + profile_last_audio_codes: torch.Tensor, # (B, C*S) + ): + B = state.config.batch_size + device = state.config.device + C = self.num_audio_codebooks + S = self.frame_stacking_factor + + # Context always advances only for active context rows. + state.context_position = state.context_position + needs_context.long() + + # Logical text stream advances only for rows participating this step. + # This avoids idle finished rows drifting while other batch rows keep running. + logical_active = active_mask & (~needs_context) + state.text_tokens_seen = state.text_tokens_seen + logical_active.long() + + # Profile behaves like your old streaming_prefill_profile: + # it consumes an audio-step-like profile input, but does not sample audio. + state.audio_steps = state.audio_steps + needs_audio.long() + profile_mask.long() + + state.phoneme_steps = state.phoneme_steps + needs_phoneme.long() + + pred_phoneme_tokens = None + audio_codes_next = None + + # ----------------------- + # Phoneme prediction, normal rows only + # ----------------------- + if needs_phoneme.any() and self.phoneme_tokenizer is not None: + first_phoneme_step = needs_phoneme & (state.phoneme_prediction_start_idx == -1) + + if first_phoneme_step.any(): + current_phoneme_step_idx = len(state.all_phoneme_predictions) + state.phoneme_prediction_start_idx = torch.where( + first_phoneme_step, + torch.full_like(state.phoneme_prediction_start_idx, current_phoneme_step_idx), + state.phoneme_prediction_start_idx, + ) + + pred_phoneme_tokens = self._predict_phoneme_tokens(state) + state.last_phoneme_tokens = pred_phoneme_tokens + state.all_phoneme_predictions.append(pred_phoneme_tokens) + + phoneme_eos_detected = ( + needs_phoneme + & (pred_phoneme_tokens == self.phoneme_tokenizer.eos_token_id).any(dim=1) + ) + + state.phoneme_eos_detected = state.phoneme_eos_detected | phoneme_eos_detected + + newly_ended_phoneme = ( + phoneme_eos_detected + & (state.phoneme_prediction_end_idx == -1) + ) + + if newly_ended_phoneme.any(): + current_phoneme_step_idx = len(state.all_phoneme_predictions) + state.phoneme_prediction_end_idx = torch.where( + newly_ended_phoneme, + torch.full_like(state.phoneme_prediction_end_idx, current_phoneme_step_idx), + state.phoneme_prediction_end_idx, + ) + + # ----------------------- + # Audio/profile output + # ----------------------- + if needs_audio.any() or profile_mask.any(): + mixed_unstacked = profile_silence_unstacked.clone() + mixed_last_codes = profile_last_audio_codes.clone() + + sampled_stacked = None + sampled_argmax = None + + if needs_audio.any(): + first_audio_step = needs_audio & (state.audio_prediction_start_idx == -1) + + if first_audio_step.any(): + current_frame_idx = sum(p.size(-1) for p in state.all_predictions) + state.audio_prediction_start_idx = torch.where( + first_audio_step, + torch.full_like(state.audio_prediction_start_idx, current_frame_idx), + state.audio_prediction_start_idx, + ) + + sampled_stacked, sampled_argmax = self._predict_audio_codes(state) + sampled_unstacked = sampled_stacked.view(B, C, S) + + mixed_unstacked = torch.where( + needs_audio.view(B, 1, 1), + sampled_unstacked, + mixed_unstacked, + ) + + mixed_last_codes = torch.where( + needs_audio.view(B, 1), + sampled_stacked, + mixed_last_codes, + ) + + # Update last audio input only for rows that actually had audio/profile activity. + update_last = needs_audio | profile_mask + + if state.last_audio_codes is None: + state.last_audio_codes = torch.full( + (B, C * S), + self.audio_bos_id, + dtype=torch.long, + device=device, + ) + + state.last_audio_codes = torch.where( + update_last.view(B, 1), + mixed_last_codes, + state.last_audio_codes, + ) + + # EOS detection only for sampled normal audio rows, never profile rows. + if needs_audio.any() and state.gt_audio_embeddings is None: + sampled_argmax_unstacked = sampled_argmax.view(B, C, S) + sampled_unstacked = sampled_stacked.view(B, C, S) + + eos_in_sampled = sampled_unstacked == self.audio_eos_id + eos_in_argmax = sampled_argmax_unstacked == self.audio_eos_id + + eos_any_codebook = ( + eos_in_sampled.any(dim=1) + | eos_in_argmax.any(dim=1) + ) # (B, S) + + eos_frame_idx = torch.where( + eos_any_codebook.any(dim=1), + eos_any_codebook.int().argmax(dim=1), + torch.full((B,), S, device=device), + ) + + audio_eos_detected = eos_any_codebook.any(dim=1) & needs_audio + state.finished = state.finished | audio_eos_detected + + newly_ended_audio = ( + audio_eos_detected + & (state.audio_prediction_end_idx == -1) + ) + + if newly_ended_audio.any(): + current_frame_count = len(state.all_predictions) * S + end_frame_idx = current_frame_count + eos_frame_idx + + state.audio_prediction_end_idx = torch.where( + newly_ended_audio, + end_frame_idx, + state.audio_prediction_end_idx, + ) + + state.all_predictions.append(mixed_unstacked) + audio_codes_next = mixed_unstacked + + if state.gt_audio_embeddings is not None and state.gt_audio_lens is not None: + gt_exhausted = needs_audio & (state.audio_steps >= state.gt_audio_lens) + state.finished = state.finished | gt_exhausted + + return audio_codes_next, pred_phoneme_tokens + def _process_predictions( self, state: StreamingState, From d2cd65bafabec787779198f75a05593f3ca5055e Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 12 May 2026 12:02:24 -0700 Subject: [PATCH 040/109] Fix merge issue Signed-off-by: Edresson Casanova --- nemo/collections/common/data/lhotse/cutset.py | 108 ++++++++++++++++++ .../tts/models/easy_magpietts_inference.py | 2 +- 2 files changed, 109 insertions(+), 1 deletion(-) diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index 751cc87fee7f..4527800d514a 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -816,6 +816,114 @@ def cut_to_conversation( custom=cut.custom, ) + +@data_type_parser(["s2s_duplex_overlap_as_s2s_duplex"]) +def read_s2s_duplex_overlap_as_s2s_duplex(config) -> Tuple[CutSet, bool]: + """ + Convert a CutSet with overlapping agent/user segments into a standard S2S duplex format. + + Use Case: + This parser is designed for conversational data where agent and user speech can overlap + in time (e.g., natural turn-taking with interruptions or backchanneling). The input + format stores agent and user segments separately as `agent_segments` and `user_segments` + attributes on each cut. This function converts them into a unified timeline of sequential + SupervisionSegments, which is the standard format expected by DuplexS2S models. + + Expected Input Data Format: + Each cut should have: + - cut.agent_segments: List[Dict] with keys: + - "start" (float): Start time in seconds + - "end" (float): End time in seconds + - "text" (str): Agent's transcription + - cut.user_segments: List[Dict] with keys: + - "start" (float): Start time in seconds + - "end" (float): End time in seconds + - "text" (str): User's transcription + + Example: + Input cut with overlapping segments: + cut.agent_segments = [ + {"start": 0.5, "end": 2.0, "text": "Hello, how can I help?"}, + {"start": 3.0, "end": 4.5, "text": "Sure, I can do that."} + ] + cut.user_segments = [ + {"start": 1.8, "end": 3.2, "text": "I need assistance"}, + {"start": 4.0, "end": 5.5, "text": "Thank you"} + ] + + Output cut.supervisions (sorted by start time): + [ + SupervisionSegment(start=0.5, duration=1.5, text="Hello, how can I help?", speaker="agent"), + SupervisionSegment(start=1.8, duration=1.4, text="I need assistance", speaker="user"), + SupervisionSegment(start=3.0, duration=1.5, text="Sure, I can do that.", speaker="agent"), + SupervisionSegment(start=4.0, duration=1.5, "Thank you", speaker="user") + ] + + Args: + config: Dictionary containing parser options: + - move_agent_text_back_by (float): Time offset to shift agent text back (default: 0). + Useful for aligning agent text with earlier audio timing. + - filter_samples_starting_with_agent (bool): Whether to remove samples starting with agent (default: False). + When True, only keeps samples where the first speaker is a user. + - agent_roles (List[str]): Roles considered as agent (default: ["agent", "Assistant", "assistant"]). + + Returns: + Tuple[CutSet, bool]: Converted cuts with unified supervisions, and a flag indicating if the data was tarred. + """ + move_agent_text_back_by = config.get("move_agent_text_back_by", 0) + filter_samples_starting_with_agent = config.get("filter_samples_starting_with_agent", False) + agent_roles = config.get("agent_roles", ["agent", "Assistant", "assistant"]) + + cuts, is_tarred = read_cutset_from_config(config) + + def filter_cuts_starting_with_agent_fn(cuts: CutSet, agent_roles: Tuple[str, ...]) -> CutSet: + """Remove cuts where the first supervision belongs to an agent role.""" + + def _filter_fn(cut: Cut) -> bool: + if not cut.supervisions: + return False + cut.supervisions = sorted(cut.supervisions, key=lambda s: s.start) + return cut.supervisions[0].speaker not in agent_roles + + return cuts.filter(_filter_fn) + + def convert_overlap_cut_fn(cut: Cut) -> Cut: + """Convert agent/user overlapping segments into sequential SupervisionSegments.""" + agent_segments = [ + SupervisionSegment( + id=cut.id, + recording_id=cut.id, + start=seg["start"] - move_agent_text_back_by, + duration=seg["end"] - seg["start"] + move_agent_text_back_by, + text=seg["text"], + speaker="agent", + ) + for seg in cut.agent_segments + ] + + user_segments = [ + SupervisionSegment( + id=cut.id, + recording_id=cut.id, + start=seg["start"], + duration=seg["end"] - seg["start"], + text=seg["text"], + speaker="user", + ) + for seg in cut.user_segments + ] + + cut.supervisions = sorted(agent_segments + user_segments, key=lambda s: s.start) + cut.task = "s2s_duplex_overlap_as_s2s_duplex" + return cut + + cuts = cuts.map(convert_overlap_cut_fn) + if filter_samples_starting_with_agent: + cuts = filter_cuts_starting_with_agent_fn(cuts, tuple(agent_roles)) + + return cuts, is_tarred + + @data_type_parser(["lhotse_magpietts_data_as_continuation"]) def read_lhotse_magpietts_data_as_s2s_duplex(config) -> Tuple[CutSet, bool]: """ diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index 10deeb48035e..2f62632284a0 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -279,7 +279,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.disable_subword_embedding = cfg.get('disable_subword_embedding', False) # If True, remove the decoder LM head over text tokens to save parameters when the model does not train or # infer text-token logits from the decoder output. - self.disable_lm_text_head = cfg.get('disable_lm_text_head', False) + self.disable_lm_text_head = cfg.get('disable_lm_text_head', True) # Legacy checkpoints may have trained context text with decoder embeddings only, even when CAS is enabled for # regular text tokens. This flag skips adding CAS embeddings for context text to match those checkpoints. self.disable_cas_for_context_text = cfg.get('disable_cas_for_context_text', False) From cd955e10eac33f26c415ca65e2d45a36d0e9af6e Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 13 May 2026 06:27:06 -0700 Subject: [PATCH 041/109] Add restore custom checkpoint to avoid full model loading on .nemo checkpoint restore Signed-off-by: Edresson Casanova --- examples/tts/easy_magpietts.py | 4 +++ .../tts/models/easy_magpietts_inference.py | 26 ++++++++++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/examples/tts/easy_magpietts.py b/examples/tts/easy_magpietts.py index 5e9be71a7805..26872c8edb87 100644 --- a/examples/tts/easy_magpietts.py +++ b/examples/tts/easy_magpietts.py @@ -54,9 +54,13 @@ def main(cfg): model = EasyMagpieTTSModel(cfg=cfg.model, trainer=trainer) else: raise NotImplementedError(f"Only train, onlinepo_train and test modes are supported. Got {mode}") + + if cfg.get("pretrained_model", None): + model.restore_from_pretrained_checkpoint(cfg.pretrained_model) model.maybe_init_from_pretrained_checkpoint(cfg=cfg) + if mode in ['train', 'onlinepo_train']: trainer.fit(model) elif mode == 'test': diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index 2f62632284a0..0beec43538cc 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -15,6 +15,7 @@ import time import random import time +import tempfile from dataclasses import dataclass, fields from functools import partial from typing import Any, Dict, List, Optional, Sequence, Tuple @@ -28,6 +29,7 @@ from torch import nn from transformers import AutoConfig, AutoModelForCausalLM +from nemo.core.connectors.save_restore_connector import SaveRestoreConnector from nemo.collections.tts.data.text_to_speech_dataset_lhotse import setup_tokenizers from nemo.collections.tts.models import AudioCodecModel from nemo.collections.tts.modules import transformer_2501 @@ -622,6 +624,29 @@ def codec_sil_codes(self): def codec_sil_codes_unconverted(self): return self._codec_sil_codes_buffer_unconverted + def restore_from_pretrained_checkpoint(self, checkpoint_path): + """ + Loads model weights a pretrained checkpoint file, supporting partial loading from safetensor and PyTorch formats. + + Args: + checkpoint_path (str): Path to checkpoint file. + + Returns: + None. The model is updated in-place. + """ + if checkpoint_path is not None: + if '.nemo' in checkpoint_path: + with tempfile.TemporaryDirectory() as tmpdir: + SaveRestoreConnector._unpack_nemo_file(checkpoint_path, tmpdir) + checkpoint_path = f"{tmpdir}/model_weights.ckpt" + checkpoint_state = torch.load(checkpoint_path, map_location='cpu') + else: + checkpoint_state = torch.load(checkpoint_path, map_location='cpu') + checkpoint_state = set_model_dict_for_partial_init(checkpoint_state, self.state_dict()) + + self.load_state_dict(checkpoint_state, strict=True) + logging.info(f"Model restored from the checkpoint: {checkpoint_path} !") + def _generate_codec_silence_buffer(self): codec_device = next(self._codec_model.parameters()).device @@ -803,7 +828,6 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): return state_dict def load_state_dict(self, state_dict, strict=True): - state_dict = set_model_dict_for_partial_init(state_dict, self.state_dict()) if not strict: super().load_state_dict(state_dict, strict=False) modules_to_skip = self._get_state_dict_keys_to_exclude() From bd764f8f97fd648dcc695f8980f1f086ac9eab7c Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 14 May 2026 04:44:37 -0700 Subject: [PATCH 042/109] Add raw tts data support on TTS dataloader Signed-off-by: Edresson Casanova --- nemo/collections/common/data/lhotse/cutset.py | 5 ++-- ...text_to_speech_dataset_lhotse_multiturn.py | 26 ++++++++++++------- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index 4527800d514a..5dc3e944cfcb 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -1001,6 +1001,7 @@ def prepend_silence_monocut( def convert_cut_fn(cut: Cut) -> Cut: """Convert a single cut into the continuation format.""" + orig_agent_sup = fastcopy(cut.supervisions[0]) target_audio_orig_dur = cut.target_audio.duration @@ -1081,9 +1082,9 @@ def convert_cut_fn(cut: Cut) -> Cut: if cut.has_custom("target_codes"): cut_source.target_codes = cut.target_codes if cut.has_custom("lang"): - cut_source.lang = cut_source.lang + cut_source.lang = cut.lang if cut.has_custom("ipa"): - cut_source.ipa = cut_source.ipa + cut_source.ipa = cut.ipa cut_source.formatter = "lhotse_magpietts_data_as_continuation" return cut_source diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py index 9d66f086fb5c..05d6f41eddb3 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py @@ -14,7 +14,6 @@ import random import re from typing import Dict, List, Union -from copy import deepcopy import numpy as np import torch @@ -29,7 +28,6 @@ from transformers import AutoTokenizer, T5Tokenizer from nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers import AggregatedTTSTokenizer, IPABPETokenizer -from nemo.collections.speechlm2.data.utils import get_pad_id from nemo.collections.speechlm2.parts.precision import fp32_precision from nemo.collections.tts.parts.utils.tts_dataset_utils import ( beta_binomial_prior_distribution, @@ -176,7 +174,7 @@ def __init__( add_text_bos: bool = False, remove_user_turns_prob: float = None, ): - # super().__init__() + super().__init__() self.sample_rate = sample_rate self.volume_norm = volume_norm @@ -235,6 +233,10 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: self.phoneme_tokenizer = instantiate(self.phoneme_tokenizer_config) cuts = cuts.transform_text(_strip_timestamps) + for cut in cuts: + if str(getattr(cut, "task", "tts")).lower() == "tts": + for supervision in cut.supervisions: + supervision.speaker = "agent" batch_tokenizer_names = [] remove_user_turn_flags = [] @@ -248,7 +250,7 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: agent_sups = [sup for sup in cut.supervisions if sup.speaker in self.output_roles] # It is a multiturn if there's more than 1 agent turn - is_multiturn = not (len(agent_sups) == 1) + is_multiturn = not (len(agent_sups) <= 1) # Apply augmentation only if it's multiturn AND passes the probability check if is_multiturn and self.remove_user_turns_prob and random.random() < self.remove_user_turns_prob: @@ -268,17 +270,23 @@ def _align_codebooks(t): target_audio, target_audio_lens = collate_audio( cuts.resample(self.sample_rate, recording_field="target_audio"), recording_field="target_audio" ) - source_audio, source_audio_lens = collate_audio(cuts.resample(self.source_sample_rate)) target_audio_list = [] source_audio_list = [] # normalize volume and apply audio the removal of user turn if needed for i, cut in enumerate(cuts): remove_user_turn_this_cut = remove_user_turn_flags[i] - + # Extract the raw, unpadded 1D numpy array for this specific cut t_audio = target_audio[i, :target_audio_lens[i]].numpy() - s_audio = source_audio[i, :source_audio_lens[i]].numpy() + # For tts/single-turn cuts, source audio should be zeros like target audio + # For multi-turn cuts, keep loading source audio from the cut recording. + if str(getattr(cut, "task", "tts")).lower() == "tts": + src_len = int(round(len(t_audio) / self.sample_rate * self.source_sample_rate)) + s_audio = np.zeros(src_len, dtype=t_audio.dtype) + else: + s_audio = cut.resample(self.source_sample_rate).load_audio().squeeze(0) + if remove_user_turn_this_cut: collapsed_t, collapsed_s = [], [] for sup in cut.supervisions: @@ -601,7 +609,7 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) ) batch_dict["agent_mask"] = agent_mask - batch_dict["agent_mask_lens"] = agent_mask + batch_dict["agent_mask_lens"] = agent_mask_lens return batch_dict @@ -858,4 +866,4 @@ def build_phoneme_channel( phoneme_ids = phoneme_ids[:len(tokens) - pos] tokens[pos:pos+len(phoneme_ids)] = phoneme_ids - return tokens + return tokens \ No newline at end of file From 9a1c0d3b1d2ff2b46f5bbfe46706b8ecaa03c8dc Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 14 May 2026 11:35:22 -0700 Subject: [PATCH 043/109] Remove complex prefil code and add slupport to nemotron_h on prefill Signed-off-by: Edresson Casanova --- ...gpietts_inference_multiturn_new_prefill.py | 1186 ----------------- .../tts/models/easy_magpietts_inference.py | 650 +-------- 2 files changed, 37 insertions(+), 1799 deletions(-) delete mode 100644 examples/tts/easy_magpietts_inference_multiturn_new_prefill.py diff --git a/examples/tts/easy_magpietts_inference_multiturn_new_prefill.py b/examples/tts/easy_magpietts_inference_multiturn_new_prefill.py deleted file mode 100644 index 30aedefb89ae..000000000000 --- a/examples/tts/easy_magpietts_inference_multiturn_new_prefill.py +++ /dev/null @@ -1,1186 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -""" -Evaluation script for custom EasyMagpieTTS models. -Features explicit Duplex (10x Padding) and Regular (Turn-by-turn) multi-turn modes. - -Usage: - python easy_magpietts_eval.py \ - --checkpoint_path=/path/to/magpie/model.ckpt \ - --codec_model_path=/path/to/codec/model.ckpt \ - --datasets_json_path=/path/to/evalset_config.jsonl \ - --out_dir=/path/to/out/audio \ - --batch_size=6 \ - --use_cfg \ - --use_librosa -""" - -import argparse -import json -import os -from copy import deepcopy -from functools import partial - -import librosa -import soundfile as sf -import torch -from torch.nn.utils.rnn import pad_sequence -from torch.utils.data import DataLoader, Dataset -from omegaconf import OmegaConf, open_dict - -from nemo.collections.audio.parts.utils.transforms import resample -from nemo.collections.speechlm2.parts.metrics.asr_cer_wer import Intelligibility -from nemo.collections.speechlm2.parts.metrics.secs import SECS -from nemo.collections.speechlm2.parts.precision import fp32_precision -from nemo.utils import logging - -# --- EasyMagpieTTS Imports --- -from nemo.collections.tts.models import AudioCodecModel -from nemo.collections.tts.modules.audio_codec_modules import VectorQuantizerIndexConverter -from nemo.collections.tts.modules.magpietts_modules import CodecHelper -from nemo.collections.tts.models.easy_magpietts_inference import EasyMagpieTTSInferenceModel - -from nemo.collections.tts.parts.utils.tts_dataset_utils import normalize_volume - -torch.set_float32_matmul_precision("medium") -torch.backends.cudnn.allow_tf32 = True -torch.backends.cuda.matmul.allow_tf32 = True - -if torch.cuda.is_available(): - torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0))) - - -def attach_dtype_counter(model): - handles = [] - stats = {} - examples = {} - - def is_leaf(module): - return len(list(module.children())) == 0 - - def get_dtype(x): - if torch.is_tensor(x): - return str(x.dtype) - elif isinstance(x, (list, tuple)): - for t in x: - if torch.is_tensor(t): - return str(t.dtype) - return "other" - - def get_module_group(name): - return name.split(".")[0] if "." in name else name - - def hook_fn(name): - def fn(module, inputs, outputs): - dtype = get_dtype(outputs) - if dtype not in ["torch.float16", "torch.bfloat16", "torch.float32"]: - dtype = "other" - group = get_module_group(name) - if group not in stats: - stats[group] = { - "torch.float16": 0, "torch.bfloat16": 0, "torch.float32": 0, "other": 0, - } - examples[group] = { - "torch.float16": [], "torch.bfloat16": [], "torch.float32": [], "other": [], - } - stats[group][dtype] += 1 - if len(examples[group][dtype]) < 3: - examples[group][dtype].append(module.__class__.__name__) - return fn - - for name, module in model.named_modules(): - if is_leaf(module): - handles.append(module.register_forward_hook(hook_fn(name))) - return handles, stats, examples - - -def report_dtype_stats(handles, stats, examples): - for h in handles: - h.remove() - logging.info("\n=== DTYPE USAGE PER MODULE ===") - for group, group_stats in stats.items(): - total = sum(group_stats.values()) - if total == 0: continue - logging.info(f"\n--- {group} ---") - for dtype, count in group_stats.items(): - if count > 0: - logging.info(f"{dtype}: {count} ({100*count/total:.2f}%)") - logging.info("\n=== EXAMPLES ===") - for group, group_examples in examples.items(): - logging.info(f"\n--- {group} ---") - for dtype, mods in group_examples.items(): - if mods: - logging.info(f"{dtype}: {mods}") - - -class EvalJSONLDataset(Dataset): - def __init__(self, file_path, num_turns=1): - self.samples = [] - raw_samples = [] - with open(file_path, "r", encoding="utf-8") as f: - for line_idx, line in enumerate(f, 1): - line = line.strip() - if not line: continue - try: - raw_samples.append(json.loads(line)) - except json.JSONDecodeError as e: - raise ValueError(f"Invalid JSON on line {line_idx}: {e}") - - if num_turns <= 1: - self.samples = raw_samples - return - - single_turn_by_speaker = {} - for sample in raw_samples: - if isinstance(sample["text"], list): - self.samples.append(sample) - else: - speaker = sample.get("speaker", "unknown") - if speaker not in single_turn_by_speaker: - single_turn_by_speaker[speaker] = [] - single_turn_by_speaker[speaker].append(sample) - - for speaker, speaker_samples in single_turn_by_speaker.items(): - buffer_texts, buffer_paths = [], [] - first_sample_meta = None - - for sample in speaker_samples: - if not buffer_texts: - first_sample_meta = dict(sample) - buffer_texts.append(sample["text"]) - buffer_paths.append(sample.get("audio_filepath", "")) - - if len(buffer_texts) == num_turns: - first_sample_meta["text"] = buffer_texts - base_names = [os.path.splitext(os.path.basename(p))[0] for p in buffer_paths if p] - ext = os.path.splitext(buffer_paths[-1])[1] if buffer_paths[-1] else "" - combined_name = "_".join(base_names) + ext - dir_name = os.path.dirname(first_sample_meta.get("audio_filepath", "")) - if dir_name: - first_sample_meta["audio_filepath"] = os.path.join(dir_name, combined_name) - else: - first_sample_meta["audio_filepath"] = combined_name - - self.samples.append(first_sample_meta) - buffer_texts, buffer_paths, first_sample_meta = [], [], None - - if buffer_texts: - first_sample_meta["text"] = buffer_texts - base_names = [os.path.splitext(os.path.basename(p))[0] for p in buffer_paths if p] - ext = os.path.splitext(buffer_paths[-1])[1] if buffer_paths[-1] else "" - combined_name = "_".join(base_names) + ext - dir_name = os.path.dirname(first_sample_meta.get("audio_filepath", "")) - if dir_name: - first_sample_meta["audio_filepath"] = os.path.join(dir_name, combined_name) - else: - first_sample_meta["audio_filepath"] = combined_name - self.samples.append(first_sample_meta) - - def __len__(self): - return len(self.samples) - - def __getitem__(self, idx): - return self.samples[idx] - - -def collate_and_tokenize_custom( - batch, - model, - extra_duration_thrshould=1.3, - sample_rate=22050, - root_path=None, - emulate_duplex_inference=False, - add_interruption_token=False, - pad_factor_text_speech=10, - force_interruption=False, - normalize_context_audio_volume=True, - use_librosa=False, - profile_multiturn_inference=False, -): - main_tokenizer_name = list(model.cfg.text_tokenizers.keys())[0] - - # --- MULTI-TURN MODE DECISION --- - is_profile = profile_multiturn_inference - is_duplex = emulate_duplex_inference and not is_profile - - out_dict = { - "duplex_multiturn": is_duplex, - "regular_multiturn": (not is_duplex) and (not is_profile), - "profile_multiturn": is_profile, - } - - tokenized_list = [] - batched_turns = [] - batched_turn_lens = [] - valid_turn_masks = [] - - if is_duplex: - # ------------------------------------------------------------- - # DUPLEX MODE (Continuous sequence with 10x pad injection) - # ------------------------------------------------------------- - for s in batch: - text_data = s["text"] - - if isinstance(text_data, list): - full_ids = [] - for segment in text_data: - seg_ids = model.tokenizer.encode(segment, tokenizer_name=main_tokenizer_name) + [model.eos_id] - seg_len = len(seg_ids) - pad_len = seg_len * pad_factor_text_speech - pad_ids = [model.pad_id] * pad_len - - if force_interruption: - fname = s["audio_filepath"] - no_ext = fname.split(".")[0] - sample_id = int(no_ext.split("_")[-1]) - case = sample_id % 3 - - if case == 0: - if len(seg_ids) >= 2: - seg_ids[-2] = model.interruption_token_id - seg_ids[-1] = model.pad_id - else: - pad_ids[0] = model.interruption_token_id - elif case == 1: - eos_idx = min(6, len(pad_ids) - 1) - pad_ids[eos_idx] = model.interruption_token_id - else: - eos_idx = 0 - pad_ids[eos_idx] = model.interruption_token_id - else: - if add_interruption_token: - eos_idx = int(len(pad_ids) * 0.7) - pad_ids[eos_idx] = model.interruption_token_id - - full_ids.extend(seg_ids) - full_ids.extend(pad_ids) - - tokenized_list.append(torch.as_tensor(full_ids, dtype=torch.long)) - else: - tokenized_list.append( - torch.as_tensor(model.tokenizer.encode(text_data, tokenizer_name=main_tokenizer_name) + [model.eos_id], dtype=torch.long) - ) - - pad_len = 25 - prefix = torch.full((pad_len,), model.pad_id, dtype=torch.long) - for i in range(len(tokenized_list)): - tokenized_list[i] = torch.cat([prefix, tokenized_list[i]]) - input_lengths = torch.tensor([len(x) for x in tokenized_list], dtype=torch.long) - - input_ids = pad_sequence(tokenized_list, batch_first=True, padding_value=model.pad_id) - - out_dict["input_ids"] = input_ids - out_dict["input_lengths"] = input_lengths - - else: - # ------------------------------------------------------------- - # REGULAR MODE (Turn-by-turn discrete packaging) - # ------------------------------------------------------------- - max_turns = 1 - for s in batch: - if isinstance(s["text"], list): - max_turns = max(max_turns, len(s["text"])) - - for t in range(max_turns): - turn_t_tokens = [] - turn_t_lens = [] - turn_t_valid = [] - - for s in batch: - text_data = s["text"] - if isinstance(text_data, list): - if t < len(text_data): - seg_ids = model.tokenizer.encode(text_data[t], tokenizer_name=main_tokenizer_name) + [model.eos_id] - turn_t_tokens.append(torch.as_tensor(seg_ids, dtype=torch.long)) - turn_t_lens.append(len(seg_ids)) - turn_t_valid.append(True) - else: - # Dummy pad to keep shapes consistent for items with fewer turns - turn_t_tokens.append(torch.as_tensor([model.pad_id], dtype=torch.long)) - turn_t_lens.append(1) - turn_t_valid.append(False) - else: - if t == 0: - seg_ids = model.tokenizer.encode(text_data, tokenizer_name=main_tokenizer_name) + [model.eos_id] - turn_t_tokens.append(torch.as_tensor(seg_ids, dtype=torch.long)) - turn_t_lens.append(len(seg_ids)) - turn_t_valid.append(True) - else: - turn_t_tokens.append(torch.as_tensor([model.pad_id], dtype=torch.long)) - turn_t_lens.append(1) - turn_t_valid.append(False) - - padded_turn_t = pad_sequence(turn_t_tokens, batch_first=True, padding_value=model.pad_id) - batched_turns.append(padded_turn_t) - batched_turn_lens.append(torch.tensor(turn_t_lens, dtype=torch.long)) - valid_turn_masks.append(torch.tensor(turn_t_valid, dtype=torch.bool)) - - out_dict["batched_turns"] = batched_turns - out_dict["batched_turn_lens"] = batched_turn_lens - out_dict["valid_turn_masks"] = valid_turn_masks - - # --- AUDIO LOADING --- - audio_list = [] - audio_lengths = [] - target_num_frames = [] - - for i, s in enumerate(batch): - audio_path = s["context_audio_filepath"] - if root_path is not None: - audio_path = os.path.join(root_path, audio_path) - - if os.path.exists(audio_path): - if use_librosa: - wav, sr = librosa.load(audio_path, sr=sample_rate, mono=True) - if normalize_context_audio_volume: - wav = normalize_volume(wav) - wav = torch.as_tensor(wav, dtype=torch.float32) - else: - wav, sr = sf.read(audio_path, dtype='float32') - # Force Mono - if wav.ndim > 1: - wav = wav.mean(axis=1) - - if normalize_context_audio_volume: - wav = normalize_volume(wav) - - # Convert to tensor, add batch dim for resampler, then remove it - wav_tensor = torch.as_tensor(wav, dtype=torch.float32).unsqueeze(0) - wav = resample(wav_tensor, sr, sample_rate).squeeze(0) - else: - wav = torch.zeros(1, dtype=torch.float32) - - audio_list.append(wav) - audio_lengths.append(len(wav)) - - tdur_audio_path = s["audio_filepath"] - if root_path is not None: - tdur_audio_path = os.path.join(root_path, tdur_audio_path) - - if tdur_audio_path and os.path.exists(tdur_audio_path): - wav_dur, sr_ = librosa.load(tdur_audio_path, sr=sample_rate, mono=True) - tdur = wav_dur.shape[0] // model.input_samples_per_frame - target_num_frames.append(tdur * extra_duration_thrshould) - else: - # Fallback estimation - if is_duplex: - current_text_len = len(tokenized_list[i]) - if isinstance(s["text"], list): - target_num_frames.append(current_text_len) - else: - target_num_frames.append(current_text_len * 5) - else: - target_num_frames.append(sum([l[i].item() for l in batched_turn_lens]) * 5) - - max_audio_len = max(audio_lengths) - B = len(audio_lengths) - padded_audio = torch.zeros((B, max_audio_len), dtype=torch.float32) - - for i, wav in enumerate(audio_list): - padded_audio[i, : len(wav)] = wav - - out_dict["context_audio"] = padded_audio - out_dict["context_audio_lengths"] = torch.tensor(audio_lengths, dtype=torch.long) - out_dict["target_audio_paths"] = [s["audio_filepath"] for s in batch] - out_dict["target_num_frames"] = target_num_frames - - out_dict["raw_text"] = [" ".join(s["text"]) if isinstance(s["text"], list) else s["text"] for s in batch] - - return out_dict - - -def main(): - parser = argparse.ArgumentParser(description="EasyMagpieTTS Inference Evaluation") - - # Required Paths - parser.add_argument("--checkpoint_path", type=str, required=True, help="Path to the EasyMagpie model") - parser.add_argument("--codec_model_path", type=str, required=True, help="Path to the audio codec") - parser.add_argument("--datasets_json_path", type=str, required=True, help="Path to JSONL data") - parser.add_argument("--out_dir", type=str, required=True, help="Directory to save audio outputs") - - # Optional Paths & General - parser.add_argument("--phoneme_tokenizer_path", type=str, default=None) - parser.add_argument("--audio_dir", type=str, default=None, help="Root dir for audio paths in JSONL") - parser.add_argument("--inference_dtype", type=str, default="float32") - parser.add_argument("--debug_dtype", action="store_true") - parser.add_argument("--use_librosa", action="store_true", help="Use librosa instead of soundfile+torch for audio load") - - # Dataloader & Batching - parser.add_argument("--batch_size", type=int, default=6) - parser.add_argument("--num_workers", type=int, default=4) - parser.add_argument("--num_turns", type=int, default=1) - parser.add_argument("--pad_factor_text_speech", type=int, default=10) - - # Text Processing Boolean Flags - parser.add_argument("--emulate_duplex_inference", action="store_true") - parser.add_argument("--add_interruption_token", action="store_true") - parser.add_argument("--force_interruption", action="store_true") - parser.add_argument("--profile_multiturn_inference", action="store_true") - parser.add_argument("--profile_pad_min_sec", type=float, default=2.0) - parser.add_argument("--profile_pad_max_sec", type=float, default=2.0) - - - # Speaker & Prompt Configurations - parser.add_argument("--user_custom_speaker_reference", action="store_true") - parser.add_argument("--inference_speaker_reference", type=str, default=None) - parser.add_argument("--language", type=str, default="en") - - # Generation Kwargs - parser.add_argument("--use_cfg", action="store_true") - parser.add_argument("--cfg_scale", type=float, default=2.5) - parser.add_argument("--temperature", type=float, default=0.7) - parser.add_argument("--topk", type=int, default=80) - parser.add_argument("--max_tts_steps", type=int, default=2000) - parser.add_argument("--force_speech_sil_codes", action="store_true") - parser.add_argument("--normalize_volume", type=lambda x: (str(x).lower() in ['true', '1', 'yes']), default=False) - - args = parser.parse_args() - - if args.profile_pad_max_sec < args.profile_pad_min_sec: - raise RuntimeError("--profile_pad_max_sec must be >= --profile_pad_min_sec.") - - distributed = int(os.environ.get("WORLD_SIZE", "1")) > 1 - if distributed and not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="nccl") - - target_device = torch.device(f"cuda:{int(os.environ.get('LOCAL_RANK', 0))}" if torch.cuda.is_available() else "cpu") - target_dtype = getattr(torch, args.inference_dtype) - torch.set_default_dtype(target_dtype) - - model_cfg = EasyMagpieTTSInferenceModel.restore_from(args.checkpoint_path, return_config=True) - with open_dict(model_cfg): - model_cfg.target = 'nemo.collections.tts.models.easy_magpietts_inference.EasyMagpieTTSInferenceModel' - model_cfg.codecmodel_path = args.codec_model_path - model_cfg.train_ds = None - model_cfg.validation_ds = None - model_cfg.run_val_inference = False - model_cfg.use_utmos = False - model_cfg.use_meta_init_for_decoder = True - - # Guarantees silence for pad tokens - model_cfg.use_multiturn_dataset = True - - if args.phoneme_tokenizer_path and getattr(model_cfg, "phoneme_tokenizer", None) is not None: - model_cfg.phoneme_tokenizer.tokenizer_path = args.phoneme_tokenizer_path - - # Load to CPU first to prevent OOM - model = EasyMagpieTTSInferenceModel.restore_from( - args.checkpoint_path, override_config_path=model_cfg, map_location=torch.device("cpu") - ) - model.use_kv_cache_for_inference = True - model.to(dtype=target_dtype) - model.eval().to(target_device) - - # --- DATALOADER COMPATIBILITY PATCHES --- - model.input_samples_per_frame = int(model.codec_model_samples_per_frame * model.frame_stacking_factor) - model.target_samples_per_frame = model.input_samples_per_frame / (model.sample_rate / model.output_sample_rate) - - # Load to CPU first to prevent OOM - codec_model = AudioCodecModel.restore_from(args.codec_model_path, strict=False, map_location=torch.device("cpu")) - if hasattr(codec_model, "discriminator"): - del codec_model.discriminator - codec_model.freeze() - codec_model = codec_model.to(target_device).eval() - - codec_converter = None - if getattr(model, "_codec_converter", None) is not None: - vq_new = deepcopy(model._codec_converter.vector_quantizer_new).to(target_device).eval() - codec_converter = VectorQuantizerIndexConverter( - vector_quantizer_original=codec_model.vector_quantizer, - vector_quantizer_new=vq_new, - ).to(target_device).eval() - - if not hasattr(model, "_codec_helper") or model._codec_helper is None: - model._codec_helper = CodecHelper(codec_model=codec_model, codec_converter=codec_converter) - - model._generate_codec_silence_buffer() - codec_sil_codes = model.codec_sil_codes - - if args.debug_dtype: - handles, stats, examples = attach_dtype_counter(model) - - with fp32_precision(): - intelligibility = Intelligibility("stt_en_fastconformer_transducer_large", reuse_asr_hyps=False).reset() - secs_metric = SECS("titanet_large").reset() - - eval_dataset = EvalJSONLDataset(args.datasets_json_path, num_turns=args.num_turns) - - collate_fn = partial( - collate_and_tokenize_custom, - model=model, - extra_duration_thrshould=1.5, - sample_rate=model.sample_rate, - root_path=args.audio_dir, - emulate_duplex_inference=args.emulate_duplex_inference, - add_interruption_token=args.add_interruption_token, - pad_factor_text_speech=args.pad_factor_text_speech, - force_interruption=args.force_interruption, - normalize_context_audio_volume=args.normalize_volume, - use_librosa=args.use_librosa, - profile_multiturn_inference=args.profile_multiturn_inference - ) - - dataloader = DataLoader( - dataset=eval_dataset, batch_size=args.batch_size, collate_fn=collate_fn, - num_workers=args.num_workers, pin_memory=True, shuffle=False, drop_last=False, - ) - - if args.user_custom_speaker_reference and args.inference_speaker_reference: - if args.use_librosa: - wav, sr = librosa.load(args.inference_speaker_reference, sr=model.sample_rate, mono=True) - if args.normalize_volume: - wav = normalize_volume(wav) - speaker_wav = torch.as_tensor(wav, dtype=target_dtype).unsqueeze(0).to(model.device) - else: - wav, sr = sf.read(args.inference_speaker_reference, dtype='float32') - # Force Mono - if wav.ndim > 1: - wav = wav.mean(axis=1) - - if args.normalize_volume: - wav = normalize_volume(wav) - - speaker_wav = torch.as_tensor(wav).unsqueeze(0) - speaker_wav = resample(speaker_wav.float(), sr, model.sample_rate).to(target_dtype).to(model.device) - - for batch_id, inputs in enumerate(dataloader): - B = inputs["context_audio"].size(0) - device = model.device - - inputs["context_audio"] = inputs["context_audio"].to(device, dtype=target_dtype) - inputs["context_audio_lengths"] = inputs["context_audio_lengths"].to(device) - - if args.user_custom_speaker_reference and args.inference_speaker_reference: - inputs["context_audio"] = speaker_wav.repeat(B, 1).detach() - inputs["context_audio_lengths"] = torch.full((B,), speaker_wav.size(-1), dtype=torch.long, device=device) - - B = inputs["context_audio"].size(0) - profile_turn_frame_ranges = [[] for _ in range(B)] - profile_decode_start_frame = None - with torch.inference_mode(): - wav = inputs["context_audio"] - wav_len = inputs["context_audio_lengths"] - codes, codes_lens = model._codec_helper.audio_to_codes(wav, wav_len) - - use_lang = bool(getattr(model, "add_language_to_context_text", False)) - ctx_text = f"[{args.language.upper()}]" if use_lang else "[NO TEXT CONTEXT]" - ctx_text_ids = model.tokenizer.encode(ctx_text, tokenizer_name=model.text_conditioning_tokenizer_name) - - ctx_toks = torch.tensor([ctx_text_ids], dtype=torch.long, device=device).expand(B, -1) - ctx_toks_lens = torch.tensor([len(ctx_text_ids)] * B, dtype=torch.long, device=device) - - state = model.streaming_init( - context_audio_codes=codes, - context_audio_codes_lens=codes_lens, - context_text_tokens=ctx_toks, - context_text_tokens_lens=ctx_toks_lens, - use_cfg=args.use_cfg, - cfg_scale=args.cfg_scale, - use_local_transformer=True, - temperature=args.temperature, - topk=args.topk, - phoneme_input_type="pred", - phoneme_sampling_method="argmax", - use_inference_mode=True, - ) - # --------------------------------------------------------- - # MODE 1: DUPLEX (Continuous Padding Token Stream) - # --------------------------------------------------------- - if inputs["duplex_multiturn"]: - text = inputs["input_ids"].to(device) - text_lens = inputs["input_lengths"].to(device) - - # Trackers for our two forced-silence zones - in_initial_silence = torch.ones(B, dtype=torch.bool, device=device) - in_post_speech_silence = torch.zeros(B, dtype=torch.bool, device=device) - - text_exhausted = state.text_tokens_seen >= text_lens - while not text_exhausted.all(): - # 1. WAKE UP OVERRIDE: Keep the text pipeline awake to read pads! - state.finished = state.finished & text_exhausted - state.text_finished = state.text_finished & text_exhausted - if hasattr(state, "phoneme_stream_ended"): - state.phoneme_stream_ended = state.phoneme_stream_ended & text_exhausted - - # 2. Safely index text using the model's internal pointer - positions = state.text_tokens_seen.clamp(max=text.size(1) - 1) - current_tokens = text[torch.arange(B, device=device), positions] - - current_tokens = torch.where( - text_exhausted, torch.full_like(current_tokens, model.eos_id), current_tokens - ) - - # 3. Update our trackers BEFORE the step - is_pad_or_eos = (current_tokens == model.pad_id) | (current_tokens == model.eos_id) - in_initial_silence = in_initial_silence & is_pad_or_eos - in_post_speech_silence = in_post_speech_silence & is_pad_or_eos - - # 4. Step the model - state, audio_codes, _ = model.streaming_step(state=state, text_tokens=current_tokens, use_inference_mode=True) - - # 5. SILENCE FORCING INJECTION - if audio_codes is not None and args.force_speech_sil_codes: - force_silence_mask = in_initial_silence | in_post_speech_silence - - if force_silence_mask.any(): - # Expand silence codes [C] -> [1, C, 1] to match audio_codes [B, C, 1] - expanded_sil = codec_sil_codes.view(1, -1, 1).expand_as(audio_codes) - # Expand mask [B] -> [B, 1, 1] for broadcasting - mask_3d = force_silence_mask.view(B, 1, 1) - # Overwrite the prediction with silence codes where the mask is True. - overwritten_codes = torch.where(mask_3d, expanded_sil, audio_codes) - # Inject back into the model's KV cache history - state.all_predictions[-1] = overwritten_codes - - # 6. TRIGGER POST-SPEECH SILENCE FOR THE *NEXT* FRAME - in_post_speech_silence = in_post_speech_silence | state.finished - - # Update exhaustion tracker for the next iteration - text_exhausted = state.text_tokens_seen >= text_lens - - # --------------------------------------------------------- - # MODE 2: REGULAR (Turn-by-Turn Re-wakes) - # --------------------------------------------------------- - elif inputs["regular_multiturn"]: - batched_turns = inputs["batched_turns"] - batched_turn_lens = inputs["batched_turn_lens"] - valid_turn_masks = inputs["valid_turn_masks"] - - max_turns = len(batched_turns) - turn_offsets = torch.zeros(B, dtype=torch.long, device=device) - - for t in range(max_turns): - turn_text = batched_turns[t].to(device) - turn_lens = batched_turn_lens[t].to(device) - valid_mask = valid_turn_masks[t].to(device) - - state.finished = state.finished & (~valid_mask) - state.text_finished = state.text_finished & (~valid_mask) - - if hasattr(state, "phoneme_stream_ended"): - state.phoneme_stream_ended = state.phoneme_stream_ended & (~valid_mask) - - if state.finished.all(): - continue - - turn_offsets = torch.where(valid_mask, state.text_tokens_seen, turn_offsets) - turn_steps = 0 - - while not state.finished.all() and turn_steps < args.max_tts_steps: - turn_steps += 1 - - relative_positions = state.text_tokens_seen - turn_offsets - positions = relative_positions.clamp(min=0, max=turn_text.size(1) - 1) - current_tokens = turn_text[torch.arange(B, device=device), positions] - - exhausted = relative_positions >= turn_lens - current_tokens = torch.where(exhausted, torch.full_like(current_tokens, model.eos_id), current_tokens) - - state, _, _ = model.streaming_step(state=state, text_tokens=current_tokens, use_inference_mode=True) - - # --------------------------------------------------------- - # MODE 3: PROFILE MULTI-TURN, BATCHED - # --------------------------------------------------------- - elif inputs["profile_multiturn"]: - batched_turns = inputs["batched_turns"] - batched_turn_lens = inputs["batched_turn_lens"] - valid_turn_masks = inputs["valid_turn_masks"] - - max_turns = len(batched_turns) - - # Per-sample debug/save metadata. - # profile_turn_frame_ranges[b] = [(turn_id, start_frame, end_frame), ...] - profile_turn_frame_ranges = [[] for _ in range(B)] - - # First decoded frame per sample. Used by streaming_finalize and by per-turn saving. - profile_decode_start_frame = torch.full( - (B,), - -1, - dtype=torch.long, - device=device, - ) - - # Last meaningful decoded frame per sample. This avoids keeping trailing batch-idle silence. - profile_final_end_frame = torch.full( - (B,), - -1, - dtype=torch.long, - device=device, - ) - - arange_B = torch.arange(B, device=device) - - for t in range(max_turns): - turn_text = batched_turns[t].to(device) # (B, T_text_t) - turn_lens = batched_turn_lens[t].to(device) # (B,) - valid_mask = valid_turn_masks[t].to(device) # (B,) - - if not valid_mask.any(): - continue - - # Re-open only rows that have this turn. - state.finished = state.finished & (~valid_mask) - state.text_finished = state.text_finished & (~valid_mask) - - # Let this turn detect its own EOS, but keep audio_prediction_start_idx - # as the first start of the full generated conversation. - state.audio_prediction_end_idx = torch.where( - valid_mask, - torch.full_like(state.audio_prediction_end_idx, -1), - state.audio_prediction_end_idx, - ) - - if hasattr(state, "phoneme_stream_ended"): - state.phoneme_stream_ended = state.phoneme_stream_ended & (~valid_mask) - if hasattr(state, "phoneme_eos_detected"): - state.phoneme_eos_detected = state.phoneme_eos_detected & (~valid_mask) - - # Optional but usually cleaner for turn-level phoneme generation. - if hasattr(state, "phoneme_steps"): - state.phoneme_steps = torch.where( - valid_mask, - torch.zeros_like(state.phoneme_steps), - state.phoneme_steps, - ) - - # Optional but cleaner for turn-local audio step accounting. - # Profile steps below will immediately set last_audio_codes and increment audio_steps. - state.audio_steps = torch.where( - valid_mask, - torch.zeros_like(state.audio_steps), - state.audio_steps, - ) - - # ----------------------------- - # 1. Batched per-row profiling - # ----------------------------- - profile_seconds = ( - args.profile_pad_min_sec - + torch.rand((B,), device=device) - * (args.profile_pad_max_sec - args.profile_pad_min_sec) - ) - - profile_T = torch.round( - profile_seconds * model.sample_rate / model.input_samples_per_frame - ).to(torch.long) - profile_T = torch.clamp(profile_T, min=1) - - profile_T = torch.where( - valid_mask, - profile_T, - torch.zeros_like(profile_T), - ) - - profile_remaining = profile_T.clone() - profile_step_idx = torch.zeros(B, dtype=torch.long, device=device) - - # Old behavior: put the first streaming_speech_delay BPE tokens into - # the last profile positions, then do normal generation after that. - # - # This keeps your current working behavior, but now per batch row. - if model.cfg.get("agent_mask_include_transition_prefix", False): - delay_tokens = torch.zeros(B, dtype=torch.long, device=device) - else: - delay_tokens = torch.full( - (B,), - int(state.config.training_mode.streaming_speech_delay), - dtype=torch.long, - device=device, - ) - - # Safer version: keep at least one token, normally EOS, outside profile. - # If you want exact old behavior, replace `turn_lens - 1` with `turn_lens`. - max_consumable_text = torch.clamp(turn_lens - 1, min=0) - - delay_tokens = torch.minimum(delay_tokens, max_consumable_text) - delay_tokens = torch.minimum(delay_tokens, profile_T) - delay_tokens = torch.where(valid_mask, delay_tokens, torch.zeros_like(delay_tokens)) - - profile_text_consumed = torch.zeros(B, dtype=torch.long, device=device) - - while profile_remaining.max().item() > 0: - profile_mask = valid_mask & (profile_remaining > 0) - profile_end_mask = profile_mask & (profile_remaining == 1) - - # Default profile text is PAD. Rows in the last delay_tokens profile - # positions receive real BPE text tokens. - profile_text_tokens = torch.full( - (B,), - model.pad_id, - dtype=torch.long, - device=device, - ) - - if (delay_tokens > 0).any(): - # step_in_profile: 0, 1, ..., profile_T[b]-1 - step_in_profile = profile_step_idx - - # Emit text only in the final delay_tokens[b] profile steps. - emit_profile_text = ( - profile_mask - & (delay_tokens > 0) - & (step_in_profile >= (profile_T - delay_tokens)) - & (profile_text_consumed < delay_tokens) - ) - - if emit_profile_text.any(): - text_pos = profile_text_consumed.clamp( - min=0, - max=turn_text.size(1) - 1, - ) - - gathered_profile_text = turn_text[arange_B, text_pos] - profile_text_tokens = torch.where( - emit_profile_text, - gathered_profile_text, - profile_text_tokens, - ) - - profile_text_consumed = profile_text_consumed + emit_profile_text.long() - - # Only rows with profile_mask=True are active in this step. - # Other rows receive silence in the rectangular all_predictions tensor, - # but their logical counters do not advance. - state, _, _ = model.streaming_step_profiled( - state=state, - text_tokens=torch.full( - (B,), - model.eos_id, - dtype=torch.long, - device=device, - ), - profile_mask=profile_mask, - profile_text_tokens=profile_text_tokens, - profile_end_mask=profile_end_mask, - active_mask=profile_mask, - use_inference_mode=True, - ) - - profile_remaining = torch.where( - profile_mask, - profile_remaining - 1, - profile_remaining, - ) - profile_step_idx = torch.where( - profile_mask, - profile_step_idx + 1, - profile_step_idx, - ) - - logging.info( - f"[profile_multiturn] turn={t} profile_steps=" - f"{profile_T.detach().cpu().tolist()} " - f"profile_seconds={profile_seconds.detach().cpu().tolist()}" - ) - - # We start the turn after all rows have finished the profile phase. - # This excludes profile/user-speaking silence from each turn segment. - turn_start_frame_global = sum(p.size(-1) for p in state.all_predictions) - turn_start_frames = torch.full( - (B,), - turn_start_frame_global, - dtype=torch.long, - device=device, - ) - - first_profile_turn = valid_mask & (profile_decode_start_frame < 0) - profile_decode_start_frame = torch.where( - first_profile_turn, - turn_start_frames, - profile_decode_start_frame, - ) - - # Make streaming_finalize start from the first generated turn frame. - state.audio_prediction_start_idx = torch.where( - first_profile_turn, - turn_start_frames, - state.audio_prediction_start_idx, - ) - - # The profile phase already consumed `delay_tokens` text tokens for each row. - # Do not slice turn_text because delay_tokens differs per row; use a per-row base offset. - turn_text_base = delay_tokens - turn_remaining_lens = torch.clamp(turn_lens - turn_text_base, min=0) - - turn_offset = state.text_tokens_seen.clone() - turn_done = ~valid_mask - turn_steps = 0 - - saw_audio = torch.zeros(B, dtype=torch.bool, device=device) - first_audio_step_finished = torch.zeros(B, dtype=torch.bool, device=device) - - # ----------------------------- - # 2. Batched turn generation - # ----------------------------- - while (not turn_done.all()) and turn_steps < args.max_tts_steps: - turn_steps += 1 - - gen_mask = valid_mask & (~turn_done) - - # Old single-sample behavior cleared state.finished every generation step. - # Do it only for active rows. This prevents an early audio EOS from ending - # the turn before text is exhausted. - state.finished = state.finished & (~gen_mask) - - relative_position = state.text_tokens_seen - turn_offset - text_exhausted = relative_position >= turn_remaining_lens - - # Current token index in original turn_text, accounting for tokens consumed - # during profile. - position = turn_text_base + relative_position - position = position.clamp(min=0, max=turn_text.size(1) - 1) - - current_tokens = turn_text[arange_B, position] - current_tokens = torch.where( - text_exhausted, - torch.full_like(current_tokens, model.eos_id), - current_tokens, - ) - - state, audio_codes, _ = model.streaming_step_profiled( - state=state, - text_tokens=current_tokens, - profile_mask=torch.zeros(B, dtype=torch.bool, device=device), - profile_text_tokens=torch.full( - (B,), - model.pad_id, - dtype=torch.long, - device=device, - ), - profile_end_mask=torch.zeros(B, dtype=torch.bool, device=device), - active_mask=gen_mask, - use_inference_mode=True, - ) - - if audio_codes is not None: - newly_saw_audio = gen_mask & (~saw_audio) - saw_audio = saw_audio | gen_mask - first_audio_step_finished = torch.where( - newly_saw_audio, - state.finished, - first_audio_step_finished, - ) - - # Match old stopping condition: - # a turn is done only when its text is exhausted AND audio EOS is detected. - done_now = gen_mask & text_exhausted & state.finished - - if done_now.any(): - current_end_frame_global = sum(p.size(-1) for p in state.all_predictions) - - # Prefer precise EOS frame if _process_predictions_profiled populated it. - eos_end_frames = torch.where( - state.audio_prediction_end_idx >= 0, - state.audio_prediction_end_idx, - torch.full_like( - state.audio_prediction_end_idx, - current_end_frame_global, - ), - ) - - profile_final_end_frame = torch.where( - done_now, - eos_end_frames, - profile_final_end_frame, - ) - - for b in done_now.nonzero(as_tuple=False).flatten().detach().cpu().tolist(): - profile_turn_frame_ranges[b].append( - ( - int(t), - int(turn_start_frames[b].detach().cpu().item()), - int(eos_end_frames[b].detach().cpu().item()), - ) - ) - - turn_done = turn_done | done_now - - # Max-step fallback for rows that did not finish by EOS. - still_running = valid_mask & (~turn_done) - if still_running.any(): - current_end_frame_global = sum(p.size(-1) for p in state.all_predictions) - fallback_end_frames = torch.full( - (B,), - current_end_frame_global, - dtype=torch.long, - device=device, - ) - - profile_final_end_frame = torch.where( - still_running, - fallback_end_frames, - profile_final_end_frame, - ) - - for b in still_running.nonzero(as_tuple=False).flatten().detach().cpu().tolist(): - profile_turn_frame_ranges[b].append( - ( - int(t), - int(turn_start_frames[b].detach().cpu().item()), - int(fallback_end_frames[b].detach().cpu().item()), - ) - ) - - # Do not let this turn's EOS crop the full conversation before later turns. - state.audio_prediction_end_idx = torch.where( - valid_mask, - torch.full_like(state.audio_prediction_end_idx, -1), - state.audio_prediction_end_idx, - ) - state.finished = state.finished & (~valid_mask) - - logging.info( - f"[profile_multiturn] turn={t} steps={turn_steps} " - f"saw_audio={saw_audio.detach().cpu().tolist()} " - f"first_audio_step_finished={first_audio_step_finished.detach().cpu().tolist()}" - ) - - # After all turns, crop each row at its own last meaningful frame. - # This prevents rows that finished early from keeping trailing batch-idle silence. - total_frames = sum(p.size(-1) for p in state.all_predictions) - profile_final_end_frame = torch.where( - profile_final_end_frame >= 0, - profile_final_end_frame, - torch.full_like(profile_final_end_frame, total_frames), - ) - - state.audio_prediction_end_idx.copy_(profile_final_end_frame) - - # if state.audio_prediction_end_idx[0].item() >= 0: - # last_audio_prediction_end_idx.copy_(state.audio_prediction_end_idx) - - # Scrub Special Tokens (BOS/EOS) from Audio Codes --- - # Because we force-decode the entire uncropped sequence, any BOS or EOS - # tokens left in the array will produce loud artifacts in the codec. - bos_id = getattr(model, "audio_bos_id", -1) - eos_id = getattr(model, "audio_eos_id", -1) - speaking_id = getattr(model, "audio_user_speaking_id", -1) - speaking_end_id = getattr(model, "audio_user_speaking_end_id", -1) - - sil_injection = codec_sil_codes.view(1, -1, 1) - - for step_idx in range(len(state.all_predictions)): - pred = state.all_predictions[step_idx] - # Check if any codebook in the frame has any special token - mask = (pred == bos_id) | (pred == eos_id) | (pred == speaking_id) | (pred == speaking_end_id) - frame_mask = mask.any(dim=1, keepdim=True) - - if frame_mask.any(): - state.all_predictions[step_idx] = torch.where( - frame_mask, - sil_injection.expand_as(pred), - pred - ) - - if inputs["duplex_multiturn"]: - # Erase the internal memory of Turn 1's EOS token so `streaming_finalize` - # decodes the entire physical sequence! - state.audio_prediction_end_idx.fill_(-1) - - - # Finalize decodes the collected Codec states globally regardless of which loop was run - finalize_output = model.streaming_finalize(state, use_inference_mode=True) - - if args.debug_dtype and batch_id == 0: - report_dtype_stats(handles, stats, examples) - - with fp32_precision(): - audio_f32 = finalize_output.audio.float() - audio_len = finalize_output.audio_len.int() - - expected_audio_lens = (torch.tensor(inputs["target_num_frames"], device=device) * model.target_samples_per_frame).int() - - if inputs["duplex_multiturn"]: - # Use exact math based on the output samples multiplier! - audio_len = (text_lens * model.target_samples_per_frame).int() - - # Cap the expected length so it physically cannot exceed the actual generated tensor size - audio_len = torch.min(audio_len, torch.tensor(audio_f32.size(1), device=device)) - elif inputs["profile_multiturn"]: - audio_len = finalize_output.audio_len.int() - else: - audio_len = torch.min(audio_len, expected_audio_lens) - - metric_audio_pred = resample(audio_f32, model.output_sample_rate, 16000) - metric_audio_pred_lens = (audio_len / model.output_sample_rate * 16000).to(torch.long) - - intelligibility.update( - name="dataset", - refs=inputs["raw_text"], - pred_audio=metric_audio_pred, - pred_audio_lens=metric_audio_pred_lens, - asr_hyps=None, - ) - - secs_metric.update( - name="dataset", - target_audio=resample(inputs["context_audio"].float(), model.sample_rate, 16000), - target_audio_lens=(inputs["context_audio_lengths"] / model.sample_rate * 16000).to(torch.long), - pred_audio=metric_audio_pred, - pred_audio_lens=metric_audio_pred_lens, - ) - - os.makedirs(args.out_dir, exist_ok=True) - audio_f32 = audio_f32.detach().cpu() - audio_len = audio_len.cpu() - - for i in range(B): - target_path = inputs["target_audio_paths"][i] - base_name = os.path.basename(target_path) - stem, ext = os.path.splitext(base_name) - if not ext: - ext = ".wav" - - if inputs["profile_multiturn"]: - wav = audio_f32[i, : audio_len[i]].numpy() - out_path = os.path.join(args.out_dir, base_name) - sf.write(out_path, wav, samplerate=model.output_sample_rate) - logging.info(f"Full Audio Saved: {out_path}") - - full_wav = audio_f32[i].numpy() - full_len = int(audio_len[i].item()) - - samples_per_prediction_frame = ( - model.codec_model_samples_per_frame - / (model.sample_rate / model.output_sample_rate) - ) - - decode_start_i = int(profile_decode_start_frame[i].detach().cpu().item()) - if decode_start_i < 0: - decode_start_i = 0 - - for turn_id, start_frame, end_frame in profile_turn_frame_ranges[i]: - rel_start_frame = start_frame - decode_start_i - rel_end_frame = end_frame - decode_start_i - - start_sample = int(round(rel_start_frame * samples_per_prediction_frame)) - end_sample = int(round(rel_end_frame * samples_per_prediction_frame)) - - start_sample = max(0, min(start_sample, full_len)) - end_sample = max(start_sample, min(end_sample, full_len)) - - turn_wav = full_wav[start_sample:end_sample] - - out_path = os.path.join(args.out_dir, f"{stem}_turn_{turn_id}{ext}") - sf.write(out_path, turn_wav, samplerate=model.output_sample_rate) - logging.info(f"Saved: {out_path}") - - else: - wav = audio_f32[i, : audio_len[i]].numpy() - out_path = os.path.join(args.out_dir, base_name) - sf.write(out_path, wav, samplerate=model.output_sample_rate) - logging.info(f"Saved: {out_path}") - - with fp32_precision(): - logging.info("\n--- Evaluation Metrics ---") - cer_wer = intelligibility.compute() - for k, m in cer_wer.items(): - logging.info(f"Intelligibility - {k}: {m}") - - secs_scores = secs_metric.compute() - for k, m in secs_scores.items(): - logging.info(f"SECS - {k}: {m}") - - -if __name__ == "__main__": - main() diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index 0beec43538cc..915f0a0481b0 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -698,25 +698,7 @@ def streaming_prefill_profile( # ----------------------- # TEXT CHANNEL # ----------------------- - if self.cfg.get("disable_subword_embedding", False): - text_emb = torch.zeros( - B, - T, - self.cfg.embedding_dim, - dtype=next(self.parameters()).dtype, - device=device, - ) - else: - text_emb = self.decoder.get_input_embeddings()(text_tokens) - - if self.use_bpe_char_tokenizer: - if self.cfg.get("use_multiturn_dataset", False): - text_mask = text_tokens != self.pad_id - else: - text_mask = torch.ones_like(text_tokens, dtype=torch.bool) - - text_emb = text_emb + self.cas_encoder(text_tokens, subword_mask=text_mask) - + text_emb = self.embed_text_tokens(text_tokens, text_lens=None, is_multiturn=self.cfg.get("use_multiturn_dataset", False)) if self.cfg.get("use_multiturn_dataset", False): text_emb[text_tokens == self.pad_id] = 0.0 @@ -769,23 +751,44 @@ def streaming_prefill_profile( # ----------------------- # KV CACHE EXTENSION # ----------------------- - cache_position = torch.arange( - state.cache_seq_len, - state.cache_seq_len + T, - device=device, - ) + # ToDo: on VLLM we need to make nemotron_h class support prefill + if self.decoder_type == "nemotron_h" and state.past_key_values is not None and T > 1: + hidden_chunks = [] + for t in range(T): + cache_position_t = torch.tensor([state.cache_seq_len], device=device) + + out = self.forward( + inputs_embeds=inputs_embeds[:, t : t + 1, :], + attention_mask=None, + use_cache=True, + past_key_values=state.past_key_values, + cache_position=cache_position_t, + ) - out = self.forward( - inputs_embeds=inputs_embeds, - attention_mask=None, - use_cache=True, - past_key_values=state.past_key_values, - cache_position=cache_position, - ) + state.past_key_values = out.past_key_values + state.cache_seq_len += 1 + hidden_chunks.append(out.last_hidden_state) - state.past_key_values = out.past_key_values - state.cache_seq_len += T - state.last_hidden = out.last_hidden_state + state.last_hidden = torch.cat(hidden_chunks, dim=1) + + else: + cache_position = torch.arange( + state.cache_seq_len, + state.cache_seq_len + T, + device=device, + ) + + out = self.forward( + inputs_embeds=inputs_embeds, + attention_mask=None, + use_cache=True, + past_key_values=state.past_key_values, + cache_position=cache_position, + ) + + state.past_key_values = out.past_key_values + state.cache_seq_len += T + state.last_hidden = out.last_hidden_state # Advance logical streams consumed by this profile prefill. state.text_tokens_seen += T @@ -1857,117 +1860,6 @@ def _prepare_streaming_input( return next_input, needs_context, needs_phoneme, needs_audio - def streaming_step_profiled( - self, - state: StreamingState, - text_tokens: Optional[torch.Tensor] = None, # (B,) - profile_mask: Optional[torch.Tensor] = None, # (B,) - profile_text_tokens: Optional[torch.Tensor] = None, # (B,) - profile_end_mask: Optional[torch.Tensor] = None, # (B,) - active_mask: Optional[torch.Tensor] = None, # (B,) - force_dropout_text: bool = False, - use_inference_mode: bool = True, - ) -> Tuple[StreamingState, Optional[torch.Tensor], Optional[torch.Tensor]]: - """ - One streaming step with per-row profiling support. - - profile_mask[b] == True: - row b runs a profiling step instead of a normal generation step. - The model receives profile_text_tokens[b] on text channel and - user-speaking/silence on audio channel. - The decoded output frame for that row is forced silence. - - active_mask[b] == False: - row b is idle this step. Its counters are not advanced. - """ - grad_ctx = torch.inference_mode if use_inference_mode else torch.no_grad - - with grad_ctx(): - device = state.config.device - B = state.config.batch_size - - if profile_mask is None: - profile_mask = torch.zeros(B, dtype=torch.bool, device=device) - else: - profile_mask = profile_mask.to(device=device, dtype=torch.bool) - - if profile_end_mask is None: - profile_end_mask = torch.zeros(B, dtype=torch.bool, device=device) - else: - profile_end_mask = profile_end_mask.to(device=device, dtype=torch.bool) - - if active_mask is None: - # Default: normal active rows are unfinished; profile rows are active. - active_mask = (~state.finished) | profile_mask - else: - active_mask = active_mask.to(device=device, dtype=torch.bool) - - profile_mask = profile_mask & active_mask - - if text_tokens is None: - text_tokens = torch.full( - (B,), - self.eos_id, - dtype=torch.long, - device=device, - ) - else: - text_tokens = text_tokens.to(device=device, dtype=torch.long) - - if profile_text_tokens is None: - profile_text_tokens = torch.full( - (B,), - self.pad_id, - dtype=torch.long, - device=device, - ) - else: - profile_text_tokens = profile_text_tokens.to(device=device, dtype=torch.long) - - ( - next_input, - needs_context, - needs_phoneme, - needs_audio, - effective_profile_mask, - profile_silence_unstacked, - profile_last_audio_codes, - ) = self._prepare_streaming_input_profiled( - state=state, - text_tokens=text_tokens, - profile_mask=profile_mask, - profile_text_tokens=profile_text_tokens, - profile_end_mask=profile_end_mask, - active_mask=active_mask, - force_dropout_text=force_dropout_text, - ) - - cache_position = torch.tensor([state.cache_seq_len], device=device) - - transformer_out = self.forward( - inputs_embeds=next_input, - attention_mask=None, - use_cache=True, - past_key_values=state.past_key_values, - cache_position=cache_position, - ) - - state.last_hidden = transformer_out.last_hidden_state - state.past_key_values = transformer_out.past_key_values - state.cache_seq_len += 1 - - audio_codes_next, pred_phoneme_tokens = self._process_predictions_profiled( - state=state, - needs_context=needs_context, - needs_phoneme=needs_phoneme, - needs_audio=needs_audio, - profile_mask=effective_profile_mask, - active_mask=active_mask, - profile_silence_unstacked=profile_silence_unstacked, - profile_last_audio_codes=profile_last_audio_codes, - ) - - return state, audio_codes_next, pred_phoneme_tokens def _embed_one_text_step( self, @@ -2012,474 +1904,6 @@ def _embed_one_text_step( return text_embedded - def _prepare_streaming_input_profiled( - self, - state: StreamingState, - text_tokens: torch.Tensor, # (B,) - profile_mask: torch.Tensor, # (B,) - profile_text_tokens: torch.Tensor, # (B,) - profile_end_mask: torch.Tensor, # (B,) - active_mask: torch.Tensor, # (B,) - force_dropout_text: bool, - ): - device = state.config.device - B = state.config.batch_size - dtype = next(self.parameters()).dtype - - streaming_speech_delay = state.config.training_mode.streaming_speech_delay - streaming_phonemes_delay = state.config.training_mode.streaming_phonemes_delay - - needs_context = active_mask & (state.context_position < state.full_context_lens) - - # Profiling only applies after context is consumed. - effective_profile_mask = profile_mask & (~needs_context) - - normal_active = active_mask & (~effective_profile_mask) - - needs_text = ( - normal_active - & (~needs_context) - & (~state.text_finished) - ) - - needs_phoneme = ( - normal_active - & (~needs_context) - & (state.text_tokens_seen >= streaming_phonemes_delay) - & (~state.phoneme_stream_ended) - ) - - needs_audio = ( - normal_active - & (~needs_context) - & (state.text_tokens_seen >= streaming_speech_delay) - & (~state.finished) - ) - - next_input = torch.zeros( - B, - 1, - self.cfg.embedding_dim, - dtype=dtype, - device=device, - ) - - # ----------------------- - # Context rows - # ----------------------- - if needs_context.any(): - ctx_positions = state.context_position.clamp( - max=state.full_context_embedding.size(1) - 1 - ) - - ctx_emb = state.full_context_embedding[ - torch.arange(B, device=device), - ctx_positions, - :, - ].unsqueeze(1) - - next_input = next_input + ctx_emb * needs_context.view(B, 1, 1).to(dtype) - - # ----------------------- - # Normal text rows - # ----------------------- - if needs_text.any(): - text_emb = self._embed_one_text_step( - text_tokens, - force_dropout_text=force_dropout_text, - ) - - next_input = next_input + text_emb * needs_text.view(B, 1, 1).to(dtype) - - is_eos_token = (text_tokens == self.eos_id) & needs_text - state.text_finished = state.text_finished | is_eos_token - - # ----------------------- - # Profile rows: text channel - # ----------------------- - profile_text_emb = None - if effective_profile_mask.any(): - profile_text_emb = self._embed_one_text_step( - profile_text_tokens, - force_dropout_text=force_dropout_text, - ) - - # ----------------------- - # Phoneme rows, normal only - # ----------------------- - if self.phoneme_tokenizer is not None and needs_phoneme.any(): - phoneme_emb = torch.zeros( - B, - 1, - self.cfg.embedding_dim, - dtype=dtype, - device=device, - ) - - if state.config.phoneme_input_type == "gt" and state.gt_phoneme_embeddings is not None: - within_gt_len = state.phoneme_steps < state.gt_phoneme_lens - positions = state.phoneme_steps.clamp( - max=state.gt_phoneme_embeddings.size(1) - 1 - ) - - gt_emb = state.gt_phoneme_embeddings[ - torch.arange(B, device=device), - positions, - :, - ].unsqueeze(1) - - phoneme_mask = (needs_phoneme & within_gt_len).view(B, 1, 1).to(dtype) - phoneme_emb = phoneme_emb + gt_emb * phoneme_mask - - else: - first_phoneme_step = needs_phoneme & (state.phoneme_steps == 0) - has_last_phoneme = ( - needs_phoneme - & (~first_phoneme_step) - & (state.last_phoneme_tokens is not None) - ) - - if first_phoneme_step.any(): - phoneme_bos = torch.full( - (B, self.phoneme_stacking_factor, 1), - self.phoneme_tokenizer.bos_token_id, - device=device, - dtype=torch.long, - ) - phoneme_bos_emb = self.embed_phoneme_tokens(phoneme_bos) - phoneme_emb = phoneme_emb + phoneme_bos_emb * first_phoneme_step.view(B, 1, 1).to(dtype) - - if has_last_phoneme.any() and state.last_phoneme_tokens is not None: - last_phoneme_emb = self.embed_phoneme_tokens( - state.last_phoneme_tokens.unsqueeze(2) - ) - phoneme_emb = phoneme_emb + last_phoneme_emb * has_last_phoneme.view(B, 1, 1).to(dtype) - - state.phoneme_stream_ended = ( - state.phoneme_stream_ended | state.phoneme_eos_detected - ) - - next_input = next_input + phoneme_emb - - # ----------------------- - # Normal audio rows - # ----------------------- - audio_emb = torch.zeros( - B, - 1, - self.cfg.embedding_dim, - dtype=dtype, - device=device, - ) - - if needs_audio.any(): - if state.gt_audio_embeddings is not None: - within_gt_len = state.audio_steps < state.gt_audio_lens - positions = state.audio_steps.clamp( - max=state.gt_audio_embeddings.size(1) - 1 - ) - - gt_emb = state.gt_audio_embeddings[ - torch.arange(B, device=device), - positions, - :, - ].unsqueeze(1) - - audio_mask = (needs_audio & within_gt_len).view(B, 1, 1).to(dtype) - audio_emb = audio_emb + gt_emb * audio_mask - - else: - first_audio_step = needs_audio & (state.audio_steps == 0) - has_last_audio = ( - needs_audio - & (~first_audio_step) - & (state.last_audio_codes is not None) - ) - - if first_audio_step.any(): - audio_bos = torch.full( - (B, self.num_audio_codebooks * self.frame_stacking_factor, 1), - self.audio_bos_id, - device=device, - dtype=torch.long, - ) - audio_bos_emb = self.embed_audio_tokens(audio_bos) - audio_emb = audio_emb + audio_bos_emb * first_audio_step.view(B, 1, 1).to(dtype) - - if has_last_audio.any() and state.last_audio_codes is not None: - last_audio_emb = self.embed_audio_tokens( - state.last_audio_codes.unsqueeze(2) - ) - audio_emb = audio_emb + last_audio_emb * has_last_audio.view(B, 1, 1).to(dtype) - - next_input = next_input + audio_emb - - # ----------------------- - # Profile audio input - # ----------------------- - C = self.num_audio_codebooks - S = self.frame_stacking_factor - - sil_codes = self.codec_sil_codes.to(device=device, dtype=torch.long) - - profile_silence_unstacked = ( - sil_codes.view(1, C, 1) - .expand(B, C, S) - .contiguous() - ) - - if self.cfg.get("use_user_speaking_token", False): - profile_audio_stacked = torch.full( - (B, C * S, 1), - self.audio_user_speaking_id, - dtype=torch.long, - device=device, - ) - else: - profile_audio_stacked, _ = self.stack_codes( - profile_silence_unstacked, - torch.full((B,), S, dtype=torch.long, device=device), - bos_id=self.audio_bos_id, - eos_id=self.audio_eos_id, - stacking_factor=S, - num_codebooks=C, - ) - - profile_last_audio_codes = profile_audio_stacked[:, :, -1].contiguous() - - if self.cfg.get("use_user_speaking_end_token", False): - profile_end_codes = torch.full( - (B, C * S), - self.audio_user_speaking_end_id, - dtype=torch.long, - device=device, - ) - profile_last_audio_codes = torch.where( - profile_end_mask.view(B, 1), - profile_end_codes, - profile_last_audio_codes, - ) - - profile_audio_emb = self.embed_audio_tokens(profile_audio_stacked) - - if effective_profile_mask.any(): - profile_emb = profile_audio_emb - - if profile_text_emb is not None: - profile_emb = profile_emb + profile_text_emb - - next_input = ( - next_input * (~effective_profile_mask).view(B, 1, 1).to(dtype) - + profile_emb * effective_profile_mask.view(B, 1, 1).to(dtype) - ) - - # ----------------------- - # CFG branch - # ----------------------- - if state.config.use_cfg: - next_input_uncond = torch.zeros_like(next_input) - - if needs_context.any(): - ctx_uncond = state.config.dummy_context_embedding_unconditional.expand( - B, - 1, - -1, - ) - next_input_uncond = next_input_uncond + ctx_uncond * needs_context.view(B, 1, 1).to(dtype) - - if needs_audio.any(): - next_input_uncond = next_input_uncond + audio_emb * needs_audio.view(B, 1, 1).to(dtype) - - if effective_profile_mask.any(): - # Match your streaming_prefill_profile behavior: - # conditional = profile text + profile audio - # unconditional = profile audio only - next_input_uncond = ( - next_input_uncond * (~effective_profile_mask).view(B, 1, 1).to(dtype) - + profile_audio_emb * effective_profile_mask.view(B, 1, 1).to(dtype) - ) - - next_input = torch.cat([next_input, next_input_uncond], dim=0) - - return ( - next_input, - needs_context, - needs_phoneme, - needs_audio, - effective_profile_mask, - profile_silence_unstacked, - profile_last_audio_codes, - ) - - def _process_predictions_profiled( - self, - state: StreamingState, - needs_context: torch.Tensor, - needs_phoneme: torch.Tensor, - needs_audio: torch.Tensor, - profile_mask: torch.Tensor, - active_mask: torch.Tensor, - profile_silence_unstacked: torch.Tensor, # (B, C, S) - profile_last_audio_codes: torch.Tensor, # (B, C*S) - ): - B = state.config.batch_size - device = state.config.device - C = self.num_audio_codebooks - S = self.frame_stacking_factor - - # Context always advances only for active context rows. - state.context_position = state.context_position + needs_context.long() - - # Logical text stream advances only for rows participating this step. - # This avoids idle finished rows drifting while other batch rows keep running. - logical_active = active_mask & (~needs_context) - state.text_tokens_seen = state.text_tokens_seen + logical_active.long() - - # Profile behaves like your old streaming_prefill_profile: - # it consumes an audio-step-like profile input, but does not sample audio. - state.audio_steps = state.audio_steps + needs_audio.long() + profile_mask.long() - - state.phoneme_steps = state.phoneme_steps + needs_phoneme.long() - - pred_phoneme_tokens = None - audio_codes_next = None - - # ----------------------- - # Phoneme prediction, normal rows only - # ----------------------- - if needs_phoneme.any() and self.phoneme_tokenizer is not None: - first_phoneme_step = needs_phoneme & (state.phoneme_prediction_start_idx == -1) - - if first_phoneme_step.any(): - current_phoneme_step_idx = len(state.all_phoneme_predictions) - state.phoneme_prediction_start_idx = torch.where( - first_phoneme_step, - torch.full_like(state.phoneme_prediction_start_idx, current_phoneme_step_idx), - state.phoneme_prediction_start_idx, - ) - - pred_phoneme_tokens = self._predict_phoneme_tokens(state) - state.last_phoneme_tokens = pred_phoneme_tokens - state.all_phoneme_predictions.append(pred_phoneme_tokens) - - phoneme_eos_detected = ( - needs_phoneme - & (pred_phoneme_tokens == self.phoneme_tokenizer.eos_token_id).any(dim=1) - ) - - state.phoneme_eos_detected = state.phoneme_eos_detected | phoneme_eos_detected - - newly_ended_phoneme = ( - phoneme_eos_detected - & (state.phoneme_prediction_end_idx == -1) - ) - - if newly_ended_phoneme.any(): - current_phoneme_step_idx = len(state.all_phoneme_predictions) - state.phoneme_prediction_end_idx = torch.where( - newly_ended_phoneme, - torch.full_like(state.phoneme_prediction_end_idx, current_phoneme_step_idx), - state.phoneme_prediction_end_idx, - ) - - # ----------------------- - # Audio/profile output - # ----------------------- - if needs_audio.any() or profile_mask.any(): - mixed_unstacked = profile_silence_unstacked.clone() - mixed_last_codes = profile_last_audio_codes.clone() - - sampled_stacked = None - sampled_argmax = None - - if needs_audio.any(): - first_audio_step = needs_audio & (state.audio_prediction_start_idx == -1) - - if first_audio_step.any(): - current_frame_idx = sum(p.size(-1) for p in state.all_predictions) - state.audio_prediction_start_idx = torch.where( - first_audio_step, - torch.full_like(state.audio_prediction_start_idx, current_frame_idx), - state.audio_prediction_start_idx, - ) - - sampled_stacked, sampled_argmax = self._predict_audio_codes(state) - sampled_unstacked = sampled_stacked.view(B, C, S) - - mixed_unstacked = torch.where( - needs_audio.view(B, 1, 1), - sampled_unstacked, - mixed_unstacked, - ) - - mixed_last_codes = torch.where( - needs_audio.view(B, 1), - sampled_stacked, - mixed_last_codes, - ) - - # Update last audio input only for rows that actually had audio/profile activity. - update_last = needs_audio | profile_mask - - if state.last_audio_codes is None: - state.last_audio_codes = torch.full( - (B, C * S), - self.audio_bos_id, - dtype=torch.long, - device=device, - ) - - state.last_audio_codes = torch.where( - update_last.view(B, 1), - mixed_last_codes, - state.last_audio_codes, - ) - - # EOS detection only for sampled normal audio rows, never profile rows. - if needs_audio.any() and state.gt_audio_embeddings is None: - sampled_argmax_unstacked = sampled_argmax.view(B, C, S) - sampled_unstacked = sampled_stacked.view(B, C, S) - - eos_in_sampled = sampled_unstacked == self.audio_eos_id - eos_in_argmax = sampled_argmax_unstacked == self.audio_eos_id - - eos_any_codebook = ( - eos_in_sampled.any(dim=1) - | eos_in_argmax.any(dim=1) - ) # (B, S) - - eos_frame_idx = torch.where( - eos_any_codebook.any(dim=1), - eos_any_codebook.int().argmax(dim=1), - torch.full((B,), S, device=device), - ) - - audio_eos_detected = eos_any_codebook.any(dim=1) & needs_audio - state.finished = state.finished | audio_eos_detected - - newly_ended_audio = ( - audio_eos_detected - & (state.audio_prediction_end_idx == -1) - ) - - if newly_ended_audio.any(): - current_frame_count = len(state.all_predictions) * S - end_frame_idx = current_frame_count + eos_frame_idx - - state.audio_prediction_end_idx = torch.where( - newly_ended_audio, - end_frame_idx, - state.audio_prediction_end_idx, - ) - - state.all_predictions.append(mixed_unstacked) - audio_codes_next = mixed_unstacked - - if state.gt_audio_embeddings is not None and state.gt_audio_lens is not None: - gt_exhausted = needs_audio & (state.audio_steps >= state.gt_audio_lens) - state.finished = state.finished | gt_exhausted - - return audio_codes_next, pred_phoneme_tokens def _process_predictions( self, From 0afe00512226c33bb57bed4a35c0deae088da540 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Sun, 17 May 2026 07:14:39 -0700 Subject: [PATCH 044/109] Add user aduio conditioning Signed-off-by: Edresson Casanova --- ...text_to_speech_dataset_lhotse_multiturn.py | 64 +++- nemo/collections/tts/models/easy_magpietts.py | 274 +++++++++++++++++- .../tts/models/easy_magpietts_inference.py | 52 ---- 3 files changed, 336 insertions(+), 54 deletions(-) diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py index 05d6f41eddb3..a7681bb96142 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py @@ -320,7 +320,14 @@ def _align_codebooks(t): source_audio = collate_vectors(source_audio_list, padding_value=0.0) source_audio_lens = torch.tensor([len(a) for a in source_audio_list], dtype=torch.long) - + + user_audio_turn_splitted, user_audio_turn_splitted_lens, user_audio_turn_splitted_indices = extract_turn_audio_channel( + cuts=cuts, + source_audio_list=source_audio_list, + source_sample_rate=self.source_sample_rate, + roles=self.input_roles, + ) + target_text_tokens, target_token_lens = collate_token_channel( cuts, self.text_tokenizer, self.frame_length, roles=self.output_roles, add_text_bos=self.add_text_bos, tokenizer_names=batch_tokenizer_names, @@ -560,6 +567,9 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) "text_lens": target_token_lens, "raw_texts": [" ".join(s.text for s in cut.supervisions if s.speaker in self.output_roles) for cut in cuts], "task": [getattr(cut, "task", "tts") for cut in cuts], + "user_audio_turn_splitted": user_audio_turn_splitted, + "user_audio_turn_splitted_lens": user_audio_turn_splitted_lens, + "user_audio_turn_splitted_indices": user_audio_turn_splitted_indices, } if target_codes_list: @@ -768,6 +778,58 @@ def build_token_channel( return tokens +def extract_turn_audio_channel( + cuts: CutSet, + source_audio_list: list[torch.Tensor], + source_sample_rate: int, + roles: set[str], +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Packs role-specific speech turns into batch dimension. + + Returns: + user_audio: [N_turns, T_max] + user_audio_lens: [N_turns] + user_audio_indices: [N_turns, 3] + columns = [original_batch_idx, start_sample, end_sample] + in the original source_audio timeline. + """ + turn_audio = [] + turn_lens = [] + turn_indices = [] + + for batch_idx, cut in enumerate(cuts): + audio = source_audio_list[batch_idx] + audio_len = audio.shape[0] + + for sup in cut.supervisions: + if sup.speaker not in roles: + continue + + start = int(round(max(0.0, sup.start) * source_sample_rate)) + end = int(round(max(0.0, sup.end) * source_sample_rate)) + + start = min(max(start, 0), audio_len) + end = min(max(end, 0), audio_len) + + if end <= start: + continue + + chunk = audio[start:end] + turn_audio.append(chunk) + turn_lens.append(chunk.shape[0]) + turn_indices.append([batch_idx, start, end]) + + if len(turn_audio) == 0: + return None, None, None + else: + user_audio = collate_vectors(turn_audio, padding_value=0.0) + user_audio_lens = torch.tensor(turn_lens, dtype=torch.long) + user_audio_indices = torch.tensor(turn_indices, dtype=torch.long) + + return user_audio, user_audio_lens, user_audio_indices + + def collate_phoneme_channel( cuts: CutSet, phoneme_tokenizer, diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index 3753cf901d65..16ad640853b4 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -823,6 +823,8 @@ def process_batch( training_mode: Optional[TrainingMode] = None, task: Optional[List[str]] = None, agent_mask: Optional[torch.Tensor] = None, + user_audio_embedded: Optional[torch.Tensor] = None, + user_mask: Optional[torch.Tensor] = None, ) -> ProcessBatchOutput: """ Simplified batch processing using channel-based embedding architecture. @@ -958,6 +960,11 @@ def process_batch( dropout_complete_phoneme_channel=dropout_complete_phoneme_channel, ) + """ + if agent_mask is not None: + debug_agent_mask = agent_mask.clone() + """ + # 5. Prepare audio channel embeddings ( audio_channel_embedding, @@ -1017,6 +1024,190 @@ def process_batch( phoneme_channel_embedding = torch.cat([phoneme_channel_embedding, padding], dim=1) combined_channel_embedding = combined_channel_embedding + phoneme_channel_embedding + if user_audio_embedded is not None: + bos_user_pad = torch.zeros( + user_audio_embedded.size(0), + 1, + user_audio_embedded.size(2), + device=user_audio_embedded.device, + dtype=user_audio_embedded.dtype, + ) + user_audio_embedded = torch.cat([bos_user_pad, user_audio_embedded], dim=1) + + if user_mask is not None: + bos_user_mask = torch.zeros( + user_mask.size(0), + 1, + device=user_mask.device, + dtype=torch.bool, + ) + user_mask = torch.cat([bos_user_mask, user_mask], dim=1) + + # Align user conditioning to audio_codes_target timeline, + # same as agent_mask from prepare_audio_channel_embeddings(). + target_T = audio_codes_target.size(2) + + if user_audio_embedded.size(1) < target_T: + pad_len = target_T - user_audio_embedded.size(1) + + user_audio_embedded = torch.cat( + [ + user_audio_embedded, + torch.zeros( + user_audio_embedded.size(0), + pad_len, + user_audio_embedded.size(2), + device=user_audio_embedded.device, + dtype=user_audio_embedded.dtype, + ), + ], + dim=1, + ) + + if user_mask is not None: + user_mask = torch.cat( + [ + user_mask, + torch.zeros( + user_mask.size(0), + pad_len, + device=user_mask.device, + dtype=torch.bool, + ), + ], + dim=1, + ) + else: + user_audio_embedded = user_audio_embedded[:, :target_T] + if user_mask is not None: + user_mask = user_mask[:, :target_T] + + """ + if user_mask is not None and debug_agent_mask is not None: + valid = get_mask_from_lengths(audio_codes_lens_target).bool().to(debug_agent_mask.device) + target_T = valid.size(1) + + raw_agent = debug_agent_mask.to(valid.device).bool() + if raw_agent.size(1) < target_T: + raw_agent = torch.cat( + [ + raw_agent, + torch.zeros( + raw_agent.size(0), + target_T - raw_agent.size(1), + device=raw_agent.device, + dtype=torch.bool, + ), + ], + dim=1, + ) + else: + raw_agent = raw_agent[:, :target_T] + + raw_agent = raw_agent & valid + + raw_input_agent_mask = torch.zeros_like(raw_agent) + raw_input_agent_mask[:, 1:] = raw_agent[:, :-1] + raw_input_agent_mask[:, 0] = False + + user_mask_cmp = user_mask.to(raw_input_agent_mask.device).bool() + + T_cmp = min(user_mask_cmp.size(1), raw_input_agent_mask.size(1)) + user_mask_cmp = user_mask_cmp[:, :T_cmp] + raw_input_agent_mask = raw_input_agent_mask[:, :T_cmp] + + overlap = user_mask_cmp & raw_input_agent_mask + + # Classify overlap location relative to contiguous user spans. + B, T = user_mask_cmp.shape + left_overlap = torch.zeros(B, device=user_mask_cmp.device, dtype=torch.long) + right_overlap = torch.zeros(B, device=user_mask_cmp.device, dtype=torch.long) + middle_overlap = torch.zeros(B, device=user_mask_cmp.device, dtype=torch.long) + + for b in range(B): + user_idx = user_mask_cmp[b].nonzero(as_tuple=False).flatten() + if user_idx.numel() == 0: + continue + + breaks = torch.where(user_idx[1:] != user_idx[:-1] + 1)[0] + 1 + spans = torch.tensor_split(user_idx, breaks.cpu().tolist()) + + for span in spans: + s = int(span[0].item()) + e = int(span[-1].item()) + 1 + + overlap_span = overlap[b, s:e] + if not overlap_span.any(): + continue + + ov_idx = overlap_span.nonzero(as_tuple=False).flatten() + ov_abs = ov_idx + s + + # Count overlap frames close to boundaries. + boundary_width = int(self.cfg.get("user_agent_overlap_boundary_width", 10)) + + left_region_end = min(e, s + boundary_width) + right_region_start = max(s, e - boundary_width) + + left_overlap[b] += overlap[b, s:left_region_end].sum() + right_overlap[b] += overlap[b, right_region_start:e].sum() + + boundary_mask = torch.zeros_like(overlap[b, s:e]) + boundary_mask[: left_region_end - s] = True + boundary_mask[right_region_start - s :] = True + + middle_overlap[b] += (overlap_span & ~boundary_mask).sum() + + logging.info( + "[user/raw-agent input-overlap side debug] " + f"overlap_frames={overlap.sum(dim=1).detach().cpu().tolist()} " + f"left_overlap={left_overlap.detach().cpu().tolist()} " + f"right_overlap={right_overlap.detach().cpu().tolist()} " + f"middle_overlap={middle_overlap.detach().cpu().tolist()} " + f"user_frames={user_mask_cmp.sum(dim=1).detach().cpu().tolist()} " + f"raw_agent_input_frames={raw_input_agent_mask.sum(dim=1).detach().cpu().tolist()}" + ) + """ + batch_size = user_audio_embedded.size(0) + device = user_audio_embedded.device + + max_delay = audio_delay.max().item() + zero_delay_tensor = torch.zeros( + batch_size, + max_delay, + self.cfg.embedding_dim, + device=device, + dtype=user_audio_embedded.dtype, + ) + + user_audio_lens = audio_codes_lens_target.to(audio_delay.device) + + user_audio_channel_embedding, _ = self.join_embeddings_temporally( + embeddings=[zero_delay_tensor, user_audio_embedded], + lengths=[audio_delay, user_audio_lens], + ) + + if user_audio_channel_embedding.size(1) < max_channel_len: + pad_len = max_channel_len - user_audio_channel_embedding.size(1) + user_audio_channel_embedding = torch.cat( + [ + user_audio_channel_embedding, + torch.zeros( + batch_size, + pad_len, + user_audio_channel_embedding.size(2), + device=user_audio_channel_embedding.device, + dtype=user_audio_channel_embedding.dtype, + ), + ], + dim=1, + ) + else: + user_audio_channel_embedding = user_audio_channel_embedding[:, :max_channel_len] + + combined_channel_embedding = combined_channel_embedding + user_audio_channel_embedding + + # 7. Join context with combined channel embeddings # The combined_channel_lens is the max of all channel lens for each batch item combined_channel_lens = ( @@ -1142,7 +1333,7 @@ def training_step(self, batch, batch_idx): audio = batch['audio'] audio_lens = batch['audio_lens'] audio_codes, audio_codes_lens = self._codec_helper.audio_to_codes(audio, audio_lens) - + # augment tts data to looks more like multiturn data by adding pad on the begining and emulating user speaking. if self.cfg.get("use_multiturn_dataset", False) and "tts" in batch['task']: prob = self.cfg.get("add_tts_sil_begining_prob", 0.0) @@ -1242,6 +1433,85 @@ def training_step(self, batch, batch_idx): torch.zeros_like(gathered_agent_mask), ) + if self.cfg.get("use_multiturn_dataset", False) and batch["user_audio_turn_splitted"] is not None and self.cfg.get("condition_on_user_speech", False): + input_samples_per_frame = self.codec_model_samples_per_frame * self.frame_stacking_factor + + user_audio_codes, user_audio_codes_lens = self._codec_helper.audio_to_codes( + batch["user_audio_turn_splitted"], + batch["user_audio_turn_splitted_lens"], + ) + + if self._codec_converter is not None: + user_audio_codes = self._codec_converter.convert_original_to_new( + audio_tokens=user_audio_codes, audio_lens=user_audio_codes_lens + ).long() + + user_audio_codes, user_audio_codes_lens = self.stack_codes( + user_audio_codes, + user_audio_codes_lens, + self.audio_bos_id, + self.audio_eos_id, + self.frame_stacking_factor, + self.num_audio_codebooks, + ) + + user_audio_embedded = self.embed_audio_tokens(user_audio_codes) + # user_audio_embedded: [N_turns, T_turn_max, D] + # user_audio_codes_lens includes BOS/EOS from stack_codes + + B = batch["text"].shape[0] + T = batch["text"].shape[1] + D = user_audio_embedded.shape[-1] + + user_audio_embedded_restored = user_audio_embedded.new_zeros(B, T, D) + user_audio_embedded_mask = torch.zeros(B, T, device=user_audio_embedded.device, dtype=torch.bool) + + indices = batch["user_audio_turn_splitted_indices"].to(user_audio_embedded.device) + for turn_idx, (b, start_sample, end_sample) in enumerate(indices): + b = int(b.item()) + if b < 0: + continue + + start_frame = int(torch.ceil(start_sample.float() / input_samples_per_frame).item()) + end_frame = int(end_sample.item()) // input_samples_per_frame + + start_frame = max(0, min(start_frame, T)) + end_frame = max(start_frame, min(end_frame, T)) + + seq_len = end_frame - start_frame + if seq_len <= 0: + continue + + # Remove BOS/EOS from user turn embedding. + # stack_codes added special tokens, so timeline frames should use only real user audio frames. + turn_len_with_special = int(user_audio_codes_lens[turn_idx].item()) + real_start = 1 + real_end = max(real_start, turn_len_with_special - 1) + + turn_emb = user_audio_embedded[turn_idx, real_start:real_end] + + copy_len = min(seq_len, turn_emb.size(0)) + if copy_len <= 0: + continue + + dst_start = start_frame + dst_end = start_frame + copy_len + user_audio_embedded_restored[b, dst_start:dst_end] = turn_emb[:copy_len] + user_audio_embedded_mask[b, dst_start:dst_end] = True + boundary_trim = int(self.cfg.get("user_audio_boundary_trim", 4)) + trim = min(boundary_trim, copy_len // 2) + if trim > 0: + user_audio_embedded_restored[b, dst_start:dst_start + trim] = 0.0 + user_audio_embedded_restored[b, dst_end - trim:dst_end] = 0.0 + user_audio_embedded_mask[b, dst_start:dst_start + trim] = False + user_audio_embedded_mask[b, dst_end - trim:dst_end] = False + + user_audio_embedded = user_audio_embedded_restored + user_audio_mask = user_audio_embedded_mask + else: + user_audio_embedded = None + user_audio_mask = None + batch_output = self.process_batch( text=batch['text'], text_lens=batch['text_lens'], @@ -1256,6 +1526,8 @@ def training_step(self, batch, batch_idx): mode="train", task=batch["task"] if self.cfg.get("use_multiturn_dataset", False) else None, agent_mask=batch["agent_mask"] if self.cfg.get("use_multiturn_dataset", False) else None, + user_audio_embedded=user_audio_embedded, + user_mask=user_audio_mask, ) loss = batch_output.loss codebook_loss = batch_output.codebook_loss diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index 915f0a0481b0..4aba6393ffe3 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -1243,58 +1243,6 @@ def prepare_context_tensors( return context_embedding, context_lens, context_audio_codes, context_audio_codes_lens - def _embed_context_text_tokens(self, context_text_tokens: torch.Tensor) -> torch.Tensor: - """ - Embed context text tokens. - - Default behavior is preserved: - - disable_subword_embedding_on_context=False: decoder text embedding only - - New behavior: - - disable_subword_embedding_on_context=True: CAS encoder replaces decoder text embedding - """ - if self.cfg.get("disable_subword_embedding_on_context", False): - if not self.cfg.get("use_bpe_char_tokenizer", False): - raise ValueError( - "`disable_subword_embedding_on_context=True` requires " - "`use_bpe_char_tokenizer=True`, because CAS must replace text_embedding." - ) - - if self.cfg.get("use_multiturn_dataset", False): - context_text_mask = context_text_tokens != self.pad_id - else: - context_text_mask = torch.ones_like(context_text_tokens, dtype=torch.bool) - - return self.cas_encoder( - context_text_tokens, - subword_mask=context_text_mask, - ) - - return self.decoder.get_input_embeddings()(context_text_tokens) - - def _get_context_cfg_embedding(self, batch_size: int, device: torch.device) -> torch.Tensor: - """ - Returns the unconditional context embedding used for CFG dropout. - Shape: (B, 1, E) - """ - cfg_token = torch.full( - (batch_size, 1), - self.cfg_unk_token_id, - dtype=torch.long, - device=device, - ) - - if self.cfg.get("disable_subword_embedding_on_context", False): - if not self.cfg.get("use_bpe_char_tokenizer", False): - raise ValueError( - "`disable_subword_embedding_on_context=True` requires " - "`use_bpe_char_tokenizer=True` for CFG context embedding." - ) - - cfg_mask = torch.ones_like(cfg_token, dtype=torch.bool) - return self.cas_encoder(cfg_token, subword_mask=cfg_mask) - - return self.decoder.get_input_embeddings()(cfg_token) def stack_codes(self, codes, codes_lens, bos_id, eos_id, stacking_factor, num_codebooks): """ From a9e106eca00814cec2db29d8da59803e92aadbb7 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 18 May 2026 14:14:43 -0700 Subject: [PATCH 045/109] Add multiturn inference support with user audio conditioning Signed-off-by: Edresson Casanova --- .../tts/easy_magpietts_inference_multiturn.py | 319 ++++++++++++++---- ...text_to_speech_dataset_lhotse_multiturn.py | 8 + nemo/collections/tts/models/easy_magpietts.py | 213 +++++------- .../tts/models/easy_magpietts_inference.py | 4 + 4 files changed, 340 insertions(+), 204 deletions(-) diff --git a/examples/tts/easy_magpietts_inference_multiturn.py b/examples/tts/easy_magpietts_inference_multiturn.py index 2d8f2bc892ea..afb42f6abdae 100644 --- a/examples/tts/easy_magpietts_inference_multiturn.py +++ b/examples/tts/easy_magpietts_inference_multiturn.py @@ -183,6 +183,35 @@ def __getitem__(self, idx): return self.samples[idx] +def _resolve_audio_path(path, root_path): + if path is None: + return None + if root_path is not None and not os.path.isabs(path): + return os.path.join(root_path, path) + return path + + +def _load_audio(path, sample_rate, normalize=True, use_librosa=False): + if path is None or not os.path.exists(path): + return torch.zeros(1, dtype=torch.float32) + + if use_librosa: + wav, sr = librosa.load(path, sr=sample_rate, mono=True) + if normalize: + wav = normalize_volume(wav) + return torch.as_tensor(wav, dtype=torch.float32) + + wav, sr = sf.read(path, dtype="float32") + if wav.ndim > 1: + wav = wav.mean(axis=1) + + if normalize: + wav = normalize_volume(wav) + + wav = torch.as_tensor(wav, dtype=torch.float32).unsqueeze(0) + return resample(wav, sr, sample_rate).squeeze(0) + + def collate_and_tokenize_custom( batch, model, @@ -193,11 +222,31 @@ def collate_and_tokenize_custom( add_interruption_token=False, pad_factor_text_speech=10, force_interruption=False, - normalize_context_audio_volume=True, + normalize_audio_volume=True, use_librosa=False, profile_multiturn_inference=False, + max_eval_turns=None, ): main_tokenizer_name = list(model.cfg.text_tokenizers.keys())[0] + + if max_eval_turns is not None: + max_eval_turns = int(max_eval_turns) + if max_eval_turns <= 0: + raise ValueError("--max_eval_turns must be > 0 when provided.") + + truncated_batch = [] + for s in batch: + s = dict(s) + + if isinstance(s["text"], list): + s["text"] = s["text"][:max_eval_turns] + + if isinstance(s.get("user_audio_file_path"), list): + s["user_audio_file_path"] = s["user_audio_file_path"][:max_eval_turns] + + truncated_batch.append(s) + + batch = truncated_batch # --- MULTI-TURN MODE DECISION --- is_profile = profile_multiturn_inference @@ -324,41 +373,63 @@ def collate_and_tokenize_custom( audio_lengths = [] target_num_frames = [] - for i, s in enumerate(batch): - audio_path = s["context_audio_filepath"] - if root_path is not None: - audio_path = os.path.join(root_path, audio_path) - - if os.path.exists(audio_path): - if use_librosa: - wav, sr = librosa.load(audio_path, sr=sample_rate, mono=True) - if normalize_context_audio_volume: - wav = normalize_volume(wav) - wav = torch.as_tensor(wav, dtype=torch.float32) - else: - wav, sr = sf.read(audio_path, dtype='float32') - # Force Mono - if wav.ndim > 1: - wav = wav.mean(axis=1) + max_turns_for_user_audio = len(batched_turns) if (not is_duplex) else 0 - if normalize_context_audio_volume: - wav = normalize_volume(wav) + if is_profile and max_turns_for_user_audio > 0: + user_audio_by_turn = [[] for _ in range(max_turns_for_user_audio)] + user_audio_lens_by_turn = [[] for _ in range(max_turns_for_user_audio)] + else: + user_audio_by_turn = [] + user_audio_lens_by_turn = [] - # Convert to tensor, add batch dim for resampler, then remove it - wav_tensor = torch.as_tensor(wav, dtype=torch.float32).unsqueeze(0) - wav = resample(wav_tensor, sr, sample_rate).squeeze(0) - else: - wav = torch.zeros(1, dtype=torch.float32) + for i, s in enumerate(batch): + audio_path = _resolve_audio_path(s.get("context_audio_filepath"), root_path) + wav = _load_audio( + audio_path, + sample_rate, + normalize=normalize_audio_volume, + use_librosa=use_librosa, + ) audio_list.append(wav) audio_lengths.append(len(wav)) - tdur_audio_path = s["audio_filepath"] - if root_path is not None: - tdur_audio_path = os.path.join(root_path, tdur_audio_path) + # Optional per-turn user audio. + # Expected JSONL field: + # "user_audio": ["turn0_user.wav", "turn1_user.wav", ...] + if is_profile and max_turns_for_user_audio > 0: + user_audio_paths = s.get("user_audio_file_path", None) + + for t in range(max_turns_for_user_audio): + has_valid_text_turn = ( + isinstance(s["text"], list) and t < len(s["text"]) + ) or ( + not isinstance(s["text"], list) and t == 0 + ) + + if ( + isinstance(user_audio_paths, list) + and t < len(user_audio_paths) + and user_audio_paths[t] + and has_valid_text_turn + ): + ua_path = _resolve_audio_path(user_audio_paths[t], root_path) + ua_wav = _load_audio( + ua_path, + sample_rate=sample_rate, + normalize=normalize_audio_volume, + use_librosa=use_librosa, + ) + else: + ua_wav = torch.zeros(1, dtype=torch.float32) + + user_audio_by_turn[t].append(ua_wav) + user_audio_lens_by_turn[t].append(len(ua_wav)) + + tdur_audio_path = _resolve_audio_path(s["audio_filepath"], root_path) if tdur_audio_path and os.path.exists(tdur_audio_path): - wav_dur, sr_ = librosa.load(tdur_audio_path, sr=sample_rate, mono=True) + wav_dur = _load_audio(tdur_audio_path, sample_rate, normalize=normalize_audio_volume, use_librosa=use_librosa) tdur = wav_dur.shape[0] // model.input_samples_per_frame target_num_frames.append(tdur * extra_duration_thrshould) else: @@ -379,6 +450,25 @@ def collate_and_tokenize_custom( for i, wav in enumerate(audio_list): padded_audio[i, : len(wav)] = wav + + if is_profile and max_turns_for_user_audio > 0: + padded_user_audio_turns = [] + padded_user_audio_turns_lens = [] + + for t in range(max_turns_for_user_audio): + turn_lens = user_audio_lens_by_turn[t] + max_turn_audio_len = max(turn_lens) + padded_turn_audio = torch.zeros((B, max_turn_audio_len), dtype=torch.float32) + + for i, wav in enumerate(user_audio_by_turn[t]): + padded_turn_audio[i, : len(wav)] = wav + + padded_user_audio_turns.append(padded_turn_audio) + padded_user_audio_turns_lens.append(torch.tensor(turn_lens, dtype=torch.long)) + + out_dict["user_audio_turns"] = padded_user_audio_turns + out_dict["user_audio_turns_lens"] = padded_user_audio_turns_lens + out_dict["context_audio"] = padded_audio out_dict["context_audio_lengths"] = torch.tensor(audio_lengths, dtype=torch.long) out_dict["target_audio_paths"] = [s["audio_filepath"] for s in batch] @@ -418,6 +508,12 @@ def main(): parser.add_argument("--profile_multiturn_inference", action="store_true") parser.add_argument("--profile_pad_min_sec", type=float, default=2.0) parser.add_argument("--profile_pad_max_sec", type=float, default=2.0) + parser.add_argument( + "--max_eval_turns", + type=int, + default=6, + help="Maximum number of turns to evaluate per sample. None means use all turns.", + ) # Speaker & Prompt Configurations @@ -432,7 +528,7 @@ def main(): parser.add_argument("--topk", type=int, default=80) parser.add_argument("--max_tts_steps", type=int, default=2000) parser.add_argument("--force_speech_sil_codes", action="store_true") - parser.add_argument("--normalize_volume", type=lambda x: (str(x).lower() in ['true', '1', 'yes']), default=False) + parser.add_argument("--normalize_volume", type=lambda x: (str(x).lower() in ['true', '1', 'yes']), default=True) args = parser.parse_args() @@ -518,9 +614,10 @@ def main(): add_interruption_token=args.add_interruption_token, pad_factor_text_speech=args.pad_factor_text_speech, force_interruption=args.force_interruption, - normalize_context_audio_volume=args.normalize_volume, + normalize_audio_volume=args.normalize_volume, use_librosa=args.use_librosa, - profile_multiturn_inference=args.profile_multiturn_inference + profile_multiturn_inference=args.profile_multiturn_inference, + max_eval_turns=args.max_eval_turns, ) dataloader = DataLoader( @@ -529,22 +626,12 @@ def main(): ) if args.user_custom_speaker_reference and args.inference_speaker_reference: - if args.use_librosa: - wav, sr = librosa.load(args.inference_speaker_reference, sr=model.sample_rate, mono=True) - if args.normalize_volume: - wav = normalize_volume(wav) - speaker_wav = torch.as_tensor(wav, dtype=target_dtype).unsqueeze(0).to(model.device) - else: - wav, sr = sf.read(args.inference_speaker_reference, dtype='float32') - # Force Mono - if wav.ndim > 1: - wav = wav.mean(axis=1) - - if args.normalize_volume: - wav = normalize_volume(wav) - - speaker_wav = torch.as_tensor(wav).unsqueeze(0) - speaker_wav = resample(speaker_wav.float(), sr, model.sample_rate).to(target_dtype).to(model.device) + speaker_wav = _load_audio( + args.inference_speaker_reference, + model.sample_rate, + normalize=args.normalize_volume, + use_librosa=args.use_librosa, + ).unsqueeze(0).to(model.device, dtype=target_dtype) for batch_id, inputs in enumerate(dataloader): B = inputs["context_audio"].size(0) @@ -557,6 +644,14 @@ def main(): inputs["context_audio"] = speaker_wav.repeat(B, 1).detach() inputs["context_audio_lengths"] = torch.full((B,), speaker_wav.size(-1), dtype=torch.long, device=device) + if "user_audio_turns" in inputs: + inputs["user_audio_turns"] = [ + x.to(device, dtype=target_dtype) for x in inputs["user_audio_turns"] + ] + inputs["user_audio_turns_lens"] = [ + x.to(device) for x in inputs["user_audio_turns_lens"] + ] + profile_turn_frame_ranges = [] with torch.inference_mode(): wav = inputs["context_audio"] @@ -714,24 +809,101 @@ def main(): if hasattr(state, "phoneme_eos_detected"): state.phoneme_eos_detected.zero_() - # Prefill on the begining of each turn - profile_seconds = ( - args.profile_pad_min_sec - + torch.rand((), device=device).item() - * (args.profile_pad_max_sec - args.profile_pad_min_sec) - ) + if not model.cfg.get("condition_on_user_speech", False): + # Prefill on the begining of each turn + profile_seconds = ( + args.profile_pad_min_sec + + torch.rand((), device=device).item() + * (args.profile_pad_max_sec - args.profile_pad_min_sec) + ) - profile_T = max( - 1, - int(round(profile_seconds * model.sample_rate / model.input_samples_per_frame)), - ) + profile_T = max( + 1, + int(round(profile_seconds * model.sample_rate / model.input_samples_per_frame)), + ) + + profile_tokens = torch.full( + (1, profile_T), + model.pad_id, + dtype=torch.long, + device=device, + ) + user_audio_channel_embedding = None + else: + user_audio_channel_embedding = None + if "user_audio_turns" in inputs: + user_audio = inputs["user_audio_turns"][t] + user_audio_lens = inputs["user_audio_turns_lens"][t] + else: + print("Warning!! USING CONTEXT AUDIO AS USER AUDIO FOR TESTING !!") + user_audio = inputs["context_audio"] + user_audio_lens = inputs["context_audio_lengths"] + + user_audio_codes, user_audio_codes_lens = model._codec_helper.audio_to_codes( + user_audio, + user_audio_lens, + ) + + if model._codec_converter is not None: + user_audio_codes = model._codec_converter.convert_original_to_new( + audio_tokens=user_audio_codes, audio_lens=user_audio_codes_lens + ).long() + + user_audio_codes, user_audio_codes_lens = model.stack_codes( + user_audio_codes, + user_audio_codes_lens, + model.audio_bos_id, + model.audio_eos_id, + model.frame_stacking_factor, + model.num_audio_codebooks, + ) + user_audio_embedded = model.embed_audio_tokens(user_audio_codes) + + boundary_trim = model.cfg.get("user_audio_boundary_trim", 4) + boundary_trim = 0 if boundary_trim is None else int(boundary_trim) + + # Remove BOS/EOS from the user-audio turn, same as training. + if boundary_trim == 0: + real_start = 0 + real_end = int(user_audio_codes_lens[0].item()) + else: + turn_len_with_special = int(user_audio_codes_lens[0].item()) + real_start = 1 + real_end = max(real_start, turn_len_with_special - 1) + + user_audio_embedded = user_audio_embedded[:, real_start:real_end] + + # Optional: trim boundaries exactly like training. + copy_len = user_audio_embedded.size(1) + + if boundary_trim > 0: + trim = min(boundary_trim, copy_len // 2) + + if trim > 0: + user_audio_embedded[:, :trim] = 0.0 + user_audio_embedded[:, copy_len - trim:] = 0.0 + + # Add BOS-aligned zero frame, because audio input timeline has BOS at t=0. + bos_user_pad = torch.zeros( + user_audio_embedded.size(0), + 1, + user_audio_embedded.size(2), + device=user_audio_embedded.device, + dtype=user_audio_embedded.dtype, + ) + user_audio_embedded = torch.cat([bos_user_pad, user_audio_embedded], dim=1) + + profile_T = user_audio_embedded.size(1) + profile_tokens = torch.full( + (B, profile_T), + model.pad_id, + dtype=torch.long, + device=device, + ) + + user_audio_channel_embedding = user_audio_embedded + profile_seconds = profile_T * model.input_samples_per_frame / model.sample_rate - profile_tokens = torch.full( - (1, profile_T), - model.pad_id, - dtype=torch.long, - device=device, - ) # add text tokens needed for profilling if not model.cfg.get("agent_mask_include_transition_prefix", False): delay_tokens = int(state.config.training_mode.streaming_speech_delay) @@ -739,11 +911,15 @@ def main(): profile_tokens[:, -delay_tokens:] = turn_text[:, :delay_tokens] turn_text = turn_text[:, delay_tokens:] turn_lens = torch.clamp(turn_lens - delay_tokens, min=0) + if user_audio_channel_embedding is not None and delay_tokens > 0: + user_audio_channel_embedding = user_audio_channel_embedding.clone() + user_audio_channel_embedding[:, -delay_tokens:] = 0.0 state = model.streaming_prefill_profile( state=state, text_tokens=profile_tokens, use_inference_mode=True, + user_audio_channel_embedding=user_audio_channel_embedding ) logging.info( @@ -771,13 +947,16 @@ def main(): relative_position = state.text_tokens_seen - turn_offset text_exhausted = relative_position >= turn_lens - position = relative_position.clamp(min=0, max=turn_text.size(1) - 1) - current_tokens = turn_text[torch.arange(B, device=device), position] - current_tokens = torch.where( - text_exhausted, - torch.full_like(current_tokens, model.eos_id), - current_tokens, - ) + if turn_text.size(1) == 0: + current_tokens = torch.full((B,), model.eos_id, dtype=torch.long, device=device) + else: + position = relative_position.clamp(min=0, max=turn_text.size(1) - 1) + current_tokens = turn_text[torch.arange(B, device=device), position] + current_tokens = torch.where( + text_exhausted, + torch.full_like(current_tokens, model.eos_id), + current_tokens, + ) state, audio_codes, _ = model.streaming_step( state=state, diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py index a7681bb96142..4251df0a3874 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py @@ -618,8 +618,16 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) self.output_roles, ) + user_mask, user_mask_lens = collate_speaker_mask_channel( + cuts, + self.frame_length, + self.input_roles, + ) + batch_dict["agent_mask"] = agent_mask batch_dict["agent_mask_lens"] = agent_mask_lens + batch_dict["user_mask"] = user_mask + batch_dict["user_mask_lens"] = user_mask_lens return batch_dict diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index 16ad640853b4..75b229d079f5 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -823,8 +823,7 @@ def process_batch( training_mode: Optional[TrainingMode] = None, task: Optional[List[str]] = None, agent_mask: Optional[torch.Tensor] = None, - user_audio_embedded: Optional[torch.Tensor] = None, - user_mask: Optional[torch.Tensor] = None, + user_audio_embedded: Optional[torch.Tensor] = None ) -> ProcessBatchOutput: """ Simplified batch processing using channel-based embedding architecture. @@ -960,11 +959,6 @@ def process_batch( dropout_complete_phoneme_channel=dropout_complete_phoneme_channel, ) - """ - if agent_mask is not None: - debug_agent_mask = agent_mask.clone() - """ - # 5. Prepare audio channel embeddings ( audio_channel_embedding, @@ -1034,22 +1028,12 @@ def process_batch( ) user_audio_embedded = torch.cat([bos_user_pad, user_audio_embedded], dim=1) - if user_mask is not None: - bos_user_mask = torch.zeros( - user_mask.size(0), - 1, - device=user_mask.device, - dtype=torch.bool, - ) - user_mask = torch.cat([bos_user_mask, user_mask], dim=1) - # Align user conditioning to audio_codes_target timeline, # same as agent_mask from prepare_audio_channel_embeddings(). target_T = audio_codes_target.size(2) if user_audio_embedded.size(1) < target_T: pad_len = target_T - user_audio_embedded.size(1) - user_audio_embedded = torch.cat( [ user_audio_embedded, @@ -1063,111 +1047,9 @@ def process_batch( ], dim=1, ) - - if user_mask is not None: - user_mask = torch.cat( - [ - user_mask, - torch.zeros( - user_mask.size(0), - pad_len, - device=user_mask.device, - dtype=torch.bool, - ), - ], - dim=1, - ) else: user_audio_embedded = user_audio_embedded[:, :target_T] - if user_mask is not None: - user_mask = user_mask[:, :target_T] - - """ - if user_mask is not None and debug_agent_mask is not None: - valid = get_mask_from_lengths(audio_codes_lens_target).bool().to(debug_agent_mask.device) - target_T = valid.size(1) - - raw_agent = debug_agent_mask.to(valid.device).bool() - if raw_agent.size(1) < target_T: - raw_agent = torch.cat( - [ - raw_agent, - torch.zeros( - raw_agent.size(0), - target_T - raw_agent.size(1), - device=raw_agent.device, - dtype=torch.bool, - ), - ], - dim=1, - ) - else: - raw_agent = raw_agent[:, :target_T] - - raw_agent = raw_agent & valid - - raw_input_agent_mask = torch.zeros_like(raw_agent) - raw_input_agent_mask[:, 1:] = raw_agent[:, :-1] - raw_input_agent_mask[:, 0] = False - - user_mask_cmp = user_mask.to(raw_input_agent_mask.device).bool() - - T_cmp = min(user_mask_cmp.size(1), raw_input_agent_mask.size(1)) - user_mask_cmp = user_mask_cmp[:, :T_cmp] - raw_input_agent_mask = raw_input_agent_mask[:, :T_cmp] - - overlap = user_mask_cmp & raw_input_agent_mask - - # Classify overlap location relative to contiguous user spans. - B, T = user_mask_cmp.shape - left_overlap = torch.zeros(B, device=user_mask_cmp.device, dtype=torch.long) - right_overlap = torch.zeros(B, device=user_mask_cmp.device, dtype=torch.long) - middle_overlap = torch.zeros(B, device=user_mask_cmp.device, dtype=torch.long) - - for b in range(B): - user_idx = user_mask_cmp[b].nonzero(as_tuple=False).flatten() - if user_idx.numel() == 0: - continue - - breaks = torch.where(user_idx[1:] != user_idx[:-1] + 1)[0] + 1 - spans = torch.tensor_split(user_idx, breaks.cpu().tolist()) - - for span in spans: - s = int(span[0].item()) - e = int(span[-1].item()) + 1 - - overlap_span = overlap[b, s:e] - if not overlap_span.any(): - continue - - ov_idx = overlap_span.nonzero(as_tuple=False).flatten() - ov_abs = ov_idx + s - - # Count overlap frames close to boundaries. - boundary_width = int(self.cfg.get("user_agent_overlap_boundary_width", 10)) - - left_region_end = min(e, s + boundary_width) - right_region_start = max(s, e - boundary_width) - - left_overlap[b] += overlap[b, s:left_region_end].sum() - right_overlap[b] += overlap[b, right_region_start:e].sum() - - boundary_mask = torch.zeros_like(overlap[b, s:e]) - boundary_mask[: left_region_end - s] = True - boundary_mask[right_region_start - s :] = True - middle_overlap[b] += (overlap_span & ~boundary_mask).sum() - - logging.info( - "[user/raw-agent input-overlap side debug] " - f"overlap_frames={overlap.sum(dim=1).detach().cpu().tolist()} " - f"left_overlap={left_overlap.detach().cpu().tolist()} " - f"right_overlap={right_overlap.detach().cpu().tolist()} " - f"middle_overlap={middle_overlap.detach().cpu().tolist()} " - f"user_frames={user_mask_cmp.sum(dim=1).detach().cpu().tolist()} " - f"raw_agent_input_frames={raw_input_agent_mask.sum(dim=1).detach().cpu().tolist()}" - ) - """ batch_size = user_audio_embedded.size(0) device = user_audio_embedded.device @@ -1207,7 +1089,6 @@ def process_batch( combined_channel_embedding = combined_channel_embedding + user_audio_channel_embedding - # 7. Join context with combined channel embeddings # The combined_channel_lens is the max of all channel lens for each batch item combined_channel_lens = ( @@ -1481,12 +1362,19 @@ def training_step(self, batch, batch_idx): seq_len = end_frame - start_frame if seq_len <= 0: continue + + boundary_trim = self.cfg.get("user_audio_boundary_trim", 4) + boundary_trim = 0 if boundary_trim is None else int(boundary_trim) - # Remove BOS/EOS from user turn embedding. - # stack_codes added special tokens, so timeline frames should use only real user audio frames. - turn_len_with_special = int(user_audio_codes_lens[turn_idx].item()) - real_start = 1 - real_end = max(real_start, turn_len_with_special - 1) + if boundary_trim == 0: + # Keep the whole stacked user-audio sequence, including BOS/EOS. + real_start = 0 + real_end = int(user_audio_codes_lens[turn_idx].item()) + else: + # Remove BOS/EOS, then trim boundaries. + turn_len_with_special = int(user_audio_codes_lens[turn_idx].item()) + real_start = 1 + real_end = max(real_start, turn_len_with_special - 1) turn_emb = user_audio_embedded[turn_idx, real_start:real_end] @@ -1498,16 +1386,74 @@ def training_step(self, batch, batch_idx): dst_end = start_frame + copy_len user_audio_embedded_restored[b, dst_start:dst_end] = turn_emb[:copy_len] user_audio_embedded_mask[b, dst_start:dst_end] = True - boundary_trim = int(self.cfg.get("user_audio_boundary_trim", 4)) - trim = min(boundary_trim, copy_len // 2) - if trim > 0: - user_audio_embedded_restored[b, dst_start:dst_start + trim] = 0.0 - user_audio_embedded_restored[b, dst_end - trim:dst_end] = 0.0 - user_audio_embedded_mask[b, dst_start:dst_start + trim] = False - user_audio_embedded_mask[b, dst_end - trim:dst_end] = False + + if boundary_trim > 0: + trim = min(boundary_trim, copy_len // 2) + if trim > 0: + user_audio_embedded_restored[b, dst_start:dst_start + trim] = 0.0 + user_audio_embedded_restored[b, dst_end - trim:dst_end] = 0.0 + user_audio_embedded_mask[b, dst_start:dst_start + trim] = False + user_audio_embedded_mask[b, dst_end - trim:dst_end] = False user_audio_embedded = user_audio_embedded_restored user_audio_mask = user_audio_embedded_mask + # compare these two masks showing count batch level overlaps, left and right overlap per item batch. Consider only where both are ones. + """ + if "agent_mask" in batch and batch["agent_mask"] is not None: + user_cmp = user_audio_mask.bool() + agent_cmp = batch["agent_mask"].to(user_cmp.device).bool() + T_cmp = min(user_cmp.size(1), agent_cmp.size(1)) + valid = torch.arange(T_cmp, device=user_cmp.device)[None, :] < batch["text_lens"].to(user_cmp.device)[:, None] + valid = valid[:, :T_cmp] + user_cmp = user_cmp[:, :T_cmp] & valid + agent_cmp = agent_cmp[:, :T_cmp] & valid + + overlap = user_cmp & agent_cmp + + left_overlap = torch.zeros(B, device=user_cmp.device, dtype=torch.long) + right_overlap = torch.zeros(B, device=user_cmp.device, dtype=torch.long) + middle_overlap = torch.zeros(B, device=user_cmp.device, dtype=torch.long) + + boundary_width = int(self.cfg.get("user_agent_overlap_boundary_width", 10)) + + for bi in range(B): + user_idx = user_cmp[bi].nonzero(as_tuple=False).flatten() + if user_idx.numel() == 0: + continue + + breaks = torch.where(user_idx[1:] != user_idx[:-1] + 1)[0] + 1 + spans = torch.tensor_split(user_idx, breaks.cpu().tolist()) + + for span in spans: + s = int(span[0].item()) + e = int(span[-1].item()) + 1 + + overlap_span = overlap[bi, s:e] + if not overlap_span.any(): + continue + + left_end = min(e, s + boundary_width) + right_start = max(s, e - boundary_width) + + left_overlap[bi] += overlap[bi, s:left_end].sum() + right_overlap[bi] += overlap[bi, right_start:e].sum() + + boundary_mask = torch.zeros_like(overlap_span) + boundary_mask[: left_end - s] = True + boundary_mask[right_start - s :] = True + + middle_overlap[bi] += (overlap_span & ~boundary_mask).sum() + + logging.info( + "[user/agent-mask overlap debug] " + f"overlap_frames={overlap.sum(dim=1).detach().cpu().tolist()} " + f"left_overlap={left_overlap.detach().cpu().tolist()} " + f"right_overlap={right_overlap.detach().cpu().tolist()} " + f"middle_overlap={middle_overlap.detach().cpu().tolist()} " + f"user_frames={user_cmp.sum(dim=1).detach().cpu().tolist()} " + f"agent_frames={agent_cmp.sum(dim=1).detach().cpu().tolist()}" + ) + """ else: user_audio_embedded = None user_audio_mask = None @@ -1526,8 +1472,7 @@ def training_step(self, batch, batch_idx): mode="train", task=batch["task"] if self.cfg.get("use_multiturn_dataset", False) else None, agent_mask=batch["agent_mask"] if self.cfg.get("use_multiturn_dataset", False) else None, - user_audio_embedded=user_audio_embedded, - user_mask=user_audio_mask, + user_audio_embedded=user_audio_embedded ) loss = batch_output.loss codebook_loss = batch_output.codebook_loss diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index 4aba6393ffe3..8b2b0b53f7c4 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -683,6 +683,7 @@ def streaming_prefill_profile( self, state: StreamingState, text_tokens: torch.Tensor, # (B, T) or (B,) + user_audio_channel_embedding: torch.Tensor = None, use_inference_mode: bool = True, # ToDo: implement audio direct support instead of use silence tokens ) -> StreamingState: @@ -736,6 +737,9 @@ def streaming_prefill_profile( # Match training channel sum: text + audio profile-token/silence input. combined_emb = text_emb + audio_emb + + if self.cfg.get("condition_on_user_speech", False): + combined_emb = combined_emb + user_audio_channel_embedding # ----------------------- # CFG handling From 8ed4363b1b0b76ce98a326a38e125a98b4b125e0 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 19 May 2026 07:13:17 -0700 Subject: [PATCH 046/109] Add silence if user audio is not available Signed-off-by: Edresson Casanova --- examples/tts/easy_magpietts_inference_multiturn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/tts/easy_magpietts_inference_multiturn.py b/examples/tts/easy_magpietts_inference_multiturn.py index afb42f6abdae..5b0ff9de3276 100644 --- a/examples/tts/easy_magpietts_inference_multiturn.py +++ b/examples/tts/easy_magpietts_inference_multiturn.py @@ -421,7 +421,8 @@ def collate_and_tokenize_custom( use_librosa=use_librosa, ) else: - ua_wav = torch.zeros(1, dtype=torch.float32) + print("User audio not founded, using silence two seconds audio") + ua_wav = torch.zeros(int(2 * sample_rate), dtype=torch.float32) user_audio_by_turn[t].append(ua_wav) user_audio_lens_by_turn[t].append(len(ua_wav)) From c9ade51539a88a9ad2f4a3d4c0dc39db35885dcd Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 19 May 2026 13:25:48 -0700 Subject: [PATCH 047/109] Add multiturn augmentations Signed-off-by: Edresson Casanova --- nemo/collections/tts/models/easy_magpietts.py | 93 +++++++++++++++---- 1 file changed, 76 insertions(+), 17 deletions(-) diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index 75b229d079f5..7fc4d634a03c 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -1314,17 +1314,36 @@ def training_step(self, batch, batch_idx): torch.zeros_like(gathered_agent_mask), ) - if self.cfg.get("use_multiturn_dataset", False) and batch["user_audio_turn_splitted"] is not None and self.cfg.get("condition_on_user_speech", False): + if ( + self.cfg.get("use_multiturn_dataset", False) + and batch["user_audio_turn_splitted"] is not None + and self.cfg.get("condition_on_user_speech", False) + ): input_samples_per_frame = self.codec_model_samples_per_frame * self.frame_stacking_factor + user_audio = batch["user_audio_turn_splitted"] + user_audio_lens = batch["user_audio_turn_splitted_lens"] + + silence_prob = float(self.cfg.get("user_cond_silence_augmentation_prob", 0.0) or 0.0) + if self.training and silence_prob > 0.0: + silence_mask = torch.rand( + user_audio.size(0), + device=user_audio.device, + ) < silence_prob + + if silence_mask.any(): + user_audio = user_audio.clone() + user_audio[silence_mask] = 0.0 + user_audio_codes, user_audio_codes_lens = self._codec_helper.audio_to_codes( - batch["user_audio_turn_splitted"], - batch["user_audio_turn_splitted_lens"], + user_audio, + user_audio_lens, ) if self._codec_converter is not None: user_audio_codes = self._codec_converter.convert_original_to_new( - audio_tokens=user_audio_codes, audio_lens=user_audio_codes_lens + audio_tokens=user_audio_codes, + audio_lens=user_audio_codes_lens, ).long() user_audio_codes, user_audio_codes_lens = self.stack_codes( @@ -1337,8 +1356,6 @@ def training_step(self, batch, batch_idx): ) user_audio_embedded = self.embed_audio_tokens(user_audio_codes) - # user_audio_embedded: [N_turns, T_turn_max, D] - # user_audio_codes_lens includes BOS/EOS from stack_codes B = batch["text"].shape[0] T = batch["text"].shape[1] @@ -1347,7 +1364,17 @@ def training_step(self, batch, batch_idx): user_audio_embedded_restored = user_audio_embedded.new_zeros(B, T, D) user_audio_embedded_mask = torch.zeros(B, T, device=user_audio_embedded.device, dtype=torch.bool) + sample_prob = float(self.cfg.get("user_cond_trim_augmentation_sample_prob", 0.0) or 0.0) + turn_prob = float(self.cfg.get("user_cond_trim_augmentation_turn_prob", 0.0) or 0.0) + base_trim = int(self.cfg.get("user_cond_trim_augmentation_base", 0) or 0) + + if self.training and sample_prob > 0.0 and turn_prob > 0.0 and base_trim > 0: + sample_trim_aug = torch.rand(B, device=user_audio_embedded.device) < sample_prob + else: + sample_trim_aug = torch.zeros(B, device=user_audio_embedded.device, dtype=torch.bool) + indices = batch["user_audio_turn_splitted_indices"].to(user_audio_embedded.device) + for turn_idx, (b, start_sample, end_sample) in enumerate(indices): b = int(b.item()) if b < 0: @@ -1362,16 +1389,14 @@ def training_step(self, batch, batch_idx): seq_len = end_frame - start_frame if seq_len <= 0: continue - + boundary_trim = self.cfg.get("user_audio_boundary_trim", 4) boundary_trim = 0 if boundary_trim is None else int(boundary_trim) if boundary_trim == 0: - # Keep the whole stacked user-audio sequence, including BOS/EOS. real_start = 0 real_end = int(user_audio_codes_lens[turn_idx].item()) else: - # Remove BOS/EOS, then trim boundaries. turn_len_with_special = int(user_audio_codes_lens[turn_idx].item()) real_start = 1 real_end = max(real_start, turn_len_with_special - 1) @@ -1382,21 +1407,55 @@ def training_step(self, batch, batch_idx): if copy_len <= 0: continue - dst_start = start_frame - dst_end = start_frame + copy_len - user_audio_embedded_restored[b, dst_start:dst_end] = turn_emb[:copy_len] - user_audio_embedded_mask[b, dst_start:dst_end] = True + turn_emb = turn_emb[:copy_len].clone() + turn_mask = torch.ones(copy_len, device=user_audio_embedded.device, dtype=torch.bool) if boundary_trim > 0: trim = min(boundary_trim, copy_len // 2) if trim > 0: - user_audio_embedded_restored[b, dst_start:dst_start + trim] = 0.0 - user_audio_embedded_restored[b, dst_end - trim:dst_end] = 0.0 - user_audio_embedded_mask[b, dst_start:dst_start + trim] = False - user_audio_embedded_mask[b, dst_end - trim:dst_end] = False + turn_emb[:trim] = 0.0 + turn_emb[copy_len - trim:] = 0.0 + turn_mask[:trim] = False + turn_mask[copy_len - trim:] = False + + if bool(sample_trim_aug[b].item()): + do_turn_aug = torch.rand((), device=user_audio_embedded.device).item() < turn_prob + + if do_turn_aug: + trim_delta = int(torch.randint( + low=-1, + high=2, + size=(), + device=user_audio_embedded.device, + ).item()) + trim_amount = max(1, base_trim + trim_delta) + trim_amount = min(trim_amount, copy_len) + + aug_choice = random.choices( + ["left", "right", "full"], + weights=[0.45, 0.45, 0.10], + k=1, + )[0] + + if aug_choice == "left": + turn_emb[:trim_amount] = 0.0 + turn_mask[:trim_amount] = False + elif aug_choice == "right": + turn_emb[copy_len - trim_amount:] = 0.0 + turn_mask[copy_len - trim_amount:] = False + else: + turn_emb.zero_() + turn_mask[:] = True + + dst_start = start_frame + dst_end = start_frame + copy_len + + user_audio_embedded_restored[b, dst_start:dst_end] = turn_emb + user_audio_embedded_mask[b, dst_start:dst_end] = turn_mask user_audio_embedded = user_audio_embedded_restored user_audio_mask = user_audio_embedded_mask + # compare these two masks showing count batch level overlaps, left and right overlap per item batch. Consider only where both are ones. """ if "agent_mask" in batch and batch["agent_mask"] is not None: From c59c8f8caee5a0cc63d2950ff6e06d3125895daa Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 20 May 2026 13:09:54 -0700 Subject: [PATCH 048/109] Add update inference Signed-off-by: Edresson Casanova --- .../tts/easy_magpietts_inference_multiturn.py | 118 ++++++++++++++++-- 1 file changed, 108 insertions(+), 10 deletions(-) diff --git a/examples/tts/easy_magpietts_inference_multiturn.py b/examples/tts/easy_magpietts_inference_multiturn.py index 5b0ff9de3276..3ea74b6fbd86 100644 --- a/examples/tts/easy_magpietts_inference_multiturn.py +++ b/examples/tts/easy_magpietts_inference_multiturn.py @@ -479,6 +479,17 @@ def collate_and_tokenize_custom( return out_dict +def _mix_user_turns_on_timeline(user_audio_turns, user_audio_turns_lens, sample_rate): + total_len = int(sum(x.item() for x in user_audio_turns_lens)) + mixed = torch.zeros(total_len, dtype=torch.float32) + + offset = 0 + for wav, wav_len in zip(user_audio_turns, user_audio_turns_lens): + wav_len = int(wav_len.item()) + mixed[offset : offset + wav_len] = wav[:wav_len] + offset += wav_len + + return mixed def main(): parser = argparse.ArgumentParser(description="EasyMagpieTTS Inference Evaluation") @@ -812,16 +823,19 @@ def main(): if not model.cfg.get("condition_on_user_speech", False): # Prefill on the begining of each turn - profile_seconds = ( - args.profile_pad_min_sec - + torch.rand((), device=device).item() - * (args.profile_pad_max_sec - args.profile_pad_min_sec) - ) - - profile_T = max( - 1, - int(round(profile_seconds * model.sample_rate / model.input_samples_per_frame)), - ) + if "user_audio_turns" in inputs: + profile_T = int(round(inputs["user_audio_turns"][t].size(-1) / model.input_samples_per_frame)) + profile_seconds = profile_T * model.input_samples_per_frame / model.sample_rate + else: + profile_seconds = ( + args.profile_pad_min_sec + + torch.rand((), device=device).item() + * (args.profile_pad_max_sec - args.profile_pad_min_sec) + ) + profile_T = max( + 1, + int(round(profile_seconds * model.sample_rate / model.input_samples_per_frame)), + ) profile_tokens = torch.full( (1, profile_T), @@ -1099,6 +1113,90 @@ def main(): out_path = os.path.join(args.out_dir, f"{stem}_turn_{turn_id}{ext}") sf.write(out_path, turn_wav, samplerate=model.output_sample_rate) logging.info(f"Saved: {out_path}") + + # --------------------------------------------------------- + # Save aligned stereo conversation: + # channel 0 = user conditioning audio + # channel 1 = generated agent audio + # --------------------------------------------------------- + if "user_audio_turns" in inputs: + user_segments = [] + + samples_per_prediction_frame = ( + model.codec_model_samples_per_frame + / (model.sample_rate / model.output_sample_rate) + ) + + first_user_len_in = int(inputs["user_audio_turns_lens"][0][i].item()) + first_user_delay_out = int( + round(first_user_len_in * model.output_sample_rate / model.sample_rate) + ) + + for turn_id, start_frame, end_frame in profile_turn_frame_ranges: + if turn_id >= len(inputs["user_audio_turns"]): + continue + + turn_audio = inputs["user_audio_turns"][turn_id][i].detach().cpu().float() + turn_audio_len = int(inputs["user_audio_turns_lens"][turn_id][i].item()) + turn_audio = turn_audio[:turn_audio_len] + + turn_audio_out = resample( + turn_audio.unsqueeze(0), + model.sample_rate, + model.output_sample_rate, + ).squeeze(0) + + if turn_id == 0: + user_start_sample = 0 + else: + prev_turn_end_frame = profile_turn_frame_ranges[turn_id - 1][2] + rel_prev_end_frame = prev_turn_end_frame - profile_decode_start_frame + user_start_sample = first_user_delay_out + int( + round(rel_prev_end_frame * samples_per_prediction_frame) + ) + + user_segments.append((user_start_sample, turn_audio_out)) + + total_user_len = 0 + for s, wav_seg in user_segments: + total_user_len = max(total_user_len, s + wav_seg.numel()) + + user_ch = torch.zeros(total_user_len) + + for s, wav_seg in user_segments: + e = s + wav_seg.numel() + user_ch[s:e] += wav_seg + + agent_pred = torch.from_numpy(wav).float() + agent_ch = torch.cat( + [ + torch.zeros(first_user_delay_out, dtype=agent_pred.dtype), + agent_pred, + ] + ) + + stereo_len = max(user_ch.numel(), agent_ch.numel()) + + user_pad = torch.zeros(stereo_len) + agent_pad = torch.zeros(stereo_len) + + user_pad[: user_ch.numel()] = user_ch + agent_pad[: agent_ch.numel()] = agent_ch + + stereo = torch.stack([user_pad, agent_pad], dim=1).numpy() + + aligned_path = os.path.join( + args.out_dir, + f"{stem}_user_agent_aligned{ext}", + ) + + sf.write( + aligned_path, + stereo, + samplerate=model.output_sample_rate, + ) + + logging.info(f"Aligned user/agent stereo audio saved: {aligned_path}") else: wav = audio_f32[i, : audio_len[i]].numpy() out_path = os.path.join(args.out_dir, base_name) From 6a23934e513d40f532a0ca475b933259dc69c4fd Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 20 May 2026 14:31:53 -0700 Subject: [PATCH 049/109] Add use_explicit_silence_for_streaming_audio_delay Signed-off-by: Edresson Casanova --- nemo/collections/tts/models/easy_magpietts.py | 62 ++++++++++++++++++- 1 file changed, 59 insertions(+), 3 deletions(-) diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index 7fc4d634a03c..aa005fa3d19d 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -456,7 +456,6 @@ def prepare_phoneme_channel_embeddings( phoneme_mask = get_mask_from_lengths(phoneme_tokens_lens_stacked) phoneme_embedded = phoneme_embedded * phoneme_mask.unsqueeze(2) # (B, T', E) - # Handle phoneme dropout - zero out the embeddings if dropout_complete_phoneme_channel: phoneme_embedded = phoneme_embedded * 0.0 @@ -750,7 +749,10 @@ def prepare_audio_channel_embeddings( # Note that consider the current_streaming_speech_delay tokens/user speaking tokens on the loss, # allowing to predict them in autoregressive way transition_prefix = int(current_streaming_speech_delay or 0) - if self.cfg.get("agent_mask_include_transition_prefix", False) and transition_prefix > 0: + if ( + self.cfg.get("agent_mask_include_transition_prefix", False) + or self.cfg.get("use_explicit_silence_for_streaming_audio_delay", False) + ) and transition_prefix > 0: agent_i = target_agent_mask.float().unsqueeze(1) agent_i = torch.nn.functional.pad(agent_i, (0, transition_prefix)) @@ -867,6 +869,12 @@ def process_batch( current_text_input_mode = selected_training_mode.text_input_mode current_streaming_speech_delay = selected_training_mode.streaming_speech_delay current_streaming_phonemes_delay = selected_training_mode.streaming_phonemes_delay + # Optional: realize streaming speech delay as explicit silence-prefix + # in the audio tokens instead of temporal embedding delay. + audio_silence_prefix_frames = 0 + if self.cfg.get("use_explicit_silence_for_streaming_audio_delay", False): + audio_silence_prefix_frames = int(current_streaming_speech_delay) + current_streaming_speech_delay = 0 # Determine dropout flags dropout_text_input = (random.random() < self.dropout_text_input_prob) if mode == 'train' else False @@ -912,6 +920,54 @@ def process_batch( text[speech_eos_mask] = self.tokenizer.pad # Clean up the text channel # else: # ToDo: move self.interruption_token_id forward by audio_delay so that soon it saw the interruption token it is forced to stop instead of await audio_delay tokens + if audio_silence_prefix_frames > 0: + # audio_codes here are still unstacked: [B, C, T] + prefix_raw_frames = ( + audio_silence_prefix_frames * self.frame_stacking_factor + ) + + B_audio, C_audio, _ = audio_codes.shape + + sil = self.codec_sil_codes_unconverted.to( + device=audio_codes.device, + dtype=audio_codes.dtype, + ).view(1, C_audio, 1) + + sil_prefix = sil.expand(B_audio, C_audio, prefix_raw_frames) + + audio_codes = torch.cat([sil_prefix, audio_codes], dim=2) + audio_codes_lens = audio_codes_lens + prefix_raw_frames + + # agent_mask is in stacked/text frame units + if agent_mask is not None: + agent_mask = agent_mask.bool() + prefix_value = agent_mask[:, :1] if agent_mask.size(1) > 0 else torch.zeros( + agent_mask.size(0), 1, device=agent_mask.device, dtype=torch.bool + ) + + prefix_agent_mask = prefix_value.expand( + -1, + audio_silence_prefix_frames, + ) + + agent_mask = torch.cat( + [prefix_agent_mask, agent_mask], + dim=1, + ) + + # speech_eos_mask is also in stacked/text frame units + if speech_eos_mask is not None: + prefix_speech_eos = torch.zeros( + speech_eos_mask.size(0), + audio_silence_prefix_frames, + device=speech_eos_mask.device, + dtype=speech_eos_mask.dtype, + ) + speech_eos_mask = torch.cat( + [prefix_speech_eos, speech_eos_mask], + dim=1, + ) + # 3. Prepare text channel embeddings text_channel_embedding, text_channel_lens = self.prepare_text_channel_embeddings( text=text, @@ -972,7 +1028,7 @@ def process_batch( delay=audio_delay, speech_eos_mask=speech_eos_mask, agent_mask=agent_mask, - current_streaming_speech_delay=current_streaming_speech_delay + current_streaming_speech_delay=current_streaming_speech_delay if not self.cfg.get("use_explicit_silence_for_streaming_audio_delay", False) else audio_silence_prefix_frames, ) # 6. Sum the channel embeddings element-wise From bf402b7c3f7caeae69cb662e690734abd5e0e457 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 21 May 2026 09:20:23 -0700 Subject: [PATCH 050/109] Add new trim augmentation Signed-off-by: Edresson Casanova --- .../tts/easy_magpietts_inference_multiturn.py | 5 +- nemo/collections/tts/models/easy_magpietts.py | 65 ++++++++++++++----- .../tts/models/easy_magpietts_inference.py | 3 + 3 files changed, 55 insertions(+), 18 deletions(-) diff --git a/examples/tts/easy_magpietts_inference_multiturn.py b/examples/tts/easy_magpietts_inference_multiturn.py index 3ea74b6fbd86..781aa30531af 100644 --- a/examples/tts/easy_magpietts_inference_multiturn.py +++ b/examples/tts/easy_magpietts_inference_multiturn.py @@ -920,7 +920,10 @@ def main(): profile_seconds = profile_T * model.input_samples_per_frame / model.sample_rate # add text tokens needed for profilling - if not model.cfg.get("agent_mask_include_transition_prefix", False): + if ( + not model.cfg.get("agent_mask_include_transition_prefix", False) + and not model.cfg.get("use_explicit_silence_for_streaming_audio_delay", False) + ): delay_tokens = int(state.config.training_mode.streaming_speech_delay) delay_tokens = min(delay_tokens, int(turn_lens[0].item()), profile_T) profile_tokens[:, -delay_tokens:] = turn_text[:, :delay_tokens] diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index aa005fa3d19d..79f70229d9df 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -1446,7 +1446,7 @@ def training_step(self, batch, batch_idx): if seq_len <= 0: continue - boundary_trim = self.cfg.get("user_audio_boundary_trim", 4) + boundary_trim = self.cfg.get("user_audio_boundary_trim", 5) boundary_trim = 0 if boundary_trim is None else int(boundary_trim) if boundary_trim == 0: @@ -1478,30 +1478,61 @@ def training_step(self, batch, batch_idx): do_turn_aug = torch.rand((), device=user_audio_embedded.device).item() < turn_prob if do_turn_aug: - trim_delta = int(torch.randint( - low=-1, - high=2, - size=(), - device=user_audio_embedded.device, - ).item()) + trim_delta = int( + torch.randint( + low=-1, + high=2, # {-1, 0, 1} + size=(), + device=user_audio_embedded.device, + ).item() + ) + trim_amount = max(1, base_trim + trim_delta) - trim_amount = min(trim_amount, copy_len) + trim_amount = min(trim_amount, max(1, copy_len - 1)) aug_choice = random.choices( - ["left", "right", "full"], - weights=[0.45, 0.45, 0.10], + ["left", "right", "both"], + weights=[0.3, 0.3, 0.4], k=1, )[0] + zero_emb_pad = turn_emb.new_zeros(trim_amount, turn_emb.size(-1)) + zero_mask_pad = torch.zeros(trim_amount, device=turn_mask.device, dtype=turn_mask.dtype) + if aug_choice == "left": - turn_emb[:trim_amount] = 0.0 - turn_mask[:trim_amount] = False + # Remove tokens from the left, then right-pad zeros. + kept_emb = turn_emb[trim_amount:] + kept_mask = turn_mask[trim_amount:] + turn_emb = torch.cat([kept_emb, zero_emb_pad], dim=0) + turn_mask = torch.cat([kept_mask, zero_mask_pad], dim=0) + elif aug_choice == "right": - turn_emb[copy_len - trim_amount:] = 0.0 - turn_mask[copy_len - trim_amount:] = False - else: - turn_emb.zero_() - turn_mask[:] = True + # Remove tokens from the right, then right-pad zeros. + # This preserves timing of the left side and removes the transition/right edge. + kept_emb = turn_emb[: copy_len - trim_amount] + kept_mask = turn_mask[: copy_len - trim_amount] + + turn_emb = torch.cat([kept_emb, zero_emb_pad], dim=0) + turn_mask = torch.cat([kept_mask, zero_mask_pad], dim=0) + + else: # "both" + # Remove trim_amount total tokens split across left and right. + left_trim = trim_amount // 2 + right_trim = trim_amount - left_trim + + # If trim_amount is odd, randomly decide which side loses the extra token. + if trim_amount % 2 == 1 and torch.rand((), device=user_audio_embedded.device).item() < 0.5: + left_trim, right_trim = right_trim, left_trim + + kept_emb = turn_emb[left_trim : copy_len - right_trim] + kept_mask = turn_mask[left_trim : copy_len - right_trim] + + turn_emb = torch.cat([kept_emb, zero_emb_pad], dim=0) + turn_mask = torch.cat([kept_mask, zero_mask_pad], dim=0) + + # Safety: keep exact same length for restore assignment. + turn_emb = turn_emb[:copy_len] + turn_mask = turn_mask[:copy_len] dst_start = start_frame dst_end = start_frame + copy_len diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index 8b2b0b53f7c4..ddeb35a6cdc4 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -1672,6 +1672,9 @@ def _prepare_streaming_input( streaming_speech_delay = state.config.training_mode.streaming_speech_delay streaming_phonemes_delay = state.config.training_mode.streaming_phonemes_delay + if self.cfg.get("use_explicit_silence_for_streaming_audio_delay", False): + streaming_speech_delay = 0 + # Determine phases per batch item needs_context = state.context_position < state.full_context_lens # (B,) bool needs_text = (~needs_context) & (~state.text_finished) From 9f64580b754c79af90b9f4d4f130a34687ab84a9 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 21 May 2026 09:26:00 -0700 Subject: [PATCH 051/109] Add new trim aug Signed-off-by: Edresson Casanova --- nemo/collections/tts/models/easy_magpietts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index 79f70229d9df..3141a5db5f60 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -1446,7 +1446,7 @@ def training_step(self, batch, batch_idx): if seq_len <= 0: continue - boundary_trim = self.cfg.get("user_audio_boundary_trim", 5) + boundary_trim = self.cfg.get("user_audio_boundary_trim", 0) boundary_trim = 0 if boundary_trim is None else int(boundary_trim) if boundary_trim == 0: From 183b2b6d523af2f15a2ee155e7dbd258c0c4fe4c Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 21 May 2026 09:58:20 -0700 Subject: [PATCH 052/109] remove use_explicit_silence_for_streaming_audio_delay Signed-off-by: Edresson Casanova --- .../tts/easy_magpietts_inference_multiturn.py | 6 +- nemo/collections/tts/models/easy_magpietts.py | 62 +------------------ .../tts/models/easy_magpietts_inference.py | 3 - 3 files changed, 4 insertions(+), 67 deletions(-) diff --git a/examples/tts/easy_magpietts_inference_multiturn.py b/examples/tts/easy_magpietts_inference_multiturn.py index 781aa30531af..d67bf992eba7 100644 --- a/examples/tts/easy_magpietts_inference_multiturn.py +++ b/examples/tts/easy_magpietts_inference_multiturn.py @@ -920,15 +920,13 @@ def main(): profile_seconds = profile_T * model.input_samples_per_frame / model.sample_rate # add text tokens needed for profilling - if ( - not model.cfg.get("agent_mask_include_transition_prefix", False) - and not model.cfg.get("use_explicit_silence_for_streaming_audio_delay", False) - ): + if not model.cfg.get("agent_mask_include_transition_prefix", False): delay_tokens = int(state.config.training_mode.streaming_speech_delay) delay_tokens = min(delay_tokens, int(turn_lens[0].item()), profile_T) profile_tokens[:, -delay_tokens:] = turn_text[:, :delay_tokens] turn_text = turn_text[:, delay_tokens:] turn_lens = torch.clamp(turn_lens - delay_tokens, min=0) + # ToDo: Check if it is really necessary (probably not) if user_audio_channel_embedding is not None and delay_tokens > 0: user_audio_channel_embedding = user_audio_channel_embedding.clone() user_audio_channel_embedding[:, -delay_tokens:] = 0.0 diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index 3141a5db5f60..ea79f65edc5e 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -749,10 +749,7 @@ def prepare_audio_channel_embeddings( # Note that consider the current_streaming_speech_delay tokens/user speaking tokens on the loss, # allowing to predict them in autoregressive way transition_prefix = int(current_streaming_speech_delay or 0) - if ( - self.cfg.get("agent_mask_include_transition_prefix", False) - or self.cfg.get("use_explicit_silence_for_streaming_audio_delay", False) - ) and transition_prefix > 0: + if self.cfg.get("agent_mask_include_transition_prefix", False) and transition_prefix > 0: agent_i = target_agent_mask.float().unsqueeze(1) agent_i = torch.nn.functional.pad(agent_i, (0, transition_prefix)) @@ -869,12 +866,6 @@ def process_batch( current_text_input_mode = selected_training_mode.text_input_mode current_streaming_speech_delay = selected_training_mode.streaming_speech_delay current_streaming_phonemes_delay = selected_training_mode.streaming_phonemes_delay - # Optional: realize streaming speech delay as explicit silence-prefix - # in the audio tokens instead of temporal embedding delay. - audio_silence_prefix_frames = 0 - if self.cfg.get("use_explicit_silence_for_streaming_audio_delay", False): - audio_silence_prefix_frames = int(current_streaming_speech_delay) - current_streaming_speech_delay = 0 # Determine dropout flags dropout_text_input = (random.random() < self.dropout_text_input_prob) if mode == 'train' else False @@ -920,54 +911,6 @@ def process_batch( text[speech_eos_mask] = self.tokenizer.pad # Clean up the text channel # else: # ToDo: move self.interruption_token_id forward by audio_delay so that soon it saw the interruption token it is forced to stop instead of await audio_delay tokens - if audio_silence_prefix_frames > 0: - # audio_codes here are still unstacked: [B, C, T] - prefix_raw_frames = ( - audio_silence_prefix_frames * self.frame_stacking_factor - ) - - B_audio, C_audio, _ = audio_codes.shape - - sil = self.codec_sil_codes_unconverted.to( - device=audio_codes.device, - dtype=audio_codes.dtype, - ).view(1, C_audio, 1) - - sil_prefix = sil.expand(B_audio, C_audio, prefix_raw_frames) - - audio_codes = torch.cat([sil_prefix, audio_codes], dim=2) - audio_codes_lens = audio_codes_lens + prefix_raw_frames - - # agent_mask is in stacked/text frame units - if agent_mask is not None: - agent_mask = agent_mask.bool() - prefix_value = agent_mask[:, :1] if agent_mask.size(1) > 0 else torch.zeros( - agent_mask.size(0), 1, device=agent_mask.device, dtype=torch.bool - ) - - prefix_agent_mask = prefix_value.expand( - -1, - audio_silence_prefix_frames, - ) - - agent_mask = torch.cat( - [prefix_agent_mask, agent_mask], - dim=1, - ) - - # speech_eos_mask is also in stacked/text frame units - if speech_eos_mask is not None: - prefix_speech_eos = torch.zeros( - speech_eos_mask.size(0), - audio_silence_prefix_frames, - device=speech_eos_mask.device, - dtype=speech_eos_mask.dtype, - ) - speech_eos_mask = torch.cat( - [prefix_speech_eos, speech_eos_mask], - dim=1, - ) - # 3. Prepare text channel embeddings text_channel_embedding, text_channel_lens = self.prepare_text_channel_embeddings( text=text, @@ -1028,7 +971,7 @@ def process_batch( delay=audio_delay, speech_eos_mask=speech_eos_mask, agent_mask=agent_mask, - current_streaming_speech_delay=current_streaming_speech_delay if not self.cfg.get("use_explicit_silence_for_streaming_audio_delay", False) else audio_silence_prefix_frames, + current_streaming_speech_delay=current_streaming_speech_delay, ) # 6. Sum the channel embeddings element-wise @@ -1430,7 +1373,6 @@ def training_step(self, batch, batch_idx): sample_trim_aug = torch.zeros(B, device=user_audio_embedded.device, dtype=torch.bool) indices = batch["user_audio_turn_splitted_indices"].to(user_audio_embedded.device) - for turn_idx, (b, start_sample, end_sample) in enumerate(indices): b = int(b.item()) if b < 0: diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index ddeb35a6cdc4..8b2b0b53f7c4 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -1672,9 +1672,6 @@ def _prepare_streaming_input( streaming_speech_delay = state.config.training_mode.streaming_speech_delay streaming_phonemes_delay = state.config.training_mode.streaming_phonemes_delay - if self.cfg.get("use_explicit_silence_for_streaming_audio_delay", False): - streaming_speech_delay = 0 - # Determine phases per batch item needs_context = state.context_position < state.full_context_lens # (B,) bool needs_text = (~needs_context) & (~state.text_finished) From 126970ca1e8b6309cfd55b4e63802295d95041e2 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 21 May 2026 12:48:53 -0700 Subject: [PATCH 053/109] Add new inference Signed-off-by: Edresson Casanova --- .../tts/easy_magpietts_inference_multiturn.py | 67 +++++++++++++----- .../tts/models/easy_magpietts_inference.py | 69 ++++++++++++++++--- uv.lock | 2 +- 3 files changed, 109 insertions(+), 29 deletions(-) diff --git a/examples/tts/easy_magpietts_inference_multiturn.py b/examples/tts/easy_magpietts_inference_multiturn.py index d67bf992eba7..918702ed41a6 100644 --- a/examples/tts/easy_magpietts_inference_multiturn.py +++ b/examples/tts/easy_magpietts_inference_multiturn.py @@ -49,6 +49,10 @@ if torch.cuda.is_available(): torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0))) +def torch_rms_norm(wav, db_level=-27.0): + r = 10 ** (db_level / 20) + a = torch.sqrt((wav.size(-1) * (r**2)) / torch.sum(wav**2)) + return wav * a def attach_dtype_counter(model): handles = [] @@ -920,23 +924,41 @@ def main(): profile_seconds = profile_T * model.input_samples_per_frame / model.sample_rate # add text tokens needed for profilling - if not model.cfg.get("agent_mask_include_transition_prefix", False): - delay_tokens = int(state.config.training_mode.streaming_speech_delay) - delay_tokens = min(delay_tokens, int(turn_lens[0].item()), profile_T) - profile_tokens[:, -delay_tokens:] = turn_text[:, :delay_tokens] - turn_text = turn_text[:, delay_tokens:] - turn_lens = torch.clamp(turn_lens - delay_tokens, min=0) - # ToDo: Check if it is really necessary (probably not) - if user_audio_channel_embedding is not None and delay_tokens > 0: - user_audio_channel_embedding = user_audio_channel_embedding.clone() - user_audio_channel_embedding[:, -delay_tokens:] = 0.0 - - state = model.streaming_prefill_profile( - state=state, - text_tokens=profile_tokens, - use_inference_mode=True, - user_audio_channel_embedding=user_audio_channel_embedding - ) + delay_tokens = int(state.config.training_mode.streaming_speech_delay) + delay_tokens = min(delay_tokens, int(turn_lens[0].item()), profile_T) + + warmup_tokens = turn_text[:, :delay_tokens] + turn_text = turn_text[:, delay_tokens:] + turn_lens = torch.clamp(turn_lens - delay_tokens, min=0) + + if user_audio_channel_embedding is not None and delay_tokens > 0: + warmup_user_audio = user_audio_channel_embedding[:, -delay_tokens:] + user_audio_channel_embedding = user_audio_channel_embedding[:, :-delay_tokens] + profile_tokens = profile_tokens[:, :-delay_tokens] + else: + warmup_user_audio = None + + if profile_tokens.size(1) > 0: + state = model.streaming_prefill_profile( + state=state, + text_tokens=profile_tokens, + use_inference_mode=True, + user_audio_channel_embedding=user_audio_channel_embedding + ) + + for i in range(delay_tokens): + user_step_emb = warmup_user_audio[:, i] if warmup_user_audio is not None else None + + state.finished.zero_() + + state, _, _ = model.streaming_step( + state=state, + text_tokens=warmup_tokens[:, i], + user_audio_channel_embedding=user_step_emb, + prefill_like_step=not bool(model.cfg.get("agent_mask_include_transition_prefix", False)), + prefill_like_is_last_step=(i == delay_tokens - 1), + use_inference_mode=True, + ) logging.info( f"[profile_multiturn] turn={t} prefilled {profile_T} steps " @@ -1061,6 +1083,13 @@ def main(): metric_audio_pred = resample(audio_f32, model.output_sample_rate, 16000) metric_audio_pred_lens = (audio_len / model.output_sample_rate * 16000).to(torch.long) + context_audio = resample(inputs["context_audio"].float(), model.sample_rate, 16000) + context_audio_lens = (inputs["context_audio_lengths"] / model.sample_rate * 16000).to(torch.long) + + # normalize volume + metric_audio_pred = torch_rms_norm(metric_audio_pred) + context_audio = torch_rms_norm(context_audio) + intelligibility.update( name="dataset", refs=inputs["raw_text"], @@ -1071,8 +1100,8 @@ def main(): secs_metric.update( name="dataset", - target_audio=resample(inputs["context_audio"].float(), model.sample_rate, 16000), - target_audio_lens=(inputs["context_audio_lengths"] / model.sample_rate * 16000).to(torch.long), + target_audio=context_audio, + target_audio_lens=context_audio_lens, pred_audio=metric_audio_pred, pred_audio_lens=metric_audio_pred_lens, ) diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index 8b2b0b53f7c4..04c37c79033c 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -800,14 +800,7 @@ def streaming_prefill_profile( # Make the next normal streaming_step continue from silence, not AUDIO_BOS. state.all_predictions.append(sil_codes_unstacked) # keep silence so that in the target audio user will be silence - if self.cfg.get("use_user_speaking_end_token", False) and not self.cfg.get("agent_mask_include_transition_prefix", False) and self.phoneme_tokenizer is None: - state.last_audio_codes = torch.full( - (B, C * S), - self.audio_user_speaking_end_id, - dtype=torch.long, - device=device, - ) - elif self.cfg.get("use_user_speaking_token", False): + if self.cfg.get("use_user_speaking_token", False): state.last_audio_codes = torch.full( (B, C * S), self.audio_user_speaking_id, @@ -1595,7 +1588,10 @@ def streaming_step( state: StreamingState, text_tokens: Optional[torch.Tensor] = None, force_dropout_text: bool = False, + user_audio_channel_embedding: Optional[torch.Tensor] = None, + prefill_like_step: bool = False, use_inference_mode: bool = True, + prefill_like_is_last_step: bool = False, ) -> Tuple[StreamingState, Optional[torch.Tensor], Optional[torch.Tensor]]: """ Perform one streaming inference step with batch support. @@ -1626,7 +1622,7 @@ def streaming_step( # Phase 1: Prepare input embedding and determine per-item phase masks next_input, needs_context, needs_phoneme, needs_audio = self._prepare_streaming_input( - state, text_tokens, force_dropout_text + state, text_tokens, force_dropout_text, user_audio_channel_embedding=user_audio_channel_embedding, ) # Phase 2: Transformer forward pass @@ -1643,6 +1639,51 @@ def streaming_step( state.past_key_values = transformer_out.past_key_values state.cache_seq_len += 1 + if prefill_like_step: + # Advance logical streams, but do not predict phoneme/audio logits. + state.context_position += needs_context.long() + state.text_tokens_seen += (~needs_context).long() + state.phoneme_steps += needs_phoneme.long() + state.audio_steps += needs_audio.long() + + C = self.num_audio_codebooks + S = self.frame_stacking_factor + B = state.config.batch_size + + sil = self.codec_sil_codes.to(device=device, dtype=torch.long) + sil = sil.view(1, C, 1).expand(B, C, S).contiguous() + + # Keep decoded profile/warmup region silent. + state.all_predictions.append(sil) + + # Only the final prefill-like step should expose USER_SPEAKING_END. + use_end_token = ( + prefill_like_is_last_step + and self.cfg.get("use_user_speaking_end_token", False) + and not self.cfg.get("agent_mask_include_transition_prefix", False) + and self.phoneme_tokenizer is None + ) + + if use_end_token: + state.last_audio_codes = torch.full( + (B, C * S), + self.audio_user_speaking_end_id, + dtype=torch.long, + device=device, + ) + elif self.cfg.get("use_user_speaking_token", False): + state.last_audio_codes = torch.full( + (B, C * S), + self.audio_user_speaking_id, + dtype=torch.long, + device=device, + ) + else: + state.last_audio_codes = sil.reshape(B, C * S) + + return state, None, None + + # Phase 3: Update counters and extract predictions audio_codes_next, pred_phoneme_tokens = self._process_predictions( state, needs_context, needs_phoneme, needs_audio @@ -1655,6 +1696,7 @@ def _prepare_streaming_input( state: StreamingState, text_tokens: Optional[torch.Tensor], force_dropout_text: bool, + user_audio_channel_embedding: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Build the input embedding for one streaming step. @@ -1793,6 +1835,15 @@ def _prepare_streaming_input( next_input = next_input + audio_emb + if user_audio_channel_embedding is not None: + if user_audio_channel_embedding.dim() == 2: + user_audio_channel_embedding = user_audio_channel_embedding.unsqueeze(1) + + next_input = next_input + user_audio_channel_embedding.to( + device=next_input.device, + dtype=next_input.dtype, + ) + # --- Handle CFG --- if state.config.use_cfg: next_input_unconditional_context = state.config.dummy_context_embedding_unconditional.expand( diff --git a/uv.lock b/uv.lock index a6be21cea190..31407662ce89 100644 --- a/uv.lock +++ b/uv.lock @@ -10168,4 +10168,4 @@ source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/e3/02/0f2892c661036d50ede074e376733dca2ae7c6eb617489437771209d4180/zipp-3.23.0.tar.gz", hash = "sha256:a07157588a12518c9d4034df3fbbee09c814741a33ff63c05fa29d26a2404166", size = 25547, upload-time = "2025-06-08T17:06:39.4Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/2e/54/647ade08bf0db230bfea292f893923872fd20be6ac6f53b2b936ba839d75/zipp-3.23.0-py3-none-any.whl", hash = "sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e", size = 10276, upload-time = "2025-06-08T17:06:38.034Z" }, -] +] \ No newline at end of file From 460f427cb8dc0cf4250a0116fc8d6b22eb0e865c Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 22 May 2026 11:48:53 -0700 Subject: [PATCH 054/109] Update inference script Signed-off-by: Edresson Casanova --- .../tts/easy_magpietts_inference_multiturn.py | 59 ++++++++++++------- ...text_to_speech_dataset_lhotse_multiturn.py | 2 +- .../tts/models/easy_magpietts_inference.py | 3 +- 3 files changed, 41 insertions(+), 23 deletions(-) diff --git a/examples/tts/easy_magpietts_inference_multiturn.py b/examples/tts/easy_magpietts_inference_multiturn.py index 918702ed41a6..9d20948cde23 100644 --- a/examples/tts/easy_magpietts_inference_multiturn.py +++ b/examples/tts/easy_magpietts_inference_multiturn.py @@ -619,6 +619,8 @@ def main(): secs_metric = SECS("titanet_large").reset() eval_dataset = EvalJSONLDataset(args.datasets_json_path, num_turns=args.num_turns) + # debug + # eval_dataset.samples = eval_dataset.samples[:100] collate_fn = partial( collate_and_tokenize_custom, @@ -1118,45 +1120,61 @@ def main(): ext = ".wav" if inputs["profile_multiturn"]: - wav = audio_f32[i, : audio_len[i]].numpy() - out_path = os.path.join(args.out_dir, base_name) - sf.write(out_path, wav, samplerate=model.output_sample_rate) - logging.info(f"Full Audio Saved: {out_path}") - - full_wav = audio_f32[i].numpy() full_len = int(audio_len[i].item()) + full_wav_t = audio_f32[i, :full_len].detach().cpu().float() + + samples_per_prediction_frame = ( + model.codec_model_samples_per_frame + / (model.sample_rate / model.output_sample_rate) + ) + + # Build artifact-free aligned agent audio: + # start from true zeros and copy only generated turn regions. + aligned_agent = torch.zeros_like(full_wav_t) + print(profile_turn_frame_ranges) + for turn_id, start_frame, end_frame in profile_turn_frame_ranges: - samples_per_prediction_frame = ( - model.codec_model_samples_per_frame / (model.sample_rate / model.output_sample_rate) - ) rel_start_frame = start_frame - profile_decode_start_frame rel_end_frame = end_frame - profile_decode_start_frame start_sample = int(round(rel_start_frame * samples_per_prediction_frame)) end_sample = int(round(rel_end_frame * samples_per_prediction_frame)) + start_sample = max(0, min(start_sample, full_len)) end_sample = max(start_sample, min(end_sample, full_len)) - print("Turn:", turn_id, "Start:", start_sample, "End:", end_sample, "Start S:", start_sample/model.output_sample_rate, "End S:", end_sample/model.output_sample_rate, ) - turn_wav = full_wav[start_sample:end_sample] + print( + "Turn:", turn_id, + "Start:", start_sample, + "End:", end_sample, + "Start S:", start_sample / model.output_sample_rate, + "End S:", end_sample / model.output_sample_rate, + ) + + # Copy only this turn into the aligned full output. + aligned_agent[start_sample:end_sample] = full_wav_t[start_sample:end_sample] + + # Save individual turn from the same aligned region. + turn_wav = aligned_agent[start_sample:end_sample].numpy() out_path = os.path.join(args.out_dir, f"{stem}_turn_{turn_id}{ext}") sf.write(out_path, turn_wav, samplerate=model.output_sample_rate) logging.info(f"Saved: {out_path}") + # Save full artifact-scrubbed agent audio. + wav = aligned_agent.numpy() + out_path = os.path.join(args.out_dir, base_name) + sf.write(out_path, wav, samplerate=model.output_sample_rate) + logging.info(f"Full aligned agent audio saved: {out_path}") + # --------------------------------------------------------- # Save aligned stereo conversation: # channel 0 = user conditioning audio - # channel 1 = generated agent audio + # channel 1 = generated agent audio, zeroed outside turns # --------------------------------------------------------- if "user_audio_turns" in inputs: user_segments = [] - samples_per_prediction_frame = ( - model.codec_model_samples_per_frame - / (model.sample_rate / model.output_sample_rate) - ) - first_user_len_in = int(inputs["user_audio_turns_lens"][0][i].item()) first_user_delay_out = int( round(first_user_len_in * model.output_sample_rate / model.sample_rate) @@ -1197,11 +1215,11 @@ def main(): e = s + wav_seg.numel() user_ch[s:e] += wav_seg - agent_pred = torch.from_numpy(wav).float() + # Agent channel keeps the same previous offset, but uses aligned_agent. agent_ch = torch.cat( [ - torch.zeros(first_user_delay_out, dtype=agent_pred.dtype), - agent_pred, + torch.zeros(first_user_delay_out, dtype=aligned_agent.dtype), + aligned_agent, ] ) @@ -1227,6 +1245,7 @@ def main(): ) logging.info(f"Aligned user/agent stereo audio saved: {aligned_path}") + else: wav = audio_f32[i, : audio_len[i]].numpy() out_path = os.path.join(args.out_dir, base_name) diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py index 4251df0a3874..111176903367 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py @@ -478,7 +478,7 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) audio_array_16khz = cut.context_audio.resample(16_000).load_audio().squeeze(0) if self.volume_norm: audio_array_16khz = normalize_volume(audio_array_16khz) - + _available_context_duration = len(audio_array_16khz) / 16_000 _context_duration_to_slice = _sample_context_duration_with_available_limit(_available_context_duration) _num_samples_to_slice = int(_context_duration_to_slice * 16_000) diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index 04c37c79033c..26a0792e6792 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -1836,8 +1836,7 @@ def _prepare_streaming_input( next_input = next_input + audio_emb if user_audio_channel_embedding is not None: - if user_audio_channel_embedding.dim() == 2: - user_audio_channel_embedding = user_audio_channel_embedding.unsqueeze(1) + user_audio_channel_embedding = user_audio_channel_embedding.unsqueeze(1) next_input = next_input + user_audio_channel_embedding.to( device=next_input.device, From 37e8b3bdd6301b8b9de65c9b9238ce93b3da7d76 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 25 May 2026 06:31:13 -0700 Subject: [PATCH 055/109] Fix phoneme loss Signed-off-by: Edresson Casanova --- nemo/collections/tts/models/easy_magpietts.py | 67 ++++++++++++++++--- 1 file changed, 58 insertions(+), 9 deletions(-) diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index ea79f65edc5e..45c609fb121f 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -214,21 +214,70 @@ def compute_loss( total_codebook_loss = total_codebook_loss / audio_codes.size(1) return total_codebook_loss, loss_mask - def compute_phoneme_loss(self, logits, phoneme_tokens, phoneme_tokens_lens): + def compute_phoneme_loss( + self, + logits, + phoneme_tokens, + phoneme_tokens_lens, + agent_mask_target=None, + ): + """ + logits: (B, T', phoneme_stacking_factor * phoneme_vocab_size) + phoneme_tokens: (B, S, T') + phoneme_tokens_lens: (B,) + agent_mask_target: optional (B, T') + """ loss_mask = get_mask_from_lengths(phoneme_tokens_lens) + loss_mask = loss_mask.unsqueeze(1).repeat(1, phoneme_tokens.size(1), 1) + + if agent_mask_target is not None: + target_T = phoneme_tokens.size(2) + + if agent_mask_target.size(1) < target_T: + pad = torch.zeros( + agent_mask_target.size(0), + target_T - agent_mask_target.size(1), + device=agent_mask_target.device, + dtype=agent_mask_target.dtype, + ) + agent_mask_target = torch.cat([agent_mask_target, pad], dim=1) + else: + agent_mask_target = agent_mask_target[:, :target_T] + + agent_mask_target = agent_mask_target.to( + device=phoneme_tokens.device, + dtype=loss_mask.dtype, + ) + total_phoneme_loss = None + for codebook in range(self.phoneme_stacking_factor): si = codebook * self.phoneme_vocab_size ei = si + self.phoneme_vocab_size + phoneme_logits = logits[:, :, si:ei] phoneme_targets = phoneme_tokens[:, codebook] - phoneme_loss = self.cross_entropy_loss(phoneme_logits.permute(0, 2, 1), phoneme_targets) - phoneme_loss = phoneme_loss * loss_mask - phoneme_loss = phoneme_loss.sum() / loss_mask.sum() - if total_phoneme_loss is None: - total_phoneme_loss = phoneme_loss - else: - total_phoneme_loss = total_phoneme_loss + phoneme_loss + + raw_loss = self.cross_entropy_loss( + phoneme_logits.permute(0, 2, 1), + phoneme_targets.long(), + ) # (B, T') + + effective_mask = loss_mask[:, codebook, :] + + if agent_mask_target is not None: + effective_mask = effective_mask * agent_mask_target + print("Hereee") + + phoneme_loss = raw_loss * effective_mask + phoneme_loss = phoneme_loss.sum() / effective_mask.sum().clamp_min(1.0) + + total_phoneme_loss = ( + phoneme_loss + if total_phoneme_loss is None + else total_phoneme_loss + phoneme_loss + ) + total_phoneme_loss = total_phoneme_loss / self.phoneme_stacking_factor return total_phoneme_loss, loss_mask @@ -1171,7 +1220,7 @@ def process_batch( dropout_complete_phoneme_channel or dropout_conditional_input or dropout_text_input ): phoneme_loss, _ = self.compute_phoneme_loss( - pb_phoneme_logits, pb_phoneme_tokens_target, pb_phoneme_tokens_lens_target + pb_phoneme_logits, pb_phoneme_tokens_target, pb_phoneme_tokens_lens_target, agent_mask_target=agent_mask if self.cfg.get("mask_user_on_loss", False) else None ) else: phoneme_loss = torch.tensor(0.0, device=logits.device) From 0f2975ccd23546ba96d8579cb903ea04d8453cab Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 25 May 2026 06:37:15 -0700 Subject: [PATCH 056/109] Remove debug print Signed-off-by: Edresson Casanova --- nemo/collections/tts/models/easy_magpietts.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index 45c609fb121f..8b9698625189 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -267,7 +267,6 @@ def compute_phoneme_loss( if agent_mask_target is not None: effective_mask = effective_mask * agent_mask_target - print("Hereee") phoneme_loss = raw_loss * effective_mask phoneme_loss = phoneme_loss.sum() / effective_mask.sum().clamp_min(1.0) From 0bd3ff7fe9def8a78aece32bd1046540927a4089 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 25 May 2026 12:27:35 -0700 Subject: [PATCH 057/109] Fix phoneme inference Signed-off-by: Edresson Casanova --- .../tts/models/easy_magpietts_inference.py | 69 ++++++++++++++++--- 1 file changed, 61 insertions(+), 8 deletions(-) diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index 26a0792e6792..c1d8eb82a78b 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -1640,11 +1640,9 @@ def streaming_step( state.cache_seq_len += 1 if prefill_like_step: - # Advance logical streams, but do not predict phoneme/audio logits. + # Advance logical streams, keep audio silent, but predict phonemes if enabled. state.context_position += needs_context.long() state.text_tokens_seen += (~needs_context).long() - state.phoneme_steps += needs_phoneme.long() - state.audio_steps += needs_audio.long() C = self.num_audio_codebooks S = self.frame_stacking_factor @@ -1656,7 +1654,51 @@ def streaming_step( # Keep decoded profile/warmup region silent. state.all_predictions.append(sil) - # Only the final prefill-like step should expose USER_SPEAKING_END. + pred_phoneme_tokens = None + + if needs_phoneme.any() and self.phoneme_tokenizer is not None: + first_phoneme_step = needs_phoneme & (state.phoneme_prediction_start_idx == -1) + if first_phoneme_step.any(): + current_phoneme_step_idx = len(state.all_phoneme_predictions) + state.phoneme_prediction_start_idx = torch.where( + first_phoneme_step, + torch.full_like(state.phoneme_prediction_start_idx, current_phoneme_step_idx), + state.phoneme_prediction_start_idx, + ) + + pred_phoneme_tokens = self._predict_phoneme_tokens(state) + + if state.last_phoneme_tokens is None: + state.last_phoneme_tokens = pred_phoneme_tokens + else: + update_mask = needs_phoneme.view(B, 1).expand_as(pred_phoneme_tokens) + state.last_phoneme_tokens = torch.where( + update_mask, + pred_phoneme_tokens, + state.last_phoneme_tokens, + ) + + state.all_phoneme_predictions.append(pred_phoneme_tokens) + + phoneme_eos_detected = needs_phoneme & ( + pred_phoneme_tokens == self.phoneme_tokenizer.eos_token_id + ).any(dim=1) + + state.phoneme_eos_detected = state.phoneme_eos_detected | phoneme_eos_detected + + newly_ended_phoneme = phoneme_eos_detected & (state.phoneme_prediction_end_idx == -1) + if newly_ended_phoneme.any(): + current_phoneme_step_idx = len(state.all_phoneme_predictions) + state.phoneme_prediction_end_idx = torch.where( + newly_ended_phoneme, + torch.full_like(state.phoneme_prediction_end_idx, current_phoneme_step_idx), + state.phoneme_prediction_end_idx, + ) + + state.phoneme_steps += needs_phoneme.long() + state.audio_steps += needs_audio.long() + + # Same behavior as your previous code when phoneme is disabled. use_end_token = ( prefill_like_is_last_step and self.cfg.get("use_user_speaking_end_token", False) @@ -1680,15 +1722,13 @@ def streaming_step( ) else: state.last_audio_codes = sil.reshape(B, C * S) - - return state, None, None + return state, None, pred_phoneme_tokens # Phase 3: Update counters and extract predictions audio_codes_next, pred_phoneme_tokens = self._process_predictions( state, needs_context, needs_phoneme, needs_audio ) - return state, audio_codes_next, pred_phoneme_tokens def _prepare_streaming_input( @@ -1791,9 +1831,22 @@ def _prepare_streaming_input( phoneme_emb = phoneme_emb + phoneme_bos_emb * first_mask if has_last_phoneme.any() and state.last_phoneme_tokens is not None: + last_phoneme_tokens = state.last_phoneme_tokens # (B, S_ph) + last_phoneme_emb = self.embed_phoneme_tokens( - state.last_phoneme_tokens.unsqueeze(2) + last_phoneme_tokens.unsqueeze(2) ) # (B, 1, E) + + # Match training: PAD phoneme inputs contribute zero embedding. + if self.cfg.get("use_multiturn_dataset", False): + phoneme_pad_id = getattr(self.phoneme_tokenizer, "pad", None) + if phoneme_pad_id is not None: + # Same convention as training: check first stacked phoneme channel. + phoneme_is_pad = last_phoneme_tokens[:, 0] == phoneme_pad_id # (B,) + last_phoneme_emb = last_phoneme_emb * (~phoneme_is_pad).view(batch_size, 1, 1).float() + else: + raise ValueError("self.phoneme_tokenizer.pad is not defined, so it is not possible to zero-out the phoneme on the padding positon, please verify it!") + last_mask = has_last_phoneme.view(batch_size, 1, 1).float() phoneme_emb = phoneme_emb + last_phoneme_emb * last_mask From f76314ff8cf7cc42372606cc598e026e33f8bd59 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 26 May 2026 06:37:36 -0700 Subject: [PATCH 058/109] Add phoneme_loss_mask_padding Signed-off-by: Edresson Casanova --- nemo/collections/tts/models/easy_magpietts.py | 134 +++--------------- 1 file changed, 21 insertions(+), 113 deletions(-) diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index 8b9698625189..890b1aefd597 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -219,32 +219,33 @@ def compute_phoneme_loss( logits, phoneme_tokens, phoneme_tokens_lens, - agent_mask_target=None, + custom_mask=None, ): """ logits: (B, T', phoneme_stacking_factor * phoneme_vocab_size) phoneme_tokens: (B, S, T') phoneme_tokens_lens: (B,) - agent_mask_target: optional (B, T') + custom_mask: optional (B, T') """ loss_mask = get_mask_from_lengths(phoneme_tokens_lens) loss_mask = loss_mask.unsqueeze(1).repeat(1, phoneme_tokens.size(1), 1) - if agent_mask_target is not None: + if custom_mask is not None: + custom_mask = custom_mask.bool() target_T = phoneme_tokens.size(2) - if agent_mask_target.size(1) < target_T: + if custom_mask.size(1) < target_T: pad = torch.zeros( - agent_mask_target.size(0), - target_T - agent_mask_target.size(1), - device=agent_mask_target.device, - dtype=agent_mask_target.dtype, + custom_mask.size(0), + target_T - custom_mask.size(1), + device=custom_mask.device, + dtype=custom_mask.dtype, ) - agent_mask_target = torch.cat([agent_mask_target, pad], dim=1) + custom_mask = torch.cat([custom_mask, pad], dim=1) else: - agent_mask_target = agent_mask_target[:, :target_T] + custom_mask = custom_mask[:, :target_T] - agent_mask_target = agent_mask_target.to( + custom_mask = custom_mask.to( device=phoneme_tokens.device, dtype=loss_mask.dtype, ) @@ -265,8 +266,8 @@ def compute_phoneme_loss( effective_mask = loss_mask[:, codebook, :] - if agent_mask_target is not None: - effective_mask = effective_mask * agent_mask_target + if custom_mask is not None: + effective_mask = effective_mask * custom_mask phoneme_loss = raw_loss * effective_mask phoneme_loss = phoneme_loss.sum() / effective_mask.sum().clamp_min(1.0) @@ -1218,8 +1219,14 @@ def process_batch( if (phoneme_corruption_mode != 'repeat_skip') and not ( dropout_complete_phoneme_channel or dropout_conditional_input or dropout_text_input ): + custom_mask = None + if self.cfg.get("phoneme_loss_mask_padding", False): + custom_mask = pb_phoneme_tokens_target[:, 0, :] != self.phoneme_tokenizer.pad # (B, T') + elif self.cfg.get("mask_user_on_loss", False): + custom_mask = agent_mask + phoneme_loss, _ = self.compute_phoneme_loss( - pb_phoneme_logits, pb_phoneme_tokens_target, pb_phoneme_tokens_lens_target, agent_mask_target=agent_mask if self.cfg.get("mask_user_on_loss", False) else None + pb_phoneme_logits, pb_phoneme_tokens_target, pb_phoneme_tokens_lens_target, custom_mask=custom_mask ) else: phoneme_loss = torch.tensor(0.0, device=logits.device) @@ -1262,105 +1269,6 @@ def training_step(self, batch, batch_idx): audio_lens = batch['audio_lens'] audio_codes, audio_codes_lens = self._codec_helper.audio_to_codes(audio, audio_lens) - # augment tts data to looks more like multiturn data by adding pad on the begining and emulating user speaking. - if self.cfg.get("use_multiturn_dataset", False) and "tts" in batch['task']: - prob = self.cfg.get("add_tts_sil_begining_prob", 0.0) - if prob > 0 and torch.rand(1).item() < prob: - audio_codes_lens_max = audio_codes_lens.max() - - # 1. Calculate the raw shift (with the -1 safety buffer) - raw_pad_lens = torch.clamp(audio_codes_lens_max - audio_codes_lens - 4, min=0) - - # 2. Round DOWN to the nearest multiple of the stacking factor - pad_lens = (raw_pad_lens // self.frame_stacking_factor) * self.frame_stacking_factor - - # 3. Calculate perfectly aligned text padding - text_pad_lens = pad_lens // self.frame_stacking_factor - - if pad_lens.max() > 0: - device = audio_codes.device - B, C, T_audio = audio_codes.shape - - # --- Vectorized Audio Shift --- - idx_a = torch.arange(T_audio, device=device).unsqueeze(0) - src_idx_a = idx_a - pad_lens.unsqueeze(1) - - valid_mask_a = (src_idx_a >= 0) & (src_idx_a < audio_codes_lens.unsqueeze(1)) - safe_src_idx_a = src_idx_a.clamp(min=0, max=T_audio - 1) - - safe_src_idx_a_exp = safe_src_idx_a.unsqueeze(1).expand(-1, C, -1) - valid_mask_a_exp = valid_mask_a.unsqueeze(1).expand(-1, C, -1) - - gathered_audio = torch.gather(audio_codes, 2, safe_src_idx_a_exp) - silence_pad = self.codec_sil_codes_unconverted.view(1, C, 1).expand(B, C, T_audio) - - audio_codes = torch.where(valid_mask_a_exp, gathered_audio, silence_pad) - audio_codes_lens = torch.clamp(audio_codes_lens + pad_lens, max=T_audio) - - # Vectorized Text Shift - old_text = batch['text'] - text_lens = batch['text_lens'] - - new_T_text = old_text.size(1) - new_text_lens = torch.clamp(text_lens + text_pad_lens, max=new_T_text) - idx_t = torch.arange(new_T_text, device=device).unsqueeze(0) - src_idx_t = idx_t - text_pad_lens.unsqueeze(1) - - valid_mask_t = (src_idx_t >= 0) & (src_idx_t < text_lens.unsqueeze(1)) - safe_src_idx_t = src_idx_t.clamp(min=0, max=old_text.size(1) - 1) - gathered_text = torch.gather(old_text, 1, safe_src_idx_t) - - batch['text'] = torch.where(valid_mask_t, gathered_text, self.pad_id) - batch['text_lens'] = new_text_lens - - # Vectorized Phoneme Shift - if ( - self.phoneme_tokenizer is not None - and batch.get("phoneme_tokens") is not None - and batch.get("phoneme_tokens_lens") is not None - ): - old_phonemes = batch["phoneme_tokens"] - phoneme_lens = batch["phoneme_tokens_lens"] - - new_T_phoneme = old_phonemes.size(1) - new_phoneme_lens = torch.clamp(phoneme_lens + text_pad_lens, max=new_T_phoneme) - - idx_p = torch.arange(new_T_phoneme, device=device).unsqueeze(0) - src_idx_p = idx_p - text_pad_lens.unsqueeze(1) - - valid_mask_p = (src_idx_p >= 0) & (src_idx_p < phoneme_lens.unsqueeze(1)) - safe_src_idx_p = src_idx_p.clamp(min=0, max=old_phonemes.size(1) - 1) - - gathered_phonemes = torch.gather(old_phonemes, 1, safe_src_idx_p) - - phoneme_pad_id = getattr(self.phoneme_tokenizer, "pad", -1) - batch["phoneme_tokens"] = torch.where( - valid_mask_p, - gathered_phonemes, - torch.full_like(gathered_phonemes, phoneme_pad_id), - ) - batch["phoneme_tokens_lens"] = new_phoneme_lens - - # change batch["agent_mask"] to consider this augmentation (in practice adding zeros/False where we are adding silence ) - if self.cfg.get("use_multiturn_dataset", False) and "agent_mask" in batch: - old_agent_mask = batch["agent_mask"].bool() - T_mask = old_agent_mask.size(1) - - idx_m = torch.arange(T_mask, device=device).unsqueeze(0) - src_idx_m = idx_m - text_pad_lens.unsqueeze(1) - - valid_mask_m = (src_idx_m >= 0) & (src_idx_m < old_agent_mask.size(1)) - safe_src_idx_m = src_idx_m.clamp(min=0, max=old_agent_mask.size(1) - 1) - - gathered_agent_mask = torch.gather(old_agent_mask, 1, safe_src_idx_m) - - # New prepended silence/user region should be non-agent. - batch["agent_mask"] = torch.where( - valid_mask_m, - gathered_agent_mask, - torch.zeros_like(gathered_agent_mask), - ) - if ( self.cfg.get("use_multiturn_dataset", False) and batch["user_audio_turn_splitted"] is not None From 0c254c7913f814a436898dd2cd8d0a7c2ba378ee Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 26 May 2026 06:39:03 -0700 Subject: [PATCH 059/109] Update inference Signed-off-by: Edresson Casanova --- .../tts/easy_magpietts_inference_multiturn.py | 9 +++++++ .../tts/models/easy_magpietts_inference.py | 24 ++++++++++++++++--- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/examples/tts/easy_magpietts_inference_multiturn.py b/examples/tts/easy_magpietts_inference_multiturn.py index 9d20948cde23..9ad3068efaf2 100644 --- a/examples/tts/easy_magpietts_inference_multiturn.py +++ b/examples/tts/easy_magpietts_inference_multiturn.py @@ -822,11 +822,20 @@ def main(): state.text_finished.zero_() state.audio_prediction_end_idx.fill_(-1) + if hasattr(state, "turn_text_tokens_seen"): + state.turn_text_tokens_seen.zero_() + + if hasattr(state, "phoneme_steps"): + state.phoneme_steps.zero_() + if hasattr(state, "phoneme_stream_ended"): state.phoneme_stream_ended.zero_() + if hasattr(state, "phoneme_eos_detected"): state.phoneme_eos_detected.zero_() + state.last_phoneme_tokens = None + if not model.cfg.get("condition_on_user_speech", False): # Prefill on the begining of each turn if "user_audio_turns" in inputs: diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index c1d8eb82a78b..b3db7d038876 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -138,6 +138,7 @@ class StreamingState: full_context_lens: torch.Tensor context_position: torch.Tensor text_tokens_seen: torch.Tensor + turn_text_tokens_seen: torch.Tensor phoneme_steps: torch.Tensor audio_steps: torch.Tensor phoneme_stream_ended: torch.Tensor @@ -151,6 +152,7 @@ class StreamingState: audio_prediction_end_idx: torch.Tensor phoneme_prediction_start_idx: torch.Tensor phoneme_prediction_end_idx: torch.Tensor + gt_phoneme_embeddings: Optional[torch.Tensor] = None # (B, T', E) pre-computed GT embeddings gt_phoneme_lens: Optional[torch.Tensor] = None # (B,) lengths after stacking gt_audio_embeddings: Optional[torch.Tensor] = None # (B, T', E) pre-computed GT audio embeddings @@ -1562,6 +1564,7 @@ def streaming_init( full_context_lens=full_context_lens, context_position=torch.full((batch_size,), min_context_len, dtype=torch.long, device=device), text_tokens_seen=torch.zeros(batch_size, dtype=torch.long, device=device), + turn_text_tokens_seen=torch.zeros(batch_size, dtype=torch.long, device=device), phoneme_steps=torch.zeros(batch_size, dtype=torch.long, device=device), audio_steps=torch.zeros(batch_size, dtype=torch.long, device=device), phoneme_stream_ended=torch.zeros(batch_size, dtype=torch.bool, device=device), @@ -1642,8 +1645,12 @@ def streaming_step( if prefill_like_step: # Advance logical streams, keep audio silent, but predict phonemes if enabled. state.context_position += needs_context.long() + advance_text = (~needs_context).long() state.text_tokens_seen += (~needs_context).long() + if hasattr(state, "turn_text_tokens_seen"): + state.turn_text_tokens_seen += (~needs_context).long() + C = self.num_audio_codebooks S = self.frame_stacking_factor B = state.config.batch_size @@ -1667,7 +1674,6 @@ def streaming_step( ) pred_phoneme_tokens = self._predict_phoneme_tokens(state) - if state.last_phoneme_tokens is None: state.last_phoneme_tokens = pred_phoneme_tokens else: @@ -1757,10 +1763,19 @@ def _prepare_streaming_input( # Determine phases per batch item needs_context = state.context_position < state.full_context_lens # (B,) bool needs_text = (~needs_context) & (~state.text_finished) + + turn_text_tokens_seen = getattr(state, "turn_text_tokens_seen", state.text_tokens_seen) needs_phoneme = ( - (~needs_context) & (state.text_tokens_seen >= streaming_phonemes_delay) & (~state.phoneme_stream_ended) + (~needs_context) + & (turn_text_tokens_seen >= streaming_phonemes_delay) + & (~state.phoneme_stream_ended) + ) + + needs_audio = ( + (~needs_context) + & (turn_text_tokens_seen >= streaming_speech_delay) + & (~state.finished) ) - needs_audio = (~needs_context) & (state.text_tokens_seen >= streaming_speech_delay) & (~state.finished) next_input = torch.zeros(batch_size, 1, self.cfg.embedding_dim, device=device) @@ -1985,6 +2000,9 @@ def _process_predictions( # Update counters state.context_position = state.context_position + needs_context.long() state.text_tokens_seen = state.text_tokens_seen + (~needs_context).long() + if hasattr(state, "turn_text_tokens_seen"): + state.turn_text_tokens_seen = state.turn_text_tokens_seen + (~needs_context).long() + state.phoneme_steps = state.phoneme_steps + needs_phoneme.long() state.audio_steps = state.audio_steps + needs_audio.long() From d97441537e88f12cf078c54a9f470c58be4eea51 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 26 May 2026 10:09:34 -0700 Subject: [PATCH 060/109] Add full user prefill support on nemotron_h class Signed-off-by: Edresson Casanova --- .../tts/models/easy_magpietts_inference.py | 51 ++-- .../tts/modules/nemotron_h_decoder.py | 253 +++++++++++++++--- 2 files changed, 224 insertions(+), 80 deletions(-) diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index b3db7d038876..5bedc185bec1 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -757,44 +757,23 @@ def streaming_prefill_profile( # ----------------------- # KV CACHE EXTENSION # ----------------------- - # ToDo: on VLLM we need to make nemotron_h class support prefill - if self.decoder_type == "nemotron_h" and state.past_key_values is not None and T > 1: - hidden_chunks = [] - for t in range(T): - cache_position_t = torch.tensor([state.cache_seq_len], device=device) - - out = self.forward( - inputs_embeds=inputs_embeds[:, t : t + 1, :], - attention_mask=None, - use_cache=True, - past_key_values=state.past_key_values, - cache_position=cache_position_t, - ) - - state.past_key_values = out.past_key_values - state.cache_seq_len += 1 - hidden_chunks.append(out.last_hidden_state) - - state.last_hidden = torch.cat(hidden_chunks, dim=1) - - else: - cache_position = torch.arange( - state.cache_seq_len, - state.cache_seq_len + T, - device=device, - ) + cache_position = torch.arange( + state.cache_seq_len, + state.cache_seq_len + T, + device=device, + ) - out = self.forward( - inputs_embeds=inputs_embeds, - attention_mask=None, - use_cache=True, - past_key_values=state.past_key_values, - cache_position=cache_position, - ) + out = self.forward( + inputs_embeds=inputs_embeds, + attention_mask=None, + use_cache=True, + past_key_values=state.past_key_values, + cache_position=cache_position, + ) - state.past_key_values = out.past_key_values - state.cache_seq_len += T - state.last_hidden = out.last_hidden_state + state.past_key_values = out.past_key_values + state.cache_seq_len += T + state.last_hidden = out.last_hidden_state # Advance logical streams consumed by this profile prefill. state.text_tokens_seen += T diff --git a/nemo/collections/tts/modules/nemotron_h_decoder.py b/nemo/collections/tts/modules/nemotron_h_decoder.py index 986359c0e2b3..f7e08746d753 100644 --- a/nemo/collections/tts/modules/nemotron_h_decoder.py +++ b/nemo/collections/tts/modules/nemotron_h_decoder.py @@ -91,6 +91,47 @@ ) +def make_mamba_conv_cache_from_sequence( + hidden_states_B_C: torch.Tensor, + conv_kernel_size: int, +) -> torch.Tensor: + """ + hidden_states_B_C: (B, T, conv_dim) + returns conv cache: (B, conv_dim, conv_kernel_size) + """ + x = hidden_states_B_C.transpose(1, 2) # (B, conv_dim, T) + + if x.size(-1) >= conv_kernel_size: + return x[:, :, -conv_kernel_size:].contiguous() + + return F.pad(x, (conv_kernel_size - x.size(-1), 0)).contiguous() + +def get_cached_mamba_ssm_state( + cache_params: "HybridMambaAttentionDynamicCache", + layer_idx: int, + batch_size: int, + num_heads: int, + head_dim: int, + ssm_state_size: int, + device: torch.device, +): + """ + Returns cached SSM state in the shape expected by full-sequence/chunk scan: + (B, num_heads, head_dim, ssm_state_size) + + The cache may contain either: + (B, num_heads * head_dim, ssm_state_size) + or: + (B, num_heads, head_dim, ssm_state_size) + depending on which path updated it last. + """ + state = cache_params.ssm_states[layer_idx].to(device=device) + + if state.dim() == 3: + state = state.view(batch_size, num_heads, head_dim, ssm_state_size) + + return state + def get_activation_fn(activation: str): """Get activation function by name.""" if activation == "silu" or activation == "swish": @@ -240,6 +281,7 @@ def __init__(self, config: NemotronHConfig, batch_size: int, dtype=torch.float16 intermediate_size = config.mamba_num_heads * config.mamba_head_dim ssm_state_size = config.ssm_state_size conv_kernel_size = config.conv_kernel + conv_dim = intermediate_size + 2 * config.n_groups * config.ssm_state_size self.conv_states = [] self.ssm_states = [] @@ -250,7 +292,7 @@ def __init__(self, config: NemotronHConfig, batch_size: int, dtype=torch.float16 for i in range(config.num_hidden_layers): if config.layers_block_type[i] == "mamba": self.conv_states.append( - torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype) + torch.zeros(batch_size, conv_dim, conv_kernel_size, device=device, dtype=dtype) ) self.ssm_states.append( torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype) @@ -495,10 +537,18 @@ def cuda_kernels_forward( - self.num_heads ) // 2 - if cache_params is not None and cache_position is not None and cache_position[0] > 0: - # Cached forward (single token) + has_cache_prefix = ( + cache_params is not None + and cache_position is not None + and cache_position[0] > 0 + ) + is_decode_step = has_cache_prefix and seq_len == 1 + + if is_decode_step: + # Cached single-token decode path. _, _, gate, hidden_states_B_C, dt = projected_states.squeeze(1).split( - [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], + dim=-1, ) hidden_states_B_C = causal_conv1d_update( @@ -539,8 +589,9 @@ def cuda_kernels_forward( hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim) hidden_states = self.norm(hidden_states, gate) out = self.out_proj(hidden_states)[:, None, ...] + else: - # Full sequence forward + # Full sequence or cached multi-token prefill path. A = -torch.exp(self.A_log.float()) dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit} @@ -567,31 +618,59 @@ def cuda_kernels_forward( ) else: _, _, gate, hidden_states_B_C, dt = projected_states.split( - [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], + dim=-1, ) - if cache_params is not None: - hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) - conv_states = F.pad( - hidden_states_B_C_transposed, - (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0), - ) - cache_params.update_conv_state( - layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True - ) + raw_hidden_states_B_C = hidden_states_B_C + + # For cached T > 1 prefill, convolution must see the previous conv cache. + if has_cache_prefix: + prev_conv = cache_params.conv_states[self.layer_idx].to( + device=raw_hidden_states_B_C.device, + dtype=raw_hidden_states_B_C.dtype, + ) # (B, conv_dim, K) + + conv_input = torch.cat( + [prev_conv, raw_hidden_states_B_C.transpose(1, 2)], + dim=-1, + ) # (B, conv_dim, K + T) + + raw_for_cache = torch.cat( + [prev_conv.transpose(1, 2), raw_hidden_states_B_C], + dim=1, + ) # (B, K + T, conv_dim) + else: + conv_input = raw_hidden_states_B_C.transpose(1, 2) + raw_for_cache = raw_hidden_states_B_C if self.activation not in ["silu", "swish"]: - hidden_states_B_C = self.act( - self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2) + conv_out = self.act( + self.conv1d(conv_input)[..., : conv_input.size(-1)].transpose(1, 2) ) else: - hidden_states_B_C = causal_conv1d_fn( - x=hidden_states_B_C.transpose(1, 2), + conv_out = causal_conv1d_fn( + x=conv_input, weight=self.conv1d.weight.squeeze(1), bias=self.conv1d.bias, activation=self.activation, ).transpose(1, 2) + # Keep only outputs corresponding to the new chunk. + hidden_states_B_C = conv_out[:, -seq_len:, :].contiguous() + + # Update cache after using the previous cache. + if cache_params is not None: + conv_states = make_mamba_conv_cache_from_sequence( + raw_for_cache, + cache_params.conv_kernel_size, + ) + cache_params.update_conv_state( + layer_idx=self.layer_idx, + new_conv_state=conv_states, + cache_init=True, + ) + hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) hidden_states, B, C = torch.split( hidden_states_B_C, @@ -599,12 +678,7 @@ def cuda_kernels_forward( dim=-1, ) - scan_output, ssm_state = mamba_chunk_scan_combined( - hidden_states.view(batch_size, seq_len, -1, self.head_dim), - dt, - A, - B.view(batch_size, seq_len, self.n_groups, -1), - C.view(batch_size, seq_len, self.n_groups, -1), + scan_kwargs = dict( chunk_size=self.chunk_size, D=self.D, z=None, @@ -615,8 +689,32 @@ def cuda_kernels_forward( **dt_limit_kwargs, ) + # For cached T > 1 prefill, SSM must start from previous cache state. + if has_cache_prefix: + scan_kwargs["initial_states"] = get_cached_mamba_ssm_state( + cache_params=cache_params, + layer_idx=self.layer_idx, + batch_size=batch_size, + num_heads=self.num_heads, + head_dim=self.head_dim, + ssm_state_size=self.ssm_state_size, + device=hidden_states.device, + ) + + scan_output, ssm_state = mamba_chunk_scan_combined( + hidden_states.view(batch_size, seq_len, -1, self.head_dim), + dt, + A, + B.view(batch_size, seq_len, self.n_groups, -1), + C.view(batch_size, seq_len, self.n_groups, -1), + **scan_kwargs, + ) + if ssm_state is not None and cache_params is not None: - cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state) + cache_params.update_ssm_state( + layer_idx=self.layer_idx, + new_ssm_state=ssm_state, + ) scan_output = scan_output.view(batch_size, seq_len, -1) scan_output = self.norm(scan_output, gate) @@ -645,28 +743,76 @@ def torch_forward( - self.num_heads ) // 2 _, _, gate, hidden_states_B_C, dt = projected_states.split( - [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], + dim=-1, + ) + + has_cache_prefix = ( + cache_params is not None + and cache_position is not None + and cache_position[0] > 0 ) + is_decode_step = has_cache_prefix and seq_len == 1 + # ----------------------- # Convolution - if cache_params is not None and cache_position is not None and cache_position[0] > 0: + # ----------------------- + if is_decode_step: cache_params.update_conv_state( - layer_idx=self.layer_idx, new_conv_state=hidden_states_B_C, cache_init=False + layer_idx=self.layer_idx, + new_conv_state=hidden_states_B_C, + cache_init=False, + ) + conv_states = cache_params.conv_states[self.layer_idx].to( + device=self.conv1d.weight.device, + dtype=hidden_states_B_C.dtype, ) - conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device) hidden_states_B_C = torch.sum(conv_states * self.conv1d.weight.squeeze(1), dim=-1) if self.use_conv_bias: hidden_states_B_C = hidden_states_B_C + self.conv1d.bias hidden_states_B_C = self.act(hidden_states_B_C) + else: + raw_hidden_states_B_C = hidden_states_B_C + + # For cached T > 1 prefill, convolution must see previous conv cache. + if has_cache_prefix: + prev_conv = cache_params.conv_states[self.layer_idx].to( + device=raw_hidden_states_B_C.device, + dtype=raw_hidden_states_B_C.dtype, + ) # (B, conv_dim, K) + + conv_input = torch.cat( + [prev_conv, raw_hidden_states_B_C.transpose(1, 2)], + dim=-1, + ) # (B, conv_dim, K + T) + + raw_for_cache = torch.cat( + [prev_conv.transpose(1, 2), raw_hidden_states_B_C], + dim=1, + ) # (B, K + T, conv_dim) + else: + conv_input = raw_hidden_states_B_C.transpose(1, 2) + raw_for_cache = raw_hidden_states_B_C + + conv_out = self.act( + self.conv1d(conv_input)[..., : conv_input.size(-1)].transpose(1, 2) + ) + + # Keep only outputs corresponding to the new chunk. + hidden_states_B_C = conv_out[:, -seq_len:, :].contiguous() + + # Update cache after using the previous cache. if cache_params is not None: - hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) - conv_states = F.pad( - hidden_states_B_C_transposed, - (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0), + conv_states = make_mamba_conv_cache_from_sequence( + raw_for_cache, + cache_params.conv_kernel_size, + ) + cache_params.update_conv_state( + layer_idx=self.layer_idx, + new_conv_state=conv_states, + cache_init=True, ) - cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True) - hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)) hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) hidden_states, B, C = torch.split( @@ -675,11 +821,13 @@ def torch_forward( dim=-1, ) + # ----------------------- # SSM + # ----------------------- A = -torch.exp(self.A_log.float()) - if cache_params is not None and cache_position is not None and cache_position[0] > 0: - # Single step SSM update + if is_decode_step: + # Single-step SSM update. cache_device = cache_params.ssm_states[self.layer_idx].device dt = dt[:, 0, :][:, None, ...] dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim) @@ -688,7 +836,9 @@ def torch_forward( dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) A_expanded = ( - A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + A[..., None, None] + .expand(self.num_heads, self.head_dim, self.ssm_state_size) + .to(dtype=torch.float32) ) dA = (torch.exp(dt[..., None] * A_expanded)).to(device=cache_device) @@ -701,7 +851,8 @@ def torch_forward( dBx = (dB * hidden_states[..., None]).to(device=cache_device) cache_params.update_ssm_state( - layer_idx=self.layer_idx, new_ssm_state=cache_params.ssm_states[self.layer_idx] * dA + dBx + layer_idx=self.layer_idx, + new_ssm_state=cache_params.ssm_states[self.layer_idx] * dA + dBx, ) C = C.reshape(batch_size, self.n_groups, -1)[..., None, :] @@ -717,8 +868,9 @@ def torch_forward( D = self.D[..., None].expand(self.D.shape[0], self.head_dim) y = (y + hidden_states * D).to(y.dtype) y = y.reshape(batch_size, -1)[:, None, ...] + else: - # Full sequence SSM (chunked) + # Full-sequence or cached multi-token SSM. dt = F.softplus(dt + self.dt_bias) dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() @@ -751,8 +903,18 @@ def torch_forward( B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None] states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2) - if cache_params is not None and cache_position is not None and cache_position[0] > 0: - previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device) + # This is the critical fix: + # cached T > 1 prefill must start from the previous SSM state. + if has_cache_prefix: + previous_states = get_cached_mamba_ssm_state( + cache_params=cache_params, + layer_idx=self.layer_idx, + batch_size=batch_size, + num_heads=self.num_heads, + head_dim=self.head_dim, + ssm_state_size=self.ssm_state_size, + device=states.device, + )[:, None, ...] else: previous_states = torch.zeros_like(states[:, :1]) @@ -776,7 +938,10 @@ def torch_forward( y = y.reshape(batch_size, seq_len, -1) if ssm_state is not None and cache_params is not None: - cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state) + cache_params.update_ssm_state( + layer_idx=self.layer_idx, + new_ssm_state=ssm_state, + ) scan_output = self.norm(y, gate) contextualized_states = self.out_proj(scan_output.to(dtype)) From bb7022c1dee67ab5de42e944f5dbf0a11a523232 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 27 May 2026 06:09:09 -0700 Subject: [PATCH 061/109] Add phoneme_loss_mask_agent_expanded Signed-off-by: Edresson Casanova --- .../tts/easy_magpietts_inference_multiturn.py | 4 +--- nemo/collections/tts/models/easy_magpietts.py | 21 +++++++++++++++++++ .../tts/models/easy_magpietts_inference.py | 5 +++-- 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/examples/tts/easy_magpietts_inference_multiturn.py b/examples/tts/easy_magpietts_inference_multiturn.py index 9ad3068efaf2..c30a6df138ea 100644 --- a/examples/tts/easy_magpietts_inference_multiturn.py +++ b/examples/tts/easy_magpietts_inference_multiturn.py @@ -571,10 +571,8 @@ def main(): model_cfg.run_val_inference = False model_cfg.use_utmos = False model_cfg.use_meta_init_for_decoder = True - # Guarantees silence for pad tokens - model_cfg.use_multiturn_dataset = True - + # model_cfg.use_multiturn_dataset = True if args.phoneme_tokenizer_path and getattr(model_cfg, "phoneme_tokenizer", None) is not None: model_cfg.phoneme_tokenizer.tokenizer_path = args.phoneme_tokenizer_path diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index 890b1aefd597..b11a07ee1f93 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -1222,6 +1222,27 @@ def process_batch( custom_mask = None if self.cfg.get("phoneme_loss_mask_padding", False): custom_mask = pb_phoneme_tokens_target[:, 0, :] != self.phoneme_tokenizer.pad # (B, T') + elif self.cfg.get("phoneme_loss_mask_agent_expanded", False): + # +1 keeps one supervised PAD step before BOS, which teaches the PAD -> BOS transition + # and makes the mask robust to small frame-shift / target-shift mismatches. + transition_prefix = int(current_streaming_phonemes_delay) + 1 # +1 to learn PAD -> BOS transition and also avoid frame-shift issues + + agent_i = agent_mask.float().unsqueeze(1) # (B, 1, T_agent) + + # Expand supervision to the left of the agent span by transition_prefix steps. + # Padding on the right + max_pool1d makes earlier positions become active. + agent_i = torch.nn.functional.pad(agent_i, (0, transition_prefix)) + + custom_mask = ( + torch.nn.functional.max_pool1d( + agent_i, + kernel_size=transition_prefix + 1, + stride=1, + ) + .squeeze(1) + .bool() + ) + elif self.cfg.get("mask_user_on_loss", False): custom_mask = agent_mask diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index 5bedc185bec1..90bd2123d23d 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -1707,6 +1707,7 @@ def streaming_step( ) else: state.last_audio_codes = sil.reshape(B, C * S) + # print("prefill", state.all_phoneme_predictions[-1] if state.all_phoneme_predictions else None, self.phoneme_tokenizer.bos_token_id, self.phoneme_tokenizer.eos_token_id, getattr(self.phoneme_tokenizer, "pad", None), ) return state, None, pred_phoneme_tokens @@ -1714,6 +1715,7 @@ def streaming_step( audio_codes_next, pred_phoneme_tokens = self._process_predictions( state, needs_context, needs_phoneme, needs_audio ) + # print("step", state.all_phoneme_predictions[-1] if state.all_phoneme_predictions else None, self.phoneme_tokenizer.bos_token_id, self.phoneme_tokenizer.eos_token_id, getattr(self.phoneme_tokenizer, "pad", None), ) return state, audio_codes_next, pred_phoneme_tokens def _prepare_streaming_input( @@ -1813,7 +1815,7 @@ def _prepare_streaming_input( else: first_phoneme_step = needs_phoneme & (state.phoneme_steps == 0) has_last_phoneme = needs_phoneme & (~first_phoneme_step) & (state.last_phoneme_tokens is not None) - + if first_phoneme_step.any(): phoneme_bos = torch.full( (batch_size, self.phoneme_stacking_factor, 1), @@ -1840,7 +1842,6 @@ def _prepare_streaming_input( last_phoneme_emb = last_phoneme_emb * (~phoneme_is_pad).view(batch_size, 1, 1).float() else: raise ValueError("self.phoneme_tokenizer.pad is not defined, so it is not possible to zero-out the phoneme on the padding positon, please verify it!") - last_mask = has_last_phoneme.view(batch_size, 1, 1).float() phoneme_emb = phoneme_emb + last_phoneme_emb * last_mask From aea4c3c06f5096ad2af4908fb40044a84c010b29 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 27 May 2026 10:00:27 -0700 Subject: [PATCH 062/109] Rename phoneme_loss_mask_include_transition Signed-off-by: Edresson Casanova --- ...text_to_speech_dataset_lhotse_multiturn.py | 2 +- nemo/collections/tts/models/easy_magpietts.py | 43 +++++++++++-------- .../tts/models/easy_magpietts_inference.py | 2 - 3 files changed, 25 insertions(+), 22 deletions(-) diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py index 111176903367..4b103694966d 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py @@ -900,6 +900,7 @@ def build_phoneme_channel( text_len = len(phoneme_ids) if text_len > turn_frames: + phoneme_ids = phoneme_ids[:turn_frames] text_len = turn_frames @@ -925,7 +926,6 @@ def build_phoneme_channel( phoneme_ids = phoneme_tokenizer.encode(ipa_text) phoneme_ids = [bos_id] + phoneme_ids + [eos_id] - phoneme_ids = torch.as_tensor(phoneme_ids, dtype=torch.long) pos = compute_num_frames(supervision.start, frame_length, cut.sampling_rate) if pos >= len(tokens): diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index b11a07ee1f93..093233d45c7c 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -1222,26 +1222,31 @@ def process_batch( custom_mask = None if self.cfg.get("phoneme_loss_mask_padding", False): custom_mask = pb_phoneme_tokens_target[:, 0, :] != self.phoneme_tokenizer.pad # (B, T') - elif self.cfg.get("phoneme_loss_mask_agent_expanded", False): - # +1 keeps one supervised PAD step before BOS, which teaches the PAD -> BOS transition - # and makes the mask robust to small frame-shift / target-shift mismatches. - transition_prefix = int(current_streaming_phonemes_delay) + 1 # +1 to learn PAD -> BOS transition and also avoid frame-shift issues - - agent_i = agent_mask.float().unsqueeze(1) # (B, 1, T_agent) - - # Expand supervision to the left of the agent span by transition_prefix steps. - # Padding on the right + max_pool1d makes earlier positions become active. - agent_i = torch.nn.functional.pad(agent_i, (0, transition_prefix)) - - custom_mask = ( - torch.nn.functional.max_pool1d( - agent_i, - kernel_size=transition_prefix + 1, - stride=1, - ) - .squeeze(1) - .bool() + elif self.cfg.get("phoneme_loss_mask_include_transition", False): + # agent_mask is aligned to the speech/audio supervision region. + # Expand it left so phoneme loss also covers the phoneme->speech transition. + # The optional +1 gives one extra supervised boundary step, useful for PAD -> BOS / target-shift robustness. + transition_prefix = max( + 0, + int(current_streaming_speech_delay - current_streaming_phonemes_delay) + 1, ) + agent_i = agent_mask.float().unsqueeze(1) # (B, 1, T) + + if transition_prefix > 0: + # Right padding + max_pool expands the active region to the left. + agent_i = torch.nn.functional.pad(agent_i, (0, transition_prefix)) + + custom_mask = ( + torch.nn.functional.max_pool1d( + agent_i, + kernel_size=transition_prefix + 1, + stride=1, + ) + .squeeze(1) + .bool() + ) + else: + custom_mask = agent_mask elif self.cfg.get("mask_user_on_loss", False): custom_mask = agent_mask diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index 90bd2123d23d..03c4828da089 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -1707,7 +1707,6 @@ def streaming_step( ) else: state.last_audio_codes = sil.reshape(B, C * S) - # print("prefill", state.all_phoneme_predictions[-1] if state.all_phoneme_predictions else None, self.phoneme_tokenizer.bos_token_id, self.phoneme_tokenizer.eos_token_id, getattr(self.phoneme_tokenizer, "pad", None), ) return state, None, pred_phoneme_tokens @@ -1715,7 +1714,6 @@ def streaming_step( audio_codes_next, pred_phoneme_tokens = self._process_predictions( state, needs_context, needs_phoneme, needs_audio ) - # print("step", state.all_phoneme_predictions[-1] if state.all_phoneme_predictions else None, self.phoneme_tokenizer.bos_token_id, self.phoneme_tokenizer.eos_token_id, getattr(self.phoneme_tokenizer, "pad", None), ) return state, audio_codes_next, pred_phoneme_tokens def _prepare_streaming_input( From 10d8904956c6172b980e726033f02e6eba6ab80b Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 28 May 2026 13:02:13 -0700 Subject: [PATCH 063/109] Remove unused code Signed-off-by: Edresson Casanova --- ...text_to_speech_dataset_lhotse_multiturn.py | 134 +----------------- nemo/collections/tts/models/easy_magpietts.py | 119 ---------------- .../tts/models/easy_magpietts_inference.py | 13 +- .../tts/modules/nemotron_h_decoder.py | 15 +- 4 files changed, 31 insertions(+), 250 deletions(-) diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py index 4b103694966d..099c04ade52b 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py @@ -172,7 +172,6 @@ def __init__( input_roles: List[str] = ["user", "User"], output_roles: List[str] = ["assistant", "Assistant", "agent", "Agent"], add_text_bos: bool = False, - remove_user_turns_prob: float = None, ): super().__init__() self.sample_rate = sample_rate @@ -180,7 +179,6 @@ def __init__( self.codec_model_samples_per_frame = codec_model_samples_per_frame self.num_audio_codebooks = num_audio_codebooks - self.remove_user_turns_prob = remove_user_turns_prob self.include_align_prior = prior_scaling_factor is not None self.prior_scaling_factor = prior_scaling_factor @@ -239,7 +237,6 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: supervision.speaker = "agent" batch_tokenizer_names = [] - remove_user_turn_flags = [] for cut in cuts: if cut.has_custom("tokenizer_names"): batch_tokenizer_names.append(random.choice(cut.tokenizer_names)) @@ -252,12 +249,6 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: # It is a multiturn if there's more than 1 agent turn is_multiturn = not (len(agent_sups) <= 1) - # Apply augmentation only if it's multiturn AND passes the probability check - if is_multiturn and self.remove_user_turns_prob and random.random() < self.remove_user_turns_prob: - remove_user_turn_flags.append(True) - else: - remove_user_turn_flags.append(False) - def _align_codebooks(t): C = t.shape[1] if C < self.num_audio_codebooks: @@ -275,7 +266,6 @@ def _align_codebooks(t): # normalize volume and apply audio the removal of user turn if needed for i, cut in enumerate(cuts): - remove_user_turn_this_cut = remove_user_turn_flags[i] # Extract the raw, unpadded 1D numpy array for this specific cut t_audio = target_audio[i, :target_audio_lens[i]].numpy() @@ -287,25 +277,6 @@ def _align_codebooks(t): else: s_audio = cut.resample(self.source_sample_rate).load_audio().squeeze(0) - if remove_user_turn_this_cut: - collapsed_t, collapsed_s = [], [] - for sup in cut.supervisions: - if sup.speaker in self.output_roles: - start_t = int(round(max(0, sup.start) * self.sample_rate)) - end_t = int(round(sup.end * self.sample_rate)) - start_s = int(round(max(0, sup.start) * self.source_sample_rate)) - end_s = int(round(sup.end * self.source_sample_rate)) - - # Clamp safely inside the array - start_t, end_t = min(start_t, len(t_audio)), min(end_t, len(t_audio)) - start_s, end_s = min(start_s, len(s_audio)), min(end_s, len(s_audio)) - - if end_t > start_t: collapsed_t.append(t_audio[start_t:end_t]) - if end_s > start_s: collapsed_s.append(s_audio[start_s:end_s]) - - t_audio = np.concatenate(collapsed_t) if collapsed_t else np.zeros(1, dtype=np.float32) - s_audio = np.concatenate(collapsed_s) if collapsed_s else np.zeros(1, dtype=np.float32) - # Apply volume norm locally (so we only normalize the stitched audio, saving math ops) if self.volume_norm: t_audio = normalize_volume(t_audio) @@ -331,19 +302,18 @@ def _align_codebooks(t): target_text_tokens, target_token_lens = collate_token_channel( cuts, self.text_tokenizer, self.frame_length, roles=self.output_roles, add_text_bos=self.add_text_bos, tokenizer_names=batch_tokenizer_names, - pad_id=self.pad_id, eos_id=self.eos_id, bos_id=self.bos_id, interruption_token_id=self.interruption_token_id, remove_user_turn_flags=remove_user_turn_flags + pad_id=self.pad_id, eos_id=self.eos_id, bos_id=self.bos_id, interruption_token_id=self.interruption_token_id ) source_tokens, source_token_lens = collate_token_channel( cuts, self.text_tokenizer, self.frame_length, roles=self.input_roles, add_text_bos=self.add_text_bos, tokenizer_names=batch_tokenizer_names, - pad_id=self.pad_id, eos_id=self.eos_id, bos_id=self.bos_id, interruption_token_id=self.interruption_token_id, remove_user_turn_flags=remove_user_turn_flags + pad_id=self.pad_id, eos_id=self.eos_id, bos_id=self.bos_id, interruption_token_id=self.interruption_token_id ) if self.phoneme_tokenizer is not None: target_phoneme_tokens, target_phoneme_lens = collate_phoneme_channel( cuts, self.phoneme_tokenizer, self.frame_length, roles=self.output_roles, ignore_phoneme_languages=self.ignore_phoneme_languages, pad_id=self.phoneme_tokenizer.pad, eos_id=self.phoneme_tokenizer.eos_token_id, bos_id=self.phoneme_tokenizer.bos_token_id, - remove_user_turn_flags=remove_user_turn_flags, ) else: target_phoneme_tokens, target_phoneme_lens = None, None @@ -372,7 +342,6 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) return random.uniform(self.context_duration_min, effective_duration_max) for i, cut in enumerate(cuts): - remove_user_turn_this_cut = remove_user_turn_flags[i] speaker_found = False for sup in reversed(cut.supervisions): if check_speaker_format(sup.speaker): @@ -383,7 +352,6 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) if not speaker_found: dataset_name = "unknown" dataset_name_list.append(dataset_name) - # print("Language is available?", cut.has_custom("lang"), " Has codes?", cut.has_custom("target_codes"), "Has context audio?", cut.has_custom("context_audio"), "Has context codes?", cut.has_custom("context_codes")) language = cut.lang if cut.has_custom("lang") else next((sup.language for sup in reversed(cut.supervisions) if sup.has_custom("language")), "en") language_list.append(language) @@ -392,14 +360,10 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) if self.load_cached_codes_if_available: if cut.has_custom("target_codes"): codes_array = cut.target_codes.load().astype(np.int32) - if remove_user_turn_this_cut: - raise RuntimeError("Remove user turn augmentation is not implemented for cached codes!") target_codes_list.append(torch.from_numpy(codes_array).T) if cut.has_custom("source_codes"): source_codes_list.append(torch.from_numpy(cut.source_codes.load().astype(np.int32)).T) - if remove_user_turn_this_cut: - raise RuntimeError("Remove user turn augmentation is not implemented for cached codes!") # Context Audio or Context Codes if self.load_cached_codes_if_available and cut.has_custom("context_codes"): @@ -443,7 +407,7 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) else: matching_supervisions = [s for s in cut.supervisions if s.speaker in self.output_roles] - + if self.load_cached_codes_if_available: if len(matching_supervisions) > 0 and cut.has_custom("target_codes"): sup = random.choice(matching_supervisions) @@ -642,18 +606,16 @@ def collate_token_channel( eos_id: int = None, bos_id: int = None, interruption_token_id: int = None, - remove_user_turn_flags: list[bool] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Build and collate token channels aligned to the audio frame grid.""" tokens = [] for i, c in enumerate(cuts): tok_name = tokenizer_names[i] if tokenizer_names else "english_phoneme" - flag = remove_user_turn_flags[i] if remove_user_turn_flags else False tokens.append( build_token_channel( c, tokenizer, frame_length, roles, pad_id, eos_id, bos_id, interruption_token_id, - add_text_bos, tok_name, remove_user_turns=flag, + add_text_bos, tok_name, ) ) token_lens = torch.tensor([len(tt) for tt in tokens]) @@ -701,53 +663,7 @@ def build_token_channel( interruption_token_id: int = -4, add_text_bos: bool = True, tokenizer_name: str = "english_phoneme", - remove_user_turns: bool = False, ) -> torch.Tensor: - if remove_user_turns: - turn_chunks = [] - for supervision in cut.supervisions: - if supervision.speaker in roles: - # 1. Get exact frame length of THIS turn - start_f = compute_num_frames(supervision.start, frame_length, cut.sampling_rate) - end_f = compute_num_frames(supervision.end, frame_length, cut.sampling_rate) - turn_frames = max(0, end_f - start_f) - - turn_tokens = torch.ones(turn_frames, dtype=torch.long) * pad_id - - if turn_frames == 0: - continue - - # 2. Encode text - text = supervision.text - if hasattr(tokenizer, "encode"): - try: - raw_ids = tokenizer.encode(text=text, tokenizer_name=tokenizer_name) - except TypeError: - raw_ids = tokenizer.encode(text) - else: - raw_ids = tokenizer.text_to_ids(text) - - if add_text_bos: - text_ids = [bos_id] + raw_ids + [eos_id] - else: - text_ids = raw_ids + [eos_id] - - # 3. Place text at the start, keeping the rest as pad_id - text_len = len(text_ids) - if text_len > turn_frames: - text_ids = text_ids[:turn_frames] - text_len = turn_frames - - turn_tokens[0:text_len] = torch.as_tensor(text_ids, dtype=torch.long) - - # 4. Place interruption token at the exact end of the turn - turn_tokens[-1] = interruption_token_id - turn_chunks.append(turn_tokens) - - if turn_chunks: - return torch.cat(turn_chunks, dim=0) - else: - return torch.tensor([pad_id], dtype=torch.long) total = compute_num_frames(cut.duration, frame_length, cut.sampling_rate) tokens = torch.ones(total, dtype=torch.long) * pad_id @@ -847,15 +763,13 @@ def collate_phoneme_channel( pad_id: int = -1, eos_id: int = -2, bos_id: int = -3, - remove_user_turn_flags: list[bool] = None, ) -> tuple[torch.Tensor, torch.Tensor]: tokens = [] for i, c in enumerate(cuts): - flag = remove_user_turn_flags[i] if remove_user_turn_flags else False tokens.append( build_phoneme_channel( c, phoneme_tokenizer, frame_length, roles, - ignore_phoneme_languages, pad_id, eos_id, bos_id, remove_user_turns=flag + ignore_phoneme_languages, pad_id, eos_id, bos_id ) ) token_lens = torch.tensor([len(tt) for tt in tokens]) @@ -871,47 +785,9 @@ def build_phoneme_channel( pad_id: int = -1, eos_id: int = -2, bos_id: int = -3, - remove_user_turns: bool = False, ) -> torch.Tensor: language = cut.lang if cut.has_custom("lang") else next((sup.language for sup in cut.supervisions if sup.has_custom("language")), "en") - if remove_user_turns: - turn_chunks = [] - for supervision in cut.supervisions: - if supervision.speaker in roles: - start_f = compute_num_frames(supervision.start, frame_length, cut.sampling_rate) - end_f = compute_num_frames(supervision.end, frame_length, cut.sampling_rate) - turn_frames = max(0, end_f - start_f) - - turn_tokens = torch.ones(turn_frames, dtype=torch.long) * pad_id - - if turn_frames == 0: - continue - - if isinstance(phoneme_tokenizer, IPABPETokenizer): - ipa_text = _get_supervision_ipa_text(supervision) - if language in ignore_phoneme_languages: - ipa_text = "" - else: - ipa_text = supervision.text - - phoneme_ids = phoneme_tokenizer.encode(ipa_text) - phoneme_ids = [bos_id] + phoneme_ids + [eos_id] - - text_len = len(phoneme_ids) - if text_len > turn_frames: - - phoneme_ids = phoneme_ids[:turn_frames] - text_len = turn_frames - - turn_tokens[0:text_len] = torch.as_tensor(phoneme_ids, dtype=torch.long) - turn_chunks.append(turn_tokens) - - if turn_chunks: - return torch.cat(turn_chunks, dim=0) - else: - return torch.tensor([pad_id], dtype=torch.long) - total = compute_num_frames(cut.duration, frame_length, cut.sampling_rate) tokens = torch.ones(total, dtype=torch.long) * pad_id diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index 093233d45c7c..5ca06f8dca35 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -736,15 +736,6 @@ def prepare_audio_channel_embeddings( target_agent_mask = agent_mask & valid loss_agent_mask = target_agent_mask - if self.cfg.get("debug_decode_agent_mask", False) and self.training and self.global_step < 5: - self.debug_decode_mask_regions( - audio_codes_target=audio_codes_target, - audio_codes_lens_target=audio_codes_lens_target, - agent_mask=agent_mask, - out_dir=os.path.join(self.trainer.log_dir, "mask_debug", f"step_{self.global_step}"), - prefix=f"batch_{self.global_rank}_{self.global_step}", - ) - # Replace user/non-agent regions with a learned token. # Important: audio_codes_input predicts audio_codes_target, so input mask must be shifted. if self.cfg.get("use_user_speaking_token", False): @@ -2023,7 +2014,6 @@ def get_lhotse_dataloader(self, dataset_cfg, mode='train') -> torch.utils.data.D input_roles=["user", "User"], output_roles=["assistant", "Assistant", "agent", "Agent"], add_text_bos=self.cfg.get("add_text_bos", False), - remove_user_turns_prob=self.cfg.get("remove_user_turns_prob", None), ) dataset = FallbackDataset(dataset) else: @@ -2177,112 +2167,3 @@ def val_dataloader(self): self._val_dl_wrapped_with_dist_sampler = True return self._validation_dl - - def debug_decode_mask_regions( - self, - audio_codes_target, - audio_codes_lens_target, - agent_mask, - out_dir, - prefix="debug_mask", - ): - os.makedirs(out_dir, exist_ok=True) - - device = audio_codes_target.device - B, C, T = audio_codes_target.shape - - agent_mask = agent_mask.to(device).bool() - - if agent_mask.size(1) < T: - pad = torch.zeros(B, T - agent_mask.size(1), device=device, dtype=torch.bool) - agent_mask = torch.cat([agent_mask, pad], dim=1) - else: - agent_mask = agent_mask[:, :T] - - valid = get_mask_from_lengths(audio_codes_lens_target).bool().to(device) - agent_mask = agent_mask & valid - - C_base = self.num_audio_codebooks - S = self.frame_stacking_factor - C_target = audio_codes_target.size(1) - - sil = self.codec_sil_codes.to(device=device, dtype=audio_codes_target.dtype) - - if C_target == C_base: - sil = sil.view(1, C_base, 1).expand(B, C_base, T) - - elif C_target == C_base * S: - sil_unstacked = sil.view(1, C_base, 1).expand(B, C_base, T * S).contiguous() - sil_stacked, _ = self.stack_codes( - sil_unstacked, - torch.full((B,), T * S, dtype=torch.long, device=device), - bos_id=self.audio_bos_id, - eos_id=self.audio_eos_id, - stacking_factor=S, - num_codebooks=C_base, - ) - sil = sil_stacked[:, :, :T] - else: - raise RuntimeError( - f"Unexpected codebook dim: target C={C_target}, " - f"base C={C_base}, stacking_factor={S}" - ) - - def decode_and_save(codes, lens, name): - codes = codes.clone() - codes, lens = self._prepare_codes_for_decode(codes, lens) - audio, audio_len, _ = self._codec_helper.codes_to_audio(codes, lens) - - for b in range(B): - wav = audio[b, : audio_len[b]].float().detach().cpu().numpy() - sf.write( - os.path.join(out_dir, f"{prefix}_b{b}_{name}.wav"), - wav, - self.output_sample_rate, - ) - - # 1. full target - decode_and_save(audio_codes_target, audio_codes_lens_target, "full_target") - - # 2. only agent region, silence elsewhere - agent_codes = torch.where(agent_mask[:, None, :], audio_codes_target, sil) - decode_and_save(agent_codes, audio_codes_lens_target, "agent_only_sil_elsewhere") - - # 3. only masked-out region, silence elsewhere - non_agent_codes = torch.where((~agent_mask & valid)[:, None, :], audio_codes_target, sil) - decode_and_save(non_agent_codes, audio_codes_lens_target, "non_agent_only_sil_elsewhere") - - # 4. each contiguous agent segment independently - for b in range(B): - mask_b = agent_mask[b] - idx = mask_b.nonzero(as_tuple=False).flatten() - - if idx.numel() == 0: - continue - - # contiguous runs - breaks = torch.where(idx[1:] != idx[:-1] + 1)[0] + 1 - chunks = torch.tensor_split(idx, breaks.cpu().tolist()) - - for seg_i, seg_idx in enumerate(chunks): - start = int(seg_idx[0]) - end = int(seg_idx[-1]) + 1 - - seg_codes = audio_codes_target[b : b + 1, :, start:end].clone() - seg_lens = torch.tensor([end - start], device=device, dtype=torch.long) - - seg_codes, seg_lens = self._prepare_codes_for_decode(seg_codes, seg_lens) - audio, audio_len, _ = self._codec_helper.codes_to_audio(seg_codes, seg_lens) - - wav = audio[0, : audio_len[0]].float().detach().cpu().numpy() - sf.write( - os.path.join(out_dir, f"{prefix}_b{b}_agent_segment{seg_i}_frames{start}-{end}.wav"), - wav, - self.output_sample_rate, - ) - - logging.info( - f"[mask_debug] saved mask decode files to {out_dir}; " - f"agent coverage frames={agent_mask.sum(dim=1).detach().cpu().tolist()} / " - f"{audio_codes_lens_target.detach().cpu().tolist()}" - ) \ No newline at end of file diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index 03c4828da089..26d6bf3ae0ef 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -507,6 +507,17 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): else: raise ValueError(f"Unknown decoder_type: {self.decoder_type}. Supported: 'huggingface', 'nemotron_h'") + self.activation_checkpointing = cfg.get("activation_checkpointing", False) + if self.activation_checkpointing: + logging.info("Enabling activation checkpointing for decoder") + + if self.decoder_type == "nemotron_h": + self.decoder.gradient_checkpointing = True + elif hasattr(self.decoder, "gradient_checkpointing_enable"): + self.decoder.gradient_checkpointing_enable() + elif hasattr(self.decoder, "gradient_checkpointing"): + self.decoder.gradient_checkpointing = True + if self.disable_lm_text_head and hasattr(self.decoder, 'lm_head'): self.decoder.lm_head = None @@ -1888,7 +1899,7 @@ def _prepare_streaming_input( device=next_input.device, dtype=next_input.dtype, ) - + # --- Handle CFG --- if state.config.use_cfg: next_input_unconditional_context = state.config.dummy_context_embedding_unconditional.expand( diff --git a/nemo/collections/tts/modules/nemotron_h_decoder.py b/nemo/collections/tts/modules/nemotron_h_decoder.py index f7e08746d753..75a47aeb9d5b 100644 --- a/nemo/collections/tts/modules/nemotron_h_decoder.py +++ b/nemo/collections/tts/modules/nemotron_h_decoder.py @@ -1536,9 +1536,22 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + def create_custom_forward(layer, layer_mask): + def custom_forward(hidden_states): + return layer( + hidden_states, + cache_params=None, + cache_position=None, + attention_mask=layer_mask, + ) + return custom_forward + + if self.gradient_checkpointing and self.training: hidden_states = torch.utils.checkpoint.checkpoint( - layer.__call__, hidden_states, cache_params, cache_position, layer_mask + create_custom_forward(layer, layer_mask), + hidden_states, + use_reentrant=False, ) else: hidden_states = layer( From c01730fc2d35f4caadd96c9198203af6e94bd42e Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 28 May 2026 13:41:10 -0700 Subject: [PATCH 064/109] Add partial copy support Signed-off-by: Edresson Casanova --- .../collections/speechlm2/parts/pretrained.py | 114 ++++++++++++++---- .../tts/models/easy_magpietts_inference.py | 2 +- 2 files changed, 90 insertions(+), 26 deletions(-) diff --git a/nemo/collections/speechlm2/parts/pretrained.py b/nemo/collections/speechlm2/parts/pretrained.py index e0d485bc0529..301f31c2179e 100644 --- a/nemo/collections/speechlm2/parts/pretrained.py +++ b/nemo/collections/speechlm2/parts/pretrained.py @@ -302,46 +302,110 @@ def setup_parallel_expert_encoder(model: torch.nn.Module): def set_model_dict_for_partial_init( - pretrained_dict: Dict[str, torch.Tensor], model_dict: Dict[str, torch.Tensor] + pretrained_dict: Dict[str, torch.Tensor], + model_dict: Dict[str, torch.Tensor], + allow_partial_copy: bool = False, ) -> Dict[str, torch.Tensor]: """ Partially initialize a model's state dictionary with a pretrained state dictionary. + This function safely copies compatible layers from a pretrained model into a new model, - ignoring layers with mismatched shapes or missing keys. + ignoring layers with missing keys or incompatible shapes. + + By default, only tensors with exactly matching shapes are restored. - Steps: - 1. Remove layers from the pretrained dictionary if their shape does not match the target model. - 2. Keep only keys that exist in the target model. - 3. Update the model dictionary with the filtered pretrained weights. + If ``allow_partial_copy=True``, tensors whose shapes differ only in the first + dimension are partially restored by copying the overlapping rows from the + pretrained tensor into the target tensor. The remaining rows keep their + model-initialized values. This is useful when adding new vocabulary rows or + special tokens, e.g. adding an interruption token to an embedding table. Args: - pretrained_dict (Dict[str, torch.Tensor]): - The state dictionary of the pretrained model. - model_dict (Dict[str, torch.Tensor]): - The state dictionary of the target model to be partially initialized. + pretrained_dict: + State dictionary from the pretrained checkpoint. + + model_dict: + State dictionary of the target model. + + allow_partial_copy: + If True, allow partial row-wise restore for tensors where only + dimension 0 differs and all trailing dimensions match. Defaults to False. Returns: - Dict[str, torch.Tensor]: - The updated model state dictionary with compatible layers loaded from the pretrained dictionary. + Updated model state dictionary with compatible pretrained weights loaded. Example: >>> model_dict = model.state_dict() >>> pretrained_dict = load_checkpoint("pretrained_model.ckpt") - >>> model_dict = set_model_dict_for_partial_init(pretrained_dict, model_dict) + >>> model_dict = set_model_dict_for_partial_init( + ... pretrained_dict, + ... model_dict, + ... allow_partial_copy=True, + ... ) >>> model.load_state_dict(model_dict) """ - # 1. Remove layers where pretrained shape differs from model shape - for k, v in list(pretrained_dict.items()): - if k in model_dict and hasattr(model_dict[k], "numel") and v.numel() != model_dict[k].numel(): - del pretrained_dict[k] - logging.info(f" | > Layer with shape mismatch in the model definition: {k}") - - # 2. Keep only keys that exist in the target model - pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} - - # 3. Update model dictionary with filtered pretrained layers - model_dict.update(pretrained_dict) - logging.info(f" | > {len(pretrained_dict)} / {len(model_dict)} layers are restored.") + restored_dict = {} + exact_restored = 0 + partial_restored = 0 + skipped_mismatch = 0 + + for key, pretrained_value in pretrained_dict.items(): + if key not in model_dict: + continue + + model_value = model_dict[key] + + if not hasattr(pretrained_value, "shape") or not hasattr(model_value, "shape"): + continue + + if pretrained_value.shape == model_value.shape: + restored_dict[key] = pretrained_value + exact_restored += 1 + continue + + can_partial_copy = ( + allow_partial_copy + and pretrained_value.ndim == model_value.ndim + and pretrained_value.ndim > 0 + and pretrained_value.shape[1:] == model_value.shape[1:] + ) + + if can_partial_copy: + merged_value = model_value.clone() + rows_to_copy = min(pretrained_value.shape[0], model_value.shape[0]) + + merged_value[:rows_to_copy].copy_( + pretrained_value[:rows_to_copy].to( + device=merged_value.device, + dtype=merged_value.dtype, + ) + ) + + restored_dict[key] = merged_value + partial_restored += 1 + + logging.info( + f" | > Partially restored resized tensor: {key} " + f"pretrained={tuple(pretrained_value.shape)} " + f"model={tuple(model_value.shape)} " + f"copied_rows={rows_to_copy}" + ) + continue + + skipped_mismatch += 1 + logging.info( + f" | > Layer with shape mismatch in the model definition: {key} " + f"pretrained={tuple(pretrained_value.shape)} " + f"model={tuple(model_value.shape)}" + ) + + model_dict.update(restored_dict) + + logging.info( + f" | > {len(restored_dict)} / {len(model_dict)} layers are restored " + f"({exact_restored} exact, {partial_restored} partial, " + f"{skipped_mismatch} skipped due to incompatible shape)." + ) return model_dict diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index 26d6bf3ae0ef..b6a0663f3fee 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -655,7 +655,7 @@ def restore_from_pretrained_checkpoint(self, checkpoint_path): checkpoint_state = torch.load(checkpoint_path, map_location='cpu') else: checkpoint_state = torch.load(checkpoint_path, map_location='cpu') - checkpoint_state = set_model_dict_for_partial_init(checkpoint_state, self.state_dict()) + checkpoint_state = set_model_dict_for_partial_init(checkpoint_state, self.state_dict(), allow_partial_copy=True) self.load_state_dict(checkpoint_state, strict=True) logging.info(f"Model restored from the checkpoint: {checkpoint_path} !") From a2a21d2a768962818de3824dbfa494d30da838f8 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 29 May 2026 10:11:38 -0700 Subject: [PATCH 065/109] Add parameter to drop all turn in sample Signed-off-by: Edresson Casanova --- nemo/collections/tts/models/easy_magpietts.py | 52 ++++++++++++++++--- 1 file changed, 45 insertions(+), 7 deletions(-) diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index 5ca06f8dca35..febe8ec9e907 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -1296,15 +1296,53 @@ def training_step(self, batch, batch_idx): user_audio = batch["user_audio_turn_splitted"] user_audio_lens = batch["user_audio_turn_splitted_lens"] - silence_prob = float(self.cfg.get("user_cond_silence_augmentation_prob", 0.0) or 0.0) - if self.training and silence_prob > 0.0: - silence_mask = torch.rand( - user_audio.size(0), - device=user_audio.device, - ) < silence_prob + turn_silence_prob = float(self.cfg.get("user_cond_silence_augmentation_prob", 0.0) or 0.0) + sample_silence_prob = float(self.cfg.get("user_cond_sample_silence_augmentation_prob", 0.0) or 0.0) + + if self.training and (turn_silence_prob > 0.0 or sample_silence_prob > 0.0): + user_audio = user_audio.clone() + + # randomly drop individual turns. + if turn_silence_prob > 0.0: + turn_silence_mask = torch.rand( + user_audio.size(0), + device=user_audio.device, + ) < turn_silence_prob + else: + turn_silence_mask = torch.zeros( + user_audio.size(0), + device=user_audio.device, + dtype=torch.bool, + ) + + # randomly drop all turns for selected samples + if sample_silence_prob > 0.0: + B = batch["text"].shape[0] + + sample_silence_mask = torch.rand( + B, + device=user_audio.device, + ) < sample_silence_prob + + indices = batch["user_audio_turn_splitted_indices"].to(user_audio.device) + turn_batch_indices = indices[:, 0].long() + + valid_turns = turn_batch_indices >= 0 + sample_drop_turn_mask = torch.zeros( + user_audio.size(0), + device=user_audio.device, + dtype=torch.bool, + ) + + sample_drop_turn_mask[valid_turns] = sample_silence_mask[ + turn_batch_indices[valid_turns] + ] + + silence_mask = turn_silence_mask | sample_drop_turn_mask + else: + silence_mask = turn_silence_mask if silence_mask.any(): - user_audio = user_audio.clone() user_audio[silence_mask] = 0.0 user_audio_codes, user_audio_codes_lens = self._codec_helper.audio_to_codes( From d52449e7bc9218760fde669ee5a477f3a4a2e61d Mon Sep 17 00:00:00 2001 From: Shehzeen Hussain Date: Fri, 29 May 2026 23:54:13 -0700 Subject: [PATCH 066/109] phoneme turn dropout Signed-off-by: Shehzeen Hussain Signed-off-by: Edresson Casanova --- .../easy_magpietts_lhotse_multiturn.yaml | 4 +- ...text_to_speech_dataset_lhotse_multiturn.py | 51 +++++++++++++++---- nemo/collections/tts/models/easy_magpietts.py | 9 +++- 3 files changed, 51 insertions(+), 13 deletions(-) diff --git a/examples/tts/conf/magpietts/easy_magpietts_lhotse_multiturn.yaml b/examples/tts/conf/magpietts/easy_magpietts_lhotse_multiturn.yaml index 0f729a41b61e..26c61172de4e 100644 --- a/examples/tts/conf/magpietts/easy_magpietts_lhotse_multiturn.yaml +++ b/examples/tts/conf/magpietts/easy_magpietts_lhotse_multiturn.yaml @@ -93,7 +93,9 @@ model: phoneme_corruption_timestep_ratio: 0.15 phoneme_corruption_unk_mode_prob: 0.5 phoneme_corruption_type: "repeat_skip_unk" # "repeat_skip_unk" or "complete_channel" - + phoneme_turn_dropout_batch_prob: 0.0 # prob of applying turn dropout to a sample + phoneme_turn_dropout_turn_prob: 0.0 # prob of dropping each phoneme turn within a sample + phoneme_tokenizer: _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPABPETokenizer tokenizer_path: ??? diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py index 099c04ade52b..6ae5642d0e55 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py @@ -172,6 +172,8 @@ def __init__( input_roles: List[str] = ["user", "User"], output_roles: List[str] = ["assistant", "Assistant", "agent", "Agent"], add_text_bos: bool = False, + phoneme_turn_dropout_batch_prob: float = 0.0, + phoneme_turn_dropout_turn_prob: float = 0.0, ): super().__init__() self.sample_rate = sample_rate @@ -203,6 +205,8 @@ def __init__( self.input_roles = set(ifnone(input_roles, ["user"])) self.output_roles = set(ifnone(output_roles, ["agent"])) self.add_text_bos = add_text_bos + self.phoneme_turn_dropout_batch_prob = phoneme_turn_dropout_batch_prob + self.phoneme_turn_dropout_turn_prob = phoneme_turn_dropout_turn_prob self.frame_length = (self.codec_model_samples_per_frame / codec_model_input_sample_rate) * frame_stacking_factor @@ -311,12 +315,15 @@ def _align_codebooks(t): ) if self.phoneme_tokenizer is not None: - target_phoneme_tokens, target_phoneme_lens = collate_phoneme_channel( + target_phoneme_tokens, target_phoneme_lens, phoneme_turn_dropout = collate_phoneme_channel( cuts, self.phoneme_tokenizer, self.frame_length, roles=self.output_roles, ignore_phoneme_languages=self.ignore_phoneme_languages, pad_id=self.phoneme_tokenizer.pad, eos_id=self.phoneme_tokenizer.eos_token_id, bos_id=self.phoneme_tokenizer.bos_token_id, + phoneme_turn_dropout_batch_prob=self.phoneme_turn_dropout_batch_prob, + phoneme_turn_dropout_turn_prob=self.phoneme_turn_dropout_turn_prob, + apply_turn_dropout=self.dataset_type == 'train', ) else: - target_phoneme_tokens, target_phoneme_lens = None, None + target_phoneme_tokens, target_phoneme_lens, phoneme_turn_dropout = None, None, None dataset_name_list = [] audio_list_16khz = [] @@ -547,6 +554,7 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) if self.phoneme_tokenizer is not None: batch_dict["phoneme_tokens"] = target_phoneme_tokens batch_dict["phoneme_tokens_lens"] = target_phoneme_lens + batch_dict["phoneme_turn_dropout"] = phoneme_turn_dropout if len(audio_list_16khz) > 0: batch_dict["audio_16khz"] = collate_vectors(audio_list_16khz, padding_value=0.0) @@ -763,17 +771,24 @@ def collate_phoneme_channel( pad_id: int = -1, eos_id: int = -2, bos_id: int = -3, -) -> tuple[torch.Tensor, torch.Tensor]: + phoneme_turn_dropout_batch_prob: float = 0.0, + phoneme_turn_dropout_turn_prob: float = 0.0, + apply_turn_dropout: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: tokens = [] + dropout_flags = [] for i, c in enumerate(cuts): - tokens.append( - build_phoneme_channel( - c, phoneme_tokenizer, frame_length, roles, - ignore_phoneme_languages, pad_id, eos_id, bos_id - ) + token, dropout_applied = build_phoneme_channel( + c, phoneme_tokenizer, frame_length, roles, + ignore_phoneme_languages, pad_id, eos_id, bos_id, + phoneme_turn_dropout_batch_prob=phoneme_turn_dropout_batch_prob, + phoneme_turn_dropout_turn_prob=phoneme_turn_dropout_turn_prob, + apply_turn_dropout=apply_turn_dropout, ) + tokens.append(token) + dropout_flags.append(dropout_applied) token_lens = torch.tensor([len(tt) for tt in tokens]) - return collate_vectors(tokens, padding_value=pad_id), token_lens + return collate_vectors(tokens, padding_value=pad_id), token_lens, torch.tensor(dropout_flags, dtype=torch.bool) def build_phoneme_channel( @@ -785,14 +800,28 @@ def build_phoneme_channel( pad_id: int = -1, eos_id: int = -2, bos_id: int = -3, -) -> torch.Tensor: + phoneme_turn_dropout_batch_prob: float = 0.0, + phoneme_turn_dropout_turn_prob: float = 0.0, + apply_turn_dropout: bool = False, +) -> tuple[torch.Tensor, bool]: language = cut.lang if cut.has_custom("lang") else next((sup.language for sup in cut.supervisions if sup.has_custom("language")), "en") total = compute_num_frames(cut.duration, frame_length, cut.sampling_rate) tokens = torch.ones(total, dtype=torch.long) * pad_id + dropout_applied = False + apply_dropout = ( + apply_turn_dropout + and phoneme_turn_dropout_batch_prob > 0.0 + and phoneme_turn_dropout_turn_prob > 0.0 + and random.random() < phoneme_turn_dropout_batch_prob + ) for supervision in cut.supervisions: if supervision.speaker in roles: + if apply_dropout and random.random() < phoneme_turn_dropout_turn_prob: + dropout_applied = True + continue + if isinstance(phoneme_tokenizer, IPABPETokenizer): ipa_text = _get_supervision_ipa_text(supervision) if language in ignore_phoneme_languages: @@ -812,4 +841,4 @@ def build_phoneme_channel( phoneme_ids = phoneme_ids[:len(tokens) - pos] tokens[pos:pos+len(phoneme_ids)] = phoneme_ids - return tokens \ No newline at end of file + return tokens, dropout_applied \ No newline at end of file diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index febe8ec9e907..775c61f207b5 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -858,6 +858,7 @@ def process_batch( context_audio_codes_lens: torch.Tensor, phoneme_tokens: Optional[torch.Tensor] = None, phoneme_tokens_lens: Optional[torch.Tensor] = None, + phoneme_turn_dropout: Optional[torch.Tensor] = None, mode: str = "train", training_mode: Optional[TrainingMode] = None, task: Optional[List[str]] = None, @@ -1208,7 +1209,10 @@ def process_batch( pb_phoneme_tokens_lens_target = phoneme_tokens_lens_stacked - 1 if (phoneme_corruption_mode != 'repeat_skip') and not ( - dropout_complete_phoneme_channel or dropout_conditional_input or dropout_text_input + dropout_complete_phoneme_channel + or (phoneme_turn_dropout is not None and phoneme_turn_dropout.any()) + or dropout_conditional_input + or dropout_text_input ): custom_mask = None if self.cfg.get("phoneme_loss_mask_padding", False): @@ -1568,6 +1572,7 @@ def training_step(self, batch, batch_idx): context_audio_codes_lens=context_audio_codes_lens, phoneme_tokens=batch.get('phoneme_tokens'), phoneme_tokens_lens=batch.get('phoneme_tokens_lens'), + phoneme_turn_dropout=batch.get('phoneme_turn_dropout'), mode="train", task=batch["task"] if self.cfg.get("use_multiturn_dataset", False) else None, agent_mask=batch["agent_mask"] if self.cfg.get("use_multiturn_dataset", False) else None, @@ -2052,6 +2057,8 @@ def get_lhotse_dataloader(self, dataset_cfg, mode='train') -> torch.utils.data.D input_roles=["user", "User"], output_roles=["assistant", "Assistant", "agent", "Agent"], add_text_bos=self.cfg.get("add_text_bos", False), + phoneme_turn_dropout_batch_prob=self.cfg.get("phoneme_turn_dropout_batch_prob", 0.0), + phoneme_turn_dropout_turn_prob=self.cfg.get("phoneme_turn_dropout_turn_prob", 0.0), ) dataset = FallbackDataset(dataset) else: From 649f75812e8a22b169c436b907f82d6f977bda03 Mon Sep 17 00:00:00 2001 From: Shehzeen Hussain Date: Mon, 1 Jun 2026 09:15:42 -0700 Subject: [PATCH 067/109] filewise metrics and aggregated metrics in the inference script Signed-off-by: Shehzeen Hussain Signed-off-by: Edresson Casanova --- .../tts/easy_magpietts_inference_multiturn.py | 141 +++++++++++++++++- 1 file changed, 134 insertions(+), 7 deletions(-) diff --git a/examples/tts/easy_magpietts_inference_multiturn.py b/examples/tts/easy_magpietts_inference_multiturn.py index c30a6df138ea..2aa5f0de6fab 100644 --- a/examples/tts/easy_magpietts_inference_multiturn.py +++ b/examples/tts/easy_magpietts_inference_multiturn.py @@ -28,6 +28,7 @@ from torch.utils.data import DataLoader, Dataset from omegaconf import OmegaConf, open_dict +from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.audio.parts.utils.transforms import resample from nemo.collections.speechlm2.parts.metrics.asr_cer_wer import Intelligibility from nemo.collections.speechlm2.parts.metrics.secs import SECS @@ -216,6 +217,46 @@ def _load_audio(path, sample_rate, normalize=True, use_librosa=False): return resample(wav, sr, sample_rate).squeeze(0) +def _json_metric_value(value): + if torch.is_tensor(value): + value = value.detach().cpu() + if value.numel() == 1: + return value.item() + return value.tolist() + return value + + +def _write_json(path, data): + output_dir = os.path.dirname(path) + if output_dir: + os.makedirs(output_dir, exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2, ensure_ascii=False) + + +def _write_filewise_metrics(path, filewise_metrics): + sorted_metrics = sorted(filewise_metrics, key=lambda row: row.get("cer", float("-inf")), reverse=True) + _write_json(path, sorted_metrics) + + +def _compute_secs_per_sample(secs_metric, name, target_audio, target_audio_lens, pred_audio, pred_audio_lens): + if secs_metric.speaker_encoder is None: + secs_metric.reset() + + with fp32_precision(): + with torch.no_grad(): + _, t_g = secs_metric.speaker_encoder( + input_signal=target_audio, input_signal_length=target_audio_lens.long() + ) + _, s_g = secs_metric.speaker_encoder( + input_signal=pred_audio, input_signal_length=pred_audio_lens.long() + ) + secs = torch.nn.functional.cosine_similarity(t_g, s_g, dim=-1) + + secs_metric._secs[name].append(secs.mean()) + return secs.detach().cpu() + + def collate_and_tokenize_custom( batch, model, @@ -476,6 +517,7 @@ def collate_and_tokenize_custom( out_dict["context_audio"] = padded_audio out_dict["context_audio_lengths"] = torch.tensor(audio_lengths, dtype=torch.long) + out_dict["context_audio_paths"] = [s.get("context_audio_filepath") for s in batch] out_dict["target_audio_paths"] = [s["audio_filepath"] for s in batch] out_dict["target_num_frames"] = target_num_frames @@ -510,6 +552,18 @@ def main(): parser.add_argument("--inference_dtype", type=str, default="float32") parser.add_argument("--debug_dtype", action="store_true") parser.add_argument("--use_librosa", action="store_true", help="Use librosa instead of soundfile+torch for audio load") + parser.add_argument( + "--filewise_metrics_path", + type=str, + default=None, + help="Path to save per-file metrics JSON. Defaults to /filewise_metrics.json", + ) + parser.add_argument( + "--aggregate_metrics_path", + type=str, + default=None, + help="Path to save aggregate metrics JSON. Defaults to /aggregate_metrics.json", + ) # Dataloader & Batching parser.add_argument("--batch_size", type=int, default=6) @@ -554,6 +608,9 @@ def main(): if args.profile_pad_max_sec < args.profile_pad_min_sec: raise RuntimeError("--profile_pad_max_sec must be >= --profile_pad_min_sec.") + filewise_metrics_path = args.filewise_metrics_path or os.path.join(args.out_dir, "filewise_metrics.json") + aggregate_metrics_path = args.aggregate_metrics_path or os.path.join(args.out_dir, "aggregate_metrics.json") + distributed = int(os.environ.get("WORLD_SIZE", "1")) > 1 if distributed and not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend="nccl") @@ -649,6 +706,10 @@ def main(): use_librosa=args.use_librosa, ).unsqueeze(0).to(model.device, dtype=target_dtype) + filewise_metrics = [] + running_metric_sums = {"cer": 0.0, "wer": 0.0, "secs": 0.0} + running_metric_count = 0 + for batch_id, inputs in enumerate(dataloader): B = inputs["context_audio"].size(0) device = model.device @@ -1099,7 +1160,7 @@ def main(): metric_audio_pred = torch_rms_norm(metric_audio_pred) context_audio = torch_rms_norm(context_audio) - intelligibility.update( + asr_hyps = intelligibility.update( name="dataset", refs=inputs["raw_text"], pred_audio=metric_audio_pred, @@ -1107,14 +1168,61 @@ def main(): asr_hyps=None, ) - secs_metric.update( - name="dataset", - target_audio=context_audio, - target_audio_lens=context_audio_lens, - pred_audio=metric_audio_pred, - pred_audio_lens=metric_audio_pred_lens, + secs_values = _compute_secs_per_sample( + secs_metric, + "dataset", + context_audio, + context_audio_lens, + metric_audio_pred, + metric_audio_pred_lens, ) + batch_filewise_metrics = [] + for i, (raw_ref, raw_hyp) in enumerate(zip(inputs["raw_text"], asr_hyps)): + normalized_ref = intelligibility.normalizer(raw_ref) + normalized_hyp = intelligibility.normalizer(raw_hyp) + target_path = inputs["target_audio_paths"][i] + context_path = inputs["context_audio_paths"][i] + cer = word_error_rate([normalized_hyp], [normalized_ref], use_cer=True) + wer = word_error_rate([normalized_hyp], [normalized_ref], use_cer=False) + secs = _json_metric_value(secs_values[i]) + running_metric_count += 1 + running_metric_sums["cer"] += cer + running_metric_sums["wer"] += wer + running_metric_sums["secs"] += secs + row = { + "sample_index": running_metric_count - 1, + "batch_id": batch_id, + "batch_sample_index": i, + "target_audio_filepath": target_path, + "target_audio_resolved_filepath": _resolve_audio_path(target_path, args.audio_dir), + "context_audio_filepath": context_path, + "context_audio_resolved_filepath": _resolve_audio_path(context_path, args.audio_dir), + "speaker_reference_filepath": ( + args.inference_speaker_reference + if args.user_custom_speaker_reference and args.inference_speaker_reference + else None + ), + "generated_audio_filepath": None, + "generated_turn_audio_filepaths": [], + "aligned_user_agent_audio_filepath": None, + "reference_transcript": raw_ref, + "asr_hypothesis": raw_hyp, + "normalized_reference_transcript": normalized_ref, + "normalized_asr_hypothesis": normalized_hyp, + "cer": cer, + "wer": wer, + "secs": secs, + "running_average_metrics": { + "num_samples": running_metric_count, + "cer": running_metric_sums["cer"] / running_metric_count, + "wer": running_metric_sums["wer"] / running_metric_count, + "secs": running_metric_sums["secs"] / running_metric_count, + }, + } + batch_filewise_metrics.append(row) + filewise_metrics.append(row) + os.makedirs(args.out_dir, exist_ok=True) audio_f32 = audio_f32.detach().cpu() audio_len = audio_len.cpu() @@ -1166,12 +1274,14 @@ def main(): turn_wav = aligned_agent[start_sample:end_sample].numpy() out_path = os.path.join(args.out_dir, f"{stem}_turn_{turn_id}{ext}") sf.write(out_path, turn_wav, samplerate=model.output_sample_rate) + batch_filewise_metrics[i]["generated_turn_audio_filepaths"].append(out_path) logging.info(f"Saved: {out_path}") # Save full artifact-scrubbed agent audio. wav = aligned_agent.numpy() out_path = os.path.join(args.out_dir, base_name) sf.write(out_path, wav, samplerate=model.output_sample_rate) + batch_filewise_metrics[i]["generated_audio_filepath"] = out_path logging.info(f"Full aligned agent audio saved: {out_path}") # --------------------------------------------------------- @@ -1251,14 +1361,22 @@ def main(): samplerate=model.output_sample_rate, ) + batch_filewise_metrics[i]["aligned_user_agent_audio_filepath"] = aligned_path logging.info(f"Aligned user/agent stereo audio saved: {aligned_path}") else: wav = audio_f32[i, : audio_len[i]].numpy() out_path = os.path.join(args.out_dir, base_name) sf.write(out_path, wav, samplerate=model.output_sample_rate) + batch_filewise_metrics[i]["generated_audio_filepath"] = out_path logging.info(f"Saved: {out_path}") + _write_filewise_metrics(filewise_metrics_path, filewise_metrics) + logging.info( + f"Filewise metrics checkpoint saved: {filewise_metrics_path} " + f"({len(filewise_metrics)} samples, sorted by CER)" + ) + with fp32_precision(): logging.info("\n--- Evaluation Metrics ---") cer_wer = intelligibility.compute() @@ -1269,6 +1387,15 @@ def main(): for k, m in secs_scores.items(): logging.info(f"SECS - {k}: {m}") + aggregate_metrics = { + **{k: _json_metric_value(v) for k, v in cer_wer.items()}, + **{k: _json_metric_value(v) for k, v in secs_scores.items()}, + } + _write_filewise_metrics(filewise_metrics_path, filewise_metrics) + logging.info(f"Filewise metrics saved: {filewise_metrics_path}") + _write_json(aggregate_metrics_path, aggregate_metrics) + logging.info(f"Aggregate metrics saved: {aggregate_metrics_path}") + if __name__ == "__main__": main() From e32d4e1e1637b4673955eaf67bf15d30021666fe Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 1 Jun 2026 12:03:52 -0700 Subject: [PATCH 068/109] Add multigpu inference script Signed-off-by: Edresson Casanova --- ..._magpietts_inference_multiturn_multigpu.py | 1992 +++++++++++++++++ 1 file changed, 1992 insertions(+) create mode 100644 examples/tts/easy_magpietts_inference_multiturn_multigpu.py diff --git a/examples/tts/easy_magpietts_inference_multiturn_multigpu.py b/examples/tts/easy_magpietts_inference_multiturn_multigpu.py new file mode 100644 index 000000000000..fbe5a24e4712 --- /dev/null +++ b/examples/tts/easy_magpietts_inference_multiturn_multigpu.py @@ -0,0 +1,1992 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +""" +Multi-GPU evaluation script for custom EasyMagpieTTS models. + +Properties: + - One process per GPU with torch.distributed when WORLD_SIZE > 1. + - Uses DistributedSampler(drop_last=False). If len(dataset) is not divisible by + world_size, PyTorch may repeat a few samples so all ranks have equal work. + This avoids last-rank/last-batch distributed hangs. Repeated samples are + deduplicated from filewise final metrics. + - Optional --num_eval_runs N repeats the same eval set N times and reports + run-averaged filewise metrics when --save_filewise_metrics is enabled. + - Optional --sort_by_text_token_count orders samples by total text token count + so each GPU step receives similarly-sized examples. By default it sorts + ascending, so DistributedSampler padding repeats short examples. + - profile_multiturn_inference remains batch_size=1 per rank, but runs in + parallel across ranks/GPUs. + - Saves global metrics in out_dir: + metrics_rankXXXX.json + metrics_final.json + metrics_final.txt + - Saves generated audio files in: + out_dir/audios/ + - Optional filewise metrics: + --save_filewise_metrics + Saves: + filewise_metrics_rankXXXX.jsonl + filewise_metrics_sorted_by_cer.jsonl + filewise_metrics_sorted_by_cer.csv + metrics_final_filewise_average.json + The merged filewise outputs deduplicate repeated DistributedSampler samples + by dataset_index. + - Prints the final text metric summary on rank 0: + Average CER: value + Average WER: value + SECS: value + +Recommended torchrun: + CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ + torchrun --nproc_per_node=8 easy_magpietts_inference_multiturn_multigpu.py ... + +Recommended SLURM/srun: + srun --ntasks-per-node=8 --gpus-per-task=1 --gpu-bind=single:1 \ + python easy_magpietts_inference_multiturn_multigpu.py ... +""" + +import argparse +import csv +import json +import os +import socket +import time +from copy import deepcopy +from functools import partial +from typing import Any, Dict, List + +import librosa +import soundfile as sf +import torch +import torch.distributed as dist +from omegaconf import open_dict +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import DataLoader, Dataset, DistributedSampler, SequentialSampler + +from nemo.collections.audio.parts.utils.transforms import resample +from nemo.collections.asr.metrics.wer import word_error_rate +from whisper_normalizer.english import EnglishTextNormalizer +from nemo.collections.speechlm2.parts.metrics.asr_cer_wer import Intelligibility +from nemo.collections.speechlm2.parts.metrics.secs import SECS +from nemo.collections.speechlm2.parts.precision import fp32_precision +from nemo.collections.tts.models import AudioCodecModel +from nemo.collections.tts.models.easy_magpietts_inference import EasyMagpieTTSInferenceModel +from nemo.collections.tts.modules.audio_codec_modules import VectorQuantizerIndexConverter +from nemo.collections.tts.modules.magpietts_modules import CodecHelper +from nemo.collections.tts.parts.utils.tts_dataset_utils import normalize_volume +from nemo.utils import logging + + +torch.set_float32_matmul_precision("medium") +torch.backends.cudnn.allow_tf32 = True +torch.backends.cuda.matmul.allow_tf32 = True + + +def get_rank_info(): + world_size = int(os.environ.get("WORLD_SIZE", "1")) + rank = int(os.environ.get("RANK", os.environ.get("SLURM_PROCID", "0"))) + local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("SLURM_LOCALID", "0"))) + distributed = world_size > 1 + return distributed, rank, local_rank, world_size + + +def get_visible_device_index(local_rank: int) -> int: + if not torch.cuda.is_available(): + return -1 + ndev = torch.cuda.device_count() + if ndev <= 0: + return -1 + return local_rank % ndev + + +def setup_distributed(): + distributed, rank, local_rank, world_size = get_rank_info() + device_index = get_visible_device_index(local_rank) + + if torch.cuda.is_available() and device_index >= 0: + torch.cuda.set_device(device_index) + + if distributed and not dist.is_initialized(): + dist.init_process_group(backend="nccl") + dist.barrier() + + return distributed, rank, local_rank, world_size, device_index + + +def cleanup_distributed(): + if dist.is_available() and dist.is_initialized(): + dist.barrier() + dist.destroy_process_group() + + +def all_rank_print(rank: int, msg: str): + print(f"[rank={rank}] {msg}", flush=True) + + +def rank0_print(rank: int, msg: str): + if rank == 0: + print(msg, flush=True) + + +def get_audio_out_dir(args) -> str: + return os.path.join(args.out_dir, "audios") + + +def torch_rms_norm(wav: torch.Tensor, db_level: float = -27.0) -> torch.Tensor: + denom = torch.sum(wav**2) + if denom <= 0: + return wav + r = 10 ** (db_level / 20) + a = torch.sqrt((wav.size(-1) * (r**2)) / denom) + return wav * a + + +def scalarize_metric_value(v: Any): + if torch.is_tensor(v): + if v.numel() == 1: + return float(v.detach().cpu().item()) + return v.detach().cpu().tolist() + + try: + import numpy as np + + if isinstance(v, np.generic): + return float(v.item()) + except Exception: + pass + + if isinstance(v, (int, float, str, bool)) or v is None: + return v + + return str(v) + + +def metric_dict_to_jsonable(d: Dict[str, Any]) -> Dict[str, Any]: + return {str(k): scalarize_metric_value(v) for k, v in d.items()} + + +def write_json(path: str, obj: Dict[str, Any]): + os.makedirs(os.path.dirname(path), exist_ok=True) + tmp_path = path + ".tmp" + with open(tmp_path, "w", encoding="utf-8") as f: + json.dump(obj, f, indent=2, sort_keys=True) + os.replace(tmp_path, path) + + +def write_text_atomic(path: str, text: str): + os.makedirs(os.path.dirname(path), exist_ok=True) + tmp_path = path + ".tmp" + with open(tmp_path, "w", encoding="utf-8") as f: + f.write(text) + os.replace(tmp_path, path) + + +def write_jsonl(path: str, rows: List[Dict[str, Any]]): + os.makedirs(os.path.dirname(path), exist_ok=True) + tmp_path = path + ".tmp" + with open(tmp_path, "w", encoding="utf-8") as f: + for row in rows: + f.write(json.dumps(row, sort_keys=True) + "\n") + os.replace(tmp_path, path) + + +def write_filewise_csv(path: str, rows: List[Dict[str, Any]]): + os.makedirs(os.path.dirname(path), exist_ok=True) + tmp_path = path + ".tmp" + + fieldnames = [ + "run_id", + "dataset_index", + "rank", + "cer", + "wer", + "secs", + "pred_audio_seconds", + "target_audio_path", + "reference_text", + "asr_hyp", + ] + + with open(tmp_path, "w", encoding="utf-8", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + for row in rows: + writer.writerow({k: row.get(k, None) for k in fieldnames}) + + os.replace(tmp_path, path) + + +def get_first_metric(metrics: Dict[str, Any], names: List[str], default=None): + for name in names: + if name in metrics: + return metrics[name] + return default + + +def safe_metric_scalar(metric_dict: Dict[str, Any], preferred_keys: List[str]): + for key in preferred_keys: + if key in metric_dict: + value = metric_dict[key] + if torch.is_tensor(value): + return float(value.detach().cpu().item()) + return float(value) + return None + + +def format_final_metric_text(final_metrics: Dict[str, Any]) -> str: + intelligibility = final_metrics.get("intelligibility", {}) + secs = final_metrics.get("secs", {}) + + cer = get_first_metric(intelligibility, ["cer", "cer_dataset"]) + wer = get_first_metric(intelligibility, ["wer", "wer_dataset"]) + secs_value = get_first_metric(secs, ["secs", "secs_dataset"]) + + def fmt(x): + if x is None: + return "nan" + try: + return f"{float(x):.10f}" + except Exception: + return str(x) + + return ( + f"Average CER: {fmt(cer)}\n" + f"Average WER: {fmt(wer)}\n" + f"SECS: {fmt(secs_value)}\n" + ) + + +def format_filewise_final_metric_text(filewise_summary: Dict[str, Any]) -> str: + def fmt(x): + if x is None: + return "nan" + try: + return f"{float(x):.10f}" + except Exception: + return str(x) + + return ( + f"Average CER: {fmt(filewise_summary.get('cer'))}\n" + f"Average WER: {fmt(filewise_summary.get('wer'))}\n" + f"SECS: {fmt(filewise_summary.get('secs'))}\n" + ) + + +def attach_dtype_counter(model): + handles = [] + stats = {} + examples = {} + + def is_leaf(module): + return len(list(module.children())) == 0 + + def get_dtype(x): + if torch.is_tensor(x): + return str(x.dtype) + if isinstance(x, (list, tuple)): + for t in x: + if torch.is_tensor(t): + return str(t.dtype) + return "other" + + def get_module_group(name): + return name.split(".")[0] if "." in name else name + + def hook_fn(name): + def fn(module, inputs, outputs): + dtype = get_dtype(outputs) + if dtype not in ["torch.float16", "torch.bfloat16", "torch.float32"]: + dtype = "other" + group = get_module_group(name) + if group not in stats: + stats[group] = { + "torch.float16": 0, + "torch.bfloat16": 0, + "torch.float32": 0, + "other": 0, + } + examples[group] = { + "torch.float16": [], + "torch.bfloat16": [], + "torch.float32": [], + "other": [], + } + stats[group][dtype] += 1 + if len(examples[group][dtype]) < 3: + examples[group][dtype].append(module.__class__.__name__) + + return fn + + for name, module in model.named_modules(): + if is_leaf(module): + handles.append(module.register_forward_hook(hook_fn(name))) + return handles, stats, examples + + +def report_dtype_stats(handles, stats, examples, rank=0): + for h in handles: + h.remove() + logging.info(f"[rank={rank}] === DTYPE USAGE PER MODULE ===") + for group, group_stats in stats.items(): + total = sum(group_stats.values()) + if total == 0: + continue + logging.info(f"[rank={rank}] --- {group} ---") + for dtype, count in group_stats.items(): + if count > 0: + logging.info(f"[rank={rank}] {dtype}: {count} ({100 * count / total:.2f}%)") + logging.info(f"[rank={rank}] === DTYPE EXAMPLES ===") + for group, group_examples in examples.items(): + for dtype, mods in group_examples.items(): + if mods: + logging.info(f"[rank={rank}] {group} {dtype}: {mods}") + + +def _combined_audio_name(first_audio_filepath: str, paths: List[str]) -> str: + base_names = [os.path.splitext(os.path.basename(p))[0] for p in paths if p] + ext = os.path.splitext(paths[-1])[1] if paths and paths[-1] else "" + combined_name = "_".join(base_names) + ext + dir_name = os.path.dirname(first_audio_filepath) + return os.path.join(dir_name, combined_name) if dir_name else combined_name + + +class EvalJSONLDataset(Dataset): + def __init__(self, file_path: str, num_turns: int = 1): + self.samples = [] + raw_samples = [] + + with open(file_path, "r", encoding="utf-8") as f: + for line_idx, line in enumerate(f, 1): + line = line.strip() + if not line: + continue + try: + sample = json.loads(line) + sample["__dataset_index__"] = len(raw_samples) + raw_samples.append(sample) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON on line {line_idx}: {e}") + + if num_turns <= 1: + self.samples = raw_samples + return + + single_turn_by_speaker = {} + for sample in raw_samples: + if isinstance(sample["text"], list): + self.samples.append(sample) + else: + speaker = sample.get("speaker", "unknown") + single_turn_by_speaker.setdefault(speaker, []).append(sample) + + synthetic_index = len(raw_samples) + for _, speaker_samples in single_turn_by_speaker.items(): + buffer_texts, buffer_paths = [], [] + first_sample_meta = None + + for sample in speaker_samples: + if not buffer_texts: + first_sample_meta = dict(sample) + + buffer_texts.append(sample["text"]) + buffer_paths.append(sample.get("audio_filepath", "")) + + if len(buffer_texts) == num_turns: + first_sample_meta["text"] = buffer_texts + first_sample_meta["audio_filepath"] = _combined_audio_name( + first_sample_meta.get("audio_filepath", ""), + buffer_paths, + ) + first_sample_meta["__dataset_index__"] = synthetic_index + synthetic_index += 1 + self.samples.append(first_sample_meta) + buffer_texts, buffer_paths, first_sample_meta = [], [], None + + if buffer_texts and first_sample_meta is not None: + first_sample_meta["text"] = buffer_texts + first_sample_meta["audio_filepath"] = _combined_audio_name( + first_sample_meta.get("audio_filepath", ""), + buffer_paths, + ) + first_sample_meta["__dataset_index__"] = synthetic_index + synthetic_index += 1 + self.samples.append(first_sample_meta) + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + return self.samples[idx] + + +def _sample_text_segments_for_count(sample: Dict[str, Any], max_eval_turns=None) -> List[str]: + text_data = sample.get("text", "") + if isinstance(text_data, list): + segments = text_data + if max_eval_turns is not None: + segments = segments[: int(max_eval_turns)] + return [str(x) for x in segments] + return [str(text_data)] + + +def estimate_text_token_count(sample: Dict[str, Any], model, max_eval_turns=None) -> int: + """Approximate generation cost by summing text tokenizer lengths over all turns.""" + main_tokenizer_name = list(model.cfg.text_tokenizers.keys())[0] + total = 0 + for segment in _sample_text_segments_for_count(sample, max_eval_turns=max_eval_turns): + total += len(model.tokenizer.encode(segment, tokenizer_name=main_tokenizer_name)) + 1 # +EOS + return int(total) + + +class SortedByTextTokenCountDataset(Dataset): + """ + Dataset wrapper that orders examples by total text-token count. + + With DistributedSampler(shuffle=False), rank r sees positions: + r, r + world_size, r + 2 * world_size, ... + + If the wrapper is sorted by length descending, then each GPU step gets a + block of examples with similar token lengths. This usually reduces + straggler effects for autoregressive/profile inference. + """ + + def __init__(self, dataset: Dataset, model, max_eval_turns=None, descending: bool = True): + self.dataset = dataset + scored = [] + for i in range(len(dataset)): + sample = dict(dataset[i]) + token_count = estimate_text_token_count(sample, model=model, max_eval_turns=max_eval_turns) + sample["__text_token_count__"] = int(token_count) + scored.append((token_count, i, sample)) + + scored.sort(key=lambda x: (x[0], -x[1]), reverse=bool(descending)) + self.indices = [i for _, i, _ in scored] + self.token_counts = {i: int(tok) for tok, i, _ in scored} + + def __len__(self): + return len(self.indices) + + def __getitem__(self, local_idx): + original_idx = self.indices[local_idx] + sample = dict(self.dataset[original_idx]) + sample["__text_token_count__"] = self.token_counts[original_idx] + return sample + + +def _resolve_audio_path(path, root_path): + if path is None: + return None + if root_path is not None and not os.path.isabs(path): + return os.path.join(root_path, path) + return path + + +def _load_audio(path, sample_rate, normalize=True, use_librosa=False): + if path is None or not os.path.exists(path): + return torch.zeros(1, dtype=torch.float32) + + if use_librosa: + wav, sr = librosa.load(path, sr=sample_rate, mono=True) + if normalize: + wav = normalize_volume(wav) + return torch.as_tensor(wav, dtype=torch.float32) + + wav, sr = sf.read(path, dtype="float32") + if wav.ndim > 1: + wav = wav.mean(axis=1) + + if normalize: + wav = normalize_volume(wav) + + wav = torch.as_tensor(wav, dtype=torch.float32).unsqueeze(0) + return resample(wav, sr, sample_rate).squeeze(0) + + +def collate_and_tokenize_custom( + batch, + model, + extra_duration_thrshould=1.3, + sample_rate=22050, + root_path=None, + emulate_duplex_inference=False, + add_interruption_token=False, + pad_factor_text_speech=10, + force_interruption=False, + normalize_audio_volume=True, + use_librosa=False, + profile_multiturn_inference=False, + max_eval_turns=None, +): + main_tokenizer_name = list(model.cfg.text_tokenizers.keys())[0] + + if max_eval_turns is not None: + max_eval_turns = int(max_eval_turns) + if max_eval_turns <= 0: + raise ValueError("--max_eval_turns must be > 0 when provided.") + + truncated_batch = [] + for s in batch: + s = dict(s) + if isinstance(s["text"], list): + s["text"] = s["text"][:max_eval_turns] + if isinstance(s.get("user_audio_file_path"), list): + s["user_audio_file_path"] = s["user_audio_file_path"][:max_eval_turns] + truncated_batch.append(s) + batch = truncated_batch + + is_profile = profile_multiturn_inference + is_duplex = emulate_duplex_inference and not is_profile + + out_dict = { + "duplex_multiturn": is_duplex, + "regular_multiturn": (not is_duplex) and (not is_profile), + "profile_multiturn": is_profile, + "dataset_indices": [int(s.get("__dataset_index__", -1)) for s in batch], + "text_token_counts": [int(s.get("__text_token_count__", -1)) for s in batch], + } + + tokenized_list = [] + batched_turns = [] + batched_turn_lens = [] + valid_turn_masks = [] + + if is_duplex: + for s in batch: + text_data = s["text"] + + if isinstance(text_data, list): + full_ids = [] + for segment in text_data: + seg_ids = model.tokenizer.encode(segment, tokenizer_name=main_tokenizer_name) + [model.eos_id] + pad_ids = [model.pad_id] * (len(seg_ids) * pad_factor_text_speech) + + if force_interruption: + fname = s["audio_filepath"] + no_ext = fname.split(".")[0] + sample_id = int(no_ext.split("_")[-1]) + case = sample_id % 3 + + if case == 0: + if len(seg_ids) >= 2: + seg_ids[-2] = model.interruption_token_id + seg_ids[-1] = model.pad_id + else: + pad_ids[0] = model.interruption_token_id + elif case == 1: + eos_idx = min(6, len(pad_ids) - 1) + pad_ids[eos_idx] = model.interruption_token_id + else: + pad_ids[0] = model.interruption_token_id + + elif add_interruption_token: + eos_idx = int(len(pad_ids) * 0.7) + pad_ids[eos_idx] = model.interruption_token_id + + full_ids.extend(seg_ids) + full_ids.extend(pad_ids) + + tokenized_list.append(torch.as_tensor(full_ids, dtype=torch.long)) + else: + tokenized_list.append( + torch.as_tensor( + model.tokenizer.encode(text_data, tokenizer_name=main_tokenizer_name) + [model.eos_id], + dtype=torch.long, + ) + ) + + prefix = torch.full((25,), model.pad_id, dtype=torch.long) + tokenized_list = [torch.cat([prefix, x]) for x in tokenized_list] + out_dict["input_lengths"] = torch.tensor([len(x) for x in tokenized_list], dtype=torch.long) + out_dict["input_ids"] = pad_sequence(tokenized_list, batch_first=True, padding_value=model.pad_id) + + else: + max_turns = 1 + for s in batch: + if isinstance(s["text"], list): + max_turns = max(max_turns, len(s["text"])) + + for t in range(max_turns): + turn_t_tokens, turn_t_lens, turn_t_valid = [], [], [] + + for s in batch: + text_data = s["text"] + + if isinstance(text_data, list): + if t < len(text_data): + seg_ids = model.tokenizer.encode(text_data[t], tokenizer_name=main_tokenizer_name) + [ + model.eos_id + ] + turn_t_tokens.append(torch.as_tensor(seg_ids, dtype=torch.long)) + turn_t_lens.append(len(seg_ids)) + turn_t_valid.append(True) + else: + turn_t_tokens.append(torch.as_tensor([model.pad_id], dtype=torch.long)) + turn_t_lens.append(1) + turn_t_valid.append(False) + else: + if t == 0: + seg_ids = model.tokenizer.encode(text_data, tokenizer_name=main_tokenizer_name) + [ + model.eos_id + ] + turn_t_tokens.append(torch.as_tensor(seg_ids, dtype=torch.long)) + turn_t_lens.append(len(seg_ids)) + turn_t_valid.append(True) + else: + turn_t_tokens.append(torch.as_tensor([model.pad_id], dtype=torch.long)) + turn_t_lens.append(1) + turn_t_valid.append(False) + + batched_turns.append(pad_sequence(turn_t_tokens, batch_first=True, padding_value=model.pad_id)) + batched_turn_lens.append(torch.tensor(turn_t_lens, dtype=torch.long)) + valid_turn_masks.append(torch.tensor(turn_t_valid, dtype=torch.bool)) + + out_dict["batched_turns"] = batched_turns + out_dict["batched_turn_lens"] = batched_turn_lens + out_dict["valid_turn_masks"] = valid_turn_masks + + audio_list, audio_lengths, target_num_frames = [], [], [] + max_turns_for_user_audio = len(batched_turns) if not is_duplex else 0 + + if is_profile and max_turns_for_user_audio > 0: + user_audio_by_turn = [[] for _ in range(max_turns_for_user_audio)] + user_audio_lens_by_turn = [[] for _ in range(max_turns_for_user_audio)] + else: + user_audio_by_turn, user_audio_lens_by_turn = [], [] + + for i, s in enumerate(batch): + audio_path = _resolve_audio_path(s.get("context_audio_filepath"), root_path) + wav = _load_audio(audio_path, sample_rate, normalize=normalize_audio_volume, use_librosa=use_librosa) + audio_list.append(wav) + audio_lengths.append(len(wav)) + + if is_profile and max_turns_for_user_audio > 0: + user_audio_paths = s.get("user_audio_file_path", None) + + for t in range(max_turns_for_user_audio): + has_valid_text_turn = (isinstance(s["text"], list) and t < len(s["text"])) or ( + not isinstance(s["text"], list) and t == 0 + ) + + if ( + isinstance(user_audio_paths, list) + and t < len(user_audio_paths) + and user_audio_paths[t] + and has_valid_text_turn + ): + ua_path = _resolve_audio_path(user_audio_paths[t], root_path) + ua_wav = _load_audio( + ua_path, + sample_rate=sample_rate, + normalize=normalize_audio_volume, + use_librosa=use_librosa, + ) + else: + ua_wav = torch.zeros(int(2 * sample_rate), dtype=torch.float32) + + user_audio_by_turn[t].append(ua_wav) + user_audio_lens_by_turn[t].append(len(ua_wav)) + + tdur_audio_path = _resolve_audio_path(s["audio_filepath"], root_path) + + if tdur_audio_path and os.path.exists(tdur_audio_path): + wav_dur = _load_audio( + tdur_audio_path, + sample_rate, + normalize=normalize_audio_volume, + use_librosa=use_librosa, + ) + tdur = wav_dur.shape[0] // model.input_samples_per_frame + target_num_frames.append(tdur * extra_duration_thrshould) + else: + if is_duplex: + current_text_len = len(tokenized_list[i]) + target_num_frames.append(current_text_len if isinstance(s["text"], list) else current_text_len * 5) + else: + target_num_frames.append(sum([l[i].item() for l in batched_turn_lens]) * 5) + + max_audio_len = max(audio_lengths) + B = len(audio_lengths) + padded_audio = torch.zeros((B, max_audio_len), dtype=torch.float32) + + for i, wav in enumerate(audio_list): + padded_audio[i, : len(wav)] = wav + + if is_profile and max_turns_for_user_audio > 0: + padded_user_audio_turns, padded_user_audio_turns_lens = [], [] + + for t in range(max_turns_for_user_audio): + turn_lens = user_audio_lens_by_turn[t] + max_turn_audio_len = max(turn_lens) + padded_turn_audio = torch.zeros((B, max_turn_audio_len), dtype=torch.float32) + + for i, wav in enumerate(user_audio_by_turn[t]): + padded_turn_audio[i, : len(wav)] = wav + + padded_user_audio_turns.append(padded_turn_audio) + padded_user_audio_turns_lens.append(torch.tensor(turn_lens, dtype=torch.long)) + + out_dict["user_audio_turns"] = padded_user_audio_turns + out_dict["user_audio_turns_lens"] = padded_user_audio_turns_lens + + out_dict["context_audio"] = padded_audio + out_dict["context_audio_lengths"] = torch.tensor(audio_lengths, dtype=torch.long) + out_dict["target_audio_paths"] = [s["audio_filepath"] for s in batch] + out_dict["target_num_frames"] = target_num_frames + out_dict["raw_text"] = [" ".join(s["text"]) if isinstance(s["text"], list) else s["text"] for s in batch] + + return out_dict + + +def build_model_and_codec(args, target_device, target_dtype): + model_cfg = EasyMagpieTTSInferenceModel.restore_from(args.checkpoint_path, return_config=True) + + with open_dict(model_cfg): + model_cfg.target = "nemo.collections.tts.models.easy_magpietts_inference.EasyMagpieTTSInferenceModel" + model_cfg.codecmodel_path = args.codec_model_path + model_cfg.train_ds = None + model_cfg.validation_ds = None + model_cfg.run_val_inference = False + model_cfg.use_utmos = False + model_cfg.use_meta_init_for_decoder = True + + if args.phoneme_tokenizer_path and getattr(model_cfg, "phoneme_tokenizer", None) is not None: + model_cfg.phoneme_tokenizer.tokenizer_path = args.phoneme_tokenizer_path + + model = EasyMagpieTTSInferenceModel.restore_from( + args.checkpoint_path, + override_config_path=model_cfg, + map_location=torch.device("cpu"), + ) + model.use_kv_cache_for_inference = True + model.to(dtype=target_dtype) + model.eval().to(target_device) + + model.input_samples_per_frame = int(model.codec_model_samples_per_frame * model.frame_stacking_factor) + model.target_samples_per_frame = model.input_samples_per_frame / (model.sample_rate / model.output_sample_rate) + + codec_model = AudioCodecModel.restore_from(args.codec_model_path, strict=False, map_location=torch.device("cpu")) + if hasattr(codec_model, "discriminator"): + del codec_model.discriminator + codec_model.freeze() + codec_model = codec_model.to(target_device).eval() + + codec_converter = None + if getattr(model, "_codec_converter", None) is not None: + vq_new = deepcopy(model._codec_converter.vector_quantizer_new).to(target_device).eval() + codec_converter = VectorQuantizerIndexConverter( + vector_quantizer_original=codec_model.vector_quantizer, + vector_quantizer_new=vq_new, + ).to(target_device).eval() + + model._codec_helper = CodecHelper(codec_model=codec_model, codec_converter=codec_converter) + model._generate_codec_silence_buffer() + + return model + + +def prepare_inputs_for_device(inputs, model, args, target_dtype, speaker_wav=None): + B = inputs["context_audio"].size(0) + device = model.device + + inputs["context_audio"] = inputs["context_audio"].to(device, dtype=target_dtype) + inputs["context_audio_lengths"] = inputs["context_audio_lengths"].to(device) + + if args.user_custom_speaker_reference and speaker_wav is not None: + inputs["context_audio"] = speaker_wav.repeat(B, 1).detach() + inputs["context_audio_lengths"] = torch.full((B,), speaker_wav.size(-1), dtype=torch.long, device=device) + + if "user_audio_turns" in inputs: + inputs["user_audio_turns"] = [x.to(device, dtype=target_dtype) for x in inputs["user_audio_turns"]] + inputs["user_audio_turns_lens"] = [x.to(device) for x in inputs["user_audio_turns_lens"]] + + return inputs + + +def run_generation(model, inputs, args, codec_sil_codes): + B = inputs["context_audio"].size(0) + device = model.device + profile_turn_frame_ranges = [] + profile_decode_start_frame = 0 + + with torch.inference_mode(): + wav = inputs["context_audio"] + wav_len = inputs["context_audio_lengths"] + codes, codes_lens = model._codec_helper.audio_to_codes(wav, wav_len) + + use_lang = bool(getattr(model, "add_language_to_context_text", False)) + ctx_text = f"[{args.language.upper()}]" if use_lang else "[NO TEXT CONTEXT]" + ctx_text_ids = model.tokenizer.encode(ctx_text, tokenizer_name=model.text_conditioning_tokenizer_name) + + ctx_toks = torch.tensor([ctx_text_ids], dtype=torch.long, device=device).expand(B, -1) + ctx_toks_lens = torch.tensor([len(ctx_text_ids)] * B, dtype=torch.long, device=device) + + state = model.streaming_init( + context_audio_codes=codes, + context_audio_codes_lens=codes_lens, + context_text_tokens=ctx_toks, + context_text_tokens_lens=ctx_toks_lens, + use_cfg=args.use_cfg, + cfg_scale=args.cfg_scale, + use_local_transformer=True, + temperature=args.temperature, + topk=args.topk, + phoneme_input_type="pred", + phoneme_sampling_method="argmax", + use_inference_mode=True, + ) + + if inputs["duplex_multiturn"]: + text = inputs["input_ids"].to(device) + text_lens = inputs["input_lengths"].to(device) + + in_initial_silence = torch.ones(B, dtype=torch.bool, device=device) + in_post_speech_silence = torch.zeros(B, dtype=torch.bool, device=device) + + text_exhausted = state.text_tokens_seen >= text_lens + + while not text_exhausted.all(): + state.finished = state.finished & text_exhausted + state.text_finished = state.text_finished & text_exhausted + + if hasattr(state, "phoneme_stream_ended"): + state.phoneme_stream_ended = state.phoneme_stream_ended & text_exhausted + + positions = state.text_tokens_seen.clamp(max=text.size(1) - 1) + current_tokens = text[torch.arange(B, device=device), positions] + current_tokens = torch.where( + text_exhausted, + torch.full_like(current_tokens, model.eos_id), + current_tokens, + ) + + is_pad_or_eos = (current_tokens == model.pad_id) | (current_tokens == model.eos_id) + in_initial_silence = in_initial_silence & is_pad_or_eos + in_post_speech_silence = in_post_speech_silence & is_pad_or_eos + + state, audio_codes, _ = model.streaming_step( + state=state, + text_tokens=current_tokens, + use_inference_mode=True, + ) + + if audio_codes is not None and args.force_speech_sil_codes: + force_silence_mask = in_initial_silence | in_post_speech_silence + if force_silence_mask.any(): + expanded_sil = codec_sil_codes.view(1, -1, 1).expand_as(audio_codes) + mask_3d = force_silence_mask.view(B, 1, 1) + state.all_predictions[-1] = torch.where(mask_3d, expanded_sil, audio_codes) + + in_post_speech_silence = in_post_speech_silence | state.finished + text_exhausted = state.text_tokens_seen >= text_lens + + elif inputs["regular_multiturn"]: + batched_turns = inputs["batched_turns"] + batched_turn_lens = inputs["batched_turn_lens"] + valid_turn_masks = inputs["valid_turn_masks"] + + turn_offsets = torch.zeros(B, dtype=torch.long, device=device) + + for t in range(len(batched_turns)): + turn_text = batched_turns[t].to(device) + turn_lens = batched_turn_lens[t].to(device) + valid_mask = valid_turn_masks[t].to(device) + + state.finished = state.finished & (~valid_mask) + state.text_finished = state.text_finished & (~valid_mask) + + if hasattr(state, "phoneme_stream_ended"): + state.phoneme_stream_ended = state.phoneme_stream_ended & (~valid_mask) + + if state.finished.all(): + continue + + turn_offsets = torch.where(valid_mask, state.text_tokens_seen, turn_offsets) + turn_steps = 0 + + while not state.finished.all() and turn_steps < args.max_tts_steps: + turn_steps += 1 + + relative_positions = state.text_tokens_seen - turn_offsets + positions = relative_positions.clamp(min=0, max=turn_text.size(1) - 1) + current_tokens = turn_text[torch.arange(B, device=device), positions] + + exhausted = relative_positions >= turn_lens + current_tokens = torch.where( + exhausted, + torch.full_like(current_tokens, model.eos_id), + current_tokens, + ) + + state, _, _ = model.streaming_step( + state=state, + text_tokens=current_tokens, + use_inference_mode=True, + ) + + elif inputs["profile_multiturn"]: + if B != 1: + raise RuntimeError("--profile_multiturn_inference requires --batch_size=1 per process.") + + batched_turns = inputs["batched_turns"] + batched_turn_lens = inputs["batched_turn_lens"] + valid_turn_masks = inputs["valid_turn_masks"] + + for t in range(len(batched_turns)): + turn_text = batched_turns[t].to(device) + turn_lens = batched_turn_lens[t].to(device) + valid_mask = valid_turn_masks[t].to(device) + + if not bool(valid_mask[0].item()): + continue + + state.finished.zero_() + state.text_finished.zero_() + state.audio_prediction_end_idx.fill_(-1) + + if hasattr(state, "turn_text_tokens_seen"): + state.turn_text_tokens_seen.zero_() + if hasattr(state, "phoneme_steps"): + state.phoneme_steps.zero_() + if hasattr(state, "phoneme_stream_ended"): + state.phoneme_stream_ended.zero_() + if hasattr(state, "phoneme_eos_detected"): + state.phoneme_eos_detected.zero_() + + state.last_phoneme_tokens = None + + if not model.cfg.get("condition_on_user_speech", False): + if "user_audio_turns" in inputs: + profile_T = int(round(inputs["user_audio_turns"][t].size(-1) / model.input_samples_per_frame)) + profile_seconds = profile_T * model.input_samples_per_frame / model.sample_rate + else: + profile_seconds = args.profile_pad_min_sec + torch.rand((), device=device).item() * ( + args.profile_pad_max_sec - args.profile_pad_min_sec + ) + profile_T = max( + 1, + int(round(profile_seconds * model.sample_rate / model.input_samples_per_frame)), + ) + + profile_tokens = torch.full((1, profile_T), model.pad_id, dtype=torch.long, device=device) + user_audio_channel_embedding = None + + else: + if "user_audio_turns" in inputs: + user_audio = inputs["user_audio_turns"][t] + user_audio_lens = inputs["user_audio_turns_lens"][t] + else: + user_audio = inputs["context_audio"] + user_audio_lens = inputs["context_audio_lengths"] + + user_audio_codes, user_audio_codes_lens = model._codec_helper.audio_to_codes( + user_audio, + user_audio_lens, + ) + + if model._codec_converter is not None: + user_audio_codes = model._codec_converter.convert_original_to_new( + audio_tokens=user_audio_codes, + audio_lens=user_audio_codes_lens, + ).long() + + user_audio_codes, user_audio_codes_lens = model.stack_codes( + user_audio_codes, + user_audio_codes_lens, + model.audio_bos_id, + model.audio_eos_id, + model.frame_stacking_factor, + model.num_audio_codebooks, + ) + + user_audio_embedded = model.embed_audio_tokens(user_audio_codes) + + boundary_trim = model.cfg.get("user_audio_boundary_trim", 4) + boundary_trim = 0 if boundary_trim is None else int(boundary_trim) + + if boundary_trim == 0: + real_start = 0 + real_end = int(user_audio_codes_lens[0].item()) + else: + turn_len_with_special = int(user_audio_codes_lens[0].item()) + real_start = 1 + real_end = max(real_start, turn_len_with_special - 1) + + user_audio_embedded = user_audio_embedded[:, real_start:real_end] + copy_len = user_audio_embedded.size(1) + + if boundary_trim > 0: + trim = min(boundary_trim, copy_len // 2) + if trim > 0: + user_audio_embedded[:, :trim] = 0.0 + user_audio_embedded[:, copy_len - trim :] = 0.0 + + bos_user_pad = torch.zeros( + user_audio_embedded.size(0), + 1, + user_audio_embedded.size(2), + device=user_audio_embedded.device, + dtype=user_audio_embedded.dtype, + ) + user_audio_embedded = torch.cat([bos_user_pad, user_audio_embedded], dim=1) + + profile_T = user_audio_embedded.size(1) + profile_tokens = torch.full((B, profile_T), model.pad_id, dtype=torch.long, device=device) + user_audio_channel_embedding = user_audio_embedded + profile_seconds = profile_T * model.input_samples_per_frame / model.sample_rate + + delay_tokens = int(state.config.training_mode.streaming_speech_delay) + delay_tokens = min(delay_tokens, int(turn_lens[0].item()), profile_T) + + warmup_tokens = turn_text[:, :delay_tokens] + turn_text = turn_text[:, delay_tokens:] + turn_lens = torch.clamp(turn_lens - delay_tokens, min=0) + + if user_audio_channel_embedding is not None and delay_tokens > 0: + warmup_user_audio = user_audio_channel_embedding[:, -delay_tokens:] + user_audio_channel_embedding = user_audio_channel_embedding[:, :-delay_tokens] + profile_tokens = profile_tokens[:, :-delay_tokens] + else: + warmup_user_audio = None + + if profile_tokens.size(1) > 0: + state = model.streaming_prefill_profile( + state=state, + text_tokens=profile_tokens, + use_inference_mode=True, + user_audio_channel_embedding=user_audio_channel_embedding, + ) + + for i in range(delay_tokens): + user_step_emb = warmup_user_audio[:, i] if warmup_user_audio is not None else None + + state.finished.zero_() + state, _, _ = model.streaming_step( + state=state, + text_tokens=warmup_tokens[:, i], + user_audio_channel_embedding=user_step_emb, + prefill_like_step=not bool(model.cfg.get("agent_mask_include_transition_prefix", False)), + prefill_like_is_last_step=(i == delay_tokens - 1), + use_inference_mode=True, + ) + + logging.info(f"[profile_multiturn] turn={t} prefilled {profile_T} steps ({profile_seconds:.2f}s)") + + turn_start_frame = sum(p.size(-1) for p in state.all_predictions) + if t == 0: + state.audio_prediction_start_idx.fill_(turn_start_frame) + profile_decode_start_frame = turn_start_frame + + turn_offset = state.text_tokens_seen.clone() + turn_steps = 0 + saw_audio = False + turn_ended_with_audio_eos = False + + while turn_steps < args.max_tts_steps: + turn_steps += 1 + state.finished.zero_() + + relative_position = state.text_tokens_seen - turn_offset + text_exhausted = relative_position >= turn_lens + + if turn_text.size(1) == 0: + current_tokens = torch.full((B,), model.eos_id, dtype=torch.long, device=device) + else: + position = relative_position.clamp(min=0, max=turn_text.size(1) - 1) + current_tokens = turn_text[torch.arange(B, device=device), position] + current_tokens = torch.where( + text_exhausted, + torch.full_like(current_tokens, model.eos_id), + current_tokens, + ) + + state, audio_codes, _ = model.streaming_step( + state=state, + text_tokens=current_tokens, + use_inference_mode=True, + ) + + if audio_codes is not None and not saw_audio: + saw_audio = True + + if bool(text_exhausted[0].item()) and bool(state.finished[0].item()): + turn_ended_with_audio_eos = True + break + + state.audio_prediction_end_idx.fill_(-1) + state.finished.zero_() + + turn_end_frame = sum(p.size(-1) for p in state.all_predictions) + profile_turn_frame_ranges.append((t, turn_start_frame, turn_end_frame)) + + logging.info( + f"[profile_multiturn] turn={t} steps={turn_steps} " + f"saw_audio={saw_audio} ended_with_audio_eos={turn_ended_with_audio_eos}" + ) + + bos_id = getattr(model, "audio_bos_id", -1) + eos_id = getattr(model, "audio_eos_id", -1) + speaking_id = getattr(model, "audio_user_speaking_id", -1) + speaking_end_id = getattr(model, "audio_user_speaking_end_id", -1) + sil_injection = codec_sil_codes.view(1, -1, 1) + + for step_idx in range(len(state.all_predictions)): + pred = state.all_predictions[step_idx] + mask = (pred == bos_id) | (pred == eos_id) | (pred == speaking_id) | (pred == speaking_end_id) + frame_mask = mask.any(dim=1, keepdim=True) + if frame_mask.any(): + state.all_predictions[step_idx] = torch.where(frame_mask, sil_injection.expand_as(pred), pred) + + if inputs["duplex_multiturn"] or inputs["profile_multiturn"]: + state.audio_prediction_end_idx.fill_(-1) + + finalize_output = model.streaming_finalize(state, use_inference_mode=True) + + return finalize_output, profile_turn_frame_ranges, profile_decode_start_frame + + +def update_metrics_and_save_audio( + model, + inputs, + finalize_output, + profile_turn_frame_ranges, + profile_decode_start_frame, + intelligibility, + secs_metric, + args, + rank, + run_id: int = 0, +): + device = model.device + B = inputs["context_audio"].size(0) + + with fp32_precision(): + audio_f32 = finalize_output.audio.float() + audio_len = finalize_output.audio_len.int() + + expected_audio_lens = ( + torch.tensor(inputs["target_num_frames"], device=device) * model.target_samples_per_frame + ).int() + + if inputs["duplex_multiturn"]: + text_lens = inputs["input_lengths"].to(device) + audio_len = (text_lens * model.target_samples_per_frame).int() + audio_len = torch.min(audio_len, torch.tensor(audio_f32.size(1), device=device)) + elif inputs["profile_multiturn"]: + audio_len = finalize_output.audio_len.int() + else: + audio_len = torch.min(audio_len, expected_audio_lens) + + metric_audio_pred = resample(audio_f32, model.output_sample_rate, 16000) + metric_audio_pred_lens = (audio_len / model.output_sample_rate * 16000).to(torch.long) + + context_audio = resample(inputs["context_audio"].float(), model.sample_rate, 16000) + context_audio_lens = (inputs["context_audio_lengths"] / model.sample_rate * 16000).to(torch.long) + + metric_audio_pred = torch_rms_norm(metric_audio_pred) + context_audio = torch_rms_norm(context_audio) + + asr_hyps = intelligibility.update( + name="dataset", + refs=inputs["raw_text"], + pred_audio=metric_audio_pred, + pred_audio_lens=metric_audio_pred_lens, + asr_hyps=None, + ) + + secs_metric.update( + name="dataset", + target_audio=context_audio, + target_audio_lens=context_audio_lens, + pred_audio=metric_audio_pred, + pred_audio_lens=metric_audio_pred_lens, + ) + + audio_out_dir = get_audio_out_dir(args) + os.makedirs(audio_out_dir, exist_ok=True) + + audio_f32_cpu = audio_f32.detach().cpu() + audio_len_cpu = audio_len.detach().cpu() + + for i in range(B): + target_path = inputs["target_audio_paths"][i] + base_name = os.path.basename(target_path) + stem, ext = os.path.splitext(base_name) + if not ext: + ext = ".wav" + + dataset_idx = inputs.get("dataset_indices", [-1] * B)[i] + run_prefix = f"run{int(run_id):03d}_" if int(getattr(args, "num_eval_runs", 1)) > 1 else "" + safe_stem = f"{run_prefix}idx{dataset_idx:08d}_{stem}" if dataset_idx >= 0 else f"{run_prefix}rank{rank}_{stem}" + + if inputs["profile_multiturn"]: + full_len = int(audio_len_cpu[i].item()) + full_wav_t = audio_f32_cpu[i, :full_len].float() + + samples_per_prediction_frame = model.codec_model_samples_per_frame / ( + model.sample_rate / model.output_sample_rate + ) + + aligned_agent = torch.zeros_like(full_wav_t) + + for turn_id, start_frame, end_frame in profile_turn_frame_ranges: + rel_start_frame = start_frame - profile_decode_start_frame + rel_end_frame = end_frame - profile_decode_start_frame + + start_sample = int(round(rel_start_frame * samples_per_prediction_frame)) + end_sample = int(round(rel_end_frame * samples_per_prediction_frame)) + + start_sample = max(0, min(start_sample, full_len)) + end_sample = max(start_sample, min(end_sample, full_len)) + + aligned_agent[start_sample:end_sample] = full_wav_t[start_sample:end_sample] + + turn_wav = aligned_agent[start_sample:end_sample].numpy() + out_path = os.path.join(audio_out_dir, f"{safe_stem}_turn_{turn_id}{ext}") + sf.write(out_path, turn_wav, samplerate=model.output_sample_rate) + + out_path = os.path.join(audio_out_dir, f"{safe_stem}{ext}") + sf.write(out_path, aligned_agent.numpy(), samplerate=model.output_sample_rate) + + if "user_audio_turns" in inputs: + user_segments = [] + + first_user_len_in = int(inputs["user_audio_turns_lens"][0][i].item()) + first_user_delay_out = int(round(first_user_len_in * model.output_sample_rate / model.sample_rate)) + + for turn_id, start_frame, _ in profile_turn_frame_ranges: + if turn_id >= len(inputs["user_audio_turns"]): + continue + + turn_audio = inputs["user_audio_turns"][turn_id][i].detach().cpu().float() + turn_audio_len = int(inputs["user_audio_turns_lens"][turn_id][i].item()) + turn_audio = turn_audio[:turn_audio_len] + + turn_audio_out = resample( + turn_audio.unsqueeze(0), + model.sample_rate, + model.output_sample_rate, + ).squeeze(0) + + if turn_id == 0: + user_start_sample = 0 + else: + prev_turn_end_frame = profile_turn_frame_ranges[turn_id - 1][2] + rel_prev_end_frame = prev_turn_end_frame - profile_decode_start_frame + user_start_sample = first_user_delay_out + int( + round(rel_prev_end_frame * samples_per_prediction_frame) + ) + + user_segments.append((user_start_sample, turn_audio_out)) + + total_user_len = 0 + for s, wav_seg in user_segments: + total_user_len = max(total_user_len, s + wav_seg.numel()) + + user_ch = torch.zeros(total_user_len) + for s, wav_seg in user_segments: + e = s + wav_seg.numel() + user_ch[s:e] += wav_seg + + agent_ch = torch.cat([torch.zeros(first_user_delay_out, dtype=aligned_agent.dtype), aligned_agent]) + + stereo_len = max(user_ch.numel(), agent_ch.numel()) + user_pad = torch.zeros(stereo_len) + agent_pad = torch.zeros(stereo_len) + + user_pad[: user_ch.numel()] = user_ch + agent_pad[: agent_ch.numel()] = agent_ch + + stereo = torch.stack([user_pad, agent_pad], dim=1).numpy() + aligned_path = os.path.join(audio_out_dir, f"{safe_stem}_user_agent_aligned{ext}") + sf.write(aligned_path, stereo, samplerate=model.output_sample_rate) + + else: + wav = audio_f32_cpu[i, : audio_len_cpu[i]].numpy() + out_path = os.path.join(audio_out_dir, f"{safe_stem}{ext}") + sf.write(out_path, wav, samplerate=model.output_sample_rate) + + return audio_f32.detach(), audio_len.detach(), asr_hyps + + +def compute_filewise_metrics_for_batch( + rank: int, + model, + inputs, + audio_f32: torch.Tensor, + audio_len: torch.Tensor, + asr_hyps: List[str], + run_id: int = 0, +): + filewise_rows = [] + B = audio_f32.size(0) + device = model.device + + for i in range(B): + dataset_idx = int(inputs.get("dataset_indices", [-1] * B)[i]) + target_path = inputs["target_audio_paths"][i] + ref_text = inputs["raw_text"][i] + asr_hyp_text = asr_hyps[i] if asr_hyps is not None and i < len(asr_hyps) else None + + pred_len_i = int(audio_len[i].item()) + pred_audio_i = audio_f32[i : i + 1, :pred_len_i].float() + pred_audio_len_i = torch.tensor([pred_len_i], dtype=torch.long, device=device) + + context_len_i = int(inputs["context_audio_lengths"][i].item()) + context_audio_i = inputs["context_audio"][i : i + 1, :context_len_i].float() + context_audio_len_i = torch.tensor([context_len_i], dtype=torch.long, device=device) + + with fp32_precision(): + pred_16k = resample(pred_audio_i, model.output_sample_rate, 16000) + pred_16k_len = (pred_audio_len_i / model.output_sample_rate * 16000).to(torch.long) + + context_16k = resample(context_audio_i, model.sample_rate, 16000) + context_16k_len = (context_audio_len_i / model.sample_rate * 16000).to(torch.long) + + pred_16k = torch_rms_norm(pred_16k) + context_16k = torch_rms_norm(context_16k) + + one_intelligibility = Intelligibility( + "stt_en_fastconformer_transducer_large", + reuse_asr_hyps=True, + ).reset() + one_intelligibility.update( + name="dataset", + refs=[ref_text], + pred_audio=None, + pred_audio_lens=None, + asr_hyps=[asr_hyp_text or ""], + ) + one_intel_metrics = metric_dict_to_jsonable(one_intelligibility.compute()) + + one_secs = SECS("titanet_large").reset() + one_secs.update( + name="dataset", + target_audio=context_16k, + target_audio_lens=context_16k_len, + pred_audio=pred_16k, + pred_audio_lens=pred_16k_len, + ) + one_secs_metrics = metric_dict_to_jsonable(one_secs.compute()) + + cer = safe_metric_scalar(one_intel_metrics, ["cer", "cer_dataset"]) + wer = safe_metric_scalar(one_intel_metrics, ["wer", "wer_dataset"]) + secs = safe_metric_scalar(one_secs_metrics, ["secs", "secs_dataset"]) + + filewise_rows.append( + { + "run_id": int(run_id), + "rank": int(rank), + "dataset_index": int(dataset_idx), + "target_audio_path": target_path, + "reference_text": ref_text, + "asr_hyp": asr_hyp_text, + "cer": cer, + "wer": wer, + "secs": secs, + "pred_audio_samples": int(pred_len_i), + "pred_audio_seconds": float(pred_len_i / model.output_sample_rate), + "intelligibility": one_intel_metrics, + "secs_metrics": one_secs_metrics, + } + ) + + return filewise_rows + + +def load_speaker_wav_if_needed(args, model, target_dtype): + if args.user_custom_speaker_reference and args.inference_speaker_reference: + return _load_audio( + args.inference_speaker_reference, + model.sample_rate, + normalize=args.normalize_volume, + use_librosa=args.use_librosa, + ).unsqueeze(0).to(model.device, dtype=target_dtype) + + return None + + +def compute_and_save_rank_metrics(args, rank, world_size, num_processed, elapsed, intelligibility, secs_metric): + if num_processed > 0: + with fp32_precision(): + cer_wer = metric_dict_to_jsonable(intelligibility.compute()) + secs_scores = metric_dict_to_jsonable(secs_metric.compute()) + else: + cer_wer = {} + secs_scores = {} + + rank_metrics = { + "rank": int(rank), + "world_size": int(world_size), + "num_processed": int(num_processed), + "elapsed_sec": float(elapsed), + "num_eval_runs": int(getattr(args, "num_eval_runs", 1)), + "intelligibility": cer_wer, + "secs": secs_scores, + } + + rank_path = os.path.join(args.out_dir, f"metrics_rank{rank:04d}.json") + write_json(rank_path, rank_metrics) + + return rank_metrics + + +def merge_metrics_on_rank0(args, rank, world_size): + if rank != 0: + return None + + rank_metric_files = [os.path.join(args.out_dir, f"metrics_rank{r:04d}.json") for r in range(world_size)] + + rank_metrics = [] + for path in rank_metric_files: + if not os.path.exists(path): + logging.warning(f"Missing rank metric file: {path}") + continue + with open(path, "r", encoding="utf-8") as f: + rank_metrics.append(json.load(f)) + + total_n = sum(int(m.get("num_processed", 0)) for m in rank_metrics) + + def weighted_average(section: str): + keys = set() + for m in rank_metrics: + keys.update(m.get(section, {}).keys()) + + out = {} + for k in sorted(keys): + numerator = 0.0 + denominator = 0 + + for m in rank_metrics: + n = int(m.get("num_processed", 0)) + if n <= 0: + continue + + value = m.get(section, {}).get(k, None) + if value is None or isinstance(value, str): + continue + + try: + value = float(value) + except Exception: + continue + + numerator += value * n + denominator += n + + if denominator > 0: + out[k] = numerator / denominator + + return out + + final_metrics = { + "world_size": int(world_size), + "num_processed": int(total_n), + "aggregation": "sum(rank_metric * rank_num_samples) / total_num_samples; repeated DistributedSampler samples included", + "intelligibility": weighted_average("intelligibility"), + "secs": weighted_average("secs"), + "ranks": rank_metrics, + } + + final_json_path = os.path.join(args.out_dir, "metrics_final.json") + final_txt_path = os.path.join(args.out_dir, "metrics_final.txt") + + write_json(final_json_path, final_metrics) + + final_text = format_final_metric_text(final_metrics) + write_text_atomic(final_txt_path, final_text) + + print("\n--- Final Evaluation Metrics ---", flush=True) + print(final_text, flush=True) + + logging.info(f"Final metrics JSON saved to: {final_json_path}") + logging.info(f"Final metrics TXT saved to: {final_txt_path}") + logging.info(json.dumps(final_metrics, indent=2, sort_keys=True)) + + return final_metrics + + +def merge_filewise_metrics_on_rank0(args, rank: int, world_size: int): + if rank != 0 or not args.save_filewise_metrics: + return [] + + all_rows = [] + + for r in range(world_size): + path = os.path.join(args.out_dir, f"filewise_metrics_rank{r:04d}.jsonl") + if not os.path.exists(path): + logging.warning(f"Missing filewise metrics file: {path}") + continue + + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + all_rows.append(json.loads(line)) + + deduped = {} + for row in all_rows: + run_id = int(row.get("run_id", 0)) + idx = int(row.get("dataset_index", -1)) + key = (run_id, idx) + if key not in deduped: + deduped[key] = row + + all_rows = list(deduped.values()) + + all_rows.sort( + key=lambda x: ( + x.get("cer") is not None, + float(x.get("cer")) if x.get("cer") is not None else -1.0, + ), + reverse=True, + ) + + jsonl_path = os.path.join(args.out_dir, "filewise_metrics_sorted_by_cer.jsonl") + csv_path = os.path.join(args.out_dir, "filewise_metrics_sorted_by_cer.csv") + + write_jsonl(jsonl_path, all_rows) + write_filewise_csv(csv_path, all_rows) + + logging.info(f"Saved sorted filewise metrics JSONL to: {jsonl_path}") + logging.info(f"Saved sorted filewise metrics CSV to: {csv_path}") + + topk = min(int(args.filewise_metrics_topk_log), len(all_rows)) + if topk > 0: + logging.info(f"Top {topk} worst CER samples:") + for row in all_rows[:topk]: + logging.info( + "run_id=%s dataset_index=%s cer=%s wer=%s secs=%s path=%s text=%s" + % ( + row.get("run_id"), + row.get("dataset_index"), + row.get("cer"), + row.get("wer"), + row.get("secs"), + row.get("target_audio_path"), + row.get("reference_text"), + ) + ) + + return all_rows + + +def compute_aggregates_from_filewise_rows(rows: List[Dict[str, Any]]): + """ + Compute final metrics after deduplicating DistributedSampler repeats. + + CER/WER are computed the same way Intelligibility.compute() does: + word_error_rate(all_normalized_hyps, all_normalized_refs, use_cer=True/False) + + SECS is averaged over the deduplicated per-file SECS values. + """ + if len(rows) == 0: + return { + "cer": None, + "wer": None, + "secs": None, + "num_samples": 0, + } + + normalizer = EnglishTextNormalizer() + refs = [] + hyps = [] + secs_vals = [] + + for row in rows: + ref = row.get("reference_text", "") + hyp = row.get("asr_hyp", "") + refs.append(normalizer(ref)) + hyps.append(normalizer(hyp)) + + if row.get("secs") is not None: + secs_vals.append(float(row["secs"])) + + cer = float(word_error_rate(hyps, refs, use_cer=True)) if refs else None + wer = float(word_error_rate(hyps, refs, use_cer=False)) if refs else None + secs = (sum(secs_vals) / len(secs_vals)) if secs_vals else None + + return { + "cer": cer, + "wer": wer, + "secs": secs, + "num_samples": len(rows), + } + + +def compute_run_averaged_aggregates_from_filewise_rows(rows: List[Dict[str, Any]]): + """ + Compute a metric per run, then average those run metrics equally. + + This is useful when --num_eval_runs > 1 and you want the final number to mean: + average(metric(run_0), metric(run_1), ..., metric(run_N-1)) + """ + if len(rows) == 0: + return { + "cer": None, + "wer": None, + "secs": None, + "num_runs": 0, + "num_samples_per_run": {}, + "runs": [], + } + + grouped = {} + for row in rows: + run_id = int(row.get("run_id", 0)) + grouped.setdefault(run_id, []).append(row) + + run_summaries = [] + for run_id in sorted(grouped.keys()): + summary = compute_aggregates_from_filewise_rows(grouped[run_id]) + summary["run_id"] = int(run_id) + run_summaries.append(summary) + + def avg_key(key): + vals = [float(r[key]) for r in run_summaries if r.get(key) is not None] + if not vals: + return None + return sum(vals) / len(vals) + + return { + "cer": avg_key("cer"), + "wer": avg_key("wer"), + "secs": avg_key("secs"), + "num_runs": len(run_summaries), + "num_samples_per_run": {str(r["run_id"]): int(r.get("num_samples", 0)) for r in run_summaries}, + "runs": run_summaries, + } + +def save_filewise_final_summary(args, filewise_rows: List[Dict[str, Any]]): + all_observation_summary = compute_aggregates_from_filewise_rows(filewise_rows) + run_averaged_summary = compute_run_averaged_aggregates_from_filewise_rows(filewise_rows) + + obj = { + "aggregation": ( + "deduplicated_by_(run_id,dataset_index); " + "cer_wer_use_corpus_word_error_rate_matching_Intelligibility_compute; " + "primary_summary_is_mean_over_runs" + ), + "run_averaged": run_averaged_summary, + "all_observations": all_observation_summary, + } + + path = os.path.join(args.out_dir, "metrics_final_filewise_average.json") + write_json(path, obj) + + final_txt_path = os.path.join(args.out_dir, "metrics_final.txt") + final_text = format_filewise_final_metric_text(run_averaged_summary) + write_text_atomic(final_txt_path, final_text) + + print("\n--- Final Filewise Run-Averaged Evaluation Metrics ---", flush=True) + print(final_text, flush=True) + + logging.info(f"Filewise averaged final metrics saved to: {path}") + logging.info(f"Final metrics TXT saved to: {final_txt_path}") + + return obj + + +def parse_args(): + parser = argparse.ArgumentParser(description="EasyMagpieTTS Multi-GPU Inference Evaluation") + + parser.add_argument("--checkpoint_path", type=str, required=True) + parser.add_argument("--codec_model_path", type=str, required=True) + parser.add_argument("--datasets_json_path", type=str, required=True) + parser.add_argument("--out_dir", type=str, required=True) + + parser.add_argument("--phoneme_tokenizer_path", type=str, default=None) + parser.add_argument("--audio_dir", type=str, default=None) + parser.add_argument("--inference_dtype", type=str, default="float32", choices=["float32", "float16", "bfloat16"]) + parser.add_argument("--debug_dtype", action="store_true") + parser.add_argument("--debug_gpu_assignment", action="store_true") + parser.add_argument("--use_librosa", action="store_true") + + parser.add_argument("--batch_size", type=int, default=6) + parser.add_argument( + "--num_eval_runs", + type=int, + default=1, + help="Repeat the same evaluation set N times. Final filewise metrics are averaged across runs.", + ) + parser.add_argument( + "--sort_by_text_token_count", + action="store_true", + help="Sort samples by summed text-token count before DistributedSampler sharding for better GPU load balance.", + ) + parser.add_argument( + "--sort_text_token_count_descending", + action="store_true", + help="When sorting by token count, sort longest first. Default is shortest first to make DistributedSampler padding cheap.", + ) + parser.add_argument("--num_workers", type=int, default=4) + parser.add_argument("--num_turns", type=int, default=1) + parser.add_argument("--pad_factor_text_speech", type=int, default=10) + + parser.add_argument("--emulate_duplex_inference", action="store_true") + parser.add_argument("--add_interruption_token", action="store_true") + parser.add_argument("--force_interruption", action="store_true") + parser.add_argument("--profile_multiturn_inference", action="store_true") + parser.add_argument("--profile_pad_min_sec", type=float, default=2.0) + parser.add_argument("--profile_pad_max_sec", type=float, default=2.0) + parser.add_argument("--max_eval_turns", type=int, default=6) + + parser.add_argument("--user_custom_speaker_reference", action="store_true") + parser.add_argument("--inference_speaker_reference", type=str, default=None) + parser.add_argument("--language", type=str, default="en") + + parser.add_argument("--use_cfg", action="store_true") + parser.add_argument("--cfg_scale", type=float, default=2.5) + parser.add_argument("--temperature", type=float, default=0.7) + parser.add_argument("--topk", type=int, default=80) + parser.add_argument("--max_tts_steps", type=int, default=2000) + parser.add_argument("--force_speech_sil_codes", action="store_true") + parser.add_argument("--normalize_volume", type=lambda x: str(x).lower() in ["true", "1", "yes"], default=True) + + parser.add_argument( + "--save_filewise_metrics", + action="store_true", + help="Save per-file CER/WER/SECS metrics sorted by CER descending.", + ) + parser.add_argument( + "--filewise_metrics_topk_log", + type=int, + default=20, + help="Number of worst CER samples to print on rank 0.", + ) + + return parser.parse_args() + + +def main(): + args = parse_args() + if int(args.num_eval_runs) <= 0: + raise RuntimeError("--num_eval_runs must be >= 1") + if int(args.num_eval_runs) > 1 and not args.save_filewise_metrics: + args.save_filewise_metrics = True + print("[info] --num_eval_runs > 1, enabling --save_filewise_metrics for run-averaged final metrics.", flush=True) + os.makedirs(args.out_dir, exist_ok=True) + os.makedirs(get_audio_out_dir(args), exist_ok=True) + + distributed, rank, local_rank, world_size, device_index = setup_distributed() + + if args.profile_multiturn_inference and args.batch_size != 1: + raise RuntimeError( + "--profile_multiturn_inference requires --batch_size=1 per process. " + "Use multiple GPUs/processes for parallelism instead of increasing batch_size." + ) + + if args.profile_pad_max_sec < args.profile_pad_min_sec: + raise RuntimeError("--profile_pad_max_sec must be >= --profile_pad_min_sec.") + + target_device = torch.device(f"cuda:{device_index}" if torch.cuda.is_available() and device_index >= 0 else "cpu") + target_dtype = getattr(torch, args.inference_dtype) + torch.set_default_dtype(target_dtype) + + hostname = socket.gethostname() + cuda_name = torch.cuda.get_device_name(target_device) if torch.cuda.is_available() and device_index >= 0 else "cpu" + + all_rank_print( + rank, + f"host={hostname} local_rank={local_rank} world_size={world_size} " + f"device={target_device} device_name={cuda_name}", + ) + + model = build_model_and_codec(args, target_device, target_dtype) + codec_sil_codes = model.codec_sil_codes + + if args.debug_dtype: + handles, stats, examples = attach_dtype_counter(model) + else: + handles = stats = examples = None + + with fp32_precision(): + intelligibility = Intelligibility("stt_en_fastconformer_transducer_large", reuse_asr_hyps=False).reset() + secs_metric = SECS("titanet_large").reset() + + eval_dataset = EvalJSONLDataset(args.datasets_json_path, num_turns=args.num_turns) + if args.sort_by_text_token_count: + eval_dataset = SortedByTextTokenCountDataset( + dataset=eval_dataset, + model=model, + max_eval_turns=args.max_eval_turns, + descending=bool(args.sort_text_token_count_descending), + ) + sort_dir = "descending" if args.sort_text_token_count_descending else "ascending" + rank0_print(rank, f"[info] Sorted evaluation samples by summed text-token count {sort_dir}.") + + if distributed: + sampler = DistributedSampler( + eval_dataset, + num_replicas=world_size, + rank=rank, + shuffle=False, + drop_last=False, + ) + else: + sampler = SequentialSampler(eval_dataset) + + if args.debug_gpu_assignment: + if distributed: + assigned_sampler_indices = list(iter(sampler)) + assigned_dataset_indices = [ + int(eval_dataset[i].get("__dataset_index__", -1)) + for i in assigned_sampler_indices + ] + repeated_on_rank = len(assigned_dataset_indices) - len(set(assigned_dataset_indices)) + all_rank_print( + rank, + f"assigned {len(assigned_dataset_indices)} / {len(eval_dataset)} samples " + f"to gpu={local_rank}; repeated_on_this_rank={repeated_on_rank}; " + f"dataset_indices={assigned_dataset_indices}; " + f"text_token_counts={[int(eval_dataset[i].get('__text_token_count__', -1)) for i in assigned_sampler_indices]}", + ) + else: + assigned_dataset_indices = [ + int(eval_dataset[i].get("__dataset_index__", -1)) + for i in range(len(eval_dataset)) + ] + all_rank_print( + rank, + f"assigned {len(assigned_dataset_indices)} samples to single process: " + f"dataset_indices={assigned_dataset_indices}; " + f"text_token_counts={[int(eval_dataset[i].get('__text_token_count__', -1)) for i in range(len(eval_dataset))]}", + ) + + collate_fn = partial( + collate_and_tokenize_custom, + model=model, + extra_duration_thrshould=1.5, + sample_rate=model.sample_rate, + root_path=args.audio_dir, + emulate_duplex_inference=args.emulate_duplex_inference, + add_interruption_token=args.add_interruption_token, + pad_factor_text_speech=args.pad_factor_text_speech, + force_interruption=args.force_interruption, + normalize_audio_volume=args.normalize_volume, + use_librosa=args.use_librosa, + profile_multiturn_inference=args.profile_multiturn_inference, + max_eval_turns=args.max_eval_turns, + ) + + dataloader = DataLoader( + dataset=eval_dataset, + batch_size=args.batch_size, + sampler=sampler, + collate_fn=collate_fn, + num_workers=args.num_workers, + pin_memory=True, + drop_last=False, + ) + + speaker_wav = load_speaker_wav_if_needed(args, model, target_dtype) + + if distributed: + dist.barrier() + + start_time = time.time() + num_processed = 0 + rank_filewise_rows = [] + + for run_id in range(int(args.num_eval_runs)): + if distributed and hasattr(sampler, "set_epoch"): + sampler.set_epoch(run_id) + + if args.debug_gpu_assignment: + all_rank_print(rank, f"starting eval run {run_id + 1}/{int(args.num_eval_runs)}") + + for batch_id, inputs in enumerate(dataloader): + batch_indices = inputs.get("dataset_indices", []) + num_processed += len(batch_indices) + + if args.debug_gpu_assignment: + all_rank_print( + rank, + f"run_id={run_id} gpu={local_rank} batch_id={batch_id} " + f"dataset_indices={batch_indices} " + f"text_token_counts={inputs.get('text_token_counts', [])} " + f"target_paths={inputs.get('target_audio_paths', [])}", + ) + + inputs = prepare_inputs_for_device(inputs, model, args, target_dtype, speaker_wav=speaker_wav) + + finalize_output, profile_turn_frame_ranges, profile_decode_start_frame = run_generation( + model=model, + inputs=inputs, + args=args, + codec_sil_codes=codec_sil_codes, + ) + + audio_f32_for_metrics, audio_len_for_metrics, asr_hyps_for_metrics = update_metrics_and_save_audio( + model=model, + inputs=inputs, + finalize_output=finalize_output, + profile_turn_frame_ranges=profile_turn_frame_ranges, + profile_decode_start_frame=profile_decode_start_frame, + intelligibility=intelligibility, + secs_metric=secs_metric, + args=args, + rank=rank, + run_id=run_id, + ) + + if args.save_filewise_metrics: + filewise_rows = compute_filewise_metrics_for_batch( + rank=rank, + model=model, + inputs=inputs, + audio_f32=audio_f32_for_metrics, + audio_len=audio_len_for_metrics, + asr_hyps=asr_hyps_for_metrics, + run_id=run_id, + ) + rank_filewise_rows.extend(filewise_rows) + + if args.debug_dtype and batch_id == 0 and run_id == 0: + report_dtype_stats(handles, stats, examples, rank=rank) + + elapsed = time.time() - start_time + + rank_metrics = compute_and_save_rank_metrics( + args=args, + rank=rank, + world_size=world_size, + num_processed=num_processed, + elapsed=elapsed, + intelligibility=intelligibility, + secs_metric=secs_metric, + ) + + all_rank_print(rank, f"saved rank metrics: {json.dumps(rank_metrics, sort_keys=True)}") + + if args.save_filewise_metrics: + rank_filewise_rows.sort( + key=lambda x: ( + x.get("cer") is not None, + float(x.get("cer")) if x.get("cer") is not None else -1.0, + ), + reverse=True, + ) + + rank_filewise_path = os.path.join(args.out_dir, f"filewise_metrics_rank{rank:04d}.jsonl") + write_jsonl(rank_filewise_path, rank_filewise_rows) + all_rank_print(rank, f"saved filewise metrics: {rank_filewise_path}") + + if distributed: + dist.barrier() + + merge_metrics_on_rank0(args, rank, world_size) + + if args.save_filewise_metrics: + if distributed: + dist.barrier() + + filewise_rows = merge_filewise_metrics_on_rank0(args, rank, world_size) + + if rank == 0: + save_filewise_final_summary(args, filewise_rows) + + cleanup_distributed() + + +if __name__ == "__main__": + main() From 7cea0ef3cd87b961f48a9a20cae58c3665eca8bc Mon Sep 17 00:00:00 2001 From: Shehzeen Hussain Date: Wed, 3 Jun 2026 16:25:59 -0700 Subject: [PATCH 069/109] phoneme pad for short turns Signed-off-by: Shehzeen Hussain Signed-off-by: Edresson Casanova --- .../magpietts/easy_magpietts_lhotse_multiturn.yaml | 1 + .../data/text_to_speech_dataset_lhotse_multiturn.py | 10 ++++++++++ nemo/collections/tts/models/easy_magpietts.py | 1 + 3 files changed, 12 insertions(+) diff --git a/examples/tts/conf/magpietts/easy_magpietts_lhotse_multiturn.yaml b/examples/tts/conf/magpietts/easy_magpietts_lhotse_multiturn.yaml index 26c61172de4e..e0571290516b 100644 --- a/examples/tts/conf/magpietts/easy_magpietts_lhotse_multiturn.yaml +++ b/examples/tts/conf/magpietts/easy_magpietts_lhotse_multiturn.yaml @@ -95,6 +95,7 @@ model: phoneme_corruption_type: "repeat_skip_unk" # "repeat_skip_unk" or "complete_channel" phoneme_turn_dropout_batch_prob: 0.0 # prob of applying turn dropout to a sample phoneme_turn_dropout_turn_prob: 0.0 # prob of dropping each phoneme turn within a sample + phoneme_turn_max_words_to_drop: 0 # turns with <= this many words keep phoneme tokens as pad_id phoneme_tokenizer: _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPABPETokenizer diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py index 6ae5642d0e55..4b4dcbea6e53 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py @@ -143,6 +143,7 @@ class MagpieTTSLhotseMultiturnDataset(torch.utils.data.Dataset): are not set externally. Defaults to None. text_context_remapping: Dict defining mapping of multiple text contexts to a single text context. text_context_remapping_prob: Probability of remapping the original text context to a remapped text context. + phoneme_turn_max_words_to_drop: Turns with this many words or fewer keep phoneme tokens as pad_id. """ def __init__( @@ -174,6 +175,7 @@ def __init__( add_text_bos: bool = False, phoneme_turn_dropout_batch_prob: float = 0.0, phoneme_turn_dropout_turn_prob: float = 0.0, + phoneme_turn_max_words_to_drop: int = 2, ): super().__init__() self.sample_rate = sample_rate @@ -207,6 +209,7 @@ def __init__( self.add_text_bos = add_text_bos self.phoneme_turn_dropout_batch_prob = phoneme_turn_dropout_batch_prob self.phoneme_turn_dropout_turn_prob = phoneme_turn_dropout_turn_prob + self.phoneme_turn_max_words_to_drop = phoneme_turn_max_words_to_drop self.frame_length = (self.codec_model_samples_per_frame / codec_model_input_sample_rate) * frame_stacking_factor @@ -320,6 +323,7 @@ def _align_codebooks(t): ignore_phoneme_languages=self.ignore_phoneme_languages, pad_id=self.phoneme_tokenizer.pad, eos_id=self.phoneme_tokenizer.eos_token_id, bos_id=self.phoneme_tokenizer.bos_token_id, phoneme_turn_dropout_batch_prob=self.phoneme_turn_dropout_batch_prob, phoneme_turn_dropout_turn_prob=self.phoneme_turn_dropout_turn_prob, + phoneme_turn_max_words_to_drop=self.phoneme_turn_max_words_to_drop, apply_turn_dropout=self.dataset_type == 'train', ) else: @@ -773,6 +777,7 @@ def collate_phoneme_channel( bos_id: int = -3, phoneme_turn_dropout_batch_prob: float = 0.0, phoneme_turn_dropout_turn_prob: float = 0.0, + phoneme_turn_max_words_to_drop: int = 2, apply_turn_dropout: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: tokens = [] @@ -783,6 +788,7 @@ def collate_phoneme_channel( ignore_phoneme_languages, pad_id, eos_id, bos_id, phoneme_turn_dropout_batch_prob=phoneme_turn_dropout_batch_prob, phoneme_turn_dropout_turn_prob=phoneme_turn_dropout_turn_prob, + phoneme_turn_max_words_to_drop=phoneme_turn_max_words_to_drop, apply_turn_dropout=apply_turn_dropout, ) tokens.append(token) @@ -802,6 +808,7 @@ def build_phoneme_channel( bos_id: int = -3, phoneme_turn_dropout_batch_prob: float = 0.0, phoneme_turn_dropout_turn_prob: float = 0.0, + phoneme_turn_max_words_to_drop: int = 2, apply_turn_dropout: bool = False, ) -> tuple[torch.Tensor, bool]: language = cut.lang if cut.has_custom("lang") else next((sup.language for sup in cut.supervisions if sup.has_custom("language")), "en") @@ -821,6 +828,9 @@ def build_phoneme_channel( if apply_dropout and random.random() < phoneme_turn_dropout_turn_prob: dropout_applied = True continue + + if len(_strip_timestamps(supervision.text).split()) <= phoneme_turn_max_words_to_drop: + continue if isinstance(phoneme_tokenizer, IPABPETokenizer): ipa_text = _get_supervision_ipa_text(supervision) diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index 775c61f207b5..43c8ec13bd9b 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -2059,6 +2059,7 @@ def get_lhotse_dataloader(self, dataset_cfg, mode='train') -> torch.utils.data.D add_text_bos=self.cfg.get("add_text_bos", False), phoneme_turn_dropout_batch_prob=self.cfg.get("phoneme_turn_dropout_batch_prob", 0.0), phoneme_turn_dropout_turn_prob=self.cfg.get("phoneme_turn_dropout_turn_prob", 0.0), + phoneme_turn_max_words_to_drop=self.cfg.get("phoneme_turn_max_words_to_drop", 2), ) dataset = FallbackDataset(dataset) else: From ed8007ff145e8c852f2e34ddda0ca50850cb2dd5 Mon Sep 17 00:00:00 2001 From: Shehzeen Hussain Date: Thu, 4 Jun 2026 11:04:35 -0700 Subject: [PATCH 070/109] ignore punctuation in word counting Signed-off-by: Shehzeen Hussain Signed-off-by: Edresson Casanova --- .../tts/data/text_to_speech_dataset_lhotse_multiturn.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py index 4b4dcbea6e53..0bdb8c2009a6 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py @@ -74,6 +74,11 @@ def _strip_timestamps( return _SPACE_PATTERN.sub(" ", text).strip() +def _count_words_ignoring_punctuation(text: str, _PUNCTUATION_PATTERN=re.compile(r"[^\w\s]+")) -> int: + text = _PUNCTUATION_PATTERN.sub("", _strip_timestamps(text)) + return len(text.split()) + + def _get_supervision_ipa_text(supervision) -> str: """Return IPA for a supervision, preferring top-level field over custom.""" ipa_text = getattr(supervision, "ipa", None) @@ -828,8 +833,8 @@ def build_phoneme_channel( if apply_dropout and random.random() < phoneme_turn_dropout_turn_prob: dropout_applied = True continue - - if len(_strip_timestamps(supervision.text).split()) <= phoneme_turn_max_words_to_drop: + + if _count_words_ignoring_punctuation(supervision.text) <= phoneme_turn_max_words_to_drop: continue if isinstance(phoneme_tokenizer, IPABPETokenizer): From 5d130dcfc025a7913018769cf09aca19698841eb Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Sat, 6 Jun 2026 05:55:35 -0700 Subject: [PATCH 071/109] Add turn based metrics Signed-off-by: Edresson Casanova --- ...nference_multiturn_multigpu_turn_metric.py | 2247 +++++++++++++++++ 1 file changed, 2247 insertions(+) create mode 100644 examples/tts/easy_magpietts_inference_multiturn_multigpu_turn_metric.py diff --git a/examples/tts/easy_magpietts_inference_multiturn_multigpu_turn_metric.py b/examples/tts/easy_magpietts_inference_multiturn_multigpu_turn_metric.py new file mode 100644 index 000000000000..deb2244f2376 --- /dev/null +++ b/examples/tts/easy_magpietts_inference_multiturn_multigpu_turn_metric.py @@ -0,0 +1,2247 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +""" +Multi-GPU EasyMagpieTTS / NemotronTTS multiturn inference evaluation. + +Key behavior: + - Uses torchrun env vars RANK, LOCAL_RANK, WORLD_SIZE for sharding/GPU assignment. + - Does NOT initialize torch.distributed. This avoids NeMo ASR doing distributed + collectives during metric computation. + - Generation runs first for all assigned samples. + - ASR and SECS are loaded only after generation is done and the TTS/codec model + has been deleted from GPU memory. + - ASR and SECS are loaded sequentially: ASR first, then released; SECS second. + - For --profile_multiturn_inference, metrics are computed turn-by-turn. + Final filewise outputs are grouped back to one row per original sample, with + lists for asr_hyp/reference_text/cer_turns/wer_turns/secs_turns. + - Uses DistributedSampler with explicit rank/world_size. A few repeated samples + may appear when len(dataset) is not divisible by world_size. Filewise final + metrics deduplicate sampler-padding repeats by (run_id, dataset_index, + turn_id), then group turns into one row per sample with metric lists, while + preserving --num_eval_runs repetitions. + - --sort_by_text_token_count sorts samples by total text-token count before + sharding to improve GPU load balance. + - Saves audio in out_dir/audios/. + - Saves metrics in out_dir/. + +Recommended single-node torchrun: + CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ + torchrun --standalone --nproc_per_node=8 easy_magpietts_inference_multiturn_multigpu_postgen_metrics.py ... + +Recommended single-node srun wrapper: + srun --nodes=1 --ntasks=1 --ntasks-per-node=1 --container-image=... \ + bash -lc 'torchrun --standalone --nproc_per_node=8 easy_magpietts_inference_multiturn_multigpu_postgen_metrics.py ...' +""" + +import argparse +import csv +import json +import os +import socket +import time +from copy import deepcopy +from functools import partial +from typing import Any, Dict, Iterable, List, Tuple + +import librosa +import soundfile as sf +import torch +from omegaconf import open_dict +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import DataLoader, Dataset, DistributedSampler, SequentialSampler + +from nemo.collections.audio.parts.utils.transforms import resample +from nemo.collections.asr.metrics.wer import word_error_rate +from nemo.collections.speechlm2.parts.metrics.asr_cer_wer import Intelligibility +from nemo.collections.speechlm2.parts.metrics.secs import SECS +from nemo.collections.speechlm2.parts.precision import fp32_precision +from nemo.collections.tts.models import AudioCodecModel +from nemo.collections.tts.models.easy_magpietts_inference import EasyMagpieTTSInferenceModel +from nemo.collections.tts.modules.audio_codec_modules import VectorQuantizerIndexConverter +from nemo.collections.tts.modules.magpietts_modules import CodecHelper +from nemo.collections.tts.parts.utils.tts_dataset_utils import normalize_volume +from nemo.utils import logging +from whisper_normalizer.english import EnglishTextNormalizer + + +torch.set_float32_matmul_precision("medium") +torch.backends.cudnn.allow_tf32 = True +torch.backends.cuda.matmul.allow_tf32 = True + + +# ----------------------------- +# Rank / file helpers +# ----------------------------- + + +def get_rank_info() -> Tuple[bool, int, int, int]: + world_size = int(os.environ.get("WORLD_SIZE", "1")) + rank = int(os.environ.get("RANK", os.environ.get("SLURM_PROCID", "0"))) + local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("SLURM_LOCALID", "0"))) + distributed = world_size > 1 + return distributed, rank, local_rank, world_size + + +def get_visible_device_index(local_rank: int) -> int: + if not torch.cuda.is_available(): + return -1 + ndev = torch.cuda.device_count() + if ndev <= 0: + return -1 + return local_rank % ndev + + +def setup_distributed(): + """ + Do not initialize torch.distributed. + + We only need RANK/LOCAL_RANK/WORLD_SIZE for rank assignment and dataset + sharding. Initializing a process group can cause NeMo ASR to run distributed + collectives during transcribe(), which may hang when ranks have different + audio lengths or workloads. + """ + distributed, rank, local_rank, world_size = get_rank_info() + device_index = get_visible_device_index(local_rank) + + if torch.cuda.is_available() and device_index >= 0: + torch.cuda.set_device(device_index) + + return distributed, rank, local_rank, world_size, device_index + + +def cleanup_distributed(): + return + + +def all_rank_print(rank: int, msg: str): + print(f"[rank={rank}] {msg}", flush=True) + + +def rank0_print(rank: int, msg: str): + if rank == 0: + print(msg, flush=True) + + +def get_audio_out_dir(args) -> str: + return os.path.join(args.out_dir, "audios") + + +def get_generated_turn_audio_dir(args) -> str: + return os.path.join(get_audio_out_dir(args), "metric_turns") + + +def get_context_metric_audio_dir(args) -> str: + return os.path.join(get_audio_out_dir(args), "metric_context") + + +def write_json(path: str, obj: Dict[str, Any]): + os.makedirs(os.path.dirname(path), exist_ok=True) + tmp_path = path + ".tmp" + with open(tmp_path, "w", encoding="utf-8") as f: + json.dump(obj, f, indent=2, sort_keys=True, ensure_ascii=False) + os.replace(tmp_path, path) + + +def write_text_atomic(path: str, text: str): + os.makedirs(os.path.dirname(path), exist_ok=True) + tmp_path = path + ".tmp" + with open(tmp_path, "w", encoding="utf-8") as f: + f.write(text) + os.replace(tmp_path, path) + + +def write_jsonl(path: str, rows: List[Dict[str, Any]]): + os.makedirs(os.path.dirname(path), exist_ok=True) + tmp_path = path + ".tmp" + with open(tmp_path, "w", encoding="utf-8") as f: + for row in rows: + f.write(json.dumps(row, sort_keys=True, ensure_ascii=False) + "\n") + os.replace(tmp_path, path) + + +def wait_for_files(paths: List[str], timeout_sec: float = 7200.0, poll_sec: float = 5.0): + start = time.time() + while True: + missing = [p for p in paths if not os.path.exists(p)] + if not missing: + return + if time.time() - start > timeout_sec: + raise TimeoutError("Timed out waiting for files:\n" + "\n".join(missing)) + time.sleep(poll_sec) + + +def wait_for_rank_metric_files(args, world_size: int): + paths = [os.path.join(args.out_dir, f"metrics_rank{r:04d}.json") for r in range(world_size)] + wait_for_files(paths) + + +def wait_for_rank_filewise_metric_files(args, world_size: int): + paths = [os.path.join(args.out_dir, f"filewise_metrics_rank{r:04d}.jsonl") for r in range(world_size)] + wait_for_files(paths) + + +def scalarize_metric_value(v: Any): + if torch.is_tensor(v): + if v.numel() == 1: + return float(v.detach().cpu().item()) + return v.detach().cpu().tolist() + try: + import numpy as np + + if isinstance(v, np.generic): + return float(v.item()) + except Exception: + pass + if isinstance(v, (int, float, str, bool)) or v is None: + return v + return str(v) + + +def metric_dict_to_jsonable(d: Dict[str, Any]) -> Dict[str, Any]: + return {str(k): scalarize_metric_value(v) for k, v in d.items()} + + +def safe_metric_scalar(metric_dict: Dict[str, Any], preferred_keys: List[str]): + for key in preferred_keys: + if key in metric_dict: + value = metric_dict[key] + if torch.is_tensor(value): + return float(value.detach().cpu().item()) + return float(value) + return None + + +def get_first_metric(metrics: Dict[str, Any], names: List[str], default=None): + for name in names: + if name in metrics: + return metrics[name] + return default + + +def format_final_metric_text(final_metrics: Dict[str, Any]) -> str: + intelligibility = final_metrics.get("intelligibility", {}) + secs = final_metrics.get("secs", {}) + + cer = get_first_metric(intelligibility, ["cer", "cer_dataset"]) + wer = get_first_metric(intelligibility, ["wer", "wer_dataset"]) + secs_value = get_first_metric(secs, ["secs", "secs_dataset"]) + + def fmt(x): + if x is None: + return "nan" + try: + return f"{float(x):.10f}" + except Exception: + return str(x) + + return f"Average CER: {fmt(cer)}\nAverage WER: {fmt(wer)}\nSECS: {fmt(secs_value)}\n" + + +def format_filewise_final_metric_text(filewise_summary: Dict[str, Any]) -> str: + def fmt(x): + if x is None: + return "nan" + try: + return f"{float(x):.10f}" + except Exception: + return str(x) + + return ( + f"Average CER: {fmt(filewise_summary.get('cer'))}\n" + f"Average WER: {fmt(filewise_summary.get('wer'))}\n" + f"SECS: {fmt(filewise_summary.get('secs'))}\n" + ) + + +def write_filewise_csv(path: str, rows: List[Dict[str, Any]]): + """Write sample-level filewise metrics. + + Several fields are lists (turn_ids, reference_text, asr_hyp, cer_turns, + etc.), so they are JSON-encoded inside CSV cells. + """ + os.makedirs(os.path.dirname(path), exist_ok=True) + tmp_path = path + ".tmp" + + fieldnames = [ + "run_id", + "dataset_index", + "rank", + "num_turns", + "cer", + "wer", + "secs", + "turn_ids", + "cer_turns", + "wer_turns", + "secs_turns", + "pred_audio_seconds_turns", + "target_audio_path", + "context_audio_path", + "pred_audio_paths", + "reference_text", + "asr_hyp", + ] + + def csv_value(v): + if isinstance(v, (list, dict)): + return json.dumps(v, ensure_ascii=False) + return v + + with open(tmp_path, "w", encoding="utf-8", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + for row in rows: + writer.writerow({k: csv_value(row.get(k, None)) for k in fieldnames}) + + os.replace(tmp_path, path) + +# ----------------------------- +# Dataset helpers +# ----------------------------- + + +def _combined_audio_name(first_audio_filepath: str, paths: List[str]) -> str: + base_names = [os.path.splitext(os.path.basename(p))[0] for p in paths if p] + ext = os.path.splitext(paths[-1])[1] if paths and paths[-1] else "" + combined_name = "_".join(base_names) + ext + dir_name = os.path.dirname(first_audio_filepath) + return os.path.join(dir_name, combined_name) if dir_name else combined_name + + +class EvalJSONLDataset(Dataset): + def __init__(self, file_path: str, num_turns: int = 1): + self.samples = [] + raw_samples = [] + + with open(file_path, "r", encoding="utf-8") as f: + for line_idx, line in enumerate(f, 1): + line = line.strip() + if not line: + continue + try: + sample = json.loads(line) + sample["__dataset_index__"] = len(raw_samples) + raw_samples.append(sample) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON on line {line_idx}: {e}") + + if num_turns <= 1: + self.samples = raw_samples + return + + single_turn_by_speaker = {} + for sample in raw_samples: + if isinstance(sample["text"], list): + self.samples.append(sample) + else: + speaker = sample.get("speaker", "unknown") + single_turn_by_speaker.setdefault(speaker, []).append(sample) + + synthetic_index = len(raw_samples) + for _, speaker_samples in single_turn_by_speaker.items(): + buffer_texts, buffer_paths = [], [] + first_sample_meta = None + + for sample in speaker_samples: + if not buffer_texts: + first_sample_meta = dict(sample) + + buffer_texts.append(sample["text"]) + buffer_paths.append(sample.get("audio_filepath", "")) + + if len(buffer_texts) == num_turns: + first_sample_meta["text"] = buffer_texts + first_sample_meta["audio_filepath"] = _combined_audio_name( + first_sample_meta.get("audio_filepath", ""), + buffer_paths, + ) + first_sample_meta["__dataset_index__"] = synthetic_index + synthetic_index += 1 + self.samples.append(first_sample_meta) + buffer_texts, buffer_paths, first_sample_meta = [], [], None + + if buffer_texts and first_sample_meta is not None: + first_sample_meta["text"] = buffer_texts + first_sample_meta["audio_filepath"] = _combined_audio_name( + first_sample_meta.get("audio_filepath", ""), + buffer_paths, + ) + first_sample_meta["__dataset_index__"] = synthetic_index + synthetic_index += 1 + self.samples.append(first_sample_meta) + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + return self.samples[idx] + + +def _sample_text_segments_for_count(sample: Dict[str, Any], max_eval_turns=None) -> List[str]: + text_data = sample.get("text", "") + if isinstance(text_data, list): + segments = text_data + if max_eval_turns is not None: + segments = segments[: int(max_eval_turns)] + return [str(x) for x in segments] + return [str(text_data)] + + +def estimate_text_token_count(sample: Dict[str, Any], model, max_eval_turns=None) -> int: + main_tokenizer_name = list(model.cfg.text_tokenizers.keys())[0] + total = 0 + for segment in _sample_text_segments_for_count(sample, max_eval_turns=max_eval_turns): + total += len(model.tokenizer.encode(segment, tokenizer_name=main_tokenizer_name)) + 1 + return int(total) + + +class SortedByTextTokenCountDataset(Dataset): + def __init__(self, dataset: Dataset, model, max_eval_turns=None, descending: bool = True): + self.dataset = dataset + scored = [] + for i in range(len(dataset)): + sample = dict(dataset[i]) + token_count = estimate_text_token_count(sample, model=model, max_eval_turns=max_eval_turns) + sample["__text_token_count__"] = int(token_count) + scored.append((token_count, i, sample)) + + scored.sort(key=lambda x: (x[0], -x[1]), reverse=bool(descending)) + self.indices = [i for _, i, _ in scored] + self.token_counts = {i: int(tok) for tok, i, _ in scored} + + def __len__(self): + return len(self.indices) + + def __getitem__(self, local_idx): + original_idx = self.indices[local_idx] + sample = dict(self.dataset[original_idx]) + sample["__text_token_count__"] = self.token_counts[original_idx] + return sample + + +# ----------------------------- +# Audio / collate helpers +# ----------------------------- + + +def _resolve_audio_path(path, root_path): + if path is None: + return None + if root_path is not None and not os.path.isabs(path): + return os.path.join(root_path, path) + return path + + +def _load_audio(path, sample_rate, normalize=True, use_librosa=False): + if path is None or not os.path.exists(path): + return torch.zeros(1, dtype=torch.float32) + + if use_librosa: + wav, sr = librosa.load(path, sr=sample_rate, mono=True) + if normalize: + wav = normalize_volume(wav) + return torch.as_tensor(wav, dtype=torch.float32) + + wav, sr = sf.read(path, dtype="float32") + if wav.ndim > 1: + wav = wav.mean(axis=1) + + if normalize: + wav = normalize_volume(wav) + + wav = torch.as_tensor(wav, dtype=torch.float32).unsqueeze(0) + return resample(wav, sr, sample_rate).squeeze(0) + + +def collate_and_tokenize_custom( + batch, + model, + extra_duration_thrshould=1.3, + sample_rate=22050, + root_path=None, + emulate_duplex_inference=False, + add_interruption_token=False, + pad_factor_text_speech=10, + force_interruption=False, + normalize_audio_volume=True, + use_librosa=False, + profile_multiturn_inference=False, + max_eval_turns=None, +): + main_tokenizer_name = list(model.cfg.text_tokenizers.keys())[0] + + if max_eval_turns is not None: + max_eval_turns = int(max_eval_turns) + if max_eval_turns <= 0: + raise ValueError("--max_eval_turns must be > 0 when provided.") + + truncated_batch = [] + for s in batch: + s = dict(s) + if isinstance(s["text"], list): + s["text"] = s["text"][:max_eval_turns] + if isinstance(s.get("user_audio_file_path"), list): + s["user_audio_file_path"] = s["user_audio_file_path"][:max_eval_turns] + truncated_batch.append(s) + batch = truncated_batch + + is_profile = profile_multiturn_inference + is_duplex = emulate_duplex_inference and not is_profile + + out_dict = { + "duplex_multiturn": is_duplex, + "regular_multiturn": (not is_duplex) and (not is_profile), + "profile_multiturn": is_profile, + "dataset_indices": [int(s.get("__dataset_index__", -1)) for s in batch], + "text_token_counts": [int(s.get("__text_token_count__", -1)) for s in batch], + } + + tokenized_list = [] + batched_turns = [] + batched_turn_lens = [] + valid_turn_masks = [] + + if is_duplex: + for s in batch: + text_data = s["text"] + + if isinstance(text_data, list): + full_ids = [] + for segment in text_data: + seg_ids = model.tokenizer.encode(segment, tokenizer_name=main_tokenizer_name) + [model.eos_id] + pad_ids = [model.pad_id] * (len(seg_ids) * pad_factor_text_speech) + + if force_interruption: + fname = s["audio_filepath"] + no_ext = fname.split(".")[0] + sample_id = int(no_ext.split("_")[-1]) + case = sample_id % 3 + + if case == 0: + if len(seg_ids) >= 2: + seg_ids[-2] = model.interruption_token_id + seg_ids[-1] = model.pad_id + else: + pad_ids[0] = model.interruption_token_id + elif case == 1: + eos_idx = min(6, len(pad_ids) - 1) + pad_ids[eos_idx] = model.interruption_token_id + else: + pad_ids[0] = model.interruption_token_id + + elif add_interruption_token: + eos_idx = int(len(pad_ids) * 0.7) + pad_ids[eos_idx] = model.interruption_token_id + + full_ids.extend(seg_ids) + full_ids.extend(pad_ids) + + tokenized_list.append(torch.as_tensor(full_ids, dtype=torch.long)) + else: + tokenized_list.append( + torch.as_tensor( + model.tokenizer.encode(text_data, tokenizer_name=main_tokenizer_name) + [model.eos_id], + dtype=torch.long, + ) + ) + + prefix = torch.full((25,), model.pad_id, dtype=torch.long) + tokenized_list = [torch.cat([prefix, x]) for x in tokenized_list] + out_dict["input_lengths"] = torch.tensor([len(x) for x in tokenized_list], dtype=torch.long) + out_dict["input_ids"] = pad_sequence(tokenized_list, batch_first=True, padding_value=model.pad_id) + + else: + max_turns = 1 + for s in batch: + if isinstance(s["text"], list): + max_turns = max(max_turns, len(s["text"])) + + for t in range(max_turns): + turn_t_tokens, turn_t_lens, turn_t_valid = [], [], [] + + for s in batch: + text_data = s["text"] + + if isinstance(text_data, list): + if t < len(text_data): + seg_ids = model.tokenizer.encode(text_data[t], tokenizer_name=main_tokenizer_name) + [ + model.eos_id + ] + turn_t_tokens.append(torch.as_tensor(seg_ids, dtype=torch.long)) + turn_t_lens.append(len(seg_ids)) + turn_t_valid.append(True) + else: + turn_t_tokens.append(torch.as_tensor([model.pad_id], dtype=torch.long)) + turn_t_lens.append(1) + turn_t_valid.append(False) + else: + if t == 0: + seg_ids = model.tokenizer.encode(text_data, tokenizer_name=main_tokenizer_name) + [ + model.eos_id + ] + turn_t_tokens.append(torch.as_tensor(seg_ids, dtype=torch.long)) + turn_t_lens.append(len(seg_ids)) + turn_t_valid.append(True) + else: + turn_t_tokens.append(torch.as_tensor([model.pad_id], dtype=torch.long)) + turn_t_lens.append(1) + turn_t_valid.append(False) + + batched_turns.append(pad_sequence(turn_t_tokens, batch_first=True, padding_value=model.pad_id)) + batched_turn_lens.append(torch.tensor(turn_t_lens, dtype=torch.long)) + valid_turn_masks.append(torch.tensor(turn_t_valid, dtype=torch.bool)) + + out_dict["batched_turns"] = batched_turns + out_dict["batched_turn_lens"] = batched_turn_lens + out_dict["valid_turn_masks"] = valid_turn_masks + + audio_list, audio_lengths, target_num_frames = [], [], [] + context_audio_paths = [] + max_turns_for_user_audio = len(batched_turns) if not is_duplex else 0 + + if is_profile and max_turns_for_user_audio > 0: + user_audio_by_turn = [[] for _ in range(max_turns_for_user_audio)] + user_audio_lens_by_turn = [[] for _ in range(max_turns_for_user_audio)] + else: + user_audio_by_turn, user_audio_lens_by_turn = [], [] + + for i, s in enumerate(batch): + audio_path = _resolve_audio_path(s.get("context_audio_filepath"), root_path) + context_audio_paths.append(audio_path) + wav = _load_audio(audio_path, sample_rate, normalize=normalize_audio_volume, use_librosa=use_librosa) + audio_list.append(wav) + audio_lengths.append(len(wav)) + + if is_profile and max_turns_for_user_audio > 0: + user_audio_paths = s.get("user_audio_file_path", None) + + for t in range(max_turns_for_user_audio): + has_valid_text_turn = (isinstance(s["text"], list) and t < len(s["text"])) or ( + not isinstance(s["text"], list) and t == 0 + ) + + if ( + isinstance(user_audio_paths, list) + and t < len(user_audio_paths) + and user_audio_paths[t] + and has_valid_text_turn + ): + ua_path = _resolve_audio_path(user_audio_paths[t], root_path) + ua_wav = _load_audio( + ua_path, + sample_rate=sample_rate, + normalize=normalize_audio_volume, + use_librosa=use_librosa, + ) + else: + ua_wav = torch.zeros(int(2 * sample_rate), dtype=torch.float32) + + user_audio_by_turn[t].append(ua_wav) + user_audio_lens_by_turn[t].append(len(ua_wav)) + + tdur_audio_path = _resolve_audio_path(s["audio_filepath"], root_path) + + if tdur_audio_path and os.path.exists(tdur_audio_path): + wav_dur = _load_audio( + tdur_audio_path, + sample_rate, + normalize=normalize_audio_volume, + use_librosa=use_librosa, + ) + tdur = wav_dur.shape[0] // model.input_samples_per_frame + target_num_frames.append(tdur * extra_duration_thrshould) + else: + if is_duplex: + current_text_len = len(tokenized_list[i]) + target_num_frames.append(current_text_len if isinstance(s["text"], list) else current_text_len * 5) + else: + target_num_frames.append(sum([l[i].item() for l in batched_turn_lens]) * 5) + + max_audio_len = max(audio_lengths) + B = len(audio_lengths) + padded_audio = torch.zeros((B, max_audio_len), dtype=torch.float32) + + for i, wav in enumerate(audio_list): + padded_audio[i, : len(wav)] = wav + + if is_profile and max_turns_for_user_audio > 0: + padded_user_audio_turns, padded_user_audio_turns_lens = [], [] + + for t in range(max_turns_for_user_audio): + turn_lens = user_audio_lens_by_turn[t] + max_turn_audio_len = max(turn_lens) + padded_turn_audio = torch.zeros((B, max_turn_audio_len), dtype=torch.float32) + + for i, wav in enumerate(user_audio_by_turn[t]): + padded_turn_audio[i, : len(wav)] = wav + + padded_user_audio_turns.append(padded_turn_audio) + padded_user_audio_turns_lens.append(torch.tensor(turn_lens, dtype=torch.long)) + + out_dict["user_audio_turns"] = padded_user_audio_turns + out_dict["user_audio_turns_lens"] = padded_user_audio_turns_lens + + raw_turn_texts = [] + for s in batch: + if isinstance(s["text"], list): + raw_turn_texts.append([str(x) for x in s["text"]]) + else: + raw_turn_texts.append([str(s["text"])]) + + out_dict["context_audio"] = padded_audio + out_dict["context_audio_lengths"] = torch.tensor(audio_lengths, dtype=torch.long) + out_dict["context_audio_paths"] = context_audio_paths + out_dict["target_audio_paths"] = [s["audio_filepath"] for s in batch] + out_dict["target_num_frames"] = target_num_frames + out_dict["raw_turn_texts"] = raw_turn_texts + out_dict["raw_text"] = [" ".join(x) for x in raw_turn_texts] + + return out_dict + + +# ----------------------------- +# Model / generation +# ----------------------------- + + +def attach_dtype_counter(model): + handles = [] + stats = {} + examples = {} + + def is_leaf(module): + return len(list(module.children())) == 0 + + def get_dtype(x): + if torch.is_tensor(x): + return str(x.dtype) + if isinstance(x, (list, tuple)): + for t in x: + if torch.is_tensor(t): + return str(t.dtype) + return "other" + + def get_module_group(name): + return name.split(".")[0] if "." in name else name + + def hook_fn(name): + def fn(module, inputs, outputs): + dtype = get_dtype(outputs) + if dtype not in ["torch.float16", "torch.bfloat16", "torch.float32"]: + dtype = "other" + group = get_module_group(name) + if group not in stats: + stats[group] = { + "torch.float16": 0, + "torch.bfloat16": 0, + "torch.float32": 0, + "other": 0, + } + examples[group] = { + "torch.float16": [], + "torch.bfloat16": [], + "torch.float32": [], + "other": [], + } + stats[group][dtype] += 1 + if len(examples[group][dtype]) < 3: + examples[group][dtype].append(module.__class__.__name__) + + return fn + + for name, module in model.named_modules(): + if is_leaf(module): + handles.append(module.register_forward_hook(hook_fn(name))) + return handles, stats, examples + + +def report_dtype_stats(handles, stats, examples, rank=0): + for h in handles: + h.remove() + logging.info(f"[rank={rank}] === DTYPE USAGE PER MODULE ===") + for group, group_stats in stats.items(): + total = sum(group_stats.values()) + if total == 0: + continue + logging.info(f"[rank={rank}] --- {group} ---") + for dtype, count in group_stats.items(): + if count > 0: + logging.info(f"[rank={rank}] {dtype}: {count} ({100 * count / total:.2f}%)") + logging.info(f"[rank={rank}] === DTYPE EXAMPLES ===") + for group, group_examples in examples.items(): + for dtype, mods in group_examples.items(): + if mods: + logging.info(f"[rank={rank}] {group} {dtype}: {mods}") + + +def build_model_and_codec(args, target_device, target_dtype): + model_cfg = EasyMagpieTTSInferenceModel.restore_from(args.checkpoint_path, return_config=True) + + with open_dict(model_cfg): + model_cfg.target = "nemo.collections.tts.models.easy_magpietts_inference.EasyMagpieTTSInferenceModel" + model_cfg.codecmodel_path = args.codec_model_path + model_cfg.train_ds = None + model_cfg.validation_ds = None + model_cfg.run_val_inference = False + model_cfg.use_utmos = False + model_cfg.use_meta_init_for_decoder = True + + if args.phoneme_tokenizer_path and getattr(model_cfg, "phoneme_tokenizer", None) is not None: + model_cfg.phoneme_tokenizer.tokenizer_path = args.phoneme_tokenizer_path + + model = EasyMagpieTTSInferenceModel.restore_from( + args.checkpoint_path, + override_config_path=model_cfg, + map_location=torch.device("cpu"), + ) + model.use_kv_cache_for_inference = True + model.to(dtype=target_dtype) + model.eval().to(target_device) + + model.input_samples_per_frame = int(model.codec_model_samples_per_frame * model.frame_stacking_factor) + model.target_samples_per_frame = model.input_samples_per_frame / (model.sample_rate / model.output_sample_rate) + + codec_model = AudioCodecModel.restore_from(args.codec_model_path, strict=False, map_location=torch.device("cpu")) + if hasattr(codec_model, "discriminator"): + del codec_model.discriminator + codec_model.freeze() + codec_model = codec_model.to(target_device).eval() + + codec_converter = None + if getattr(model, "_codec_converter", None) is not None: + vq_new = deepcopy(model._codec_converter.vector_quantizer_new).to(target_device).eval() + codec_converter = VectorQuantizerIndexConverter( + vector_quantizer_original=codec_model.vector_quantizer, + vector_quantizer_new=vq_new, + ).to(target_device).eval() + + model._codec_helper = CodecHelper(codec_model=codec_model, codec_converter=codec_converter) + model._generate_codec_silence_buffer() + + return model + + +def prepare_inputs_for_device(inputs, model, args, target_dtype, speaker_wav=None): + B = inputs["context_audio"].size(0) + device = model.device + + inputs["context_audio"] = inputs["context_audio"].to(device, dtype=target_dtype) + inputs["context_audio_lengths"] = inputs["context_audio_lengths"].to(device) + + if args.user_custom_speaker_reference and speaker_wav is not None: + inputs["context_audio"] = speaker_wav.repeat(B, 1).detach() + inputs["context_audio_lengths"] = torch.full((B,), speaker_wav.size(-1), dtype=torch.long, device=device) + + if "user_audio_turns" in inputs: + inputs["user_audio_turns"] = [x.to(device, dtype=target_dtype) for x in inputs["user_audio_turns"]] + inputs["user_audio_turns_lens"] = [x.to(device) for x in inputs["user_audio_turns_lens"]] + + return inputs + + +def run_generation(model, inputs, args, codec_sil_codes): + B = inputs["context_audio"].size(0) + device = model.device + profile_turn_frame_ranges = [] + profile_decode_start_frame = 0 + + with torch.inference_mode(): + wav = inputs["context_audio"] + wav_len = inputs["context_audio_lengths"] + codes, codes_lens = model._codec_helper.audio_to_codes(wav, wav_len) + + use_lang = bool(getattr(model, "add_language_to_context_text", False)) + ctx_text = f"[{args.language.upper()}]" if use_lang else "[NO TEXT CONTEXT]" + ctx_text_ids = model.tokenizer.encode(ctx_text, tokenizer_name=model.text_conditioning_tokenizer_name) + + ctx_toks = torch.tensor([ctx_text_ids], dtype=torch.long, device=device).expand(B, -1) + ctx_toks_lens = torch.tensor([len(ctx_text_ids)] * B, dtype=torch.long, device=device) + + state = model.streaming_init( + context_audio_codes=codes, + context_audio_codes_lens=codes_lens, + context_text_tokens=ctx_toks, + context_text_tokens_lens=ctx_toks_lens, + use_cfg=args.use_cfg, + cfg_scale=args.cfg_scale, + use_local_transformer=True, + temperature=args.temperature, + topk=args.topk, + phoneme_input_type="pred", + phoneme_sampling_method="argmax", + use_inference_mode=True, + ) + + if inputs["duplex_multiturn"]: + text = inputs["input_ids"].to(device) + text_lens = inputs["input_lengths"].to(device) + + in_initial_silence = torch.ones(B, dtype=torch.bool, device=device) + in_post_speech_silence = torch.zeros(B, dtype=torch.bool, device=device) + text_exhausted = state.text_tokens_seen >= text_lens + + while not text_exhausted.all(): + state.finished = state.finished & text_exhausted + state.text_finished = state.text_finished & text_exhausted + + if hasattr(state, "phoneme_stream_ended"): + state.phoneme_stream_ended = state.phoneme_stream_ended & text_exhausted + + positions = state.text_tokens_seen.clamp(max=text.size(1) - 1) + current_tokens = text[torch.arange(B, device=device), positions] + current_tokens = torch.where( + text_exhausted, + torch.full_like(current_tokens, model.eos_id), + current_tokens, + ) + + is_pad_or_eos = (current_tokens == model.pad_id) | (current_tokens == model.eos_id) + in_initial_silence = in_initial_silence & is_pad_or_eos + in_post_speech_silence = in_post_speech_silence & is_pad_or_eos + + state, audio_codes, _ = model.streaming_step( + state=state, + text_tokens=current_tokens, + use_inference_mode=True, + ) + + if audio_codes is not None and args.force_speech_sil_codes: + force_silence_mask = in_initial_silence | in_post_speech_silence + if force_silence_mask.any(): + expanded_sil = codec_sil_codes.view(1, -1, 1).expand_as(audio_codes) + mask_3d = force_silence_mask.view(B, 1, 1) + state.all_predictions[-1] = torch.where(mask_3d, expanded_sil, audio_codes) + + in_post_speech_silence = in_post_speech_silence | state.finished + text_exhausted = state.text_tokens_seen >= text_lens + + elif inputs["regular_multiturn"]: + batched_turns = inputs["batched_turns"] + batched_turn_lens = inputs["batched_turn_lens"] + valid_turn_masks = inputs["valid_turn_masks"] + turn_offsets = torch.zeros(B, dtype=torch.long, device=device) + + for t in range(len(batched_turns)): + turn_text = batched_turns[t].to(device) + turn_lens = batched_turn_lens[t].to(device) + valid_mask = valid_turn_masks[t].to(device) + + state.finished = state.finished & (~valid_mask) + state.text_finished = state.text_finished & (~valid_mask) + + if hasattr(state, "phoneme_stream_ended"): + state.phoneme_stream_ended = state.phoneme_stream_ended & (~valid_mask) + + if state.finished.all(): + continue + + turn_offsets = torch.where(valid_mask, state.text_tokens_seen, turn_offsets) + turn_steps = 0 + + while not state.finished.all() and turn_steps < args.max_tts_steps: + turn_steps += 1 + relative_positions = state.text_tokens_seen - turn_offsets + positions = relative_positions.clamp(min=0, max=turn_text.size(1) - 1) + current_tokens = turn_text[torch.arange(B, device=device), positions] + + exhausted = relative_positions >= turn_lens + current_tokens = torch.where( + exhausted, + torch.full_like(current_tokens, model.eos_id), + current_tokens, + ) + + state, _, _ = model.streaming_step( + state=state, + text_tokens=current_tokens, + use_inference_mode=True, + ) + + elif inputs["profile_multiturn"]: + if B != 1: + raise RuntimeError("--profile_multiturn_inference requires --batch_size=1 per process.") + + batched_turns = inputs["batched_turns"] + batched_turn_lens = inputs["batched_turn_lens"] + valid_turn_masks = inputs["valid_turn_masks"] + + for t in range(len(batched_turns)): + turn_text = batched_turns[t].to(device) + turn_lens = batched_turn_lens[t].to(device) + valid_mask = valid_turn_masks[t].to(device) + + if not bool(valid_mask[0].item()): + continue + + state.finished.zero_() + state.text_finished.zero_() + state.audio_prediction_end_idx.fill_(-1) + + if hasattr(state, "turn_text_tokens_seen"): + state.turn_text_tokens_seen.zero_() + if hasattr(state, "phoneme_steps"): + state.phoneme_steps.zero_() + if hasattr(state, "phoneme_stream_ended"): + state.phoneme_stream_ended.zero_() + if hasattr(state, "phoneme_eos_detected"): + state.phoneme_eos_detected.zero_() + state.last_phoneme_tokens = None + + if not model.cfg.get("condition_on_user_speech", False): + if "user_audio_turns" in inputs: + profile_T = int(round(inputs["user_audio_turns"][t].size(-1) / model.input_samples_per_frame)) + profile_seconds = profile_T * model.input_samples_per_frame / model.sample_rate + else: + profile_seconds = args.profile_pad_min_sec + torch.rand((), device=device).item() * ( + args.profile_pad_max_sec - args.profile_pad_min_sec + ) + profile_T = max( + 1, + int(round(profile_seconds * model.sample_rate / model.input_samples_per_frame)), + ) + + profile_tokens = torch.full((1, profile_T), model.pad_id, dtype=torch.long, device=device) + user_audio_channel_embedding = None + + else: + if "user_audio_turns" in inputs: + user_audio = inputs["user_audio_turns"][t] + user_audio_lens = inputs["user_audio_turns_lens"][t] + else: + user_audio = inputs["context_audio"] + user_audio_lens = inputs["context_audio_lengths"] + + user_audio_codes, user_audio_codes_lens = model._codec_helper.audio_to_codes( + user_audio, + user_audio_lens, + ) + + if model._codec_converter is not None: + user_audio_codes = model._codec_converter.convert_original_to_new( + audio_tokens=user_audio_codes, + audio_lens=user_audio_codes_lens, + ).long() + + user_audio_codes, user_audio_codes_lens = model.stack_codes( + user_audio_codes, + user_audio_codes_lens, + model.audio_bos_id, + model.audio_eos_id, + model.frame_stacking_factor, + model.num_audio_codebooks, + ) + + user_audio_embedded = model.embed_audio_tokens(user_audio_codes) + + boundary_trim = model.cfg.get("user_audio_boundary_trim", 4) + boundary_trim = 0 if boundary_trim is None else int(boundary_trim) + + if boundary_trim == 0: + real_start = 0 + real_end = int(user_audio_codes_lens[0].item()) + else: + turn_len_with_special = int(user_audio_codes_lens[0].item()) + real_start = 1 + real_end = max(real_start, turn_len_with_special - 1) + + user_audio_embedded = user_audio_embedded[:, real_start:real_end] + copy_len = user_audio_embedded.size(1) + + if boundary_trim > 0: + trim = min(boundary_trim, copy_len // 2) + if trim > 0: + user_audio_embedded[:, :trim] = 0.0 + user_audio_embedded[:, copy_len - trim :] = 0.0 + + bos_user_pad = torch.zeros( + user_audio_embedded.size(0), + 1, + user_audio_embedded.size(2), + device=user_audio_embedded.device, + dtype=user_audio_embedded.dtype, + ) + user_audio_embedded = torch.cat([bos_user_pad, user_audio_embedded], dim=1) + + profile_T = user_audio_embedded.size(1) + profile_tokens = torch.full((B, profile_T), model.pad_id, dtype=torch.long, device=device) + user_audio_channel_embedding = user_audio_embedded + profile_seconds = profile_T * model.input_samples_per_frame / model.sample_rate + + delay_tokens = int(state.config.training_mode.streaming_speech_delay) + delay_tokens = min(delay_tokens, int(turn_lens[0].item()), profile_T) + + warmup_tokens = turn_text[:, :delay_tokens] + turn_text = turn_text[:, delay_tokens:] + turn_lens = torch.clamp(turn_lens - delay_tokens, min=0) + + if user_audio_channel_embedding is not None and delay_tokens > 0: + warmup_user_audio = user_audio_channel_embedding[:, -delay_tokens:] + user_audio_channel_embedding = user_audio_channel_embedding[:, :-delay_tokens] + profile_tokens = profile_tokens[:, :-delay_tokens] + else: + warmup_user_audio = None + + if profile_tokens.size(1) > 0: + state = model.streaming_prefill_profile( + state=state, + text_tokens=profile_tokens, + use_inference_mode=True, + user_audio_channel_embedding=user_audio_channel_embedding, + ) + + for i in range(delay_tokens): + user_step_emb = warmup_user_audio[:, i] if warmup_user_audio is not None else None + + state.finished.zero_() + state, _, _ = model.streaming_step( + state=state, + text_tokens=warmup_tokens[:, i], + user_audio_channel_embedding=user_step_emb, + prefill_like_step=not bool(model.cfg.get("agent_mask_include_transition_prefix", False)), + prefill_like_is_last_step=(i == delay_tokens - 1), + use_inference_mode=True, + ) + + logging.info(f"[profile_multiturn] turn={t} prefilled {profile_T} steps ({profile_seconds:.2f}s)") + + turn_start_frame = sum(p.size(-1) for p in state.all_predictions) + if t == 0: + state.audio_prediction_start_idx.fill_(turn_start_frame) + profile_decode_start_frame = turn_start_frame + + turn_offset = state.text_tokens_seen.clone() + turn_steps = 0 + saw_audio = False + turn_ended_with_audio_eos = False + + while turn_steps < args.max_tts_steps: + turn_steps += 1 + state.finished.zero_() + + relative_position = state.text_tokens_seen - turn_offset + text_exhausted = relative_position >= turn_lens + + if turn_text.size(1) == 0: + current_tokens = torch.full((B,), model.eos_id, dtype=torch.long, device=device) + else: + position = relative_position.clamp(min=0, max=turn_text.size(1) - 1) + current_tokens = turn_text[torch.arange(B, device=device), position] + current_tokens = torch.where( + text_exhausted, + torch.full_like(current_tokens, model.eos_id), + current_tokens, + ) + + state, audio_codes, _ = model.streaming_step( + state=state, + text_tokens=current_tokens, + use_inference_mode=True, + ) + + if audio_codes is not None and not saw_audio: + saw_audio = True + + if bool(text_exhausted[0].item()) and bool(state.finished[0].item()): + turn_ended_with_audio_eos = True + break + + state.audio_prediction_end_idx.fill_(-1) + state.finished.zero_() + + turn_end_frame = sum(p.size(-1) for p in state.all_predictions) + profile_turn_frame_ranges.append((t, turn_start_frame, turn_end_frame)) + + logging.info( + f"[profile_multiturn] turn={t} steps={turn_steps} " + f"saw_audio={saw_audio} ended_with_audio_eos={turn_ended_with_audio_eos}" + ) + + bos_id = getattr(model, "audio_bos_id", -1) + eos_id = getattr(model, "audio_eos_id", -1) + speaking_id = getattr(model, "audio_user_speaking_id", -1) + speaking_end_id = getattr(model, "audio_user_speaking_end_id", -1) + sil_injection = codec_sil_codes.view(1, -1, 1) + + for step_idx in range(len(state.all_predictions)): + pred = state.all_predictions[step_idx] + mask = (pred == bos_id) | (pred == eos_id) | (pred == speaking_id) | (pred == speaking_end_id) + frame_mask = mask.any(dim=1, keepdim=True) + if frame_mask.any(): + state.all_predictions[step_idx] = torch.where(frame_mask, sil_injection.expand_as(pred), pred) + + if inputs["duplex_multiturn"] or inputs["profile_multiturn"]: + state.audio_prediction_end_idx.fill_(-1) + + finalize_output = model.streaming_finalize(state, use_inference_mode=True) + + return finalize_output, profile_turn_frame_ranges, profile_decode_start_frame + + +def load_speaker_wav_if_needed(args, model, target_dtype): + if args.user_custom_speaker_reference and args.inference_speaker_reference: + return _load_audio( + args.inference_speaker_reference, + model.sample_rate, + normalize=args.normalize_volume, + use_librosa=args.use_librosa, + ).unsqueeze(0).to(model.device, dtype=target_dtype) + + return None + + +# ----------------------------- +# Save generation outputs and metric manifests +# ----------------------------- + + +def write_audio_1d(path: str, wav: torch.Tensor, sr: int): + os.makedirs(os.path.dirname(path), exist_ok=True) + wav_np = wav.detach().cpu().float().numpy() + sf.write(path, wav_np, samplerate=sr) + + +def build_metric_item( + run_id: int, + rank: int, + dataset_index: int, + turn_id: int, + target_audio_path: str, + reference_text: str, + pred_audio_path: str, + context_audio_path: str, + pred_audio_samples: int, + context_audio_samples: int, + output_sample_rate: int, + context_sample_rate: int, +): + return { + "run_id": int(run_id), + "rank": int(rank), + "dataset_index": int(dataset_index), + "turn_id": int(turn_id), + "target_audio_path": target_audio_path, + "reference_text": reference_text, + "pred_audio_path": pred_audio_path, + "context_audio_path": context_audio_path, + "pred_audio_samples": int(pred_audio_samples), + "context_audio_samples": int(context_audio_samples), + "pred_audio_seconds": float(pred_audio_samples / output_sample_rate), + "context_audio_seconds": float(context_audio_samples / context_sample_rate), + "output_sample_rate": int(output_sample_rate), + "context_sample_rate": int(context_sample_rate), + } + + +def save_generation_outputs_and_build_metric_items( + model, + inputs, + finalize_output, + profile_turn_frame_ranges, + profile_decode_start_frame, + args, + rank: int, + run_id: int, +): + device = model.device + B = inputs["context_audio"].size(0) + + with fp32_precision(): + audio_f32 = finalize_output.audio.float() + audio_len = finalize_output.audio_len.int() + + expected_audio_lens = ( + torch.tensor(inputs["target_num_frames"], device=device) * model.target_samples_per_frame + ).int() + + if inputs["duplex_multiturn"]: + text_lens = inputs["input_lengths"].to(device) + audio_len = (text_lens * model.target_samples_per_frame).int() + audio_len = torch.min(audio_len, torch.tensor(audio_f32.size(1), device=device)) + elif inputs["profile_multiturn"]: + audio_len = finalize_output.audio_len.int() + else: + audio_len = torch.min(audio_len, expected_audio_lens) + + audio_out_dir = get_audio_out_dir(args) + metric_turn_dir = get_generated_turn_audio_dir(args) + metric_context_dir = get_context_metric_audio_dir(args) + os.makedirs(audio_out_dir, exist_ok=True) + os.makedirs(metric_turn_dir, exist_ok=True) + os.makedirs(metric_context_dir, exist_ok=True) + + audio_f32_cpu = audio_f32.detach().cpu() + audio_len_cpu = audio_len.detach().cpu() + metric_items = [] + + for i in range(B): + target_path = inputs["target_audio_paths"][i] + base_name = os.path.basename(target_path) + stem, ext = os.path.splitext(base_name) + if not ext: + ext = ".wav" + + dataset_idx = int(inputs.get("dataset_indices", [-1] * B)[i]) + safe_stem = ( + f"run{run_id:02d}_idx{dataset_idx:08d}_{stem}" + if dataset_idx >= 0 + else f"run{run_id:02d}_rank{rank}_{stem}" + ) + + context_len = int(inputs["context_audio_lengths"][i].detach().cpu().item()) + context_wav = inputs["context_audio"][i, :context_len].detach().cpu().float() + context_metric_path = os.path.join(metric_context_dir, f"{safe_stem}_context.wav") + write_audio_1d(context_metric_path, context_wav, model.sample_rate) + + if inputs["profile_multiturn"]: + full_len = int(audio_len_cpu[i].item()) + full_wav_t = audio_f32_cpu[i, :full_len].float() + + samples_per_prediction_frame = model.codec_model_samples_per_frame / ( + model.sample_rate / model.output_sample_rate + ) + + aligned_agent = torch.zeros_like(full_wav_t) + raw_turn_texts = inputs.get("raw_turn_texts", [[] for _ in range(B)]) + + for turn_id, start_frame, end_frame in profile_turn_frame_ranges: + rel_start_frame = start_frame - profile_decode_start_frame + rel_end_frame = end_frame - profile_decode_start_frame + + start_sample = int(round(rel_start_frame * samples_per_prediction_frame)) + end_sample = int(round(rel_end_frame * samples_per_prediction_frame)) + + start_sample = max(0, min(start_sample, full_len)) + end_sample = max(start_sample, min(end_sample, full_len)) + + aligned_agent[start_sample:end_sample] = full_wav_t[start_sample:end_sample] + + turn_wav = aligned_agent[start_sample:end_sample].float() + turn_out_path = os.path.join(audio_out_dir, f"{safe_stem}_turn_{turn_id}{ext}") + write_audio_1d(turn_out_path, turn_wav, model.output_sample_rate) + + metric_turn_path = os.path.join(metric_turn_dir, f"{safe_stem}_turn_{turn_id}.wav") + write_audio_1d(metric_turn_path, turn_wav, model.output_sample_rate) + + if turn_id < len(raw_turn_texts[i]): + metric_items.append( + build_metric_item( + run_id=run_id, + rank=rank, + dataset_index=dataset_idx, + turn_id=turn_id, + target_audio_path=target_path, + reference_text=str(raw_turn_texts[i][turn_id]), + pred_audio_path=metric_turn_path, + context_audio_path=context_metric_path, + pred_audio_samples=int(turn_wav.numel()), + context_audio_samples=int(context_wav.numel()), + output_sample_rate=model.output_sample_rate, + context_sample_rate=model.sample_rate, + ) + ) + + full_out_path = os.path.join(audio_out_dir, f"{safe_stem}{ext}") + write_audio_1d(full_out_path, aligned_agent, model.output_sample_rate) + + if "user_audio_turns" in inputs: + user_segments = [] + + first_user_len_in = int(inputs["user_audio_turns_lens"][0][i].item()) + first_user_delay_out = int(round(first_user_len_in * model.output_sample_rate / model.sample_rate)) + + for turn_id, start_frame, _ in profile_turn_frame_ranges: + if turn_id >= len(inputs["user_audio_turns"]): + continue + + turn_audio = inputs["user_audio_turns"][turn_id][i].detach().cpu().float() + turn_audio_len = int(inputs["user_audio_turns_lens"][turn_id][i].item()) + turn_audio = turn_audio[:turn_audio_len] + + turn_audio_out = resample( + turn_audio.unsqueeze(0), + model.sample_rate, + model.output_sample_rate, + ).squeeze(0) + + if turn_id == 0: + user_start_sample = 0 + else: + prev_turn_end_frame = profile_turn_frame_ranges[turn_id - 1][2] + rel_prev_end_frame = prev_turn_end_frame - profile_decode_start_frame + user_start_sample = first_user_delay_out + int( + round(rel_prev_end_frame * samples_per_prediction_frame) + ) + + user_segments.append((user_start_sample, turn_audio_out.detach().cpu().float())) + + total_user_len = 0 + for s, wav_seg in user_segments: + total_user_len = max(total_user_len, s + wav_seg.numel()) + + user_ch = torch.zeros(total_user_len) + for s, wav_seg in user_segments: + e = s + wav_seg.numel() + user_ch[s:e] += wav_seg + + agent_ch = torch.cat([torch.zeros(first_user_delay_out, dtype=aligned_agent.dtype), aligned_agent]) + + stereo_len = max(user_ch.numel(), agent_ch.numel()) + user_pad = torch.zeros(stereo_len) + agent_pad = torch.zeros(stereo_len) + + user_pad[: user_ch.numel()] = user_ch + agent_pad[: agent_ch.numel()] = agent_ch + + stereo = torch.stack([user_pad, agent_pad], dim=1).numpy() + aligned_path = os.path.join(audio_out_dir, f"{safe_stem}_user_agent_aligned{ext}") + sf.write(aligned_path, stereo, samplerate=model.output_sample_rate) + + else: + full_len = int(audio_len_cpu[i].item()) + wav = audio_f32_cpu[i, :full_len].float() + out_path = os.path.join(audio_out_dir, f"{safe_stem}{ext}") + write_audio_1d(out_path, wav, model.output_sample_rate) + + metric_turn_path = os.path.join(metric_turn_dir, f"{safe_stem}_turn_0.wav") + write_audio_1d(metric_turn_path, wav, model.output_sample_rate) + + metric_items.append( + build_metric_item( + run_id=run_id, + rank=rank, + dataset_index=dataset_idx, + turn_id=0, + target_audio_path=target_path, + reference_text=str(inputs["raw_text"][i]), + pred_audio_path=metric_turn_path, + context_audio_path=context_metric_path, + pred_audio_samples=int(wav.numel()), + context_audio_samples=int(context_wav.numel()), + output_sample_rate=model.output_sample_rate, + context_sample_rate=model.sample_rate, + ) + ) + + return metric_items + + +# ----------------------------- +# Metrics after generation +# ----------------------------- + + +def torch_rms_norm(wav: torch.Tensor, db_level: float = -27.0) -> torch.Tensor: + denom = torch.sum(wav**2) + if denom <= 0: + return wav + r = 10 ** (db_level / 20) + a = torch.sqrt((wav.size(-1) * (r**2)) / denom) + return wav * a + + +def _load_audio_for_metric(path: str, sample_rate: int): + wav = _load_audio(path, sample_rate=sample_rate, normalize=False, use_librosa=False) + if wav.numel() == 0: + wav = torch.zeros(1, dtype=torch.float32) + return wav.float() + + +def _pad_audio_1d_list(wavs: List[torch.Tensor], device, dtype=torch.float32): + if len(wavs) == 0: + return torch.zeros((0, 1), device=device, dtype=dtype), torch.zeros((0,), device=device, dtype=torch.long) + + lens = torch.tensor([max(1, int(w.numel())) for w in wavs], device=device, dtype=torch.long) + max_len = int(lens.max().item()) + out = torch.zeros((len(wavs), max_len), device=device, dtype=dtype) + + for i, w in enumerate(wavs): + w = w.to(device=device, dtype=dtype).flatten() + if w.numel() == 0: + continue + out[i, : w.numel()] = w + + return out, lens + + +def chunk_list(xs: List[Any], chunk_size: int) -> Iterable[List[Any]]: + chunk_size = max(1, int(chunk_size)) + for start in range(0, len(xs), chunk_size): + yield xs[start : start + chunk_size] + + +def _metric_device(): + return "cuda" if torch.cuda.is_available() else "cpu" + + +def _load_metric_batch_audio(batch_items: List[Dict[str, Any]], args): + pred_wavs = [] + context_wavs = [] + + for item in batch_items: + pred = _load_audio_for_metric(item["pred_audio_path"], sample_rate=int(item["output_sample_rate"])) + context = _load_audio_for_metric(item["context_audio_path"], sample_rate=int(item["context_sample_rate"])) + + if args.max_metric_audio_sec is not None: + max_pred_len = int(float(args.max_metric_audio_sec) * int(item["output_sample_rate"])) + pred = pred[: max(1, max_pred_len)] + + pred_wavs.append(pred) + context_wavs.append(context) + + device = _metric_device() + pred_audio, pred_lens = _pad_audio_1d_list(pred_wavs, device=device) + context_audio, context_lens = _pad_audio_1d_list(context_wavs, device=device) + output_sample_rate = int(batch_items[0]["output_sample_rate"]) + context_sample_rate = int(batch_items[0]["context_sample_rate"]) + + return pred_audio, pred_lens, context_audio, context_lens, output_sample_rate, context_sample_rate + + +def compute_metrics_after_generation(args, rank: int, world_size: int, metric_items: List[Dict[str, Any]]): + """ + Load metric models only after generation is complete. + + Order: + 1. Load ASR, compute turn-level CER/WER and ASR hyps, then free ASR. + 2. Load SECS speaker encoder and compute turn-level SECS. + 3. Save rank-level aggregate metrics from the same turn-level rows. + + SECS is always computed turn-by-turn, like CER/WER. The grouped filewise + output stores secs_turns and sample-level secs, and metrics_final.* receives + the turn-level aggregate SECS. + """ + metric_start = time.time() + + if len(metric_items) == 0: + return { + "rank": int(rank), + "world_size": int(world_size), + "num_processed": 0, + "num_metric_items": 0, + "metric_elapsed_sec": 0.0, + "intelligibility": {}, + "secs": {}, + }, [] + + normalizer = EnglishTextNormalizer() + normalizer.ignore_patterns = r"$^" + filewise_rows = [] + + # ASR pass. + all_rank_print(rank, f"loading ASR after generation: {args.asr_model_name}") + with fp32_precision(): + intelligibility = Intelligibility(args.asr_model_name, reuse_asr_hyps=False).reset() + + for batch_items in chunk_list(metric_items, args.metric_batch_size): + refs = [x["reference_text"] for x in batch_items] + pred_audio, pred_lens, _, _, output_sr, _ = _load_metric_batch_audio(batch_items, args) + + with fp32_precision(): + pred_16k = resample(pred_audio, output_sr, 16000) + pred_16k_lens = (pred_lens / output_sr * 16000).to(torch.long) + pred_16k = torch_rms_norm(pred_16k) + + asr_hyps = intelligibility.update( + name="dataset", + refs=refs, + pred_audio=pred_16k, + pred_audio_lens=pred_16k_lens, + asr_hyps=None, + ) + + for item, hyp in zip(batch_items, asr_hyps): + ref_norm = normalizer(str(item["reference_text"])).strip() + hyp_norm = normalizer(str(hyp)).strip() + if ref_norm == "": + cer = None + wer = None + else: + cer = float(word_error_rate([hyp_norm], [ref_norm], use_cer=True)) + wer = float(word_error_rate([hyp_norm], [ref_norm], use_cer=False)) + + row = dict(item) + row["asr_hyp"] = hyp + row["cer"] = cer + row["wer"] = wer + row["secs"] = None + filewise_rows.append(row) + + with fp32_precision(): + cer_wer = metric_dict_to_jsonable(intelligibility.compute()) + del intelligibility + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # SECS pass. This is intentionally turn-level, matching CER/WER. + # We keep one aggregate SECS metric for metrics_final.* and also compute + # one SECS value per filewise turn row so grouped outputs have secs_turns. + all_rank_print(rank, f"loading speaker encoder after ASR is released: {args.secs_model_name}") + with fp32_precision(): + secs_metric = SECS(args.secs_model_name).reset() + + # Aggregate turn-level SECS for metrics_final.json / metrics_final.txt. + for batch_items in chunk_list(metric_items, args.metric_batch_size): + pred_audio, pred_lens, context_audio, context_lens, output_sr, context_sr = _load_metric_batch_audio( + batch_items, args + ) + + with fp32_precision(): + pred_16k = resample(pred_audio, output_sr, 16000) + pred_16k_lens = (pred_lens / output_sr * 16000).to(torch.long) + context_16k = resample(context_audio, context_sr, 16000) + context_16k_lens = (context_lens / context_sr * 16000).to(torch.long) + + pred_16k = torch_rms_norm(pred_16k) + context_16k = torch_rms_norm(context_16k) + + secs_metric.update( + name="dataset", + target_audio=context_16k, + target_audio_lens=context_16k_lens, + pred_audio=pred_16k, + pred_audio_lens=pred_16k_lens, + ) + + with fp32_precision(): + secs_scores = metric_dict_to_jsonable(secs_metric.compute()) + del secs_metric + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Per-turn SECS for filewise/grouped outputs. This is always computed so + # secs_turns and sample-level secs are never null in final filewise metrics. + # It is slower than aggregate-only SECS, but it matches the turn-level + # semantics requested for CER/WER/SECS. + all_rank_print(rank, "computing per-turn SECS rows") + for row in filewise_rows: + pred_audio, pred_lens, context_audio, context_lens, output_sr, context_sr = _load_metric_batch_audio([row], args) + + with fp32_precision(): + one_secs = SECS(args.secs_model_name).reset() + pred_16k = resample(pred_audio, output_sr, 16000) + pred_16k_lens = (pred_lens / output_sr * 16000).to(torch.long) + context_16k = resample(context_audio, context_sr, 16000) + context_16k_lens = (context_lens / context_sr * 16000).to(torch.long) + + pred_16k = torch_rms_norm(pred_16k) + context_16k = torch_rms_norm(context_16k) + + one_secs.update( + name="dataset", + target_audio=context_16k, + target_audio_lens=context_16k_lens, + pred_audio=pred_16k, + pred_audio_lens=pred_16k_lens, + ) + one_secs_metrics = metric_dict_to_jsonable(one_secs.compute()) + + row["secs"] = safe_metric_scalar(one_secs_metrics, ["secs", "secs_dataset"]) + row["secs_metrics"] = one_secs_metrics + del one_secs + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + metric_elapsed = time.time() - metric_start + + rank_metrics = { + "rank": int(rank), + "world_size": int(world_size), + "num_processed": len({(x["run_id"], x["dataset_index"]) for x in metric_items}), + "num_metric_items": int(len(metric_items)), + "metric_elapsed_sec": float(metric_elapsed), + "intelligibility": cer_wer, + "secs": secs_scores, + } + + return rank_metrics, filewise_rows + + +# ----------------------------- +# Merge helpers +# ----------------------------- + + +def compute_and_save_rank_metrics_file(args, rank_metrics: Dict[str, Any], rank: int): + rank_path = os.path.join(args.out_dir, f"metrics_rank{rank:04d}.json") + write_json(rank_path, rank_metrics) + return rank_metrics + + +def merge_metrics_on_rank0(args, rank, world_size): + if rank != 0: + return None + + rank_metric_files = [os.path.join(args.out_dir, f"metrics_rank{r:04d}.json") for r in range(world_size)] + + rank_metrics = [] + for path in rank_metric_files: + if not os.path.exists(path): + logging.warning(f"Missing rank metric file: {path}") + continue + with open(path, "r", encoding="utf-8") as f: + rank_metrics.append(json.load(f)) + + total_n = sum(int(m.get("num_metric_items", m.get("num_processed", 0))) for m in rank_metrics) + + def weighted_average(section: str): + keys = set() + for m in rank_metrics: + keys.update(m.get(section, {}).keys()) + + out = {} + for k in sorted(keys): + numerator = 0.0 + denominator = 0 + + for m in rank_metrics: + n = int(m.get("num_metric_items", m.get("num_processed", 0))) + if n <= 0: + continue + + value = m.get(section, {}).get(k, None) + if value is None or isinstance(value, str): + continue + + try: + value = float(value) + except Exception: + continue + + numerator += value * n + denominator += n + + if denominator > 0: + out[k] = numerator / denominator + + return out + + final_metrics = { + "world_size": int(world_size), + "num_metric_items": int(total_n), + "aggregation": "sum(rank_metric * rank_num_metric_items) / total_num_metric_items", + "intelligibility": weighted_average("intelligibility"), + "secs": weighted_average("secs"), + "ranks": rank_metrics, + } + + final_json_path = os.path.join(args.out_dir, "metrics_final.json") + final_txt_path = os.path.join(args.out_dir, "metrics_final.txt") + + write_json(final_json_path, final_metrics) + + final_text = format_final_metric_text(final_metrics) + write_text_atomic(final_txt_path, final_text) + + print("\n--- Final Evaluation Metrics ---", flush=True) + print(final_text, flush=True) + + logging.info(f"Final metrics JSON saved to: {final_json_path}") + logging.info(f"Final metrics TXT saved to: {final_txt_path}") + logging.info(json.dumps(final_metrics, indent=2, sort_keys=True)) + + return final_metrics + + +def merge_filewise_metrics_on_rank0(args, rank: int, world_size: int): + """Merge per-turn rank metric rows into one row per original sample. + + Rank files still contain one row per turn because metrics are computed + turn-by-turn. The final filewise outputs group those turn rows by + (run_id, dataset_index), producing one JSONL/CSV row per original sample + with list fields: + reference_text, asr_hyp, cer_turns, wer_turns, secs_turns. + + DistributedSampler padding repeats are deduplicated by + (run_id, dataset_index, turn_id), but repetitions from --num_eval_runs are + preserved because run_id is part of the key. + """ + if rank != 0 or not args.save_filewise_metrics: + return [] + + turn_rows = [] + + for r in range(world_size): + path = os.path.join(args.out_dir, f"filewise_metrics_rank{r:04d}.jsonl") + if not os.path.exists(path): + logging.warning(f"Missing filewise metrics file: {path}") + continue + + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + turn_rows.append(json.loads(line)) + + # Deduplicate DistributedSampler padding repeats, but preserve --num_eval_runs. + deduped_turns = {} + for row in turn_rows: + run_id = int(row.get("run_id", 0)) + idx = int(row.get("dataset_index", -1)) + turn_id = int(row.get("turn_id", 0)) + key = (run_id, idx, turn_id) + if key not in deduped_turns: + deduped_turns[key] = row + + turn_rows = list(deduped_turns.values()) + + # Group turn rows into one row per original file/sample. + grouped = {} + for row in turn_rows: + run_id = int(row.get("run_id", 0)) + idx = int(row.get("dataset_index", -1)) + key = (run_id, idx) + + if key not in grouped: + grouped[key] = { + "run_id": run_id, + "dataset_index": idx, + "rank": int(row.get("rank", -1)), + "target_audio_path": row.get("target_audio_path", ""), + "context_audio_path": row.get("context_audio_path", ""), + "turn_rows": [], + } + + grouped[key]["turn_rows"].append(row) + + def avg(vals): + vals = [float(x) for x in vals if x is not None and math.isfinite(float(x))] + return None if not vals else sum(vals) / len(vals) + + sample_rows = [] + for _, group in grouped.items(): + turns = sorted(group["turn_rows"], key=lambda x: int(x.get("turn_id", 0))) + + cer_turns = [r.get("cer") for r in turns] + wer_turns = [r.get("wer") for r in turns] + secs_turns = [r.get("secs") for r in turns] + + sample_row = { + "run_id": group["run_id"], + "dataset_index": group["dataset_index"], + "rank": group["rank"], + "num_turns": len(turns), + "turn_ids": [int(r.get("turn_id", 0)) for r in turns], + "target_audio_path": group["target_audio_path"], + "context_audio_path": group["context_audio_path"], + "pred_audio_paths": [r.get("pred_audio_path", "") for r in turns], + "pred_audio_seconds_turns": [r.get("pred_audio_seconds") for r in turns], + "reference_text": [r.get("reference_text", "") for r in turns], + "asr_hyp": [r.get("asr_hyp", "") for r in turns], + "cer_turns": cer_turns, + "wer_turns": wer_turns, + "secs_turns": secs_turns, + "cer": avg(cer_turns), + "wer": avg(wer_turns), + "secs": avg(secs_turns), + } + + sample_rows.append(sample_row) + + # Sort samples by average CER descending for failure analysis. + sample_rows.sort( + key=lambda x: ( + x.get("cer") is not None, + float(x.get("cer")) if x.get("cer") is not None else -1.0, + ), + reverse=True, + ) + + jsonl_path = os.path.join(args.out_dir, "filewise_metrics_sorted_by_cer.jsonl") + csv_path = os.path.join(args.out_dir, "filewise_metrics_sorted_by_cer.csv") + + write_jsonl(jsonl_path, sample_rows) + write_filewise_csv(csv_path, sample_rows) + + logging.info(f"Saved sample-level filewise metrics JSONL to: {jsonl_path}") + logging.info(f"Saved sample-level filewise metrics CSV to: {csv_path}") + + topk = min(int(args.filewise_metrics_topk_log), len(sample_rows)) + if topk > 0: + logging.info(f"Top {topk} worst CER samples:") + for row in sample_rows[:topk]: + logging.info( + "run_id=%s dataset_index=%s num_turns=%s cer=%s wer=%s secs=%s path=%s" + % ( + row.get("run_id"), + row.get("dataset_index"), + row.get("num_turns"), + row.get("cer"), + row.get("wer"), + row.get("secs"), + row.get("target_audio_path"), + ) + ) + + return sample_rows + +def compute_aggregates_from_filewise_rows(rows: List[Dict[str, Any]]): + """Aggregate over sample-level rows. + + Each row may internally contain multiple turn metrics in cer_turns/wer_turns, + but the final filewise average is over original samples/files. + """ + if len(rows) == 0: + return { + "cer": None, + "wer": None, + "secs": None, + "num_samples": 0, + } + + def avg_key(key): + vals = [float(r[key]) for r in rows if r.get(key) is not None] + if len(vals) == 0: + return None + return sum(vals) / len(vals) + + return { + "cer": avg_key("cer"), + "wer": avg_key("wer"), + "secs": avg_key("secs"), + "num_samples": len(rows), + } + +def save_filewise_final_summary(args, filewise_rows: List[Dict[str, Any]]): + filewise_summary = compute_aggregates_from_filewise_rows(filewise_rows) + + obj = { + "aggregation": "mean_over_sample_metrics_each_sample_contains_turn_metric_lists", + **filewise_summary, + } + + path = os.path.join(args.out_dir, "metrics_final_filewise_average.json") + write_json(path, obj) + + sample_metrics_final_path = os.path.join(args.out_dir, "metrics_final_sample_average.json") + write_json(sample_metrics_final_path, obj) + + final_txt_path = os.path.join(args.out_dir, "metrics_final.txt") + final_text = format_filewise_final_metric_text(filewise_summary) + write_text_atomic(final_txt_path, final_text) + + print("\n--- Final Sample-Averaged Evaluation Metrics ---", flush=True) + print(final_text, flush=True) + + logging.info(f"Filewise averaged final metrics saved to: {path}") + logging.info(f"Sample averaged metrics_final JSON saved to: {sample_metrics_final_path}") + logging.info(f"Final metrics TXT saved to: {final_txt_path}") + + return obj + + +# ----------------------------- +# Args / main +# ----------------------------- + + +def parse_args(): + parser = argparse.ArgumentParser(description="EasyMagpieTTS Multi-GPU Inference Evaluation") + + parser.add_argument("--checkpoint_path", type=str, required=True) + parser.add_argument("--codec_model_path", type=str, required=True) + parser.add_argument("--datasets_json_path", type=str, required=True) + parser.add_argument("--out_dir", type=str, required=True) + + parser.add_argument("--phoneme_tokenizer_path", type=str, default=None) + parser.add_argument("--audio_dir", type=str, default=None) + parser.add_argument("--inference_dtype", type=str, default="float32", choices=["float32", "float16", "bfloat16"]) + parser.add_argument("--debug_dtype", action="store_true") + parser.add_argument("--debug_gpu_assignment", action="store_true") + parser.add_argument("--use_librosa", action="store_true") + + parser.add_argument("--batch_size", type=int, default=6) + parser.add_argument("--num_workers", type=int, default=4) + parser.add_argument("--num_turns", type=int, default=1) + parser.add_argument("--pad_factor_text_speech", type=int, default=10) + + parser.add_argument("--emulate_duplex_inference", action="store_true") + parser.add_argument("--add_interruption_token", action="store_true") + parser.add_argument("--force_interruption", action="store_true") + parser.add_argument("--profile_multiturn_inference", action="store_true") + parser.add_argument("--profile_pad_min_sec", type=float, default=2.0) + parser.add_argument("--profile_pad_max_sec", type=float, default=2.0) + parser.add_argument("--max_eval_turns", type=int, default=6) + + parser.add_argument("--user_custom_speaker_reference", action="store_true") + parser.add_argument("--inference_speaker_reference", type=str, default=None) + parser.add_argument("--language", type=str, default="en") + + parser.add_argument("--use_cfg", action="store_true") + parser.add_argument("--cfg_scale", type=float, default=2.5) + parser.add_argument("--temperature", type=float, default=0.7) + parser.add_argument("--topk", type=int, default=80) + parser.add_argument("--max_tts_steps", type=int, default=2000) + parser.add_argument("--force_speech_sil_codes", action="store_true") + parser.add_argument("--normalize_volume", type=lambda x: str(x).lower() in ["true", "1", "yes"], default=True) + + parser.add_argument( + "--save_filewise_metrics", + action="store_true", + help="Save per-turn/file CER/WER metrics sorted by CER descending.", + ) + parser.add_argument( + "--compute_filewise_secs", + action="store_true", + help="Also compute per-turn/file SECS. Slower because it runs SECS per row.", + ) + parser.add_argument( + "--filewise_metrics_topk_log", + type=int, + default=20, + help="Number of worst CER samples to print on rank 0.", + ) + parser.add_argument( + "--num_eval_runs", + type=int, + default=1, + help="Repeat the full eval set N times. Repetitions are preserved in final filewise average.", + ) + parser.add_argument( + "--sort_by_text_token_count", + action="store_true", + help="Sort eval samples by total text token count before distributed sharding for better load balancing.", + ) + parser.add_argument( + "--metric_batch_size", + type=int, + default=8, + help="Batch size used for post-generation ASR/SECS metric computation.", + ) + parser.add_argument( + "--max_metric_audio_sec", + type=float, + default=120.0, + help="Clamp generated audio length used for ASR/SECS metrics to avoid metric OOM/hangs.", + ) + parser.add_argument( + "--asr_model_name", + type=str, + default="stt_en_fastconformer_transducer_large", + help="Pretrained NeMo ASR model used for CER/WER.", + ) + parser.add_argument( + "--secs_model_name", + type=str, + default="titanet_large", + help="Pretrained speaker encoder model used for SECS.", + ) + + return parser.parse_args() + + +def main(): + args = parse_args() + os.makedirs(args.out_dir, exist_ok=True) + os.makedirs(get_audio_out_dir(args), exist_ok=True) + os.makedirs(get_generated_turn_audio_dir(args), exist_ok=True) + os.makedirs(get_context_metric_audio_dir(args), exist_ok=True) + + distributed, rank, local_rank, world_size, device_index = setup_distributed() + + if args.profile_multiturn_inference and args.batch_size != 1: + raise RuntimeError( + "--profile_multiturn_inference requires --batch_size=1 per process. " + "Use multiple GPUs/processes for parallelism instead of increasing batch_size." + ) + + if args.profile_pad_max_sec < args.profile_pad_min_sec: + raise RuntimeError("--profile_pad_max_sec must be >= --profile_pad_min_sec.") + + if args.num_eval_runs <= 0: + raise RuntimeError("--num_eval_runs must be >= 1.") + + target_device = torch.device(f"cuda:{device_index}" if torch.cuda.is_available() and device_index >= 0 else "cpu") + target_dtype = getattr(torch, args.inference_dtype) + torch.set_default_dtype(target_dtype) + + hostname = socket.gethostname() + cuda_name = torch.cuda.get_device_name(target_device) if torch.cuda.is_available() and device_index >= 0 else "cpu" + + all_rank_print( + rank, + f"host={hostname} local_rank={local_rank} world_size={world_size} " + f"device={target_device} device_name={cuda_name}", + ) + + model = build_model_and_codec(args, target_device, target_dtype) + codec_sil_codes = model.codec_sil_codes + + if args.debug_dtype: + handles, stats, examples = attach_dtype_counter(model) + else: + handles = stats = examples = None + + full_eval_dataset = EvalJSONLDataset(args.datasets_json_path, num_turns=args.num_turns) + # debug + # full_eval_dataset.samples = full_eval_dataset.samples[:7] + + if args.sort_by_text_token_count: + full_eval_dataset = SortedByTextTokenCountDataset( + full_eval_dataset, + model=model, + max_eval_turns=args.max_eval_turns, + descending=True, + ) + + collate_fn = partial( + collate_and_tokenize_custom, + model=model, + extra_duration_thrshould=1.5, + sample_rate=model.sample_rate, + root_path=args.audio_dir, + emulate_duplex_inference=args.emulate_duplex_inference, + add_interruption_token=args.add_interruption_token, + pad_factor_text_speech=args.pad_factor_text_speech, + force_interruption=args.force_interruption, + normalize_audio_volume=args.normalize_volume, + use_librosa=args.use_librosa, + profile_multiturn_inference=args.profile_multiturn_inference, + max_eval_turns=args.max_eval_turns, + ) + + speaker_wav = load_speaker_wav_if_needed(args, model, target_dtype) + + generation_start = time.time() + all_metric_items = [] + total_batches = 0 + total_generated_samples = 0 + + for run_id in range(args.num_eval_runs): + if distributed: + sampler = DistributedSampler( + full_eval_dataset, + num_replicas=world_size, + rank=rank, + shuffle=False, + drop_last=False, + ) + sampler.set_epoch(run_id) + else: + sampler = SequentialSampler(full_eval_dataset) + + if args.debug_gpu_assignment: + try: + assigned_indices = list(iter(sampler)) + assigned_dataset_indices = [ + int(full_eval_dataset[i].get("__dataset_index__", -1)) for i in assigned_indices + ] + all_rank_print( + rank, + f"run_id={run_id} assigned {len(assigned_dataset_indices)} / {len(full_eval_dataset)} " + f"samples to gpu={local_rank}: dataset_indices={assigned_dataset_indices}", + ) + except Exception as e: + all_rank_print(rank, f"Could not print assigned indices: {repr(e)}") + + dataloader = DataLoader( + dataset=full_eval_dataset, + batch_size=args.batch_size, + sampler=sampler, + collate_fn=collate_fn, + num_workers=args.num_workers, + pin_memory=True, + drop_last=False, + ) + + for batch_id, inputs in enumerate(dataloader): + total_batches += 1 + batch_indices = inputs.get("dataset_indices", []) + total_generated_samples += len(batch_indices) + + if args.debug_gpu_assignment: + all_rank_print( + rank, + f"run_id={run_id} gpu={local_rank} batch_id={batch_id} " + f"dataset_indices={batch_indices} text_token_counts={inputs.get('text_token_counts', [])} " + f"target_paths={inputs.get('target_audio_paths', [])}", + ) + + inputs = prepare_inputs_for_device(inputs, model, args, target_dtype, speaker_wav=speaker_wav) + + finalize_output, profile_turn_frame_ranges, profile_decode_start_frame = run_generation( + model=model, + inputs=inputs, + args=args, + codec_sil_codes=codec_sil_codes, + ) + + metric_items = save_generation_outputs_and_build_metric_items( + model=model, + inputs=inputs, + finalize_output=finalize_output, + profile_turn_frame_ranges=profile_turn_frame_ranges, + profile_decode_start_frame=profile_decode_start_frame, + args=args, + rank=rank, + run_id=run_id, + ) + all_metric_items.extend(metric_items) + + if args.debug_dtype and batch_id == 0 and run_id == 0: + report_dtype_stats(handles, stats, examples, rank=rank) + + generation_elapsed = time.time() - generation_start + + # Save pre-metric manifest for debugging and restartability. + metric_manifest_path = os.path.join(args.out_dir, f"metric_items_rank{rank:04d}.jsonl") + write_jsonl(metric_manifest_path, all_metric_items) + + # Free TTS/codec model memory before loading ASR and speaker encoder metrics. + del model + if speaker_wav is not None: + del speaker_wav + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + all_rank_print( + rank, + f"generation done: batches={total_batches} generated_samples_with_sampler_padding={total_generated_samples} " + f"metric_items={len(all_metric_items)} elapsed_sec={generation_elapsed:.2f}. " + "Loading ASR/SECS metrics now.", + ) + + rank_metrics, rank_filewise_rows = compute_metrics_after_generation( + args=args, + rank=rank, + world_size=world_size, + metric_items=all_metric_items, + ) + rank_metrics["generation_elapsed_sec"] = float(generation_elapsed) + rank_metrics["num_generated_samples_with_sampler_padding"] = int(total_generated_samples) + + rank_metrics = compute_and_save_rank_metrics_file(args, rank_metrics, rank) + all_rank_print(rank, f"saved rank metrics: {json.dumps(rank_metrics, sort_keys=True)}") + + if args.save_filewise_metrics: + rank_filewise_rows.sort( + key=lambda x: ( + x.get("cer") is not None, + float(x.get("cer")) if x.get("cer") is not None else -1.0, + ), + reverse=True, + ) + + rank_filewise_path = os.path.join(args.out_dir, f"filewise_metrics_rank{rank:04d}.jsonl") + write_jsonl(rank_filewise_path, rank_filewise_rows) + all_rank_print(rank, f"saved filewise metrics: {rank_filewise_path}") + + if rank == 0: + wait_for_rank_metric_files(args, world_size) + + merge_metrics_on_rank0(args, rank, world_size) + + if args.save_filewise_metrics: + if rank == 0: + wait_for_rank_filewise_metric_files(args, world_size) + + filewise_rows = merge_filewise_metrics_on_rank0(args, rank, world_size) + + if rank == 0: + save_filewise_final_summary(args, filewise_rows) + + cleanup_distributed() + + +if __name__ == "__main__": + main() From 646e576a915bb6549a7b901f5e59cc195398923c Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 8 Jun 2026 11:52:11 -0700 Subject: [PATCH 072/109] Remove unused methods Signed-off-by: Edresson Casanova --- .../tts/models/easy_magpietts_inference.py | 49 +------------------ 1 file changed, 2 insertions(+), 47 deletions(-) diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index b6a0663f3fee..fb34090aecd8 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -1462,7 +1462,7 @@ def streaming_init( dropout_conditional_input=False, ) ) - + # Store full context embedding and lens before any CFG manipulation full_context_embedding = context_embedding.clone() # (B, T_max, E) full_context_lens = context_lens.clone() # (B,) @@ -1635,7 +1635,6 @@ def streaming_step( if prefill_like_step: # Advance logical streams, keep audio silent, but predict phonemes if enabled. state.context_position += needs_context.long() - advance_text = (~needs_context).long() state.text_tokens_seen += (~needs_context).long() if hasattr(state, "turn_text_tokens_seen"): @@ -1824,7 +1823,7 @@ def _prepare_streaming_input( else: first_phoneme_step = needs_phoneme & (state.phoneme_steps == 0) has_last_phoneme = needs_phoneme & (~first_phoneme_step) & (state.last_phoneme_tokens is not None) - + if first_phoneme_step.any(): phoneme_bos = torch.full( (batch_size, self.phoneme_stacking_factor, 1), @@ -1920,50 +1919,6 @@ def _prepare_streaming_input( return next_input, needs_context, needs_phoneme, needs_audio - def _embed_one_text_step( - self, - tokens: torch.Tensor, # (B,) - force_dropout_text: bool = False, - ) -> torch.Tensor: - """ - Embed one text step. Returns (B, 1, E). - """ - device = tokens.device - tokens_2d = tokens.unsqueeze(1) - - if self.cfg.get("disable_subword_embedding", False): - text_embedded = torch.zeros( - tokens_2d.size(0), - 1, - self.cfg.embedding_dim, - dtype=next(self.parameters()).dtype, - device=device, - ) - else: - text_embedded = self.decoder.get_input_embeddings()(tokens_2d) - - is_pad = tokens_2d == self.pad_id - - if self.use_bpe_char_tokenizer: - if self.cfg.get("use_multiturn_dataset", False): - text_mask = ~is_pad - else: - text_mask = torch.ones_like(tokens_2d, dtype=torch.bool) - - text_embedded = text_embedded + self.cas_encoder( - tokens_2d, - subword_mask=text_mask, - ) - - if force_dropout_text: - text_embedded = text_embedded * 0.0 - - if self.cfg.get("use_multiturn_dataset", False): - text_embedded[is_pad] = 0.0 - - return text_embedded - - def _process_predictions( self, state: StreamingState, From dea328c91d3e3a20950a4d73cdc3435fc0c83081 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 8 Jun 2026 11:52:34 -0700 Subject: [PATCH 073/109] Add new easymagpie compatible inference script Signed-off-by: Edresson Casanova --- .../tts/easy_magpietts_inference_multiturn.py | 4023 ++++++++++++----- ...sy_magpietts_inference_multiturn_runner.py | 750 +++ ...nference_multiturn_turn_based_as_magpie.py | 2247 --------- .../modules/magpietts_inference/inference.py | 830 +++- 4 files changed, 4510 insertions(+), 3340 deletions(-) create mode 100644 examples/tts/easy_magpietts_inference_multiturn_runner.py delete mode 100644 examples/tts/easy_magpietts_inference_multiturn_turn_based_as_magpie.py diff --git a/examples/tts/easy_magpietts_inference_multiturn.py b/examples/tts/easy_magpietts_inference_multiturn.py index 2aa5f0de6fab..f9c54f1c5393 100644 --- a/examples/tts/easy_magpietts_inference_multiturn.py +++ b/examples/tts/easy_magpietts_inference_multiturn.py @@ -1,137 +1,730 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +#!/usr/bin/env python3 +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. """ -Evaluation script for custom EasyMagpieTTS models. -Features explicit Duplex (10x Padding) and Regular (Turn-by-turn) multi-turn modes. - -Usage: - python easy_magpietts_eval.py \ - --checkpoint_path=/path/to/magpie/model.ckpt \ - --codec_model_path=/path/to/codec/model.ckpt \ - --datasets_json_path=/path/to/evalset_config.jsonl \ - --out_dir=/path/to/out/audio \ - --batch_size=6 \ - --use_cfg \ - --use_librosa +Multi-GPU EasyMagpieTTS / NemotronTTS multiturn inference evaluation. + +Key behavior: + - Uses torchrun env vars RANK, LOCAL_RANK, WORLD_SIZE for sharding/GPU assignment. + - Does NOT initialize torch.distributed. This avoids NeMo ASR doing distributed + collectives during metric computation. + - Generation runs first for all assigned samples. + - ASR and speaker-similarity models are loaded only after generation is done and the TTS/codec model + has been deleted from GPU memory. + - ASR and speaker-similarity models are loaded sequentially: ASR first, then released; speaker-similarity second. + - Supports multiturn-user-audio and regular single-turn inference; metrics are turn/file based. + Final filewise outputs are grouped back to one row per original sample, with + lists for asr_hyp/reference_text/cer_turns/wer_turns/ssim_turns. + - Uses DistributedSampler with explicit rank/world_size. A few repeated samples + may appear when len(dataset) is not divisible by world_size. Filewise final + metrics deduplicate sampler-padding repeats by (run_id, dataset_index, + turn_id), then group turns into one row per sample with metric lists, while + preserving --num_eval_runs repetitions. + - --sort_by_text_token_count sorts samples by total text-token count before + sharding to improve GPU load balance. + - Saves audio in out_dir/audios/. + - Saves metrics in out_dir/. + +Recommended single-node torchrun: + CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ + torchrun --standalone --nproc_per_node=8 easy_magpietts_inference_multiturn_multigpu_postgen_metrics.py ... + +Recommended single-node srun wrapper: + srun --nodes=1 --ntasks=1 --ntasks-per-node=1 --container-image=... \ + bash -lc 'torchrun --standalone --nproc_per_node=8 easy_magpietts_inference_multiturn_multigpu_postgen_metrics.py ...' """ import argparse +import csv import json +import math import os +import socket +import shutil +import time +from collections import Counter from copy import deepcopy from functools import partial +from typing import Any, Dict, Iterable, List, Tuple import librosa import soundfile as sf import torch +from omegaconf import open_dict from torch.nn.utils.rnn import pad_sequence -from torch.utils.data import DataLoader, Dataset -from omegaconf import OmegaConf, open_dict +from torch.utils.data import DataLoader, Dataset, DistributedSampler, SequentialSampler -from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.audio.parts.utils.transforms import resample -from nemo.collections.speechlm2.parts.metrics.asr_cer_wer import Intelligibility -from nemo.collections.speechlm2.parts.metrics.secs import SECS +from nemo.collections.asr.metrics.wer import word_error_rate, word_error_rate_detail from nemo.collections.speechlm2.parts.precision import fp32_precision -from nemo.utils import logging - -# --- EasyMagpieTTS Imports --- from nemo.collections.tts.models import AudioCodecModel +from nemo.collections.tts.models.easy_magpietts_inference import EasyMagpieTTSInferenceModel from nemo.collections.tts.modules.audio_codec_modules import VectorQuantizerIndexConverter from nemo.collections.tts.modules.magpietts_modules import CodecHelper -from nemo.collections.tts.models.easy_magpietts_inference import EasyMagpieTTSInferenceModel - from nemo.collections.tts.parts.utils.tts_dataset_utils import normalize_volume +from nemo.utils import logging +from whisper_normalizer.english import EnglishTextNormalizer + +try: + import nemo.collections.asr as nemo_asr +except Exception: + nemo_asr = None + +try: + from nemo.collections.asr.models import ASRModel +except Exception: + ASRModel = None + +try: + from transformers import Wav2Vec2FeatureExtractor, WavLMForXVector +except Exception: + Wav2Vec2FeatureExtractor = None + WavLMForXVector = None + +try: + from nemo.collections.tts.modules.magpietts_inference.evaluate_generated_audio import ( + compute_utmosv2_scores, + extract_embedding, + ) +except Exception: + compute_utmosv2_scores = None + extract_embedding = None + +try: + from nemo.collections.tts.metrics.eou_classifier import EoUClassifier, EoUType +except Exception: + EoUClassifier = None + EoUType = None + +try: + from nemo.collections.tts.modules.magpietts_inference.evaluation import DEFAULT_VIOLIN_METRICS +except Exception: + DEFAULT_VIOLIN_METRICS = ['cer', 'pred_context_ssim', 'utmosv2'] + +try: + from nemo.collections.tts.modules.magpietts_inference.visualization import create_violin_plot +except Exception: + create_violin_plot = None + +try: + from nemo.collections.tts.metrics.frechet_codec_distance import FrechetCodecDistance +except Exception: + FrechetCodecDistance = None + + torch.set_float32_matmul_precision("medium") torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True -if torch.cuda.is_available(): - torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0))) -def torch_rms_norm(wav, db_level=-27.0): - r = 10 ** (db_level / 20) - a = torch.sqrt((wav.size(-1) * (r**2)) / torch.sum(wav**2)) - return wav * a +# ----------------------------- +# Rank / file helpers +# ----------------------------- -def attach_dtype_counter(model): - handles = [] - stats = {} - examples = {} - def is_leaf(module): - return len(list(module.children())) == 0 +def get_rank_info() -> Tuple[bool, int, int, int]: + world_size = int(os.environ.get("WORLD_SIZE", "1")) + rank = int(os.environ.get("RANK", os.environ.get("SLURM_PROCID", "0"))) + local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("SLURM_LOCALID", "0"))) + distributed = world_size > 1 + return distributed, rank, local_rank, world_size - def get_dtype(x): - if torch.is_tensor(x): - return str(x.dtype) - elif isinstance(x, (list, tuple)): - for t in x: - if torch.is_tensor(t): - return str(t.dtype) - return "other" - def get_module_group(name): - return name.split(".")[0] if "." in name else name +def get_visible_device_index(local_rank: int) -> int: + if not torch.cuda.is_available(): + return -1 + ndev = torch.cuda.device_count() + if ndev <= 0: + return -1 + return local_rank % ndev - def hook_fn(name): - def fn(module, inputs, outputs): - dtype = get_dtype(outputs) - if dtype not in ["torch.float16", "torch.bfloat16", "torch.float32"]: - dtype = "other" - group = get_module_group(name) - if group not in stats: - stats[group] = { - "torch.float16": 0, "torch.bfloat16": 0, "torch.float32": 0, "other": 0, - } - examples[group] = { - "torch.float16": [], "torch.bfloat16": [], "torch.float32": [], "other": [], - } - stats[group][dtype] += 1 - if len(examples[group][dtype]) < 3: - examples[group][dtype].append(module.__class__.__name__) - return fn - for name, module in model.named_modules(): - if is_leaf(module): - handles.append(module.register_forward_hook(hook_fn(name))) - return handles, stats, examples +def setup_distributed(): + """ + Do not initialize torch.distributed. + We only need RANK/LOCAL_RANK/WORLD_SIZE for rank assignment and dataset + sharding. Initializing a process group can cause NeMo ASR to run distributed + collectives during transcribe(), which may hang when ranks have different + audio lengths or workloads. + """ + distributed, rank, local_rank, world_size = get_rank_info() + device_index = get_visible_device_index(local_rank) -def report_dtype_stats(handles, stats, examples): - for h in handles: - h.remove() - logging.info("\n=== DTYPE USAGE PER MODULE ===") - for group, group_stats in stats.items(): - total = sum(group_stats.values()) - if total == 0: continue - logging.info(f"\n--- {group} ---") - for dtype, count in group_stats.items(): - if count > 0: - logging.info(f"{dtype}: {count} ({100*count/total:.2f}%)") - logging.info("\n=== EXAMPLES ===") - for group, group_examples in examples.items(): - logging.info(f"\n--- {group} ---") - for dtype, mods in group_examples.items(): - if mods: - logging.info(f"{dtype}: {mods}") + if torch.cuda.is_available() and device_index >= 0: + torch.cuda.set_device(device_index) + + return distributed, rank, local_rank, world_size, device_index + + +def cleanup_distributed(): + return + + +def all_rank_print(rank: int, msg: str): + print(f"[rank={rank}] {msg}", flush=True) + + +def rank0_print(rank: int, msg: str): + if rank == 0: + print(msg, flush=True) + + +def get_audio_out_dir(args) -> str: + return os.path.join(args.out_dir, "audios") + + +def get_generated_turn_audio_dir(args) -> str: + return os.path.join(get_audio_out_dir(args), "metric_turns") + + +def get_context_metric_audio_dir(args) -> str: + return os.path.join(get_audio_out_dir(args), "metric_context") + + +def get_predicted_codes_dir(args) -> str: + return os.path.join(get_audio_out_dir(args), "predicted_codes") + + +def write_json(path: str, obj: Dict[str, Any]): + os.makedirs(os.path.dirname(path), exist_ok=True) + tmp_path = path + ".tmp" + with open(tmp_path, "w", encoding="utf-8") as f: + json.dump(obj, f, indent=2, sort_keys=True, ensure_ascii=False) + os.replace(tmp_path, path) + + +def write_text_atomic(path: str, text: str): + os.makedirs(os.path.dirname(path), exist_ok=True) + tmp_path = path + ".tmp" + with open(tmp_path, "w", encoding="utf-8") as f: + f.write(text) + os.replace(tmp_path, path) + + +def write_jsonl(path: str, rows: List[Dict[str, Any]]): + os.makedirs(os.path.dirname(path), exist_ok=True) + tmp_path = path + ".tmp" + with open(tmp_path, "w", encoding="utf-8") as f: + for row in rows: + f.write(json.dumps(row, sort_keys=True, ensure_ascii=False) + "\n") + os.replace(tmp_path, path) + +def write_csv_header_if_needed(csv_path: str, header: str) -> None: + os.makedirs(os.path.dirname(csv_path), exist_ok=True) + if not os.path.exists(csv_path): + with open(csv_path, "w", encoding="utf-8") as f: + f.write(header + "\n") + + +def append_metrics_to_csv(csv_path: str, checkpoint_name: str, dataset: str, metrics: Dict[str, Any]) -> None: + """Append metrics using the same column order as MagpieTTS inference/eval.""" + csv_header = ( + "checkpoint_name,dataset,cer_filewise_avg,wer_filewise_avg,cer_cumulative," + "wer_cumulative,ssim_pred_gt_avg,ssim_pred_context_avg,ssim_gt_context_avg," + "ssim_pred_gt_avg_alternate,ssim_pred_context_avg_alternate," + "ssim_gt_context_avg_alternate,cer_gt_audio_cumulative,wer_gt_audio_cumulative," + "utmosv2_avg,total_gen_audio_seconds,frechet_codec_distance," + "eou_cutoff_rate,eou_silence_rate,eou_noise_rate,eou_error_rate" + ) + write_csv_header_if_needed(csv_path, csv_header) + + values = [ + checkpoint_name, + dataset, + metrics.get("cer_filewise_avg", ""), + metrics.get("wer_filewise_avg", ""), + metrics.get("cer_cumulative", ""), + metrics.get("wer_cumulative", ""), + metrics.get("ssim_pred_gt_avg", ""), + metrics.get("ssim_pred_context_avg", ""), + metrics.get("ssim_gt_context_avg", ""), + metrics.get("ssim_pred_gt_avg_alternate", ""), + metrics.get("ssim_pred_context_avg_alternate", ""), + metrics.get("ssim_gt_context_avg_alternate", ""), + metrics.get("cer_gt_audio_cumulative", ""), + metrics.get("wer_gt_audio_cumulative", ""), + metrics.get("utmosv2_avg", ""), + metrics.get("total_gen_audio_seconds", ""), + metrics.get("frechet_codec_distance", ""), + metrics.get("eou_cutoff_rate", ""), + metrics.get("eou_silence_rate", ""), + metrics.get("eou_noise_rate", ""), + metrics.get("eou_error_rate", ""), + ] + + def clean_csv_value(v): + if v is None: + return "" + if isinstance(v, float) and not math.isfinite(v): + return "nan" + return str(v).replace(",", " ") + + with open(csv_path, "a", encoding="utf-8") as f: + f.write(",".join(clean_csv_value(v) for v in values) + "\n") + logging.info(f"Metrics appended to: {csv_path}") + + +def get_checkpoint_name(args) -> str: + checkpoint_path = getattr(args, "checkpoint_path", None) + if checkpoint_path: + stem = os.path.basename(checkpoint_path) + if stem.endswith(".nemo"): + stem = stem[:-5] + return stem + return "checkpoint" + + +def get_dataset_name(args) -> str: + out_name = os.path.basename(os.path.normpath(args.out_dir)) + if out_name: + return out_name + dataset_path = getattr(args, "datasets_json_path", None) + return os.path.splitext(os.path.basename(dataset_path))[0] if dataset_path else "dataset" + + +def create_violin_plot_if_available(metrics: List[Dict[str, Any]], metric_keys: List[str], output_path: str): + if create_violin_plot is None: + logging.warning( + "create_violin_plot is unavailable; skipping violin plot. " + "Make sure nemo.collections.tts.modules.magpietts_inference.visualization is importable." + ) + return + + if not metrics: + logging.warning(f"No metrics available for violin plot: {output_path}") + return + + available_keys = [] + for key in metric_keys: + for row in metrics: + value = row.get(key, None) + if value is None: + continue + try: + value = float(value) + except Exception: + continue + if math.isfinite(value): + available_keys.append(key) + break + + if not available_keys: + logging.warning(f"No finite requested plot metrics available for violin plot: {output_path}") + return + + os.makedirs(os.path.dirname(output_path), exist_ok=True) + create_violin_plot(metrics, available_keys, output_path) + + +def _copy_or_link(src: str, dst: str): + if src is None or not src or not os.path.exists(src): + return None + os.makedirs(os.path.dirname(dst), exist_ok=True) + try: + if os.path.lexists(dst): + os.remove(dst) + os.symlink(os.path.abspath(src), dst) + except Exception: + shutil.copyfile(src, dst) + return dst + + +def write_easymagpie_generated_audio_dir(args, sample_rows: List[Dict[str, Any]]): + """Write EasyMagpie/MagpieTTS-style generated audio/code files. + + This creates files named predicted_audio_*.wav and predicted_codes_*.pt, + plus target/context audio files, so downstream EasyMagpie demo/report tools + can consume this output directory. + """ + generated_audio_dir = os.path.join(args.out_dir, "easy_magpie_generated_audio") + os.makedirs(generated_audio_dir, exist_ok=True) + + manifest_rows = [] + filewise_rows = [] + + rows = sorted(sample_rows, key=lambda r: (int(r.get("run_id", 0)), int(r.get("dataset_index", -1)))) + + for item_idx, row in enumerate(rows): + pred_src = row.get("sample_pred_audio_path") or ( + row.get("pred_audio_paths", [None])[0] if isinstance(row.get("pred_audio_paths"), list) else None + ) + code_src = row.get("sample_predicted_codes_path") or ( + row.get("predicted_codes_paths", [None])[0] if isinstance(row.get("predicted_codes_paths"), list) else None + ) + target_src = _resolve_audio_path(row.get("target_audio_path"), args.audio_dir) + context_src = row.get("context_audio_path") + + pred_dst = os.path.join(generated_audio_dir, f"predicted_audio_{item_idx}.wav") + code_dst = os.path.join(generated_audio_dir, f"predicted_codes_{item_idx}.pt") + target_dst = os.path.join(generated_audio_dir, f"target_audio_{item_idx}.wav") + context_dst = os.path.join(generated_audio_dir, f"context_audio_{item_idx}.wav") + + _copy_or_link(pred_src, pred_dst) + _copy_or_link(code_src, code_dst) + _copy_or_link(target_src, target_dst) + _copy_or_link(context_src, context_dst) + + reference_text = row.get("reference_text", "") + if isinstance(reference_text, list): + manifest_text = " ".join(str(x) for x in reference_text) + else: + manifest_text = str(reference_text) + + manifest_rows.append( + { + "audio_filepath": f"target_audio_{item_idx}.wav", + "context_audio_filepath": f"context_audio_{item_idx}.wav", + "text": manifest_text, + "speaker": row.get("dataset_index", item_idx), + "original_dataset_index": row.get("dataset_index"), + "run_id": row.get("run_id", 0), + } + ) + + metric_row = dict(row) + metric_row.update( + { + "easy_magpie_item_idx": item_idx, + "gt_audio_filepath": target_dst if os.path.exists(target_dst) else target_src, + "pred_audio_filepath": pred_dst if os.path.exists(pred_dst) else pred_src, + "context_audio_filepath": context_dst if os.path.exists(context_dst) else context_src, + "predicted_codes_path": code_dst if os.path.exists(code_dst) else code_src, + } + ) + filewise_rows.append(metric_row) + + manifest_path = os.path.join(args.out_dir, "easy_magpie_generated_manifest.jsonl") + filewise_path = os.path.join(args.out_dir, "easy_magpie_generated_filewise_metrics.json") + write_jsonl(manifest_path, manifest_rows) + write_json(filewise_path, {"filewise_metrics": filewise_rows}) + + logging.info(f"Saved EasyMagpie-style generated audio dir to: {generated_audio_dir}") + logging.info(f"Saved EasyMagpie-style generated manifest to: {manifest_path}") + logging.info(f"Saved EasyMagpie-style generated filewise metrics to: {filewise_path}") + + return { + "generated_audio_dir": generated_audio_dir, + "manifest_path": manifest_path, + "filewise_metrics_path": filewise_path, + } + + +def save_easymagpie_style_eval_outputs(args, sample_rows: List[Dict[str, Any]], filewise_summary: Dict[str, Any]): + """Save CSV, plots, and generated-audio artifacts following EasyMagpie conventions.""" + easy_magpie_artifacts = write_easymagpie_generated_audio_dir(args, sample_rows) + filewise_summary["easy_magpie_generated_audio_dir"] = easy_magpie_artifacts["generated_audio_dir"] + filewise_summary["easy_magpie_generated_manifest"] = easy_magpie_artifacts["manifest_path"] + + checkpoint_name = get_checkpoint_name(args) + dataset_name = get_dataset_name(args) + + per_run_csv = os.path.join(args.out_dir, "all_experiment_metrics.csv") + append_metrics_to_csv(per_run_csv, checkpoint_name, dataset_name, filewise_summary) + + # Keep this alias because EasyMagpie aggregation scripts often look for the CI CSV. + ci_csv = os.path.join(args.out_dir, "all_experiment_metrics_with_ci.csv") + append_metrics_to_csv(ci_csv, checkpoint_name, dataset_name, filewise_summary) + + if not args.save_plots: + return + + violin_metrics = list(args.violin_plot_metrics) + if args.disable_utmosv2 and "utmosv2" in violin_metrics: + violin_metrics.remove("utmosv2") + + plot_dir = os.path.join(args.out_dir, "plots") + create_violin_plot_if_available( + sample_rows, + violin_metrics, + os.path.join(plot_dir, f"{dataset_name}_violin.png"), + ) + + # Also write in eval_dir root with the same style used by MagpieTTS: + # f"{dataset}_violin_{repeat_idx}.png". Here the merged final output is repeat 0. + create_violin_plot_if_available( + sample_rows, + violin_metrics, + os.path.join(args.out_dir, f"{dataset_name}_violin_0.png"), + ) + + +def wait_for_files(paths: List[str], timeout_sec: float = 7200.0, poll_sec: float = 5.0): + start = time.time() + while True: + missing = [p for p in paths if not os.path.exists(p)] + if not missing: + return + if time.time() - start > timeout_sec: + raise TimeoutError("Timed out waiting for files:\n" + "\n".join(missing)) + time.sleep(poll_sec) + + +def wait_for_rank_metric_files(args, world_size: int): + paths = [os.path.join(args.out_dir, f"metrics_rank{r:04d}.json") for r in range(world_size)] + wait_for_files(paths) + + +def wait_for_rank_filewise_metric_files(args, world_size: int): + paths = [os.path.join(args.out_dir, f"filewise_metrics_rank{r:04d}.jsonl") for r in range(world_size)] + wait_for_files(paths) + + +def scalarize_metric_value(v: Any): + if torch.is_tensor(v): + if v.numel() == 1: + return float(v.detach().cpu().item()) + return v.detach().cpu().tolist() + try: + import numpy as np + + if isinstance(v, np.generic): + return float(v.item()) + except Exception: + pass + if isinstance(v, (int, float, str, bool)) or v is None: + return v + return str(v) + + +def metric_dict_to_jsonable(d: Dict[str, Any]) -> Dict[str, Any]: + return {str(k): scalarize_metric_value(v) for k, v in d.items()} + + +def safe_metric_scalar(metric_dict: Dict[str, Any], preferred_keys: List[str]): + for key in preferred_keys: + if key in metric_dict: + value = metric_dict[key] + if torch.is_tensor(value): + return float(value.detach().cpu().item()) + return float(value) + return None + + +def get_first_metric(metrics: Dict[str, Any], names: List[str], default=None): + for name in names: + if name in metrics: + return metrics[name] + return default + + + +def format_final_metric_text(final_metrics: Dict[str, Any]) -> str: + intelligibility = final_metrics.get("intelligibility", {}) + speaker_similarity = final_metrics.get("speaker_similarity", {}) + + cer = get_first_metric(intelligibility, ["cer", "cer_dataset", "cer_cumulative"]) + wer = get_first_metric(intelligibility, ["wer", "wer_dataset", "wer_cumulative"]) + ssim_value = get_first_metric( + speaker_similarity, + ["ssim", "ssim_dataset", "ssim_pred_context_avg", "pred_context_ssim"], + ) + + def fmt(x): + if x is None: + return "nan" + try: + return f"{float(x):.10f}" + except Exception: + return str(x) + + return f"Average CER: {fmt(cer)}\nAverage WER: {fmt(wer)}\nSSIM: {fmt(ssim_value)}\n" + + + +def format_filewise_final_metric_text(filewise_summary: Dict[str, Any]) -> str: + def fmt(x): + if x is None: + return "nan" + try: + return f"{float(x):.10f}" + except Exception: + return str(x) + + ordered_keys = [ + ("cer", "CER filewise avg"), + ("wer", "WER filewise avg"), + ("cer_cumulative", "CER cumulative"), + ("wer_cumulative", "WER cumulative"), + ("ssim", "SSIM"), + ("ssim_pred_gt_avg", "SSIM pred/GT avg"), + ("ssim_pred_context_avg", "SSIM pred/context avg"), + ("ssim_gt_context_avg", "SSIM GT/context avg"), + ("ssim_pred_gt_avg_alternate", "SSIM pred/GT avg alternate"), + ("ssim_pred_context_avg_alternate", "SSIM pred/context avg alternate"), + ("ssim_gt_context_avg_alternate", "SSIM GT/context avg alternate"), + ("cer_gt_audio_cumulative", "CER GT-audio cumulative"), + ("wer_gt_audio_cumulative", "WER GT-audio cumulative"), + ("utmosv2_avg", "UTMOSv2 avg"), + ("total_gen_audio_seconds", "Total generated audio seconds"), + ("frechet_codec_distance", "Frechet codec distance"), + ("eou_cutoff_rate", "EOU cutoff rate"), + ("eou_silence_rate", "EOU silence rate"), + ("eou_noise_rate", "EOU noise rate"), + ("eou_error_rate", "EOU error rate"), + ] + + lines = [ + f"Average CER: {fmt(filewise_summary.get('cer'))}", + f"Average WER: {fmt(filewise_summary.get('wer'))}", + f"SSIM: {fmt(filewise_summary.get('ssim'))}", + ] + + for key, label in ordered_keys: + if key in {"cer", "wer", "ssim"}: + continue + if key in filewise_summary: + lines.append(f"{label}: {fmt(filewise_summary.get(key))}") + + return "\n".join(lines) + "\n" + + + +def write_filewise_csv(path: str, rows: List[Dict[str, Any]]): + """Write sample-level filewise metrics. + + Several fields are lists (turn_ids, reference_text, asr_hyp, cer_turns, + etc.), so they are JSON-encoded inside CSV cells. + """ + os.makedirs(os.path.dirname(path), exist_ok=True) + tmp_path = path + ".tmp" + + fieldnames = [ + "run_id", + "dataset_index", + "rank", + "num_turns", + "cer", + "wer", + "ssim", + "pred_gt_ssim", + "pred_context_ssim", + "gt_context_ssim", + "pred_gt_ssim_alternate", + "pred_context_ssim_alternate", + "gt_context_ssim_alternate", + "utmosv2", + "eou_error", + "turn_ids", + "cer_turns", + "wer_turns", + "ssim_turns", + "pred_gt_ssim_turns", + "pred_context_ssim_turns", + "gt_context_ssim_turns", + "pred_gt_ssim_alternate_turns", + "pred_context_ssim_alternate_turns", + "gt_context_ssim_alternate_turns", + "utmosv2_turns", + "eou_type_turns", + "eou_trailing_duration_turns", + "eou_trail_rms_ratio_turns", + "pred_audio_seconds_turns", + "target_audio_path", + "context_audio_path", + "pred_audio_paths", + "predicted_codes_paths", + "sample_pred_audio_path", + "sample_predicted_codes_path", + "reference_text", + "asr_hyp", + ] + + def csv_value(v): + if isinstance(v, (list, dict)): + return json.dumps(v, ensure_ascii=False) + return v + + with open(tmp_path, "w", encoding="utf-8", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + for row in rows: + writer.writerow({k: csv_value(row.get(k, None)) for k in fieldnames}) + + os.replace(tmp_path, path) + + + +def write_turnwise_csv(path: str, rows: List[Dict[str, Any]]): + """Write merged turn-level filewise metrics sorted by CER.""" + os.makedirs(os.path.dirname(path), exist_ok=True) + tmp_path = path + ".tmp" + + fieldnames = [ + "run_id", + "dataset_index", + "turn_id", + "rank", + "cer", + "wer", + "ssim", + "pred_gt_ssim", + "pred_context_ssim", + "gt_context_ssim", + "pred_gt_ssim_alternate", + "pred_context_ssim_alternate", + "gt_context_ssim_alternate", + "utmosv2", + "eou_type", + "eou_trailing_duration", + "eou_trail_rms_ratio", + "pred_audio_seconds", + "target_audio_path", + "context_audio_path", + "pred_audio_path", + "predicted_codes_path", + "sample_pred_audio_path", + "sample_predicted_codes_path", + "reference_text", + "asr_hyp", + ] + + def csv_value(v): + if isinstance(v, (list, dict)): + return json.dumps(v, ensure_ascii=False) + return v + + with open(tmp_path, "w", encoding="utf-8", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + for row in rows: + writer.writerow({k: csv_value(row.get(k, None)) for k in fieldnames}) + + os.replace(tmp_path, path) + + +# ----------------------------- +# Dataset helpers +# ----------------------------- + + +def _combined_audio_name(first_audio_filepath: str, paths: List[str]) -> str: + base_names = [os.path.splitext(os.path.basename(p))[0] for p in paths if p] + ext = os.path.splitext(paths[-1])[1] if paths and paths[-1] else "" + combined_name = "_".join(base_names) + ext + dir_name = os.path.dirname(first_audio_filepath) + return os.path.join(dir_name, combined_name) if dir_name else combined_name class EvalJSONLDataset(Dataset): - def __init__(self, file_path, num_turns=1): + def __init__(self, file_path: str, emulate_multiturn_num_turns: int = 1): self.samples = [] raw_samples = [] + with open(file_path, "r", encoding="utf-8") as f: for line_idx, line in enumerate(f, 1): line = line.strip() - if not line: continue + if not line: + continue try: - raw_samples.append(json.loads(line)) + sample = json.loads(line) + sample["__dataset_index__"] = len(raw_samples) + raw_samples.append(sample) except json.JSONDecodeError as e: raise ValueError(f"Invalid JSON on line {line_idx}: {e}") - if num_turns <= 1: + if emulate_multiturn_num_turns <= 1: self.samples = raw_samples return @@ -140,45 +733,40 @@ def __init__(self, file_path, num_turns=1): if isinstance(sample["text"], list): self.samples.append(sample) else: - speaker = sample.get("speaker", "unknown") - if speaker not in single_turn_by_speaker: - single_turn_by_speaker[speaker] = [] - single_turn_by_speaker[speaker].append(sample) + speaker = sample.get("speaker", "unknown") + single_turn_by_speaker.setdefault(speaker, []).append(sample) - for speaker, speaker_samples in single_turn_by_speaker.items(): + synthetic_index = len(raw_samples) + for _, speaker_samples in single_turn_by_speaker.items(): buffer_texts, buffer_paths = [], [] first_sample_meta = None for sample in speaker_samples: if not buffer_texts: first_sample_meta = dict(sample) + buffer_texts.append(sample["text"]) buffer_paths.append(sample.get("audio_filepath", "")) - if len(buffer_texts) == num_turns: + if len(buffer_texts) == emulate_multiturn_num_turns: first_sample_meta["text"] = buffer_texts - base_names = [os.path.splitext(os.path.basename(p))[0] for p in buffer_paths if p] - ext = os.path.splitext(buffer_paths[-1])[1] if buffer_paths[-1] else "" - combined_name = "_".join(base_names) + ext - dir_name = os.path.dirname(first_sample_meta.get("audio_filepath", "")) - if dir_name: - first_sample_meta["audio_filepath"] = os.path.join(dir_name, combined_name) - else: - first_sample_meta["audio_filepath"] = combined_name - + first_sample_meta["audio_filepath"] = _combined_audio_name( + first_sample_meta.get("audio_filepath", ""), + buffer_paths, + ) + first_sample_meta["__dataset_index__"] = synthetic_index + synthetic_index += 1 self.samples.append(first_sample_meta) buffer_texts, buffer_paths, first_sample_meta = [], [], None - if buffer_texts: + if buffer_texts and first_sample_meta is not None: first_sample_meta["text"] = buffer_texts - base_names = [os.path.splitext(os.path.basename(p))[0] for p in buffer_paths if p] - ext = os.path.splitext(buffer_paths[-1])[1] if buffer_paths[-1] else "" - combined_name = "_".join(base_names) + ext - dir_name = os.path.dirname(first_sample_meta.get("audio_filepath", "")) - if dir_name: - first_sample_meta["audio_filepath"] = os.path.join(dir_name, combined_name) - else: - first_sample_meta["audio_filepath"] = combined_name + first_sample_meta["audio_filepath"] = _combined_audio_name( + first_sample_meta.get("audio_filepath", ""), + buffer_paths, + ) + first_sample_meta["__dataset_index__"] = synthetic_index + synthetic_index += 1 self.samples.append(first_sample_meta) def __len__(self): @@ -188,6 +776,53 @@ def __getitem__(self, idx): return self.samples[idx] +def _sample_text_segments_for_count(sample: Dict[str, Any], max_eval_turns=None) -> List[str]: + text_data = sample.get("text", "") + if isinstance(text_data, list): + segments = text_data + if max_eval_turns is not None: + segments = segments[: int(max_eval_turns)] + return [str(x) for x in segments] + return [str(text_data)] + + +def estimate_text_token_count(sample: Dict[str, Any], model, max_eval_turns=None) -> int: + main_tokenizer_name = list(model.cfg.text_tokenizers.keys())[0] + total = 0 + for segment in _sample_text_segments_for_count(sample, max_eval_turns=max_eval_turns): + total += len(model.tokenizer.encode(segment, tokenizer_name=main_tokenizer_name)) + 1 + return int(total) + + +class SortedByTextTokenCountDataset(Dataset): + def __init__(self, dataset: Dataset, model, max_eval_turns=None, descending: bool = True): + self.dataset = dataset + scored = [] + for i in range(len(dataset)): + sample = dict(dataset[i]) + token_count = estimate_text_token_count(sample, model=model, max_eval_turns=max_eval_turns) + sample["__text_token_count__"] = int(token_count) + scored.append((token_count, i, sample)) + + scored.sort(key=lambda x: (x[0], -x[1]), reverse=bool(descending)) + self.indices = [i for _, i, _ in scored] + self.token_counts = {i: int(tok) for tok, i, _ in scored} + + def __len__(self): + return len(self.indices) + + def __getitem__(self, local_idx): + original_idx = self.indices[local_idx] + sample = dict(self.dataset[original_idx]) + sample["__text_token_count__"] = self.token_counts[original_idx] + return sample + + +# ----------------------------- +# Audio / collate helpers +# ----------------------------- + + def _resolve_audio_path(path, root_path): if path is None: return None @@ -217,435 +852,303 @@ def _load_audio(path, sample_rate, normalize=True, use_librosa=False): return resample(wav, sr, sample_rate).squeeze(0) -def _json_metric_value(value): - if torch.is_tensor(value): - value = value.detach().cpu() - if value.numel() == 1: - return value.item() - return value.tolist() - return value - - -def _write_json(path, data): - output_dir = os.path.dirname(path) - if output_dir: - os.makedirs(output_dir, exist_ok=True) - with open(path, "w", encoding="utf-8") as f: - json.dump(data, f, indent=2, ensure_ascii=False) - - -def _write_filewise_metrics(path, filewise_metrics): - sorted_metrics = sorted(filewise_metrics, key=lambda row: row.get("cer", float("-inf")), reverse=True) - _write_json(path, sorted_metrics) - - -def _compute_secs_per_sample(secs_metric, name, target_audio, target_audio_lens, pred_audio, pred_audio_lens): - if secs_metric.speaker_encoder is None: - secs_metric.reset() - - with fp32_precision(): - with torch.no_grad(): - _, t_g = secs_metric.speaker_encoder( - input_signal=target_audio, input_signal_length=target_audio_lens.long() - ) - _, s_g = secs_metric.speaker_encoder( - input_signal=pred_audio, input_signal_length=pred_audio_lens.long() - ) - secs = torch.nn.functional.cosine_similarity(t_g, s_g, dim=-1) - - secs_metric._secs[name].append(secs.mean()) - return secs.detach().cpu() def collate_and_tokenize_custom( batch, model, - extra_duration_thrshould=1.3, sample_rate=22050, root_path=None, - emulate_duplex_inference=False, - add_interruption_token=False, - pad_factor_text_speech=10, - force_interruption=False, normalize_audio_volume=True, use_librosa=False, - profile_multiturn_inference=False, max_eval_turns=None, + inference_mode="auto", ): + """Collate for either multiturn-user-audio or regular single-turn inference. + + Mode selection: + - multiturn_user_audio: turn-based multiturn user-audio prefill with user_audio_file_path. + - single_turn: regular batched TTS, no user-speech/silence prefill. + - auto: multiturn_user_audio when samples look multiturn/user-conditioned; otherwise + single_turn. This keeps old LibriTTS commands working with batch_size=32. + """ main_tokenizer_name = list(model.cfg.text_tokenizers.keys())[0] - + if max_eval_turns is not None: max_eval_turns = int(max_eval_turns) if max_eval_turns <= 0: raise ValueError("--max_eval_turns must be > 0 when provided.") - truncated_batch = [] - for s in batch: - s = dict(s) - - if isinstance(s["text"], list): - s["text"] = s["text"][:max_eval_turns] - - if isinstance(s.get("user_audio_file_path"), list): - s["user_audio_file_path"] = s["user_audio_file_path"][:max_eval_turns] - - truncated_batch.append(s) - + for sample in batch: + sample = dict(sample) + if isinstance(sample["text"], list): + sample["text"] = sample["text"][:max_eval_turns] + if isinstance(sample.get("user_audio_file_path"), list): + sample["user_audio_file_path"] = sample["user_audio_file_path"][:max_eval_turns] + truncated_batch.append(sample) batch = truncated_batch - # --- MULTI-TURN MODE DECISION --- - is_profile = profile_multiturn_inference - is_duplex = emulate_duplex_inference and not is_profile + def looks_multiturn_user_audio(sample): + return isinstance(sample.get("text"), list) or bool(sample.get("user_audio_file_path", None)) + + if inference_mode == "multiturn_user_audio": + is_multiturn_user_audio = True + elif inference_mode == "single_turn": + is_multiturn_user_audio = False + elif inference_mode == "auto": + is_multiturn_user_audio = any(looks_multiturn_user_audio(sample) for sample in batch) + else: + raise ValueError(f"Unknown inference_mode={inference_mode}") out_dict = { - "duplex_multiturn": is_duplex, - "regular_multiturn": (not is_duplex) and (not is_profile), - "profile_multiturn": is_profile, + "multiturn_user_audio": bool(is_multiturn_user_audio), + "dataset_indices": [int(s.get("__dataset_index__", -1)) for s in batch], } - tokenized_list = [] - batched_turns = [] - batched_turn_lens = [] - valid_turn_masks = [] - - if is_duplex: - # ------------------------------------------------------------- - # DUPLEX MODE (Continuous sequence with 10x pad injection) - # ------------------------------------------------------------- - for s in batch: - text_data = s["text"] - - if isinstance(text_data, list): - full_ids = [] - for segment in text_data: - seg_ids = model.tokenizer.encode(segment, tokenizer_name=main_tokenizer_name) + [model.eos_id] - seg_len = len(seg_ids) - pad_len = seg_len * pad_factor_text_speech - pad_ids = [model.pad_id] * pad_len - - if force_interruption: - fname = s["audio_filepath"] - no_ext = fname.split(".")[0] - sample_id = int(no_ext.split("_")[-1]) - case = sample_id % 3 - - if case == 0: - if len(seg_ids) >= 2: - seg_ids[-2] = model.interruption_token_id - seg_ids[-1] = model.pad_id - else: - pad_ids[0] = model.interruption_token_id - elif case == 1: - eos_idx = min(6, len(pad_ids) - 1) - pad_ids[eos_idx] = model.interruption_token_id - else: - eos_idx = 0 - pad_ids[eos_idx] = model.interruption_token_id - else: - if add_interruption_token: - eos_idx = int(len(pad_ids) * 0.7) - pad_ids[eos_idx] = model.interruption_token_id - - full_ids.extend(seg_ids) - full_ids.extend(pad_ids) + if is_multiturn_user_audio: + max_turns = 1 + for sample in batch: + if isinstance(sample["text"], list): + max_turns = max(max_turns, len(sample["text"])) - tokenized_list.append(torch.as_tensor(full_ids, dtype=torch.long)) + raw_turn_texts = [] + for sample in batch: + if isinstance(sample["text"], list): + raw_turn_texts.append([str(x) for x in sample["text"]]) else: - tokenized_list.append( - torch.as_tensor(model.tokenizer.encode(text_data, tokenizer_name=main_tokenizer_name) + [model.eos_id], dtype=torch.long) - ) - - pad_len = 25 - prefix = torch.full((pad_len,), model.pad_id, dtype=torch.long) - for i in range(len(tokenized_list)): - tokenized_list[i] = torch.cat([prefix, tokenized_list[i]]) - input_lengths = torch.tensor([len(x) for x in tokenized_list], dtype=torch.long) - - input_ids = pad_sequence(tokenized_list, batch_first=True, padding_value=model.pad_id) - - out_dict["input_ids"] = input_ids - out_dict["input_lengths"] = input_lengths - - else: - # ------------------------------------------------------------- - # REGULAR MODE (Turn-by-turn discrete packaging) - # ------------------------------------------------------------- - max_turns = 1 - for s in batch: - if isinstance(s["text"], list): - max_turns = max(max_turns, len(s["text"])) - - for t in range(max_turns): - turn_t_tokens = [] - turn_t_lens = [] - turn_t_valid = [] - - for s in batch: - text_data = s["text"] + raw_turn_texts.append([str(sample["text"])]) + + batched_turns = [] + batched_turn_lens = [] + valid_turn_masks = [] + + for turn_id in range(max_turns): + turn_tokens = [] + turn_lens = [] + turn_valid = [] + for sample in batch: + text_data = sample["text"] if isinstance(text_data, list): - if t < len(text_data): - seg_ids = model.tokenizer.encode(text_data[t], tokenizer_name=main_tokenizer_name) + [model.eos_id] - turn_t_tokens.append(torch.as_tensor(seg_ids, dtype=torch.long)) - turn_t_lens.append(len(seg_ids)) - turn_t_valid.append(True) + if turn_id < len(text_data): + seg_ids = model.tokenizer.encode(text_data[turn_id], tokenizer_name=main_tokenizer_name) + [ + model.eos_id + ] + turn_tokens.append(torch.as_tensor(seg_ids, dtype=torch.long)) + turn_lens.append(len(seg_ids)) + turn_valid.append(True) else: - # Dummy pad to keep shapes consistent for items with fewer turns - turn_t_tokens.append(torch.as_tensor([model.pad_id], dtype=torch.long)) - turn_t_lens.append(1) - turn_t_valid.append(False) + turn_tokens.append(torch.as_tensor([model.pad_id], dtype=torch.long)) + turn_lens.append(1) + turn_valid.append(False) else: - if t == 0: + if turn_id == 0: seg_ids = model.tokenizer.encode(text_data, tokenizer_name=main_tokenizer_name) + [model.eos_id] - turn_t_tokens.append(torch.as_tensor(seg_ids, dtype=torch.long)) - turn_t_lens.append(len(seg_ids)) - turn_t_valid.append(True) + turn_tokens.append(torch.as_tensor(seg_ids, dtype=torch.long)) + turn_lens.append(len(seg_ids)) + turn_valid.append(True) else: - turn_t_tokens.append(torch.as_tensor([model.pad_id], dtype=torch.long)) - turn_t_lens.append(1) - turn_t_valid.append(False) - - padded_turn_t = pad_sequence(turn_t_tokens, batch_first=True, padding_value=model.pad_id) - batched_turns.append(padded_turn_t) - batched_turn_lens.append(torch.tensor(turn_t_lens, dtype=torch.long)) - valid_turn_masks.append(torch.tensor(turn_t_valid, dtype=torch.bool)) - - out_dict["batched_turns"] = batched_turns - out_dict["batched_turn_lens"] = batched_turn_lens - out_dict["valid_turn_masks"] = valid_turn_masks + turn_tokens.append(torch.as_tensor([model.pad_id], dtype=torch.long)) + turn_lens.append(1) + turn_valid.append(False) - # --- AUDIO LOADING --- - audio_list = [] - audio_lengths = [] - target_num_frames = [] + batched_turns.append(pad_sequence(turn_tokens, batch_first=True, padding_value=model.pad_id)) + batched_turn_lens.append(torch.tensor(turn_lens, dtype=torch.long)) + valid_turn_masks.append(torch.tensor(turn_valid, dtype=torch.bool)) - max_turns_for_user_audio = len(batched_turns) if (not is_duplex) else 0 + user_audio_by_turn = [[] for _ in range(max_turns)] + user_audio_lens_by_turn = [[] for _ in range(max_turns)] - if is_profile and max_turns_for_user_audio > 0: - user_audio_by_turn = [[] for _ in range(max_turns_for_user_audio)] - user_audio_lens_by_turn = [[] for _ in range(max_turns_for_user_audio)] else: + # Single-turn regular inference: one text segment per sample, batched. + raw_turn_texts = [] + single_turn_tokens = [] + single_turn_lens = [] + for sample in batch: + text_data = sample["text"] + if isinstance(text_data, list): + text = " ".join(str(x) for x in text_data) + else: + text = str(text_data) + raw_turn_texts.append([text]) + seg_ids = model.tokenizer.encode(text, tokenizer_name=main_tokenizer_name) + [model.eos_id] + single_turn_tokens.append(torch.as_tensor(seg_ids, dtype=torch.long)) + single_turn_lens.append(len(seg_ids)) + + out_dict["input_ids"] = pad_sequence(single_turn_tokens, batch_first=True, padding_value=model.pad_id) + out_dict["input_lengths"] = torch.tensor(single_turn_lens, dtype=torch.long) user_audio_by_turn = [] user_audio_lens_by_turn = [] - for i, s in enumerate(batch): - audio_path = _resolve_audio_path(s.get("context_audio_filepath"), root_path) - wav = _load_audio( - audio_path, - sample_rate, - normalize=normalize_audio_volume, - use_librosa=use_librosa, - ) - - audio_list.append(wav) - audio_lengths.append(len(wav)) + audio_list = [] + audio_lengths = [] - # Optional per-turn user audio. - # Expected JSONL field: - # "user_audio": ["turn0_user.wav", "turn1_user.wav", ...] - if is_profile and max_turns_for_user_audio > 0: - user_audio_paths = s.get("user_audio_file_path", None) + for i, sample in enumerate(batch): + context_path = _resolve_audio_path(sample.get("context_audio_filepath"), root_path) + context_wav = _load_audio(context_path, sample_rate, normalize=normalize_audio_volume, use_librosa=use_librosa) + audio_list.append(context_wav) + audio_lengths.append(len(context_wav)) - for t in range(max_turns_for_user_audio): + if is_multiturn_user_audio: + user_audio_paths = sample.get("user_audio_file_path", None) + for turn_id in range(len(user_audio_by_turn)): has_valid_text_turn = ( - isinstance(s["text"], list) and t < len(s["text"]) - ) or ( - not isinstance(s["text"], list) and t == 0 - ) + isinstance(sample["text"], list) and turn_id < len(sample["text"]) + ) or ((not isinstance(sample["text"], list)) and turn_id == 0) if ( isinstance(user_audio_paths, list) - and t < len(user_audio_paths) - and user_audio_paths[t] + and turn_id < len(user_audio_paths) + and user_audio_paths[turn_id] and has_valid_text_turn ): - ua_path = _resolve_audio_path(user_audio_paths[t], root_path) - ua_wav = _load_audio( - ua_path, + user_path = _resolve_audio_path(user_audio_paths[turn_id], root_path) + user_wav = _load_audio( + user_path, sample_rate=sample_rate, normalize=normalize_audio_volume, use_librosa=use_librosa, ) else: - print("User audio not founded, using silence two seconds audio") - ua_wav = torch.zeros(int(2 * sample_rate), dtype=torch.float32) - - user_audio_by_turn[t].append(ua_wav) - user_audio_lens_by_turn[t].append(len(ua_wav)) + user_wav = torch.zeros(int(2 * sample_rate), dtype=torch.float32) - tdur_audio_path = _resolve_audio_path(s["audio_filepath"], root_path) - - if tdur_audio_path and os.path.exists(tdur_audio_path): - wav_dur = _load_audio(tdur_audio_path, sample_rate, normalize=normalize_audio_volume, use_librosa=use_librosa) - tdur = wav_dur.shape[0] // model.input_samples_per_frame - target_num_frames.append(tdur * extra_duration_thrshould) - else: - # Fallback estimation - if is_duplex: - current_text_len = len(tokenized_list[i]) - if isinstance(s["text"], list): - target_num_frames.append(current_text_len) - else: - target_num_frames.append(current_text_len * 5) - else: - target_num_frames.append(sum([l[i].item() for l in batched_turn_lens]) * 5) + user_audio_by_turn[turn_id].append(user_wav) + user_audio_lens_by_turn[turn_id].append(len(user_wav)) max_audio_len = max(audio_lengths) - B = len(audio_lengths) - padded_audio = torch.zeros((B, max_audio_len), dtype=torch.float32) - + batch_size = len(audio_lengths) + padded_audio = torch.zeros((batch_size, max_audio_len), dtype=torch.float32) for i, wav in enumerate(audio_list): padded_audio[i, : len(wav)] = wav - - if is_profile and max_turns_for_user_audio > 0: + if is_multiturn_user_audio: padded_user_audio_turns = [] - padded_user_audio_turns_lens = [] - - for t in range(max_turns_for_user_audio): - turn_lens = user_audio_lens_by_turn[t] + padded_user_audio_turn_lens = [] + for turn_id in range(len(user_audio_by_turn)): + turn_lens = user_audio_lens_by_turn[turn_id] max_turn_audio_len = max(turn_lens) - padded_turn_audio = torch.zeros((B, max_turn_audio_len), dtype=torch.float32) - - for i, wav in enumerate(user_audio_by_turn[t]): + padded_turn_audio = torch.zeros((batch_size, max_turn_audio_len), dtype=torch.float32) + for i, wav in enumerate(user_audio_by_turn[turn_id]): padded_turn_audio[i, : len(wav)] = wav - padded_user_audio_turns.append(padded_turn_audio) - padded_user_audio_turns_lens.append(torch.tensor(turn_lens, dtype=torch.long)) + padded_user_audio_turn_lens.append(torch.tensor(turn_lens, dtype=torch.long)) + out_dict["batched_turns"] = batched_turns + out_dict["batched_turn_lens"] = batched_turn_lens + out_dict["valid_turn_masks"] = valid_turn_masks out_dict["user_audio_turns"] = padded_user_audio_turns - out_dict["user_audio_turns_lens"] = padded_user_audio_turns_lens + out_dict["user_audio_turns_lens"] = padded_user_audio_turn_lens out_dict["context_audio"] = padded_audio out_dict["context_audio_lengths"] = torch.tensor(audio_lengths, dtype=torch.long) - out_dict["context_audio_paths"] = [s.get("context_audio_filepath") for s in batch] out_dict["target_audio_paths"] = [s["audio_filepath"] for s in batch] - out_dict["target_num_frames"] = target_num_frames - - out_dict["raw_text"] = [" ".join(s["text"]) if isinstance(s["text"], list) else s["text"] for s in batch] + out_dict["raw_text"] = [" ".join(x) for x in raw_turn_texts] + out_dict["raw_turn_texts"] = raw_turn_texts return out_dict -def _mix_user_turns_on_timeline(user_audio_turns, user_audio_turns_lens, sample_rate): - total_len = int(sum(x.item() for x in user_audio_turns_lens)) - mixed = torch.zeros(total_len, dtype=torch.float32) - - offset = 0 - for wav, wav_len in zip(user_audio_turns, user_audio_turns_lens): - wav_len = int(wav_len.item()) - mixed[offset : offset + wav_len] = wav[:wav_len] - offset += wav_len - - return mixed - -def main(): - parser = argparse.ArgumentParser(description="EasyMagpieTTS Inference Evaluation") - - # Required Paths - parser.add_argument("--checkpoint_path", type=str, required=True, help="Path to the EasyMagpie model") - parser.add_argument("--codec_model_path", type=str, required=True, help="Path to the audio codec") - parser.add_argument("--datasets_json_path", type=str, required=True, help="Path to JSONL data") - parser.add_argument("--out_dir", type=str, required=True, help="Directory to save audio outputs") - - # Optional Paths & General - parser.add_argument("--phoneme_tokenizer_path", type=str, default=None) - parser.add_argument("--audio_dir", type=str, default=None, help="Root dir for audio paths in JSONL") - parser.add_argument("--inference_dtype", type=str, default="float32") - parser.add_argument("--debug_dtype", action="store_true") - parser.add_argument("--use_librosa", action="store_true", help="Use librosa instead of soundfile+torch for audio load") - parser.add_argument( - "--filewise_metrics_path", - type=str, - default=None, - help="Path to save per-file metrics JSON. Defaults to /filewise_metrics.json", - ) - parser.add_argument( - "--aggregate_metrics_path", - type=str, - default=None, - help="Path to save aggregate metrics JSON. Defaults to /aggregate_metrics.json", - ) - - # Dataloader & Batching - parser.add_argument("--batch_size", type=int, default=6) - parser.add_argument("--num_workers", type=int, default=4) - parser.add_argument("--num_turns", type=int, default=1) - parser.add_argument("--pad_factor_text_speech", type=int, default=10) - - # Text Processing Boolean Flags - parser.add_argument("--emulate_duplex_inference", action="store_true") - parser.add_argument("--add_interruption_token", action="store_true") - parser.add_argument("--force_interruption", action="store_true") - parser.add_argument("--profile_multiturn_inference", action="store_true") - parser.add_argument("--profile_pad_min_sec", type=float, default=2.0) - parser.add_argument("--profile_pad_max_sec", type=float, default=2.0) - parser.add_argument( - "--max_eval_turns", - type=int, - default=6, - help="Maximum number of turns to evaluate per sample. None means use all turns.", - ) - - # Speaker & Prompt Configurations - parser.add_argument("--user_custom_speaker_reference", action="store_true") - parser.add_argument("--inference_speaker_reference", type=str, default=None) - parser.add_argument("--language", type=str, default="en") - - # Generation Kwargs - parser.add_argument("--use_cfg", action="store_true") - parser.add_argument("--cfg_scale", type=float, default=2.5) - parser.add_argument("--temperature", type=float, default=0.7) - parser.add_argument("--topk", type=int, default=80) - parser.add_argument("--max_tts_steps", type=int, default=2000) - parser.add_argument("--force_speech_sil_codes", action="store_true") - parser.add_argument("--normalize_volume", type=lambda x: (str(x).lower() in ['true', '1', 'yes']), default=True) +# ----------------------------- +# Model / generation +# ----------------------------- - args = parser.parse_args() - if args.profile_multiturn_inference and args.batch_size != 1: - raise RuntimeError("--profile_multiturn_inference currently requires --batch_size=1.") - - if args.profile_pad_max_sec < args.profile_pad_min_sec: - raise RuntimeError("--profile_pad_max_sec must be >= --profile_pad_min_sec.") +def attach_dtype_counter(model): + handles = [] + stats = {} + examples = {} - filewise_metrics_path = args.filewise_metrics_path or os.path.join(args.out_dir, "filewise_metrics.json") - aggregate_metrics_path = args.aggregate_metrics_path or os.path.join(args.out_dir, "aggregate_metrics.json") + def is_leaf(module): + return len(list(module.children())) == 0 - distributed = int(os.environ.get("WORLD_SIZE", "1")) > 1 - if distributed and not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="nccl") + def get_dtype(x): + if torch.is_tensor(x): + return str(x.dtype) + if isinstance(x, (list, tuple)): + for t in x: + if torch.is_tensor(t): + return str(t.dtype) + return "other" - target_device = torch.device(f"cuda:{int(os.environ.get('LOCAL_RANK', 0))}" if torch.cuda.is_available() else "cpu") - target_dtype = getattr(torch, args.inference_dtype) - torch.set_default_dtype(target_dtype) + def get_module_group(name): + return name.split(".")[0] if "." in name else name - model_cfg = EasyMagpieTTSInferenceModel.restore_from(args.checkpoint_path, return_config=True) - with open_dict(model_cfg): - model_cfg.target = 'nemo.collections.tts.models.easy_magpietts_inference.EasyMagpieTTSInferenceModel' - model_cfg.codecmodel_path = args.codec_model_path - model_cfg.train_ds = None - model_cfg.validation_ds = None - model_cfg.run_val_inference = False + def hook_fn(name): + def fn(module, inputs, outputs): + dtype = get_dtype(outputs) + if dtype not in ["torch.float16", "torch.bfloat16", "torch.float32"]: + dtype = "other" + group = get_module_group(name) + if group not in stats: + stats[group] = { + "torch.float16": 0, + "torch.bfloat16": 0, + "torch.float32": 0, + "other": 0, + } + examples[group] = { + "torch.float16": [], + "torch.bfloat16": [], + "torch.float32": [], + "other": [], + } + stats[group][dtype] += 1 + if len(examples[group][dtype]) < 3: + examples[group][dtype].append(module.__class__.__name__) + + return fn + + for name, module in model.named_modules(): + if is_leaf(module): + handles.append(module.register_forward_hook(hook_fn(name))) + return handles, stats, examples + + +def report_dtype_stats(handles, stats, examples, rank=0): + for h in handles: + h.remove() + logging.info(f"[rank={rank}] === DTYPE USAGE PER MODULE ===") + for group, group_stats in stats.items(): + total = sum(group_stats.values()) + if total == 0: + continue + logging.info(f"[rank={rank}] --- {group} ---") + for dtype, count in group_stats.items(): + if count > 0: + logging.info(f"[rank={rank}] {dtype}: {count} ({100 * count / total:.2f}%)") + logging.info(f"[rank={rank}] === DTYPE EXAMPLES ===") + for group, group_examples in examples.items(): + for dtype, mods in group_examples.items(): + if mods: + logging.info(f"[rank={rank}] {group} {dtype}: {mods}") + + +def build_model_and_codec(args, target_device, target_dtype): + model_cfg = EasyMagpieTTSInferenceModel.restore_from(args.checkpoint_path, return_config=True) + + with open_dict(model_cfg): + model_cfg.target = "nemo.collections.tts.models.easy_magpietts_inference.EasyMagpieTTSInferenceModel" + model_cfg.codecmodel_path = args.codec_model_path + model_cfg.train_ds = None + model_cfg.validation_ds = None + model_cfg.run_val_inference = False model_cfg.use_utmos = False model_cfg.use_meta_init_for_decoder = True - # Guarantees silence for pad tokens - # model_cfg.use_multiturn_dataset = True + if args.phoneme_tokenizer_path and getattr(model_cfg, "phoneme_tokenizer", None) is not None: model_cfg.phoneme_tokenizer.tokenizer_path = args.phoneme_tokenizer_path - # Load to CPU first to prevent OOM model = EasyMagpieTTSInferenceModel.restore_from( - args.checkpoint_path, override_config_path=model_cfg, map_location=torch.device("cpu") + args.checkpoint_path, + override_config_path=model_cfg, + map_location=torch.device("cpu"), ) model.use_kv_cache_for_inference = True model.to(dtype=target_dtype) model.eval().to(target_device) - # --- DATALOADER COMPATIBILITY PATCHES --- model.input_samples_per_frame = int(model.codec_model_samples_per_frame * model.frame_stacking_factor) model.target_samples_per_frame = model.input_samples_per_frame / (model.sample_rate / model.output_sample_rate) - # Load to CPU first to prevent OOM codec_model = AudioCodecModel.restore_from(args.codec_model_path, strict=False, map_location=torch.device("cpu")) if hasattr(codec_model, "discriminator"): del codec_model.discriminator @@ -660,741 +1163,2077 @@ def main(): vector_quantizer_new=vq_new, ).to(target_device).eval() - if not hasattr(model, "_codec_helper") or model._codec_helper is None: - model._codec_helper = CodecHelper(codec_model=codec_model, codec_converter=codec_converter) - + model._codec_helper = CodecHelper(codec_model=codec_model, codec_converter=codec_converter) model._generate_codec_silence_buffer() - codec_sil_codes = model.codec_sil_codes - if args.debug_dtype: - handles, stats, examples = attach_dtype_counter(model) + return model - with fp32_precision(): - intelligibility = Intelligibility("stt_en_fastconformer_transducer_large", reuse_asr_hyps=False).reset() - secs_metric = SECS("titanet_large").reset() - eval_dataset = EvalJSONLDataset(args.datasets_json_path, num_turns=args.num_turns) - # debug - # eval_dataset.samples = eval_dataset.samples[:100] +def prepare_inputs_for_device(inputs, model, args, target_dtype, speaker_wav=None): + B = inputs["context_audio"].size(0) + device = model.device - collate_fn = partial( - collate_and_tokenize_custom, - model=model, - extra_duration_thrshould=1.5, - sample_rate=model.sample_rate, - root_path=args.audio_dir, - emulate_duplex_inference=args.emulate_duplex_inference, - add_interruption_token=args.add_interruption_token, - pad_factor_text_speech=args.pad_factor_text_speech, - force_interruption=args.force_interruption, - normalize_audio_volume=args.normalize_volume, - use_librosa=args.use_librosa, - profile_multiturn_inference=args.profile_multiturn_inference, - max_eval_turns=args.max_eval_turns, - ) + inputs["context_audio"] = inputs["context_audio"].to(device, dtype=target_dtype) + inputs["context_audio_lengths"] = inputs["context_audio_lengths"].to(device) - dataloader = DataLoader( - dataset=eval_dataset, batch_size=args.batch_size, collate_fn=collate_fn, - num_workers=args.num_workers, pin_memory=True, shuffle=False, drop_last=False, - ) + if args.user_custom_speaker_reference and speaker_wav is not None: + inputs["context_audio"] = speaker_wav.repeat(B, 1).detach() + inputs["context_audio_lengths"] = torch.full((B,), speaker_wav.size(-1), dtype=torch.long, device=device) + + if "user_audio_turns" in inputs: + inputs["user_audio_turns"] = [x.to(device, dtype=target_dtype) for x in inputs["user_audio_turns"]] + inputs["user_audio_turns_lens"] = [x.to(device) for x in inputs["user_audio_turns_lens"]] + + return inputs + + + + +def run_single_turn_generation(model, inputs, args): + """Regular batched single-turn EasyMagpieTTS generation. + + This path does not prefill with user speech or synthetic silence. It is for + classic single-turn datasets such as LibriTTS and supports batch_size > 1. + """ + B = inputs["context_audio"].size(0) + device = model.device + + with torch.inference_mode(): + wav = inputs["context_audio"] + wav_len = inputs["context_audio_lengths"] + codes, codes_lens = model._codec_helper.audio_to_codes(wav, wav_len) + + use_lang = bool(getattr(model, "add_language_to_context_text", False)) + ctx_text = f"[{args.language.upper()}]" if use_lang else "[NO TEXT CONTEXT]" + ctx_text_ids = model.tokenizer.encode(ctx_text, tokenizer_name=model.text_conditioning_tokenizer_name) + ctx_toks = torch.tensor([ctx_text_ids], dtype=torch.long, device=device).expand(B, -1) + ctx_toks_lens = torch.tensor([len(ctx_text_ids)] * B, dtype=torch.long, device=device) + + state = model.streaming_init( + context_audio_codes=codes, + context_audio_codes_lens=codes_lens, + context_text_tokens=ctx_toks, + context_text_tokens_lens=ctx_toks_lens, + use_cfg=args.use_cfg, + cfg_scale=args.cfg_scale, + use_local_transformer=True, + temperature=args.temperature, + topk=args.topk, + phoneme_input_type="pred", + phoneme_sampling_method="argmax", + use_inference_mode=True, + ) + + text = inputs["input_ids"].to(device) + text_lens = inputs["input_lengths"].to(device) + + turn_offsets = torch.zeros(B, dtype=torch.long, device=device) + turn_steps = 0 + + while not state.finished.all() and turn_steps < args.max_tts_steps: + turn_steps += 1 + relative_positions = state.text_tokens_seen - turn_offsets + positions = relative_positions.clamp(min=0, max=text.size(1) - 1) + current_tokens = text[torch.arange(B, device=device), positions] + exhausted = relative_positions >= text_lens + current_tokens = torch.where( + exhausted, + torch.full_like(current_tokens, model.eos_id), + current_tokens, + ) + state, _, _ = model.streaming_step( + state=state, + text_tokens=current_tokens, + use_inference_mode=True, + ) + + generated_codes = None + if getattr(state, "all_predictions", None): + try: + generated_codes = torch.cat(state.all_predictions, dim=-1).detach() + except Exception: + generated_codes = None + + finalize_output = model.streaming_finalize(state, use_inference_mode=True) + + # For single turn, there is one generated segment per sample and no + # multiturn frame alignment needed. + return finalize_output, [], 0, generated_codes + +def run_generation(model, inputs, args, codec_sil_codes): + """Run either multiturn-user-audio or regular single-turn generation.""" + if not inputs.get("multiturn_user_audio", False): + return run_single_turn_generation(model, inputs, args) + + B = inputs["context_audio"].size(0) + if B != 1: + raise RuntimeError("Multiturn user-audio inference requires --batch_size=1 per process.") + + device = model.device + multiturn_turn_frame_ranges = [] + multiturn_decode_start_frame = 0 + + with torch.inference_mode(): + wav = inputs["context_audio"] + wav_len = inputs["context_audio_lengths"] + codes, codes_lens = model._codec_helper.audio_to_codes(wav, wav_len) + + use_lang = bool(getattr(model, "add_language_to_context_text", False)) + ctx_text = f"[{args.language.upper()}]" if use_lang else "[NO TEXT CONTEXT]" + ctx_text_ids = model.tokenizer.encode(ctx_text, tokenizer_name=model.text_conditioning_tokenizer_name) + ctx_toks = torch.tensor([ctx_text_ids], dtype=torch.long, device=device).expand(B, -1) + ctx_toks_lens = torch.tensor([len(ctx_text_ids)] * B, dtype=torch.long, device=device) + + state = model.streaming_init( + context_audio_codes=codes, + context_audio_codes_lens=codes_lens, + context_text_tokens=ctx_toks, + context_text_tokens_lens=ctx_toks_lens, + use_cfg=args.use_cfg, + cfg_scale=args.cfg_scale, + use_local_transformer=True, + temperature=args.temperature, + topk=args.topk, + phoneme_input_type="pred", + phoneme_sampling_method="argmax", + use_inference_mode=True, + ) + + batched_turns = inputs["batched_turns"] + batched_turn_lens = inputs["batched_turn_lens"] + valid_turn_masks = inputs["valid_turn_masks"] + + for turn_id in range(len(batched_turns)): + turn_text = batched_turns[turn_id].to(device) + turn_lens = batched_turn_lens[turn_id].to(device) + valid_mask = valid_turn_masks[turn_id].to(device) + if not bool(valid_mask[0].item()): + continue + + state.finished.zero_() + state.text_finished.zero_() + state.audio_prediction_end_idx.fill_(-1) + if hasattr(state, "turn_text_tokens_seen"): + state.turn_text_tokens_seen.zero_() + if hasattr(state, "phoneme_steps"): + state.phoneme_steps.zero_() + if hasattr(state, "phoneme_stream_ended"): + state.phoneme_stream_ended.zero_() + if hasattr(state, "phoneme_eos_detected"): + state.phoneme_eos_detected.zero_() + state.last_phoneme_tokens = None + + if not model.cfg.get("condition_on_user_speech", False): + user_audio = inputs["user_audio_turns"][turn_id] + user_audio_prefill_steps = int(round(user_audio.size(-1) / model.input_samples_per_frame)) + user_audio_prefill_seconds = user_audio_prefill_steps * model.input_samples_per_frame / model.sample_rate + user_audio_prefill_tokens = torch.full((1, user_audio_prefill_steps), model.pad_id, dtype=torch.long, device=device) + user_audio_channel_embedding = None + else: + user_audio = inputs["user_audio_turns"][turn_id] + user_audio_lens = inputs["user_audio_turns_lens"][turn_id] + user_audio_codes, user_audio_codes_lens = model._codec_helper.audio_to_codes(user_audio, user_audio_lens) + + if model._codec_converter is not None: + user_audio_codes = model._codec_converter.convert_original_to_new( + audio_tokens=user_audio_codes, + audio_lens=user_audio_codes_lens, + ).long() + + user_audio_codes, user_audio_codes_lens = model.stack_codes( + user_audio_codes, + user_audio_codes_lens, + model.audio_bos_id, + model.audio_eos_id, + model.frame_stacking_factor, + model.num_audio_codebooks, + ) + user_audio_embedded = model.embed_audio_tokens(user_audio_codes) + + boundary_trim = model.cfg.get("user_audio_boundary_trim", 4) + boundary_trim = 0 if boundary_trim is None else int(boundary_trim) + if boundary_trim == 0: + real_start = 0 + real_end = int(user_audio_codes_lens[0].item()) + else: + turn_len_with_special = int(user_audio_codes_lens[0].item()) + real_start = 1 + real_end = max(real_start, turn_len_with_special - 1) + + user_audio_embedded = user_audio_embedded[:, real_start:real_end] + copy_len = user_audio_embedded.size(1) + if boundary_trim > 0: + trim = min(boundary_trim, copy_len // 2) + if trim > 0: + user_audio_embedded[:, :trim] = 0.0 + user_audio_embedded[:, copy_len - trim :] = 0.0 + + bos_user_pad = torch.zeros( + user_audio_embedded.size(0), + 1, + user_audio_embedded.size(2), + device=user_audio_embedded.device, + dtype=user_audio_embedded.dtype, + ) + user_audio_embedded = torch.cat([bos_user_pad, user_audio_embedded], dim=1) + user_audio_prefill_steps = user_audio_embedded.size(1) + user_audio_prefill_tokens = torch.full((B, user_audio_prefill_steps), model.pad_id, dtype=torch.long, device=device) + user_audio_channel_embedding = user_audio_embedded + user_audio_prefill_seconds = user_audio_prefill_steps * model.input_samples_per_frame / model.sample_rate + + delay_tokens = int(state.config.training_mode.streaming_speech_delay) + delay_tokens = min(delay_tokens, int(turn_lens[0].item()), user_audio_prefill_steps) + + warmup_tokens = turn_text[:, :delay_tokens] + turn_text = turn_text[:, delay_tokens:] + turn_lens = torch.clamp(turn_lens - delay_tokens, min=0) + + if user_audio_channel_embedding is not None and delay_tokens > 0: + warmup_user_audio = user_audio_channel_embedding[:, -delay_tokens:] + user_audio_channel_embedding = user_audio_channel_embedding[:, :-delay_tokens] + user_audio_prefill_tokens = user_audio_prefill_tokens[:, :-delay_tokens] + else: + warmup_user_audio = None + + if user_audio_prefill_tokens.size(1) > 0: + state = model.streaming_prefill_profile( + state=state, + text_tokens=user_audio_prefill_tokens, + use_inference_mode=True, + user_audio_channel_embedding=user_audio_channel_embedding, + ) + + for i in range(delay_tokens): + user_step_emb = warmup_user_audio[:, i] if warmup_user_audio is not None else None + state.finished.zero_() + state, _, _ = model.streaming_step( + state=state, + text_tokens=warmup_tokens[:, i], + user_audio_channel_embedding=user_step_emb, + prefill_like_step=not bool(model.cfg.get("agent_mask_include_transition_prefix", False)), + prefill_like_is_last_step=(i == delay_tokens - 1), + use_inference_mode=True, + ) + + logging.info(f"[multiturn_user_audio] turn={turn_id} prefilled {user_audio_prefill_steps} steps ({user_audio_prefill_seconds:.2f}s)") + + turn_start_frame = sum(p.size(-1) for p in state.all_predictions) + if turn_id == 0: + state.audio_prediction_start_idx.fill_(turn_start_frame) + multiturn_decode_start_frame = turn_start_frame + + turn_offset = state.text_tokens_seen.clone() + turn_steps = 0 + saw_audio = False + turn_ended_with_audio_eos = False + + while turn_steps < args.max_tts_steps: + turn_steps += 1 + state.finished.zero_() + relative_position = state.text_tokens_seen - turn_offset + text_exhausted = relative_position >= turn_lens + + if turn_text.size(1) == 0: + current_tokens = torch.full((B,), model.eos_id, dtype=torch.long, device=device) + else: + position = relative_position.clamp(min=0, max=turn_text.size(1) - 1) + current_tokens = turn_text[torch.arange(B, device=device), position] + current_tokens = torch.where( + text_exhausted, + torch.full_like(current_tokens, model.eos_id), + current_tokens, + ) + + state, audio_codes, _ = model.streaming_step( + state=state, + text_tokens=current_tokens, + use_inference_mode=True, + ) + + if audio_codes is not None and not saw_audio: + saw_audio = True + + if bool(text_exhausted[0].item()) and bool(state.finished[0].item()): + turn_ended_with_audio_eos = True + break + + state.audio_prediction_end_idx.fill_(-1) + state.finished.zero_() + turn_end_frame = sum(p.size(-1) for p in state.all_predictions) + multiturn_turn_frame_ranges.append((turn_id, turn_start_frame, turn_end_frame)) + logging.info( + f"[multiturn_user_audio] turn={turn_id} steps={turn_steps} " + f"saw_audio={saw_audio} ended_with_audio_eos={turn_ended_with_audio_eos}" + ) + + bos_id = getattr(model, "audio_bos_id", -1) + eos_id = getattr(model, "audio_eos_id", -1) + speaking_id = getattr(model, "audio_user_speaking_id", -1) + speaking_end_id = getattr(model, "audio_user_speaking_end_id", -1) + sil_injection = codec_sil_codes.view(1, -1, 1) + for step_idx in range(len(state.all_predictions)): + pred = state.all_predictions[step_idx] + mask = (pred == bos_id) | (pred == eos_id) | (pred == speaking_id) | (pred == speaking_end_id) + frame_mask = mask.any(dim=1, keepdim=True) + if frame_mask.any(): + state.all_predictions[step_idx] = torch.where(frame_mask, sil_injection.expand_as(pred), pred) + + state.audio_prediction_end_idx.fill_(-1) + + generated_codes = None + if getattr(state, "all_predictions", None): + try: + generated_codes = torch.cat(state.all_predictions, dim=-1).detach() + except Exception: + generated_codes = None + + finalize_output = model.streaming_finalize(state, use_inference_mode=True) + + return finalize_output, multiturn_turn_frame_ranges, multiturn_decode_start_frame, generated_codes + + +def load_speaker_wav_if_needed(args, model, target_dtype): if args.user_custom_speaker_reference and args.inference_speaker_reference: - speaker_wav = _load_audio( + return _load_audio( args.inference_speaker_reference, model.sample_rate, normalize=args.normalize_volume, use_librosa=args.use_librosa, ).unsqueeze(0).to(model.device, dtype=target_dtype) - filewise_metrics = [] - running_metric_sums = {"cer": 0.0, "wer": 0.0, "secs": 0.0} - running_metric_count = 0 - - for batch_id, inputs in enumerate(dataloader): - B = inputs["context_audio"].size(0) - device = model.device - - inputs["context_audio"] = inputs["context_audio"].to(device, dtype=target_dtype) - inputs["context_audio_lengths"] = inputs["context_audio_lengths"].to(device) - - if args.user_custom_speaker_reference and args.inference_speaker_reference: - inputs["context_audio"] = speaker_wav.repeat(B, 1).detach() - inputs["context_audio_lengths"] = torch.full((B,), speaker_wav.size(-1), dtype=torch.long, device=device) - - if "user_audio_turns" in inputs: - inputs["user_audio_turns"] = [ - x.to(device, dtype=target_dtype) for x in inputs["user_audio_turns"] - ] - inputs["user_audio_turns_lens"] = [ - x.to(device) for x in inputs["user_audio_turns_lens"] - ] - - profile_turn_frame_ranges = [] - with torch.inference_mode(): - wav = inputs["context_audio"] - wav_len = inputs["context_audio_lengths"] - codes, codes_lens = model._codec_helper.audio_to_codes(wav, wav_len) - - use_lang = bool(getattr(model, "add_language_to_context_text", False)) - ctx_text = f"[{args.language.upper()}]" if use_lang else "[NO TEXT CONTEXT]" - ctx_text_ids = model.tokenizer.encode(ctx_text, tokenizer_name=model.text_conditioning_tokenizer_name) - - ctx_toks = torch.tensor([ctx_text_ids], dtype=torch.long, device=device).expand(B, -1) - ctx_toks_lens = torch.tensor([len(ctx_text_ids)] * B, dtype=torch.long, device=device) - - state = model.streaming_init( - context_audio_codes=codes, - context_audio_codes_lens=codes_lens, - context_text_tokens=ctx_toks, - context_text_tokens_lens=ctx_toks_lens, - use_cfg=args.use_cfg, - cfg_scale=args.cfg_scale, - use_local_transformer=True, - temperature=args.temperature, - topk=args.topk, - phoneme_input_type="pred", - phoneme_sampling_method="argmax", - use_inference_mode=True, + return None + + +# ----------------------------- +# Save generation outputs and metric manifests +# ----------------------------- + + +def write_audio_1d(path: str, wav: torch.Tensor, sr: int): + os.makedirs(os.path.dirname(path), exist_ok=True) + wav_np = wav.detach().cpu().float().numpy() + sf.write(path, wav_np, samplerate=sr) + + +def build_metric_item( + run_id: int, + rank: int, + dataset_index: int, + turn_id: int, + target_audio_path: str, + reference_text: str, + pred_audio_path: str, + context_audio_path: str, + pred_audio_samples: int, + context_audio_samples: int, + output_sample_rate: int, + context_sample_rate: int, + predicted_codes_path: str = None, + sample_pred_audio_path: str = None, + sample_predicted_codes_path: str = None, +): + return { + "run_id": int(run_id), + "rank": int(rank), + "dataset_index": int(dataset_index), + "turn_id": int(turn_id), + "target_audio_path": target_audio_path, + "reference_text": reference_text, + "pred_audio_path": pred_audio_path, + "context_audio_path": context_audio_path, + "pred_audio_samples": int(pred_audio_samples), + "context_audio_samples": int(context_audio_samples), + "pred_audio_seconds": float(pred_audio_samples / output_sample_rate), + "context_audio_seconds": float(context_audio_samples / context_sample_rate), + "output_sample_rate": int(output_sample_rate), + "context_sample_rate": int(context_sample_rate), + "predicted_codes_path": predicted_codes_path, + "sample_pred_audio_path": sample_pred_audio_path or pred_audio_path, + "sample_predicted_codes_path": sample_predicted_codes_path or predicted_codes_path, + } + + +def save_generated_code_slice(generated_codes, batch_idx: int, start_frame: int, end_frame: int, path: str): + """Save predicted codec codes as [num_codebooks, T] for MagpieTTS FCD.""" + if generated_codes is None: + return None + try: + os.makedirs(os.path.dirname(path), exist_ok=True) + T = int(generated_codes.size(-1)) + start_frame = max(0, min(int(start_frame), T)) + end_frame = max(start_frame, min(int(end_frame), T)) + if end_frame <= start_frame: + return None + codes = generated_codes[batch_idx, :, start_frame:end_frame].detach().cpu().long() + if codes.numel() == 0: + return None + torch.save(codes, path) + return path + except Exception as e: + logging.warning(f"Could not save predicted codes to {path}: {repr(e)}") + return None + + +def save_generation_outputs_and_build_metric_items( + model, + inputs, + finalize_output, + multiturn_turn_frame_ranges, + multiturn_decode_start_frame, + generated_codes, + args, + rank: int, + run_id: int, +): + device = model.device + B = inputs["context_audio"].size(0) + + with fp32_precision(): + audio_f32 = finalize_output.audio.float() + audio_len = finalize_output.audio_len.int() + + # Use model-reported generated audio lengths for both supported modes. + audio_len = torch.clamp(audio_len, max=audio_f32.size(1)) + + audio_out_dir = get_audio_out_dir(args) + metric_turn_dir = get_generated_turn_audio_dir(args) + metric_context_dir = get_context_metric_audio_dir(args) + predicted_codes_dir = get_predicted_codes_dir(args) + os.makedirs(audio_out_dir, exist_ok=True) + os.makedirs(metric_turn_dir, exist_ok=True) + os.makedirs(metric_context_dir, exist_ok=True) + os.makedirs(predicted_codes_dir, exist_ok=True) + + audio_f32_cpu = audio_f32.detach().cpu() + audio_len_cpu = audio_len.detach().cpu() + metric_items = [] + + for i in range(B): + target_path = inputs["target_audio_paths"][i] + base_name = os.path.basename(target_path) + stem, ext = os.path.splitext(base_name) + if not ext: + ext = ".wav" + + dataset_idx = int(inputs.get("dataset_indices", [-1] * B)[i]) + safe_stem = ( + f"run{run_id:02d}_idx{dataset_idx:08d}_{stem}" + if dataset_idx >= 0 + else f"run{run_id:02d}_rank{rank}_{stem}" ) - # --------------------------------------------------------- - # MODE 1: DUPLEX (Continuous Padding Token Stream) - # --------------------------------------------------------- - if inputs["duplex_multiturn"]: - text = inputs["input_ids"].to(device) - text_lens = inputs["input_lengths"].to(device) - - # Trackers for our two forced-silence zones - in_initial_silence = torch.ones(B, dtype=torch.bool, device=device) - in_post_speech_silence = torch.zeros(B, dtype=torch.bool, device=device) - - text_exhausted = state.text_tokens_seen >= text_lens - while not text_exhausted.all(): - # 1. WAKE UP OVERRIDE: Keep the text pipeline awake to read pads! - state.finished = state.finished & text_exhausted - state.text_finished = state.text_finished & text_exhausted - if hasattr(state, "phoneme_stream_ended"): - state.phoneme_stream_ended = state.phoneme_stream_ended & text_exhausted - - # 2. Safely index text using the model's internal pointer - positions = state.text_tokens_seen.clamp(max=text.size(1) - 1) - current_tokens = text[torch.arange(B, device=device), positions] - current_tokens = torch.where( - text_exhausted, torch.full_like(current_tokens, model.eos_id), current_tokens - ) + context_len = int(inputs["context_audio_lengths"][i].detach().cpu().item()) + context_wav = inputs["context_audio"][i, :context_len].detach().cpu().float() + context_metric_path = os.path.join(metric_context_dir, f"{safe_stem}_context.wav") + write_audio_1d(context_metric_path, context_wav, model.sample_rate) + + if inputs.get("multiturn_user_audio", False): + full_len = int(audio_len_cpu[i].item()) + full_wav_t = audio_f32_cpu[i, :full_len].float() + full_out_path = os.path.join(audio_out_dir, f"{safe_stem}{ext}") + + full_codes_path = os.path.join(predicted_codes_dir, f"{safe_stem}_sample.pt") + sample_predicted_codes_path = save_generated_code_slice( + generated_codes, + i, + multiturn_decode_start_frame, + generated_codes.size(-1) if generated_codes is not None else multiturn_decode_start_frame, + full_codes_path, + ) - # 3. Update our trackers BEFORE the step - is_pad_or_eos = (current_tokens == model.pad_id) | (current_tokens == model.eos_id) - in_initial_silence = in_initial_silence & is_pad_or_eos - in_post_speech_silence = in_post_speech_silence & is_pad_or_eos - - # 4. Step the model - state, audio_codes, _ = model.streaming_step(state=state, text_tokens=current_tokens, use_inference_mode=True) - - # 5. SILENCE FORCING INJECTION - if audio_codes is not None and args.force_speech_sil_codes: - force_silence_mask = in_initial_silence | in_post_speech_silence - - if force_silence_mask.any(): - # Expand silence codes [C] -> [1, C, 1] to match audio_codes [B, C, 1] - expanded_sil = codec_sil_codes.view(1, -1, 1).expand_as(audio_codes) - # Expand mask [B] -> [B, 1, 1] for broadcasting - mask_3d = force_silence_mask.view(B, 1, 1) - # Overwrite the prediction with silence codes where the mask is True. - overwritten_codes = torch.where(mask_3d, expanded_sil, audio_codes) - # Inject back into the model's KV cache history - state.all_predictions[-1] = overwritten_codes - - # 6. TRIGGER POST-SPEECH SILENCE FOR THE *NEXT* FRAME - in_post_speech_silence = in_post_speech_silence | state.finished - - # Update exhaustion tracker for the next iteration - text_exhausted = state.text_tokens_seen >= text_lens - - # --------------------------------------------------------- - # MODE 2: REGULAR (Turn-by-Turn Re-wakes) - # --------------------------------------------------------- - elif inputs["regular_multiturn"]: - batched_turns = inputs["batched_turns"] - batched_turn_lens = inputs["batched_turn_lens"] - valid_turn_masks = inputs["valid_turn_masks"] - - max_turns = len(batched_turns) - turn_offsets = torch.zeros(B, dtype=torch.long, device=device) - - for t in range(max_turns): - turn_text = batched_turns[t].to(device) - turn_lens = batched_turn_lens[t].to(device) - valid_mask = valid_turn_masks[t].to(device) - - state.finished = state.finished & (~valid_mask) - state.text_finished = state.text_finished & (~valid_mask) - - if hasattr(state, "phoneme_stream_ended"): - state.phoneme_stream_ended = state.phoneme_stream_ended & (~valid_mask) - - if state.finished.all(): - continue - - turn_offsets = torch.where(valid_mask, state.text_tokens_seen, turn_offsets) - turn_steps = 0 - - while not state.finished.all() and turn_steps < args.max_tts_steps: - turn_steps += 1 - - relative_positions = state.text_tokens_seen - turn_offsets - positions = relative_positions.clamp(min=0, max=turn_text.size(1) - 1) - current_tokens = turn_text[torch.arange(B, device=device), positions] - - exhausted = relative_positions >= turn_lens - current_tokens = torch.where(exhausted, torch.full_like(current_tokens, model.eos_id), current_tokens) - - state, _, _ = model.streaming_step(state=state, text_tokens=current_tokens, use_inference_mode=True) - - # --------------------------------------------------------- - # MODE 3: PROFILE MULTI-TURN - # --------------------------------------------------------- - elif inputs["profile_multiturn"]: - if B != 1: - raise RuntimeError( - "--profile_multiturn_inference currently supports only batch_size=1. " - "Use --batch_size=1 for this mode." - ) + samples_per_prediction_frame = model.codec_model_samples_per_frame / ( + model.sample_rate / model.output_sample_rate + ) + + aligned_agent = torch.zeros_like(full_wav_t) + raw_turn_texts = inputs.get("raw_turn_texts", [[] for _ in range(B)]) + + for turn_id, start_frame, end_frame in multiturn_turn_frame_ranges: + rel_start_frame = start_frame - multiturn_decode_start_frame + rel_end_frame = end_frame - multiturn_decode_start_frame + + start_sample = int(round(rel_start_frame * samples_per_prediction_frame)) + end_sample = int(round(rel_end_frame * samples_per_prediction_frame)) - batched_turns = inputs["batched_turns"] - batched_turn_lens = inputs["batched_turn_lens"] - valid_turn_masks = inputs["valid_turn_masks"] + start_sample = max(0, min(start_sample, full_len)) + end_sample = max(start_sample, min(end_sample, full_len)) - max_turns = len(batched_turns) - prev_turn_ended_with_audio_eos = True # profile before turn 0 - for t in range(max_turns): - turn_ended_with_audio_eos = False - turn_text = batched_turns[t].to(device) - turn_lens = batched_turn_lens[t].to(device) - valid_mask = valid_turn_masks[t].to(device) + aligned_agent[start_sample:end_sample] = full_wav_t[start_sample:end_sample] - if not bool(valid_mask[0].item()): - continue + turn_wav = aligned_agent[start_sample:end_sample].float() + turn_out_path = os.path.join(audio_out_dir, f"{safe_stem}_turn_{turn_id}{ext}") + write_audio_1d(turn_out_path, turn_wav, model.output_sample_rate) + + metric_turn_path = os.path.join(metric_turn_dir, f"{safe_stem}_turn_{turn_id}.wav") + write_audio_1d(metric_turn_path, turn_wav, model.output_sample_rate) + + turn_codes_path = os.path.join(predicted_codes_dir, f"{safe_stem}_turn_{turn_id}.pt") + predicted_codes_path = save_generated_code_slice( + generated_codes, + i, + start_frame, + end_frame, + turn_codes_path, + ) + + if turn_id < len(raw_turn_texts[i]): + metric_items.append( + build_metric_item( + run_id=run_id, + rank=rank, + dataset_index=dataset_idx, + turn_id=turn_id, + target_audio_path=target_path, + reference_text=str(raw_turn_texts[i][turn_id]), + pred_audio_path=metric_turn_path, + context_audio_path=context_metric_path, + pred_audio_samples=int(turn_wav.numel()), + context_audio_samples=int(context_wav.numel()), + output_sample_rate=model.output_sample_rate, + context_sample_rate=model.sample_rate, + predicted_codes_path=predicted_codes_path, + sample_pred_audio_path=full_out_path, + sample_predicted_codes_path=sample_predicted_codes_path, + ) + ) - # Re-open stream for this turn. - state.finished.zero_() - state.text_finished.zero_() - state.audio_prediction_end_idx.fill_(-1) + write_audio_1d(full_out_path, aligned_agent, model.output_sample_rate) - if hasattr(state, "turn_text_tokens_seen"): - state.turn_text_tokens_seen.zero_() + if "user_audio_turns" in inputs: + user_segments = [] - if hasattr(state, "phoneme_steps"): - state.phoneme_steps.zero_() + first_user_len_in = int(inputs["user_audio_turns_lens"][0][i].item()) + first_user_delay_out = int(round(first_user_len_in * model.output_sample_rate / model.sample_rate)) - if hasattr(state, "phoneme_stream_ended"): - state.phoneme_stream_ended.zero_() + for turn_id, start_frame, _ in multiturn_turn_frame_ranges: + if turn_id >= len(inputs["user_audio_turns"]): + continue - if hasattr(state, "phoneme_eos_detected"): - state.phoneme_eos_detected.zero_() + turn_audio = inputs["user_audio_turns"][turn_id][i].detach().cpu().float() + turn_audio_len = int(inputs["user_audio_turns_lens"][turn_id][i].item()) + turn_audio = turn_audio[:turn_audio_len] - state.last_phoneme_tokens = None + turn_audio_out = resample( + turn_audio.unsqueeze(0), + model.sample_rate, + model.output_sample_rate, + ).squeeze(0) - if not model.cfg.get("condition_on_user_speech", False): - # Prefill on the begining of each turn - if "user_audio_turns" in inputs: - profile_T = int(round(inputs["user_audio_turns"][t].size(-1) / model.input_samples_per_frame)) - profile_seconds = profile_T * model.input_samples_per_frame / model.sample_rate + if turn_id == 0: + user_start_sample = 0 else: - profile_seconds = ( - args.profile_pad_min_sec - + torch.rand((), device=device).item() - * (args.profile_pad_max_sec - args.profile_pad_min_sec) - ) - profile_T = max( - 1, - int(round(profile_seconds * model.sample_rate / model.input_samples_per_frame)), + prev_turn_end_frame = multiturn_turn_frame_ranges[turn_id - 1][2] + rel_prev_end_frame = prev_turn_end_frame - multiturn_decode_start_frame + user_start_sample = first_user_delay_out + int( + round(rel_prev_end_frame * samples_per_prediction_frame) ) - profile_tokens = torch.full( - (1, profile_T), - model.pad_id, - dtype=torch.long, - device=device, - ) - user_audio_channel_embedding = None - else: - user_audio_channel_embedding = None - if "user_audio_turns" in inputs: - user_audio = inputs["user_audio_turns"][t] - user_audio_lens = inputs["user_audio_turns_lens"][t] - else: - print("Warning!! USING CONTEXT AUDIO AS USER AUDIO FOR TESTING !!") - user_audio = inputs["context_audio"] - user_audio_lens = inputs["context_audio_lengths"] + user_segments.append((user_start_sample, turn_audio_out.detach().cpu().float())) - user_audio_codes, user_audio_codes_lens = model._codec_helper.audio_to_codes( - user_audio, - user_audio_lens, - ) + total_user_len = 0 + for s, wav_seg in user_segments: + total_user_len = max(total_user_len, s + wav_seg.numel()) - if model._codec_converter is not None: - user_audio_codes = model._codec_converter.convert_original_to_new( - audio_tokens=user_audio_codes, audio_lens=user_audio_codes_lens - ).long() - - user_audio_codes, user_audio_codes_lens = model.stack_codes( - user_audio_codes, - user_audio_codes_lens, - model.audio_bos_id, - model.audio_eos_id, - model.frame_stacking_factor, - model.num_audio_codebooks, - ) - user_audio_embedded = model.embed_audio_tokens(user_audio_codes) + user_ch = torch.zeros(total_user_len) + for s, wav_seg in user_segments: + e = s + wav_seg.numel() + user_ch[s:e] += wav_seg - boundary_trim = model.cfg.get("user_audio_boundary_trim", 4) - boundary_trim = 0 if boundary_trim is None else int(boundary_trim) + agent_ch = torch.cat([torch.zeros(first_user_delay_out, dtype=aligned_agent.dtype), aligned_agent]) - # Remove BOS/EOS from the user-audio turn, same as training. - if boundary_trim == 0: - real_start = 0 - real_end = int(user_audio_codes_lens[0].item()) - else: - turn_len_with_special = int(user_audio_codes_lens[0].item()) - real_start = 1 - real_end = max(real_start, turn_len_with_special - 1) + stereo_len = max(user_ch.numel(), agent_ch.numel()) + user_pad = torch.zeros(stereo_len) + agent_pad = torch.zeros(stereo_len) - user_audio_embedded = user_audio_embedded[:, real_start:real_end] + user_pad[: user_ch.numel()] = user_ch + agent_pad[: agent_ch.numel()] = agent_ch - # Optional: trim boundaries exactly like training. - copy_len = user_audio_embedded.size(1) + stereo = torch.stack([user_pad, agent_pad], dim=1).numpy() + aligned_path = os.path.join(audio_out_dir, f"{safe_stem}_user_agent_aligned{ext}") + sf.write(aligned_path, stereo, samplerate=model.output_sample_rate) - if boundary_trim > 0: - trim = min(boundary_trim, copy_len // 2) + else: + full_len = int(audio_len_cpu[i].item()) + wav = audio_f32_cpu[i, :full_len].float() + out_path = os.path.join(audio_out_dir, f"{safe_stem}{ext}") + write_audio_1d(out_path, wav, model.output_sample_rate) + + metric_turn_path = os.path.join(metric_turn_dir, f"{safe_stem}_turn_0.wav") + write_audio_1d(metric_turn_path, wav, model.output_sample_rate) + + codes_path = os.path.join(predicted_codes_dir, f"{safe_stem}_turn_0.pt") + predicted_codes_path = save_generated_code_slice( + generated_codes, + i, + 0, + generated_codes.size(-1) if generated_codes is not None else 0, + codes_path, + ) - if trim > 0: - user_audio_embedded[:, :trim] = 0.0 - user_audio_embedded[:, copy_len - trim:] = 0.0 + metric_items.append( + build_metric_item( + run_id=run_id, + rank=rank, + dataset_index=dataset_idx, + turn_id=0, + target_audio_path=target_path, + reference_text=str(inputs["raw_text"][i]), + pred_audio_path=metric_turn_path, + context_audio_path=context_metric_path, + pred_audio_samples=int(wav.numel()), + context_audio_samples=int(context_wav.numel()), + output_sample_rate=model.output_sample_rate, + context_sample_rate=model.sample_rate, + predicted_codes_path=predicted_codes_path, + sample_pred_audio_path=out_path, + sample_predicted_codes_path=predicted_codes_path, + ) + ) - # Add BOS-aligned zero frame, because audio input timeline has BOS at t=0. - bos_user_pad = torch.zeros( - user_audio_embedded.size(0), - 1, - user_audio_embedded.size(2), - device=user_audio_embedded.device, - dtype=user_audio_embedded.dtype, - ) - user_audio_embedded = torch.cat([bos_user_pad, user_audio_embedded], dim=1) - - profile_T = user_audio_embedded.size(1) - profile_tokens = torch.full( - (B, profile_T), - model.pad_id, - dtype=torch.long, - device=device, - ) + return metric_items - user_audio_channel_embedding = user_audio_embedded - profile_seconds = profile_T * model.input_samples_per_frame / model.sample_rate - # add text tokens needed for profilling - delay_tokens = int(state.config.training_mode.streaming_speech_delay) - delay_tokens = min(delay_tokens, int(turn_lens[0].item()), profile_T) +# ----------------------------- +# Metrics after generation +# ----------------------------- - warmup_tokens = turn_text[:, :delay_tokens] - turn_text = turn_text[:, delay_tokens:] - turn_lens = torch.clamp(turn_lens - delay_tokens, min=0) - if user_audio_channel_embedding is not None and delay_tokens > 0: - warmup_user_audio = user_audio_channel_embedding[:, -delay_tokens:] - user_audio_channel_embedding = user_audio_channel_embedding[:, :-delay_tokens] - profile_tokens = profile_tokens[:, :-delay_tokens] - else: - warmup_user_audio = None - - if profile_tokens.size(1) > 0: - state = model.streaming_prefill_profile( - state=state, - text_tokens=profile_tokens, - use_inference_mode=True, - user_audio_channel_embedding=user_audio_channel_embedding - ) +def torch_rms_norm(wav: torch.Tensor, db_level: float = -27.0) -> torch.Tensor: + denom = torch.sum(wav**2) + if denom <= 0: + return wav + r = 10 ** (db_level / 20) + a = torch.sqrt((wav.size(-1) * (r**2)) / denom) + return wav * a - for i in range(delay_tokens): - user_step_emb = warmup_user_audio[:, i] if warmup_user_audio is not None else None - state.finished.zero_() +def _load_audio_for_metric(path: str, sample_rate: int): + wav = _load_audio(path, sample_rate=sample_rate, normalize=False, use_librosa=False) + if wav.numel() == 0: + wav = torch.zeros(1, dtype=torch.float32) + return wav.float() - state, _, _ = model.streaming_step( - state=state, - text_tokens=warmup_tokens[:, i], - user_audio_channel_embedding=user_step_emb, - prefill_like_step=not bool(model.cfg.get("agent_mask_include_transition_prefix", False)), - prefill_like_is_last_step=(i == delay_tokens - 1), - use_inference_mode=True, - ) - logging.info( - f"[profile_multiturn] turn={t} prefilled {profile_T} steps " - f"({profile_seconds:.2f}s)" - ) +def _pad_audio_1d_list(wavs: List[torch.Tensor], device, dtype=torch.float32): + if len(wavs) == 0: + return torch.zeros((0, 1), device=device, dtype=dtype), torch.zeros((0,), device=device, dtype=torch.long) - turn_start_frame = sum(p.size(-1) for p in state.all_predictions) - if t == 0: - state.audio_prediction_start_idx.fill_(turn_start_frame) - profile_decode_start_frame = turn_start_frame + lens = torch.tensor([max(1, int(w.numel())) for w in wavs], device=device, dtype=torch.long) + max_len = int(lens.max().item()) + out = torch.zeros((len(wavs), max_len), device=device, dtype=dtype) - turn_offset = state.text_tokens_seen.clone() - turn_steps = 0 - saw_audio = False - first_audio_step_finished = False + for i, w in enumerate(wavs): + w = w.to(device=device, dtype=dtype).flatten() + if w.numel() == 0: + continue + out[i, : w.numel()] = w - turn_text_done = False + return out, lens - while turn_steps < args.max_tts_steps: - turn_steps += 1 - state.finished.zero_() +def chunk_list(xs: List[Any], chunk_size: int) -> Iterable[List[Any]]: + chunk_size = max(1, int(chunk_size)) + for start in range(0, len(xs), chunk_size): + yield xs[start : start + chunk_size] - relative_position = state.text_tokens_seen - turn_offset - text_exhausted = relative_position >= turn_lens - if turn_text.size(1) == 0: - current_tokens = torch.full((B,), model.eos_id, dtype=torch.long, device=device) - else: - position = relative_position.clamp(min=0, max=turn_text.size(1) - 1) - current_tokens = turn_text[torch.arange(B, device=device), position] - current_tokens = torch.where( - text_exhausted, - torch.full_like(current_tokens, model.eos_id), - current_tokens, - ) +def _metric_device(): + return "cuda" if torch.cuda.is_available() else "cpu" - state, audio_codes, _ = model.streaming_step( - state=state, - text_tokens=current_tokens, - use_inference_mode=True, - ) - if audio_codes is not None and not saw_audio: - saw_audio = True - first_audio_step_finished = bool(state.finished[0].item()) +def _load_metric_batch_audio(batch_items: List[Dict[str, Any]], args): + pred_wavs = [] + context_wavs = [] - if bool(text_exhausted[0].item()) and bool(state.finished[0].item()): - turn_ended_with_audio_eos = True - break + for item in batch_items: + pred = _load_audio_for_metric(item["pred_audio_path"], sample_rate=int(item["output_sample_rate"])) + context = _load_audio_for_metric(item["context_audio_path"], sample_rate=int(item["context_sample_rate"])) - prev_turn_ended_with_audio_eos = turn_ended_with_audio_eos + if args.max_metric_audio_sec is not None: + max_pred_len = int(float(args.max_metric_audio_sec) * int(item["output_sample_rate"])) + pred = pred[: max(1, max_pred_len)] - # keep generated codes, but don't let this turn's EOS crop finalize output - state.audio_prediction_end_idx.fill_(-1) - state.finished.zero_() + pred_wavs.append(pred) + context_wavs.append(context) - logging.info( - f"[profile_multiturn] turn={t} steps={turn_steps} " - f"saw_audio={saw_audio} immediate_eos={prev_turn_ended_with_audio_eos}" - ) - turn_end_frame = sum(p.size(-1) for p in state.all_predictions) - profile_turn_frame_ranges.append((t, turn_start_frame, turn_end_frame)) - - # if state.audio_prediction_end_idx[0].item() >= 0: - # last_audio_prediction_end_idx.copy_(state.audio_prediction_end_idx) - - # Scrub Special Tokens (BOS/EOS) from Audio Codes --- - # Because we force-decode the entire uncropped sequence, any BOS or EOS - # tokens left in the array will produce loud artifacts in the codec. - bos_id = getattr(model, "audio_bos_id", -1) - eos_id = getattr(model, "audio_eos_id", -1) - speaking_id = getattr(model, "audio_user_speaking_id", -1) - speaking_end_id = getattr(model, "audio_user_speaking_end_id", -1) - - sil_injection = codec_sil_codes.view(1, -1, 1) - - for step_idx in range(len(state.all_predictions)): - pred = state.all_predictions[step_idx] - # Check if any codebook in the frame has any special token - mask = (pred == bos_id) | (pred == eos_id) | (pred == speaking_id) | (pred == speaking_end_id) - frame_mask = mask.any(dim=1, keepdim=True) - - if frame_mask.any(): - state.all_predictions[step_idx] = torch.where( - frame_mask, - sil_injection.expand_as(pred), - pred - ) + device = _metric_device() + pred_audio, pred_lens = _pad_audio_1d_list(pred_wavs, device=device) + context_audio, context_lens = _pad_audio_1d_list(context_wavs, device=device) + output_sample_rate = int(batch_items[0]["output_sample_rate"]) + context_sample_rate = int(batch_items[0]["context_sample_rate"]) - if inputs["duplex_multiturn"]: - # Erase the internal memory of Turn 1's EOS token so `streaming_finalize` - # decodes the entire physical sequence! - state.audio_prediction_end_idx.fill_(-1) - - if inputs["profile_multiturn"]: - state.audio_prediction_end_idx.fill_(-1) + return pred_audio, pred_lens, context_audio, context_lens, output_sample_rate, context_sample_rate - # Finalize decodes the collected Codec states globally regardless of which loop was run - finalize_output = model.streaming_finalize(state, use_inference_mode=True) - if args.debug_dtype and batch_id == 0: - report_dtype_stats(handles, stats, examples) - with fp32_precision(): - audio_f32 = finalize_output.audio.float() - audio_len = finalize_output.audio_len.int() - - expected_audio_lens = (torch.tensor(inputs["target_num_frames"], device=device) * model.target_samples_per_frame).int() - - if inputs["duplex_multiturn"]: - # Use exact math based on the output samples multiplier! - audio_len = (text_lens * model.target_samples_per_frame).int() - - # Cap the expected length so it physically cannot exceed the actual generated tensor size - audio_len = torch.min(audio_len, torch.tensor(audio_f32.size(1), device=device)) - elif inputs["profile_multiturn"]: - audio_len = finalize_output.audio_len.int() - else: - audio_len = torch.min(audio_len, expected_audio_lens) +def _nan(): + return float("nan") + + +def finite_avg(values): + finite_values = [] + for value in values: + if value is None: + continue + try: + value = float(value) + except Exception: + continue + if math.isfinite(value): + finite_values.append(value) + if not finite_values: + return None + return sum(finite_values) / len(finite_values) + + +def _safe_word_error_detail(hyp_text: str, ref_text: str, use_cer: bool): + ref_text = "" if ref_text is None else str(ref_text).strip() + hyp_text = "" if hyp_text is None else str(hyp_text).strip() + if ref_text == "": + return None + try: + detailed = word_error_rate_detail(hypotheses=[hyp_text], references=[ref_text], use_cer=use_cer) + value = float(detailed[0]) + if not math.isfinite(value): + return None + return detailed + except Exception: + return None + + +def _safe_detail_value(detailed): + if detailed is None: + return None + try: + value = float(detailed[0]) + except Exception: + return None + if not math.isfinite(value): + return None + return value - metric_audio_pred = resample(audio_f32, model.output_sample_rate, 16000) - metric_audio_pred_lens = (audio_len / model.output_sample_rate * 16000).to(torch.long) - context_audio = resample(inputs["context_audio"].float(), model.sample_rate, 16000) - context_audio_lens = (inputs["context_audio_lengths"] / model.sample_rate * 16000).to(torch.long) +def _load_speaker_eval_models(args, device: str): + """Load speaker verification models in the same style as MagpieTTS evaluation utils.""" + models = { + "feature_extractor": None, + "sv_model": None, + "sv_model_alternate": None, + } - # normalize volume - metric_audio_pred = torch_rms_norm(metric_audio_pred) - context_audio = torch_rms_norm(context_audio) + if nemo_asr is None or extract_embedding is None: + logging.warning("Speaker metric dependencies are unavailable; speaker similarity metrics will be NaN.") + return models - asr_hyps = intelligibility.update( - name="dataset", - refs=inputs["raw_text"], - pred_audio=metric_audio_pred, - pred_audio_lens=metric_audio_pred_lens, - asr_hyps=None, + try: + if args.sv_model_type == "wavlm": + if Wav2Vec2FeatureExtractor is None or WavLMForXVector is None: + raise RuntimeError("transformers WavLM dependencies are unavailable") + models["feature_extractor"] = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/wavlm-base-plus-sv") + models["sv_model"] = WavLMForXVector.from_pretrained("microsoft/wavlm-base-plus-sv").to(device).eval() + else: + models["sv_model"] = ( + nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(model_name="titanet_large").to(device).eval() ) - secs_values = _compute_secs_per_sample( - secs_metric, - "dataset", - context_audio, - context_audio_lens, - metric_audio_pred, - metric_audio_pred_lens, + logging.info("Loading alternate speaker model `titanet_small`.") + with logging.temp_verbosity(logging.ERROR): + models["sv_model_alternate"] = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained( + model_name="titanet_small" + ) + models["sv_model_alternate"] = models["sv_model_alternate"].to(device).eval() + except Exception as e: + logging.warning(f"Could not load speaker evaluation models: {repr(e)}") + models = {"feature_extractor": None, "sv_model": None, "sv_model_alternate": None} + + return models + + +def _compute_speaker_similarity_rows(args, rows: List[Dict[str, Any]]): + """Populate pred/GT/context speaker similarity metrics per turn.""" + if args.disable_speaker_metrics: + for row in rows: + row["pred_gt_ssim"] = _nan() + row["pred_context_ssim"] = _nan() + row["gt_context_ssim"] = _nan() + row["pred_gt_ssim_alternate"] = _nan() + row["pred_context_ssim_alternate"] = _nan() + row["gt_context_ssim_alternate"] = _nan() + row["ssim"] = _nan() + return + + device = _metric_device() + models = _load_speaker_eval_models(args, device=device) + sv_model = models.get("sv_model") + sv_model_alt = models.get("sv_model_alternate") + extractor = models.get("feature_extractor") + + if sv_model is None or sv_model_alt is None or extract_embedding is None: + for row in rows: + row["pred_gt_ssim"] = _nan() + row["pred_context_ssim"] = _nan() + row["gt_context_ssim"] = _nan() + row["pred_gt_ssim_alternate"] = _nan() + row["pred_context_ssim_alternate"] = _nan() + row["gt_context_ssim_alternate"] = _nan() + row["ssim"] = _nan() + return + + emb_cache = {} + emb_alt_cache = {} + + def get_emb(path: str, alternate: bool = False): + if path is None or not path or not os.path.exists(path): + return None + cache = emb_alt_cache if alternate else emb_cache + if path in cache: + return cache[path] + model = sv_model_alt if alternate else sv_model + sv_type = "titanet" if alternate else args.sv_model_type + try: + with torch.inference_mode(): + emb = extract_embedding( + model=model, + extractor=extractor, + audio_path=path, + device=device, + sv_model_type=sv_type, + ) + cache[path] = emb + return emb + except Exception as e: + logging.warning(f"Could not extract speaker embedding for {path}: {repr(e)}") + cache[path] = None + return None + + def cosine(a, b): + if a is None or b is None: + return _nan() + try: + return torch.nn.functional.cosine_similarity(a, b, dim=0).item() + except Exception: + return _nan() + + for row in rows: + pred_path = row.get("pred_audio_path") + gt_path = row.get("target_audio_path") + context_path = row.get("context_audio_path") + + pred = get_emb(pred_path, alternate=False) + gt = get_emb(gt_path, alternate=False) + context = get_emb(context_path, alternate=False) + + pred_alt = get_emb(pred_path, alternate=True) + gt_alt = get_emb(gt_path, alternate=True) + context_alt = get_emb(context_path, alternate=True) + + row["pred_gt_ssim"] = cosine(pred, gt) + row["pred_context_ssim"] = cosine(pred, context) + row["gt_context_ssim"] = cosine(gt, context) + row["pred_gt_ssim_alternate"] = cosine(pred_alt, gt_alt) + row["pred_context_ssim_alternate"] = cosine(pred_alt, context_alt) + row["gt_context_ssim_alternate"] = cosine(gt_alt, context_alt) + row["ssim"] = row["pred_context_ssim"] + + del models + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def _compute_utmos_rows(args, rows: List[Dict[str, Any]], rank: int): + if args.disable_utmosv2: + for row in rows: + row["utmosv2"] = _nan() + return + + if compute_utmosv2_scores is None: + logging.warning("UTMOSv2 utility is unavailable; setting utmosv2 to NaN.") + for row in rows: + row["utmosv2"] = _nan() + return + + try: + # All predicted metric turns are written into the same directory. + audio_dir = get_generated_turn_audio_dir(args) + scores = compute_utmosv2_scores(audio_dir, _metric_device()) + for row in rows: + row["utmosv2"] = scores.get(os.path.normpath(row.get("pred_audio_path", "")), _nan()) + except Exception as e: + all_rank_print(rank, f"UTMOSv2 computation failed; setting utmosv2 to NaN: {repr(e)}") + for row in rows: + row["utmosv2"] = _nan() + + +def _compute_eou_rows(args, rows: List[Dict[str, Any]], rank: int): + if args.disable_eou or args.language != "en" or EoUClassifier is None: + for row in rows: + row["eou_type"] = None + row["eou_trailing_duration"] = _nan() + row["eou_trail_rms_ratio"] = _nan() + return + + try: + kwargs = {"device": _metric_device()} + if args.eou_model_name: + kwargs["model_name"] = args.eou_model_name + classifier = EoUClassifier(**kwargs) + items = [(row.get("pred_audio_path"), row.get("reference_text", "")) for row in rows] + + results = [] + batch_size = max(1, int(args.eou_batch_size)) + for start in range(0, len(items), batch_size): + results.extend(classifier.classify_batch(items[start : start + batch_size])) + + for row, result in zip(rows, results): + row["eou_type"] = result.eou_type.value + row["eou_trailing_duration"] = result.trailing_duration + row["eou_trail_rms_ratio"] = result.trail_rms_ratio + except Exception as e: + all_rank_print(rank, f"EOU computation failed; setting EOU metrics to NaN: {repr(e)}") + for row in rows: + row["eou_type"] = None + row["eou_trailing_duration"] = _nan() + row["eou_trail_rms_ratio"] = _nan() + + +def compute_magpie_style_global_metrics(rows: List[Dict[str, Any]]) -> Dict[str, Any]: + """Aggregate the same metric keys used by MagpieTTS evaluate_generated_audio.""" + n = len(rows) + if n == 0: + return { + "cer_filewise_avg": None, + "wer_filewise_avg": None, + "cer_cumulative": None, + "wer_cumulative": None, + "ssim_pred_gt_avg": None, + "ssim_pred_context_avg": None, + "ssim_gt_context_avg": None, + "ssim_pred_gt_avg_alternate": None, + "ssim_pred_context_avg_alternate": None, + "ssim_gt_context_avg_alternate": None, + "cer_gt_audio_cumulative": _nan(), + "wer_gt_audio_cumulative": _nan(), + "utmosv2_avg": None, + "total_gen_audio_seconds": 0.0, + "frechet_codec_distance": _nan(), + "eou_cutoff_rate": _nan(), + "eou_silence_rate": _nan(), + "eou_noise_rate": _nan(), + "eou_error_rate": _nan(), + } + + pred_texts = [str(r.get("pred_text", r.get("asr_hyp", ""))) for r in rows if r.get("gt_text", r.get("reference_text", ""))] + gt_texts = [str(r.get("gt_text", r.get("reference_text", ""))) for r in rows if r.get("gt_text", r.get("reference_text", ""))] + + out = {} + out["cer_filewise_avg"] = finite_avg([r.get("cer") for r in rows]) + out["wer_filewise_avg"] = finite_avg([r.get("wer") for r in rows]) + + if pred_texts and gt_texts: + try: + out["cer_cumulative"] = float(word_error_rate_detail(hypotheses=pred_texts, references=gt_texts, use_cer=True)[0]) + except Exception: + out["cer_cumulative"] = None + try: + out["wer_cumulative"] = float(word_error_rate_detail(hypotheses=pred_texts, references=gt_texts, use_cer=False)[0]) + except Exception: + out["wer_cumulative"] = None + else: + out["cer_cumulative"] = None + out["wer_cumulative"] = None + + out["ssim_pred_gt_avg"] = finite_avg([r.get("pred_gt_ssim") for r in rows]) + out["ssim_pred_context_avg"] = finite_avg([r.get("pred_context_ssim") for r in rows]) + out["ssim_gt_context_avg"] = finite_avg([r.get("gt_context_ssim") for r in rows]) + out["ssim_pred_gt_avg_alternate"] = finite_avg([r.get("pred_gt_ssim_alternate") for r in rows]) + out["ssim_pred_context_avg_alternate"] = finite_avg([r.get("pred_context_ssim_alternate") for r in rows]) + out["ssim_gt_context_avg_alternate"] = finite_avg([r.get("gt_context_ssim_alternate") for r in rows]) + + gt_audio_texts = [r.get("gt_audio_text") for r in rows] + if gt_audio_texts and all(x is not None for x in gt_audio_texts): + try: + out["cer_gt_audio_cumulative"] = float( + word_error_rate_detail(hypotheses=gt_audio_texts, references=gt_texts, use_cer=True)[0] ) + except Exception: + out["cer_gt_audio_cumulative"] = _nan() + try: + out["wer_gt_audio_cumulative"] = float( + word_error_rate_detail(hypotheses=gt_audio_texts, references=gt_texts, use_cer=False)[0] + ) + except Exception: + out["wer_gt_audio_cumulative"] = _nan() + else: + out["cer_gt_audio_cumulative"] = _nan() + out["wer_gt_audio_cumulative"] = _nan() + + out["utmosv2_avg"] = finite_avg([r.get("utmosv2") for r in rows]) + out["total_gen_audio_seconds"] = sum(float(r.get("total_gen_audio_seconds", r.get("pred_audio_seconds", 0.0)) or 0.0) for r in rows) + out["frechet_codec_distance"] = _nan() + + eou_types = [r.get("eou_type") for r in rows] + if eou_types and eou_types[0] is not None: + counts = Counter(eou_types) + if EoUType is not None: + labels = list(EoUType.error_types()) + good_label = EoUType.GOOD + else: + labels = ["cutoff", "silence", "noise"] + good_label = "good" - batch_filewise_metrics = [] - for i, (raw_ref, raw_hyp) in enumerate(zip(inputs["raw_text"], asr_hyps)): - normalized_ref = intelligibility.normalizer(raw_ref) - normalized_hyp = intelligibility.normalizer(raw_hyp) - target_path = inputs["target_audio_paths"][i] - context_path = inputs["context_audio_paths"][i] - cer = word_error_rate([normalized_hyp], [normalized_ref], use_cer=True) - wer = word_error_rate([normalized_hyp], [normalized_ref], use_cer=False) - secs = _json_metric_value(secs_values[i]) - running_metric_count += 1 - running_metric_sums["cer"] += cer - running_metric_sums["wer"] += wer - running_metric_sums["secs"] += secs - row = { - "sample_index": running_metric_count - 1, - "batch_id": batch_id, - "batch_sample_index": i, - "target_audio_filepath": target_path, - "target_audio_resolved_filepath": _resolve_audio_path(target_path, args.audio_dir), - "context_audio_filepath": context_path, - "context_audio_resolved_filepath": _resolve_audio_path(context_path, args.audio_dir), - "speaker_reference_filepath": ( - args.inference_speaker_reference - if args.user_custom_speaker_reference and args.inference_speaker_reference - else None - ), - "generated_audio_filepath": None, - "generated_turn_audio_filepaths": [], - "aligned_user_agent_audio_filepath": None, - "reference_transcript": raw_ref, - "asr_hypothesis": raw_hyp, - "normalized_reference_transcript": normalized_ref, - "normalized_asr_hypothesis": normalized_hyp, - "cer": cer, - "wer": wer, - "secs": secs, - "running_average_metrics": { - "num_samples": running_metric_count, - "cer": running_metric_sums["cer"] / running_metric_count, - "wer": running_metric_sums["wer"] / running_metric_count, - "secs": running_metric_sums["secs"] / running_metric_count, - }, - } - batch_filewise_metrics.append(row) - filewise_metrics.append(row) - - os.makedirs(args.out_dir, exist_ok=True) - audio_f32 = audio_f32.detach().cpu() - audio_len = audio_len.cpu() - - for i in range(B): - target_path = inputs["target_audio_paths"][i] - base_name = os.path.basename(target_path) - stem, ext = os.path.splitext(base_name) - if not ext: - ext = ".wav" - - if inputs["profile_multiturn"]: - full_len = int(audio_len[i].item()) - full_wav_t = audio_f32[i, :full_len].detach().cpu().float() - - samples_per_prediction_frame = ( - model.codec_model_samples_per_frame - / (model.sample_rate / model.output_sample_rate) - ) + for label in labels: + out[f"eou_{label}_rate"] = counts.get(label, 0) / n + out["eou_error_rate"] = 1.0 - counts.get(good_label, 0) / n + else: + out["eou_cutoff_rate"] = _nan() + out["eou_silence_rate"] = _nan() + out["eou_noise_rate"] = _nan() + out["eou_error_rate"] = _nan() - # Build artifact-free aligned agent audio: - # start from true zeros and copy only generated turn regions. - aligned_agent = torch.zeros_like(full_wav_t) + return out - print(profile_turn_frame_ranges) - for turn_id, start_frame, end_frame in profile_turn_frame_ranges: - rel_start_frame = start_frame - profile_decode_start_frame - rel_end_frame = end_frame - profile_decode_start_frame +def _load_asr_model_for_metrics(args, rank: int): + """Load ASR directly, matching the EasyMagpie/MagpieTTS evaluation style.""" + asr_cls = ASRModel + if asr_cls is None and nemo_asr is not None: + asr_cls = getattr(getattr(nemo_asr, "models", None), "ASRModel", None) + if asr_cls is None: + raise RuntimeError("NeMo ASRModel is unavailable, cannot load ASR model.") - start_sample = int(round(rel_start_frame * samples_per_prediction_frame)) - end_sample = int(round(rel_end_frame * samples_per_prediction_frame)) + all_rank_print(rank, f"loading ASR model after generation: {args.asr_model_name}") + with fp32_precision(): + asr_model = asr_cls.from_pretrained(model_name=args.asr_model_name) + asr_model = asr_model.to(_metric_device()).eval() - start_sample = max(0, min(start_sample, full_len)) - end_sample = max(start_sample, min(end_sample, full_len)) + return asr_model - print( - "Turn:", turn_id, - "Start:", start_sample, - "End:", end_sample, - "Start S:", start_sample / model.output_sample_rate, - "End S:", end_sample / model.output_sample_rate, - ) - # Copy only this turn into the aligned full output. - aligned_agent[start_sample:end_sample] = full_wav_t[start_sample:end_sample] - - # Save individual turn from the same aligned region. - turn_wav = aligned_agent[start_sample:end_sample].numpy() - out_path = os.path.join(args.out_dir, f"{stem}_turn_{turn_id}{ext}") - sf.write(out_path, turn_wav, samplerate=model.output_sample_rate) - batch_filewise_metrics[i]["generated_turn_audio_filepaths"].append(out_path) - logging.info(f"Saved: {out_path}") - - # Save full artifact-scrubbed agent audio. - wav = aligned_agent.numpy() - out_path = os.path.join(args.out_dir, base_name) - sf.write(out_path, wav, samplerate=model.output_sample_rate) - batch_filewise_metrics[i]["generated_audio_filepath"] = out_path - logging.info(f"Full aligned agent audio saved: {out_path}") - - # --------------------------------------------------------- - # Save aligned stereo conversation: - # channel 0 = user conditioning audio - # channel 1 = generated agent audio, zeroed outside turns - # --------------------------------------------------------- - if "user_audio_turns" in inputs: - user_segments = [] - - first_user_len_in = int(inputs["user_audio_turns_lens"][0][i].item()) - first_user_delay_out = int( - round(first_user_len_in * model.output_sample_rate / model.sample_rate) - ) +def _asr_transcribe_audio_batch(asr_model, audio: torch.Tensor, audio_lens: torch.Tensor, batch_size: int): + audio_list = [a[: int(alen.item())].detach().cpu() for a, alen in zip(audio, audio_lens)] + with fp32_precision(), torch.inference_mode(): + hyps = asr_model.transcribe(audio_list, batch_size=batch_size, verbose=False) - for turn_id, start_frame, end_frame in profile_turn_frame_ranges: - if turn_id >= len(inputs["user_audio_turns"]): - continue - - turn_audio = inputs["user_audio_turns"][turn_id][i].detach().cpu().float() - turn_audio_len = int(inputs["user_audio_turns_lens"][turn_id][i].item()) - turn_audio = turn_audio[:turn_audio_len] - - turn_audio_out = resample( - turn_audio.unsqueeze(0), - model.sample_rate, - model.output_sample_rate, - ).squeeze(0) - - if turn_id == 0: - user_start_sample = 0 - else: - prev_turn_end_frame = profile_turn_frame_ranges[turn_id - 1][2] - rel_prev_end_frame = prev_turn_end_frame - profile_decode_start_frame - user_start_sample = first_user_delay_out + int( - round(rel_prev_end_frame * samples_per_prediction_frame) - ) - - user_segments.append((user_start_sample, turn_audio_out)) - - total_user_len = 0 - for s, wav_seg in user_segments: - total_user_len = max(total_user_len, s + wav_seg.numel()) - - user_ch = torch.zeros(total_user_len) - - for s, wav_seg in user_segments: - e = s + wav_seg.numel() - user_ch[s:e] += wav_seg - - # Agent channel keeps the same previous offset, but uses aligned_agent. - agent_ch = torch.cat( - [ - torch.zeros(first_user_delay_out, dtype=aligned_agent.dtype), - aligned_agent, - ] - ) + out = [] + for hyp in hyps: + if hasattr(hyp, "text"): + out.append(str(hyp.text)) + else: + out.append(str(hyp)) + return out - stereo_len = max(user_ch.numel(), agent_ch.numel()) - user_pad = torch.zeros(stereo_len) - agent_pad = torch.zeros(stereo_len) - user_pad[: user_ch.numel()] = user_ch - agent_pad[: agent_ch.numel()] = agent_ch +def compute_metrics_after_generation(args, rank: int, world_size: int, metric_items: List[Dict[str, Any]]): + """ + Compute metrics after generation without the speechlm2 metric wrappers. - stereo = torch.stack([user_pad, agent_pad], dim=1).numpy() + This follows the MagpieTTS/EasyMagpieTTS evaluation style more closely: + - ASR is loaded directly from args.asr_model_name and used for transcription. + - CER/WER are computed from ASR hypotheses with word_error_rate_detail. + - Speaker similarity is computed with the MagpieTTS embedding helper and + reported as SSIM, especially pred_context_ssim / ssim. + """ + metric_start = time.time() - aligned_path = os.path.join( - args.out_dir, - f"{stem}_user_agent_aligned{ext}", - ) + if len(metric_items) == 0: + return { + "rank": int(rank), + "world_size": int(world_size), + "num_processed": 0, + "num_metric_items": 0, + "metric_elapsed_sec": 0.0, + "intelligibility": {}, + "speaker_similarity": {}, + "magpie_style_metrics": {}, + }, [] - sf.write( - aligned_path, - stereo, - samplerate=model.output_sample_rate, - ) + normalizer = EnglishTextNormalizer() + normalizer.ignore_patterns = r"$^" + filewise_rows = [] - batch_filewise_metrics[i]["aligned_user_agent_audio_filepath"] = aligned_path - logging.info(f"Aligned user/agent stereo audio saved: {aligned_path}") + # ASR pass, directly using ASRModel as in MagpieTTS evaluation. + asr_model = _load_asr_model_for_metrics(args, rank=rank) - else: - wav = audio_f32[i, : audio_len[i]].numpy() - out_path = os.path.join(args.out_dir, base_name) - sf.write(out_path, wav, samplerate=model.output_sample_rate) - batch_filewise_metrics[i]["generated_audio_filepath"] = out_path - logging.info(f"Saved: {out_path}") + for batch_items in chunk_list(metric_items, args.metric_batch_size): + pred_audio, pred_lens, _, _, output_sr, _ = _load_metric_batch_audio(batch_items, args) + + with fp32_precision(): + pred_16k = resample(pred_audio, output_sr, 16000) + pred_16k_lens = (pred_lens / output_sr * 16000).to(torch.long) + + asr_hyps = _asr_transcribe_audio_batch( + asr_model=asr_model, + audio=pred_16k, + audio_lens=pred_16k_lens, + batch_size=len(batch_items), + ) + + for item, hyp in zip(batch_items, asr_hyps): + ref_norm = normalizer(str(item["reference_text"])).strip() + hyp_norm = normalizer(str(hyp)).strip() + + detailed_cer = _safe_word_error_detail(hyp_norm, ref_norm, use_cer=True) + detailed_wer = _safe_word_error_detail(hyp_norm, ref_norm, use_cer=False) + cer = _safe_detail_value(detailed_cer) + wer = _safe_detail_value(detailed_wer) + + row = dict(item) + row["asr_hyp"] = hyp + row["pred_text"] = hyp_norm + row["gt_text"] = ref_norm + row["detailed_cer"] = detailed_cer + row["detailed_wer"] = detailed_wer + row["cer"] = cer + row["wer"] = wer + row["ssim"] = _nan() + row["gt_audio_text"] = None + row["utmosv2"] = _nan() + row["eou_type"] = None + row["eou_trailing_duration"] = _nan() + row["eou_trail_rms_ratio"] = _nan() + row["pred_gt_ssim"] = _nan() + row["pred_context_ssim"] = _nan() + row["gt_context_ssim"] = _nan() + row["pred_gt_ssim_alternate"] = _nan() + row["pred_context_ssim_alternate"] = _nan() + row["gt_context_ssim_alternate"] = _nan() + row["total_gen_audio_seconds"] = float(row.get("pred_audio_seconds", 0.0) or 0.0) + filewise_rows.append(row) + + del asr_model + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Speaker similarity pass. This is the standardized "SSIM" used by the + # MagpieTTS evaluation scripts: pred_context_ssim is the main speaker + # similarity against the conditioning/context audio. + _compute_speaker_similarity_rows(args, filewise_rows) + for row in filewise_rows: + row["ssim"] = row.get("pred_context_ssim", _nan()) + + _compute_utmos_rows(args, filewise_rows, rank=rank) + _compute_eou_rows(args, filewise_rows, rank=rank) + + magpie_style_metrics = compute_magpie_style_global_metrics(filewise_rows) + + cer_wer = { + "cer": magpie_style_metrics.get("cer_cumulative"), + "wer": magpie_style_metrics.get("wer_cumulative"), + "cer_dataset": magpie_style_metrics.get("cer_cumulative"), + "wer_dataset": magpie_style_metrics.get("wer_cumulative"), + "cer_filewise_avg": magpie_style_metrics.get("cer_filewise_avg"), + "wer_filewise_avg": magpie_style_metrics.get("wer_filewise_avg"), + } + + speaker_similarity = { + "ssim": magpie_style_metrics.get("ssim_pred_context_avg"), + "ssim_dataset": magpie_style_metrics.get("ssim_pred_context_avg"), + "ssim_pred_gt_avg": magpie_style_metrics.get("ssim_pred_gt_avg"), + "ssim_pred_context_avg": magpie_style_metrics.get("ssim_pred_context_avg"), + "ssim_gt_context_avg": magpie_style_metrics.get("ssim_gt_context_avg"), + "ssim_pred_gt_avg_alternate": magpie_style_metrics.get("ssim_pred_gt_avg_alternate"), + "ssim_pred_context_avg_alternate": magpie_style_metrics.get("ssim_pred_context_avg_alternate"), + "ssim_gt_context_avg_alternate": magpie_style_metrics.get("ssim_gt_context_avg_alternate"), + } + + metric_elapsed = time.time() - metric_start + + rank_metrics = { + "rank": int(rank), + "world_size": int(world_size), + "num_processed": len({(x["run_id"], x["dataset_index"]) for x in metric_items}), + "num_metric_items": int(len(metric_items)), + "metric_elapsed_sec": float(metric_elapsed), + "intelligibility": cer_wer, + "speaker_similarity": speaker_similarity, + "magpie_style_metrics": magpie_style_metrics, + } + + return rank_metrics, filewise_rows + + +# ----------------------------- +# Merge helpers +# ----------------------------- + + +def compute_and_save_rank_metrics_file(args, rank_metrics: Dict[str, Any], rank: int): + rank_path = os.path.join(args.out_dir, f"metrics_rank{rank:04d}.json") + write_json(rank_path, rank_metrics) + return rank_metrics + + +def merge_metrics_on_rank0(args, rank, world_size): + if rank != 0: + return None + + rank_metric_files = [os.path.join(args.out_dir, f"metrics_rank{r:04d}.json") for r in range(world_size)] + + rank_metrics = [] + for path in rank_metric_files: + if not os.path.exists(path): + logging.warning(f"Missing rank metric file: {path}") + continue + with open(path, "r", encoding="utf-8") as f: + rank_metrics.append(json.load(f)) + + total_n = sum(int(m.get("num_metric_items", m.get("num_processed", 0))) for m in rank_metrics) + + def weighted_average(section: str): + keys = set() + for m in rank_metrics: + keys.update(m.get(section, {}).keys()) + + out = {} + for k in sorted(keys): + numerator = 0.0 + denominator = 0 + + for m in rank_metrics: + n = int(m.get("num_metric_items", m.get("num_processed", 0))) + if n <= 0: + continue + + value = m.get(section, {}).get(k, None) + if value is None or isinstance(value, str): + continue + + try: + value = float(value) + except Exception: + continue + + numerator += value * n + denominator += n + + if denominator > 0: + out[k] = numerator / denominator + + return out + + final_metrics = { + "world_size": int(world_size), + "num_metric_items": int(total_n), + "aggregation": "sum(rank_metric * rank_num_metric_items) / total_num_metric_items", + "intelligibility": weighted_average("intelligibility"), + "speaker_similarity": weighted_average("speaker_similarity"), + "ranks": rank_metrics, + } + + final_json_path = os.path.join(args.out_dir, "metrics_final.json") + final_txt_path = os.path.join(args.out_dir, "metrics_final.txt") + + write_json(final_json_path, final_metrics) + + final_text = format_final_metric_text(final_metrics) + write_text_atomic(final_txt_path, final_text) + + print("\n--- Final Evaluation Metrics ---", flush=True) + print(final_text, flush=True) + + logging.info(f"Final metrics JSON saved to: {final_json_path}") + logging.info(f"Final metrics TXT saved to: {final_txt_path}") + logging.info(json.dumps(final_metrics, indent=2, sort_keys=True)) + + return final_metrics + + +def _cer_sort_value(row: Dict[str, Any]) -> float: + """Return finite CER for sorting; missing/non-finite values go last.""" + value = row.get("cer", None) + if value is None: + return float("-inf") + try: + value = float(value) + except Exception: + return float("-inf") + if not math.isfinite(value): + return float("-inf") + return value + + +def merge_filewise_metrics_on_rank0(args, rank: int, world_size: int): + """Merge per-turn rank metric rows and write global CER-sorted outputs. + + Writes: + - filewise_metrics_turns_sorted_by_cer.jsonl/csv: + one row per turn, merged across ranks, sorted by turn CER. + - filewise_metrics_global_sorted_by_cer.jsonl/csv: + compatibility alias for the same turn-level global output. + - filewise_metrics_sorted_by_cer.jsonl/csv: + one row per original sample, with turn metric lists, sorted by + sample-average CER. + """ + if rank != 0 or not args.save_filewise_metrics: + return [] + + turn_rows = [] + + for r in range(world_size): + path = os.path.join(args.out_dir, f"filewise_metrics_rank{r:04d}.jsonl") + if not os.path.exists(path): + logging.warning(f"Missing filewise metrics file: {path}") + continue + + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + turn_rows.append(json.loads(line)) + + # Deduplicate DistributedSampler padding repeats, but preserve --num_eval_runs. + deduped_turns = {} + for row in turn_rows: + run_id = int(row.get("run_id", 0)) + idx = int(row.get("dataset_index", -1)) + turn_id = int(row.get("turn_id", 0)) + key = (run_id, idx, turn_id) + if key not in deduped_turns: + deduped_turns[key] = row + + turn_rows = list(deduped_turns.values()) + + # Global turn-level output sorted by CER descending. + turn_rows_sorted = sorted( + turn_rows, + key=lambda x: ( + x.get("cer") is not None, + _cer_sort_value(x), + ), + reverse=True, + ) + + turn_jsonl_path = os.path.join(args.out_dir, "filewise_metrics_turns_sorted_by_cer.jsonl") + turn_csv_path = os.path.join(args.out_dir, "filewise_metrics_turns_sorted_by_cer.csv") + turn_json_path = os.path.join(args.out_dir, "filewise_metrics_turns_sorted_by_cer.json") + write_jsonl(turn_jsonl_path, turn_rows_sorted) + write_json(turn_json_path, {"filewise_metrics": turn_rows_sorted}) + write_turnwise_csv(turn_csv_path, turn_rows_sorted) + + # Compatibility alias using the "global" name. + global_jsonl_path = os.path.join(args.out_dir, "filewise_metrics_global_sorted_by_cer.jsonl") + global_csv_path = os.path.join(args.out_dir, "filewise_metrics_global_sorted_by_cer.csv") + write_jsonl(global_jsonl_path, turn_rows_sorted) + write_turnwise_csv(global_csv_path, turn_rows_sorted) + + turn_global_metrics = compute_magpie_style_global_metrics(turn_rows_sorted) + turn_global_metrics_path = os.path.join(args.out_dir, "metrics_final_turn_global.json") + write_json( + turn_global_metrics_path, + { + "aggregation": "magpie_style_global_metrics_over_turn_rows", + **turn_global_metrics, + }, + ) + + logging.info(f"Saved global turn-level filewise metrics JSONL to: {turn_jsonl_path}") + logging.info(f"Saved global turn-level filewise metrics JSON to: {turn_json_path}") + logging.info(f"Saved global turn-level filewise metrics CSV to: {turn_csv_path}") + logging.info(f"Saved global filewise compatibility JSONL to: {global_jsonl_path}") + logging.info(f"Saved global filewise compatibility CSV to: {global_csv_path}") + logging.info(f"Saved turn global metrics JSON to: {turn_global_metrics_path}") + + # Group turn rows into one row per original file/sample. + grouped = {} + for row in turn_rows: + run_id = int(row.get("run_id", 0)) + idx = int(row.get("dataset_index", -1)) + key = (run_id, idx) + + if key not in grouped: + grouped[key] = { + "run_id": run_id, + "dataset_index": idx, + "rank": int(row.get("rank", -1)), + "target_audio_path": row.get("target_audio_path", ""), + "context_audio_path": row.get("context_audio_path", ""), + "turn_rows": [], + } + + grouped[key]["turn_rows"].append(row) + + def avg(vals): + finite_vals = [] + for x in vals: + if x is None: + continue + try: + x = float(x) + except Exception: + continue + if math.isfinite(x): + finite_vals.append(x) + return None if not finite_vals else sum(finite_vals) / len(finite_vals) + + sample_rows = [] + for _, group in grouped.items(): + turns = sorted(group["turn_rows"], key=lambda x: int(x.get("turn_id", 0))) + + cer_turns = [r.get("cer") for r in turns] + wer_turns = [r.get("wer") for r in turns] + ssim_turns = [r.get("ssim") for r in turns] + + pred_gt_ssim_turns = [r.get("pred_gt_ssim") for r in turns] + pred_context_ssim_turns = [r.get("pred_context_ssim") for r in turns] + gt_context_ssim_turns = [r.get("gt_context_ssim") for r in turns] + pred_gt_ssim_alternate_turns = [r.get("pred_gt_ssim_alternate") for r in turns] + pred_context_ssim_alternate_turns = [r.get("pred_context_ssim_alternate") for r in turns] + gt_context_ssim_alternate_turns = [r.get("gt_context_ssim_alternate") for r in turns] + utmosv2_turns = [r.get("utmosv2") for r in turns] + eou_type_turns = [r.get("eou_type") for r in turns] + eou_trailing_duration_turns = [r.get("eou_trailing_duration") for r in turns] + eou_trail_rms_ratio_turns = [r.get("eou_trail_rms_ratio") for r in turns] + + sample_row = { + "run_id": group["run_id"], + "dataset_index": group["dataset_index"], + "rank": group["rank"], + "num_turns": len(turns), + "turn_ids": [int(r.get("turn_id", 0)) for r in turns], + "target_audio_path": group["target_audio_path"], + "context_audio_path": group["context_audio_path"], + "pred_audio_paths": [r.get("pred_audio_path", "") for r in turns], + "predicted_codes_paths": [r.get("predicted_codes_path") for r in turns], + "sample_pred_audio_path": turns[0].get("sample_pred_audio_path", turns[0].get("pred_audio_path", "")), + "sample_predicted_codes_path": turns[0].get( + "sample_predicted_codes_path", + turns[0].get("predicted_codes_path"), + ), + "pred_audio_seconds_turns": [r.get("pred_audio_seconds") for r in turns], + "reference_text": [r.get("reference_text", "") for r in turns], + "asr_hyp": [r.get("asr_hyp", "") for r in turns], + "gt_text": [r.get("gt_text", "") for r in turns], + "pred_text": [r.get("pred_text", "") for r in turns], + "cer_turns": cer_turns, + "wer_turns": wer_turns, + "ssim_turns": ssim_turns, + "pred_gt_ssim_turns": pred_gt_ssim_turns, + "pred_context_ssim_turns": pred_context_ssim_turns, + "gt_context_ssim_turns": gt_context_ssim_turns, + "pred_gt_ssim_alternate_turns": pred_gt_ssim_alternate_turns, + "pred_context_ssim_alternate_turns": pred_context_ssim_alternate_turns, + "gt_context_ssim_alternate_turns": gt_context_ssim_alternate_turns, + "utmosv2_turns": utmosv2_turns, + "eou_type_turns": eou_type_turns, + "eou_trailing_duration_turns": eou_trailing_duration_turns, + "eou_trail_rms_ratio_turns": eou_trail_rms_ratio_turns, + "cer": avg(cer_turns), + "wer": avg(wer_turns), + "ssim": avg(ssim_turns), + "pred_gt_ssim": avg(pred_gt_ssim_turns), + "pred_context_ssim": avg(pred_context_ssim_turns), + "gt_context_ssim": avg(gt_context_ssim_turns), + "pred_gt_ssim_alternate": avg(pred_gt_ssim_alternate_turns), + "pred_context_ssim_alternate": avg(pred_context_ssim_alternate_turns), + "gt_context_ssim_alternate": avg(gt_context_ssim_alternate_turns), + "utmosv2": avg(utmosv2_turns), + "eou_error": None if not eou_type_turns or eou_type_turns[0] is None else float( + sum(1 for x in eou_type_turns if str(x).lower() != "good") / len(eou_type_turns) + ), + "total_gen_audio_seconds": sum(float(r.get("total_gen_audio_seconds", r.get("pred_audio_seconds", 0.0)) or 0.0) for r in turns), + } + + sample_rows.append(sample_row) + + # Sample-level output sorted by average CER descending. + sample_rows.sort( + key=lambda x: ( + x.get("cer") is not None, + _cer_sort_value(x), + ), + reverse=True, + ) + + jsonl_path = os.path.join(args.out_dir, "filewise_metrics_sorted_by_cer.jsonl") + json_path = os.path.join(args.out_dir, "filewise_metrics_sorted_by_cer.json") + csv_path = os.path.join(args.out_dir, "filewise_metrics_sorted_by_cer.csv") + + write_jsonl(jsonl_path, sample_rows) + write_json(json_path, {"filewise_metrics": sample_rows}) + write_filewise_csv(csv_path, sample_rows) + + logging.info(f"Saved sample-level filewise metrics JSONL to: {jsonl_path}") + logging.info(f"Saved sample-level filewise metrics JSON to: {json_path}") + logging.info(f"Saved sample-level filewise metrics CSV to: {csv_path}") - _write_filewise_metrics(filewise_metrics_path, filewise_metrics) + topk = min(int(args.filewise_metrics_topk_log), len(sample_rows)) + if topk > 0: + logging.info(f"Top {topk} worst CER samples:") + for row in sample_rows[:topk]: logging.info( - f"Filewise metrics checkpoint saved: {filewise_metrics_path} " - f"({len(filewise_metrics)} samples, sorted by CER)" + "run_id=%s dataset_index=%s num_turns=%s cer=%s wer=%s ssim=%s path=%s" + % ( + row.get("run_id"), + row.get("dataset_index"), + row.get("num_turns"), + row.get("cer"), + row.get("wer"), + row.get("ssim"), + row.get("target_audio_path"), + ) ) - with fp32_precision(): - logging.info("\n--- Evaluation Metrics ---") - cer_wer = intelligibility.compute() - for k, m in cer_wer.items(): - logging.info(f"Intelligibility - {k}: {m}") - - secs_scores = secs_metric.compute() - for k, m in secs_scores.items(): - logging.info(f"SECS - {k}: {m}") - - aggregate_metrics = { - **{k: _json_metric_value(v) for k, v in cer_wer.items()}, - **{k: _json_metric_value(v) for k, v in secs_scores.items()}, + topk_turns = min(int(args.filewise_metrics_topk_log), len(turn_rows_sorted)) + if topk_turns > 0: + logging.info(f"Top {topk_turns} worst CER turns:") + for row in turn_rows_sorted[:topk_turns]: + logging.info( + "run_id=%s dataset_index=%s turn_id=%s cer=%s wer=%s ssim=%s path=%s text=%s" + % ( + row.get("run_id"), + row.get("dataset_index"), + row.get("turn_id"), + row.get("cer"), + row.get("wer"), + row.get("ssim"), + row.get("pred_audio_path"), + row.get("reference_text"), + ) + ) + + return sample_rows + +def compute_frechet_codec_distance_from_sample_rows(args, rows: List[Dict[str, Any]]): + """Compute FCD in the same spirit as MagpieTTS: GT audio vs predicted codec codes.""" + if args.disable_fcd: + return _nan() + if FrechetCodecDistance is None: + logging.warning("FrechetCodecDistance is unavailable; setting frechet_codec_distance to NaN.") + return _nan() + + gt_paths = [] + code_paths = [] + seen = set() + + for row in rows: + key = (int(row.get("run_id", 0)), int(row.get("dataset_index", -1))) + if key in seen: + continue + seen.add(key) + + gt_path = _resolve_audio_path(row.get("target_audio_path"), args.audio_dir) + code_path = row.get("sample_predicted_codes_path") or row.get("predicted_codes_path") + if gt_path and code_path and os.path.exists(gt_path) and os.path.exists(code_path): + gt_paths.append(gt_path) + code_paths.append(code_path) + + if not gt_paths: + logging.warning("No valid GT-audio/predicted-code pairs found for FCD; setting FCD to NaN.") + return _nan() + + device = _metric_device() + try: + fcd_metric = FrechetCodecDistance(codec_name=args.codec_model_path).to(device) + for gt_path, code_path in zip(gt_paths, code_paths): + fcd_metric.update_from_audio_file(gt_path, True) + predicted_codes = torch.load(code_path, map_location="cpu").unsqueeze(0).to(device) + predicted_codes_lens = torch.tensor([predicted_codes.size(-1)], dtype=torch.int, device=device) + fcd_metric.update(predicted_codes, predicted_codes_lens, False) + + fcd = fcd_metric.compute().detach().cpu().item() + fcd_metric.reset() + return float(fcd) + except Exception as e: + logging.warning(f"Frechet Codec Distance computation failed: {repr(e)}") + return _nan() + + +def compute_aggregates_from_filewise_rows(rows: List[Dict[str, Any]]): + """Aggregate over sample-level rows using the MagpieTTS evaluation metric set.""" + if len(rows) == 0: + out = compute_magpie_style_global_metrics([]) + out["cer"] = None + out["wer"] = None + out["ssim"] = None + out["num_samples"] = 0 + return out + + def avg_key(key): + return finite_avg([r.get(key) for r in rows]) + + out = { + "cer": avg_key("cer"), + "wer": avg_key("wer"), + "ssim": avg_key("ssim"), + "num_samples": len(rows), + "cer_filewise_avg": avg_key("cer"), + "wer_filewise_avg": avg_key("wer"), + "ssim_pred_gt_avg": avg_key("pred_gt_ssim"), + "ssim_pred_context_avg": avg_key("pred_context_ssim"), + "ssim_gt_context_avg": avg_key("gt_context_ssim"), + "ssim_pred_gt_avg_alternate": avg_key("pred_gt_ssim_alternate"), + "ssim_pred_context_avg_alternate": avg_key("pred_context_ssim_alternate"), + "ssim_gt_context_avg_alternate": avg_key("gt_context_ssim_alternate"), + "utmosv2_avg": avg_key("utmosv2"), + "total_gen_audio_seconds": sum(float(r.get("total_gen_audio_seconds", 0.0) or 0.0) for r in rows), + "frechet_codec_distance": _nan(), } - _write_filewise_metrics(filewise_metrics_path, filewise_metrics) - logging.info(f"Filewise metrics saved: {filewise_metrics_path}") - _write_json(aggregate_metrics_path, aggregate_metrics) - logging.info(f"Aggregate metrics saved: {aggregate_metrics_path}") + + # Sample rows contain lists, so cumulative CER/WER are computed by flattening + # the normalized turn text lists. + pred_texts = [] + gt_texts = [] + for row in rows: + preds = row.get("pred_text", row.get("asr_hyp", [])) + refs = row.get("gt_text", row.get("reference_text", [])) + if not isinstance(preds, list): + preds = [preds] + if not isinstance(refs, list): + refs = [refs] + for pred, ref in zip(preds, refs): + ref = "" if ref is None else str(ref).strip() + pred = "" if pred is None else str(pred).strip() + if ref: + pred_texts.append(pred) + gt_texts.append(ref) + + if pred_texts and gt_texts: + try: + out["cer_cumulative"] = float(word_error_rate_detail(hypotheses=pred_texts, references=gt_texts, use_cer=True)[0]) + except Exception: + out["cer_cumulative"] = None + try: + out["wer_cumulative"] = float(word_error_rate_detail(hypotheses=pred_texts, references=gt_texts, use_cer=False)[0]) + except Exception: + out["wer_cumulative"] = None + else: + out["cer_cumulative"] = None + out["wer_cumulative"] = None + + out["cer_gt_audio_cumulative"] = _nan() + out["wer_gt_audio_cumulative"] = _nan() + + eou_types = [] + for row in rows: + values = row.get("eou_type_turns", []) + if isinstance(values, list): + eou_types.extend(values) + eou_types = [x for x in eou_types if x is not None] + + if eou_types: + counts = Counter(eou_types) + n = len(eou_types) + if EoUType is not None: + labels = list(EoUType.error_types()) + good_label = EoUType.GOOD + else: + labels = ["cutoff", "silence", "noise"] + good_label = "good" + for label in labels: + out[f"eou_{label}_rate"] = counts.get(label, 0) / n + out["eou_error_rate"] = 1.0 - counts.get(good_label, 0) / n + else: + out["eou_cutoff_rate"] = _nan() + out["eou_silence_rate"] = _nan() + out["eou_noise_rate"] = _nan() + out["eou_error_rate"] = _nan() + + return out + +def save_filewise_final_summary(args, filewise_rows: List[Dict[str, Any]]): + filewise_summary = compute_aggregates_from_filewise_rows(filewise_rows) + filewise_summary["frechet_codec_distance"] = compute_frechet_codec_distance_from_sample_rows(args, filewise_rows) + save_easymagpie_style_eval_outputs(args, filewise_rows, filewise_summary) + + obj = { + "aggregation": "mean_over_sample_metrics_each_sample_contains_turn_metric_lists", + **filewise_summary, + } + + path = os.path.join(args.out_dir, "metrics_final_filewise_average.json") + write_json(path, obj) + + sample_metrics_final_path = os.path.join(args.out_dir, "metrics_final_sample_average.json") + write_json(sample_metrics_final_path, obj) + + final_txt_path = os.path.join(args.out_dir, "metrics_final.txt") + final_text = format_filewise_final_metric_text(filewise_summary) + write_text_atomic(final_txt_path, final_text) + + print("\n--- Final Sample-Averaged Evaluation Metrics ---", flush=True) + print(final_text, flush=True) + + logging.info(f"Filewise averaged final metrics saved to: {path}") + logging.info(f"Sample averaged metrics_final JSON saved to: {sample_metrics_final_path}") + logging.info(f"Final metrics TXT saved to: {final_txt_path}") + + return obj + + +# ----------------------------- +# Args / main +# ----------------------------- + + +def parse_args(): + parser = argparse.ArgumentParser(description="EasyMagpieTTS Multi-GPU Inference Evaluation") + + parser.add_argument("--checkpoint_path", type=str, required=True) + parser.add_argument("--codec_model_path", type=str, required=True) + parser.add_argument("--datasets_json_path", type=str, required=True) + parser.add_argument("--out_dir", type=str, required=True) + + parser.add_argument("--phoneme_tokenizer_path", type=str, default=None) + parser.add_argument("--audio_dir", type=str, default=None) + parser.add_argument("--inference_dtype", type=str, default="float32", choices=["float32", "float16", "bfloat16"]) + parser.add_argument("--debug_dtype", action="store_true") + parser.add_argument("--debug_gpu_assignment", action="store_true") + parser.add_argument("--use_librosa", action="store_true") + + parser.add_argument("--batch_size", type=int, default=6) + parser.add_argument("--num_workers", type=int, default=4) + parser.add_argument( + "--emulate_multiturn", + action="store_true", + help=( + "Group scalar single-turn JSONL rows by speaker into synthetic multiturn samples. " + "This replaces the older --num_turns behavior." + ), + ) + parser.add_argument( + "--emulate_multiturn_num_turns", + type=int, + default=1, + help="Number of scalar single-turn rows to group when --emulate_multiturn is enabled.", + ) + parser.add_argument("--max_eval_turns", type=int, default=6) + + parser.add_argument( + "--inference_mode", + type=str, + default="auto", + choices=["auto", "multiturn_user_audio", "single_turn"], + help=( + "auto selects multiturn_user_audio for samples with list text or user_audio_file_path, " + "and single_turn for classic scalar-text datasets such as LibriTTS. " + "single_turn does not prefill with user/silence audio and supports batch_size > 1." + ), + ) + + parser.add_argument("--user_custom_speaker_reference", action="store_true") + parser.add_argument("--inference_speaker_reference", type=str, default=None) + parser.add_argument("--language", type=str, default="en") + + parser.add_argument("--use_cfg", action="store_true") + parser.add_argument("--cfg_scale", type=float, default=2.5) + parser.add_argument("--temperature", type=float, default=0.7) + parser.add_argument("--topk", type=int, default=80) + parser.add_argument("--max_tts_steps", type=int, default=2000) + parser.add_argument("--normalize_volume", type=lambda x: str(x).lower() in ["true", "1", "yes"], default=True) + + parser.add_argument( + "--save_filewise_metrics", + action=argparse.BooleanOptionalAction, + default=True, + help="Save filewise metrics. Enabled by default. Use --no-save_filewise_metrics to disable.", + ) + parser.add_argument( + "--filewise_metrics_topk_log", + type=int, + default=20, + help="Number of worst CER samples to print on rank 0.", + ) + parser.add_argument( + "--num_eval_runs", + type=int, + default=1, + help="Repeat the full eval set N times. Repetitions are preserved in final filewise average.", + ) + parser.add_argument( + "--sort_by_text_token_count", + action="store_true", + help="Sort eval samples by total text token count before distributed sharding for better load balancing.", + ) + parser.add_argument( + "--metric_batch_size", + type=int, + default=8, + help="Batch size used for post-generation ASR/SSIM metric computation.", + ) + parser.add_argument( + "--max_metric_audio_sec", + type=float, + default=120.0, + help="Clamp generated audio length used for ASR/SSIM metrics to avoid metric OOM/hangs.", + ) + parser.add_argument( + "--asr_model_name", + type=str, + default="nvidia/parakeet-tdt-1.1b", + help="Pretrained ASR model used for CER/WER, matching the EasyMagpie/MagpieTTS eval default.", + ) + parser.add_argument( + "--sv_model_type", + type=str, + default="titanet", + choices=["titanet", "wavlm"], + help="Speaker verification model type for MagpieTTS-style SSIM metrics.", + ) + parser.add_argument( + "--disable_speaker_metrics", + action="store_true", + help="Disable pred/GT/context speaker similarity metrics.", + ) + parser.add_argument( + "--disable_utmosv2", + action="store_true", + help="Disable UTMOSv2. By default UTMOSv2 is computed when the dependency is available.", + ) + parser.add_argument( + "--disable_eou", + action="store_true", + help="Disable end-of-utterance classification metrics.", + ) + parser.add_argument( + "--disable_fcd", + action="store_true", + help="Disable Frechet Codec Distance. By default FCD is computed from saved predicted codec codes.", + ) + parser.add_argument( + "--eou_model_name", + type=str, + default="facebook/wav2vec2-base-960h", + help="Hugging Face model id or local path for the EOU classifier.", + ) + parser.add_argument( + "--eou_batch_size", + type=int, + default=32, + help="Batch size for EOU classification.", + ) + + parser.add_argument( + "--save_plots", + action=argparse.BooleanOptionalAction, + default=True, + help="Save EasyMagpie/MagpieTTS-style violin plots. Enabled by default.", + ) + parser.add_argument( + "--violin_plot_metrics", + type=str, + nargs="*", + default=list(DEFAULT_VIOLIN_METRICS), + help="Metrics to include in violin plots.", + ) + + return parser.parse_args() + + +def main(): + args = parse_args() + + os.makedirs(args.out_dir, exist_ok=True) + os.makedirs(get_audio_out_dir(args), exist_ok=True) + os.makedirs(get_generated_turn_audio_dir(args), exist_ok=True) + os.makedirs(get_context_metric_audio_dir(args), exist_ok=True) + os.makedirs(get_predicted_codes_dir(args), exist_ok=True) + + distributed, rank, local_rank, world_size, device_index = setup_distributed() + + if args.inference_mode == "multiturn_user_audio" and args.batch_size != 1: + raise RuntimeError( + "--inference_mode multiturn_user_audio requires --batch_size=1 per process. " + "Use multiple GPUs/processes for parallelism instead of increasing batch_size." + ) + + if args.num_eval_runs <= 0: + raise RuntimeError("--num_eval_runs must be >= 1.") + + if args.emulate_multiturn and args.emulate_multiturn_num_turns <= 1: + raise RuntimeError("--emulate_multiturn_num_turns must be > 1 when --emulate_multiturn is enabled.") + + target_device = torch.device(f"cuda:{device_index}" if torch.cuda.is_available() and device_index >= 0 else "cpu") + target_dtype = getattr(torch, args.inference_dtype) + torch.set_default_dtype(target_dtype) + + hostname = socket.gethostname() + cuda_name = torch.cuda.get_device_name(target_device) if torch.cuda.is_available() and device_index >= 0 else "cpu" + + all_rank_print( + rank, + f"host={hostname} local_rank={local_rank} world_size={world_size} " + f"device={target_device} device_name={cuda_name}", + ) + + model = build_model_and_codec(args, target_device, target_dtype) + codec_sil_codes = model.codec_sil_codes + + if args.debug_dtype: + handles, stats, examples = attach_dtype_counter(model) + else: + handles = stats = examples = None + + emulate_multiturn_num_turns = args.emulate_multiturn_num_turns if args.emulate_multiturn else 1 + full_eval_dataset = EvalJSONLDataset( + args.datasets_json_path, + emulate_multiturn_num_turns=emulate_multiturn_num_turns, + ) + # debug + # full_eval_dataset.samples = full_eval_dataset.samples[:7] + + if args.sort_by_text_token_count: + full_eval_dataset = SortedByTextTokenCountDataset( + full_eval_dataset, + model=model, + max_eval_turns=args.max_eval_turns, + descending=True, + ) + + collate_fn = partial( + collate_and_tokenize_custom, + model=model, + sample_rate=model.sample_rate, + root_path=args.audio_dir, + normalize_audio_volume=args.normalize_volume, + use_librosa=args.use_librosa, + max_eval_turns=args.max_eval_turns, + inference_mode=args.inference_mode, + ) + + speaker_wav = load_speaker_wav_if_needed(args, model, target_dtype) + + generation_start = time.time() + all_metric_items = [] + total_batches = 0 + total_generated_samples = 0 + + for run_id in range(args.num_eval_runs): + if distributed: + sampler = DistributedSampler( + full_eval_dataset, + num_replicas=world_size, + rank=rank, + shuffle=False, + drop_last=False, + ) + sampler.set_epoch(run_id) + else: + sampler = SequentialSampler(full_eval_dataset) + + if args.debug_gpu_assignment: + try: + assigned_indices = list(iter(sampler)) + assigned_dataset_indices = [ + int(full_eval_dataset[i].get("__dataset_index__", -1)) for i in assigned_indices + ] + all_rank_print( + rank, + f"run_id={run_id} assigned {len(assigned_dataset_indices)} / {len(full_eval_dataset)} " + f"samples to gpu={local_rank}: dataset_indices={assigned_dataset_indices}", + ) + except Exception as e: + all_rank_print(rank, f"Could not print assigned indices: {repr(e)}") + + dataloader = DataLoader( + dataset=full_eval_dataset, + batch_size=args.batch_size, + sampler=sampler, + collate_fn=collate_fn, + num_workers=args.num_workers, + pin_memory=True, + drop_last=False, + ) + + for batch_id, inputs in enumerate(dataloader): + total_batches += 1 + batch_indices = inputs.get("dataset_indices", []) + total_generated_samples += len(batch_indices) + + if args.debug_gpu_assignment: + all_rank_print( + rank, + f"run_id={run_id} gpu={local_rank} batch_id={batch_id} " + f"dataset_indices={batch_indices} text_token_counts={inputs.get('text_token_counts', [])} " + f"target_paths={inputs.get('target_audio_paths', [])}", + ) + + inputs = prepare_inputs_for_device(inputs, model, args, target_dtype, speaker_wav=speaker_wav) + + finalize_output, multiturn_turn_frame_ranges, multiturn_decode_start_frame, generated_codes = run_generation( + model=model, + inputs=inputs, + args=args, + codec_sil_codes=codec_sil_codes, + ) + + metric_items = save_generation_outputs_and_build_metric_items( + model=model, + inputs=inputs, + finalize_output=finalize_output, + multiturn_turn_frame_ranges=multiturn_turn_frame_ranges, + multiturn_decode_start_frame=multiturn_decode_start_frame, + generated_codes=generated_codes, + args=args, + rank=rank, + run_id=run_id, + ) + all_metric_items.extend(metric_items) + + if args.debug_dtype and batch_id == 0 and run_id == 0: + report_dtype_stats(handles, stats, examples, rank=rank) + + generation_elapsed = time.time() - generation_start + + # Save pre-metric manifest for debugging and restartability. + metric_manifest_path = os.path.join(args.out_dir, f"metric_items_rank{rank:04d}.jsonl") + write_jsonl(metric_manifest_path, all_metric_items) + + # Free TTS/codec model memory before loading ASR and speaker encoder metrics. + del model + if speaker_wav is not None: + del speaker_wav + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + all_rank_print( + rank, + f"generation done: batches={total_batches} generated_samples_with_sampler_padding={total_generated_samples} " + f"metric_items={len(all_metric_items)} elapsed_sec={generation_elapsed:.2f}. " + "Loading ASR/SSIM metrics now.", + ) + + rank_metrics, rank_filewise_rows = compute_metrics_after_generation( + args=args, + rank=rank, + world_size=world_size, + metric_items=all_metric_items, + ) + rank_metrics["generation_elapsed_sec"] = float(generation_elapsed) + rank_metrics["num_generated_samples_with_sampler_padding"] = int(total_generated_samples) + + rank_metrics = compute_and_save_rank_metrics_file(args, rank_metrics, rank) + all_rank_print(rank, f"saved rank metrics: {json.dumps(rank_metrics, sort_keys=True)}") + + if args.save_filewise_metrics: + rank_filewise_rows.sort( + key=lambda x: ( + x.get("cer") is not None, + float(x.get("cer")) if x.get("cer") is not None else -1.0, + ), + reverse=True, + ) + + rank_filewise_path = os.path.join(args.out_dir, f"filewise_metrics_rank{rank:04d}.jsonl") + write_jsonl(rank_filewise_path, rank_filewise_rows) + all_rank_print(rank, f"saved filewise metrics: {rank_filewise_path}") + + if rank == 0: + wait_for_rank_metric_files(args, world_size) + + merge_metrics_on_rank0(args, rank, world_size) + + if args.save_filewise_metrics: + if rank == 0: + wait_for_rank_filewise_metric_files(args, world_size) + + filewise_rows = merge_filewise_metrics_on_rank0(args, rank, world_size) + + if rank == 0: + save_filewise_final_summary(args, filewise_rows) + + cleanup_distributed() if __name__ == "__main__": diff --git a/examples/tts/easy_magpietts_inference_multiturn_runner.py b/examples/tts/easy_magpietts_inference_multiturn_runner.py new file mode 100644 index 000000000000..5d089bf4b95e --- /dev/null +++ b/examples/tts/easy_magpietts_inference_multiturn_runner.py @@ -0,0 +1,750 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +TTS inference and evaluation entry point for MagpieTTS/EasyMagpieTTS. + +This version adds EasyMagpie multiturn user-audio inference as a first-class +runner mode while keeping the existing EasyMagpie evaluation pipeline. The new +runner writes turn-level EasyMagpie-compatible generated files and a generated +turn-level manifest, so ``evaluate_generated_audio_dir`` can compute CER/WER, +SSIM, UTMOSv2, EOU, FCD, CSVs and plots without custom metric code. +""" +from __future__ import annotations + +import argparse +import copy +import json +import os +import random +import shutil +import time +from dataclasses import fields +from pathlib import Path +from typing import List, Optional, Tuple + +import numpy as np +import torch + +from nemo.collections.asr.parts.utils.manifest_utils import read_manifest +from nemo.collections.tts.models.easy_magpietts_inference import EasyModelInferenceParameters +from nemo.collections.tts.models.magpietts import ModelInferenceParameters +from nemo.collections.tts.modules.magpietts_inference.evaluate_generated_audio import load_evalset_config +from nemo.collections.tts.modules.magpietts_inference.evaluation import ( + DEFAULT_VIOLIN_METRICS, + EvaluationConfig, + compute_mean_with_confidence_interval, + evaluate_generated_audio_dir, +) +from nemo.collections.tts.modules.magpietts_inference.inference import ( + BaseInferenceConfig, + BaseInferenceRunner, + EasyMagpieInferenceConfig, + EasyMagpieInferenceRunner, + EasyMagpieMultiturnUserAudioInferenceConfig, + EasyMagpieMultiturnUserAudioInferenceRunner, + MagpieInferenceConfig, + MagpieInferenceRunner, +) +from nemo.collections.tts.modules.magpietts_inference.utils import ( + ModelLoadConfig, + get_experiment_name_from_checkpoint_path, + load_easy_magpie_model, + load_magpie_model, + log_model_architecture_summary, +) +from nemo.collections.tts.modules.magpietts_inference.visualization import create_combined_box_plot, create_violin_plot +from nemo.collections.tts.modules.magpietts_modules import EOSDetectionMethod +from nemo.utils import logging + + +def parse_layer_list(layer_str: Optional[str]) -> Optional[List[int]]: + if layer_str is None: + return None + return [int(l.strip()) for l in layer_str.split(",")] + + +def write_csv_header_if_needed(csv_path: str, header: str) -> None: + if not os.path.exists(csv_path): + os.makedirs(os.path.dirname(csv_path), exist_ok=True) + with open(csv_path, "w", encoding="utf-8") as f: + f.write(header + "\n") + + +def append_metrics_to_csv(csv_path: str, checkpoint_name: str, dataset: str, metrics: dict) -> None: + values = [ + checkpoint_name, + dataset, + metrics.get('cer_filewise_avg', ''), + metrics.get('wer_filewise_avg', ''), + metrics.get('cer_cumulative', ''), + metrics.get('wer_cumulative', ''), + metrics.get('ssim_pred_gt_avg', ''), + metrics.get('ssim_pred_context_avg', ''), + metrics.get('ssim_gt_context_avg', ''), + metrics.get('ssim_pred_gt_avg_alternate', ''), + metrics.get('ssim_pred_context_avg_alternate', ''), + metrics.get('ssim_gt_context_avg_alternate', ''), + metrics.get('cer_gt_audio_cumulative', ''), + metrics.get('wer_gt_audio_cumulative', ''), + metrics.get('utmosv2_avg', ''), + metrics.get('total_gen_audio_seconds', ''), + metrics.get('frechet_codec_distance', ''), + metrics.get('eou_cutoff_rate', ''), + metrics.get('eou_silence_rate', ''), + metrics.get('eou_noise_rate', ''), + metrics.get('eou_error_rate', ''), + ] + with open(csv_path, "a", encoding="utf-8") as f: + f.write(",".join(str(v).replace(",", " ") for v in values) + "\n") + logging.info(f"Metrics appended to: {csv_path}") + + +def create_formatted_metrics_mean_ci(metrics_mean_ci: dict) -> dict: + for k, v in metrics_mean_ci.items(): + if isinstance(v, list): + mean, ci = float(v[0]), float(v[1]) + logging.info(f"Metric {k}: {mean:.4f} ± {ci:.4f}") + metrics_mean_ci[k] = f"{mean:.4f} ± {ci:.4f}" + return metrics_mean_ci + + +def filter_datasets(dataset_meta_info: dict, datasets: Optional[str]) -> List[str]: + if datasets is None: + return list(dataset_meta_info.keys()) + selected = datasets.split(",") + for dataset in selected: + if dataset not in dataset_meta_info: + raise ValueError(f"Dataset {dataset} not found in dataset meta info") + return selected + + +def _runner_eval_manifest_and_audio_dir(runner: BaseInferenceRunner, default_manifest: str, default_audio_dir: str): + """Return evaluation manifest/audio dir produced by the runner, if any.""" + eval_manifest = getattr(runner, "evaluation_manifest_path", None) or default_manifest + eval_audio_dir = getattr(runner, "evaluation_audio_dir", None) or default_audio_dir + return eval_manifest, eval_audio_dir + + + +def _get_torchrun_rank_info() -> Tuple[int, int, int]: + """Return (rank, world_size, local_rank) from torchrun/SLURM env vars. + + We intentionally do not initialize torch.distributed here. The inference + script only needs env-based sharding, while NeMo evaluation models can run + without distributed collectives. + """ + rank = int(os.environ.get("RANK", os.environ.get("SLURM_PROCID", "0"))) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("SLURM_LOCALID", "0"))) + return rank, world_size, local_rank + + +def _configure_cuda_for_rank() -> Tuple[int, int, int]: + rank, world_size, local_rank = _get_torchrun_rank_info() + if torch.cuda.is_available(): + device_count = torch.cuda.device_count() + if device_count > 0: + torch.cuda.set_device(local_rank % device_count) + logging.info( + f"Using CUDA device {local_rank % device_count}; " + f"rank={rank}, local_rank={local_rank}, world_size={world_size}" + ) + return rank, world_size, local_rank + + +def _wait_for_multiturn_rank_manifests(repeat_audio_dir: str, world_size: int, timeout_sec: int = 7200) -> None: + deadline = time.time() + timeout_sec + while time.time() < deadline: + missing = [] + for rank in range(world_size): + path = os.path.join( + repeat_audio_dir, + f"rank_{rank:04d}", + f"multiturn_user_audio_turn_manifest_rank{rank:04d}.jsonl", + ) + if not os.path.exists(path): + missing.append(path) + if not missing: + return + time.sleep(5) + raise RuntimeError(f"Timed out waiting for multiturn rank manifests: {missing}") + + +def _copy_or_link(src: str, dst: str) -> None: + if src is None or not os.path.exists(src): + return + os.makedirs(os.path.dirname(dst), exist_ok=True) + try: + if os.path.lexists(dst): + os.remove(dst) + os.symlink(os.path.abspath(src), dst) + except Exception: + shutil.copyfile(src, dst) + + +def _merge_multiturn_rank_outputs(repeat_audio_dir: str, world_size: int, save_predicted_codes: bool) -> str: + """Merge rank-local multiturn outputs into one EasyMagpie-compatible dir. + + Each rank writes local files named predicted_audio_0.wav, target_audio_0.wav, + context_audio_0.wav, predicted_codes_0.pt, ... inside rank_XXXX/. This + function remaps them to contiguous global indices in repeat_audio_dir/ and + writes a merged turn-level manifest. + """ + merged_records = [] + global_idx = 0 + + for rank in range(world_size): + rank_dir = os.path.join(repeat_audio_dir, f"rank_{rank:04d}") + rank_manifest = os.path.join(rank_dir, f"multiturn_user_audio_turn_manifest_rank{rank:04d}.jsonl") + if not os.path.exists(rank_manifest): + raise FileNotFoundError(f"Missing rank manifest: {rank_manifest}") + + with open(rank_manifest, "r", encoding="utf-8") as f: + rank_records = [json.loads(line) for line in f if line.strip()] + + for local_idx, record in enumerate(rank_records): + pred_src = os.path.join(rank_dir, f"predicted_audio_{local_idx}.wav") + pred_dst = os.path.join(repeat_audio_dir, f"predicted_audio_{global_idx}.wav") + _copy_or_link(pred_src, pred_dst) + + if save_predicted_codes: + code_src = os.path.join(rank_dir, f"predicted_codes_{local_idx}.pt") + code_dst = os.path.join(repeat_audio_dir, f"predicted_codes_{global_idx}.pt") + _copy_or_link(code_src, code_dst) + + target_src = os.path.join(rank_dir, record.get("audio_filepath", f"target_audio_{local_idx}.wav")) + target_dst = os.path.join(repeat_audio_dir, f"target_audio_{global_idx}.wav") + _copy_or_link(target_src, target_dst) + + context_src = os.path.join( + rank_dir, + record.get("context_audio_filepath", f"context_audio_{local_idx}.wav"), + ) + context_dst = os.path.join(repeat_audio_dir, f"context_audio_{global_idx}.wav") + _copy_or_link(context_src, context_dst) + + merged = dict(record) + merged["audio_filepath"] = f"target_audio_{global_idx}.wav" + merged["context_audio_filepath"] = f"context_audio_{global_idx}.wav" + merged["rank"] = rank + merged["rank_local_idx"] = local_idx + merged_records.append(merged) + global_idx += 1 + + merged_manifest = os.path.join(repeat_audio_dir, "multiturn_user_audio_turn_manifest.jsonl") + with open(merged_manifest, "w", encoding="utf-8") as f: + for record in merged_records: + f.write(json.dumps(record, ensure_ascii=False) + "\n") + + logging.info(f"Merged {len(merged_records)} multiturn turn records into {merged_manifest}") + return merged_manifest + + +def run_inference_and_evaluation( + runner: BaseInferenceRunner, + checkpoint_name: str, + inference_config: BaseInferenceConfig, + eval_config: EvaluationConfig, + dataset_meta_info: dict, + datasets: List[str], + out_dir: str, + flops_per_component: dict, + moe_info: str, + num_repeats: int = 1, + confidence_level: float = 0.95, + violin_plot_metrics: Optional[List[str]] = None, + clean_up_disk: bool = False, + skip_evaluation: bool = False, +) -> Tuple[Optional[float], Optional[float]]: + if violin_plot_metrics is None: + violin_plot_metrics = list(DEFAULT_VIOLIN_METRICS) + if not eval_config.with_utmosv2 and 'utmosv2' in violin_plot_metrics: + violin_plot_metrics.remove('utmosv2') + + rank, world_size, _ = _get_torchrun_rank_info() + is_distributed = world_size > 1 + is_multiturn_user_audio = getattr(runner, "produces_turn_level_evaluation", False) + + if hasattr(runner, "set_distributed_context"): + runner.set_distributed_context(rank=rank, world_size=world_size) + + full_checkpoint_name = ( + f"{checkpoint_name}_{moe_info}{inference_config.build_identifier()}_SV_{eval_config.sv_model}" + ) + + ssim_per_dataset = [] + cer_per_dataset = [] + all_datasets_filewise_metrics = {} + + csv_header = ( + "checkpoint_name,dataset,cer_filewise_avg,wer_filewise_avg,cer_cumulative," + "wer_cumulative,ssim_pred_gt_avg,ssim_pred_context_avg,ssim_gt_context_avg," + "ssim_pred_gt_avg_alternate,ssim_pred_context_avg_alternate," + "ssim_gt_context_avg_alternate,cer_gt_audio_cumulative,wer_gt_audio_cumulative," + "utmosv2_avg,total_gen_audio_seconds,frechet_codec_distance," + "eou_cutoff_rate,eou_silence_rate,eou_noise_rate,eou_error_rate" + ) + + for dataset in datasets: + logging.info(f"Processing dataset: {dataset}") + meta = dataset_meta_info[dataset] + manifest_records = read_manifest(meta['manifest_path']) + language = meta.get('whisper_language', 'en') + + dataset_meta_for_dl = copy.deepcopy(meta) + for key in ["whisper_language", "load_cached_codes_if_available"]: + dataset_meta_for_dl.pop(key, None) + + eval_dir = os.path.join(out_dir, f"{full_checkpoint_name}_{dataset}") + audio_dir = os.path.join(eval_dir, "audio") + os.makedirs(eval_dir, exist_ok=True) + + per_run_csv = os.path.join(eval_dir, "all_experiment_metrics.csv") + if rank == 0: + write_csv_header_if_needed(per_run_csv, csv_header) + + metrics_all_repeats = [] + filewise_metrics_all_repeats = [] + + for repeat_idx in range(num_repeats): + logging.info(f"Repeat {repeat_idx + 1}/{num_repeats} for dataset {dataset}, rank {rank}/{world_size}") + repeat_audio_dir = os.path.join(audio_dir, f"repeat_{repeat_idx}") + os.makedirs(repeat_audio_dir, exist_ok=True) + + test_dataset = runner.create_dataset({dataset: dataset_meta_for_dl}) + + if not is_multiturn_user_audio: + if is_distributed: + raise RuntimeError( + "torchrun multi-GPU sharding is currently implemented for " + "--easy_magpie_inference_mode multiturn_user_audio only. " + "Use the existing single-process path for single_turn/magpie, or add a " + "rank-safe merge path for those runners." + ) + if len(test_dataset) != len(manifest_records): + raise ValueError( + f"Dataset length mismatch: {len(test_dataset)} vs {len(manifest_records)} manifest records" + ) + + if is_distributed and is_multiturn_user_audio: + rank_audio_dir = os.path.join(repeat_audio_dir, f"rank_{rank:04d}") + inference_output_dir = rank_audio_dir + else: + inference_output_dir = repeat_audio_dir + + rtf_metrics_list, _, codec_file_paths = runner.run_inference_on_dataset( + dataset=test_dataset, + output_dir=inference_output_dir, + manifest_records=manifest_records, + audio_base_dir=meta['audio_dir'], + save_cross_attention_maps=True, + save_context_audio=(repeat_idx == 0), + save_predicted_codes=eval_config.with_fcd, + ) + + mean_rtf = runner.compute_mean_rtf_metrics(rtf_metrics_list) + for component_name, component_flops in flops_per_component.items(): + for key, value in component_flops.items(): + mean_rtf[f"{component_name}_{key}"] = value + logging.info(f"{component_name} FLOPs per token: {component_flops['total_flops_per_token']:,}") + + rtf_path = os.path.join(eval_dir, f"{dataset}_rtf_metrics_{repeat_idx}_rank{rank:04d}.json") + with open(rtf_path, "w", encoding="utf-8") as f: + json.dump(mean_rtf, f, indent=4) + + if skip_evaluation: + logging.info("Skipping evaluation as requested.") + continue + + if is_distributed and is_multiturn_user_audio: + if rank != 0: + # Non-zero ranks only generate. Rank 0 waits and evaluates merged outputs. + continue + + _wait_for_multiturn_rank_manifests(repeat_audio_dir, world_size) + merged_manifest_path = _merge_multiturn_rank_outputs( + repeat_audio_dir=repeat_audio_dir, + world_size=world_size, + save_predicted_codes=eval_config.with_fcd, + ) + eval_manifest_path = merged_manifest_path + eval_audio_dir = repeat_audio_dir + else: + eval_manifest_path, eval_audio_dir = _runner_eval_manifest_and_audio_dir( + runner, + default_manifest=meta['manifest_path'], + default_audio_dir=meta['audio_dir'], + ) + + eval_config_for_dataset = EvaluationConfig( + sv_model=eval_config.sv_model, + asr_model_name=eval_config.asr_model_name, + eou_model_name=eval_config.eou_model_name, + language=language, + with_utmosv2=eval_config.with_utmosv2, + with_fcd=eval_config.with_fcd, + codec_model_path=eval_config.codec_model_path, + device=eval_config.device, + ) + + metrics, filewise_metrics = evaluate_generated_audio_dir( + manifest_path=eval_manifest_path, + audio_dir=eval_audio_dir, + generated_audio_dir=repeat_audio_dir, + config=eval_config_for_dataset, + ) + + metrics_all_repeats.append(metrics) + filewise_metrics_all_repeats.extend(filewise_metrics) + + with open(os.path.join(eval_dir, f"{dataset}_metrics_{repeat_idx}.json"), "w", encoding="utf-8") as f: + json.dump(metrics, f, indent=4) + + sorted_filewise = sorted(filewise_metrics, key=lambda x: x.get('cer', 0), reverse=True) + with open(os.path.join(eval_dir, f"{dataset}_filewise_metrics_{repeat_idx}.json"), "w", encoding="utf-8") as f: + json.dump(sorted_filewise, f, indent=4) + + append_metrics_to_csv(per_run_csv, full_checkpoint_name, dataset, metrics) + create_violin_plot( + filewise_metrics, + violin_plot_metrics, + Path(eval_dir) / f"{dataset}_violin_{repeat_idx}.png", + ) + + # EasyMagpie deletes codec files after evaluation. For distributed + # multiturn, the merged predicted_codes_*.pt live in repeat_audio_dir. + cleanup_code_paths = codec_file_paths + if is_distributed and is_multiturn_user_audio: + cleanup_code_paths = list(Path(repeat_audio_dir).glob("predicted_codes_*.pt")) + for codec_file_path in cleanup_code_paths: + if os.path.exists(codec_file_path): + os.remove(codec_file_path) + + if rank != 0: + continue + + if skip_evaluation or not metrics_all_repeats: + continue + + all_datasets_filewise_metrics[dataset] = filewise_metrics_all_repeats + metrics_mean_ci = compute_mean_with_confidence_interval(metrics_all_repeats, confidence=confidence_level) + formatted_metrics_mean_ci = create_formatted_metrics_mean_ci(metrics_mean_ci) + + ci_csv = os.path.join(out_dir, "all_experiment_metrics_with_ci.csv") + write_csv_header_if_needed(ci_csv, csv_header) + append_metrics_to_csv(ci_csv, full_checkpoint_name, dataset, formatted_metrics_mean_ci) + + ssim_values = [m['ssim_pred_context_avg'] for m in metrics_all_repeats] + cer_values = [m['cer_cumulative'] for m in metrics_all_repeats] + ssim_per_dataset.append(np.mean(ssim_values)) + cer_per_dataset.append(np.mean(cer_values)) + + if rank == 0 and len(all_datasets_filewise_metrics) > 1: + combined_plot_path = os.path.join(out_dir, f"{full_checkpoint_name}_combined_violin_plot.png") + create_combined_box_plot(all_datasets_filewise_metrics, violin_plot_metrics, combined_plot_path) + + if rank == 0 and clean_up_disk: + logging.info(f"Cleaning up output directory: {out_dir}") + shutil.rmtree(out_dir) + + if rank == 0 and ssim_per_dataset and cer_per_dataset: + return np.mean(cer_per_dataset), np.mean(ssim_per_dataset) + return None, None + + +def _get_shared_inference_param_names() -> set: + magpie_fields = {f.name for f in fields(ModelInferenceParameters)} + easy_fields = {f.name for f in fields(EasyModelInferenceParameters)} + return magpie_fields & easy_fields + + +def _add_inference_param_fields( + group: argparse._ArgumentGroup, + param_cls: type, + skip_fields: Optional[set] = None, + only_fields: Optional[set] = None, +) -> None: + if skip_fields is None: + skip_fields = set() + for f in fields(param_cls): + if f.name in skip_fields: + continue + if only_fields is not None and f.name not in only_fields: + continue + extra_args: dict = {"type": f.type} + if f.type == bool: + extra_args = {"action": "store_true"} + if f.name in ("estimate_alignment_from_layers", "apply_prior_to_layers"): + extra_args = {"help": "Must be a comma separate string. Not enclosed in brackets", "type": str} + elif f.name == "eos_detection_method": + extra_args["choices"] = [m.value for m in EOSDetectionMethod] + group.add_argument(f"--{f.name}", **extra_args) + + +def _add_common_args(parser: argparse.ArgumentParser) -> None: + parser.add_argument('--model_type', type=str, default='magpie', choices=['magpie', 'easy_magpie']) + parser.add_argument('--deterministic', action='store_true') + + model_group = parser.add_argument_group('Model Loading') + model_group.add_argument('--hparams_files', type=str, default=None) + model_group.add_argument('--checkpoint_files', type=str, default=None) + model_group.add_argument('--nemo_files', type=str, default=None) + model_group.add_argument('--codecmodel_path', type=str, required=True) + model_group.add_argument('--hparams_file_from_wandb', action='store_true') + model_group.add_argument('--legacy_codebooks', action='store_true') + model_group.add_argument('--legacy_text_conditioning', action='store_true') + + data_group = parser.add_argument_group('Dataset and Output') + data_group.add_argument('--datasets_json_path', type=str, required=True, default=None) + data_group.add_argument('--datasets_base_path', type=Path, default=None) + data_group.add_argument('--datasets', type=str, default=None) + data_group.add_argument('--out_dir', type=str, required=True) + data_group.add_argument('--log_exp_name', action='store_true') + data_group.add_argument('--clean_up_disk', action='store_true') + + infer_group = parser.add_argument_group('Common Inference Parameters') + infer_group.add_argument('--batch_size', type=int, default=32) + infer_group.add_argument('--use_cfg', action='store_true') + infer_group.add_argument('--use_local_transformer', action='store_true') + shared_param_names = _get_shared_inference_param_names() + _add_inference_param_fields(infer_group, ModelInferenceParameters, only_fields=shared_param_names) + + eval_group = parser.add_argument_group('Evaluation') + eval_group.add_argument('--run_evaluation', action='store_true') + eval_group.add_argument('--sv_model', type=str, default='titanet', choices=['titanet', 'wavlm']) + eval_group.add_argument('--asr_model_name', type=str, default='nvidia/parakeet-tdt-1.1b') + eval_group.add_argument('--eou_model_name', type=str, default='facebook/wav2vec2-base-960h') + eval_group.add_argument('--num_repeats', type=int, default=1) + eval_group.add_argument('--confidence_level', type=float, default=0.95) + eval_group.add_argument('--disable_utmosv2', action='store_true') + eval_group.add_argument('--violin_plot_metrics', type=str, nargs='*', default=['cer', 'pred_context_ssim', 'utmosv2']) + eval_group.add_argument('--disable_fcd', action='store_true') + + target_group = parser.add_argument_group('Quality Targets') + target_group.add_argument('--cer_target', type=float, default=None) + target_group.add_argument('--ssim_target', type=float, default=None) + + +def seed_all(seed: int): + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + torch.backends.cudnn.benchmark = False + torch.use_deterministic_algorithms(True) + + +def _add_magpie_args(parser: argparse.ArgumentParser) -> None: + group = parser.add_argument_group('MagpieTTS-specific Parameters') + shared_param_names = _get_shared_inference_param_names() + _add_inference_param_fields(group, ModelInferenceParameters, skip_fields=shared_param_names) + group.add_argument('--maskgit_n_steps', type=int, default=3) + group.add_argument('--maskgit_noise_scale', type=float, default=0.0) + group.add_argument('--maskgit_fixed_schedule', type=int, nargs='+', default=None) + group.add_argument('--maskgit_sampling_type', default=None, choices=['default', 'causal', 'purity_causal', 'purity_default']) + + +def _add_easy_magpie_args(parser: argparse.ArgumentParser) -> None: + group = parser.add_argument_group('EasyMagpieTTS-specific Parameters') + group.add_argument('--easy_magpie_inference_mode', type=str, default='single_turn', choices=['single_turn', 'multiturn_user_audio']) + group.add_argument('--max_eval_turns', type=int, default=6) + group.add_argument('--no_save_debug_multiturn_audio', action='store_true') + group.add_argument('--phoneme_input_type', type=str, default='gt', choices=['gt', 'predicted']) + group.add_argument('--phoneme_sampling_method', type=str, default='argmax', choices=['argmax', 'multinomial', 'greedy']) + group.add_argument('--dropout_text_input', action='store_true') + group.add_argument('--phoneme_tokenizer_path', type=str, default=None) + group.add_argument('--disable_cas_for_context_text', action='store_true') + + +def create_argument_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description='TTS Inference and Evaluation (MagpieTTS & EasyMagpieTTS)') + _add_common_args(parser) + _add_magpie_args(parser) + _add_easy_magpie_args(parser) + return parser + + +def _build_inference_params_from_args(param_cls: type, args): + params = {} + for f in fields(param_cls): + arg_val = vars(args).get(f.name) + if arg_val is not None: + if f.name in ('estimate_alignment_from_layers', 'apply_prior_to_layers'): + params[f.name] = parse_layer_list(arg_val) + else: + params[f.name] = arg_val + return param_cls.from_dict(params) + + +def _build_magpie_config(args) -> MagpieInferenceConfig: + return MagpieInferenceConfig( + model_inference_parameters=_build_inference_params_from_args(ModelInferenceParameters, args), + batch_size=args.batch_size, + use_cfg=args.use_cfg, + apply_attention_prior=args.apply_attention_prior, + use_local_transformer=args.use_local_transformer, + maskgit_n_steps=args.maskgit_n_steps, + maskgit_noise_scale=args.maskgit_noise_scale, + maskgit_fixed_schedule=args.maskgit_fixed_schedule, + maskgit_sampling_type=args.maskgit_sampling_type, + ) + + +def _build_easy_magpie_config(args) -> EasyMagpieInferenceConfig: + cfg_cls = ( + EasyMagpieMultiturnUserAudioInferenceConfig + if args.easy_magpie_inference_mode == 'multiturn_user_audio' + else EasyMagpieInferenceConfig + ) + kwargs = dict( + model_inference_parameters=_build_inference_params_from_args(EasyModelInferenceParameters, args), + batch_size=args.batch_size, + use_cfg=args.use_cfg, + use_local_transformer=args.use_local_transformer, + phoneme_input_type=args.phoneme_input_type, + phoneme_sampling_method=args.phoneme_sampling_method, + dropout_text_input=args.dropout_text_input, + ) + if cfg_cls is EasyMagpieMultiturnUserAudioInferenceConfig: + kwargs.update( + max_eval_turns=args.max_eval_turns, + save_debug_multiturn_audio=not args.no_save_debug_multiturn_audio, + ) + return cfg_cls(**kwargs) + + +def _select_runner_cls(args): + if args.model_type == 'magpie': + if args.easy_magpie_inference_mode != 'single_turn': + raise ValueError('--easy_magpie_inference_mode is only supported with --model_type easy_magpie') + return MagpieInferenceRunner + if args.easy_magpie_inference_mode == 'multiturn_user_audio': + return EasyMagpieMultiturnUserAudioInferenceRunner + return EasyMagpieInferenceRunner + + +def main(argv=None): + parser = create_argument_parser() + args = parser.parse_args(argv) + rank, world_size, local_rank = _configure_cuda_for_rank() + + if args.model_type == 'easy_magpie' and args.easy_magpie_inference_mode == 'multiturn_user_audio' and args.batch_size > 1: + parser.error("--easy_magpie_inference_mode multiturn_user_audio requires --batch_size 1.") + + if args.deterministic: + seed_all(seed=9) + + dataset_meta_info = load_evalset_config(config_path=args.datasets_json_path, dataset_base_path=args.datasets_base_path) + datasets = filter_datasets(dataset_meta_info, args.datasets) + logging.info(f"Loaded {len(datasets)} datasets: {', '.join(datasets)}") + + has_checkpoint_mode = ( + args.hparams_files is not None + and args.checkpoint_files is not None + and args.hparams_files != 'null' + and args.checkpoint_files != 'null' + ) + has_nemo_mode = args.nemo_files is not None and args.nemo_files != 'null' + + if not has_checkpoint_mode and not has_nemo_mode: + parser.error('You must provide either --hparams_files/--checkpoint_files or --nemo_files') + + is_easy_magpie = args.model_type == 'easy_magpie' + load_fn = load_easy_magpie_model if is_easy_magpie else load_magpie_model + inference_config = _build_easy_magpie_config(args) if is_easy_magpie else _build_magpie_config(args) + runner_cls = _select_runner_cls(args) + + eval_config = EvaluationConfig( + sv_model=args.sv_model, + asr_model_name=args.asr_model_name, + eou_model_name=args.eou_model_name, + with_utmosv2=not args.disable_utmosv2, + with_fcd=not args.disable_fcd, + codec_model_path=args.codecmodel_path if not args.disable_fcd else None, + ) + + cer, ssim = None, None + + def run_one_model(model_config: ModelLoadConfig): + nonlocal cer, ssim + model, checkpoint_name = load_fn(model_config) + moe_info, flops_per_component = log_model_architecture_summary(model) + if args.log_exp_name and model_config.checkpoint_file: + exp_name = get_experiment_name_from_checkpoint_path(model_config.checkpoint_file) + checkpoint_name = f'{exp_name}__{checkpoint_name}' + runner = runner_cls(model, inference_config) + cer, ssim = run_inference_and_evaluation( + runner=runner, + checkpoint_name=checkpoint_name, + inference_config=inference_config, + eval_config=eval_config, + dataset_meta_info=dataset_meta_info, + datasets=datasets, + out_dir=args.out_dir, + flops_per_component=flops_per_component, + moe_info=moe_info, + num_repeats=args.num_repeats, + confidence_level=args.confidence_level, + violin_plot_metrics=args.violin_plot_metrics, + clean_up_disk=args.clean_up_disk, + skip_evaluation=not args.run_evaluation, + ) + + if has_checkpoint_mode: + hparam_files = args.hparams_files.split(',') + checkpoint_files = args.checkpoint_files.split(',') + if len(hparam_files) != len(checkpoint_files): + parser.error('Number of hparams_files must match number of checkpoint_files') + for hparams_file, checkpoint_file in zip(hparam_files, checkpoint_files): + logging.info(f'Processing checkpoint: {checkpoint_file}') + run_one_model( + ModelLoadConfig( + hparams_file=hparams_file, + checkpoint_file=checkpoint_file, + codecmodel_path=args.codecmodel_path, + legacy_codebooks=args.legacy_codebooks, + legacy_text_conditioning=args.legacy_text_conditioning, + hparams_from_wandb=args.hparams_file_from_wandb, + phoneme_tokenizer_path=getattr(args, 'phoneme_tokenizer_path', None), + disable_cas_for_context_text=args.disable_cas_for_context_text, + ) + ) + else: + for nemo_file in args.nemo_files.split(','): + logging.info(f'Processing NeMo file: {nemo_file}') + run_one_model( + ModelLoadConfig( + nemo_file=nemo_file, + codecmodel_path=args.codecmodel_path, + legacy_codebooks=args.legacy_codebooks, + legacy_text_conditioning=args.legacy_text_conditioning, + phoneme_tokenizer_path=getattr(args, 'phoneme_tokenizer_path', None), + disable_cas_for_context_text=args.disable_cas_for_context_text, + ) + ) + + if cer is not None and args.cer_target is not None: + if cer > args.cer_target: + raise ValueError(f'CER {cer:.4f} exceeds target {args.cer_target:.4f}') + logging.info(f'CER {cer:.4f} meets target {args.cer_target:.4f}') + + if ssim is not None and args.ssim_target is not None: + if ssim < args.ssim_target: + raise ValueError(f'SSIM {ssim:.4f} below target {args.ssim_target:.4f}') + logging.info(f'SSIM {ssim:.4f} meets target {args.ssim_target:.4f}') + + logging.info('Inference and evaluation completed successfully.') + + +if __name__ == '__main__': + main() diff --git a/examples/tts/easy_magpietts_inference_multiturn_turn_based_as_magpie.py b/examples/tts/easy_magpietts_inference_multiturn_turn_based_as_magpie.py deleted file mode 100644 index deb2244f2376..000000000000 --- a/examples/tts/easy_magpietts_inference_multiturn_turn_based_as_magpie.py +++ /dev/null @@ -1,2247 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -""" -Multi-GPU EasyMagpieTTS / NemotronTTS multiturn inference evaluation. - -Key behavior: - - Uses torchrun env vars RANK, LOCAL_RANK, WORLD_SIZE for sharding/GPU assignment. - - Does NOT initialize torch.distributed. This avoids NeMo ASR doing distributed - collectives during metric computation. - - Generation runs first for all assigned samples. - - ASR and SECS are loaded only after generation is done and the TTS/codec model - has been deleted from GPU memory. - - ASR and SECS are loaded sequentially: ASR first, then released; SECS second. - - For --profile_multiturn_inference, metrics are computed turn-by-turn. - Final filewise outputs are grouped back to one row per original sample, with - lists for asr_hyp/reference_text/cer_turns/wer_turns/secs_turns. - - Uses DistributedSampler with explicit rank/world_size. A few repeated samples - may appear when len(dataset) is not divisible by world_size. Filewise final - metrics deduplicate sampler-padding repeats by (run_id, dataset_index, - turn_id), then group turns into one row per sample with metric lists, while - preserving --num_eval_runs repetitions. - - --sort_by_text_token_count sorts samples by total text-token count before - sharding to improve GPU load balance. - - Saves audio in out_dir/audios/. - - Saves metrics in out_dir/. - -Recommended single-node torchrun: - CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ - torchrun --standalone --nproc_per_node=8 easy_magpietts_inference_multiturn_multigpu_postgen_metrics.py ... - -Recommended single-node srun wrapper: - srun --nodes=1 --ntasks=1 --ntasks-per-node=1 --container-image=... \ - bash -lc 'torchrun --standalone --nproc_per_node=8 easy_magpietts_inference_multiturn_multigpu_postgen_metrics.py ...' -""" - -import argparse -import csv -import json -import os -import socket -import time -from copy import deepcopy -from functools import partial -from typing import Any, Dict, Iterable, List, Tuple - -import librosa -import soundfile as sf -import torch -from omegaconf import open_dict -from torch.nn.utils.rnn import pad_sequence -from torch.utils.data import DataLoader, Dataset, DistributedSampler, SequentialSampler - -from nemo.collections.audio.parts.utils.transforms import resample -from nemo.collections.asr.metrics.wer import word_error_rate -from nemo.collections.speechlm2.parts.metrics.asr_cer_wer import Intelligibility -from nemo.collections.speechlm2.parts.metrics.secs import SECS -from nemo.collections.speechlm2.parts.precision import fp32_precision -from nemo.collections.tts.models import AudioCodecModel -from nemo.collections.tts.models.easy_magpietts_inference import EasyMagpieTTSInferenceModel -from nemo.collections.tts.modules.audio_codec_modules import VectorQuantizerIndexConverter -from nemo.collections.tts.modules.magpietts_modules import CodecHelper -from nemo.collections.tts.parts.utils.tts_dataset_utils import normalize_volume -from nemo.utils import logging -from whisper_normalizer.english import EnglishTextNormalizer - - -torch.set_float32_matmul_precision("medium") -torch.backends.cudnn.allow_tf32 = True -torch.backends.cuda.matmul.allow_tf32 = True - - -# ----------------------------- -# Rank / file helpers -# ----------------------------- - - -def get_rank_info() -> Tuple[bool, int, int, int]: - world_size = int(os.environ.get("WORLD_SIZE", "1")) - rank = int(os.environ.get("RANK", os.environ.get("SLURM_PROCID", "0"))) - local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("SLURM_LOCALID", "0"))) - distributed = world_size > 1 - return distributed, rank, local_rank, world_size - - -def get_visible_device_index(local_rank: int) -> int: - if not torch.cuda.is_available(): - return -1 - ndev = torch.cuda.device_count() - if ndev <= 0: - return -1 - return local_rank % ndev - - -def setup_distributed(): - """ - Do not initialize torch.distributed. - - We only need RANK/LOCAL_RANK/WORLD_SIZE for rank assignment and dataset - sharding. Initializing a process group can cause NeMo ASR to run distributed - collectives during transcribe(), which may hang when ranks have different - audio lengths or workloads. - """ - distributed, rank, local_rank, world_size = get_rank_info() - device_index = get_visible_device_index(local_rank) - - if torch.cuda.is_available() and device_index >= 0: - torch.cuda.set_device(device_index) - - return distributed, rank, local_rank, world_size, device_index - - -def cleanup_distributed(): - return - - -def all_rank_print(rank: int, msg: str): - print(f"[rank={rank}] {msg}", flush=True) - - -def rank0_print(rank: int, msg: str): - if rank == 0: - print(msg, flush=True) - - -def get_audio_out_dir(args) -> str: - return os.path.join(args.out_dir, "audios") - - -def get_generated_turn_audio_dir(args) -> str: - return os.path.join(get_audio_out_dir(args), "metric_turns") - - -def get_context_metric_audio_dir(args) -> str: - return os.path.join(get_audio_out_dir(args), "metric_context") - - -def write_json(path: str, obj: Dict[str, Any]): - os.makedirs(os.path.dirname(path), exist_ok=True) - tmp_path = path + ".tmp" - with open(tmp_path, "w", encoding="utf-8") as f: - json.dump(obj, f, indent=2, sort_keys=True, ensure_ascii=False) - os.replace(tmp_path, path) - - -def write_text_atomic(path: str, text: str): - os.makedirs(os.path.dirname(path), exist_ok=True) - tmp_path = path + ".tmp" - with open(tmp_path, "w", encoding="utf-8") as f: - f.write(text) - os.replace(tmp_path, path) - - -def write_jsonl(path: str, rows: List[Dict[str, Any]]): - os.makedirs(os.path.dirname(path), exist_ok=True) - tmp_path = path + ".tmp" - with open(tmp_path, "w", encoding="utf-8") as f: - for row in rows: - f.write(json.dumps(row, sort_keys=True, ensure_ascii=False) + "\n") - os.replace(tmp_path, path) - - -def wait_for_files(paths: List[str], timeout_sec: float = 7200.0, poll_sec: float = 5.0): - start = time.time() - while True: - missing = [p for p in paths if not os.path.exists(p)] - if not missing: - return - if time.time() - start > timeout_sec: - raise TimeoutError("Timed out waiting for files:\n" + "\n".join(missing)) - time.sleep(poll_sec) - - -def wait_for_rank_metric_files(args, world_size: int): - paths = [os.path.join(args.out_dir, f"metrics_rank{r:04d}.json") for r in range(world_size)] - wait_for_files(paths) - - -def wait_for_rank_filewise_metric_files(args, world_size: int): - paths = [os.path.join(args.out_dir, f"filewise_metrics_rank{r:04d}.jsonl") for r in range(world_size)] - wait_for_files(paths) - - -def scalarize_metric_value(v: Any): - if torch.is_tensor(v): - if v.numel() == 1: - return float(v.detach().cpu().item()) - return v.detach().cpu().tolist() - try: - import numpy as np - - if isinstance(v, np.generic): - return float(v.item()) - except Exception: - pass - if isinstance(v, (int, float, str, bool)) or v is None: - return v - return str(v) - - -def metric_dict_to_jsonable(d: Dict[str, Any]) -> Dict[str, Any]: - return {str(k): scalarize_metric_value(v) for k, v in d.items()} - - -def safe_metric_scalar(metric_dict: Dict[str, Any], preferred_keys: List[str]): - for key in preferred_keys: - if key in metric_dict: - value = metric_dict[key] - if torch.is_tensor(value): - return float(value.detach().cpu().item()) - return float(value) - return None - - -def get_first_metric(metrics: Dict[str, Any], names: List[str], default=None): - for name in names: - if name in metrics: - return metrics[name] - return default - - -def format_final_metric_text(final_metrics: Dict[str, Any]) -> str: - intelligibility = final_metrics.get("intelligibility", {}) - secs = final_metrics.get("secs", {}) - - cer = get_first_metric(intelligibility, ["cer", "cer_dataset"]) - wer = get_first_metric(intelligibility, ["wer", "wer_dataset"]) - secs_value = get_first_metric(secs, ["secs", "secs_dataset"]) - - def fmt(x): - if x is None: - return "nan" - try: - return f"{float(x):.10f}" - except Exception: - return str(x) - - return f"Average CER: {fmt(cer)}\nAverage WER: {fmt(wer)}\nSECS: {fmt(secs_value)}\n" - - -def format_filewise_final_metric_text(filewise_summary: Dict[str, Any]) -> str: - def fmt(x): - if x is None: - return "nan" - try: - return f"{float(x):.10f}" - except Exception: - return str(x) - - return ( - f"Average CER: {fmt(filewise_summary.get('cer'))}\n" - f"Average WER: {fmt(filewise_summary.get('wer'))}\n" - f"SECS: {fmt(filewise_summary.get('secs'))}\n" - ) - - -def write_filewise_csv(path: str, rows: List[Dict[str, Any]]): - """Write sample-level filewise metrics. - - Several fields are lists (turn_ids, reference_text, asr_hyp, cer_turns, - etc.), so they are JSON-encoded inside CSV cells. - """ - os.makedirs(os.path.dirname(path), exist_ok=True) - tmp_path = path + ".tmp" - - fieldnames = [ - "run_id", - "dataset_index", - "rank", - "num_turns", - "cer", - "wer", - "secs", - "turn_ids", - "cer_turns", - "wer_turns", - "secs_turns", - "pred_audio_seconds_turns", - "target_audio_path", - "context_audio_path", - "pred_audio_paths", - "reference_text", - "asr_hyp", - ] - - def csv_value(v): - if isinstance(v, (list, dict)): - return json.dumps(v, ensure_ascii=False) - return v - - with open(tmp_path, "w", encoding="utf-8", newline="") as f: - writer = csv.DictWriter(f, fieldnames=fieldnames) - writer.writeheader() - for row in rows: - writer.writerow({k: csv_value(row.get(k, None)) for k in fieldnames}) - - os.replace(tmp_path, path) - -# ----------------------------- -# Dataset helpers -# ----------------------------- - - -def _combined_audio_name(first_audio_filepath: str, paths: List[str]) -> str: - base_names = [os.path.splitext(os.path.basename(p))[0] for p in paths if p] - ext = os.path.splitext(paths[-1])[1] if paths and paths[-1] else "" - combined_name = "_".join(base_names) + ext - dir_name = os.path.dirname(first_audio_filepath) - return os.path.join(dir_name, combined_name) if dir_name else combined_name - - -class EvalJSONLDataset(Dataset): - def __init__(self, file_path: str, num_turns: int = 1): - self.samples = [] - raw_samples = [] - - with open(file_path, "r", encoding="utf-8") as f: - for line_idx, line in enumerate(f, 1): - line = line.strip() - if not line: - continue - try: - sample = json.loads(line) - sample["__dataset_index__"] = len(raw_samples) - raw_samples.append(sample) - except json.JSONDecodeError as e: - raise ValueError(f"Invalid JSON on line {line_idx}: {e}") - - if num_turns <= 1: - self.samples = raw_samples - return - - single_turn_by_speaker = {} - for sample in raw_samples: - if isinstance(sample["text"], list): - self.samples.append(sample) - else: - speaker = sample.get("speaker", "unknown") - single_turn_by_speaker.setdefault(speaker, []).append(sample) - - synthetic_index = len(raw_samples) - for _, speaker_samples in single_turn_by_speaker.items(): - buffer_texts, buffer_paths = [], [] - first_sample_meta = None - - for sample in speaker_samples: - if not buffer_texts: - first_sample_meta = dict(sample) - - buffer_texts.append(sample["text"]) - buffer_paths.append(sample.get("audio_filepath", "")) - - if len(buffer_texts) == num_turns: - first_sample_meta["text"] = buffer_texts - first_sample_meta["audio_filepath"] = _combined_audio_name( - first_sample_meta.get("audio_filepath", ""), - buffer_paths, - ) - first_sample_meta["__dataset_index__"] = synthetic_index - synthetic_index += 1 - self.samples.append(first_sample_meta) - buffer_texts, buffer_paths, first_sample_meta = [], [], None - - if buffer_texts and first_sample_meta is not None: - first_sample_meta["text"] = buffer_texts - first_sample_meta["audio_filepath"] = _combined_audio_name( - first_sample_meta.get("audio_filepath", ""), - buffer_paths, - ) - first_sample_meta["__dataset_index__"] = synthetic_index - synthetic_index += 1 - self.samples.append(first_sample_meta) - - def __len__(self): - return len(self.samples) - - def __getitem__(self, idx): - return self.samples[idx] - - -def _sample_text_segments_for_count(sample: Dict[str, Any], max_eval_turns=None) -> List[str]: - text_data = sample.get("text", "") - if isinstance(text_data, list): - segments = text_data - if max_eval_turns is not None: - segments = segments[: int(max_eval_turns)] - return [str(x) for x in segments] - return [str(text_data)] - - -def estimate_text_token_count(sample: Dict[str, Any], model, max_eval_turns=None) -> int: - main_tokenizer_name = list(model.cfg.text_tokenizers.keys())[0] - total = 0 - for segment in _sample_text_segments_for_count(sample, max_eval_turns=max_eval_turns): - total += len(model.tokenizer.encode(segment, tokenizer_name=main_tokenizer_name)) + 1 - return int(total) - - -class SortedByTextTokenCountDataset(Dataset): - def __init__(self, dataset: Dataset, model, max_eval_turns=None, descending: bool = True): - self.dataset = dataset - scored = [] - for i in range(len(dataset)): - sample = dict(dataset[i]) - token_count = estimate_text_token_count(sample, model=model, max_eval_turns=max_eval_turns) - sample["__text_token_count__"] = int(token_count) - scored.append((token_count, i, sample)) - - scored.sort(key=lambda x: (x[0], -x[1]), reverse=bool(descending)) - self.indices = [i for _, i, _ in scored] - self.token_counts = {i: int(tok) for tok, i, _ in scored} - - def __len__(self): - return len(self.indices) - - def __getitem__(self, local_idx): - original_idx = self.indices[local_idx] - sample = dict(self.dataset[original_idx]) - sample["__text_token_count__"] = self.token_counts[original_idx] - return sample - - -# ----------------------------- -# Audio / collate helpers -# ----------------------------- - - -def _resolve_audio_path(path, root_path): - if path is None: - return None - if root_path is not None and not os.path.isabs(path): - return os.path.join(root_path, path) - return path - - -def _load_audio(path, sample_rate, normalize=True, use_librosa=False): - if path is None or not os.path.exists(path): - return torch.zeros(1, dtype=torch.float32) - - if use_librosa: - wav, sr = librosa.load(path, sr=sample_rate, mono=True) - if normalize: - wav = normalize_volume(wav) - return torch.as_tensor(wav, dtype=torch.float32) - - wav, sr = sf.read(path, dtype="float32") - if wav.ndim > 1: - wav = wav.mean(axis=1) - - if normalize: - wav = normalize_volume(wav) - - wav = torch.as_tensor(wav, dtype=torch.float32).unsqueeze(0) - return resample(wav, sr, sample_rate).squeeze(0) - - -def collate_and_tokenize_custom( - batch, - model, - extra_duration_thrshould=1.3, - sample_rate=22050, - root_path=None, - emulate_duplex_inference=False, - add_interruption_token=False, - pad_factor_text_speech=10, - force_interruption=False, - normalize_audio_volume=True, - use_librosa=False, - profile_multiturn_inference=False, - max_eval_turns=None, -): - main_tokenizer_name = list(model.cfg.text_tokenizers.keys())[0] - - if max_eval_turns is not None: - max_eval_turns = int(max_eval_turns) - if max_eval_turns <= 0: - raise ValueError("--max_eval_turns must be > 0 when provided.") - - truncated_batch = [] - for s in batch: - s = dict(s) - if isinstance(s["text"], list): - s["text"] = s["text"][:max_eval_turns] - if isinstance(s.get("user_audio_file_path"), list): - s["user_audio_file_path"] = s["user_audio_file_path"][:max_eval_turns] - truncated_batch.append(s) - batch = truncated_batch - - is_profile = profile_multiturn_inference - is_duplex = emulate_duplex_inference and not is_profile - - out_dict = { - "duplex_multiturn": is_duplex, - "regular_multiturn": (not is_duplex) and (not is_profile), - "profile_multiturn": is_profile, - "dataset_indices": [int(s.get("__dataset_index__", -1)) for s in batch], - "text_token_counts": [int(s.get("__text_token_count__", -1)) for s in batch], - } - - tokenized_list = [] - batched_turns = [] - batched_turn_lens = [] - valid_turn_masks = [] - - if is_duplex: - for s in batch: - text_data = s["text"] - - if isinstance(text_data, list): - full_ids = [] - for segment in text_data: - seg_ids = model.tokenizer.encode(segment, tokenizer_name=main_tokenizer_name) + [model.eos_id] - pad_ids = [model.pad_id] * (len(seg_ids) * pad_factor_text_speech) - - if force_interruption: - fname = s["audio_filepath"] - no_ext = fname.split(".")[0] - sample_id = int(no_ext.split("_")[-1]) - case = sample_id % 3 - - if case == 0: - if len(seg_ids) >= 2: - seg_ids[-2] = model.interruption_token_id - seg_ids[-1] = model.pad_id - else: - pad_ids[0] = model.interruption_token_id - elif case == 1: - eos_idx = min(6, len(pad_ids) - 1) - pad_ids[eos_idx] = model.interruption_token_id - else: - pad_ids[0] = model.interruption_token_id - - elif add_interruption_token: - eos_idx = int(len(pad_ids) * 0.7) - pad_ids[eos_idx] = model.interruption_token_id - - full_ids.extend(seg_ids) - full_ids.extend(pad_ids) - - tokenized_list.append(torch.as_tensor(full_ids, dtype=torch.long)) - else: - tokenized_list.append( - torch.as_tensor( - model.tokenizer.encode(text_data, tokenizer_name=main_tokenizer_name) + [model.eos_id], - dtype=torch.long, - ) - ) - - prefix = torch.full((25,), model.pad_id, dtype=torch.long) - tokenized_list = [torch.cat([prefix, x]) for x in tokenized_list] - out_dict["input_lengths"] = torch.tensor([len(x) for x in tokenized_list], dtype=torch.long) - out_dict["input_ids"] = pad_sequence(tokenized_list, batch_first=True, padding_value=model.pad_id) - - else: - max_turns = 1 - for s in batch: - if isinstance(s["text"], list): - max_turns = max(max_turns, len(s["text"])) - - for t in range(max_turns): - turn_t_tokens, turn_t_lens, turn_t_valid = [], [], [] - - for s in batch: - text_data = s["text"] - - if isinstance(text_data, list): - if t < len(text_data): - seg_ids = model.tokenizer.encode(text_data[t], tokenizer_name=main_tokenizer_name) + [ - model.eos_id - ] - turn_t_tokens.append(torch.as_tensor(seg_ids, dtype=torch.long)) - turn_t_lens.append(len(seg_ids)) - turn_t_valid.append(True) - else: - turn_t_tokens.append(torch.as_tensor([model.pad_id], dtype=torch.long)) - turn_t_lens.append(1) - turn_t_valid.append(False) - else: - if t == 0: - seg_ids = model.tokenizer.encode(text_data, tokenizer_name=main_tokenizer_name) + [ - model.eos_id - ] - turn_t_tokens.append(torch.as_tensor(seg_ids, dtype=torch.long)) - turn_t_lens.append(len(seg_ids)) - turn_t_valid.append(True) - else: - turn_t_tokens.append(torch.as_tensor([model.pad_id], dtype=torch.long)) - turn_t_lens.append(1) - turn_t_valid.append(False) - - batched_turns.append(pad_sequence(turn_t_tokens, batch_first=True, padding_value=model.pad_id)) - batched_turn_lens.append(torch.tensor(turn_t_lens, dtype=torch.long)) - valid_turn_masks.append(torch.tensor(turn_t_valid, dtype=torch.bool)) - - out_dict["batched_turns"] = batched_turns - out_dict["batched_turn_lens"] = batched_turn_lens - out_dict["valid_turn_masks"] = valid_turn_masks - - audio_list, audio_lengths, target_num_frames = [], [], [] - context_audio_paths = [] - max_turns_for_user_audio = len(batched_turns) if not is_duplex else 0 - - if is_profile and max_turns_for_user_audio > 0: - user_audio_by_turn = [[] for _ in range(max_turns_for_user_audio)] - user_audio_lens_by_turn = [[] for _ in range(max_turns_for_user_audio)] - else: - user_audio_by_turn, user_audio_lens_by_turn = [], [] - - for i, s in enumerate(batch): - audio_path = _resolve_audio_path(s.get("context_audio_filepath"), root_path) - context_audio_paths.append(audio_path) - wav = _load_audio(audio_path, sample_rate, normalize=normalize_audio_volume, use_librosa=use_librosa) - audio_list.append(wav) - audio_lengths.append(len(wav)) - - if is_profile and max_turns_for_user_audio > 0: - user_audio_paths = s.get("user_audio_file_path", None) - - for t in range(max_turns_for_user_audio): - has_valid_text_turn = (isinstance(s["text"], list) and t < len(s["text"])) or ( - not isinstance(s["text"], list) and t == 0 - ) - - if ( - isinstance(user_audio_paths, list) - and t < len(user_audio_paths) - and user_audio_paths[t] - and has_valid_text_turn - ): - ua_path = _resolve_audio_path(user_audio_paths[t], root_path) - ua_wav = _load_audio( - ua_path, - sample_rate=sample_rate, - normalize=normalize_audio_volume, - use_librosa=use_librosa, - ) - else: - ua_wav = torch.zeros(int(2 * sample_rate), dtype=torch.float32) - - user_audio_by_turn[t].append(ua_wav) - user_audio_lens_by_turn[t].append(len(ua_wav)) - - tdur_audio_path = _resolve_audio_path(s["audio_filepath"], root_path) - - if tdur_audio_path and os.path.exists(tdur_audio_path): - wav_dur = _load_audio( - tdur_audio_path, - sample_rate, - normalize=normalize_audio_volume, - use_librosa=use_librosa, - ) - tdur = wav_dur.shape[0] // model.input_samples_per_frame - target_num_frames.append(tdur * extra_duration_thrshould) - else: - if is_duplex: - current_text_len = len(tokenized_list[i]) - target_num_frames.append(current_text_len if isinstance(s["text"], list) else current_text_len * 5) - else: - target_num_frames.append(sum([l[i].item() for l in batched_turn_lens]) * 5) - - max_audio_len = max(audio_lengths) - B = len(audio_lengths) - padded_audio = torch.zeros((B, max_audio_len), dtype=torch.float32) - - for i, wav in enumerate(audio_list): - padded_audio[i, : len(wav)] = wav - - if is_profile and max_turns_for_user_audio > 0: - padded_user_audio_turns, padded_user_audio_turns_lens = [], [] - - for t in range(max_turns_for_user_audio): - turn_lens = user_audio_lens_by_turn[t] - max_turn_audio_len = max(turn_lens) - padded_turn_audio = torch.zeros((B, max_turn_audio_len), dtype=torch.float32) - - for i, wav in enumerate(user_audio_by_turn[t]): - padded_turn_audio[i, : len(wav)] = wav - - padded_user_audio_turns.append(padded_turn_audio) - padded_user_audio_turns_lens.append(torch.tensor(turn_lens, dtype=torch.long)) - - out_dict["user_audio_turns"] = padded_user_audio_turns - out_dict["user_audio_turns_lens"] = padded_user_audio_turns_lens - - raw_turn_texts = [] - for s in batch: - if isinstance(s["text"], list): - raw_turn_texts.append([str(x) for x in s["text"]]) - else: - raw_turn_texts.append([str(s["text"])]) - - out_dict["context_audio"] = padded_audio - out_dict["context_audio_lengths"] = torch.tensor(audio_lengths, dtype=torch.long) - out_dict["context_audio_paths"] = context_audio_paths - out_dict["target_audio_paths"] = [s["audio_filepath"] for s in batch] - out_dict["target_num_frames"] = target_num_frames - out_dict["raw_turn_texts"] = raw_turn_texts - out_dict["raw_text"] = [" ".join(x) for x in raw_turn_texts] - - return out_dict - - -# ----------------------------- -# Model / generation -# ----------------------------- - - -def attach_dtype_counter(model): - handles = [] - stats = {} - examples = {} - - def is_leaf(module): - return len(list(module.children())) == 0 - - def get_dtype(x): - if torch.is_tensor(x): - return str(x.dtype) - if isinstance(x, (list, tuple)): - for t in x: - if torch.is_tensor(t): - return str(t.dtype) - return "other" - - def get_module_group(name): - return name.split(".")[0] if "." in name else name - - def hook_fn(name): - def fn(module, inputs, outputs): - dtype = get_dtype(outputs) - if dtype not in ["torch.float16", "torch.bfloat16", "torch.float32"]: - dtype = "other" - group = get_module_group(name) - if group not in stats: - stats[group] = { - "torch.float16": 0, - "torch.bfloat16": 0, - "torch.float32": 0, - "other": 0, - } - examples[group] = { - "torch.float16": [], - "torch.bfloat16": [], - "torch.float32": [], - "other": [], - } - stats[group][dtype] += 1 - if len(examples[group][dtype]) < 3: - examples[group][dtype].append(module.__class__.__name__) - - return fn - - for name, module in model.named_modules(): - if is_leaf(module): - handles.append(module.register_forward_hook(hook_fn(name))) - return handles, stats, examples - - -def report_dtype_stats(handles, stats, examples, rank=0): - for h in handles: - h.remove() - logging.info(f"[rank={rank}] === DTYPE USAGE PER MODULE ===") - for group, group_stats in stats.items(): - total = sum(group_stats.values()) - if total == 0: - continue - logging.info(f"[rank={rank}] --- {group} ---") - for dtype, count in group_stats.items(): - if count > 0: - logging.info(f"[rank={rank}] {dtype}: {count} ({100 * count / total:.2f}%)") - logging.info(f"[rank={rank}] === DTYPE EXAMPLES ===") - for group, group_examples in examples.items(): - for dtype, mods in group_examples.items(): - if mods: - logging.info(f"[rank={rank}] {group} {dtype}: {mods}") - - -def build_model_and_codec(args, target_device, target_dtype): - model_cfg = EasyMagpieTTSInferenceModel.restore_from(args.checkpoint_path, return_config=True) - - with open_dict(model_cfg): - model_cfg.target = "nemo.collections.tts.models.easy_magpietts_inference.EasyMagpieTTSInferenceModel" - model_cfg.codecmodel_path = args.codec_model_path - model_cfg.train_ds = None - model_cfg.validation_ds = None - model_cfg.run_val_inference = False - model_cfg.use_utmos = False - model_cfg.use_meta_init_for_decoder = True - - if args.phoneme_tokenizer_path and getattr(model_cfg, "phoneme_tokenizer", None) is not None: - model_cfg.phoneme_tokenizer.tokenizer_path = args.phoneme_tokenizer_path - - model = EasyMagpieTTSInferenceModel.restore_from( - args.checkpoint_path, - override_config_path=model_cfg, - map_location=torch.device("cpu"), - ) - model.use_kv_cache_for_inference = True - model.to(dtype=target_dtype) - model.eval().to(target_device) - - model.input_samples_per_frame = int(model.codec_model_samples_per_frame * model.frame_stacking_factor) - model.target_samples_per_frame = model.input_samples_per_frame / (model.sample_rate / model.output_sample_rate) - - codec_model = AudioCodecModel.restore_from(args.codec_model_path, strict=False, map_location=torch.device("cpu")) - if hasattr(codec_model, "discriminator"): - del codec_model.discriminator - codec_model.freeze() - codec_model = codec_model.to(target_device).eval() - - codec_converter = None - if getattr(model, "_codec_converter", None) is not None: - vq_new = deepcopy(model._codec_converter.vector_quantizer_new).to(target_device).eval() - codec_converter = VectorQuantizerIndexConverter( - vector_quantizer_original=codec_model.vector_quantizer, - vector_quantizer_new=vq_new, - ).to(target_device).eval() - - model._codec_helper = CodecHelper(codec_model=codec_model, codec_converter=codec_converter) - model._generate_codec_silence_buffer() - - return model - - -def prepare_inputs_for_device(inputs, model, args, target_dtype, speaker_wav=None): - B = inputs["context_audio"].size(0) - device = model.device - - inputs["context_audio"] = inputs["context_audio"].to(device, dtype=target_dtype) - inputs["context_audio_lengths"] = inputs["context_audio_lengths"].to(device) - - if args.user_custom_speaker_reference and speaker_wav is not None: - inputs["context_audio"] = speaker_wav.repeat(B, 1).detach() - inputs["context_audio_lengths"] = torch.full((B,), speaker_wav.size(-1), dtype=torch.long, device=device) - - if "user_audio_turns" in inputs: - inputs["user_audio_turns"] = [x.to(device, dtype=target_dtype) for x in inputs["user_audio_turns"]] - inputs["user_audio_turns_lens"] = [x.to(device) for x in inputs["user_audio_turns_lens"]] - - return inputs - - -def run_generation(model, inputs, args, codec_sil_codes): - B = inputs["context_audio"].size(0) - device = model.device - profile_turn_frame_ranges = [] - profile_decode_start_frame = 0 - - with torch.inference_mode(): - wav = inputs["context_audio"] - wav_len = inputs["context_audio_lengths"] - codes, codes_lens = model._codec_helper.audio_to_codes(wav, wav_len) - - use_lang = bool(getattr(model, "add_language_to_context_text", False)) - ctx_text = f"[{args.language.upper()}]" if use_lang else "[NO TEXT CONTEXT]" - ctx_text_ids = model.tokenizer.encode(ctx_text, tokenizer_name=model.text_conditioning_tokenizer_name) - - ctx_toks = torch.tensor([ctx_text_ids], dtype=torch.long, device=device).expand(B, -1) - ctx_toks_lens = torch.tensor([len(ctx_text_ids)] * B, dtype=torch.long, device=device) - - state = model.streaming_init( - context_audio_codes=codes, - context_audio_codes_lens=codes_lens, - context_text_tokens=ctx_toks, - context_text_tokens_lens=ctx_toks_lens, - use_cfg=args.use_cfg, - cfg_scale=args.cfg_scale, - use_local_transformer=True, - temperature=args.temperature, - topk=args.topk, - phoneme_input_type="pred", - phoneme_sampling_method="argmax", - use_inference_mode=True, - ) - - if inputs["duplex_multiturn"]: - text = inputs["input_ids"].to(device) - text_lens = inputs["input_lengths"].to(device) - - in_initial_silence = torch.ones(B, dtype=torch.bool, device=device) - in_post_speech_silence = torch.zeros(B, dtype=torch.bool, device=device) - text_exhausted = state.text_tokens_seen >= text_lens - - while not text_exhausted.all(): - state.finished = state.finished & text_exhausted - state.text_finished = state.text_finished & text_exhausted - - if hasattr(state, "phoneme_stream_ended"): - state.phoneme_stream_ended = state.phoneme_stream_ended & text_exhausted - - positions = state.text_tokens_seen.clamp(max=text.size(1) - 1) - current_tokens = text[torch.arange(B, device=device), positions] - current_tokens = torch.where( - text_exhausted, - torch.full_like(current_tokens, model.eos_id), - current_tokens, - ) - - is_pad_or_eos = (current_tokens == model.pad_id) | (current_tokens == model.eos_id) - in_initial_silence = in_initial_silence & is_pad_or_eos - in_post_speech_silence = in_post_speech_silence & is_pad_or_eos - - state, audio_codes, _ = model.streaming_step( - state=state, - text_tokens=current_tokens, - use_inference_mode=True, - ) - - if audio_codes is not None and args.force_speech_sil_codes: - force_silence_mask = in_initial_silence | in_post_speech_silence - if force_silence_mask.any(): - expanded_sil = codec_sil_codes.view(1, -1, 1).expand_as(audio_codes) - mask_3d = force_silence_mask.view(B, 1, 1) - state.all_predictions[-1] = torch.where(mask_3d, expanded_sil, audio_codes) - - in_post_speech_silence = in_post_speech_silence | state.finished - text_exhausted = state.text_tokens_seen >= text_lens - - elif inputs["regular_multiturn"]: - batched_turns = inputs["batched_turns"] - batched_turn_lens = inputs["batched_turn_lens"] - valid_turn_masks = inputs["valid_turn_masks"] - turn_offsets = torch.zeros(B, dtype=torch.long, device=device) - - for t in range(len(batched_turns)): - turn_text = batched_turns[t].to(device) - turn_lens = batched_turn_lens[t].to(device) - valid_mask = valid_turn_masks[t].to(device) - - state.finished = state.finished & (~valid_mask) - state.text_finished = state.text_finished & (~valid_mask) - - if hasattr(state, "phoneme_stream_ended"): - state.phoneme_stream_ended = state.phoneme_stream_ended & (~valid_mask) - - if state.finished.all(): - continue - - turn_offsets = torch.where(valid_mask, state.text_tokens_seen, turn_offsets) - turn_steps = 0 - - while not state.finished.all() and turn_steps < args.max_tts_steps: - turn_steps += 1 - relative_positions = state.text_tokens_seen - turn_offsets - positions = relative_positions.clamp(min=0, max=turn_text.size(1) - 1) - current_tokens = turn_text[torch.arange(B, device=device), positions] - - exhausted = relative_positions >= turn_lens - current_tokens = torch.where( - exhausted, - torch.full_like(current_tokens, model.eos_id), - current_tokens, - ) - - state, _, _ = model.streaming_step( - state=state, - text_tokens=current_tokens, - use_inference_mode=True, - ) - - elif inputs["profile_multiturn"]: - if B != 1: - raise RuntimeError("--profile_multiturn_inference requires --batch_size=1 per process.") - - batched_turns = inputs["batched_turns"] - batched_turn_lens = inputs["batched_turn_lens"] - valid_turn_masks = inputs["valid_turn_masks"] - - for t in range(len(batched_turns)): - turn_text = batched_turns[t].to(device) - turn_lens = batched_turn_lens[t].to(device) - valid_mask = valid_turn_masks[t].to(device) - - if not bool(valid_mask[0].item()): - continue - - state.finished.zero_() - state.text_finished.zero_() - state.audio_prediction_end_idx.fill_(-1) - - if hasattr(state, "turn_text_tokens_seen"): - state.turn_text_tokens_seen.zero_() - if hasattr(state, "phoneme_steps"): - state.phoneme_steps.zero_() - if hasattr(state, "phoneme_stream_ended"): - state.phoneme_stream_ended.zero_() - if hasattr(state, "phoneme_eos_detected"): - state.phoneme_eos_detected.zero_() - state.last_phoneme_tokens = None - - if not model.cfg.get("condition_on_user_speech", False): - if "user_audio_turns" in inputs: - profile_T = int(round(inputs["user_audio_turns"][t].size(-1) / model.input_samples_per_frame)) - profile_seconds = profile_T * model.input_samples_per_frame / model.sample_rate - else: - profile_seconds = args.profile_pad_min_sec + torch.rand((), device=device).item() * ( - args.profile_pad_max_sec - args.profile_pad_min_sec - ) - profile_T = max( - 1, - int(round(profile_seconds * model.sample_rate / model.input_samples_per_frame)), - ) - - profile_tokens = torch.full((1, profile_T), model.pad_id, dtype=torch.long, device=device) - user_audio_channel_embedding = None - - else: - if "user_audio_turns" in inputs: - user_audio = inputs["user_audio_turns"][t] - user_audio_lens = inputs["user_audio_turns_lens"][t] - else: - user_audio = inputs["context_audio"] - user_audio_lens = inputs["context_audio_lengths"] - - user_audio_codes, user_audio_codes_lens = model._codec_helper.audio_to_codes( - user_audio, - user_audio_lens, - ) - - if model._codec_converter is not None: - user_audio_codes = model._codec_converter.convert_original_to_new( - audio_tokens=user_audio_codes, - audio_lens=user_audio_codes_lens, - ).long() - - user_audio_codes, user_audio_codes_lens = model.stack_codes( - user_audio_codes, - user_audio_codes_lens, - model.audio_bos_id, - model.audio_eos_id, - model.frame_stacking_factor, - model.num_audio_codebooks, - ) - - user_audio_embedded = model.embed_audio_tokens(user_audio_codes) - - boundary_trim = model.cfg.get("user_audio_boundary_trim", 4) - boundary_trim = 0 if boundary_trim is None else int(boundary_trim) - - if boundary_trim == 0: - real_start = 0 - real_end = int(user_audio_codes_lens[0].item()) - else: - turn_len_with_special = int(user_audio_codes_lens[0].item()) - real_start = 1 - real_end = max(real_start, turn_len_with_special - 1) - - user_audio_embedded = user_audio_embedded[:, real_start:real_end] - copy_len = user_audio_embedded.size(1) - - if boundary_trim > 0: - trim = min(boundary_trim, copy_len // 2) - if trim > 0: - user_audio_embedded[:, :trim] = 0.0 - user_audio_embedded[:, copy_len - trim :] = 0.0 - - bos_user_pad = torch.zeros( - user_audio_embedded.size(0), - 1, - user_audio_embedded.size(2), - device=user_audio_embedded.device, - dtype=user_audio_embedded.dtype, - ) - user_audio_embedded = torch.cat([bos_user_pad, user_audio_embedded], dim=1) - - profile_T = user_audio_embedded.size(1) - profile_tokens = torch.full((B, profile_T), model.pad_id, dtype=torch.long, device=device) - user_audio_channel_embedding = user_audio_embedded - profile_seconds = profile_T * model.input_samples_per_frame / model.sample_rate - - delay_tokens = int(state.config.training_mode.streaming_speech_delay) - delay_tokens = min(delay_tokens, int(turn_lens[0].item()), profile_T) - - warmup_tokens = turn_text[:, :delay_tokens] - turn_text = turn_text[:, delay_tokens:] - turn_lens = torch.clamp(turn_lens - delay_tokens, min=0) - - if user_audio_channel_embedding is not None and delay_tokens > 0: - warmup_user_audio = user_audio_channel_embedding[:, -delay_tokens:] - user_audio_channel_embedding = user_audio_channel_embedding[:, :-delay_tokens] - profile_tokens = profile_tokens[:, :-delay_tokens] - else: - warmup_user_audio = None - - if profile_tokens.size(1) > 0: - state = model.streaming_prefill_profile( - state=state, - text_tokens=profile_tokens, - use_inference_mode=True, - user_audio_channel_embedding=user_audio_channel_embedding, - ) - - for i in range(delay_tokens): - user_step_emb = warmup_user_audio[:, i] if warmup_user_audio is not None else None - - state.finished.zero_() - state, _, _ = model.streaming_step( - state=state, - text_tokens=warmup_tokens[:, i], - user_audio_channel_embedding=user_step_emb, - prefill_like_step=not bool(model.cfg.get("agent_mask_include_transition_prefix", False)), - prefill_like_is_last_step=(i == delay_tokens - 1), - use_inference_mode=True, - ) - - logging.info(f"[profile_multiturn] turn={t} prefilled {profile_T} steps ({profile_seconds:.2f}s)") - - turn_start_frame = sum(p.size(-1) for p in state.all_predictions) - if t == 0: - state.audio_prediction_start_idx.fill_(turn_start_frame) - profile_decode_start_frame = turn_start_frame - - turn_offset = state.text_tokens_seen.clone() - turn_steps = 0 - saw_audio = False - turn_ended_with_audio_eos = False - - while turn_steps < args.max_tts_steps: - turn_steps += 1 - state.finished.zero_() - - relative_position = state.text_tokens_seen - turn_offset - text_exhausted = relative_position >= turn_lens - - if turn_text.size(1) == 0: - current_tokens = torch.full((B,), model.eos_id, dtype=torch.long, device=device) - else: - position = relative_position.clamp(min=0, max=turn_text.size(1) - 1) - current_tokens = turn_text[torch.arange(B, device=device), position] - current_tokens = torch.where( - text_exhausted, - torch.full_like(current_tokens, model.eos_id), - current_tokens, - ) - - state, audio_codes, _ = model.streaming_step( - state=state, - text_tokens=current_tokens, - use_inference_mode=True, - ) - - if audio_codes is not None and not saw_audio: - saw_audio = True - - if bool(text_exhausted[0].item()) and bool(state.finished[0].item()): - turn_ended_with_audio_eos = True - break - - state.audio_prediction_end_idx.fill_(-1) - state.finished.zero_() - - turn_end_frame = sum(p.size(-1) for p in state.all_predictions) - profile_turn_frame_ranges.append((t, turn_start_frame, turn_end_frame)) - - logging.info( - f"[profile_multiturn] turn={t} steps={turn_steps} " - f"saw_audio={saw_audio} ended_with_audio_eos={turn_ended_with_audio_eos}" - ) - - bos_id = getattr(model, "audio_bos_id", -1) - eos_id = getattr(model, "audio_eos_id", -1) - speaking_id = getattr(model, "audio_user_speaking_id", -1) - speaking_end_id = getattr(model, "audio_user_speaking_end_id", -1) - sil_injection = codec_sil_codes.view(1, -1, 1) - - for step_idx in range(len(state.all_predictions)): - pred = state.all_predictions[step_idx] - mask = (pred == bos_id) | (pred == eos_id) | (pred == speaking_id) | (pred == speaking_end_id) - frame_mask = mask.any(dim=1, keepdim=True) - if frame_mask.any(): - state.all_predictions[step_idx] = torch.where(frame_mask, sil_injection.expand_as(pred), pred) - - if inputs["duplex_multiturn"] or inputs["profile_multiturn"]: - state.audio_prediction_end_idx.fill_(-1) - - finalize_output = model.streaming_finalize(state, use_inference_mode=True) - - return finalize_output, profile_turn_frame_ranges, profile_decode_start_frame - - -def load_speaker_wav_if_needed(args, model, target_dtype): - if args.user_custom_speaker_reference and args.inference_speaker_reference: - return _load_audio( - args.inference_speaker_reference, - model.sample_rate, - normalize=args.normalize_volume, - use_librosa=args.use_librosa, - ).unsqueeze(0).to(model.device, dtype=target_dtype) - - return None - - -# ----------------------------- -# Save generation outputs and metric manifests -# ----------------------------- - - -def write_audio_1d(path: str, wav: torch.Tensor, sr: int): - os.makedirs(os.path.dirname(path), exist_ok=True) - wav_np = wav.detach().cpu().float().numpy() - sf.write(path, wav_np, samplerate=sr) - - -def build_metric_item( - run_id: int, - rank: int, - dataset_index: int, - turn_id: int, - target_audio_path: str, - reference_text: str, - pred_audio_path: str, - context_audio_path: str, - pred_audio_samples: int, - context_audio_samples: int, - output_sample_rate: int, - context_sample_rate: int, -): - return { - "run_id": int(run_id), - "rank": int(rank), - "dataset_index": int(dataset_index), - "turn_id": int(turn_id), - "target_audio_path": target_audio_path, - "reference_text": reference_text, - "pred_audio_path": pred_audio_path, - "context_audio_path": context_audio_path, - "pred_audio_samples": int(pred_audio_samples), - "context_audio_samples": int(context_audio_samples), - "pred_audio_seconds": float(pred_audio_samples / output_sample_rate), - "context_audio_seconds": float(context_audio_samples / context_sample_rate), - "output_sample_rate": int(output_sample_rate), - "context_sample_rate": int(context_sample_rate), - } - - -def save_generation_outputs_and_build_metric_items( - model, - inputs, - finalize_output, - profile_turn_frame_ranges, - profile_decode_start_frame, - args, - rank: int, - run_id: int, -): - device = model.device - B = inputs["context_audio"].size(0) - - with fp32_precision(): - audio_f32 = finalize_output.audio.float() - audio_len = finalize_output.audio_len.int() - - expected_audio_lens = ( - torch.tensor(inputs["target_num_frames"], device=device) * model.target_samples_per_frame - ).int() - - if inputs["duplex_multiturn"]: - text_lens = inputs["input_lengths"].to(device) - audio_len = (text_lens * model.target_samples_per_frame).int() - audio_len = torch.min(audio_len, torch.tensor(audio_f32.size(1), device=device)) - elif inputs["profile_multiturn"]: - audio_len = finalize_output.audio_len.int() - else: - audio_len = torch.min(audio_len, expected_audio_lens) - - audio_out_dir = get_audio_out_dir(args) - metric_turn_dir = get_generated_turn_audio_dir(args) - metric_context_dir = get_context_metric_audio_dir(args) - os.makedirs(audio_out_dir, exist_ok=True) - os.makedirs(metric_turn_dir, exist_ok=True) - os.makedirs(metric_context_dir, exist_ok=True) - - audio_f32_cpu = audio_f32.detach().cpu() - audio_len_cpu = audio_len.detach().cpu() - metric_items = [] - - for i in range(B): - target_path = inputs["target_audio_paths"][i] - base_name = os.path.basename(target_path) - stem, ext = os.path.splitext(base_name) - if not ext: - ext = ".wav" - - dataset_idx = int(inputs.get("dataset_indices", [-1] * B)[i]) - safe_stem = ( - f"run{run_id:02d}_idx{dataset_idx:08d}_{stem}" - if dataset_idx >= 0 - else f"run{run_id:02d}_rank{rank}_{stem}" - ) - - context_len = int(inputs["context_audio_lengths"][i].detach().cpu().item()) - context_wav = inputs["context_audio"][i, :context_len].detach().cpu().float() - context_metric_path = os.path.join(metric_context_dir, f"{safe_stem}_context.wav") - write_audio_1d(context_metric_path, context_wav, model.sample_rate) - - if inputs["profile_multiturn"]: - full_len = int(audio_len_cpu[i].item()) - full_wav_t = audio_f32_cpu[i, :full_len].float() - - samples_per_prediction_frame = model.codec_model_samples_per_frame / ( - model.sample_rate / model.output_sample_rate - ) - - aligned_agent = torch.zeros_like(full_wav_t) - raw_turn_texts = inputs.get("raw_turn_texts", [[] for _ in range(B)]) - - for turn_id, start_frame, end_frame in profile_turn_frame_ranges: - rel_start_frame = start_frame - profile_decode_start_frame - rel_end_frame = end_frame - profile_decode_start_frame - - start_sample = int(round(rel_start_frame * samples_per_prediction_frame)) - end_sample = int(round(rel_end_frame * samples_per_prediction_frame)) - - start_sample = max(0, min(start_sample, full_len)) - end_sample = max(start_sample, min(end_sample, full_len)) - - aligned_agent[start_sample:end_sample] = full_wav_t[start_sample:end_sample] - - turn_wav = aligned_agent[start_sample:end_sample].float() - turn_out_path = os.path.join(audio_out_dir, f"{safe_stem}_turn_{turn_id}{ext}") - write_audio_1d(turn_out_path, turn_wav, model.output_sample_rate) - - metric_turn_path = os.path.join(metric_turn_dir, f"{safe_stem}_turn_{turn_id}.wav") - write_audio_1d(metric_turn_path, turn_wav, model.output_sample_rate) - - if turn_id < len(raw_turn_texts[i]): - metric_items.append( - build_metric_item( - run_id=run_id, - rank=rank, - dataset_index=dataset_idx, - turn_id=turn_id, - target_audio_path=target_path, - reference_text=str(raw_turn_texts[i][turn_id]), - pred_audio_path=metric_turn_path, - context_audio_path=context_metric_path, - pred_audio_samples=int(turn_wav.numel()), - context_audio_samples=int(context_wav.numel()), - output_sample_rate=model.output_sample_rate, - context_sample_rate=model.sample_rate, - ) - ) - - full_out_path = os.path.join(audio_out_dir, f"{safe_stem}{ext}") - write_audio_1d(full_out_path, aligned_agent, model.output_sample_rate) - - if "user_audio_turns" in inputs: - user_segments = [] - - first_user_len_in = int(inputs["user_audio_turns_lens"][0][i].item()) - first_user_delay_out = int(round(first_user_len_in * model.output_sample_rate / model.sample_rate)) - - for turn_id, start_frame, _ in profile_turn_frame_ranges: - if turn_id >= len(inputs["user_audio_turns"]): - continue - - turn_audio = inputs["user_audio_turns"][turn_id][i].detach().cpu().float() - turn_audio_len = int(inputs["user_audio_turns_lens"][turn_id][i].item()) - turn_audio = turn_audio[:turn_audio_len] - - turn_audio_out = resample( - turn_audio.unsqueeze(0), - model.sample_rate, - model.output_sample_rate, - ).squeeze(0) - - if turn_id == 0: - user_start_sample = 0 - else: - prev_turn_end_frame = profile_turn_frame_ranges[turn_id - 1][2] - rel_prev_end_frame = prev_turn_end_frame - profile_decode_start_frame - user_start_sample = first_user_delay_out + int( - round(rel_prev_end_frame * samples_per_prediction_frame) - ) - - user_segments.append((user_start_sample, turn_audio_out.detach().cpu().float())) - - total_user_len = 0 - for s, wav_seg in user_segments: - total_user_len = max(total_user_len, s + wav_seg.numel()) - - user_ch = torch.zeros(total_user_len) - for s, wav_seg in user_segments: - e = s + wav_seg.numel() - user_ch[s:e] += wav_seg - - agent_ch = torch.cat([torch.zeros(first_user_delay_out, dtype=aligned_agent.dtype), aligned_agent]) - - stereo_len = max(user_ch.numel(), agent_ch.numel()) - user_pad = torch.zeros(stereo_len) - agent_pad = torch.zeros(stereo_len) - - user_pad[: user_ch.numel()] = user_ch - agent_pad[: agent_ch.numel()] = agent_ch - - stereo = torch.stack([user_pad, agent_pad], dim=1).numpy() - aligned_path = os.path.join(audio_out_dir, f"{safe_stem}_user_agent_aligned{ext}") - sf.write(aligned_path, stereo, samplerate=model.output_sample_rate) - - else: - full_len = int(audio_len_cpu[i].item()) - wav = audio_f32_cpu[i, :full_len].float() - out_path = os.path.join(audio_out_dir, f"{safe_stem}{ext}") - write_audio_1d(out_path, wav, model.output_sample_rate) - - metric_turn_path = os.path.join(metric_turn_dir, f"{safe_stem}_turn_0.wav") - write_audio_1d(metric_turn_path, wav, model.output_sample_rate) - - metric_items.append( - build_metric_item( - run_id=run_id, - rank=rank, - dataset_index=dataset_idx, - turn_id=0, - target_audio_path=target_path, - reference_text=str(inputs["raw_text"][i]), - pred_audio_path=metric_turn_path, - context_audio_path=context_metric_path, - pred_audio_samples=int(wav.numel()), - context_audio_samples=int(context_wav.numel()), - output_sample_rate=model.output_sample_rate, - context_sample_rate=model.sample_rate, - ) - ) - - return metric_items - - -# ----------------------------- -# Metrics after generation -# ----------------------------- - - -def torch_rms_norm(wav: torch.Tensor, db_level: float = -27.0) -> torch.Tensor: - denom = torch.sum(wav**2) - if denom <= 0: - return wav - r = 10 ** (db_level / 20) - a = torch.sqrt((wav.size(-1) * (r**2)) / denom) - return wav * a - - -def _load_audio_for_metric(path: str, sample_rate: int): - wav = _load_audio(path, sample_rate=sample_rate, normalize=False, use_librosa=False) - if wav.numel() == 0: - wav = torch.zeros(1, dtype=torch.float32) - return wav.float() - - -def _pad_audio_1d_list(wavs: List[torch.Tensor], device, dtype=torch.float32): - if len(wavs) == 0: - return torch.zeros((0, 1), device=device, dtype=dtype), torch.zeros((0,), device=device, dtype=torch.long) - - lens = torch.tensor([max(1, int(w.numel())) for w in wavs], device=device, dtype=torch.long) - max_len = int(lens.max().item()) - out = torch.zeros((len(wavs), max_len), device=device, dtype=dtype) - - for i, w in enumerate(wavs): - w = w.to(device=device, dtype=dtype).flatten() - if w.numel() == 0: - continue - out[i, : w.numel()] = w - - return out, lens - - -def chunk_list(xs: List[Any], chunk_size: int) -> Iterable[List[Any]]: - chunk_size = max(1, int(chunk_size)) - for start in range(0, len(xs), chunk_size): - yield xs[start : start + chunk_size] - - -def _metric_device(): - return "cuda" if torch.cuda.is_available() else "cpu" - - -def _load_metric_batch_audio(batch_items: List[Dict[str, Any]], args): - pred_wavs = [] - context_wavs = [] - - for item in batch_items: - pred = _load_audio_for_metric(item["pred_audio_path"], sample_rate=int(item["output_sample_rate"])) - context = _load_audio_for_metric(item["context_audio_path"], sample_rate=int(item["context_sample_rate"])) - - if args.max_metric_audio_sec is not None: - max_pred_len = int(float(args.max_metric_audio_sec) * int(item["output_sample_rate"])) - pred = pred[: max(1, max_pred_len)] - - pred_wavs.append(pred) - context_wavs.append(context) - - device = _metric_device() - pred_audio, pred_lens = _pad_audio_1d_list(pred_wavs, device=device) - context_audio, context_lens = _pad_audio_1d_list(context_wavs, device=device) - output_sample_rate = int(batch_items[0]["output_sample_rate"]) - context_sample_rate = int(batch_items[0]["context_sample_rate"]) - - return pred_audio, pred_lens, context_audio, context_lens, output_sample_rate, context_sample_rate - - -def compute_metrics_after_generation(args, rank: int, world_size: int, metric_items: List[Dict[str, Any]]): - """ - Load metric models only after generation is complete. - - Order: - 1. Load ASR, compute turn-level CER/WER and ASR hyps, then free ASR. - 2. Load SECS speaker encoder and compute turn-level SECS. - 3. Save rank-level aggregate metrics from the same turn-level rows. - - SECS is always computed turn-by-turn, like CER/WER. The grouped filewise - output stores secs_turns and sample-level secs, and metrics_final.* receives - the turn-level aggregate SECS. - """ - metric_start = time.time() - - if len(metric_items) == 0: - return { - "rank": int(rank), - "world_size": int(world_size), - "num_processed": 0, - "num_metric_items": 0, - "metric_elapsed_sec": 0.0, - "intelligibility": {}, - "secs": {}, - }, [] - - normalizer = EnglishTextNormalizer() - normalizer.ignore_patterns = r"$^" - filewise_rows = [] - - # ASR pass. - all_rank_print(rank, f"loading ASR after generation: {args.asr_model_name}") - with fp32_precision(): - intelligibility = Intelligibility(args.asr_model_name, reuse_asr_hyps=False).reset() - - for batch_items in chunk_list(metric_items, args.metric_batch_size): - refs = [x["reference_text"] for x in batch_items] - pred_audio, pred_lens, _, _, output_sr, _ = _load_metric_batch_audio(batch_items, args) - - with fp32_precision(): - pred_16k = resample(pred_audio, output_sr, 16000) - pred_16k_lens = (pred_lens / output_sr * 16000).to(torch.long) - pred_16k = torch_rms_norm(pred_16k) - - asr_hyps = intelligibility.update( - name="dataset", - refs=refs, - pred_audio=pred_16k, - pred_audio_lens=pred_16k_lens, - asr_hyps=None, - ) - - for item, hyp in zip(batch_items, asr_hyps): - ref_norm = normalizer(str(item["reference_text"])).strip() - hyp_norm = normalizer(str(hyp)).strip() - if ref_norm == "": - cer = None - wer = None - else: - cer = float(word_error_rate([hyp_norm], [ref_norm], use_cer=True)) - wer = float(word_error_rate([hyp_norm], [ref_norm], use_cer=False)) - - row = dict(item) - row["asr_hyp"] = hyp - row["cer"] = cer - row["wer"] = wer - row["secs"] = None - filewise_rows.append(row) - - with fp32_precision(): - cer_wer = metric_dict_to_jsonable(intelligibility.compute()) - del intelligibility - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - # SECS pass. This is intentionally turn-level, matching CER/WER. - # We keep one aggregate SECS metric for metrics_final.* and also compute - # one SECS value per filewise turn row so grouped outputs have secs_turns. - all_rank_print(rank, f"loading speaker encoder after ASR is released: {args.secs_model_name}") - with fp32_precision(): - secs_metric = SECS(args.secs_model_name).reset() - - # Aggregate turn-level SECS for metrics_final.json / metrics_final.txt. - for batch_items in chunk_list(metric_items, args.metric_batch_size): - pred_audio, pred_lens, context_audio, context_lens, output_sr, context_sr = _load_metric_batch_audio( - batch_items, args - ) - - with fp32_precision(): - pred_16k = resample(pred_audio, output_sr, 16000) - pred_16k_lens = (pred_lens / output_sr * 16000).to(torch.long) - context_16k = resample(context_audio, context_sr, 16000) - context_16k_lens = (context_lens / context_sr * 16000).to(torch.long) - - pred_16k = torch_rms_norm(pred_16k) - context_16k = torch_rms_norm(context_16k) - - secs_metric.update( - name="dataset", - target_audio=context_16k, - target_audio_lens=context_16k_lens, - pred_audio=pred_16k, - pred_audio_lens=pred_16k_lens, - ) - - with fp32_precision(): - secs_scores = metric_dict_to_jsonable(secs_metric.compute()) - del secs_metric - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - # Per-turn SECS for filewise/grouped outputs. This is always computed so - # secs_turns and sample-level secs are never null in final filewise metrics. - # It is slower than aggregate-only SECS, but it matches the turn-level - # semantics requested for CER/WER/SECS. - all_rank_print(rank, "computing per-turn SECS rows") - for row in filewise_rows: - pred_audio, pred_lens, context_audio, context_lens, output_sr, context_sr = _load_metric_batch_audio([row], args) - - with fp32_precision(): - one_secs = SECS(args.secs_model_name).reset() - pred_16k = resample(pred_audio, output_sr, 16000) - pred_16k_lens = (pred_lens / output_sr * 16000).to(torch.long) - context_16k = resample(context_audio, context_sr, 16000) - context_16k_lens = (context_lens / context_sr * 16000).to(torch.long) - - pred_16k = torch_rms_norm(pred_16k) - context_16k = torch_rms_norm(context_16k) - - one_secs.update( - name="dataset", - target_audio=context_16k, - target_audio_lens=context_16k_lens, - pred_audio=pred_16k, - pred_audio_lens=pred_16k_lens, - ) - one_secs_metrics = metric_dict_to_jsonable(one_secs.compute()) - - row["secs"] = safe_metric_scalar(one_secs_metrics, ["secs", "secs_dataset"]) - row["secs_metrics"] = one_secs_metrics - del one_secs - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - metric_elapsed = time.time() - metric_start - - rank_metrics = { - "rank": int(rank), - "world_size": int(world_size), - "num_processed": len({(x["run_id"], x["dataset_index"]) for x in metric_items}), - "num_metric_items": int(len(metric_items)), - "metric_elapsed_sec": float(metric_elapsed), - "intelligibility": cer_wer, - "secs": secs_scores, - } - - return rank_metrics, filewise_rows - - -# ----------------------------- -# Merge helpers -# ----------------------------- - - -def compute_and_save_rank_metrics_file(args, rank_metrics: Dict[str, Any], rank: int): - rank_path = os.path.join(args.out_dir, f"metrics_rank{rank:04d}.json") - write_json(rank_path, rank_metrics) - return rank_metrics - - -def merge_metrics_on_rank0(args, rank, world_size): - if rank != 0: - return None - - rank_metric_files = [os.path.join(args.out_dir, f"metrics_rank{r:04d}.json") for r in range(world_size)] - - rank_metrics = [] - for path in rank_metric_files: - if not os.path.exists(path): - logging.warning(f"Missing rank metric file: {path}") - continue - with open(path, "r", encoding="utf-8") as f: - rank_metrics.append(json.load(f)) - - total_n = sum(int(m.get("num_metric_items", m.get("num_processed", 0))) for m in rank_metrics) - - def weighted_average(section: str): - keys = set() - for m in rank_metrics: - keys.update(m.get(section, {}).keys()) - - out = {} - for k in sorted(keys): - numerator = 0.0 - denominator = 0 - - for m in rank_metrics: - n = int(m.get("num_metric_items", m.get("num_processed", 0))) - if n <= 0: - continue - - value = m.get(section, {}).get(k, None) - if value is None or isinstance(value, str): - continue - - try: - value = float(value) - except Exception: - continue - - numerator += value * n - denominator += n - - if denominator > 0: - out[k] = numerator / denominator - - return out - - final_metrics = { - "world_size": int(world_size), - "num_metric_items": int(total_n), - "aggregation": "sum(rank_metric * rank_num_metric_items) / total_num_metric_items", - "intelligibility": weighted_average("intelligibility"), - "secs": weighted_average("secs"), - "ranks": rank_metrics, - } - - final_json_path = os.path.join(args.out_dir, "metrics_final.json") - final_txt_path = os.path.join(args.out_dir, "metrics_final.txt") - - write_json(final_json_path, final_metrics) - - final_text = format_final_metric_text(final_metrics) - write_text_atomic(final_txt_path, final_text) - - print("\n--- Final Evaluation Metrics ---", flush=True) - print(final_text, flush=True) - - logging.info(f"Final metrics JSON saved to: {final_json_path}") - logging.info(f"Final metrics TXT saved to: {final_txt_path}") - logging.info(json.dumps(final_metrics, indent=2, sort_keys=True)) - - return final_metrics - - -def merge_filewise_metrics_on_rank0(args, rank: int, world_size: int): - """Merge per-turn rank metric rows into one row per original sample. - - Rank files still contain one row per turn because metrics are computed - turn-by-turn. The final filewise outputs group those turn rows by - (run_id, dataset_index), producing one JSONL/CSV row per original sample - with list fields: - reference_text, asr_hyp, cer_turns, wer_turns, secs_turns. - - DistributedSampler padding repeats are deduplicated by - (run_id, dataset_index, turn_id), but repetitions from --num_eval_runs are - preserved because run_id is part of the key. - """ - if rank != 0 or not args.save_filewise_metrics: - return [] - - turn_rows = [] - - for r in range(world_size): - path = os.path.join(args.out_dir, f"filewise_metrics_rank{r:04d}.jsonl") - if not os.path.exists(path): - logging.warning(f"Missing filewise metrics file: {path}") - continue - - with open(path, "r", encoding="utf-8") as f: - for line in f: - line = line.strip() - if line: - turn_rows.append(json.loads(line)) - - # Deduplicate DistributedSampler padding repeats, but preserve --num_eval_runs. - deduped_turns = {} - for row in turn_rows: - run_id = int(row.get("run_id", 0)) - idx = int(row.get("dataset_index", -1)) - turn_id = int(row.get("turn_id", 0)) - key = (run_id, idx, turn_id) - if key not in deduped_turns: - deduped_turns[key] = row - - turn_rows = list(deduped_turns.values()) - - # Group turn rows into one row per original file/sample. - grouped = {} - for row in turn_rows: - run_id = int(row.get("run_id", 0)) - idx = int(row.get("dataset_index", -1)) - key = (run_id, idx) - - if key not in grouped: - grouped[key] = { - "run_id": run_id, - "dataset_index": idx, - "rank": int(row.get("rank", -1)), - "target_audio_path": row.get("target_audio_path", ""), - "context_audio_path": row.get("context_audio_path", ""), - "turn_rows": [], - } - - grouped[key]["turn_rows"].append(row) - - def avg(vals): - vals = [float(x) for x in vals if x is not None and math.isfinite(float(x))] - return None if not vals else sum(vals) / len(vals) - - sample_rows = [] - for _, group in grouped.items(): - turns = sorted(group["turn_rows"], key=lambda x: int(x.get("turn_id", 0))) - - cer_turns = [r.get("cer") for r in turns] - wer_turns = [r.get("wer") for r in turns] - secs_turns = [r.get("secs") for r in turns] - - sample_row = { - "run_id": group["run_id"], - "dataset_index": group["dataset_index"], - "rank": group["rank"], - "num_turns": len(turns), - "turn_ids": [int(r.get("turn_id", 0)) for r in turns], - "target_audio_path": group["target_audio_path"], - "context_audio_path": group["context_audio_path"], - "pred_audio_paths": [r.get("pred_audio_path", "") for r in turns], - "pred_audio_seconds_turns": [r.get("pred_audio_seconds") for r in turns], - "reference_text": [r.get("reference_text", "") for r in turns], - "asr_hyp": [r.get("asr_hyp", "") for r in turns], - "cer_turns": cer_turns, - "wer_turns": wer_turns, - "secs_turns": secs_turns, - "cer": avg(cer_turns), - "wer": avg(wer_turns), - "secs": avg(secs_turns), - } - - sample_rows.append(sample_row) - - # Sort samples by average CER descending for failure analysis. - sample_rows.sort( - key=lambda x: ( - x.get("cer") is not None, - float(x.get("cer")) if x.get("cer") is not None else -1.0, - ), - reverse=True, - ) - - jsonl_path = os.path.join(args.out_dir, "filewise_metrics_sorted_by_cer.jsonl") - csv_path = os.path.join(args.out_dir, "filewise_metrics_sorted_by_cer.csv") - - write_jsonl(jsonl_path, sample_rows) - write_filewise_csv(csv_path, sample_rows) - - logging.info(f"Saved sample-level filewise metrics JSONL to: {jsonl_path}") - logging.info(f"Saved sample-level filewise metrics CSV to: {csv_path}") - - topk = min(int(args.filewise_metrics_topk_log), len(sample_rows)) - if topk > 0: - logging.info(f"Top {topk} worst CER samples:") - for row in sample_rows[:topk]: - logging.info( - "run_id=%s dataset_index=%s num_turns=%s cer=%s wer=%s secs=%s path=%s" - % ( - row.get("run_id"), - row.get("dataset_index"), - row.get("num_turns"), - row.get("cer"), - row.get("wer"), - row.get("secs"), - row.get("target_audio_path"), - ) - ) - - return sample_rows - -def compute_aggregates_from_filewise_rows(rows: List[Dict[str, Any]]): - """Aggregate over sample-level rows. - - Each row may internally contain multiple turn metrics in cer_turns/wer_turns, - but the final filewise average is over original samples/files. - """ - if len(rows) == 0: - return { - "cer": None, - "wer": None, - "secs": None, - "num_samples": 0, - } - - def avg_key(key): - vals = [float(r[key]) for r in rows if r.get(key) is not None] - if len(vals) == 0: - return None - return sum(vals) / len(vals) - - return { - "cer": avg_key("cer"), - "wer": avg_key("wer"), - "secs": avg_key("secs"), - "num_samples": len(rows), - } - -def save_filewise_final_summary(args, filewise_rows: List[Dict[str, Any]]): - filewise_summary = compute_aggregates_from_filewise_rows(filewise_rows) - - obj = { - "aggregation": "mean_over_sample_metrics_each_sample_contains_turn_metric_lists", - **filewise_summary, - } - - path = os.path.join(args.out_dir, "metrics_final_filewise_average.json") - write_json(path, obj) - - sample_metrics_final_path = os.path.join(args.out_dir, "metrics_final_sample_average.json") - write_json(sample_metrics_final_path, obj) - - final_txt_path = os.path.join(args.out_dir, "metrics_final.txt") - final_text = format_filewise_final_metric_text(filewise_summary) - write_text_atomic(final_txt_path, final_text) - - print("\n--- Final Sample-Averaged Evaluation Metrics ---", flush=True) - print(final_text, flush=True) - - logging.info(f"Filewise averaged final metrics saved to: {path}") - logging.info(f"Sample averaged metrics_final JSON saved to: {sample_metrics_final_path}") - logging.info(f"Final metrics TXT saved to: {final_txt_path}") - - return obj - - -# ----------------------------- -# Args / main -# ----------------------------- - - -def parse_args(): - parser = argparse.ArgumentParser(description="EasyMagpieTTS Multi-GPU Inference Evaluation") - - parser.add_argument("--checkpoint_path", type=str, required=True) - parser.add_argument("--codec_model_path", type=str, required=True) - parser.add_argument("--datasets_json_path", type=str, required=True) - parser.add_argument("--out_dir", type=str, required=True) - - parser.add_argument("--phoneme_tokenizer_path", type=str, default=None) - parser.add_argument("--audio_dir", type=str, default=None) - parser.add_argument("--inference_dtype", type=str, default="float32", choices=["float32", "float16", "bfloat16"]) - parser.add_argument("--debug_dtype", action="store_true") - parser.add_argument("--debug_gpu_assignment", action="store_true") - parser.add_argument("--use_librosa", action="store_true") - - parser.add_argument("--batch_size", type=int, default=6) - parser.add_argument("--num_workers", type=int, default=4) - parser.add_argument("--num_turns", type=int, default=1) - parser.add_argument("--pad_factor_text_speech", type=int, default=10) - - parser.add_argument("--emulate_duplex_inference", action="store_true") - parser.add_argument("--add_interruption_token", action="store_true") - parser.add_argument("--force_interruption", action="store_true") - parser.add_argument("--profile_multiturn_inference", action="store_true") - parser.add_argument("--profile_pad_min_sec", type=float, default=2.0) - parser.add_argument("--profile_pad_max_sec", type=float, default=2.0) - parser.add_argument("--max_eval_turns", type=int, default=6) - - parser.add_argument("--user_custom_speaker_reference", action="store_true") - parser.add_argument("--inference_speaker_reference", type=str, default=None) - parser.add_argument("--language", type=str, default="en") - - parser.add_argument("--use_cfg", action="store_true") - parser.add_argument("--cfg_scale", type=float, default=2.5) - parser.add_argument("--temperature", type=float, default=0.7) - parser.add_argument("--topk", type=int, default=80) - parser.add_argument("--max_tts_steps", type=int, default=2000) - parser.add_argument("--force_speech_sil_codes", action="store_true") - parser.add_argument("--normalize_volume", type=lambda x: str(x).lower() in ["true", "1", "yes"], default=True) - - parser.add_argument( - "--save_filewise_metrics", - action="store_true", - help="Save per-turn/file CER/WER metrics sorted by CER descending.", - ) - parser.add_argument( - "--compute_filewise_secs", - action="store_true", - help="Also compute per-turn/file SECS. Slower because it runs SECS per row.", - ) - parser.add_argument( - "--filewise_metrics_topk_log", - type=int, - default=20, - help="Number of worst CER samples to print on rank 0.", - ) - parser.add_argument( - "--num_eval_runs", - type=int, - default=1, - help="Repeat the full eval set N times. Repetitions are preserved in final filewise average.", - ) - parser.add_argument( - "--sort_by_text_token_count", - action="store_true", - help="Sort eval samples by total text token count before distributed sharding for better load balancing.", - ) - parser.add_argument( - "--metric_batch_size", - type=int, - default=8, - help="Batch size used for post-generation ASR/SECS metric computation.", - ) - parser.add_argument( - "--max_metric_audio_sec", - type=float, - default=120.0, - help="Clamp generated audio length used for ASR/SECS metrics to avoid metric OOM/hangs.", - ) - parser.add_argument( - "--asr_model_name", - type=str, - default="stt_en_fastconformer_transducer_large", - help="Pretrained NeMo ASR model used for CER/WER.", - ) - parser.add_argument( - "--secs_model_name", - type=str, - default="titanet_large", - help="Pretrained speaker encoder model used for SECS.", - ) - - return parser.parse_args() - - -def main(): - args = parse_args() - os.makedirs(args.out_dir, exist_ok=True) - os.makedirs(get_audio_out_dir(args), exist_ok=True) - os.makedirs(get_generated_turn_audio_dir(args), exist_ok=True) - os.makedirs(get_context_metric_audio_dir(args), exist_ok=True) - - distributed, rank, local_rank, world_size, device_index = setup_distributed() - - if args.profile_multiturn_inference and args.batch_size != 1: - raise RuntimeError( - "--profile_multiturn_inference requires --batch_size=1 per process. " - "Use multiple GPUs/processes for parallelism instead of increasing batch_size." - ) - - if args.profile_pad_max_sec < args.profile_pad_min_sec: - raise RuntimeError("--profile_pad_max_sec must be >= --profile_pad_min_sec.") - - if args.num_eval_runs <= 0: - raise RuntimeError("--num_eval_runs must be >= 1.") - - target_device = torch.device(f"cuda:{device_index}" if torch.cuda.is_available() and device_index >= 0 else "cpu") - target_dtype = getattr(torch, args.inference_dtype) - torch.set_default_dtype(target_dtype) - - hostname = socket.gethostname() - cuda_name = torch.cuda.get_device_name(target_device) if torch.cuda.is_available() and device_index >= 0 else "cpu" - - all_rank_print( - rank, - f"host={hostname} local_rank={local_rank} world_size={world_size} " - f"device={target_device} device_name={cuda_name}", - ) - - model = build_model_and_codec(args, target_device, target_dtype) - codec_sil_codes = model.codec_sil_codes - - if args.debug_dtype: - handles, stats, examples = attach_dtype_counter(model) - else: - handles = stats = examples = None - - full_eval_dataset = EvalJSONLDataset(args.datasets_json_path, num_turns=args.num_turns) - # debug - # full_eval_dataset.samples = full_eval_dataset.samples[:7] - - if args.sort_by_text_token_count: - full_eval_dataset = SortedByTextTokenCountDataset( - full_eval_dataset, - model=model, - max_eval_turns=args.max_eval_turns, - descending=True, - ) - - collate_fn = partial( - collate_and_tokenize_custom, - model=model, - extra_duration_thrshould=1.5, - sample_rate=model.sample_rate, - root_path=args.audio_dir, - emulate_duplex_inference=args.emulate_duplex_inference, - add_interruption_token=args.add_interruption_token, - pad_factor_text_speech=args.pad_factor_text_speech, - force_interruption=args.force_interruption, - normalize_audio_volume=args.normalize_volume, - use_librosa=args.use_librosa, - profile_multiturn_inference=args.profile_multiturn_inference, - max_eval_turns=args.max_eval_turns, - ) - - speaker_wav = load_speaker_wav_if_needed(args, model, target_dtype) - - generation_start = time.time() - all_metric_items = [] - total_batches = 0 - total_generated_samples = 0 - - for run_id in range(args.num_eval_runs): - if distributed: - sampler = DistributedSampler( - full_eval_dataset, - num_replicas=world_size, - rank=rank, - shuffle=False, - drop_last=False, - ) - sampler.set_epoch(run_id) - else: - sampler = SequentialSampler(full_eval_dataset) - - if args.debug_gpu_assignment: - try: - assigned_indices = list(iter(sampler)) - assigned_dataset_indices = [ - int(full_eval_dataset[i].get("__dataset_index__", -1)) for i in assigned_indices - ] - all_rank_print( - rank, - f"run_id={run_id} assigned {len(assigned_dataset_indices)} / {len(full_eval_dataset)} " - f"samples to gpu={local_rank}: dataset_indices={assigned_dataset_indices}", - ) - except Exception as e: - all_rank_print(rank, f"Could not print assigned indices: {repr(e)}") - - dataloader = DataLoader( - dataset=full_eval_dataset, - batch_size=args.batch_size, - sampler=sampler, - collate_fn=collate_fn, - num_workers=args.num_workers, - pin_memory=True, - drop_last=False, - ) - - for batch_id, inputs in enumerate(dataloader): - total_batches += 1 - batch_indices = inputs.get("dataset_indices", []) - total_generated_samples += len(batch_indices) - - if args.debug_gpu_assignment: - all_rank_print( - rank, - f"run_id={run_id} gpu={local_rank} batch_id={batch_id} " - f"dataset_indices={batch_indices} text_token_counts={inputs.get('text_token_counts', [])} " - f"target_paths={inputs.get('target_audio_paths', [])}", - ) - - inputs = prepare_inputs_for_device(inputs, model, args, target_dtype, speaker_wav=speaker_wav) - - finalize_output, profile_turn_frame_ranges, profile_decode_start_frame = run_generation( - model=model, - inputs=inputs, - args=args, - codec_sil_codes=codec_sil_codes, - ) - - metric_items = save_generation_outputs_and_build_metric_items( - model=model, - inputs=inputs, - finalize_output=finalize_output, - profile_turn_frame_ranges=profile_turn_frame_ranges, - profile_decode_start_frame=profile_decode_start_frame, - args=args, - rank=rank, - run_id=run_id, - ) - all_metric_items.extend(metric_items) - - if args.debug_dtype and batch_id == 0 and run_id == 0: - report_dtype_stats(handles, stats, examples, rank=rank) - - generation_elapsed = time.time() - generation_start - - # Save pre-metric manifest for debugging and restartability. - metric_manifest_path = os.path.join(args.out_dir, f"metric_items_rank{rank:04d}.jsonl") - write_jsonl(metric_manifest_path, all_metric_items) - - # Free TTS/codec model memory before loading ASR and speaker encoder metrics. - del model - if speaker_wav is not None: - del speaker_wav - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - all_rank_print( - rank, - f"generation done: batches={total_batches} generated_samples_with_sampler_padding={total_generated_samples} " - f"metric_items={len(all_metric_items)} elapsed_sec={generation_elapsed:.2f}. " - "Loading ASR/SECS metrics now.", - ) - - rank_metrics, rank_filewise_rows = compute_metrics_after_generation( - args=args, - rank=rank, - world_size=world_size, - metric_items=all_metric_items, - ) - rank_metrics["generation_elapsed_sec"] = float(generation_elapsed) - rank_metrics["num_generated_samples_with_sampler_padding"] = int(total_generated_samples) - - rank_metrics = compute_and_save_rank_metrics_file(args, rank_metrics, rank) - all_rank_print(rank, f"saved rank metrics: {json.dumps(rank_metrics, sort_keys=True)}") - - if args.save_filewise_metrics: - rank_filewise_rows.sort( - key=lambda x: ( - x.get("cer") is not None, - float(x.get("cer")) if x.get("cer") is not None else -1.0, - ), - reverse=True, - ) - - rank_filewise_path = os.path.join(args.out_dir, f"filewise_metrics_rank{rank:04d}.jsonl") - write_jsonl(rank_filewise_path, rank_filewise_rows) - all_rank_print(rank, f"saved filewise metrics: {rank_filewise_path}") - - if rank == 0: - wait_for_rank_metric_files(args, world_size) - - merge_metrics_on_rank0(args, rank, world_size) - - if args.save_filewise_metrics: - if rank == 0: - wait_for_rank_filewise_metric_files(args, world_size) - - filewise_rows = merge_filewise_metrics_on_rank0(args, rank, world_size) - - if rank == 0: - save_filewise_final_summary(args, filewise_rows) - - cleanup_distributed() - - -if __name__ == "__main__": - main() diff --git a/nemo/collections/tts/modules/magpietts_inference/inference.py b/nemo/collections/tts/modules/magpietts_inference/inference.py index 287149627988..31f9a2105877 100644 --- a/nemo/collections/tts/modules/magpietts_inference/inference.py +++ b/nemo/collections/tts/modules/magpietts_inference/inference.py @@ -28,6 +28,7 @@ import abc import glob +import json import os import shutil import time @@ -42,7 +43,8 @@ from nemo.collections.tts.data.text_to_speech_dataset import ChunkedTTSInferenceDataset, MagpieTTSDataset from nemo.collections.tts.models.easy_magpietts_inference import EasyModelInferenceParameters from nemo.collections.tts.models.magpietts import ModelInferenceParameters -from nemo.collections.tts.parts.utils.tts_dataset_utils import stack_tensors +from nemo.collections.audio.parts.utils.transforms import resample +from nemo.collections.tts.parts.utils.tts_dataset_utils import normalize_volume, stack_tensors from nemo.utils import logging @@ -680,6 +682,169 @@ def _compute_end_of_text_flags( return is_end_of_text + + +@dataclass +class EasyMagpieMultiturnUserAudioInferenceConfig(EasyMagpieInferenceConfig): + """Configuration for EasyMagpie multiturn user-audio inference. + + This mode keeps the standard EasyMagpie/MagpieTTS evaluation contract by + writing one evaluation row per generated agent turn: + + predicted_audio_.wav + predicted_codes_.pt + target_audio_.wav + context_audio_.wav + + The generated turn-level manifest is cached on the runner as + ``evaluation_manifest_path`` so the top-level inference script can pass it + to ``evaluate_generated_audio_dir``. + """ + + max_eval_turns: int = 6 + save_debug_multiturn_audio: bool = True + + def build_identifier(self) -> str: + return super().build_identifier() + f"_MTUserAudio_MaxTurns{self.max_eval_turns}" + + +class EasyMagpieMultiturnUserAudioDataset(torch.utils.data.Dataset): + """Manifest dataset for turn-level multiturn user-audio EasyMagpie inference.""" + + def __init__( + self, + manifest_path: str, + audio_dir: str, + model, + max_eval_turns: int = 6, + normalize_audio: bool = True, + ): + self.manifest_path = manifest_path + self.audio_dir = audio_dir or "" + self.model = model + self.max_eval_turns = max_eval_turns + self.normalize_audio = normalize_audio + self.records = read_manifest(manifest_path) + # debug + self.records = self.records[:7] + def __len__(self): + return len(self.records) + + def __getitem__(self, idx: int): + item = dict(self.records[idx]) + item["idx"] = idx + return item + + def _resolve_path(self, path: Optional[str]) -> Optional[str]: + if path is None or path == "": + return None + if os.path.isabs(path): + return path + return os.path.join(self.audio_dir, path) + + def _load_audio_1d(self, path: str, sample_rate: int) -> torch.Tensor: + path = self._resolve_path(path) + if path is None or not os.path.exists(path): + raise FileNotFoundError(f"Missing audio path: {path}") + + audio, sr = sf.read(path, dtype="float32", always_2d=False) + wav = torch.as_tensor(audio, dtype=torch.float32) + if wav.ndim == 2: + wav = wav.mean(dim=1) + wav = wav.flatten() + + if sr != sample_rate: + wav = resample(wav.unsqueeze(0), sr, sample_rate).squeeze(0) + + if self.normalize_audio: + try: + wav = normalize_volume(wav) + except Exception: + # Keep evaluation robust across normalize_volume signature changes. + pass + + return wav.contiguous() + + @staticmethod + def _as_turn_list(value) -> List[str]: + if isinstance(value, list): + return [str(x) for x in value] + return [str(value)] + + def collate_fn(self, batch: List[dict]) -> Dict[str, Any]: + if len(batch) != 1: + raise RuntimeError("multiturn_user_audio inference currently requires batch_size=1.") + + sample = batch[0] + model = self.model + sample_rate = model.sample_rate + main_tokenizer_name = list(model.cfg.text_tokenizers.keys())[0] + + raw_turn_texts = self._as_turn_list(sample["text"])[: self.max_eval_turns] + max_turns = len(raw_turn_texts) + + batched_turns = [] + batched_turn_lens = [] + valid_turn_masks = [] + for turn_text in raw_turn_texts: + ids = model.tokenizer.encode(turn_text, tokenizer_name=main_tokenizer_name) + [model.eos_id] + batched_turns.append(torch.tensor([ids], dtype=torch.long)) + batched_turn_lens.append(torch.tensor([len(ids)], dtype=torch.long)) + valid_turn_masks.append(torch.tensor([True], dtype=torch.bool)) + + context_path = self._resolve_path(sample.get("context_audio_filepath")) + context_audio = self._load_audio_1d(context_path, sample_rate).unsqueeze(0) + context_audio_lens = torch.tensor([context_audio.size(1)], dtype=torch.long) + + user_audio_paths = sample.get("user_audio_file_path", None) + if not isinstance(user_audio_paths, list): + user_audio_paths = [] + + user_audio_turns = [] + user_audio_turns_lens = [] + for turn_id in range(max_turns): + if turn_id < len(user_audio_paths) and user_audio_paths[turn_id]: + wav = self._load_audio_1d(user_audio_paths[turn_id], sample_rate) + else: + wav = torch.zeros(int(2 * sample_rate), dtype=torch.float32) + user_audio_turns.append(wav.unsqueeze(0)) + user_audio_turns_lens.append(torch.tensor([wav.numel()], dtype=torch.long)) + + target_turn_audio_paths = sample.get("target_audio_file_path", sample.get("target_audio_filepath", None)) + if target_turn_audio_paths is not None and not isinstance(target_turn_audio_paths, list): + target_turn_audio_paths = [target_turn_audio_paths] + + return { + "idx": torch.tensor([int(sample["idx"])], dtype=torch.long), + "raw_record": sample, + "raw_turn_texts": [raw_turn_texts], + "batched_turns": batched_turns, + "batched_turn_lens": batched_turn_lens, + "valid_turn_masks": valid_turn_masks, + "context_audio": context_audio, + "context_audio_lengths": context_audio_lens, + "user_audio_turns": user_audio_turns, + "user_audio_turns_lens": user_audio_turns_lens, + "target_audio_path": sample.get("audio_filepath"), + "target_turn_audio_paths": target_turn_audio_paths, + } + + + +class _InferenceSubset(torch.utils.data.Dataset): + """Subset wrapper that preserves the wrapped dataset collate_fn.""" + + def __init__(self, dataset, indices: List[int]): + self.dataset = dataset + self.indices = list(indices) + self.collate_fn = getattr(dataset, "collate_fn", None) + + def __len__(self): + return len(self.indices) + + def __getitem__(self, idx: int): + return self.dataset[self.indices[idx]] + class EasyMagpieInferenceRunner(BaseInferenceRunner): """Runner for decoder-only EasyMagpieTTSInferenceModel. @@ -823,3 +988,666 @@ def _run_decoder_only_inference( item_idx += 1 return all_rtf_metrics, generated_audio_paths, codec_file_paths + +class EasyMagpieMultiturnUserAudioInferenceRunner(BaseInferenceRunner): + """Runner for decoder-only EasyMagpieTTS multiturn user-audio inference. + + It generates one agent turn at a time using user-audio prefill, but writes + outputs using the standard EasyMagpie evaluation contract. Therefore the + existing magpietts_inference.py evaluation code can call + evaluate_generated_audio_dir() unchanged. + """ + + produces_turn_level_evaluation: bool = True + + def __init__(self, model, config: EasyMagpieMultiturnUserAudioInferenceConfig): + if config.batch_size != 1: + raise ValueError("EasyMagpie multiturn user-audio inference requires batch_size=1.") + super().__init__(model, config) + self.evaluation_manifest_path: Optional[str] = None + self.evaluation_audio_dir: Optional[str] = None + self.evaluation_manifest_records: Optional[List[dict]] = None + + # Used by examples/tts/magpietts_inference.py for torchrun sharding. + self.distributed_rank: int = int(os.environ.get("RANK", "0")) + self.distributed_world_size: int = int(os.environ.get("WORLD_SIZE", "1")) + + def create_dataset( + self, + dataset_meta: dict, + context_duration_min: Optional[float] = None, + context_duration_max: Optional[float] = None, + ) -> EasyMagpieMultiturnUserAudioDataset: + manifest_path, audio_dir = self._read_and_cache_manifest(dataset_meta) + logging.info("Creating multiturn user-audio inference dataset for decoder-only model") + return EasyMagpieMultiturnUserAudioDataset( + manifest_path=manifest_path, + audio_dir=audio_dir, + model=self.model, + max_eval_turns=self.config.max_eval_turns, + normalize_audio=True, + ) + + def set_distributed_context(self, rank: int, world_size: int) -> None: + self.distributed_rank = int(rank) + self.distributed_world_size = int(world_size) + + def run_inference_on_dataset( + self, + dataset: EasyMagpieMultiturnUserAudioDataset, + output_dir: str, + manifest_records: Optional[List[dict]] = None, + audio_base_dir: Optional[str] = None, + save_cross_attention_maps: bool = True, + save_context_audio: bool = True, + save_predicted_codes: bool = True, + ) -> Tuple[List[dict], List[str], List[str]]: + manifest_records, audio_base_dir = self._resolve_manifest_and_audio_dir(manifest_records, audio_base_dir) + return self._run_multiturn_user_audio_inference( + dataset=dataset, + output_dir=output_dir, + manifest_records=manifest_records, + audio_base_dir=audio_base_dir, + save_context_audio=save_context_audio, + save_predicted_codes=save_predicted_codes, + ) + + @staticmethod + def _move_batch_to_device(batch: Dict[str, Any], device: torch.device) -> Dict[str, Any]: + out = {} + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + out[key] = value.to(device) + elif isinstance(value, list): + out[key] = [v.to(device) if isinstance(v, torch.Tensor) else v for v in value] + else: + out[key] = value + return out + + @staticmethod + def _copy_or_link(src: Optional[str], dst: str, required: bool = False, description: str = "audio") -> Optional[str]: + """Copy/symlink an audio artifact and optionally fail fast if missing. + + Evaluation later expects target_audio_*.wav/context_audio_*.wav to exist. + Silently skipping those files makes evaluate_generated_audio_dir fail much + later with a less useful FileNotFoundError, so target/context paths should + call this with required=True. + """ + if src is None or src == "": + if required: + raise FileNotFoundError(f"Missing required {description}: source path is empty for destination {dst}") + return None + if not os.path.exists(src): + if required: + raise FileNotFoundError( + f"Missing required {description}: source path does not exist: {src}; destination would be: {dst}" + ) + return None + + os.makedirs(os.path.dirname(dst), exist_ok=True) + try: + if os.path.lexists(dst): + os.remove(dst) + os.symlink(os.path.abspath(src), dst) + except Exception: + shutil.copyfile(src, dst) + + if required and not os.path.exists(dst): + raise FileNotFoundError( + f"Failed to materialize required {description}: src={src}, dst={dst}. " + "The destination may be a broken symlink." + ) + return dst + + def _resolve_audio_path(self, path: Optional[str], audio_base_dir: str) -> Optional[str]: + if path is None or path == "": + return None + if os.path.isabs(path): + return path + return os.path.join(audio_base_dir, path) + + def _ensure_codec_silence_codes(self) -> torch.Tensor: + """Ensure silence codec codes exist before streaming_prefill_profile. + + Newer EasyMagpieTTSInferenceModel exposes codec_sil_codes as a @property + backed by _codec_sil_codes_buffer. Some older branches/checkpoints do not + have that property, but still have _generate_codec_silence_buffer(). This + helper supports both cases and creates a plain module attribute fallback + when the property is absent. + """ + if not hasattr(self.model, "_codec_sil_codes_buffer"): + if not hasattr(self.model, "_generate_codec_silence_buffer"): + raise AttributeError( + "Model does not have _codec_sil_codes_buffer or _generate_codec_silence_buffer(); " + "cannot run multiturn_user_audio prefill." + ) + self.model._generate_codec_silence_buffer() + + class_codec_sil_codes = getattr(type(self.model), "codec_sil_codes", None) + if class_codec_sil_codes is None: + # Compatibility with branches where streaming_prefill_profile expects + # self.codec_sil_codes but the @property was not added. + self.model.codec_sil_codes = self.model._codec_sil_codes_buffer + if hasattr(self.model, "_codec_sil_codes_buffer_unconverted"): + self.model.codec_sil_codes_unconverted = self.model._codec_sil_codes_buffer_unconverted + + return self.model._codec_sil_codes_buffer.to(self.model.device).long() + + def _run_multiturn_generation(self, batch: Dict[str, Any]): + model = self.model + device = model.device + B = int(batch["context_audio"].size(0)) + if B != 1: + raise RuntimeError("multiturn_user_audio generation requires batch_size=1.") + + with torch.inference_mode(): + # streaming_prefill_profile reads self.model.codec_sil_codes, so make + # sure the silence buffer/property exists before entering the turn loop. + self._ensure_codec_silence_codes() + + wav = batch["context_audio"] + wav_len = batch["context_audio_lengths"] + codes, codes_lens = model._codec_helper.audio_to_codes(wav, wav_len) + + use_lang = bool(getattr(model, "add_language_to_context_text", False)) + ctx_text = "[EN]" if use_lang else "[NO TEXT CONTEXT]" + ctx_text_ids = model.tokenizer.encode(ctx_text, tokenizer_name=model.text_conditioning_tokenizer_name) + ctx_toks = torch.tensor([ctx_text_ids], dtype=torch.long, device=device).expand(B, -1) + ctx_toks_lens = torch.tensor([len(ctx_text_ids)] * B, dtype=torch.long, device=device) + + params = self.config.model_inference_parameters + state = model.streaming_init( + context_audio_codes=codes, + context_audio_codes_lens=codes_lens, + context_text_tokens=ctx_toks, + context_text_tokens_lens=ctx_toks_lens, + use_cfg=self.config.use_cfg, + cfg_scale=params.cfg_scale, + use_local_transformer=self.config.use_local_transformer, + temperature=params.temperature, + topk=params.topk, + phoneme_input_type="pred", + phoneme_sampling_method=self.config.phoneme_sampling_method, + use_inference_mode=True, + ) + + turn_frame_ranges = [] + decode_start_frame = 0 + max_decoder_steps = params.max_decoder_steps + + for turn_id in range(len(batch["batched_turns"])): + turn_text = batch["batched_turns"][turn_id].to(device) + turn_lens = batch["batched_turn_lens"][turn_id].to(device) + valid_mask = batch["valid_turn_masks"][turn_id].to(device) + if not bool(valid_mask[0].item()): + continue + + state.finished.zero_() + state.text_finished.zero_() + state.audio_prediction_end_idx.fill_(-1) + for attr in [ + "turn_text_tokens_seen", + "phoneme_steps", + "phoneme_stream_ended", + "phoneme_eos_detected", + ]: + if hasattr(state, attr): + getattr(state, attr).zero_() + state.last_phoneme_tokens = None + + if not model.cfg.get("condition_on_user_speech", False): + user_audio = batch["user_audio_turns"][turn_id] + user_audio_prefill_steps = int(round(user_audio.size(-1) / model.input_samples_per_frame)) + user_audio_prefill_tokens = torch.full( + (1, user_audio_prefill_steps), model.pad_id, dtype=torch.long, device=device + ) + user_audio_channel_embedding = None + else: + user_audio = batch["user_audio_turns"][turn_id] + user_audio_lens = batch["user_audio_turns_lens"][turn_id] + user_audio_codes, user_audio_codes_lens = model._codec_helper.audio_to_codes( + user_audio, user_audio_lens + ) + + if model._codec_converter is not None: + user_audio_codes = model._codec_converter.convert_original_to_new( + audio_tokens=user_audio_codes, + audio_lens=user_audio_codes_lens, + ).long() + + user_audio_codes, user_audio_codes_lens = model.stack_codes( + user_audio_codes, + user_audio_codes_lens, + model.audio_bos_id, + model.audio_eos_id, + model.frame_stacking_factor, + model.num_audio_codebooks, + ) + + user_audio_embedded = model.embed_audio_tokens(user_audio_codes) + boundary_trim = model.cfg.get("user_audio_boundary_trim", 4) + boundary_trim = 0 if boundary_trim is None else int(boundary_trim) + + if boundary_trim == 0: + real_start = 0 + real_end = int(user_audio_codes_lens[0].item()) + else: + real_start = 1 + real_end = max(real_start, int(user_audio_codes_lens[0].item()) - 1) + + user_audio_embedded = user_audio_embedded[:, real_start:real_end] + copy_len = user_audio_embedded.size(1) + if boundary_trim > 0: + trim = min(boundary_trim, copy_len // 2) + if trim > 0: + user_audio_embedded[:, :trim] = 0.0 + user_audio_embedded[:, copy_len - trim :] = 0.0 + + bos_user_pad = torch.zeros( + user_audio_embedded.size(0), + 1, + user_audio_embedded.size(2), + device=user_audio_embedded.device, + dtype=user_audio_embedded.dtype, + ) + user_audio_embedded = torch.cat([bos_user_pad, user_audio_embedded], dim=1) + user_audio_prefill_steps = user_audio_embedded.size(1) + user_audio_prefill_tokens = torch.full( + (B, user_audio_prefill_steps), model.pad_id, dtype=torch.long, device=device + ) + user_audio_channel_embedding = user_audio_embedded + + delay_tokens = int(state.config.training_mode.streaming_speech_delay) + delay_tokens = min(delay_tokens, int(turn_lens[0].item()), user_audio_prefill_steps) + + warmup_tokens = turn_text[:, :delay_tokens] + turn_text = turn_text[:, delay_tokens:] + turn_lens = torch.clamp(turn_lens - delay_tokens, min=0) + + if user_audio_channel_embedding is not None and delay_tokens > 0: + warmup_user_audio = user_audio_channel_embedding[:, -delay_tokens:] + user_audio_channel_embedding = user_audio_channel_embedding[:, :-delay_tokens] + user_audio_prefill_tokens = user_audio_prefill_tokens[:, :-delay_tokens] + else: + warmup_user_audio = None + + if user_audio_prefill_tokens.size(1) > 0: + state = model.streaming_prefill_profile( + state=state, + text_tokens=user_audio_prefill_tokens, + use_inference_mode=True, + user_audio_channel_embedding=user_audio_channel_embedding, + ) + + for i in range(delay_tokens): + user_step_emb = warmup_user_audio[:, i] if warmup_user_audio is not None else None + state.finished.zero_() + state, _, _ = model.streaming_step( + state=state, + text_tokens=warmup_tokens[:, i], + user_audio_channel_embedding=user_step_emb, + prefill_like_step=not bool(model.cfg.get("agent_mask_include_transition_prefix", False)), + prefill_like_is_last_step=(i == delay_tokens - 1), + use_inference_mode=True, + ) + + turn_start_frame = sum(p.size(-1) for p in state.all_predictions) + if turn_id == 0: + state.audio_prediction_start_idx.fill_(turn_start_frame) + decode_start_frame = turn_start_frame + + turn_offset = state.text_tokens_seen.clone() + steps = 0 + while steps < max_decoder_steps: + steps += 1 + state.finished.zero_() + relative_position = state.text_tokens_seen - turn_offset + text_exhausted = relative_position >= turn_lens + + if turn_text.size(1) == 0: + current_tokens = torch.full((B,), model.eos_id, dtype=torch.long, device=device) + else: + position = relative_position.clamp(min=0, max=turn_text.size(1) - 1) + current_tokens = turn_text[torch.arange(B, device=device), position] + current_tokens = torch.where( + text_exhausted, + torch.full_like(current_tokens, model.eos_id), + current_tokens, + ) + + state, _, _ = model.streaming_step( + state=state, + text_tokens=current_tokens, + use_inference_mode=True, + ) + + if bool(text_exhausted[0].item()) and bool(state.finished[0].item()): + break + + state.audio_prediction_end_idx.fill_(-1) + state.finished.zero_() + turn_end_frame = sum(p.size(-1) for p in state.all_predictions) + turn_frame_ranges.append((turn_id, turn_start_frame, turn_end_frame)) + + codec_sil_codes = self._ensure_codec_silence_codes() + bos_id = getattr(model, "audio_bos_id", -1) + eos_id = getattr(model, "audio_eos_id", -1) + speaking_id = getattr(model, "audio_user_speaking_id", -1) + speaking_end_id = getattr(model, "audio_user_speaking_end_id", -1) + sil_injection = codec_sil_codes.view(1, -1, 1) + + for step_idx in range(len(state.all_predictions)): + pred = state.all_predictions[step_idx] + mask = (pred == bos_id) | (pred == eos_id) | (pred == speaking_id) | (pred == speaking_end_id) + frame_mask = mask.any(dim=1, keepdim=True) + if frame_mask.any(): + state.all_predictions[step_idx] = torch.where(frame_mask, sil_injection.expand_as(pred), pred) + + state.audio_prediction_end_idx.fill_(-1) + generated_codes = None + if getattr(state, "all_predictions", None): + generated_codes = torch.cat(state.all_predictions, dim=-1).detach() + + finalize_output = model.streaming_finalize(state, use_inference_mode=True) + + return finalize_output, turn_frame_ranges, decode_start_frame, generated_codes + + @staticmethod + def _save_code_slice(generated_codes, batch_idx: int, start_frame: int, end_frame: int, path: str) -> Optional[str]: + if generated_codes is None: + return None + os.makedirs(os.path.dirname(path), exist_ok=True) + total_frames = int(generated_codes.size(-1)) + start_frame = max(0, min(int(start_frame), total_frames)) + end_frame = max(start_frame, min(int(end_frame), total_frames)) + if end_frame <= start_frame: + return None + codes = generated_codes[batch_idx, :, start_frame:end_frame].detach().cpu().long() + torch.save(codes, path) + return path + + def _resolve_target_audio_for_turn( + self, + raw_record: dict, + target_turn_audio_paths, + local_turn_idx: int, + audio_base_dir: str, + ) -> Optional[str]: + """Resolve the GT target audio for one evaluation turn. + + Prefer per-turn GT if present; otherwise fall back to sample-level audio_filepath. + If no candidate exists, returns None so the caller can fall back to + context audio and keep EasyMagpie evaluation from failing on missing + target_audio_*.wav. + """ + candidates = [] + + if isinstance(target_turn_audio_paths, list) and local_turn_idx < len(target_turn_audio_paths): + candidates.append(target_turn_audio_paths[local_turn_idx]) + elif isinstance(target_turn_audio_paths, str): + candidates.append(target_turn_audio_paths) + + # Common manifest keys. + candidates.extend( + [ + raw_record.get("target_audio_file_path"), + raw_record.get("target_audio_filepath"), + raw_record.get("audio_filepath"), + ] + ) + + tried = [] + for candidate in candidates: + if candidate is None or candidate == "": + continue + if isinstance(candidate, list): + if local_turn_idx < len(candidate): + candidate = candidate[local_turn_idx] + else: + continue + resolved = self._resolve_audio_path(candidate, audio_base_dir) + tried.append(resolved) + if resolved is not None and os.path.exists(resolved): + return resolved + + logging.warning( + "Could not resolve target audio for multiturn_user_audio evaluation turn; " + "caller will fall back to context audio. " + f"sample_idx={raw_record.get('idx')}, local_turn_idx={local_turn_idx}, " + f"audio_base_dir={audio_base_dir}, tried={tried}, " + f"raw_record_keys={sorted(raw_record.keys())}" + ) + return None + + def _run_multiturn_user_audio_inference( + self, + dataset: EasyMagpieMultiturnUserAudioDataset, + output_dir: str, + manifest_records: List[dict], + audio_base_dir: str, + save_context_audio: bool = True, + save_predicted_codes: bool = True, + ) -> Tuple[List[dict], List[str], List[str]]: + os.makedirs(output_dir, exist_ok=True) + self._delete_old_generated_files(output_dir) + + debug_user_dir = os.path.join(output_dir, "debug_user_turns") + debug_mixed_dir = os.path.join(output_dir, "debug_mixed_user_agent") + if self.config.save_debug_multiturn_audio: + os.makedirs(debug_user_dir, exist_ok=True) + os.makedirs(debug_mixed_dir, exist_ok=True) + + rank = int(getattr(self, "distributed_rank", 0)) + world_size = int(getattr(self, "distributed_world_size", 1)) + if world_size > 1: + rank_indices = list(range(rank, len(dataset), world_size)) + logging.info( + f"multiturn_user_audio distributed sharding: rank={rank}/{world_size}, " + f"local_samples={len(rank_indices)}, total_samples={len(dataset)}" + ) + dataset_for_rank = _InferenceSubset(dataset, rank_indices) + else: + dataset_for_rank = dataset + + dataloader = torch.utils.data.DataLoader( + dataset_for_rank, + batch_size=1, + collate_fn=dataset_for_rank.collate_fn, + num_workers=0, + shuffle=False, + ) + + all_rtf_metrics = [] + generated_audio_paths = [] + codec_file_paths = [] + turn_manifest_records = [] + item_idx = 0 + + sample_rate = getattr(self.model, "output_sample_rate", self.model.sample_rate) + + for batch_idx, batch in enumerate(dataloader): + logging.info(f"Processing multiturn user-audio sample {batch_idx + 1}/{len(dataloader)}") + batch = self._move_batch_to_device(batch, self.model.device) + sample_idx = int(batch["idx"][0].item()) + raw_record = batch["raw_record"] + raw_turn_texts = batch["raw_turn_texts"][0] + + start_time = time.time() + output, turn_frame_ranges, decode_start_frame, generated_codes = self._run_multiturn_generation(batch) + elapsed = time.time() - start_time + + predicted_audio = output.audio.float().detach().cpu() + predicted_audio_lens = output.audio_len.int().detach().cpu() + full_len = int(predicted_audio_lens[0].item()) + full_wav = predicted_audio[0, :full_len] + + samples_per_prediction_frame = self.model.codec_model_samples_per_frame / ( + self.model.sample_rate / sample_rate + ) + aligned_agent = torch.zeros_like(full_wav) + + context_len = int(batch["context_audio_lengths"][0].detach().cpu().item()) + context_wav = batch["context_audio"][0, :context_len].detach().cpu().float() + context_audio_path = os.path.join(output_dir, f"context_audio_sample_{sample_idx}.wav") + # Always write context audio because evaluate_generated_audio_dir reads + # context_audio_filepath from the generated turn-level manifest for every repeat. + sf.write(context_audio_path, context_wav.numpy(), self.model.sample_rate) + + target_turn_audio_paths = batch.get("target_turn_audio_paths") + + for local_turn_idx, (turn_id, start_frame, end_frame) in enumerate(turn_frame_ranges): + rel_start_frame = start_frame - decode_start_frame + rel_end_frame = end_frame - decode_start_frame + start_sample = int(round(rel_start_frame * samples_per_prediction_frame)) + end_sample = int(round(rel_end_frame * samples_per_prediction_frame)) + start_sample = max(0, min(start_sample, full_len)) + end_sample = max(start_sample, min(end_sample, full_len)) + + aligned_agent[start_sample:end_sample] = full_wav[start_sample:end_sample] + turn_wav = aligned_agent[start_sample:end_sample].float() + + predicted_audio_path = os.path.join(output_dir, f"predicted_audio_{item_idx}.wav") + sf.write(predicted_audio_path, turn_wav.numpy(), sample_rate) + generated_audio_paths.append(predicted_audio_path) + + if save_predicted_codes: + code_path = os.path.join(output_dir, f"predicted_codes_{item_idx}.pt") + saved_code_path = self._save_code_slice(generated_codes, 0, start_frame, end_frame, code_path) + if saved_code_path is not None: + codec_file_paths.append(saved_code_path) + + turn_context_path = os.path.join(output_dir, f"context_audio_{item_idx}.wav") + self._copy_or_link( + context_audio_path, + turn_context_path, + required=True, + description=f"context audio for sample_idx={sample_idx}, turn_id={turn_id}", + ) + + target_src = self._resolve_target_audio_for_turn( + raw_record=raw_record, + target_turn_audio_paths=target_turn_audio_paths, + local_turn_idx=local_turn_idx, + audio_base_dir=audio_base_dir, + ) + if target_src is None or not os.path.exists(target_src): + logging.warning( + "Target audio is missing for multiturn_user_audio evaluation; " + "using context audio as target fallback to avoid evaluator failure. " + f"sample_idx={sample_idx}, turn_id={turn_id}, missing_target={target_src}, " + f"context_audio_path={context_audio_path}" + ) + target_src = context_audio_path + + target_dst = os.path.join(output_dir, f"target_audio_{item_idx}.wav") + self._copy_or_link( + target_src, + target_dst, + required=True, + description=f"target audio fallback/context for sample_idx={sample_idx}, turn_id={turn_id}", + ) + + turn_manifest_records.append( + { + "audio_filepath": f"target_audio_{item_idx}.wav", + "context_audio_filepath": f"context_audio_{item_idx}.wav", + "text": raw_turn_texts[local_turn_idx] if local_turn_idx < len(raw_turn_texts) else "", + "speaker": str(sample_idx), + "source_sample_idx": sample_idx, + "turn_id": int(turn_id), + } + ) + item_idx += 1 + + full_agent_path = os.path.join(output_dir, f"predicted_audio_sample_{sample_idx}_full_agent.wav") + sf.write(full_agent_path, aligned_agent.numpy(), sample_rate) + + if self.config.save_debug_multiturn_audio and "user_audio_turns" in batch: + self._save_debug_user_agent_audio( + batch=batch, + sample_idx=sample_idx, + turn_frame_ranges=turn_frame_ranges, + decode_start_frame=decode_start_frame, + aligned_agent=aligned_agent, + samples_per_prediction_frame=samples_per_prediction_frame, + output_dir=output_dir, + debug_user_dir=debug_user_dir, + debug_mixed_dir=debug_mixed_dir, + ) + + audio_seconds = sum( + max(0, int(round((end - start) * samples_per_prediction_frame))) for _, start, end in turn_frame_ranges + ) / sample_rate + all_rtf_metrics.append( + { + "inference_time": elapsed, + "audio_seconds": audio_seconds, + "rtf": elapsed / audio_seconds if audio_seconds > 0 else 0.0, + } + ) + + self.evaluation_audio_dir = output_dir + rank = int(getattr(self, "distributed_rank", 0)) + self.evaluation_manifest_path = os.path.join(output_dir, f"multiturn_user_audio_turn_manifest_rank{rank:04d}.jsonl") + self.evaluation_manifest_records = turn_manifest_records + with open(self.evaluation_manifest_path, "w", encoding="utf-8") as f: + for record in turn_manifest_records: + f.write(json.dumps(record, ensure_ascii=False) + "\n") + + logging.info(f"Wrote multiturn turn-level evaluation manifest: {self.evaluation_manifest_path}") + return all_rtf_metrics, generated_audio_paths, codec_file_paths + + def _save_debug_user_agent_audio( + self, + batch: Dict[str, Any], + sample_idx: int, + turn_frame_ranges: List[Tuple[int, int, int]], + decode_start_frame: int, + aligned_agent: torch.Tensor, + samples_per_prediction_frame: float, + output_dir: str, + debug_user_dir: str, + debug_mixed_dir: str, + ) -> None: + sample_rate = getattr(self.model, "output_sample_rate", self.model.sample_rate) + first_user_len_in = int(batch["user_audio_turns_lens"][0][0].detach().cpu().item()) + first_user_delay_out = int(round(first_user_len_in * sample_rate / self.model.sample_rate)) + + user_segments = [] + for turn_id, _, _ in turn_frame_ranges: + if turn_id >= len(batch["user_audio_turns"]): + continue + turn_audio = batch["user_audio_turns"][turn_id][0].detach().cpu().float() + turn_audio_len = int(batch["user_audio_turns_lens"][turn_id][0].detach().cpu().item()) + turn_audio = turn_audio[:turn_audio_len] + turn_audio_out = resample(turn_audio.unsqueeze(0), self.model.sample_rate, sample_rate).squeeze(0) + + user_turn_path = os.path.join(debug_user_dir, f"sample_{sample_idx}_user_turn_{turn_id}.wav") + sf.write(user_turn_path, turn_audio_out.numpy(), sample_rate) + + if turn_id == 0: + user_start_sample = 0 + else: + prev_turn_end_frame = turn_frame_ranges[turn_id - 1][2] + rel_prev_end_frame = prev_turn_end_frame - decode_start_frame + user_start_sample = first_user_delay_out + int(round(rel_prev_end_frame * samples_per_prediction_frame)) + user_segments.append((user_start_sample, turn_audio_out)) + + total_user_len = 0 + for start, wav in user_segments: + total_user_len = max(total_user_len, start + wav.numel()) + user_ch = torch.zeros(total_user_len) + for start, wav in user_segments: + user_ch[start : start + wav.numel()] += wav + + agent_ch = torch.cat([torch.zeros(first_user_delay_out, dtype=aligned_agent.dtype), aligned_agent]) + stereo_len = max(user_ch.numel(), agent_ch.numel()) + user_pad = torch.zeros(stereo_len) + agent_pad = torch.zeros(stereo_len) + user_pad[: user_ch.numel()] = user_ch + agent_pad[: agent_ch.numel()] = agent_ch + + mono_mix = torch.clamp(user_pad + agent_pad, min=-1.0, max=1.0) + sf.write(os.path.join(debug_mixed_dir, f"sample_{sample_idx}_user_agent_mixed_mono.wav"), mono_mix.numpy(), sample_rate) + stereo = torch.stack([user_pad, agent_pad], dim=1).numpy() + sf.write(os.path.join(debug_mixed_dir, f"sample_{sample_idx}_user_agent_aligned_stereo.wav"), stereo, sample_rate) From 6ab2a2e1fdd019a4476547380ef277b2f33e5256 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 8 Jun 2026 12:43:35 -0700 Subject: [PATCH 074/109] Update EasyMagpie inference script to support multiturn Signed-off-by: Edresson Casanova --- examples/tts/magpietts_inference.py | 242 ++++++++++++++++-- .../modules/magpietts_inference/inference.py | 3 +- 2 files changed, 229 insertions(+), 16 deletions(-) diff --git a/examples/tts/magpietts_inference.py b/examples/tts/magpietts_inference.py index 90f97fdedfa8..a8763010e5e7 100644 --- a/examples/tts/magpietts_inference.py +++ b/examples/tts/magpietts_inference.py @@ -60,6 +60,7 @@ import os import random import shutil +import time from dataclasses import fields from pathlib import Path from typing import List, Optional, Tuple @@ -82,6 +83,8 @@ BaseInferenceRunner, EasyMagpieInferenceConfig, EasyMagpieInferenceRunner, + EasyMagpieMultiturnUserAudioInferenceConfig, + EasyMagpieMultiturnUserAudioInferenceRunner, MagpieInferenceConfig, MagpieInferenceRunner, ) @@ -169,6 +172,127 @@ def filter_datasets( return datasets +def _runner_eval_manifest_and_audio_dir(runner: BaseInferenceRunner, default_manifest: str, default_audio_dir: str): + """Return evaluation manifest/audio dir produced by the runner, if any.""" + eval_manifest = getattr(runner, "evaluation_manifest_path", None) or default_manifest + eval_audio_dir = getattr(runner, "evaluation_audio_dir", None) or default_audio_dir + return eval_manifest, eval_audio_dir + + +def _get_torchrun_rank_info() -> Tuple[int, int, int]: + """Return (rank, world_size, local_rank) from torchrun/SLURM env vars. + + We intentionally do not initialize torch.distributed here. The inference + script only needs env-based sharding, while NeMo evaluation models can run + without distributed collectives. + """ + rank = int(os.environ.get("RANK", os.environ.get("SLURM_PROCID", "0"))) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("SLURM_LOCALID", "0"))) + return rank, world_size, local_rank + + +def _configure_cuda_for_rank() -> Tuple[int, int, int]: + rank, world_size, local_rank = _get_torchrun_rank_info() + if torch.cuda.is_available(): + device_count = torch.cuda.device_count() + if device_count > 0: + torch.cuda.set_device(local_rank % device_count) + logging.info( + f"Using CUDA device {local_rank % device_count}; " + f"rank={rank}, local_rank={local_rank}, world_size={world_size}" + ) + return rank, world_size, local_rank + + +def _wait_for_multiturn_rank_manifests(repeat_audio_dir: str, world_size: int, timeout_sec: int = 7200) -> None: + deadline = time.time() + timeout_sec + while time.time() < deadline: + missing = [] + for rank in range(world_size): + path = os.path.join( + repeat_audio_dir, + f"rank_{rank:04d}", + f"multiturn_user_audio_turn_manifest_rank{rank:04d}.jsonl", + ) + if not os.path.exists(path): + missing.append(path) + if not missing: + return + time.sleep(5) + raise RuntimeError(f"Timed out waiting for multiturn rank manifests: {missing}") + + +def _copy_or_link(src: str, dst: str) -> None: + if src is None or not os.path.exists(src): + return + os.makedirs(os.path.dirname(dst), exist_ok=True) + try: + if os.path.lexists(dst): + os.remove(dst) + os.symlink(os.path.abspath(src), dst) + except Exception: + shutil.copyfile(src, dst) + + +def _merge_multiturn_rank_outputs(repeat_audio_dir: str, world_size: int, save_predicted_codes: bool) -> str: + """Merge rank-local multiturn outputs into one EasyMagpie-compatible dir. + + Each rank writes local files named predicted_audio_0.wav, target_audio_0.wav, + context_audio_0.wav, predicted_codes_0.pt, ... inside rank_XXXX/. This + function remaps them to contiguous global indices in repeat_audio_dir/ and + writes a merged turn-level manifest. + """ + merged_records = [] + global_idx = 0 + + for rank in range(world_size): + rank_dir = os.path.join(repeat_audio_dir, f"rank_{rank:04d}") + rank_manifest = os.path.join(rank_dir, f"multiturn_user_audio_turn_manifest_rank{rank:04d}.jsonl") + if not os.path.exists(rank_manifest): + raise FileNotFoundError(f"Missing rank manifest: {rank_manifest}") + + with open(rank_manifest, "r", encoding="utf-8") as f: + rank_records = [json.loads(line) for line in f if line.strip()] + + for local_idx, record in enumerate(rank_records): + pred_src = os.path.join(rank_dir, f"predicted_audio_{local_idx}.wav") + pred_dst = os.path.join(repeat_audio_dir, f"predicted_audio_{global_idx}.wav") + _copy_or_link(pred_src, pred_dst) + + if save_predicted_codes: + code_src = os.path.join(rank_dir, f"predicted_codes_{local_idx}.pt") + code_dst = os.path.join(repeat_audio_dir, f"predicted_codes_{global_idx}.pt") + _copy_or_link(code_src, code_dst) + + target_src = os.path.join(rank_dir, record.get("audio_filepath", f"target_audio_{local_idx}.wav")) + target_dst = os.path.join(repeat_audio_dir, f"target_audio_{global_idx}.wav") + _copy_or_link(target_src, target_dst) + + context_src = os.path.join( + rank_dir, + record.get("context_audio_filepath", f"context_audio_{local_idx}.wav"), + ) + context_dst = os.path.join(repeat_audio_dir, f"context_audio_{global_idx}.wav") + _copy_or_link(context_src, context_dst) + + merged = dict(record) + merged["audio_filepath"] = f"target_audio_{global_idx}.wav" + merged["context_audio_filepath"] = f"context_audio_{global_idx}.wav" + merged["rank"] = rank + merged["rank_local_idx"] = local_idx + merged_records.append(merged) + global_idx += 1 + + merged_manifest = os.path.join(repeat_audio_dir, "multiturn_user_audio_turn_manifest.jsonl") + with open(merged_manifest, "w", encoding="utf-8") as f: + for record in merged_records: + f.write(json.dumps(record, ensure_ascii=False) + "\n") + + logging.info(f"Merged {len(merged_records)} multiturn turn records into {merged_manifest}") + return merged_manifest + + def run_inference_and_evaluation( runner: BaseInferenceRunner, checkpoint_name: str, @@ -216,6 +340,13 @@ def run_inference_and_evaluation( if not eval_config.with_utmosv2 and 'utmosv2' in violin_plot_metrics: violin_plot_metrics.remove('utmosv2') + rank, world_size, _ = _get_torchrun_rank_info() + is_distributed = world_size > 1 + is_multiturn_user_audio = getattr(runner, "produces_turn_level_evaluation", False) + + if hasattr(runner, "set_distributed_context"): + runner.set_distributed_context(rank=rank, world_size=world_size) + # Build full checkpoint identifier (include MoE info if present) full_checkpoint_name = ( f"{checkpoint_name}_{moe_info}{inference_config.build_identifier()}_SV_{eval_config.sv_model}" @@ -255,13 +386,17 @@ def run_inference_and_evaluation( # Setup CSV files per_run_csv = os.path.join(eval_dir, "all_experiment_metrics.csv") - write_csv_header_if_needed(per_run_csv, csv_header) + if rank == 0: + write_csv_header_if_needed(per_run_csv, csv_header) metrics_all_repeats = [] filewise_metrics_all_repeats = [] for repeat_idx in range(num_repeats): - logging.info(f"Repeat {repeat_idx + 1}/{num_repeats} for dataset {dataset}") + repeat_log_msg = f"Repeat {repeat_idx + 1}/{num_repeats} for dataset {dataset}" + if is_distributed: + repeat_log_msg += f", rank {rank}/{world_size}" + logging.info(repeat_log_msg) repeat_audio_dir = os.path.join(audio_dir, f"repeat_{repeat_idx}") os.makedirs(repeat_audio_dir, exist_ok=True) @@ -274,9 +409,21 @@ def run_inference_and_evaluation( f"Dataset length mismatch: {len(test_dataset)} vs {len(manifest_records)} manifest records" ) + if is_distributed and not is_multiturn_user_audio: + raise RuntimeError( + "torchrun multi-GPU sharding is currently implemented for " + "--easy_magpie_inference_mode multiturn_user_audio only. " + "Use the existing single-process path for single_turn/magpie, or add a " + "rank-safe merge path for those runners." + ) + + inference_output_dir = repeat_audio_dir + if is_distributed and is_multiturn_user_audio: + inference_output_dir = os.path.join(repeat_audio_dir, f"rank_{rank:04d}") + rtf_metrics_list, _, codec_file_paths = runner.run_inference_on_dataset( dataset=test_dataset, - output_dir=repeat_audio_dir, + output_dir=inference_output_dir, manifest_records=manifest_records, audio_base_dir=meta['audio_dir'], save_cross_attention_maps=True, @@ -293,7 +440,10 @@ def run_inference_and_evaluation( mean_rtf[f"{component_name}_{key}"] = value logging.info(f"{component_name} FLOPs per token: {component_flops['total_flops_per_token']:,}") - with open(os.path.join(eval_dir, f"{dataset}_rtf_metrics_{repeat_idx}.json"), "w") as f: + rtf_metrics_filename = f"{dataset}_rtf_metrics_{repeat_idx}.json" + if is_distributed: + rtf_metrics_filename = f"{dataset}_rtf_metrics_{repeat_idx}_rank{rank:04d}.json" + with open(os.path.join(eval_dir, rtf_metrics_filename), "w") as f: json.dump(mean_rtf, f, indent=4) if skip_evaluation: @@ -301,6 +451,26 @@ def run_inference_and_evaluation( continue # Run evaluation + if is_distributed and is_multiturn_user_audio: + if rank != 0: + # Non-zero ranks only generate. Rank 0 waits and evaluates merged outputs. + continue + + _wait_for_multiturn_rank_manifests(repeat_audio_dir, world_size) + merged_manifest_path = _merge_multiturn_rank_outputs( + repeat_audio_dir=repeat_audio_dir, + world_size=world_size, + save_predicted_codes=eval_config.with_fcd, + ) + eval_manifest_path = merged_manifest_path + eval_audio_dir = repeat_audio_dir + else: + eval_manifest_path, eval_audio_dir = _runner_eval_manifest_and_audio_dir( + runner, + default_manifest=meta['manifest_path'], + default_audio_dir=meta['audio_dir'], + ) + eval_config_for_dataset = EvaluationConfig( sv_model=eval_config.sv_model, asr_model_name=eval_config.asr_model_name, @@ -313,8 +483,8 @@ def run_inference_and_evaluation( ) metrics, filewise_metrics = evaluate_generated_audio_dir( - manifest_path=meta['manifest_path'], - audio_dir=meta['audio_dir'], + manifest_path=eval_manifest_path, + audio_dir=eval_audio_dir, generated_audio_dir=repeat_audio_dir, config=eval_config_for_dataset, ) @@ -338,8 +508,16 @@ def run_inference_and_evaluation( create_violin_plot(filewise_metrics, violin_plot_metrics, violin_path) # Delete temporary predicted codes files - for codec_file_path in codec_file_paths: - os.remove(codec_file_path) + if is_distributed and is_multiturn_user_audio: + for codec_file_path in Path(repeat_audio_dir).glob("predicted_codes_*.pt"): + if os.path.exists(codec_file_path): + os.remove(codec_file_path) + else: + for codec_file_path in codec_file_paths: + os.remove(codec_file_path) + + if rank != 0: + continue if skip_evaluation or not metrics_all_repeats: continue @@ -367,17 +545,17 @@ def run_inference_and_evaluation( cer_per_dataset.append(np.mean(cer_values)) # Create combined plot if we have multiple datasets - if len(all_datasets_filewise_metrics) > 1: + if rank == 0 and len(all_datasets_filewise_metrics) > 1: combined_plot_path = os.path.join(out_dir, f"{full_checkpoint_name}_combined_violin_plot.png") create_combined_box_plot(all_datasets_filewise_metrics, violin_plot_metrics, combined_plot_path) # Clean up if requested - if clean_up_disk: + if rank == 0 and clean_up_disk: logging.info(f"Cleaning up output directory: {out_dir}") shutil.rmtree(out_dir) # Return averaged metrics - if ssim_per_dataset and cer_per_dataset: + if rank == 0 and ssim_per_dataset and cer_per_dataset: return np.mean(cer_per_dataset), np.mean(ssim_per_dataset) return None, None @@ -580,6 +758,14 @@ def _add_magpie_args(parser: argparse.ArgumentParser) -> None: def _add_easy_magpie_args(parser: argparse.ArgumentParser) -> None: """Add arguments specific to decoder-only EasyMagpieTTSInferenceModel.""" group = parser.add_argument_group('EasyMagpieTTS-specific Parameters') + group.add_argument( + '--easy_magpie_inference_mode', + type=str, + default='single_turn', + choices=['single_turn', 'multiturn_user_audio'], + ) + group.add_argument('--max_eval_turns', type=int, default=6) + group.add_argument('--no_save_debug_multiturn_audio', action='store_true') group.add_argument( '--phoneme_input_type', type=str, @@ -591,7 +777,7 @@ def _add_easy_magpie_args(parser: argparse.ArgumentParser) -> None: '--phoneme_sampling_method', type=str, default='argmax', - choices=['argmax', 'multinomial'], + choices=['argmax', 'multinomial', 'greedy'], help='Sampling method for phoneme prediction', ) group.add_argument('--dropout_text_input', action='store_true', help='Force dropout on text input') @@ -649,7 +835,12 @@ def _build_magpie_config(args) -> MagpieInferenceConfig: def _build_easy_magpie_config(args) -> EasyMagpieInferenceConfig: - return EasyMagpieInferenceConfig( + cfg_cls = ( + EasyMagpieMultiturnUserAudioInferenceConfig + if args.easy_magpie_inference_mode == 'multiturn_user_audio' + else EasyMagpieInferenceConfig + ) + kwargs = dict( model_inference_parameters=_build_inference_params_from_args(EasyModelInferenceParameters, args), batch_size=args.batch_size, use_cfg=args.use_cfg, @@ -658,12 +849,33 @@ def _build_easy_magpie_config(args) -> EasyMagpieInferenceConfig: phoneme_sampling_method=args.phoneme_sampling_method, dropout_text_input=args.dropout_text_input, ) + if cfg_cls is EasyMagpieMultiturnUserAudioInferenceConfig: + kwargs.update( + max_eval_turns=args.max_eval_turns, + save_debug_multiturn_audio=not args.no_save_debug_multiturn_audio, + ) + return cfg_cls(**kwargs) + + +def _select_runner_cls(args): + if args.model_type == 'magpie': + if args.easy_magpie_inference_mode != 'single_turn': + raise ValueError('--easy_magpie_inference_mode is only supported with --model_type easy_magpie') + return MagpieInferenceRunner + if args.easy_magpie_inference_mode == 'multiturn_user_audio': + return EasyMagpieMultiturnUserAudioInferenceRunner + return EasyMagpieInferenceRunner def main(argv=None): """Entry point for TTS inference and evaluation.""" parser = create_argument_parser() args = parser.parse_args(argv) + if args.model_type == 'easy_magpie' and args.easy_magpie_inference_mode == 'multiturn_user_audio': + _configure_cuda_for_rank() + if args.batch_size > 1: + parser.error("--easy_magpie_inference_mode multiturn_user_audio requires --batch_size 1.") + if args.deterministic: seed_all(seed=9) @@ -689,7 +901,7 @@ def main(argv=None): is_easy_magpie = args.model_type == 'easy_magpie' load_fn = load_easy_magpie_model if is_easy_magpie else load_magpie_model inference_config = _build_easy_magpie_config(args) if is_easy_magpie else _build_magpie_config(args) - runner_cls = EasyMagpieInferenceRunner if is_easy_magpie else MagpieInferenceRunner + runner_cls = _select_runner_cls(args) eval_config = EvaluationConfig( sv_model=args.sv_model, @@ -807,4 +1019,4 @@ def main(argv=None): if __name__ == '__main__': - main() + main() \ No newline at end of file diff --git a/nemo/collections/tts/modules/magpietts_inference/inference.py b/nemo/collections/tts/modules/magpietts_inference/inference.py index 31f9a2105877..e4ddebaf6faf 100644 --- a/nemo/collections/tts/modules/magpietts_inference/inference.py +++ b/nemo/collections/tts/modules/magpietts_inference/inference.py @@ -726,7 +726,8 @@ def __init__( self.normalize_audio = normalize_audio self.records = read_manifest(manifest_path) # debug - self.records = self.records[:7] + # self.records = self.records[:7] + def __len__(self): return len(self.records) From 585decbbbf20a47614f37343d3e32e31b2139bcf Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 8 Jun 2026 15:48:26 -0700 Subject: [PATCH 075/109] Fix new inference volume norm Signed-off-by: Edresson Casanova --- .../modules/magpietts_inference/inference.py | 24 +++++++++---------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/nemo/collections/tts/modules/magpietts_inference/inference.py b/nemo/collections/tts/modules/magpietts_inference/inference.py index e4ddebaf6faf..680bcee9fb0a 100644 --- a/nemo/collections/tts/modules/magpietts_inference/inference.py +++ b/nemo/collections/tts/modules/magpietts_inference/inference.py @@ -749,20 +749,17 @@ def _load_audio_1d(self, path: str, sample_rate: int) -> torch.Tensor: raise FileNotFoundError(f"Missing audio path: {path}") audio, sr = sf.read(path, dtype="float32", always_2d=False) - wav = torch.as_tensor(audio, dtype=torch.float32) - if wav.ndim == 2: - wav = wav.mean(dim=1) - wav = wav.flatten() - if sr != sample_rate: - wav = resample(wav.unsqueeze(0), sr, sample_rate).squeeze(0) + if audio.ndim == 2: + audio = audio.mean(axis=1) if self.normalize_audio: - try: - wav = normalize_volume(wav) - except Exception: - # Keep evaluation robust across normalize_volume signature changes. - pass + audio = normalize_volume(audio) + + wav = torch.as_tensor(audio, dtype=torch.float32).flatten() + + if sr != sample_rate: + wav = resample(wav.unsqueeze(0), sr, sample_rate).squeeze(0) return wav.contiguous() @@ -1151,7 +1148,8 @@ def _run_multiturn_generation(self, batch: Dict[str, Any]): codes, codes_lens = model._codec_helper.audio_to_codes(wav, wav_len) use_lang = bool(getattr(model, "add_language_to_context_text", False)) - ctx_text = "[EN]" if use_lang else "[NO TEXT CONTEXT]" + language = getattr(self.config, "language", "en") + ctx_text = f"[{language.upper()}]" if use_lang else "[NO TEXT CONTEXT]" ctx_text_ids = model.tokenizer.encode(ctx_text, tokenizer_name=model.text_conditioning_tokenizer_name) ctx_toks = torch.tensor([ctx_text_ids], dtype=torch.long, device=device).expand(B, -1) ctx_toks_lens = torch.tensor([len(ctx_text_ids)] * B, dtype=torch.long, device=device) @@ -1226,7 +1224,7 @@ def _run_multiturn_generation(self, batch: Dict[str, Any]): ) user_audio_embedded = model.embed_audio_tokens(user_audio_codes) - boundary_trim = model.cfg.get("user_audio_boundary_trim", 4) + boundary_trim = model.cfg.get("user_audio_boundary_trim", 0) boundary_trim = 0 if boundary_trim is None else int(boundary_trim) if boundary_trim == 0: From 032e8df8e89fbcce250caca6f81b9f73961807a8 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 9 Jun 2026 04:54:07 -0700 Subject: [PATCH 076/109] Expose ASR and EOU batch sizes on config Signed-off-by: Edresson Casanova --- examples/tts/magpietts_inference.py | 6 ++++++ .../tts/modules/magpietts_inference/evaluation.py | 4 ++++ .../tts/modules/magpietts_inference/inference.py | 2 +- 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/examples/tts/magpietts_inference.py b/examples/tts/magpietts_inference.py index a8763010e5e7..b49adc18a883 100644 --- a/examples/tts/magpietts_inference.py +++ b/examples/tts/magpietts_inference.py @@ -480,6 +480,8 @@ def run_inference_and_evaluation( with_fcd=eval_config.with_fcd, codec_model_path=eval_config.codec_model_path, device=eval_config.device, + asr_batch_size=eval_config.asr_batch_size, + eou_batch_size=eval_config.eou_batch_size, ) metrics, filewise_metrics = evaluate_generated_audio_dir( @@ -719,6 +721,8 @@ def _add_common_args(parser: argparse.ArgumentParser) -> None: default=['cer', 'pred_context_ssim', 'utmosv2'], ) eval_group.add_argument('--disable_fcd', action='store_true') + eval_group.add_argument("--asr_batch_size", type=int, default=32) + eval_group.add_argument("--eou_batch_size", type=int, default=32) # Quality targets target_group = parser.add_argument_group('Quality Targets') @@ -910,6 +914,8 @@ def main(argv=None): with_utmosv2=not args.disable_utmosv2, with_fcd=not args.disable_fcd, codec_model_path=args.codecmodel_path if not args.disable_fcd else None, + asr_batch_size=args.asr_batch_size, + eou_batch_size=args.eou_batch_size, ) cer, ssim = None, None diff --git a/nemo/collections/tts/modules/magpietts_inference/evaluation.py b/nemo/collections/tts/modules/magpietts_inference/evaluation.py index 00b25614a9ab..bb9013cc9ff1 100644 --- a/nemo/collections/tts/modules/magpietts_inference/evaluation.py +++ b/nemo/collections/tts/modules/magpietts_inference/evaluation.py @@ -54,6 +54,8 @@ class EvaluationConfig: with_fcd: bool = True codec_model_path: str = None device: str = "cuda" + asr_batch_size: int = 32 + eou_batch_size: int = 32 def evaluate_generated_audio_dir( @@ -95,6 +97,8 @@ def evaluate_generated_audio_dir( codec_model_path=config.codec_model_path, device=config.device, eou_model_name=config.eou_model_name, + asr_batch_size=config.asr_batch_size, + eou_batch_size=config.eou_batch_size, ) return avg_metrics, filewise_metrics diff --git a/nemo/collections/tts/modules/magpietts_inference/inference.py b/nemo/collections/tts/modules/magpietts_inference/inference.py index 680bcee9fb0a..c85c6b6f235d 100644 --- a/nemo/collections/tts/modules/magpietts_inference/inference.py +++ b/nemo/collections/tts/modules/magpietts_inference/inference.py @@ -726,7 +726,7 @@ def __init__( self.normalize_audio = normalize_audio self.records = read_manifest(manifest_path) # debug - # self.records = self.records[:7] + self.records = self.records[:7] def __len__(self): return len(self.records) From 79a767028d30fc00cf866c21b12c99caae835bd4 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 9 Jun 2026 06:09:06 -0700 Subject: [PATCH 077/109] Get language from dataloader for multiturn eval Signed-off-by: Edresson Casanova --- examples/tts/magpietts_inference.py | 227 ++++++++++++++++++ .../modules/magpietts_inference/inference.py | 213 ++++++++++++++-- 2 files changed, 414 insertions(+), 26 deletions(-) diff --git a/examples/tts/magpietts_inference.py b/examples/tts/magpietts_inference.py index b49adc18a883..033fe0efd8e2 100644 --- a/examples/tts/magpietts_inference.py +++ b/examples/tts/magpietts_inference.py @@ -144,6 +144,224 @@ def append_metrics_to_csv(csv_path: str, checkpoint_name: str, dataset: str, met logging.info(f"Metrics appended to: {csv_path}") +def _mean_finite(values: list): + vals = [] + for value in values: + try: + value = float(value) + except Exception: + continue + if np.isfinite(value): + vals.append(value) + return None if not vals else float(np.mean(vals)) + + +def _enrich_filewise_metrics_with_manifest(filewise_metrics: list, manifest_path: str) -> list: + """Attach multiturn manifest metadata back to evaluator filewise rows. + + evaluate_generated_audio_dir() returns one filewise row per generated turn, + but the filtered row does not preserve source_sample_idx/turn_id. The + generated multiturn manifest has the same order as predicted_audio_*.wav, so + we merge by list index before grouping. + """ + if manifest_path is None or not os.path.exists(manifest_path): + logging.warning(f"Could not enrich multiturn filewise metrics; manifest missing: {manifest_path}") + return filewise_metrics + + manifest_records = read_manifest(manifest_path) + if len(manifest_records) != len(filewise_metrics): + logging.warning( + "Could not safely enrich multiturn filewise metrics; " + f"manifest rows={len(manifest_records)} filewise rows={len(filewise_metrics)} " + f"manifest_path={manifest_path}" + ) + return filewise_metrics + + enriched = [] + for row, record in zip(filewise_metrics, manifest_records): + new_row = dict(row) + for key in [ + "source_sample_idx", + "turn_id", + "speaker", + "rank", + "rank_local_idx", + "audio_filepath", + "context_audio_filepath", + ]: + if key in record and key not in new_row: + new_row[key] = record[key] + enriched.append(new_row) + + return enriched + + +def _group_multiturn_filewise_metrics_by_sample(filewise_metrics: list) -> list: + """Group turn-level multiturn metrics into one old-style row per sample. + + Each grouped row keeps turn-by-turn CER/WER/SSIM/UTMOS/text/audio lists plus + averaged sample-level values. Rows are sorted by averaged CER descending so + the worst conversations/samples appear first. + """ + grouped = {} + + for row_idx, row in enumerate(filewise_metrics): + source_sample_idx = row.get("source_sample_idx", None) + if source_sample_idx is None: + source_sample_idx = row.get("speaker", None) + if source_sample_idx is None: + source_sample_idx = row_idx + + key = str(source_sample_idx) + if key not in grouped: + grouped[key] = { + "source_sample_idx": source_sample_idx, + "speaker": row.get("speaker", source_sample_idx), + "rank": row.get("rank", None), + "target_audio_path": row.get("gt_audio_filepath", row.get("audio_filepath", "")), + "context_audio_path": row.get("context_audio_filepath", ""), + "turn_rows": [], + } + grouped[key]["turn_rows"].append(row) + + grouped_rows = [] + for _, group in grouped.items(): + turns = group["turn_rows"] + + def turn_sort_key(r): + try: + return int(r.get("turn_id", 0)) + except Exception: + return 0 + + turns = sorted(turns, key=turn_sort_key) + + cer_turns = [r.get("cer") for r in turns] + wer_turns = [r.get("wer") for r in turns] + pred_context_ssim_turns = [r.get("pred_context_ssim") for r in turns] + pred_gt_ssim_turns = [r.get("pred_gt_ssim") for r in turns] + gt_context_ssim_turns = [r.get("gt_context_ssim") for r in turns] + utmosv2_turns = [r.get("utmosv2") for r in turns] + eou_type_turns = [r.get("eou_type") for r in turns] + eou_trailing_duration_turns = [r.get("eou_trailing_duration") for r in turns] + eou_trail_rms_ratio_turns = [r.get("eou_trail_rms_ratio") for r in turns] + + grouped_rows.append( + { + "source_sample_idx": group["source_sample_idx"], + "speaker": group["speaker"], + "rank": group["rank"], + "num_turns": len(turns), + + # Sample-level averages over all turns. + "cer": _mean_finite(cer_turns), + "wer": _mean_finite(wer_turns), + "pred_context_ssim": _mean_finite(pred_context_ssim_turns), + "pred_gt_ssim": _mean_finite(pred_gt_ssim_turns), + "gt_context_ssim": _mean_finite(gt_context_ssim_turns), + "utmosv2": _mean_finite(utmosv2_turns), + "eou_trailing_duration": _mean_finite(eou_trailing_duration_turns), + "eou_trail_rms_ratio": _mean_finite(eou_trail_rms_ratio_turns), + + # Turn-by-turn values, old-script style. + "turn_ids": [r.get("turn_id", i) for i, r in enumerate(turns)], + "cer_turns": cer_turns, + "wer_turns": wer_turns, + "pred_context_ssim_turns": pred_context_ssim_turns, + "pred_gt_ssim_turns": pred_gt_ssim_turns, + "gt_context_ssim_turns": gt_context_ssim_turns, + "utmosv2_turns": utmosv2_turns, + "eou_type_turns": eou_type_turns, + "eou_trailing_duration_turns": eou_trailing_duration_turns, + "eou_trail_rms_ratio_turns": eou_trail_rms_ratio_turns, + "reference_text": [r.get("gt_text", "") for r in turns], + "asr_hyp": [r.get("pred_text", "") for r in turns], + "pred_audio_paths": [r.get("pred_audio_filepath", "") for r in turns], + "target_audio_path": group["target_audio_path"], + "context_audio_path": group["context_audio_path"], + "turn_metrics": turns, + } + ) + + grouped_rows.sort( + key=lambda r: ( + r.get("cer") is not None, + float(r["cer"]) if r.get("cer") is not None else -1.0, + ), + reverse=True, + ) + return grouped_rows + + +def _write_grouped_multiturn_filewise_metrics_csv(csv_path: str, grouped_rows: list) -> None: + fieldnames = [ + "source_sample_idx", + "speaker", + "rank", + "num_turns", + "cer", + "wer", + "pred_context_ssim", + "pred_gt_ssim", + "gt_context_ssim", + "utmosv2", + "eou_trailing_duration", + "eou_trail_rms_ratio", + "turn_ids", + "cer_turns", + "wer_turns", + "pred_context_ssim_turns", + "pred_gt_ssim_turns", + "gt_context_ssim_turns", + "utmosv2_turns", + "eou_type_turns", + "eou_trailing_duration_turns", + "eou_trail_rms_ratio_turns", + "target_audio_path", + "context_audio_path", + "pred_audio_paths", + "reference_text", + "asr_hyp", + ] + + def csv_value(value): + if isinstance(value, (list, dict)): + value = json.dumps(value, ensure_ascii=False) + if value is None: + value = "" + value = str(value).replace('"', '""') + if "," in value or "\n" in value or "[" in value or "{" in value: + value = f'"{value}"' + return value + + with open(csv_path, "w", encoding="utf-8") as f: + f.write(",".join(fieldnames) + "\n") + for row in grouped_rows: + f.write(",".join(csv_value(row.get(k, "")) for k in fieldnames) + "\n") + + +def _save_grouped_multiturn_filewise_metrics( + eval_dir: str, + dataset: str, + repeat_idx: int, + filewise_metrics: list, + manifest_path: str, +) -> None: + enriched_filewise = _enrich_filewise_metrics_with_manifest(filewise_metrics, manifest_path) + grouped_rows = _group_multiturn_filewise_metrics_by_sample(enriched_filewise) + + json_path = os.path.join(eval_dir, f"{dataset}_grouped_filewise_metrics_{repeat_idx}.json") + csv_path = os.path.join(eval_dir, f"{dataset}_grouped_filewise_metrics_{repeat_idx}.csv") + + with open(json_path, "w", encoding="utf-8") as f: + json.dump(grouped_rows, f, indent=4, ensure_ascii=False) + + _write_grouped_multiturn_filewise_metrics_csv(csv_path, grouped_rows) + + logging.info(f"Saved grouped multiturn filewise metrics JSON to: {json_path}") + logging.info(f"Saved grouped multiturn filewise metrics CSV to: {csv_path}") + + def create_formatted_metrics_mean_ci(metrics_mean_ci: dict) -> dict: """Create formatted metrics mean CI.""" for k, v in metrics_mean_ci.items(): @@ -502,6 +720,15 @@ def run_inference_and_evaluation( with open(os.path.join(eval_dir, f"{dataset}_filewise_metrics_{repeat_idx}.json"), "w") as f: json.dump(sorted_filewise, f, indent=4) + if is_multiturn_user_audio: + _save_grouped_multiturn_filewise_metrics( + eval_dir=eval_dir, + dataset=dataset, + repeat_idx=repeat_idx, + filewise_metrics=filewise_metrics, + manifest_path=eval_manifest_path, + ) + # Append to per-run CSV append_metrics_to_csv(per_run_csv, full_checkpoint_name, dataset, metrics) diff --git a/nemo/collections/tts/modules/magpietts_inference/inference.py b/nemo/collections/tts/modules/magpietts_inference/inference.py index c85c6b6f235d..947579471dbc 100644 --- a/nemo/collections/tts/modules/magpietts_inference/inference.py +++ b/nemo/collections/tts/modules/magpietts_inference/inference.py @@ -725,8 +725,6 @@ def __init__( self.max_eval_turns = max_eval_turns self.normalize_audio = normalize_audio self.records = read_manifest(manifest_path) - # debug - self.records = self.records[:7] def __len__(self): return len(self.records) @@ -749,18 +747,21 @@ def _load_audio_1d(self, path: str, sample_rate: int) -> torch.Tensor: raise FileNotFoundError(f"Missing audio path: {path}") audio, sr = sf.read(path, dtype="float32", always_2d=False) - - if audio.ndim == 2: - audio = audio.mean(axis=1) - - if self.normalize_audio: - audio = normalize_volume(audio) - - wav = torch.as_tensor(audio, dtype=torch.float32).flatten() + wav = torch.as_tensor(audio, dtype=torch.float32) + if wav.ndim == 2: + wav = wav.mean(dim=1) + wav = wav.flatten() if sr != sample_rate: wav = resample(wav.unsqueeze(0), sr, sample_rate).squeeze(0) + if self.normalize_audio: + try: + wav = normalize_volume(wav) + except Exception: + # Keep evaluation robust across normalize_volume signature changes. + pass + return wav.contiguous() @staticmethod @@ -825,6 +826,7 @@ def collate_fn(self, batch: List[dict]) -> Dict[str, Any]: "user_audio_turns_lens": user_audio_turns_lens, "target_audio_path": sample.get("audio_filepath"), "target_turn_audio_paths": target_turn_audio_paths, + "languages": [sample.get("language", "en")], } @@ -1147,12 +1149,42 @@ def _run_multiturn_generation(self, batch: Dict[str, Any]): wav_len = batch["context_audio_lengths"] codes, codes_lens = model._codec_helper.audio_to_codes(wav, wav_len) + # add language on context if needed use_lang = bool(getattr(model, "add_language_to_context_text", False)) - language = getattr(self.config, "language", "en") - ctx_text = f"[{language.upper()}]" if use_lang else "[NO TEXT CONTEXT]" - ctx_text_ids = model.tokenizer.encode(ctx_text, tokenizer_name=model.text_conditioning_tokenizer_name) - ctx_toks = torch.tensor([ctx_text_ids], dtype=torch.long, device=device).expand(B, -1) - ctx_toks_lens = torch.tensor([len(ctx_text_ids)] * B, dtype=torch.long, device=device) + + languages = batch.get("languages", None) + if languages is None: + raise RuntimeError("Missing batch['languages']; collate_fn must provide one language per sample.") + + if not isinstance(languages, (list, tuple)): + raise RuntimeError(f"Expected batch['languages'] to be a list/tuple, got {type(languages)}") + + if len(languages) != B: + raise RuntimeError(f"Expected {B} language entries from collate_fn, got {len(languages)}") + + ctx_texts = [] + for b, language in enumerate(languages): + if language is None or language == "": + raise RuntimeError(f"Missing language for batch item {b}") + ctx_texts.append(f"[{str(language).upper()}]" if use_lang else "[NO TEXT CONTEXT]") + + ctx_text_ids_list = [ + model.tokenizer.encode(ctx_text, tokenizer_name=model.text_conditioning_tokenizer_name) + for ctx_text in ctx_texts + ] + + ctx_toks_lens = torch.tensor( + [len(ctx_text_ids) for ctx_text_ids in ctx_text_ids_list], + dtype=torch.long, + device=device, + ) + + max_ctx_len = int(ctx_toks_lens.max().item()) + ctx_pad_id = int(getattr(model, "pad_id", 0)) + ctx_toks = torch.full((B, max_ctx_len), ctx_pad_id, dtype=torch.long, device=device) + + for b, ctx_text_ids in enumerate(ctx_text_ids_list): + ctx_toks[b, : len(ctx_text_ids)] = torch.tensor(ctx_text_ids, dtype=torch.long, device=device) params = self.config.model_inference_parameters state = model.streaming_init( @@ -1224,7 +1256,7 @@ def _run_multiturn_generation(self, batch: Dict[str, Any]): ) user_audio_embedded = model.embed_audio_tokens(user_audio_codes) - boundary_trim = model.cfg.get("user_audio_boundary_trim", 0) + boundary_trim = model.cfg.get("user_audio_boundary_trim", 4) boundary_trim = 0 if boundary_trim is None else int(boundary_trim) if boundary_trim == 0: @@ -1418,6 +1450,69 @@ def _resolve_target_audio_for_turn( ) return None + @staticmethod + def _get_multiturn_debug_output_dir(output_dir: str) -> str: + """Return a sibling audios_MT path for debug/listening artifacts. + + Examples: + /audio/repeat_0 + -> /audios_MT/repeat_0 + /audio/repeat_0/rank_0003 + -> /audios_MT/repeat_0/rank_0003 + + Evaluation files stay in the standard audio/ directory; only + human-listening/debug multiturn files go under audios_MT/. + """ + normalized = os.path.normpath(output_dir) + parts = normalized.split(os.sep) + for i in range(len(parts) - 1, -1, -1): + if parts[i] == "audio": + parts[i] = "audios_MT" + prefix = os.sep if normalized.startswith(os.sep) else "" + return prefix + os.path.join(*[p for p in parts if p != ""]) + return normalized + "_MT" + + @staticmethod + def _safe_audio_stem(path_or_name: Optional[str], fallback: str) -> str: + if path_or_name is None or path_or_name == "": + stem = fallback + else: + stem = os.path.splitext(os.path.basename(str(path_or_name)))[0] or fallback + safe = "".join(ch if ch.isalnum() or ch in ("-", "_", ".") else "_" for ch in stem) + return safe or fallback + + def _target_audio_stem_for_debug( + self, + raw_record: dict, + sample_idx: int, + local_turn_idx: Optional[int] = None, + ) -> str: + """Choose a readable debug/listening filename stem from original manifest target audio. + + This does not affect evaluator file names. It is only used under audios_MT/. + Prefer per-turn target fields when present; otherwise use sample-level + audio_filepath. Fall back to sample_. + """ + fallback = f"sample_{sample_idx}" + candidates = [ + raw_record.get("target_audio_file_path"), + raw_record.get("target_audio_filepath"), + raw_record.get("audio_filepath"), + ] + for candidate in candidates: + if candidate is None or candidate == "": + continue + if isinstance(candidate, list): + if local_turn_idx is None: + candidate = candidate[0] if candidate else None + elif local_turn_idx < len(candidate): + candidate = candidate[local_turn_idx] + else: + candidate = None + if candidate: + return self._safe_audio_stem(candidate, fallback) + return fallback + def _run_multiturn_user_audio_inference( self, dataset: EasyMagpieMultiturnUserAudioDataset, @@ -1430,22 +1525,37 @@ def _run_multiturn_user_audio_inference( os.makedirs(output_dir, exist_ok=True) self._delete_old_generated_files(output_dir) - debug_user_dir = os.path.join(output_dir, "debug_user_turns") - debug_mixed_dir = os.path.join(output_dir, "debug_mixed_user_agent") + mt_debug_output_dir = self._get_multiturn_debug_output_dir(output_dir) + debug_user_dir = os.path.join(mt_debug_output_dir, "debug_user_turns") + debug_mixed_dir = os.path.join(mt_debug_output_dir, "debug_mixed_user_agent") + debug_full_agent_dir = os.path.join(mt_debug_output_dir, "debug_full_agent") if self.config.save_debug_multiturn_audio: os.makedirs(debug_user_dir, exist_ok=True) os.makedirs(debug_mixed_dir, exist_ok=True) + os.makedirs(debug_full_agent_dir, exist_ok=True) + logging.info( + f"Saving multiturn debug/listening audios under audios_MT: {mt_debug_output_dir}" + ) rank = int(getattr(self, "distributed_rank", 0)) world_size = int(getattr(self, "distributed_world_size", 1)) + total_samples = len(dataset) + if world_size > 1: - rank_indices = list(range(rank, len(dataset), world_size)) + rank_indices = list(range(rank, total_samples, world_size)) logging.info( - f"multiturn_user_audio distributed sharding: rank={rank}/{world_size}, " - f"local_samples={len(rank_indices)}, total_samples={len(dataset)}" + f"[MT_USER_AUDIO_SPLIT] rank={rank}/{world_size} " + f"total_samples={total_samples} local_samples={len(rank_indices)} " + f"indices={rank_indices}" ) dataset_for_rank = _InferenceSubset(dataset, rank_indices) else: + rank_indices = list(range(total_samples)) + logging.info( + f"[MT_USER_AUDIO_SPLIT] rank=0/1 " + f"total_samples={total_samples} local_samples={len(rank_indices)} " + f"indices={rank_indices}" + ) dataset_for_rank = dataset dataloader = torch.utils.data.DataLoader( @@ -1470,6 +1580,11 @@ def _run_multiturn_user_audio_inference( sample_idx = int(batch["idx"][0].item()) raw_record = batch["raw_record"] raw_turn_texts = batch["raw_turn_texts"][0] + logging.info( + f"[MT_USER_AUDIO_PROCESS] rank={rank}/{world_size} " + f"local_batch={batch_idx} global_sample_idx={sample_idx} " + f"num_turns={len(raw_turn_texts)}" + ) start_time = time.time() output, turn_frame_ranges, decode_start_frame, generated_codes = self._run_multiturn_generation(batch) @@ -1556,15 +1671,29 @@ def _run_multiturn_user_audio_inference( "turn_id": int(turn_id), } ) + logging.info( + f"[MT_USER_AUDIO_TURN] rank={rank}/{world_size} " + f"global_sample_idx={sample_idx} turn_id={int(turn_id)} " + f"rank_local_item_idx={item_idx} " + f"predicted_audio=predicted_audio_{item_idx}.wav " + f"target_audio=target_audio_{item_idx}.wav " + f"context_audio=context_audio_{item_idx}.wav" + ) item_idx += 1 - full_agent_path = os.path.join(output_dir, f"predicted_audio_sample_{sample_idx}_full_agent.wav") - sf.write(full_agent_path, aligned_agent.numpy(), sample_rate) + if self.config.save_debug_multiturn_audio: + debug_sample_stem = self._target_audio_stem_for_debug(raw_record, sample_idx) + full_agent_path = os.path.join( + debug_full_agent_dir, + f"{debug_sample_stem}__sample_{sample_idx}__predicted_full_agent.wav", + ) + sf.write(full_agent_path, aligned_agent.numpy(), sample_rate) if self.config.save_debug_multiturn_audio and "user_audio_turns" in batch: self._save_debug_user_agent_audio( batch=batch, sample_idx=sample_idx, + raw_record=raw_record, turn_frame_ranges=turn_frame_ranges, decode_start_frame=decode_start_frame, aligned_agent=aligned_agent, @@ -1593,6 +1722,14 @@ def _run_multiturn_user_audio_inference( for record in turn_manifest_records: f.write(json.dumps(record, ensure_ascii=False) + "\n") + logging.info( + f"[MT_USER_AUDIO_SUMMARY] rank={rank}/{world_size} " + f"assigned_samples={rank_indices} " + f"generated_turns={len(turn_manifest_records)} " + f"generated_audio_paths={len(generated_audio_paths)} " + f"predicted_code_paths={len(codec_file_paths)} " + f"manifest={self.evaluation_manifest_path}" + ) logging.info(f"Wrote multiturn turn-level evaluation manifest: {self.evaluation_manifest_path}") return all_rtf_metrics, generated_audio_paths, codec_file_paths @@ -1600,6 +1737,7 @@ def _save_debug_user_agent_audio( self, batch: Dict[str, Any], sample_idx: int, + raw_record: dict, turn_frame_ranges: List[Tuple[int, int, int]], decode_start_frame: int, aligned_agent: torch.Tensor, @@ -1609,6 +1747,7 @@ def _save_debug_user_agent_audio( debug_mixed_dir: str, ) -> None: sample_rate = getattr(self.model, "output_sample_rate", self.model.sample_rate) + debug_sample_stem = self._target_audio_stem_for_debug(raw_record, sample_idx) first_user_len_in = int(batch["user_audio_turns_lens"][0][0].detach().cpu().item()) first_user_delay_out = int(round(first_user_len_in * sample_rate / self.model.sample_rate)) @@ -1621,7 +1760,15 @@ def _save_debug_user_agent_audio( turn_audio = turn_audio[:turn_audio_len] turn_audio_out = resample(turn_audio.unsqueeze(0), self.model.sample_rate, sample_rate).squeeze(0) - user_turn_path = os.path.join(debug_user_dir, f"sample_{sample_idx}_user_turn_{turn_id}.wav") + debug_turn_stem = self._target_audio_stem_for_debug( + raw_record, + sample_idx, + local_turn_idx=int(turn_id), + ) + user_turn_path = os.path.join( + debug_user_dir, + f"{debug_turn_stem}__sample_{sample_idx}__turn_{turn_id}__user.wav", + ) sf.write(user_turn_path, turn_audio_out.numpy(), sample_rate) if turn_id == 0: @@ -1647,6 +1794,20 @@ def _save_debug_user_agent_audio( agent_pad[: agent_ch.numel()] = agent_ch mono_mix = torch.clamp(user_pad + agent_pad, min=-1.0, max=1.0) - sf.write(os.path.join(debug_mixed_dir, f"sample_{sample_idx}_user_agent_mixed_mono.wav"), mono_mix.numpy(), sample_rate) + sf.write( + os.path.join( + debug_mixed_dir, + f"{debug_sample_stem}__sample_{sample_idx}__user_agent_mixed_mono.wav", + ), + mono_mix.numpy(), + sample_rate, + ) stereo = torch.stack([user_pad, agent_pad], dim=1).numpy() - sf.write(os.path.join(debug_mixed_dir, f"sample_{sample_idx}_user_agent_aligned_stereo.wav"), stereo, sample_rate) + sf.write( + os.path.join( + debug_mixed_dir, + f"{debug_sample_stem}__sample_{sample_idx}__user_agent_aligned_stereo.wav", + ), + stereo, + sample_rate, + ) From 2aa37fef334d2c019ab1cd5f022e3c4e05b7df30 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 9 Jun 2026 06:10:06 -0700 Subject: [PATCH 078/109] Remove old multiturn eval scripts Signed-off-by: Edresson Casanova --- .../tts/easy_magpietts_inference_multiturn.py | 3240 ----------------- ..._magpietts_inference_multiturn_multigpu.py | 1992 ---------- ...nference_multiturn_multigpu_turn_metric.py | 2247 ------------ ...sy_magpietts_inference_multiturn_runner.py | 750 ---- 4 files changed, 8229 deletions(-) delete mode 100644 examples/tts/easy_magpietts_inference_multiturn.py delete mode 100644 examples/tts/easy_magpietts_inference_multiturn_multigpu.py delete mode 100644 examples/tts/easy_magpietts_inference_multiturn_multigpu_turn_metric.py delete mode 100644 examples/tts/easy_magpietts_inference_multiturn_runner.py diff --git a/examples/tts/easy_magpietts_inference_multiturn.py b/examples/tts/easy_magpietts_inference_multiturn.py deleted file mode 100644 index f9c54f1c5393..000000000000 --- a/examples/tts/easy_magpietts_inference_multiturn.py +++ /dev/null @@ -1,3240 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -""" -Multi-GPU EasyMagpieTTS / NemotronTTS multiturn inference evaluation. - -Key behavior: - - Uses torchrun env vars RANK, LOCAL_RANK, WORLD_SIZE for sharding/GPU assignment. - - Does NOT initialize torch.distributed. This avoids NeMo ASR doing distributed - collectives during metric computation. - - Generation runs first for all assigned samples. - - ASR and speaker-similarity models are loaded only after generation is done and the TTS/codec model - has been deleted from GPU memory. - - ASR and speaker-similarity models are loaded sequentially: ASR first, then released; speaker-similarity second. - - Supports multiturn-user-audio and regular single-turn inference; metrics are turn/file based. - Final filewise outputs are grouped back to one row per original sample, with - lists for asr_hyp/reference_text/cer_turns/wer_turns/ssim_turns. - - Uses DistributedSampler with explicit rank/world_size. A few repeated samples - may appear when len(dataset) is not divisible by world_size. Filewise final - metrics deduplicate sampler-padding repeats by (run_id, dataset_index, - turn_id), then group turns into one row per sample with metric lists, while - preserving --num_eval_runs repetitions. - - --sort_by_text_token_count sorts samples by total text-token count before - sharding to improve GPU load balance. - - Saves audio in out_dir/audios/. - - Saves metrics in out_dir/. - -Recommended single-node torchrun: - CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ - torchrun --standalone --nproc_per_node=8 easy_magpietts_inference_multiturn_multigpu_postgen_metrics.py ... - -Recommended single-node srun wrapper: - srun --nodes=1 --ntasks=1 --ntasks-per-node=1 --container-image=... \ - bash -lc 'torchrun --standalone --nproc_per_node=8 easy_magpietts_inference_multiturn_multigpu_postgen_metrics.py ...' -""" - -import argparse -import csv -import json -import math -import os -import socket -import shutil -import time -from collections import Counter -from copy import deepcopy -from functools import partial -from typing import Any, Dict, Iterable, List, Tuple - -import librosa -import soundfile as sf -import torch -from omegaconf import open_dict -from torch.nn.utils.rnn import pad_sequence -from torch.utils.data import DataLoader, Dataset, DistributedSampler, SequentialSampler - -from nemo.collections.audio.parts.utils.transforms import resample -from nemo.collections.asr.metrics.wer import word_error_rate, word_error_rate_detail -from nemo.collections.speechlm2.parts.precision import fp32_precision -from nemo.collections.tts.models import AudioCodecModel -from nemo.collections.tts.models.easy_magpietts_inference import EasyMagpieTTSInferenceModel -from nemo.collections.tts.modules.audio_codec_modules import VectorQuantizerIndexConverter -from nemo.collections.tts.modules.magpietts_modules import CodecHelper -from nemo.collections.tts.parts.utils.tts_dataset_utils import normalize_volume -from nemo.utils import logging -from whisper_normalizer.english import EnglishTextNormalizer - -try: - import nemo.collections.asr as nemo_asr -except Exception: - nemo_asr = None - -try: - from nemo.collections.asr.models import ASRModel -except Exception: - ASRModel = None - -try: - from transformers import Wav2Vec2FeatureExtractor, WavLMForXVector -except Exception: - Wav2Vec2FeatureExtractor = None - WavLMForXVector = None - -try: - from nemo.collections.tts.modules.magpietts_inference.evaluate_generated_audio import ( - compute_utmosv2_scores, - extract_embedding, - ) -except Exception: - compute_utmosv2_scores = None - extract_embedding = None - -try: - from nemo.collections.tts.metrics.eou_classifier import EoUClassifier, EoUType -except Exception: - EoUClassifier = None - EoUType = None - -try: - from nemo.collections.tts.modules.magpietts_inference.evaluation import DEFAULT_VIOLIN_METRICS -except Exception: - DEFAULT_VIOLIN_METRICS = ['cer', 'pred_context_ssim', 'utmosv2'] - -try: - from nemo.collections.tts.modules.magpietts_inference.visualization import create_violin_plot -except Exception: - create_violin_plot = None - -try: - from nemo.collections.tts.metrics.frechet_codec_distance import FrechetCodecDistance -except Exception: - FrechetCodecDistance = None - - - -torch.set_float32_matmul_precision("medium") -torch.backends.cudnn.allow_tf32 = True -torch.backends.cuda.matmul.allow_tf32 = True - - -# ----------------------------- -# Rank / file helpers -# ----------------------------- - - -def get_rank_info() -> Tuple[bool, int, int, int]: - world_size = int(os.environ.get("WORLD_SIZE", "1")) - rank = int(os.environ.get("RANK", os.environ.get("SLURM_PROCID", "0"))) - local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("SLURM_LOCALID", "0"))) - distributed = world_size > 1 - return distributed, rank, local_rank, world_size - - -def get_visible_device_index(local_rank: int) -> int: - if not torch.cuda.is_available(): - return -1 - ndev = torch.cuda.device_count() - if ndev <= 0: - return -1 - return local_rank % ndev - - -def setup_distributed(): - """ - Do not initialize torch.distributed. - - We only need RANK/LOCAL_RANK/WORLD_SIZE for rank assignment and dataset - sharding. Initializing a process group can cause NeMo ASR to run distributed - collectives during transcribe(), which may hang when ranks have different - audio lengths or workloads. - """ - distributed, rank, local_rank, world_size = get_rank_info() - device_index = get_visible_device_index(local_rank) - - if torch.cuda.is_available() and device_index >= 0: - torch.cuda.set_device(device_index) - - return distributed, rank, local_rank, world_size, device_index - - -def cleanup_distributed(): - return - - -def all_rank_print(rank: int, msg: str): - print(f"[rank={rank}] {msg}", flush=True) - - -def rank0_print(rank: int, msg: str): - if rank == 0: - print(msg, flush=True) - - -def get_audio_out_dir(args) -> str: - return os.path.join(args.out_dir, "audios") - - -def get_generated_turn_audio_dir(args) -> str: - return os.path.join(get_audio_out_dir(args), "metric_turns") - - -def get_context_metric_audio_dir(args) -> str: - return os.path.join(get_audio_out_dir(args), "metric_context") - - -def get_predicted_codes_dir(args) -> str: - return os.path.join(get_audio_out_dir(args), "predicted_codes") - - -def write_json(path: str, obj: Dict[str, Any]): - os.makedirs(os.path.dirname(path), exist_ok=True) - tmp_path = path + ".tmp" - with open(tmp_path, "w", encoding="utf-8") as f: - json.dump(obj, f, indent=2, sort_keys=True, ensure_ascii=False) - os.replace(tmp_path, path) - - -def write_text_atomic(path: str, text: str): - os.makedirs(os.path.dirname(path), exist_ok=True) - tmp_path = path + ".tmp" - with open(tmp_path, "w", encoding="utf-8") as f: - f.write(text) - os.replace(tmp_path, path) - - -def write_jsonl(path: str, rows: List[Dict[str, Any]]): - os.makedirs(os.path.dirname(path), exist_ok=True) - tmp_path = path + ".tmp" - with open(tmp_path, "w", encoding="utf-8") as f: - for row in rows: - f.write(json.dumps(row, sort_keys=True, ensure_ascii=False) + "\n") - os.replace(tmp_path, path) - -def write_csv_header_if_needed(csv_path: str, header: str) -> None: - os.makedirs(os.path.dirname(csv_path), exist_ok=True) - if not os.path.exists(csv_path): - with open(csv_path, "w", encoding="utf-8") as f: - f.write(header + "\n") - - -def append_metrics_to_csv(csv_path: str, checkpoint_name: str, dataset: str, metrics: Dict[str, Any]) -> None: - """Append metrics using the same column order as MagpieTTS inference/eval.""" - csv_header = ( - "checkpoint_name,dataset,cer_filewise_avg,wer_filewise_avg,cer_cumulative," - "wer_cumulative,ssim_pred_gt_avg,ssim_pred_context_avg,ssim_gt_context_avg," - "ssim_pred_gt_avg_alternate,ssim_pred_context_avg_alternate," - "ssim_gt_context_avg_alternate,cer_gt_audio_cumulative,wer_gt_audio_cumulative," - "utmosv2_avg,total_gen_audio_seconds,frechet_codec_distance," - "eou_cutoff_rate,eou_silence_rate,eou_noise_rate,eou_error_rate" - ) - write_csv_header_if_needed(csv_path, csv_header) - - values = [ - checkpoint_name, - dataset, - metrics.get("cer_filewise_avg", ""), - metrics.get("wer_filewise_avg", ""), - metrics.get("cer_cumulative", ""), - metrics.get("wer_cumulative", ""), - metrics.get("ssim_pred_gt_avg", ""), - metrics.get("ssim_pred_context_avg", ""), - metrics.get("ssim_gt_context_avg", ""), - metrics.get("ssim_pred_gt_avg_alternate", ""), - metrics.get("ssim_pred_context_avg_alternate", ""), - metrics.get("ssim_gt_context_avg_alternate", ""), - metrics.get("cer_gt_audio_cumulative", ""), - metrics.get("wer_gt_audio_cumulative", ""), - metrics.get("utmosv2_avg", ""), - metrics.get("total_gen_audio_seconds", ""), - metrics.get("frechet_codec_distance", ""), - metrics.get("eou_cutoff_rate", ""), - metrics.get("eou_silence_rate", ""), - metrics.get("eou_noise_rate", ""), - metrics.get("eou_error_rate", ""), - ] - - def clean_csv_value(v): - if v is None: - return "" - if isinstance(v, float) and not math.isfinite(v): - return "nan" - return str(v).replace(",", " ") - - with open(csv_path, "a", encoding="utf-8") as f: - f.write(",".join(clean_csv_value(v) for v in values) + "\n") - logging.info(f"Metrics appended to: {csv_path}") - - -def get_checkpoint_name(args) -> str: - checkpoint_path = getattr(args, "checkpoint_path", None) - if checkpoint_path: - stem = os.path.basename(checkpoint_path) - if stem.endswith(".nemo"): - stem = stem[:-5] - return stem - return "checkpoint" - - -def get_dataset_name(args) -> str: - out_name = os.path.basename(os.path.normpath(args.out_dir)) - if out_name: - return out_name - dataset_path = getattr(args, "datasets_json_path", None) - return os.path.splitext(os.path.basename(dataset_path))[0] if dataset_path else "dataset" - - -def create_violin_plot_if_available(metrics: List[Dict[str, Any]], metric_keys: List[str], output_path: str): - if create_violin_plot is None: - logging.warning( - "create_violin_plot is unavailable; skipping violin plot. " - "Make sure nemo.collections.tts.modules.magpietts_inference.visualization is importable." - ) - return - - if not metrics: - logging.warning(f"No metrics available for violin plot: {output_path}") - return - - available_keys = [] - for key in metric_keys: - for row in metrics: - value = row.get(key, None) - if value is None: - continue - try: - value = float(value) - except Exception: - continue - if math.isfinite(value): - available_keys.append(key) - break - - if not available_keys: - logging.warning(f"No finite requested plot metrics available for violin plot: {output_path}") - return - - os.makedirs(os.path.dirname(output_path), exist_ok=True) - create_violin_plot(metrics, available_keys, output_path) - - -def _copy_or_link(src: str, dst: str): - if src is None or not src or not os.path.exists(src): - return None - os.makedirs(os.path.dirname(dst), exist_ok=True) - try: - if os.path.lexists(dst): - os.remove(dst) - os.symlink(os.path.abspath(src), dst) - except Exception: - shutil.copyfile(src, dst) - return dst - - -def write_easymagpie_generated_audio_dir(args, sample_rows: List[Dict[str, Any]]): - """Write EasyMagpie/MagpieTTS-style generated audio/code files. - - This creates files named predicted_audio_*.wav and predicted_codes_*.pt, - plus target/context audio files, so downstream EasyMagpie demo/report tools - can consume this output directory. - """ - generated_audio_dir = os.path.join(args.out_dir, "easy_magpie_generated_audio") - os.makedirs(generated_audio_dir, exist_ok=True) - - manifest_rows = [] - filewise_rows = [] - - rows = sorted(sample_rows, key=lambda r: (int(r.get("run_id", 0)), int(r.get("dataset_index", -1)))) - - for item_idx, row in enumerate(rows): - pred_src = row.get("sample_pred_audio_path") or ( - row.get("pred_audio_paths", [None])[0] if isinstance(row.get("pred_audio_paths"), list) else None - ) - code_src = row.get("sample_predicted_codes_path") or ( - row.get("predicted_codes_paths", [None])[0] if isinstance(row.get("predicted_codes_paths"), list) else None - ) - target_src = _resolve_audio_path(row.get("target_audio_path"), args.audio_dir) - context_src = row.get("context_audio_path") - - pred_dst = os.path.join(generated_audio_dir, f"predicted_audio_{item_idx}.wav") - code_dst = os.path.join(generated_audio_dir, f"predicted_codes_{item_idx}.pt") - target_dst = os.path.join(generated_audio_dir, f"target_audio_{item_idx}.wav") - context_dst = os.path.join(generated_audio_dir, f"context_audio_{item_idx}.wav") - - _copy_or_link(pred_src, pred_dst) - _copy_or_link(code_src, code_dst) - _copy_or_link(target_src, target_dst) - _copy_or_link(context_src, context_dst) - - reference_text = row.get("reference_text", "") - if isinstance(reference_text, list): - manifest_text = " ".join(str(x) for x in reference_text) - else: - manifest_text = str(reference_text) - - manifest_rows.append( - { - "audio_filepath": f"target_audio_{item_idx}.wav", - "context_audio_filepath": f"context_audio_{item_idx}.wav", - "text": manifest_text, - "speaker": row.get("dataset_index", item_idx), - "original_dataset_index": row.get("dataset_index"), - "run_id": row.get("run_id", 0), - } - ) - - metric_row = dict(row) - metric_row.update( - { - "easy_magpie_item_idx": item_idx, - "gt_audio_filepath": target_dst if os.path.exists(target_dst) else target_src, - "pred_audio_filepath": pred_dst if os.path.exists(pred_dst) else pred_src, - "context_audio_filepath": context_dst if os.path.exists(context_dst) else context_src, - "predicted_codes_path": code_dst if os.path.exists(code_dst) else code_src, - } - ) - filewise_rows.append(metric_row) - - manifest_path = os.path.join(args.out_dir, "easy_magpie_generated_manifest.jsonl") - filewise_path = os.path.join(args.out_dir, "easy_magpie_generated_filewise_metrics.json") - write_jsonl(manifest_path, manifest_rows) - write_json(filewise_path, {"filewise_metrics": filewise_rows}) - - logging.info(f"Saved EasyMagpie-style generated audio dir to: {generated_audio_dir}") - logging.info(f"Saved EasyMagpie-style generated manifest to: {manifest_path}") - logging.info(f"Saved EasyMagpie-style generated filewise metrics to: {filewise_path}") - - return { - "generated_audio_dir": generated_audio_dir, - "manifest_path": manifest_path, - "filewise_metrics_path": filewise_path, - } - - -def save_easymagpie_style_eval_outputs(args, sample_rows: List[Dict[str, Any]], filewise_summary: Dict[str, Any]): - """Save CSV, plots, and generated-audio artifacts following EasyMagpie conventions.""" - easy_magpie_artifacts = write_easymagpie_generated_audio_dir(args, sample_rows) - filewise_summary["easy_magpie_generated_audio_dir"] = easy_magpie_artifacts["generated_audio_dir"] - filewise_summary["easy_magpie_generated_manifest"] = easy_magpie_artifacts["manifest_path"] - - checkpoint_name = get_checkpoint_name(args) - dataset_name = get_dataset_name(args) - - per_run_csv = os.path.join(args.out_dir, "all_experiment_metrics.csv") - append_metrics_to_csv(per_run_csv, checkpoint_name, dataset_name, filewise_summary) - - # Keep this alias because EasyMagpie aggregation scripts often look for the CI CSV. - ci_csv = os.path.join(args.out_dir, "all_experiment_metrics_with_ci.csv") - append_metrics_to_csv(ci_csv, checkpoint_name, dataset_name, filewise_summary) - - if not args.save_plots: - return - - violin_metrics = list(args.violin_plot_metrics) - if args.disable_utmosv2 and "utmosv2" in violin_metrics: - violin_metrics.remove("utmosv2") - - plot_dir = os.path.join(args.out_dir, "plots") - create_violin_plot_if_available( - sample_rows, - violin_metrics, - os.path.join(plot_dir, f"{dataset_name}_violin.png"), - ) - - # Also write in eval_dir root with the same style used by MagpieTTS: - # f"{dataset}_violin_{repeat_idx}.png". Here the merged final output is repeat 0. - create_violin_plot_if_available( - sample_rows, - violin_metrics, - os.path.join(args.out_dir, f"{dataset_name}_violin_0.png"), - ) - - -def wait_for_files(paths: List[str], timeout_sec: float = 7200.0, poll_sec: float = 5.0): - start = time.time() - while True: - missing = [p for p in paths if not os.path.exists(p)] - if not missing: - return - if time.time() - start > timeout_sec: - raise TimeoutError("Timed out waiting for files:\n" + "\n".join(missing)) - time.sleep(poll_sec) - - -def wait_for_rank_metric_files(args, world_size: int): - paths = [os.path.join(args.out_dir, f"metrics_rank{r:04d}.json") for r in range(world_size)] - wait_for_files(paths) - - -def wait_for_rank_filewise_metric_files(args, world_size: int): - paths = [os.path.join(args.out_dir, f"filewise_metrics_rank{r:04d}.jsonl") for r in range(world_size)] - wait_for_files(paths) - - -def scalarize_metric_value(v: Any): - if torch.is_tensor(v): - if v.numel() == 1: - return float(v.detach().cpu().item()) - return v.detach().cpu().tolist() - try: - import numpy as np - - if isinstance(v, np.generic): - return float(v.item()) - except Exception: - pass - if isinstance(v, (int, float, str, bool)) or v is None: - return v - return str(v) - - -def metric_dict_to_jsonable(d: Dict[str, Any]) -> Dict[str, Any]: - return {str(k): scalarize_metric_value(v) for k, v in d.items()} - - -def safe_metric_scalar(metric_dict: Dict[str, Any], preferred_keys: List[str]): - for key in preferred_keys: - if key in metric_dict: - value = metric_dict[key] - if torch.is_tensor(value): - return float(value.detach().cpu().item()) - return float(value) - return None - - -def get_first_metric(metrics: Dict[str, Any], names: List[str], default=None): - for name in names: - if name in metrics: - return metrics[name] - return default - - - -def format_final_metric_text(final_metrics: Dict[str, Any]) -> str: - intelligibility = final_metrics.get("intelligibility", {}) - speaker_similarity = final_metrics.get("speaker_similarity", {}) - - cer = get_first_metric(intelligibility, ["cer", "cer_dataset", "cer_cumulative"]) - wer = get_first_metric(intelligibility, ["wer", "wer_dataset", "wer_cumulative"]) - ssim_value = get_first_metric( - speaker_similarity, - ["ssim", "ssim_dataset", "ssim_pred_context_avg", "pred_context_ssim"], - ) - - def fmt(x): - if x is None: - return "nan" - try: - return f"{float(x):.10f}" - except Exception: - return str(x) - - return f"Average CER: {fmt(cer)}\nAverage WER: {fmt(wer)}\nSSIM: {fmt(ssim_value)}\n" - - - -def format_filewise_final_metric_text(filewise_summary: Dict[str, Any]) -> str: - def fmt(x): - if x is None: - return "nan" - try: - return f"{float(x):.10f}" - except Exception: - return str(x) - - ordered_keys = [ - ("cer", "CER filewise avg"), - ("wer", "WER filewise avg"), - ("cer_cumulative", "CER cumulative"), - ("wer_cumulative", "WER cumulative"), - ("ssim", "SSIM"), - ("ssim_pred_gt_avg", "SSIM pred/GT avg"), - ("ssim_pred_context_avg", "SSIM pred/context avg"), - ("ssim_gt_context_avg", "SSIM GT/context avg"), - ("ssim_pred_gt_avg_alternate", "SSIM pred/GT avg alternate"), - ("ssim_pred_context_avg_alternate", "SSIM pred/context avg alternate"), - ("ssim_gt_context_avg_alternate", "SSIM GT/context avg alternate"), - ("cer_gt_audio_cumulative", "CER GT-audio cumulative"), - ("wer_gt_audio_cumulative", "WER GT-audio cumulative"), - ("utmosv2_avg", "UTMOSv2 avg"), - ("total_gen_audio_seconds", "Total generated audio seconds"), - ("frechet_codec_distance", "Frechet codec distance"), - ("eou_cutoff_rate", "EOU cutoff rate"), - ("eou_silence_rate", "EOU silence rate"), - ("eou_noise_rate", "EOU noise rate"), - ("eou_error_rate", "EOU error rate"), - ] - - lines = [ - f"Average CER: {fmt(filewise_summary.get('cer'))}", - f"Average WER: {fmt(filewise_summary.get('wer'))}", - f"SSIM: {fmt(filewise_summary.get('ssim'))}", - ] - - for key, label in ordered_keys: - if key in {"cer", "wer", "ssim"}: - continue - if key in filewise_summary: - lines.append(f"{label}: {fmt(filewise_summary.get(key))}") - - return "\n".join(lines) + "\n" - - - -def write_filewise_csv(path: str, rows: List[Dict[str, Any]]): - """Write sample-level filewise metrics. - - Several fields are lists (turn_ids, reference_text, asr_hyp, cer_turns, - etc.), so they are JSON-encoded inside CSV cells. - """ - os.makedirs(os.path.dirname(path), exist_ok=True) - tmp_path = path + ".tmp" - - fieldnames = [ - "run_id", - "dataset_index", - "rank", - "num_turns", - "cer", - "wer", - "ssim", - "pred_gt_ssim", - "pred_context_ssim", - "gt_context_ssim", - "pred_gt_ssim_alternate", - "pred_context_ssim_alternate", - "gt_context_ssim_alternate", - "utmosv2", - "eou_error", - "turn_ids", - "cer_turns", - "wer_turns", - "ssim_turns", - "pred_gt_ssim_turns", - "pred_context_ssim_turns", - "gt_context_ssim_turns", - "pred_gt_ssim_alternate_turns", - "pred_context_ssim_alternate_turns", - "gt_context_ssim_alternate_turns", - "utmosv2_turns", - "eou_type_turns", - "eou_trailing_duration_turns", - "eou_trail_rms_ratio_turns", - "pred_audio_seconds_turns", - "target_audio_path", - "context_audio_path", - "pred_audio_paths", - "predicted_codes_paths", - "sample_pred_audio_path", - "sample_predicted_codes_path", - "reference_text", - "asr_hyp", - ] - - def csv_value(v): - if isinstance(v, (list, dict)): - return json.dumps(v, ensure_ascii=False) - return v - - with open(tmp_path, "w", encoding="utf-8", newline="") as f: - writer = csv.DictWriter(f, fieldnames=fieldnames) - writer.writeheader() - for row in rows: - writer.writerow({k: csv_value(row.get(k, None)) for k in fieldnames}) - - os.replace(tmp_path, path) - - - -def write_turnwise_csv(path: str, rows: List[Dict[str, Any]]): - """Write merged turn-level filewise metrics sorted by CER.""" - os.makedirs(os.path.dirname(path), exist_ok=True) - tmp_path = path + ".tmp" - - fieldnames = [ - "run_id", - "dataset_index", - "turn_id", - "rank", - "cer", - "wer", - "ssim", - "pred_gt_ssim", - "pred_context_ssim", - "gt_context_ssim", - "pred_gt_ssim_alternate", - "pred_context_ssim_alternate", - "gt_context_ssim_alternate", - "utmosv2", - "eou_type", - "eou_trailing_duration", - "eou_trail_rms_ratio", - "pred_audio_seconds", - "target_audio_path", - "context_audio_path", - "pred_audio_path", - "predicted_codes_path", - "sample_pred_audio_path", - "sample_predicted_codes_path", - "reference_text", - "asr_hyp", - ] - - def csv_value(v): - if isinstance(v, (list, dict)): - return json.dumps(v, ensure_ascii=False) - return v - - with open(tmp_path, "w", encoding="utf-8", newline="") as f: - writer = csv.DictWriter(f, fieldnames=fieldnames) - writer.writeheader() - for row in rows: - writer.writerow({k: csv_value(row.get(k, None)) for k in fieldnames}) - - os.replace(tmp_path, path) - - -# ----------------------------- -# Dataset helpers -# ----------------------------- - - -def _combined_audio_name(first_audio_filepath: str, paths: List[str]) -> str: - base_names = [os.path.splitext(os.path.basename(p))[0] for p in paths if p] - ext = os.path.splitext(paths[-1])[1] if paths and paths[-1] else "" - combined_name = "_".join(base_names) + ext - dir_name = os.path.dirname(first_audio_filepath) - return os.path.join(dir_name, combined_name) if dir_name else combined_name - - -class EvalJSONLDataset(Dataset): - def __init__(self, file_path: str, emulate_multiturn_num_turns: int = 1): - self.samples = [] - raw_samples = [] - - with open(file_path, "r", encoding="utf-8") as f: - for line_idx, line in enumerate(f, 1): - line = line.strip() - if not line: - continue - try: - sample = json.loads(line) - sample["__dataset_index__"] = len(raw_samples) - raw_samples.append(sample) - except json.JSONDecodeError as e: - raise ValueError(f"Invalid JSON on line {line_idx}: {e}") - - if emulate_multiturn_num_turns <= 1: - self.samples = raw_samples - return - - single_turn_by_speaker = {} - for sample in raw_samples: - if isinstance(sample["text"], list): - self.samples.append(sample) - else: - speaker = sample.get("speaker", "unknown") - single_turn_by_speaker.setdefault(speaker, []).append(sample) - - synthetic_index = len(raw_samples) - for _, speaker_samples in single_turn_by_speaker.items(): - buffer_texts, buffer_paths = [], [] - first_sample_meta = None - - for sample in speaker_samples: - if not buffer_texts: - first_sample_meta = dict(sample) - - buffer_texts.append(sample["text"]) - buffer_paths.append(sample.get("audio_filepath", "")) - - if len(buffer_texts) == emulate_multiturn_num_turns: - first_sample_meta["text"] = buffer_texts - first_sample_meta["audio_filepath"] = _combined_audio_name( - first_sample_meta.get("audio_filepath", ""), - buffer_paths, - ) - first_sample_meta["__dataset_index__"] = synthetic_index - synthetic_index += 1 - self.samples.append(first_sample_meta) - buffer_texts, buffer_paths, first_sample_meta = [], [], None - - if buffer_texts and first_sample_meta is not None: - first_sample_meta["text"] = buffer_texts - first_sample_meta["audio_filepath"] = _combined_audio_name( - first_sample_meta.get("audio_filepath", ""), - buffer_paths, - ) - first_sample_meta["__dataset_index__"] = synthetic_index - synthetic_index += 1 - self.samples.append(first_sample_meta) - - def __len__(self): - return len(self.samples) - - def __getitem__(self, idx): - return self.samples[idx] - - -def _sample_text_segments_for_count(sample: Dict[str, Any], max_eval_turns=None) -> List[str]: - text_data = sample.get("text", "") - if isinstance(text_data, list): - segments = text_data - if max_eval_turns is not None: - segments = segments[: int(max_eval_turns)] - return [str(x) for x in segments] - return [str(text_data)] - - -def estimate_text_token_count(sample: Dict[str, Any], model, max_eval_turns=None) -> int: - main_tokenizer_name = list(model.cfg.text_tokenizers.keys())[0] - total = 0 - for segment in _sample_text_segments_for_count(sample, max_eval_turns=max_eval_turns): - total += len(model.tokenizer.encode(segment, tokenizer_name=main_tokenizer_name)) + 1 - return int(total) - - -class SortedByTextTokenCountDataset(Dataset): - def __init__(self, dataset: Dataset, model, max_eval_turns=None, descending: bool = True): - self.dataset = dataset - scored = [] - for i in range(len(dataset)): - sample = dict(dataset[i]) - token_count = estimate_text_token_count(sample, model=model, max_eval_turns=max_eval_turns) - sample["__text_token_count__"] = int(token_count) - scored.append((token_count, i, sample)) - - scored.sort(key=lambda x: (x[0], -x[1]), reverse=bool(descending)) - self.indices = [i for _, i, _ in scored] - self.token_counts = {i: int(tok) for tok, i, _ in scored} - - def __len__(self): - return len(self.indices) - - def __getitem__(self, local_idx): - original_idx = self.indices[local_idx] - sample = dict(self.dataset[original_idx]) - sample["__text_token_count__"] = self.token_counts[original_idx] - return sample - - -# ----------------------------- -# Audio / collate helpers -# ----------------------------- - - -def _resolve_audio_path(path, root_path): - if path is None: - return None - if root_path is not None and not os.path.isabs(path): - return os.path.join(root_path, path) - return path - - -def _load_audio(path, sample_rate, normalize=True, use_librosa=False): - if path is None or not os.path.exists(path): - return torch.zeros(1, dtype=torch.float32) - - if use_librosa: - wav, sr = librosa.load(path, sr=sample_rate, mono=True) - if normalize: - wav = normalize_volume(wav) - return torch.as_tensor(wav, dtype=torch.float32) - - wav, sr = sf.read(path, dtype="float32") - if wav.ndim > 1: - wav = wav.mean(axis=1) - - if normalize: - wav = normalize_volume(wav) - - wav = torch.as_tensor(wav, dtype=torch.float32).unsqueeze(0) - return resample(wav, sr, sample_rate).squeeze(0) - - - - -def collate_and_tokenize_custom( - batch, - model, - sample_rate=22050, - root_path=None, - normalize_audio_volume=True, - use_librosa=False, - max_eval_turns=None, - inference_mode="auto", -): - """Collate for either multiturn-user-audio or regular single-turn inference. - - Mode selection: - - multiturn_user_audio: turn-based multiturn user-audio prefill with user_audio_file_path. - - single_turn: regular batched TTS, no user-speech/silence prefill. - - auto: multiturn_user_audio when samples look multiturn/user-conditioned; otherwise - single_turn. This keeps old LibriTTS commands working with batch_size=32. - """ - main_tokenizer_name = list(model.cfg.text_tokenizers.keys())[0] - - if max_eval_turns is not None: - max_eval_turns = int(max_eval_turns) - if max_eval_turns <= 0: - raise ValueError("--max_eval_turns must be > 0 when provided.") - truncated_batch = [] - for sample in batch: - sample = dict(sample) - if isinstance(sample["text"], list): - sample["text"] = sample["text"][:max_eval_turns] - if isinstance(sample.get("user_audio_file_path"), list): - sample["user_audio_file_path"] = sample["user_audio_file_path"][:max_eval_turns] - truncated_batch.append(sample) - batch = truncated_batch - - def looks_multiturn_user_audio(sample): - return isinstance(sample.get("text"), list) or bool(sample.get("user_audio_file_path", None)) - - if inference_mode == "multiturn_user_audio": - is_multiturn_user_audio = True - elif inference_mode == "single_turn": - is_multiturn_user_audio = False - elif inference_mode == "auto": - is_multiturn_user_audio = any(looks_multiturn_user_audio(sample) for sample in batch) - else: - raise ValueError(f"Unknown inference_mode={inference_mode}") - - out_dict = { - "multiturn_user_audio": bool(is_multiturn_user_audio), - "dataset_indices": [int(s.get("__dataset_index__", -1)) for s in batch], - } - - if is_multiturn_user_audio: - max_turns = 1 - for sample in batch: - if isinstance(sample["text"], list): - max_turns = max(max_turns, len(sample["text"])) - - raw_turn_texts = [] - for sample in batch: - if isinstance(sample["text"], list): - raw_turn_texts.append([str(x) for x in sample["text"]]) - else: - raw_turn_texts.append([str(sample["text"])]) - - batched_turns = [] - batched_turn_lens = [] - valid_turn_masks = [] - - for turn_id in range(max_turns): - turn_tokens = [] - turn_lens = [] - turn_valid = [] - for sample in batch: - text_data = sample["text"] - if isinstance(text_data, list): - if turn_id < len(text_data): - seg_ids = model.tokenizer.encode(text_data[turn_id], tokenizer_name=main_tokenizer_name) + [ - model.eos_id - ] - turn_tokens.append(torch.as_tensor(seg_ids, dtype=torch.long)) - turn_lens.append(len(seg_ids)) - turn_valid.append(True) - else: - turn_tokens.append(torch.as_tensor([model.pad_id], dtype=torch.long)) - turn_lens.append(1) - turn_valid.append(False) - else: - if turn_id == 0: - seg_ids = model.tokenizer.encode(text_data, tokenizer_name=main_tokenizer_name) + [model.eos_id] - turn_tokens.append(torch.as_tensor(seg_ids, dtype=torch.long)) - turn_lens.append(len(seg_ids)) - turn_valid.append(True) - else: - turn_tokens.append(torch.as_tensor([model.pad_id], dtype=torch.long)) - turn_lens.append(1) - turn_valid.append(False) - - batched_turns.append(pad_sequence(turn_tokens, batch_first=True, padding_value=model.pad_id)) - batched_turn_lens.append(torch.tensor(turn_lens, dtype=torch.long)) - valid_turn_masks.append(torch.tensor(turn_valid, dtype=torch.bool)) - - user_audio_by_turn = [[] for _ in range(max_turns)] - user_audio_lens_by_turn = [[] for _ in range(max_turns)] - - else: - # Single-turn regular inference: one text segment per sample, batched. - raw_turn_texts = [] - single_turn_tokens = [] - single_turn_lens = [] - for sample in batch: - text_data = sample["text"] - if isinstance(text_data, list): - text = " ".join(str(x) for x in text_data) - else: - text = str(text_data) - raw_turn_texts.append([text]) - seg_ids = model.tokenizer.encode(text, tokenizer_name=main_tokenizer_name) + [model.eos_id] - single_turn_tokens.append(torch.as_tensor(seg_ids, dtype=torch.long)) - single_turn_lens.append(len(seg_ids)) - - out_dict["input_ids"] = pad_sequence(single_turn_tokens, batch_first=True, padding_value=model.pad_id) - out_dict["input_lengths"] = torch.tensor(single_turn_lens, dtype=torch.long) - user_audio_by_turn = [] - user_audio_lens_by_turn = [] - - audio_list = [] - audio_lengths = [] - - for i, sample in enumerate(batch): - context_path = _resolve_audio_path(sample.get("context_audio_filepath"), root_path) - context_wav = _load_audio(context_path, sample_rate, normalize=normalize_audio_volume, use_librosa=use_librosa) - audio_list.append(context_wav) - audio_lengths.append(len(context_wav)) - - if is_multiturn_user_audio: - user_audio_paths = sample.get("user_audio_file_path", None) - for turn_id in range(len(user_audio_by_turn)): - has_valid_text_turn = ( - isinstance(sample["text"], list) and turn_id < len(sample["text"]) - ) or ((not isinstance(sample["text"], list)) and turn_id == 0) - - if ( - isinstance(user_audio_paths, list) - and turn_id < len(user_audio_paths) - and user_audio_paths[turn_id] - and has_valid_text_turn - ): - user_path = _resolve_audio_path(user_audio_paths[turn_id], root_path) - user_wav = _load_audio( - user_path, - sample_rate=sample_rate, - normalize=normalize_audio_volume, - use_librosa=use_librosa, - ) - else: - user_wav = torch.zeros(int(2 * sample_rate), dtype=torch.float32) - - user_audio_by_turn[turn_id].append(user_wav) - user_audio_lens_by_turn[turn_id].append(len(user_wav)) - - max_audio_len = max(audio_lengths) - batch_size = len(audio_lengths) - padded_audio = torch.zeros((batch_size, max_audio_len), dtype=torch.float32) - for i, wav in enumerate(audio_list): - padded_audio[i, : len(wav)] = wav - - if is_multiturn_user_audio: - padded_user_audio_turns = [] - padded_user_audio_turn_lens = [] - for turn_id in range(len(user_audio_by_turn)): - turn_lens = user_audio_lens_by_turn[turn_id] - max_turn_audio_len = max(turn_lens) - padded_turn_audio = torch.zeros((batch_size, max_turn_audio_len), dtype=torch.float32) - for i, wav in enumerate(user_audio_by_turn[turn_id]): - padded_turn_audio[i, : len(wav)] = wav - padded_user_audio_turns.append(padded_turn_audio) - padded_user_audio_turn_lens.append(torch.tensor(turn_lens, dtype=torch.long)) - - out_dict["batched_turns"] = batched_turns - out_dict["batched_turn_lens"] = batched_turn_lens - out_dict["valid_turn_masks"] = valid_turn_masks - out_dict["user_audio_turns"] = padded_user_audio_turns - out_dict["user_audio_turns_lens"] = padded_user_audio_turn_lens - - out_dict["context_audio"] = padded_audio - out_dict["context_audio_lengths"] = torch.tensor(audio_lengths, dtype=torch.long) - out_dict["target_audio_paths"] = [s["audio_filepath"] for s in batch] - out_dict["raw_text"] = [" ".join(x) for x in raw_turn_texts] - out_dict["raw_turn_texts"] = raw_turn_texts - - return out_dict - - -# ----------------------------- -# Model / generation -# ----------------------------- - - -def attach_dtype_counter(model): - handles = [] - stats = {} - examples = {} - - def is_leaf(module): - return len(list(module.children())) == 0 - - def get_dtype(x): - if torch.is_tensor(x): - return str(x.dtype) - if isinstance(x, (list, tuple)): - for t in x: - if torch.is_tensor(t): - return str(t.dtype) - return "other" - - def get_module_group(name): - return name.split(".")[0] if "." in name else name - - def hook_fn(name): - def fn(module, inputs, outputs): - dtype = get_dtype(outputs) - if dtype not in ["torch.float16", "torch.bfloat16", "torch.float32"]: - dtype = "other" - group = get_module_group(name) - if group not in stats: - stats[group] = { - "torch.float16": 0, - "torch.bfloat16": 0, - "torch.float32": 0, - "other": 0, - } - examples[group] = { - "torch.float16": [], - "torch.bfloat16": [], - "torch.float32": [], - "other": [], - } - stats[group][dtype] += 1 - if len(examples[group][dtype]) < 3: - examples[group][dtype].append(module.__class__.__name__) - - return fn - - for name, module in model.named_modules(): - if is_leaf(module): - handles.append(module.register_forward_hook(hook_fn(name))) - return handles, stats, examples - - -def report_dtype_stats(handles, stats, examples, rank=0): - for h in handles: - h.remove() - logging.info(f"[rank={rank}] === DTYPE USAGE PER MODULE ===") - for group, group_stats in stats.items(): - total = sum(group_stats.values()) - if total == 0: - continue - logging.info(f"[rank={rank}] --- {group} ---") - for dtype, count in group_stats.items(): - if count > 0: - logging.info(f"[rank={rank}] {dtype}: {count} ({100 * count / total:.2f}%)") - logging.info(f"[rank={rank}] === DTYPE EXAMPLES ===") - for group, group_examples in examples.items(): - for dtype, mods in group_examples.items(): - if mods: - logging.info(f"[rank={rank}] {group} {dtype}: {mods}") - - -def build_model_and_codec(args, target_device, target_dtype): - model_cfg = EasyMagpieTTSInferenceModel.restore_from(args.checkpoint_path, return_config=True) - - with open_dict(model_cfg): - model_cfg.target = "nemo.collections.tts.models.easy_magpietts_inference.EasyMagpieTTSInferenceModel" - model_cfg.codecmodel_path = args.codec_model_path - model_cfg.train_ds = None - model_cfg.validation_ds = None - model_cfg.run_val_inference = False - model_cfg.use_utmos = False - model_cfg.use_meta_init_for_decoder = True - - if args.phoneme_tokenizer_path and getattr(model_cfg, "phoneme_tokenizer", None) is not None: - model_cfg.phoneme_tokenizer.tokenizer_path = args.phoneme_tokenizer_path - - model = EasyMagpieTTSInferenceModel.restore_from( - args.checkpoint_path, - override_config_path=model_cfg, - map_location=torch.device("cpu"), - ) - model.use_kv_cache_for_inference = True - model.to(dtype=target_dtype) - model.eval().to(target_device) - - model.input_samples_per_frame = int(model.codec_model_samples_per_frame * model.frame_stacking_factor) - model.target_samples_per_frame = model.input_samples_per_frame / (model.sample_rate / model.output_sample_rate) - - codec_model = AudioCodecModel.restore_from(args.codec_model_path, strict=False, map_location=torch.device("cpu")) - if hasattr(codec_model, "discriminator"): - del codec_model.discriminator - codec_model.freeze() - codec_model = codec_model.to(target_device).eval() - - codec_converter = None - if getattr(model, "_codec_converter", None) is not None: - vq_new = deepcopy(model._codec_converter.vector_quantizer_new).to(target_device).eval() - codec_converter = VectorQuantizerIndexConverter( - vector_quantizer_original=codec_model.vector_quantizer, - vector_quantizer_new=vq_new, - ).to(target_device).eval() - - model._codec_helper = CodecHelper(codec_model=codec_model, codec_converter=codec_converter) - model._generate_codec_silence_buffer() - - return model - - -def prepare_inputs_for_device(inputs, model, args, target_dtype, speaker_wav=None): - B = inputs["context_audio"].size(0) - device = model.device - - inputs["context_audio"] = inputs["context_audio"].to(device, dtype=target_dtype) - inputs["context_audio_lengths"] = inputs["context_audio_lengths"].to(device) - - if args.user_custom_speaker_reference and speaker_wav is not None: - inputs["context_audio"] = speaker_wav.repeat(B, 1).detach() - inputs["context_audio_lengths"] = torch.full((B,), speaker_wav.size(-1), dtype=torch.long, device=device) - - if "user_audio_turns" in inputs: - inputs["user_audio_turns"] = [x.to(device, dtype=target_dtype) for x in inputs["user_audio_turns"]] - inputs["user_audio_turns_lens"] = [x.to(device) for x in inputs["user_audio_turns_lens"]] - - return inputs - - - - -def run_single_turn_generation(model, inputs, args): - """Regular batched single-turn EasyMagpieTTS generation. - - This path does not prefill with user speech or synthetic silence. It is for - classic single-turn datasets such as LibriTTS and supports batch_size > 1. - """ - B = inputs["context_audio"].size(0) - device = model.device - - with torch.inference_mode(): - wav = inputs["context_audio"] - wav_len = inputs["context_audio_lengths"] - codes, codes_lens = model._codec_helper.audio_to_codes(wav, wav_len) - - use_lang = bool(getattr(model, "add_language_to_context_text", False)) - ctx_text = f"[{args.language.upper()}]" if use_lang else "[NO TEXT CONTEXT]" - ctx_text_ids = model.tokenizer.encode(ctx_text, tokenizer_name=model.text_conditioning_tokenizer_name) - ctx_toks = torch.tensor([ctx_text_ids], dtype=torch.long, device=device).expand(B, -1) - ctx_toks_lens = torch.tensor([len(ctx_text_ids)] * B, dtype=torch.long, device=device) - - state = model.streaming_init( - context_audio_codes=codes, - context_audio_codes_lens=codes_lens, - context_text_tokens=ctx_toks, - context_text_tokens_lens=ctx_toks_lens, - use_cfg=args.use_cfg, - cfg_scale=args.cfg_scale, - use_local_transformer=True, - temperature=args.temperature, - topk=args.topk, - phoneme_input_type="pred", - phoneme_sampling_method="argmax", - use_inference_mode=True, - ) - - text = inputs["input_ids"].to(device) - text_lens = inputs["input_lengths"].to(device) - - turn_offsets = torch.zeros(B, dtype=torch.long, device=device) - turn_steps = 0 - - while not state.finished.all() and turn_steps < args.max_tts_steps: - turn_steps += 1 - relative_positions = state.text_tokens_seen - turn_offsets - positions = relative_positions.clamp(min=0, max=text.size(1) - 1) - current_tokens = text[torch.arange(B, device=device), positions] - exhausted = relative_positions >= text_lens - current_tokens = torch.where( - exhausted, - torch.full_like(current_tokens, model.eos_id), - current_tokens, - ) - state, _, _ = model.streaming_step( - state=state, - text_tokens=current_tokens, - use_inference_mode=True, - ) - - generated_codes = None - if getattr(state, "all_predictions", None): - try: - generated_codes = torch.cat(state.all_predictions, dim=-1).detach() - except Exception: - generated_codes = None - - finalize_output = model.streaming_finalize(state, use_inference_mode=True) - - # For single turn, there is one generated segment per sample and no - # multiturn frame alignment needed. - return finalize_output, [], 0, generated_codes - -def run_generation(model, inputs, args, codec_sil_codes): - """Run either multiturn-user-audio or regular single-turn generation.""" - if not inputs.get("multiturn_user_audio", False): - return run_single_turn_generation(model, inputs, args) - - B = inputs["context_audio"].size(0) - if B != 1: - raise RuntimeError("Multiturn user-audio inference requires --batch_size=1 per process.") - - device = model.device - multiturn_turn_frame_ranges = [] - multiturn_decode_start_frame = 0 - - with torch.inference_mode(): - wav = inputs["context_audio"] - wav_len = inputs["context_audio_lengths"] - codes, codes_lens = model._codec_helper.audio_to_codes(wav, wav_len) - - use_lang = bool(getattr(model, "add_language_to_context_text", False)) - ctx_text = f"[{args.language.upper()}]" if use_lang else "[NO TEXT CONTEXT]" - ctx_text_ids = model.tokenizer.encode(ctx_text, tokenizer_name=model.text_conditioning_tokenizer_name) - ctx_toks = torch.tensor([ctx_text_ids], dtype=torch.long, device=device).expand(B, -1) - ctx_toks_lens = torch.tensor([len(ctx_text_ids)] * B, dtype=torch.long, device=device) - - state = model.streaming_init( - context_audio_codes=codes, - context_audio_codes_lens=codes_lens, - context_text_tokens=ctx_toks, - context_text_tokens_lens=ctx_toks_lens, - use_cfg=args.use_cfg, - cfg_scale=args.cfg_scale, - use_local_transformer=True, - temperature=args.temperature, - topk=args.topk, - phoneme_input_type="pred", - phoneme_sampling_method="argmax", - use_inference_mode=True, - ) - - batched_turns = inputs["batched_turns"] - batched_turn_lens = inputs["batched_turn_lens"] - valid_turn_masks = inputs["valid_turn_masks"] - - for turn_id in range(len(batched_turns)): - turn_text = batched_turns[turn_id].to(device) - turn_lens = batched_turn_lens[turn_id].to(device) - valid_mask = valid_turn_masks[turn_id].to(device) - if not bool(valid_mask[0].item()): - continue - - state.finished.zero_() - state.text_finished.zero_() - state.audio_prediction_end_idx.fill_(-1) - if hasattr(state, "turn_text_tokens_seen"): - state.turn_text_tokens_seen.zero_() - if hasattr(state, "phoneme_steps"): - state.phoneme_steps.zero_() - if hasattr(state, "phoneme_stream_ended"): - state.phoneme_stream_ended.zero_() - if hasattr(state, "phoneme_eos_detected"): - state.phoneme_eos_detected.zero_() - state.last_phoneme_tokens = None - - if not model.cfg.get("condition_on_user_speech", False): - user_audio = inputs["user_audio_turns"][turn_id] - user_audio_prefill_steps = int(round(user_audio.size(-1) / model.input_samples_per_frame)) - user_audio_prefill_seconds = user_audio_prefill_steps * model.input_samples_per_frame / model.sample_rate - user_audio_prefill_tokens = torch.full((1, user_audio_prefill_steps), model.pad_id, dtype=torch.long, device=device) - user_audio_channel_embedding = None - else: - user_audio = inputs["user_audio_turns"][turn_id] - user_audio_lens = inputs["user_audio_turns_lens"][turn_id] - user_audio_codes, user_audio_codes_lens = model._codec_helper.audio_to_codes(user_audio, user_audio_lens) - - if model._codec_converter is not None: - user_audio_codes = model._codec_converter.convert_original_to_new( - audio_tokens=user_audio_codes, - audio_lens=user_audio_codes_lens, - ).long() - - user_audio_codes, user_audio_codes_lens = model.stack_codes( - user_audio_codes, - user_audio_codes_lens, - model.audio_bos_id, - model.audio_eos_id, - model.frame_stacking_factor, - model.num_audio_codebooks, - ) - user_audio_embedded = model.embed_audio_tokens(user_audio_codes) - - boundary_trim = model.cfg.get("user_audio_boundary_trim", 4) - boundary_trim = 0 if boundary_trim is None else int(boundary_trim) - if boundary_trim == 0: - real_start = 0 - real_end = int(user_audio_codes_lens[0].item()) - else: - turn_len_with_special = int(user_audio_codes_lens[0].item()) - real_start = 1 - real_end = max(real_start, turn_len_with_special - 1) - - user_audio_embedded = user_audio_embedded[:, real_start:real_end] - copy_len = user_audio_embedded.size(1) - if boundary_trim > 0: - trim = min(boundary_trim, copy_len // 2) - if trim > 0: - user_audio_embedded[:, :trim] = 0.0 - user_audio_embedded[:, copy_len - trim :] = 0.0 - - bos_user_pad = torch.zeros( - user_audio_embedded.size(0), - 1, - user_audio_embedded.size(2), - device=user_audio_embedded.device, - dtype=user_audio_embedded.dtype, - ) - user_audio_embedded = torch.cat([bos_user_pad, user_audio_embedded], dim=1) - user_audio_prefill_steps = user_audio_embedded.size(1) - user_audio_prefill_tokens = torch.full((B, user_audio_prefill_steps), model.pad_id, dtype=torch.long, device=device) - user_audio_channel_embedding = user_audio_embedded - user_audio_prefill_seconds = user_audio_prefill_steps * model.input_samples_per_frame / model.sample_rate - - delay_tokens = int(state.config.training_mode.streaming_speech_delay) - delay_tokens = min(delay_tokens, int(turn_lens[0].item()), user_audio_prefill_steps) - - warmup_tokens = turn_text[:, :delay_tokens] - turn_text = turn_text[:, delay_tokens:] - turn_lens = torch.clamp(turn_lens - delay_tokens, min=0) - - if user_audio_channel_embedding is not None and delay_tokens > 0: - warmup_user_audio = user_audio_channel_embedding[:, -delay_tokens:] - user_audio_channel_embedding = user_audio_channel_embedding[:, :-delay_tokens] - user_audio_prefill_tokens = user_audio_prefill_tokens[:, :-delay_tokens] - else: - warmup_user_audio = None - - if user_audio_prefill_tokens.size(1) > 0: - state = model.streaming_prefill_profile( - state=state, - text_tokens=user_audio_prefill_tokens, - use_inference_mode=True, - user_audio_channel_embedding=user_audio_channel_embedding, - ) - - for i in range(delay_tokens): - user_step_emb = warmup_user_audio[:, i] if warmup_user_audio is not None else None - state.finished.zero_() - state, _, _ = model.streaming_step( - state=state, - text_tokens=warmup_tokens[:, i], - user_audio_channel_embedding=user_step_emb, - prefill_like_step=not bool(model.cfg.get("agent_mask_include_transition_prefix", False)), - prefill_like_is_last_step=(i == delay_tokens - 1), - use_inference_mode=True, - ) - - logging.info(f"[multiturn_user_audio] turn={turn_id} prefilled {user_audio_prefill_steps} steps ({user_audio_prefill_seconds:.2f}s)") - - turn_start_frame = sum(p.size(-1) for p in state.all_predictions) - if turn_id == 0: - state.audio_prediction_start_idx.fill_(turn_start_frame) - multiturn_decode_start_frame = turn_start_frame - - turn_offset = state.text_tokens_seen.clone() - turn_steps = 0 - saw_audio = False - turn_ended_with_audio_eos = False - - while turn_steps < args.max_tts_steps: - turn_steps += 1 - state.finished.zero_() - relative_position = state.text_tokens_seen - turn_offset - text_exhausted = relative_position >= turn_lens - - if turn_text.size(1) == 0: - current_tokens = torch.full((B,), model.eos_id, dtype=torch.long, device=device) - else: - position = relative_position.clamp(min=0, max=turn_text.size(1) - 1) - current_tokens = turn_text[torch.arange(B, device=device), position] - current_tokens = torch.where( - text_exhausted, - torch.full_like(current_tokens, model.eos_id), - current_tokens, - ) - - state, audio_codes, _ = model.streaming_step( - state=state, - text_tokens=current_tokens, - use_inference_mode=True, - ) - - if audio_codes is not None and not saw_audio: - saw_audio = True - - if bool(text_exhausted[0].item()) and bool(state.finished[0].item()): - turn_ended_with_audio_eos = True - break - - state.audio_prediction_end_idx.fill_(-1) - state.finished.zero_() - turn_end_frame = sum(p.size(-1) for p in state.all_predictions) - multiturn_turn_frame_ranges.append((turn_id, turn_start_frame, turn_end_frame)) - logging.info( - f"[multiturn_user_audio] turn={turn_id} steps={turn_steps} " - f"saw_audio={saw_audio} ended_with_audio_eos={turn_ended_with_audio_eos}" - ) - - bos_id = getattr(model, "audio_bos_id", -1) - eos_id = getattr(model, "audio_eos_id", -1) - speaking_id = getattr(model, "audio_user_speaking_id", -1) - speaking_end_id = getattr(model, "audio_user_speaking_end_id", -1) - sil_injection = codec_sil_codes.view(1, -1, 1) - - for step_idx in range(len(state.all_predictions)): - pred = state.all_predictions[step_idx] - mask = (pred == bos_id) | (pred == eos_id) | (pred == speaking_id) | (pred == speaking_end_id) - frame_mask = mask.any(dim=1, keepdim=True) - if frame_mask.any(): - state.all_predictions[step_idx] = torch.where(frame_mask, sil_injection.expand_as(pred), pred) - - state.audio_prediction_end_idx.fill_(-1) - - generated_codes = None - if getattr(state, "all_predictions", None): - try: - generated_codes = torch.cat(state.all_predictions, dim=-1).detach() - except Exception: - generated_codes = None - - finalize_output = model.streaming_finalize(state, use_inference_mode=True) - - return finalize_output, multiturn_turn_frame_ranges, multiturn_decode_start_frame, generated_codes - - -def load_speaker_wav_if_needed(args, model, target_dtype): - if args.user_custom_speaker_reference and args.inference_speaker_reference: - return _load_audio( - args.inference_speaker_reference, - model.sample_rate, - normalize=args.normalize_volume, - use_librosa=args.use_librosa, - ).unsqueeze(0).to(model.device, dtype=target_dtype) - - return None - - -# ----------------------------- -# Save generation outputs and metric manifests -# ----------------------------- - - -def write_audio_1d(path: str, wav: torch.Tensor, sr: int): - os.makedirs(os.path.dirname(path), exist_ok=True) - wav_np = wav.detach().cpu().float().numpy() - sf.write(path, wav_np, samplerate=sr) - - -def build_metric_item( - run_id: int, - rank: int, - dataset_index: int, - turn_id: int, - target_audio_path: str, - reference_text: str, - pred_audio_path: str, - context_audio_path: str, - pred_audio_samples: int, - context_audio_samples: int, - output_sample_rate: int, - context_sample_rate: int, - predicted_codes_path: str = None, - sample_pred_audio_path: str = None, - sample_predicted_codes_path: str = None, -): - return { - "run_id": int(run_id), - "rank": int(rank), - "dataset_index": int(dataset_index), - "turn_id": int(turn_id), - "target_audio_path": target_audio_path, - "reference_text": reference_text, - "pred_audio_path": pred_audio_path, - "context_audio_path": context_audio_path, - "pred_audio_samples": int(pred_audio_samples), - "context_audio_samples": int(context_audio_samples), - "pred_audio_seconds": float(pred_audio_samples / output_sample_rate), - "context_audio_seconds": float(context_audio_samples / context_sample_rate), - "output_sample_rate": int(output_sample_rate), - "context_sample_rate": int(context_sample_rate), - "predicted_codes_path": predicted_codes_path, - "sample_pred_audio_path": sample_pred_audio_path or pred_audio_path, - "sample_predicted_codes_path": sample_predicted_codes_path or predicted_codes_path, - } - - -def save_generated_code_slice(generated_codes, batch_idx: int, start_frame: int, end_frame: int, path: str): - """Save predicted codec codes as [num_codebooks, T] for MagpieTTS FCD.""" - if generated_codes is None: - return None - try: - os.makedirs(os.path.dirname(path), exist_ok=True) - T = int(generated_codes.size(-1)) - start_frame = max(0, min(int(start_frame), T)) - end_frame = max(start_frame, min(int(end_frame), T)) - if end_frame <= start_frame: - return None - codes = generated_codes[batch_idx, :, start_frame:end_frame].detach().cpu().long() - if codes.numel() == 0: - return None - torch.save(codes, path) - return path - except Exception as e: - logging.warning(f"Could not save predicted codes to {path}: {repr(e)}") - return None - - -def save_generation_outputs_and_build_metric_items( - model, - inputs, - finalize_output, - multiturn_turn_frame_ranges, - multiturn_decode_start_frame, - generated_codes, - args, - rank: int, - run_id: int, -): - device = model.device - B = inputs["context_audio"].size(0) - - with fp32_precision(): - audio_f32 = finalize_output.audio.float() - audio_len = finalize_output.audio_len.int() - - # Use model-reported generated audio lengths for both supported modes. - audio_len = torch.clamp(audio_len, max=audio_f32.size(1)) - - audio_out_dir = get_audio_out_dir(args) - metric_turn_dir = get_generated_turn_audio_dir(args) - metric_context_dir = get_context_metric_audio_dir(args) - predicted_codes_dir = get_predicted_codes_dir(args) - os.makedirs(audio_out_dir, exist_ok=True) - os.makedirs(metric_turn_dir, exist_ok=True) - os.makedirs(metric_context_dir, exist_ok=True) - os.makedirs(predicted_codes_dir, exist_ok=True) - - audio_f32_cpu = audio_f32.detach().cpu() - audio_len_cpu = audio_len.detach().cpu() - metric_items = [] - - for i in range(B): - target_path = inputs["target_audio_paths"][i] - base_name = os.path.basename(target_path) - stem, ext = os.path.splitext(base_name) - if not ext: - ext = ".wav" - - dataset_idx = int(inputs.get("dataset_indices", [-1] * B)[i]) - safe_stem = ( - f"run{run_id:02d}_idx{dataset_idx:08d}_{stem}" - if dataset_idx >= 0 - else f"run{run_id:02d}_rank{rank}_{stem}" - ) - - context_len = int(inputs["context_audio_lengths"][i].detach().cpu().item()) - context_wav = inputs["context_audio"][i, :context_len].detach().cpu().float() - context_metric_path = os.path.join(metric_context_dir, f"{safe_stem}_context.wav") - write_audio_1d(context_metric_path, context_wav, model.sample_rate) - - if inputs.get("multiturn_user_audio", False): - full_len = int(audio_len_cpu[i].item()) - full_wav_t = audio_f32_cpu[i, :full_len].float() - full_out_path = os.path.join(audio_out_dir, f"{safe_stem}{ext}") - - full_codes_path = os.path.join(predicted_codes_dir, f"{safe_stem}_sample.pt") - sample_predicted_codes_path = save_generated_code_slice( - generated_codes, - i, - multiturn_decode_start_frame, - generated_codes.size(-1) if generated_codes is not None else multiturn_decode_start_frame, - full_codes_path, - ) - - samples_per_prediction_frame = model.codec_model_samples_per_frame / ( - model.sample_rate / model.output_sample_rate - ) - - aligned_agent = torch.zeros_like(full_wav_t) - raw_turn_texts = inputs.get("raw_turn_texts", [[] for _ in range(B)]) - - for turn_id, start_frame, end_frame in multiturn_turn_frame_ranges: - rel_start_frame = start_frame - multiturn_decode_start_frame - rel_end_frame = end_frame - multiturn_decode_start_frame - - start_sample = int(round(rel_start_frame * samples_per_prediction_frame)) - end_sample = int(round(rel_end_frame * samples_per_prediction_frame)) - - start_sample = max(0, min(start_sample, full_len)) - end_sample = max(start_sample, min(end_sample, full_len)) - - aligned_agent[start_sample:end_sample] = full_wav_t[start_sample:end_sample] - - turn_wav = aligned_agent[start_sample:end_sample].float() - turn_out_path = os.path.join(audio_out_dir, f"{safe_stem}_turn_{turn_id}{ext}") - write_audio_1d(turn_out_path, turn_wav, model.output_sample_rate) - - metric_turn_path = os.path.join(metric_turn_dir, f"{safe_stem}_turn_{turn_id}.wav") - write_audio_1d(metric_turn_path, turn_wav, model.output_sample_rate) - - turn_codes_path = os.path.join(predicted_codes_dir, f"{safe_stem}_turn_{turn_id}.pt") - predicted_codes_path = save_generated_code_slice( - generated_codes, - i, - start_frame, - end_frame, - turn_codes_path, - ) - - if turn_id < len(raw_turn_texts[i]): - metric_items.append( - build_metric_item( - run_id=run_id, - rank=rank, - dataset_index=dataset_idx, - turn_id=turn_id, - target_audio_path=target_path, - reference_text=str(raw_turn_texts[i][turn_id]), - pred_audio_path=metric_turn_path, - context_audio_path=context_metric_path, - pred_audio_samples=int(turn_wav.numel()), - context_audio_samples=int(context_wav.numel()), - output_sample_rate=model.output_sample_rate, - context_sample_rate=model.sample_rate, - predicted_codes_path=predicted_codes_path, - sample_pred_audio_path=full_out_path, - sample_predicted_codes_path=sample_predicted_codes_path, - ) - ) - - write_audio_1d(full_out_path, aligned_agent, model.output_sample_rate) - - if "user_audio_turns" in inputs: - user_segments = [] - - first_user_len_in = int(inputs["user_audio_turns_lens"][0][i].item()) - first_user_delay_out = int(round(first_user_len_in * model.output_sample_rate / model.sample_rate)) - - for turn_id, start_frame, _ in multiturn_turn_frame_ranges: - if turn_id >= len(inputs["user_audio_turns"]): - continue - - turn_audio = inputs["user_audio_turns"][turn_id][i].detach().cpu().float() - turn_audio_len = int(inputs["user_audio_turns_lens"][turn_id][i].item()) - turn_audio = turn_audio[:turn_audio_len] - - turn_audio_out = resample( - turn_audio.unsqueeze(0), - model.sample_rate, - model.output_sample_rate, - ).squeeze(0) - - if turn_id == 0: - user_start_sample = 0 - else: - prev_turn_end_frame = multiturn_turn_frame_ranges[turn_id - 1][2] - rel_prev_end_frame = prev_turn_end_frame - multiturn_decode_start_frame - user_start_sample = first_user_delay_out + int( - round(rel_prev_end_frame * samples_per_prediction_frame) - ) - - user_segments.append((user_start_sample, turn_audio_out.detach().cpu().float())) - - total_user_len = 0 - for s, wav_seg in user_segments: - total_user_len = max(total_user_len, s + wav_seg.numel()) - - user_ch = torch.zeros(total_user_len) - for s, wav_seg in user_segments: - e = s + wav_seg.numel() - user_ch[s:e] += wav_seg - - agent_ch = torch.cat([torch.zeros(first_user_delay_out, dtype=aligned_agent.dtype), aligned_agent]) - - stereo_len = max(user_ch.numel(), agent_ch.numel()) - user_pad = torch.zeros(stereo_len) - agent_pad = torch.zeros(stereo_len) - - user_pad[: user_ch.numel()] = user_ch - agent_pad[: agent_ch.numel()] = agent_ch - - stereo = torch.stack([user_pad, agent_pad], dim=1).numpy() - aligned_path = os.path.join(audio_out_dir, f"{safe_stem}_user_agent_aligned{ext}") - sf.write(aligned_path, stereo, samplerate=model.output_sample_rate) - - else: - full_len = int(audio_len_cpu[i].item()) - wav = audio_f32_cpu[i, :full_len].float() - out_path = os.path.join(audio_out_dir, f"{safe_stem}{ext}") - write_audio_1d(out_path, wav, model.output_sample_rate) - - metric_turn_path = os.path.join(metric_turn_dir, f"{safe_stem}_turn_0.wav") - write_audio_1d(metric_turn_path, wav, model.output_sample_rate) - - codes_path = os.path.join(predicted_codes_dir, f"{safe_stem}_turn_0.pt") - predicted_codes_path = save_generated_code_slice( - generated_codes, - i, - 0, - generated_codes.size(-1) if generated_codes is not None else 0, - codes_path, - ) - - metric_items.append( - build_metric_item( - run_id=run_id, - rank=rank, - dataset_index=dataset_idx, - turn_id=0, - target_audio_path=target_path, - reference_text=str(inputs["raw_text"][i]), - pred_audio_path=metric_turn_path, - context_audio_path=context_metric_path, - pred_audio_samples=int(wav.numel()), - context_audio_samples=int(context_wav.numel()), - output_sample_rate=model.output_sample_rate, - context_sample_rate=model.sample_rate, - predicted_codes_path=predicted_codes_path, - sample_pred_audio_path=out_path, - sample_predicted_codes_path=predicted_codes_path, - ) - ) - - return metric_items - - -# ----------------------------- -# Metrics after generation -# ----------------------------- - - -def torch_rms_norm(wav: torch.Tensor, db_level: float = -27.0) -> torch.Tensor: - denom = torch.sum(wav**2) - if denom <= 0: - return wav - r = 10 ** (db_level / 20) - a = torch.sqrt((wav.size(-1) * (r**2)) / denom) - return wav * a - - -def _load_audio_for_metric(path: str, sample_rate: int): - wav = _load_audio(path, sample_rate=sample_rate, normalize=False, use_librosa=False) - if wav.numel() == 0: - wav = torch.zeros(1, dtype=torch.float32) - return wav.float() - - -def _pad_audio_1d_list(wavs: List[torch.Tensor], device, dtype=torch.float32): - if len(wavs) == 0: - return torch.zeros((0, 1), device=device, dtype=dtype), torch.zeros((0,), device=device, dtype=torch.long) - - lens = torch.tensor([max(1, int(w.numel())) for w in wavs], device=device, dtype=torch.long) - max_len = int(lens.max().item()) - out = torch.zeros((len(wavs), max_len), device=device, dtype=dtype) - - for i, w in enumerate(wavs): - w = w.to(device=device, dtype=dtype).flatten() - if w.numel() == 0: - continue - out[i, : w.numel()] = w - - return out, lens - - -def chunk_list(xs: List[Any], chunk_size: int) -> Iterable[List[Any]]: - chunk_size = max(1, int(chunk_size)) - for start in range(0, len(xs), chunk_size): - yield xs[start : start + chunk_size] - - -def _metric_device(): - return "cuda" if torch.cuda.is_available() else "cpu" - - -def _load_metric_batch_audio(batch_items: List[Dict[str, Any]], args): - pred_wavs = [] - context_wavs = [] - - for item in batch_items: - pred = _load_audio_for_metric(item["pred_audio_path"], sample_rate=int(item["output_sample_rate"])) - context = _load_audio_for_metric(item["context_audio_path"], sample_rate=int(item["context_sample_rate"])) - - if args.max_metric_audio_sec is not None: - max_pred_len = int(float(args.max_metric_audio_sec) * int(item["output_sample_rate"])) - pred = pred[: max(1, max_pred_len)] - - pred_wavs.append(pred) - context_wavs.append(context) - - device = _metric_device() - pred_audio, pred_lens = _pad_audio_1d_list(pred_wavs, device=device) - context_audio, context_lens = _pad_audio_1d_list(context_wavs, device=device) - output_sample_rate = int(batch_items[0]["output_sample_rate"]) - context_sample_rate = int(batch_items[0]["context_sample_rate"]) - - return pred_audio, pred_lens, context_audio, context_lens, output_sample_rate, context_sample_rate - - - -def _nan(): - return float("nan") - - -def finite_avg(values): - finite_values = [] - for value in values: - if value is None: - continue - try: - value = float(value) - except Exception: - continue - if math.isfinite(value): - finite_values.append(value) - if not finite_values: - return None - return sum(finite_values) / len(finite_values) - - -def _safe_word_error_detail(hyp_text: str, ref_text: str, use_cer: bool): - ref_text = "" if ref_text is None else str(ref_text).strip() - hyp_text = "" if hyp_text is None else str(hyp_text).strip() - if ref_text == "": - return None - try: - detailed = word_error_rate_detail(hypotheses=[hyp_text], references=[ref_text], use_cer=use_cer) - value = float(detailed[0]) - if not math.isfinite(value): - return None - return detailed - except Exception: - return None - - -def _safe_detail_value(detailed): - if detailed is None: - return None - try: - value = float(detailed[0]) - except Exception: - return None - if not math.isfinite(value): - return None - return value - - -def _load_speaker_eval_models(args, device: str): - """Load speaker verification models in the same style as MagpieTTS evaluation utils.""" - models = { - "feature_extractor": None, - "sv_model": None, - "sv_model_alternate": None, - } - - if nemo_asr is None or extract_embedding is None: - logging.warning("Speaker metric dependencies are unavailable; speaker similarity metrics will be NaN.") - return models - - try: - if args.sv_model_type == "wavlm": - if Wav2Vec2FeatureExtractor is None or WavLMForXVector is None: - raise RuntimeError("transformers WavLM dependencies are unavailable") - models["feature_extractor"] = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/wavlm-base-plus-sv") - models["sv_model"] = WavLMForXVector.from_pretrained("microsoft/wavlm-base-plus-sv").to(device).eval() - else: - models["sv_model"] = ( - nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(model_name="titanet_large").to(device).eval() - ) - - logging.info("Loading alternate speaker model `titanet_small`.") - with logging.temp_verbosity(logging.ERROR): - models["sv_model_alternate"] = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained( - model_name="titanet_small" - ) - models["sv_model_alternate"] = models["sv_model_alternate"].to(device).eval() - except Exception as e: - logging.warning(f"Could not load speaker evaluation models: {repr(e)}") - models = {"feature_extractor": None, "sv_model": None, "sv_model_alternate": None} - - return models - - -def _compute_speaker_similarity_rows(args, rows: List[Dict[str, Any]]): - """Populate pred/GT/context speaker similarity metrics per turn.""" - if args.disable_speaker_metrics: - for row in rows: - row["pred_gt_ssim"] = _nan() - row["pred_context_ssim"] = _nan() - row["gt_context_ssim"] = _nan() - row["pred_gt_ssim_alternate"] = _nan() - row["pred_context_ssim_alternate"] = _nan() - row["gt_context_ssim_alternate"] = _nan() - row["ssim"] = _nan() - return - - device = _metric_device() - models = _load_speaker_eval_models(args, device=device) - sv_model = models.get("sv_model") - sv_model_alt = models.get("sv_model_alternate") - extractor = models.get("feature_extractor") - - if sv_model is None or sv_model_alt is None or extract_embedding is None: - for row in rows: - row["pred_gt_ssim"] = _nan() - row["pred_context_ssim"] = _nan() - row["gt_context_ssim"] = _nan() - row["pred_gt_ssim_alternate"] = _nan() - row["pred_context_ssim_alternate"] = _nan() - row["gt_context_ssim_alternate"] = _nan() - row["ssim"] = _nan() - return - - emb_cache = {} - emb_alt_cache = {} - - def get_emb(path: str, alternate: bool = False): - if path is None or not path or not os.path.exists(path): - return None - cache = emb_alt_cache if alternate else emb_cache - if path in cache: - return cache[path] - model = sv_model_alt if alternate else sv_model - sv_type = "titanet" if alternate else args.sv_model_type - try: - with torch.inference_mode(): - emb = extract_embedding( - model=model, - extractor=extractor, - audio_path=path, - device=device, - sv_model_type=sv_type, - ) - cache[path] = emb - return emb - except Exception as e: - logging.warning(f"Could not extract speaker embedding for {path}: {repr(e)}") - cache[path] = None - return None - - def cosine(a, b): - if a is None or b is None: - return _nan() - try: - return torch.nn.functional.cosine_similarity(a, b, dim=0).item() - except Exception: - return _nan() - - for row in rows: - pred_path = row.get("pred_audio_path") - gt_path = row.get("target_audio_path") - context_path = row.get("context_audio_path") - - pred = get_emb(pred_path, alternate=False) - gt = get_emb(gt_path, alternate=False) - context = get_emb(context_path, alternate=False) - - pred_alt = get_emb(pred_path, alternate=True) - gt_alt = get_emb(gt_path, alternate=True) - context_alt = get_emb(context_path, alternate=True) - - row["pred_gt_ssim"] = cosine(pred, gt) - row["pred_context_ssim"] = cosine(pred, context) - row["gt_context_ssim"] = cosine(gt, context) - row["pred_gt_ssim_alternate"] = cosine(pred_alt, gt_alt) - row["pred_context_ssim_alternate"] = cosine(pred_alt, context_alt) - row["gt_context_ssim_alternate"] = cosine(gt_alt, context_alt) - row["ssim"] = row["pred_context_ssim"] - - del models - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - -def _compute_utmos_rows(args, rows: List[Dict[str, Any]], rank: int): - if args.disable_utmosv2: - for row in rows: - row["utmosv2"] = _nan() - return - - if compute_utmosv2_scores is None: - logging.warning("UTMOSv2 utility is unavailable; setting utmosv2 to NaN.") - for row in rows: - row["utmosv2"] = _nan() - return - - try: - # All predicted metric turns are written into the same directory. - audio_dir = get_generated_turn_audio_dir(args) - scores = compute_utmosv2_scores(audio_dir, _metric_device()) - for row in rows: - row["utmosv2"] = scores.get(os.path.normpath(row.get("pred_audio_path", "")), _nan()) - except Exception as e: - all_rank_print(rank, f"UTMOSv2 computation failed; setting utmosv2 to NaN: {repr(e)}") - for row in rows: - row["utmosv2"] = _nan() - - -def _compute_eou_rows(args, rows: List[Dict[str, Any]], rank: int): - if args.disable_eou or args.language != "en" or EoUClassifier is None: - for row in rows: - row["eou_type"] = None - row["eou_trailing_duration"] = _nan() - row["eou_trail_rms_ratio"] = _nan() - return - - try: - kwargs = {"device": _metric_device()} - if args.eou_model_name: - kwargs["model_name"] = args.eou_model_name - classifier = EoUClassifier(**kwargs) - items = [(row.get("pred_audio_path"), row.get("reference_text", "")) for row in rows] - - results = [] - batch_size = max(1, int(args.eou_batch_size)) - for start in range(0, len(items), batch_size): - results.extend(classifier.classify_batch(items[start : start + batch_size])) - - for row, result in zip(rows, results): - row["eou_type"] = result.eou_type.value - row["eou_trailing_duration"] = result.trailing_duration - row["eou_trail_rms_ratio"] = result.trail_rms_ratio - except Exception as e: - all_rank_print(rank, f"EOU computation failed; setting EOU metrics to NaN: {repr(e)}") - for row in rows: - row["eou_type"] = None - row["eou_trailing_duration"] = _nan() - row["eou_trail_rms_ratio"] = _nan() - - -def compute_magpie_style_global_metrics(rows: List[Dict[str, Any]]) -> Dict[str, Any]: - """Aggregate the same metric keys used by MagpieTTS evaluate_generated_audio.""" - n = len(rows) - if n == 0: - return { - "cer_filewise_avg": None, - "wer_filewise_avg": None, - "cer_cumulative": None, - "wer_cumulative": None, - "ssim_pred_gt_avg": None, - "ssim_pred_context_avg": None, - "ssim_gt_context_avg": None, - "ssim_pred_gt_avg_alternate": None, - "ssim_pred_context_avg_alternate": None, - "ssim_gt_context_avg_alternate": None, - "cer_gt_audio_cumulative": _nan(), - "wer_gt_audio_cumulative": _nan(), - "utmosv2_avg": None, - "total_gen_audio_seconds": 0.0, - "frechet_codec_distance": _nan(), - "eou_cutoff_rate": _nan(), - "eou_silence_rate": _nan(), - "eou_noise_rate": _nan(), - "eou_error_rate": _nan(), - } - - pred_texts = [str(r.get("pred_text", r.get("asr_hyp", ""))) for r in rows if r.get("gt_text", r.get("reference_text", ""))] - gt_texts = [str(r.get("gt_text", r.get("reference_text", ""))) for r in rows if r.get("gt_text", r.get("reference_text", ""))] - - out = {} - out["cer_filewise_avg"] = finite_avg([r.get("cer") for r in rows]) - out["wer_filewise_avg"] = finite_avg([r.get("wer") for r in rows]) - - if pred_texts and gt_texts: - try: - out["cer_cumulative"] = float(word_error_rate_detail(hypotheses=pred_texts, references=gt_texts, use_cer=True)[0]) - except Exception: - out["cer_cumulative"] = None - try: - out["wer_cumulative"] = float(word_error_rate_detail(hypotheses=pred_texts, references=gt_texts, use_cer=False)[0]) - except Exception: - out["wer_cumulative"] = None - else: - out["cer_cumulative"] = None - out["wer_cumulative"] = None - - out["ssim_pred_gt_avg"] = finite_avg([r.get("pred_gt_ssim") for r in rows]) - out["ssim_pred_context_avg"] = finite_avg([r.get("pred_context_ssim") for r in rows]) - out["ssim_gt_context_avg"] = finite_avg([r.get("gt_context_ssim") for r in rows]) - out["ssim_pred_gt_avg_alternate"] = finite_avg([r.get("pred_gt_ssim_alternate") for r in rows]) - out["ssim_pred_context_avg_alternate"] = finite_avg([r.get("pred_context_ssim_alternate") for r in rows]) - out["ssim_gt_context_avg_alternate"] = finite_avg([r.get("gt_context_ssim_alternate") for r in rows]) - - gt_audio_texts = [r.get("gt_audio_text") for r in rows] - if gt_audio_texts and all(x is not None for x in gt_audio_texts): - try: - out["cer_gt_audio_cumulative"] = float( - word_error_rate_detail(hypotheses=gt_audio_texts, references=gt_texts, use_cer=True)[0] - ) - except Exception: - out["cer_gt_audio_cumulative"] = _nan() - try: - out["wer_gt_audio_cumulative"] = float( - word_error_rate_detail(hypotheses=gt_audio_texts, references=gt_texts, use_cer=False)[0] - ) - except Exception: - out["wer_gt_audio_cumulative"] = _nan() - else: - out["cer_gt_audio_cumulative"] = _nan() - out["wer_gt_audio_cumulative"] = _nan() - - out["utmosv2_avg"] = finite_avg([r.get("utmosv2") for r in rows]) - out["total_gen_audio_seconds"] = sum(float(r.get("total_gen_audio_seconds", r.get("pred_audio_seconds", 0.0)) or 0.0) for r in rows) - out["frechet_codec_distance"] = _nan() - - eou_types = [r.get("eou_type") for r in rows] - if eou_types and eou_types[0] is not None: - counts = Counter(eou_types) - if EoUType is not None: - labels = list(EoUType.error_types()) - good_label = EoUType.GOOD - else: - labels = ["cutoff", "silence", "noise"] - good_label = "good" - - for label in labels: - out[f"eou_{label}_rate"] = counts.get(label, 0) / n - out["eou_error_rate"] = 1.0 - counts.get(good_label, 0) / n - else: - out["eou_cutoff_rate"] = _nan() - out["eou_silence_rate"] = _nan() - out["eou_noise_rate"] = _nan() - out["eou_error_rate"] = _nan() - - return out - - -def _load_asr_model_for_metrics(args, rank: int): - """Load ASR directly, matching the EasyMagpie/MagpieTTS evaluation style.""" - asr_cls = ASRModel - if asr_cls is None and nemo_asr is not None: - asr_cls = getattr(getattr(nemo_asr, "models", None), "ASRModel", None) - if asr_cls is None: - raise RuntimeError("NeMo ASRModel is unavailable, cannot load ASR model.") - - all_rank_print(rank, f"loading ASR model after generation: {args.asr_model_name}") - with fp32_precision(): - asr_model = asr_cls.from_pretrained(model_name=args.asr_model_name) - asr_model = asr_model.to(_metric_device()).eval() - - return asr_model - - -def _asr_transcribe_audio_batch(asr_model, audio: torch.Tensor, audio_lens: torch.Tensor, batch_size: int): - audio_list = [a[: int(alen.item())].detach().cpu() for a, alen in zip(audio, audio_lens)] - with fp32_precision(), torch.inference_mode(): - hyps = asr_model.transcribe(audio_list, batch_size=batch_size, verbose=False) - - out = [] - for hyp in hyps: - if hasattr(hyp, "text"): - out.append(str(hyp.text)) - else: - out.append(str(hyp)) - return out - - - -def compute_metrics_after_generation(args, rank: int, world_size: int, metric_items: List[Dict[str, Any]]): - """ - Compute metrics after generation without the speechlm2 metric wrappers. - - This follows the MagpieTTS/EasyMagpieTTS evaluation style more closely: - - ASR is loaded directly from args.asr_model_name and used for transcription. - - CER/WER are computed from ASR hypotheses with word_error_rate_detail. - - Speaker similarity is computed with the MagpieTTS embedding helper and - reported as SSIM, especially pred_context_ssim / ssim. - """ - metric_start = time.time() - - if len(metric_items) == 0: - return { - "rank": int(rank), - "world_size": int(world_size), - "num_processed": 0, - "num_metric_items": 0, - "metric_elapsed_sec": 0.0, - "intelligibility": {}, - "speaker_similarity": {}, - "magpie_style_metrics": {}, - }, [] - - normalizer = EnglishTextNormalizer() - normalizer.ignore_patterns = r"$^" - filewise_rows = [] - - # ASR pass, directly using ASRModel as in MagpieTTS evaluation. - asr_model = _load_asr_model_for_metrics(args, rank=rank) - - for batch_items in chunk_list(metric_items, args.metric_batch_size): - pred_audio, pred_lens, _, _, output_sr, _ = _load_metric_batch_audio(batch_items, args) - - with fp32_precision(): - pred_16k = resample(pred_audio, output_sr, 16000) - pred_16k_lens = (pred_lens / output_sr * 16000).to(torch.long) - - asr_hyps = _asr_transcribe_audio_batch( - asr_model=asr_model, - audio=pred_16k, - audio_lens=pred_16k_lens, - batch_size=len(batch_items), - ) - - for item, hyp in zip(batch_items, asr_hyps): - ref_norm = normalizer(str(item["reference_text"])).strip() - hyp_norm = normalizer(str(hyp)).strip() - - detailed_cer = _safe_word_error_detail(hyp_norm, ref_norm, use_cer=True) - detailed_wer = _safe_word_error_detail(hyp_norm, ref_norm, use_cer=False) - cer = _safe_detail_value(detailed_cer) - wer = _safe_detail_value(detailed_wer) - - row = dict(item) - row["asr_hyp"] = hyp - row["pred_text"] = hyp_norm - row["gt_text"] = ref_norm - row["detailed_cer"] = detailed_cer - row["detailed_wer"] = detailed_wer - row["cer"] = cer - row["wer"] = wer - row["ssim"] = _nan() - row["gt_audio_text"] = None - row["utmosv2"] = _nan() - row["eou_type"] = None - row["eou_trailing_duration"] = _nan() - row["eou_trail_rms_ratio"] = _nan() - row["pred_gt_ssim"] = _nan() - row["pred_context_ssim"] = _nan() - row["gt_context_ssim"] = _nan() - row["pred_gt_ssim_alternate"] = _nan() - row["pred_context_ssim_alternate"] = _nan() - row["gt_context_ssim_alternate"] = _nan() - row["total_gen_audio_seconds"] = float(row.get("pred_audio_seconds", 0.0) or 0.0) - filewise_rows.append(row) - - del asr_model - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - # Speaker similarity pass. This is the standardized "SSIM" used by the - # MagpieTTS evaluation scripts: pred_context_ssim is the main speaker - # similarity against the conditioning/context audio. - _compute_speaker_similarity_rows(args, filewise_rows) - for row in filewise_rows: - row["ssim"] = row.get("pred_context_ssim", _nan()) - - _compute_utmos_rows(args, filewise_rows, rank=rank) - _compute_eou_rows(args, filewise_rows, rank=rank) - - magpie_style_metrics = compute_magpie_style_global_metrics(filewise_rows) - - cer_wer = { - "cer": magpie_style_metrics.get("cer_cumulative"), - "wer": magpie_style_metrics.get("wer_cumulative"), - "cer_dataset": magpie_style_metrics.get("cer_cumulative"), - "wer_dataset": magpie_style_metrics.get("wer_cumulative"), - "cer_filewise_avg": magpie_style_metrics.get("cer_filewise_avg"), - "wer_filewise_avg": magpie_style_metrics.get("wer_filewise_avg"), - } - - speaker_similarity = { - "ssim": magpie_style_metrics.get("ssim_pred_context_avg"), - "ssim_dataset": magpie_style_metrics.get("ssim_pred_context_avg"), - "ssim_pred_gt_avg": magpie_style_metrics.get("ssim_pred_gt_avg"), - "ssim_pred_context_avg": magpie_style_metrics.get("ssim_pred_context_avg"), - "ssim_gt_context_avg": magpie_style_metrics.get("ssim_gt_context_avg"), - "ssim_pred_gt_avg_alternate": magpie_style_metrics.get("ssim_pred_gt_avg_alternate"), - "ssim_pred_context_avg_alternate": magpie_style_metrics.get("ssim_pred_context_avg_alternate"), - "ssim_gt_context_avg_alternate": magpie_style_metrics.get("ssim_gt_context_avg_alternate"), - } - - metric_elapsed = time.time() - metric_start - - rank_metrics = { - "rank": int(rank), - "world_size": int(world_size), - "num_processed": len({(x["run_id"], x["dataset_index"]) for x in metric_items}), - "num_metric_items": int(len(metric_items)), - "metric_elapsed_sec": float(metric_elapsed), - "intelligibility": cer_wer, - "speaker_similarity": speaker_similarity, - "magpie_style_metrics": magpie_style_metrics, - } - - return rank_metrics, filewise_rows - - -# ----------------------------- -# Merge helpers -# ----------------------------- - - -def compute_and_save_rank_metrics_file(args, rank_metrics: Dict[str, Any], rank: int): - rank_path = os.path.join(args.out_dir, f"metrics_rank{rank:04d}.json") - write_json(rank_path, rank_metrics) - return rank_metrics - - -def merge_metrics_on_rank0(args, rank, world_size): - if rank != 0: - return None - - rank_metric_files = [os.path.join(args.out_dir, f"metrics_rank{r:04d}.json") for r in range(world_size)] - - rank_metrics = [] - for path in rank_metric_files: - if not os.path.exists(path): - logging.warning(f"Missing rank metric file: {path}") - continue - with open(path, "r", encoding="utf-8") as f: - rank_metrics.append(json.load(f)) - - total_n = sum(int(m.get("num_metric_items", m.get("num_processed", 0))) for m in rank_metrics) - - def weighted_average(section: str): - keys = set() - for m in rank_metrics: - keys.update(m.get(section, {}).keys()) - - out = {} - for k in sorted(keys): - numerator = 0.0 - denominator = 0 - - for m in rank_metrics: - n = int(m.get("num_metric_items", m.get("num_processed", 0))) - if n <= 0: - continue - - value = m.get(section, {}).get(k, None) - if value is None or isinstance(value, str): - continue - - try: - value = float(value) - except Exception: - continue - - numerator += value * n - denominator += n - - if denominator > 0: - out[k] = numerator / denominator - - return out - - final_metrics = { - "world_size": int(world_size), - "num_metric_items": int(total_n), - "aggregation": "sum(rank_metric * rank_num_metric_items) / total_num_metric_items", - "intelligibility": weighted_average("intelligibility"), - "speaker_similarity": weighted_average("speaker_similarity"), - "ranks": rank_metrics, - } - - final_json_path = os.path.join(args.out_dir, "metrics_final.json") - final_txt_path = os.path.join(args.out_dir, "metrics_final.txt") - - write_json(final_json_path, final_metrics) - - final_text = format_final_metric_text(final_metrics) - write_text_atomic(final_txt_path, final_text) - - print("\n--- Final Evaluation Metrics ---", flush=True) - print(final_text, flush=True) - - logging.info(f"Final metrics JSON saved to: {final_json_path}") - logging.info(f"Final metrics TXT saved to: {final_txt_path}") - logging.info(json.dumps(final_metrics, indent=2, sort_keys=True)) - - return final_metrics - - -def _cer_sort_value(row: Dict[str, Any]) -> float: - """Return finite CER for sorting; missing/non-finite values go last.""" - value = row.get("cer", None) - if value is None: - return float("-inf") - try: - value = float(value) - except Exception: - return float("-inf") - if not math.isfinite(value): - return float("-inf") - return value - - -def merge_filewise_metrics_on_rank0(args, rank: int, world_size: int): - """Merge per-turn rank metric rows and write global CER-sorted outputs. - - Writes: - - filewise_metrics_turns_sorted_by_cer.jsonl/csv: - one row per turn, merged across ranks, sorted by turn CER. - - filewise_metrics_global_sorted_by_cer.jsonl/csv: - compatibility alias for the same turn-level global output. - - filewise_metrics_sorted_by_cer.jsonl/csv: - one row per original sample, with turn metric lists, sorted by - sample-average CER. - """ - if rank != 0 or not args.save_filewise_metrics: - return [] - - turn_rows = [] - - for r in range(world_size): - path = os.path.join(args.out_dir, f"filewise_metrics_rank{r:04d}.jsonl") - if not os.path.exists(path): - logging.warning(f"Missing filewise metrics file: {path}") - continue - - with open(path, "r", encoding="utf-8") as f: - for line in f: - line = line.strip() - if line: - turn_rows.append(json.loads(line)) - - # Deduplicate DistributedSampler padding repeats, but preserve --num_eval_runs. - deduped_turns = {} - for row in turn_rows: - run_id = int(row.get("run_id", 0)) - idx = int(row.get("dataset_index", -1)) - turn_id = int(row.get("turn_id", 0)) - key = (run_id, idx, turn_id) - if key not in deduped_turns: - deduped_turns[key] = row - - turn_rows = list(deduped_turns.values()) - - # Global turn-level output sorted by CER descending. - turn_rows_sorted = sorted( - turn_rows, - key=lambda x: ( - x.get("cer") is not None, - _cer_sort_value(x), - ), - reverse=True, - ) - - turn_jsonl_path = os.path.join(args.out_dir, "filewise_metrics_turns_sorted_by_cer.jsonl") - turn_csv_path = os.path.join(args.out_dir, "filewise_metrics_turns_sorted_by_cer.csv") - turn_json_path = os.path.join(args.out_dir, "filewise_metrics_turns_sorted_by_cer.json") - write_jsonl(turn_jsonl_path, turn_rows_sorted) - write_json(turn_json_path, {"filewise_metrics": turn_rows_sorted}) - write_turnwise_csv(turn_csv_path, turn_rows_sorted) - - # Compatibility alias using the "global" name. - global_jsonl_path = os.path.join(args.out_dir, "filewise_metrics_global_sorted_by_cer.jsonl") - global_csv_path = os.path.join(args.out_dir, "filewise_metrics_global_sorted_by_cer.csv") - write_jsonl(global_jsonl_path, turn_rows_sorted) - write_turnwise_csv(global_csv_path, turn_rows_sorted) - - turn_global_metrics = compute_magpie_style_global_metrics(turn_rows_sorted) - turn_global_metrics_path = os.path.join(args.out_dir, "metrics_final_turn_global.json") - write_json( - turn_global_metrics_path, - { - "aggregation": "magpie_style_global_metrics_over_turn_rows", - **turn_global_metrics, - }, - ) - - logging.info(f"Saved global turn-level filewise metrics JSONL to: {turn_jsonl_path}") - logging.info(f"Saved global turn-level filewise metrics JSON to: {turn_json_path}") - logging.info(f"Saved global turn-level filewise metrics CSV to: {turn_csv_path}") - logging.info(f"Saved global filewise compatibility JSONL to: {global_jsonl_path}") - logging.info(f"Saved global filewise compatibility CSV to: {global_csv_path}") - logging.info(f"Saved turn global metrics JSON to: {turn_global_metrics_path}") - - # Group turn rows into one row per original file/sample. - grouped = {} - for row in turn_rows: - run_id = int(row.get("run_id", 0)) - idx = int(row.get("dataset_index", -1)) - key = (run_id, idx) - - if key not in grouped: - grouped[key] = { - "run_id": run_id, - "dataset_index": idx, - "rank": int(row.get("rank", -1)), - "target_audio_path": row.get("target_audio_path", ""), - "context_audio_path": row.get("context_audio_path", ""), - "turn_rows": [], - } - - grouped[key]["turn_rows"].append(row) - - def avg(vals): - finite_vals = [] - for x in vals: - if x is None: - continue - try: - x = float(x) - except Exception: - continue - if math.isfinite(x): - finite_vals.append(x) - return None if not finite_vals else sum(finite_vals) / len(finite_vals) - - sample_rows = [] - for _, group in grouped.items(): - turns = sorted(group["turn_rows"], key=lambda x: int(x.get("turn_id", 0))) - - cer_turns = [r.get("cer") for r in turns] - wer_turns = [r.get("wer") for r in turns] - ssim_turns = [r.get("ssim") for r in turns] - - pred_gt_ssim_turns = [r.get("pred_gt_ssim") for r in turns] - pred_context_ssim_turns = [r.get("pred_context_ssim") for r in turns] - gt_context_ssim_turns = [r.get("gt_context_ssim") for r in turns] - pred_gt_ssim_alternate_turns = [r.get("pred_gt_ssim_alternate") for r in turns] - pred_context_ssim_alternate_turns = [r.get("pred_context_ssim_alternate") for r in turns] - gt_context_ssim_alternate_turns = [r.get("gt_context_ssim_alternate") for r in turns] - utmosv2_turns = [r.get("utmosv2") for r in turns] - eou_type_turns = [r.get("eou_type") for r in turns] - eou_trailing_duration_turns = [r.get("eou_trailing_duration") for r in turns] - eou_trail_rms_ratio_turns = [r.get("eou_trail_rms_ratio") for r in turns] - - sample_row = { - "run_id": group["run_id"], - "dataset_index": group["dataset_index"], - "rank": group["rank"], - "num_turns": len(turns), - "turn_ids": [int(r.get("turn_id", 0)) for r in turns], - "target_audio_path": group["target_audio_path"], - "context_audio_path": group["context_audio_path"], - "pred_audio_paths": [r.get("pred_audio_path", "") for r in turns], - "predicted_codes_paths": [r.get("predicted_codes_path") for r in turns], - "sample_pred_audio_path": turns[0].get("sample_pred_audio_path", turns[0].get("pred_audio_path", "")), - "sample_predicted_codes_path": turns[0].get( - "sample_predicted_codes_path", - turns[0].get("predicted_codes_path"), - ), - "pred_audio_seconds_turns": [r.get("pred_audio_seconds") for r in turns], - "reference_text": [r.get("reference_text", "") for r in turns], - "asr_hyp": [r.get("asr_hyp", "") for r in turns], - "gt_text": [r.get("gt_text", "") for r in turns], - "pred_text": [r.get("pred_text", "") for r in turns], - "cer_turns": cer_turns, - "wer_turns": wer_turns, - "ssim_turns": ssim_turns, - "pred_gt_ssim_turns": pred_gt_ssim_turns, - "pred_context_ssim_turns": pred_context_ssim_turns, - "gt_context_ssim_turns": gt_context_ssim_turns, - "pred_gt_ssim_alternate_turns": pred_gt_ssim_alternate_turns, - "pred_context_ssim_alternate_turns": pred_context_ssim_alternate_turns, - "gt_context_ssim_alternate_turns": gt_context_ssim_alternate_turns, - "utmosv2_turns": utmosv2_turns, - "eou_type_turns": eou_type_turns, - "eou_trailing_duration_turns": eou_trailing_duration_turns, - "eou_trail_rms_ratio_turns": eou_trail_rms_ratio_turns, - "cer": avg(cer_turns), - "wer": avg(wer_turns), - "ssim": avg(ssim_turns), - "pred_gt_ssim": avg(pred_gt_ssim_turns), - "pred_context_ssim": avg(pred_context_ssim_turns), - "gt_context_ssim": avg(gt_context_ssim_turns), - "pred_gt_ssim_alternate": avg(pred_gt_ssim_alternate_turns), - "pred_context_ssim_alternate": avg(pred_context_ssim_alternate_turns), - "gt_context_ssim_alternate": avg(gt_context_ssim_alternate_turns), - "utmosv2": avg(utmosv2_turns), - "eou_error": None if not eou_type_turns or eou_type_turns[0] is None else float( - sum(1 for x in eou_type_turns if str(x).lower() != "good") / len(eou_type_turns) - ), - "total_gen_audio_seconds": sum(float(r.get("total_gen_audio_seconds", r.get("pred_audio_seconds", 0.0)) or 0.0) for r in turns), - } - - sample_rows.append(sample_row) - - # Sample-level output sorted by average CER descending. - sample_rows.sort( - key=lambda x: ( - x.get("cer") is not None, - _cer_sort_value(x), - ), - reverse=True, - ) - - jsonl_path = os.path.join(args.out_dir, "filewise_metrics_sorted_by_cer.jsonl") - json_path = os.path.join(args.out_dir, "filewise_metrics_sorted_by_cer.json") - csv_path = os.path.join(args.out_dir, "filewise_metrics_sorted_by_cer.csv") - - write_jsonl(jsonl_path, sample_rows) - write_json(json_path, {"filewise_metrics": sample_rows}) - write_filewise_csv(csv_path, sample_rows) - - logging.info(f"Saved sample-level filewise metrics JSONL to: {jsonl_path}") - logging.info(f"Saved sample-level filewise metrics JSON to: {json_path}") - logging.info(f"Saved sample-level filewise metrics CSV to: {csv_path}") - - topk = min(int(args.filewise_metrics_topk_log), len(sample_rows)) - if topk > 0: - logging.info(f"Top {topk} worst CER samples:") - for row in sample_rows[:topk]: - logging.info( - "run_id=%s dataset_index=%s num_turns=%s cer=%s wer=%s ssim=%s path=%s" - % ( - row.get("run_id"), - row.get("dataset_index"), - row.get("num_turns"), - row.get("cer"), - row.get("wer"), - row.get("ssim"), - row.get("target_audio_path"), - ) - ) - - topk_turns = min(int(args.filewise_metrics_topk_log), len(turn_rows_sorted)) - if topk_turns > 0: - logging.info(f"Top {topk_turns} worst CER turns:") - for row in turn_rows_sorted[:topk_turns]: - logging.info( - "run_id=%s dataset_index=%s turn_id=%s cer=%s wer=%s ssim=%s path=%s text=%s" - % ( - row.get("run_id"), - row.get("dataset_index"), - row.get("turn_id"), - row.get("cer"), - row.get("wer"), - row.get("ssim"), - row.get("pred_audio_path"), - row.get("reference_text"), - ) - ) - - return sample_rows - -def compute_frechet_codec_distance_from_sample_rows(args, rows: List[Dict[str, Any]]): - """Compute FCD in the same spirit as MagpieTTS: GT audio vs predicted codec codes.""" - if args.disable_fcd: - return _nan() - if FrechetCodecDistance is None: - logging.warning("FrechetCodecDistance is unavailable; setting frechet_codec_distance to NaN.") - return _nan() - - gt_paths = [] - code_paths = [] - seen = set() - - for row in rows: - key = (int(row.get("run_id", 0)), int(row.get("dataset_index", -1))) - if key in seen: - continue - seen.add(key) - - gt_path = _resolve_audio_path(row.get("target_audio_path"), args.audio_dir) - code_path = row.get("sample_predicted_codes_path") or row.get("predicted_codes_path") - if gt_path and code_path and os.path.exists(gt_path) and os.path.exists(code_path): - gt_paths.append(gt_path) - code_paths.append(code_path) - - if not gt_paths: - logging.warning("No valid GT-audio/predicted-code pairs found for FCD; setting FCD to NaN.") - return _nan() - - device = _metric_device() - try: - fcd_metric = FrechetCodecDistance(codec_name=args.codec_model_path).to(device) - for gt_path, code_path in zip(gt_paths, code_paths): - fcd_metric.update_from_audio_file(gt_path, True) - predicted_codes = torch.load(code_path, map_location="cpu").unsqueeze(0).to(device) - predicted_codes_lens = torch.tensor([predicted_codes.size(-1)], dtype=torch.int, device=device) - fcd_metric.update(predicted_codes, predicted_codes_lens, False) - - fcd = fcd_metric.compute().detach().cpu().item() - fcd_metric.reset() - return float(fcd) - except Exception as e: - logging.warning(f"Frechet Codec Distance computation failed: {repr(e)}") - return _nan() - - -def compute_aggregates_from_filewise_rows(rows: List[Dict[str, Any]]): - """Aggregate over sample-level rows using the MagpieTTS evaluation metric set.""" - if len(rows) == 0: - out = compute_magpie_style_global_metrics([]) - out["cer"] = None - out["wer"] = None - out["ssim"] = None - out["num_samples"] = 0 - return out - - def avg_key(key): - return finite_avg([r.get(key) for r in rows]) - - out = { - "cer": avg_key("cer"), - "wer": avg_key("wer"), - "ssim": avg_key("ssim"), - "num_samples": len(rows), - "cer_filewise_avg": avg_key("cer"), - "wer_filewise_avg": avg_key("wer"), - "ssim_pred_gt_avg": avg_key("pred_gt_ssim"), - "ssim_pred_context_avg": avg_key("pred_context_ssim"), - "ssim_gt_context_avg": avg_key("gt_context_ssim"), - "ssim_pred_gt_avg_alternate": avg_key("pred_gt_ssim_alternate"), - "ssim_pred_context_avg_alternate": avg_key("pred_context_ssim_alternate"), - "ssim_gt_context_avg_alternate": avg_key("gt_context_ssim_alternate"), - "utmosv2_avg": avg_key("utmosv2"), - "total_gen_audio_seconds": sum(float(r.get("total_gen_audio_seconds", 0.0) or 0.0) for r in rows), - "frechet_codec_distance": _nan(), - } - - # Sample rows contain lists, so cumulative CER/WER are computed by flattening - # the normalized turn text lists. - pred_texts = [] - gt_texts = [] - for row in rows: - preds = row.get("pred_text", row.get("asr_hyp", [])) - refs = row.get("gt_text", row.get("reference_text", [])) - if not isinstance(preds, list): - preds = [preds] - if not isinstance(refs, list): - refs = [refs] - for pred, ref in zip(preds, refs): - ref = "" if ref is None else str(ref).strip() - pred = "" if pred is None else str(pred).strip() - if ref: - pred_texts.append(pred) - gt_texts.append(ref) - - if pred_texts and gt_texts: - try: - out["cer_cumulative"] = float(word_error_rate_detail(hypotheses=pred_texts, references=gt_texts, use_cer=True)[0]) - except Exception: - out["cer_cumulative"] = None - try: - out["wer_cumulative"] = float(word_error_rate_detail(hypotheses=pred_texts, references=gt_texts, use_cer=False)[0]) - except Exception: - out["wer_cumulative"] = None - else: - out["cer_cumulative"] = None - out["wer_cumulative"] = None - - out["cer_gt_audio_cumulative"] = _nan() - out["wer_gt_audio_cumulative"] = _nan() - - eou_types = [] - for row in rows: - values = row.get("eou_type_turns", []) - if isinstance(values, list): - eou_types.extend(values) - eou_types = [x for x in eou_types if x is not None] - - if eou_types: - counts = Counter(eou_types) - n = len(eou_types) - if EoUType is not None: - labels = list(EoUType.error_types()) - good_label = EoUType.GOOD - else: - labels = ["cutoff", "silence", "noise"] - good_label = "good" - for label in labels: - out[f"eou_{label}_rate"] = counts.get(label, 0) / n - out["eou_error_rate"] = 1.0 - counts.get(good_label, 0) / n - else: - out["eou_cutoff_rate"] = _nan() - out["eou_silence_rate"] = _nan() - out["eou_noise_rate"] = _nan() - out["eou_error_rate"] = _nan() - - return out - -def save_filewise_final_summary(args, filewise_rows: List[Dict[str, Any]]): - filewise_summary = compute_aggregates_from_filewise_rows(filewise_rows) - filewise_summary["frechet_codec_distance"] = compute_frechet_codec_distance_from_sample_rows(args, filewise_rows) - save_easymagpie_style_eval_outputs(args, filewise_rows, filewise_summary) - - obj = { - "aggregation": "mean_over_sample_metrics_each_sample_contains_turn_metric_lists", - **filewise_summary, - } - - path = os.path.join(args.out_dir, "metrics_final_filewise_average.json") - write_json(path, obj) - - sample_metrics_final_path = os.path.join(args.out_dir, "metrics_final_sample_average.json") - write_json(sample_metrics_final_path, obj) - - final_txt_path = os.path.join(args.out_dir, "metrics_final.txt") - final_text = format_filewise_final_metric_text(filewise_summary) - write_text_atomic(final_txt_path, final_text) - - print("\n--- Final Sample-Averaged Evaluation Metrics ---", flush=True) - print(final_text, flush=True) - - logging.info(f"Filewise averaged final metrics saved to: {path}") - logging.info(f"Sample averaged metrics_final JSON saved to: {sample_metrics_final_path}") - logging.info(f"Final metrics TXT saved to: {final_txt_path}") - - return obj - - -# ----------------------------- -# Args / main -# ----------------------------- - - -def parse_args(): - parser = argparse.ArgumentParser(description="EasyMagpieTTS Multi-GPU Inference Evaluation") - - parser.add_argument("--checkpoint_path", type=str, required=True) - parser.add_argument("--codec_model_path", type=str, required=True) - parser.add_argument("--datasets_json_path", type=str, required=True) - parser.add_argument("--out_dir", type=str, required=True) - - parser.add_argument("--phoneme_tokenizer_path", type=str, default=None) - parser.add_argument("--audio_dir", type=str, default=None) - parser.add_argument("--inference_dtype", type=str, default="float32", choices=["float32", "float16", "bfloat16"]) - parser.add_argument("--debug_dtype", action="store_true") - parser.add_argument("--debug_gpu_assignment", action="store_true") - parser.add_argument("--use_librosa", action="store_true") - - parser.add_argument("--batch_size", type=int, default=6) - parser.add_argument("--num_workers", type=int, default=4) - parser.add_argument( - "--emulate_multiturn", - action="store_true", - help=( - "Group scalar single-turn JSONL rows by speaker into synthetic multiturn samples. " - "This replaces the older --num_turns behavior." - ), - ) - parser.add_argument( - "--emulate_multiturn_num_turns", - type=int, - default=1, - help="Number of scalar single-turn rows to group when --emulate_multiturn is enabled.", - ) - parser.add_argument("--max_eval_turns", type=int, default=6) - - parser.add_argument( - "--inference_mode", - type=str, - default="auto", - choices=["auto", "multiturn_user_audio", "single_turn"], - help=( - "auto selects multiturn_user_audio for samples with list text or user_audio_file_path, " - "and single_turn for classic scalar-text datasets such as LibriTTS. " - "single_turn does not prefill with user/silence audio and supports batch_size > 1." - ), - ) - - parser.add_argument("--user_custom_speaker_reference", action="store_true") - parser.add_argument("--inference_speaker_reference", type=str, default=None) - parser.add_argument("--language", type=str, default="en") - - parser.add_argument("--use_cfg", action="store_true") - parser.add_argument("--cfg_scale", type=float, default=2.5) - parser.add_argument("--temperature", type=float, default=0.7) - parser.add_argument("--topk", type=int, default=80) - parser.add_argument("--max_tts_steps", type=int, default=2000) - parser.add_argument("--normalize_volume", type=lambda x: str(x).lower() in ["true", "1", "yes"], default=True) - - parser.add_argument( - "--save_filewise_metrics", - action=argparse.BooleanOptionalAction, - default=True, - help="Save filewise metrics. Enabled by default. Use --no-save_filewise_metrics to disable.", - ) - parser.add_argument( - "--filewise_metrics_topk_log", - type=int, - default=20, - help="Number of worst CER samples to print on rank 0.", - ) - parser.add_argument( - "--num_eval_runs", - type=int, - default=1, - help="Repeat the full eval set N times. Repetitions are preserved in final filewise average.", - ) - parser.add_argument( - "--sort_by_text_token_count", - action="store_true", - help="Sort eval samples by total text token count before distributed sharding for better load balancing.", - ) - parser.add_argument( - "--metric_batch_size", - type=int, - default=8, - help="Batch size used for post-generation ASR/SSIM metric computation.", - ) - parser.add_argument( - "--max_metric_audio_sec", - type=float, - default=120.0, - help="Clamp generated audio length used for ASR/SSIM metrics to avoid metric OOM/hangs.", - ) - parser.add_argument( - "--asr_model_name", - type=str, - default="nvidia/parakeet-tdt-1.1b", - help="Pretrained ASR model used for CER/WER, matching the EasyMagpie/MagpieTTS eval default.", - ) - parser.add_argument( - "--sv_model_type", - type=str, - default="titanet", - choices=["titanet", "wavlm"], - help="Speaker verification model type for MagpieTTS-style SSIM metrics.", - ) - parser.add_argument( - "--disable_speaker_metrics", - action="store_true", - help="Disable pred/GT/context speaker similarity metrics.", - ) - parser.add_argument( - "--disable_utmosv2", - action="store_true", - help="Disable UTMOSv2. By default UTMOSv2 is computed when the dependency is available.", - ) - parser.add_argument( - "--disable_eou", - action="store_true", - help="Disable end-of-utterance classification metrics.", - ) - parser.add_argument( - "--disable_fcd", - action="store_true", - help="Disable Frechet Codec Distance. By default FCD is computed from saved predicted codec codes.", - ) - parser.add_argument( - "--eou_model_name", - type=str, - default="facebook/wav2vec2-base-960h", - help="Hugging Face model id or local path for the EOU classifier.", - ) - parser.add_argument( - "--eou_batch_size", - type=int, - default=32, - help="Batch size for EOU classification.", - ) - - parser.add_argument( - "--save_plots", - action=argparse.BooleanOptionalAction, - default=True, - help="Save EasyMagpie/MagpieTTS-style violin plots. Enabled by default.", - ) - parser.add_argument( - "--violin_plot_metrics", - type=str, - nargs="*", - default=list(DEFAULT_VIOLIN_METRICS), - help="Metrics to include in violin plots.", - ) - - return parser.parse_args() - - -def main(): - args = parse_args() - - os.makedirs(args.out_dir, exist_ok=True) - os.makedirs(get_audio_out_dir(args), exist_ok=True) - os.makedirs(get_generated_turn_audio_dir(args), exist_ok=True) - os.makedirs(get_context_metric_audio_dir(args), exist_ok=True) - os.makedirs(get_predicted_codes_dir(args), exist_ok=True) - - distributed, rank, local_rank, world_size, device_index = setup_distributed() - - if args.inference_mode == "multiturn_user_audio" and args.batch_size != 1: - raise RuntimeError( - "--inference_mode multiturn_user_audio requires --batch_size=1 per process. " - "Use multiple GPUs/processes for parallelism instead of increasing batch_size." - ) - - if args.num_eval_runs <= 0: - raise RuntimeError("--num_eval_runs must be >= 1.") - - if args.emulate_multiturn and args.emulate_multiturn_num_turns <= 1: - raise RuntimeError("--emulate_multiturn_num_turns must be > 1 when --emulate_multiturn is enabled.") - - target_device = torch.device(f"cuda:{device_index}" if torch.cuda.is_available() and device_index >= 0 else "cpu") - target_dtype = getattr(torch, args.inference_dtype) - torch.set_default_dtype(target_dtype) - - hostname = socket.gethostname() - cuda_name = torch.cuda.get_device_name(target_device) if torch.cuda.is_available() and device_index >= 0 else "cpu" - - all_rank_print( - rank, - f"host={hostname} local_rank={local_rank} world_size={world_size} " - f"device={target_device} device_name={cuda_name}", - ) - - model = build_model_and_codec(args, target_device, target_dtype) - codec_sil_codes = model.codec_sil_codes - - if args.debug_dtype: - handles, stats, examples = attach_dtype_counter(model) - else: - handles = stats = examples = None - - emulate_multiturn_num_turns = args.emulate_multiturn_num_turns if args.emulate_multiturn else 1 - full_eval_dataset = EvalJSONLDataset( - args.datasets_json_path, - emulate_multiturn_num_turns=emulate_multiturn_num_turns, - ) - # debug - # full_eval_dataset.samples = full_eval_dataset.samples[:7] - - if args.sort_by_text_token_count: - full_eval_dataset = SortedByTextTokenCountDataset( - full_eval_dataset, - model=model, - max_eval_turns=args.max_eval_turns, - descending=True, - ) - - collate_fn = partial( - collate_and_tokenize_custom, - model=model, - sample_rate=model.sample_rate, - root_path=args.audio_dir, - normalize_audio_volume=args.normalize_volume, - use_librosa=args.use_librosa, - max_eval_turns=args.max_eval_turns, - inference_mode=args.inference_mode, - ) - - speaker_wav = load_speaker_wav_if_needed(args, model, target_dtype) - - generation_start = time.time() - all_metric_items = [] - total_batches = 0 - total_generated_samples = 0 - - for run_id in range(args.num_eval_runs): - if distributed: - sampler = DistributedSampler( - full_eval_dataset, - num_replicas=world_size, - rank=rank, - shuffle=False, - drop_last=False, - ) - sampler.set_epoch(run_id) - else: - sampler = SequentialSampler(full_eval_dataset) - - if args.debug_gpu_assignment: - try: - assigned_indices = list(iter(sampler)) - assigned_dataset_indices = [ - int(full_eval_dataset[i].get("__dataset_index__", -1)) for i in assigned_indices - ] - all_rank_print( - rank, - f"run_id={run_id} assigned {len(assigned_dataset_indices)} / {len(full_eval_dataset)} " - f"samples to gpu={local_rank}: dataset_indices={assigned_dataset_indices}", - ) - except Exception as e: - all_rank_print(rank, f"Could not print assigned indices: {repr(e)}") - - dataloader = DataLoader( - dataset=full_eval_dataset, - batch_size=args.batch_size, - sampler=sampler, - collate_fn=collate_fn, - num_workers=args.num_workers, - pin_memory=True, - drop_last=False, - ) - - for batch_id, inputs in enumerate(dataloader): - total_batches += 1 - batch_indices = inputs.get("dataset_indices", []) - total_generated_samples += len(batch_indices) - - if args.debug_gpu_assignment: - all_rank_print( - rank, - f"run_id={run_id} gpu={local_rank} batch_id={batch_id} " - f"dataset_indices={batch_indices} text_token_counts={inputs.get('text_token_counts', [])} " - f"target_paths={inputs.get('target_audio_paths', [])}", - ) - - inputs = prepare_inputs_for_device(inputs, model, args, target_dtype, speaker_wav=speaker_wav) - - finalize_output, multiturn_turn_frame_ranges, multiturn_decode_start_frame, generated_codes = run_generation( - model=model, - inputs=inputs, - args=args, - codec_sil_codes=codec_sil_codes, - ) - - metric_items = save_generation_outputs_and_build_metric_items( - model=model, - inputs=inputs, - finalize_output=finalize_output, - multiturn_turn_frame_ranges=multiturn_turn_frame_ranges, - multiturn_decode_start_frame=multiturn_decode_start_frame, - generated_codes=generated_codes, - args=args, - rank=rank, - run_id=run_id, - ) - all_metric_items.extend(metric_items) - - if args.debug_dtype and batch_id == 0 and run_id == 0: - report_dtype_stats(handles, stats, examples, rank=rank) - - generation_elapsed = time.time() - generation_start - - # Save pre-metric manifest for debugging and restartability. - metric_manifest_path = os.path.join(args.out_dir, f"metric_items_rank{rank:04d}.jsonl") - write_jsonl(metric_manifest_path, all_metric_items) - - # Free TTS/codec model memory before loading ASR and speaker encoder metrics. - del model - if speaker_wav is not None: - del speaker_wav - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - all_rank_print( - rank, - f"generation done: batches={total_batches} generated_samples_with_sampler_padding={total_generated_samples} " - f"metric_items={len(all_metric_items)} elapsed_sec={generation_elapsed:.2f}. " - "Loading ASR/SSIM metrics now.", - ) - - rank_metrics, rank_filewise_rows = compute_metrics_after_generation( - args=args, - rank=rank, - world_size=world_size, - metric_items=all_metric_items, - ) - rank_metrics["generation_elapsed_sec"] = float(generation_elapsed) - rank_metrics["num_generated_samples_with_sampler_padding"] = int(total_generated_samples) - - rank_metrics = compute_and_save_rank_metrics_file(args, rank_metrics, rank) - all_rank_print(rank, f"saved rank metrics: {json.dumps(rank_metrics, sort_keys=True)}") - - if args.save_filewise_metrics: - rank_filewise_rows.sort( - key=lambda x: ( - x.get("cer") is not None, - float(x.get("cer")) if x.get("cer") is not None else -1.0, - ), - reverse=True, - ) - - rank_filewise_path = os.path.join(args.out_dir, f"filewise_metrics_rank{rank:04d}.jsonl") - write_jsonl(rank_filewise_path, rank_filewise_rows) - all_rank_print(rank, f"saved filewise metrics: {rank_filewise_path}") - - if rank == 0: - wait_for_rank_metric_files(args, world_size) - - merge_metrics_on_rank0(args, rank, world_size) - - if args.save_filewise_metrics: - if rank == 0: - wait_for_rank_filewise_metric_files(args, world_size) - - filewise_rows = merge_filewise_metrics_on_rank0(args, rank, world_size) - - if rank == 0: - save_filewise_final_summary(args, filewise_rows) - - cleanup_distributed() - - -if __name__ == "__main__": - main() diff --git a/examples/tts/easy_magpietts_inference_multiturn_multigpu.py b/examples/tts/easy_magpietts_inference_multiturn_multigpu.py deleted file mode 100644 index fbe5a24e4712..000000000000 --- a/examples/tts/easy_magpietts_inference_multiturn_multigpu.py +++ /dev/null @@ -1,1992 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -""" -Multi-GPU evaluation script for custom EasyMagpieTTS models. - -Properties: - - One process per GPU with torch.distributed when WORLD_SIZE > 1. - - Uses DistributedSampler(drop_last=False). If len(dataset) is not divisible by - world_size, PyTorch may repeat a few samples so all ranks have equal work. - This avoids last-rank/last-batch distributed hangs. Repeated samples are - deduplicated from filewise final metrics. - - Optional --num_eval_runs N repeats the same eval set N times and reports - run-averaged filewise metrics when --save_filewise_metrics is enabled. - - Optional --sort_by_text_token_count orders samples by total text token count - so each GPU step receives similarly-sized examples. By default it sorts - ascending, so DistributedSampler padding repeats short examples. - - profile_multiturn_inference remains batch_size=1 per rank, but runs in - parallel across ranks/GPUs. - - Saves global metrics in out_dir: - metrics_rankXXXX.json - metrics_final.json - metrics_final.txt - - Saves generated audio files in: - out_dir/audios/ - - Optional filewise metrics: - --save_filewise_metrics - Saves: - filewise_metrics_rankXXXX.jsonl - filewise_metrics_sorted_by_cer.jsonl - filewise_metrics_sorted_by_cer.csv - metrics_final_filewise_average.json - The merged filewise outputs deduplicate repeated DistributedSampler samples - by dataset_index. - - Prints the final text metric summary on rank 0: - Average CER: value - Average WER: value - SECS: value - -Recommended torchrun: - CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ - torchrun --nproc_per_node=8 easy_magpietts_inference_multiturn_multigpu.py ... - -Recommended SLURM/srun: - srun --ntasks-per-node=8 --gpus-per-task=1 --gpu-bind=single:1 \ - python easy_magpietts_inference_multiturn_multigpu.py ... -""" - -import argparse -import csv -import json -import os -import socket -import time -from copy import deepcopy -from functools import partial -from typing import Any, Dict, List - -import librosa -import soundfile as sf -import torch -import torch.distributed as dist -from omegaconf import open_dict -from torch.nn.utils.rnn import pad_sequence -from torch.utils.data import DataLoader, Dataset, DistributedSampler, SequentialSampler - -from nemo.collections.audio.parts.utils.transforms import resample -from nemo.collections.asr.metrics.wer import word_error_rate -from whisper_normalizer.english import EnglishTextNormalizer -from nemo.collections.speechlm2.parts.metrics.asr_cer_wer import Intelligibility -from nemo.collections.speechlm2.parts.metrics.secs import SECS -from nemo.collections.speechlm2.parts.precision import fp32_precision -from nemo.collections.tts.models import AudioCodecModel -from nemo.collections.tts.models.easy_magpietts_inference import EasyMagpieTTSInferenceModel -from nemo.collections.tts.modules.audio_codec_modules import VectorQuantizerIndexConverter -from nemo.collections.tts.modules.magpietts_modules import CodecHelper -from nemo.collections.tts.parts.utils.tts_dataset_utils import normalize_volume -from nemo.utils import logging - - -torch.set_float32_matmul_precision("medium") -torch.backends.cudnn.allow_tf32 = True -torch.backends.cuda.matmul.allow_tf32 = True - - -def get_rank_info(): - world_size = int(os.environ.get("WORLD_SIZE", "1")) - rank = int(os.environ.get("RANK", os.environ.get("SLURM_PROCID", "0"))) - local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("SLURM_LOCALID", "0"))) - distributed = world_size > 1 - return distributed, rank, local_rank, world_size - - -def get_visible_device_index(local_rank: int) -> int: - if not torch.cuda.is_available(): - return -1 - ndev = torch.cuda.device_count() - if ndev <= 0: - return -1 - return local_rank % ndev - - -def setup_distributed(): - distributed, rank, local_rank, world_size = get_rank_info() - device_index = get_visible_device_index(local_rank) - - if torch.cuda.is_available() and device_index >= 0: - torch.cuda.set_device(device_index) - - if distributed and not dist.is_initialized(): - dist.init_process_group(backend="nccl") - dist.barrier() - - return distributed, rank, local_rank, world_size, device_index - - -def cleanup_distributed(): - if dist.is_available() and dist.is_initialized(): - dist.barrier() - dist.destroy_process_group() - - -def all_rank_print(rank: int, msg: str): - print(f"[rank={rank}] {msg}", flush=True) - - -def rank0_print(rank: int, msg: str): - if rank == 0: - print(msg, flush=True) - - -def get_audio_out_dir(args) -> str: - return os.path.join(args.out_dir, "audios") - - -def torch_rms_norm(wav: torch.Tensor, db_level: float = -27.0) -> torch.Tensor: - denom = torch.sum(wav**2) - if denom <= 0: - return wav - r = 10 ** (db_level / 20) - a = torch.sqrt((wav.size(-1) * (r**2)) / denom) - return wav * a - - -def scalarize_metric_value(v: Any): - if torch.is_tensor(v): - if v.numel() == 1: - return float(v.detach().cpu().item()) - return v.detach().cpu().tolist() - - try: - import numpy as np - - if isinstance(v, np.generic): - return float(v.item()) - except Exception: - pass - - if isinstance(v, (int, float, str, bool)) or v is None: - return v - - return str(v) - - -def metric_dict_to_jsonable(d: Dict[str, Any]) -> Dict[str, Any]: - return {str(k): scalarize_metric_value(v) for k, v in d.items()} - - -def write_json(path: str, obj: Dict[str, Any]): - os.makedirs(os.path.dirname(path), exist_ok=True) - tmp_path = path + ".tmp" - with open(tmp_path, "w", encoding="utf-8") as f: - json.dump(obj, f, indent=2, sort_keys=True) - os.replace(tmp_path, path) - - -def write_text_atomic(path: str, text: str): - os.makedirs(os.path.dirname(path), exist_ok=True) - tmp_path = path + ".tmp" - with open(tmp_path, "w", encoding="utf-8") as f: - f.write(text) - os.replace(tmp_path, path) - - -def write_jsonl(path: str, rows: List[Dict[str, Any]]): - os.makedirs(os.path.dirname(path), exist_ok=True) - tmp_path = path + ".tmp" - with open(tmp_path, "w", encoding="utf-8") as f: - for row in rows: - f.write(json.dumps(row, sort_keys=True) + "\n") - os.replace(tmp_path, path) - - -def write_filewise_csv(path: str, rows: List[Dict[str, Any]]): - os.makedirs(os.path.dirname(path), exist_ok=True) - tmp_path = path + ".tmp" - - fieldnames = [ - "run_id", - "dataset_index", - "rank", - "cer", - "wer", - "secs", - "pred_audio_seconds", - "target_audio_path", - "reference_text", - "asr_hyp", - ] - - with open(tmp_path, "w", encoding="utf-8", newline="") as f: - writer = csv.DictWriter(f, fieldnames=fieldnames) - writer.writeheader() - for row in rows: - writer.writerow({k: row.get(k, None) for k in fieldnames}) - - os.replace(tmp_path, path) - - -def get_first_metric(metrics: Dict[str, Any], names: List[str], default=None): - for name in names: - if name in metrics: - return metrics[name] - return default - - -def safe_metric_scalar(metric_dict: Dict[str, Any], preferred_keys: List[str]): - for key in preferred_keys: - if key in metric_dict: - value = metric_dict[key] - if torch.is_tensor(value): - return float(value.detach().cpu().item()) - return float(value) - return None - - -def format_final_metric_text(final_metrics: Dict[str, Any]) -> str: - intelligibility = final_metrics.get("intelligibility", {}) - secs = final_metrics.get("secs", {}) - - cer = get_first_metric(intelligibility, ["cer", "cer_dataset"]) - wer = get_first_metric(intelligibility, ["wer", "wer_dataset"]) - secs_value = get_first_metric(secs, ["secs", "secs_dataset"]) - - def fmt(x): - if x is None: - return "nan" - try: - return f"{float(x):.10f}" - except Exception: - return str(x) - - return ( - f"Average CER: {fmt(cer)}\n" - f"Average WER: {fmt(wer)}\n" - f"SECS: {fmt(secs_value)}\n" - ) - - -def format_filewise_final_metric_text(filewise_summary: Dict[str, Any]) -> str: - def fmt(x): - if x is None: - return "nan" - try: - return f"{float(x):.10f}" - except Exception: - return str(x) - - return ( - f"Average CER: {fmt(filewise_summary.get('cer'))}\n" - f"Average WER: {fmt(filewise_summary.get('wer'))}\n" - f"SECS: {fmt(filewise_summary.get('secs'))}\n" - ) - - -def attach_dtype_counter(model): - handles = [] - stats = {} - examples = {} - - def is_leaf(module): - return len(list(module.children())) == 0 - - def get_dtype(x): - if torch.is_tensor(x): - return str(x.dtype) - if isinstance(x, (list, tuple)): - for t in x: - if torch.is_tensor(t): - return str(t.dtype) - return "other" - - def get_module_group(name): - return name.split(".")[0] if "." in name else name - - def hook_fn(name): - def fn(module, inputs, outputs): - dtype = get_dtype(outputs) - if dtype not in ["torch.float16", "torch.bfloat16", "torch.float32"]: - dtype = "other" - group = get_module_group(name) - if group not in stats: - stats[group] = { - "torch.float16": 0, - "torch.bfloat16": 0, - "torch.float32": 0, - "other": 0, - } - examples[group] = { - "torch.float16": [], - "torch.bfloat16": [], - "torch.float32": [], - "other": [], - } - stats[group][dtype] += 1 - if len(examples[group][dtype]) < 3: - examples[group][dtype].append(module.__class__.__name__) - - return fn - - for name, module in model.named_modules(): - if is_leaf(module): - handles.append(module.register_forward_hook(hook_fn(name))) - return handles, stats, examples - - -def report_dtype_stats(handles, stats, examples, rank=0): - for h in handles: - h.remove() - logging.info(f"[rank={rank}] === DTYPE USAGE PER MODULE ===") - for group, group_stats in stats.items(): - total = sum(group_stats.values()) - if total == 0: - continue - logging.info(f"[rank={rank}] --- {group} ---") - for dtype, count in group_stats.items(): - if count > 0: - logging.info(f"[rank={rank}] {dtype}: {count} ({100 * count / total:.2f}%)") - logging.info(f"[rank={rank}] === DTYPE EXAMPLES ===") - for group, group_examples in examples.items(): - for dtype, mods in group_examples.items(): - if mods: - logging.info(f"[rank={rank}] {group} {dtype}: {mods}") - - -def _combined_audio_name(first_audio_filepath: str, paths: List[str]) -> str: - base_names = [os.path.splitext(os.path.basename(p))[0] for p in paths if p] - ext = os.path.splitext(paths[-1])[1] if paths and paths[-1] else "" - combined_name = "_".join(base_names) + ext - dir_name = os.path.dirname(first_audio_filepath) - return os.path.join(dir_name, combined_name) if dir_name else combined_name - - -class EvalJSONLDataset(Dataset): - def __init__(self, file_path: str, num_turns: int = 1): - self.samples = [] - raw_samples = [] - - with open(file_path, "r", encoding="utf-8") as f: - for line_idx, line in enumerate(f, 1): - line = line.strip() - if not line: - continue - try: - sample = json.loads(line) - sample["__dataset_index__"] = len(raw_samples) - raw_samples.append(sample) - except json.JSONDecodeError as e: - raise ValueError(f"Invalid JSON on line {line_idx}: {e}") - - if num_turns <= 1: - self.samples = raw_samples - return - - single_turn_by_speaker = {} - for sample in raw_samples: - if isinstance(sample["text"], list): - self.samples.append(sample) - else: - speaker = sample.get("speaker", "unknown") - single_turn_by_speaker.setdefault(speaker, []).append(sample) - - synthetic_index = len(raw_samples) - for _, speaker_samples in single_turn_by_speaker.items(): - buffer_texts, buffer_paths = [], [] - first_sample_meta = None - - for sample in speaker_samples: - if not buffer_texts: - first_sample_meta = dict(sample) - - buffer_texts.append(sample["text"]) - buffer_paths.append(sample.get("audio_filepath", "")) - - if len(buffer_texts) == num_turns: - first_sample_meta["text"] = buffer_texts - first_sample_meta["audio_filepath"] = _combined_audio_name( - first_sample_meta.get("audio_filepath", ""), - buffer_paths, - ) - first_sample_meta["__dataset_index__"] = synthetic_index - synthetic_index += 1 - self.samples.append(first_sample_meta) - buffer_texts, buffer_paths, first_sample_meta = [], [], None - - if buffer_texts and first_sample_meta is not None: - first_sample_meta["text"] = buffer_texts - first_sample_meta["audio_filepath"] = _combined_audio_name( - first_sample_meta.get("audio_filepath", ""), - buffer_paths, - ) - first_sample_meta["__dataset_index__"] = synthetic_index - synthetic_index += 1 - self.samples.append(first_sample_meta) - - def __len__(self): - return len(self.samples) - - def __getitem__(self, idx): - return self.samples[idx] - - -def _sample_text_segments_for_count(sample: Dict[str, Any], max_eval_turns=None) -> List[str]: - text_data = sample.get("text", "") - if isinstance(text_data, list): - segments = text_data - if max_eval_turns is not None: - segments = segments[: int(max_eval_turns)] - return [str(x) for x in segments] - return [str(text_data)] - - -def estimate_text_token_count(sample: Dict[str, Any], model, max_eval_turns=None) -> int: - """Approximate generation cost by summing text tokenizer lengths over all turns.""" - main_tokenizer_name = list(model.cfg.text_tokenizers.keys())[0] - total = 0 - for segment in _sample_text_segments_for_count(sample, max_eval_turns=max_eval_turns): - total += len(model.tokenizer.encode(segment, tokenizer_name=main_tokenizer_name)) + 1 # +EOS - return int(total) - - -class SortedByTextTokenCountDataset(Dataset): - """ - Dataset wrapper that orders examples by total text-token count. - - With DistributedSampler(shuffle=False), rank r sees positions: - r, r + world_size, r + 2 * world_size, ... - - If the wrapper is sorted by length descending, then each GPU step gets a - block of examples with similar token lengths. This usually reduces - straggler effects for autoregressive/profile inference. - """ - - def __init__(self, dataset: Dataset, model, max_eval_turns=None, descending: bool = True): - self.dataset = dataset - scored = [] - for i in range(len(dataset)): - sample = dict(dataset[i]) - token_count = estimate_text_token_count(sample, model=model, max_eval_turns=max_eval_turns) - sample["__text_token_count__"] = int(token_count) - scored.append((token_count, i, sample)) - - scored.sort(key=lambda x: (x[0], -x[1]), reverse=bool(descending)) - self.indices = [i for _, i, _ in scored] - self.token_counts = {i: int(tok) for tok, i, _ in scored} - - def __len__(self): - return len(self.indices) - - def __getitem__(self, local_idx): - original_idx = self.indices[local_idx] - sample = dict(self.dataset[original_idx]) - sample["__text_token_count__"] = self.token_counts[original_idx] - return sample - - -def _resolve_audio_path(path, root_path): - if path is None: - return None - if root_path is not None and not os.path.isabs(path): - return os.path.join(root_path, path) - return path - - -def _load_audio(path, sample_rate, normalize=True, use_librosa=False): - if path is None or not os.path.exists(path): - return torch.zeros(1, dtype=torch.float32) - - if use_librosa: - wav, sr = librosa.load(path, sr=sample_rate, mono=True) - if normalize: - wav = normalize_volume(wav) - return torch.as_tensor(wav, dtype=torch.float32) - - wav, sr = sf.read(path, dtype="float32") - if wav.ndim > 1: - wav = wav.mean(axis=1) - - if normalize: - wav = normalize_volume(wav) - - wav = torch.as_tensor(wav, dtype=torch.float32).unsqueeze(0) - return resample(wav, sr, sample_rate).squeeze(0) - - -def collate_and_tokenize_custom( - batch, - model, - extra_duration_thrshould=1.3, - sample_rate=22050, - root_path=None, - emulate_duplex_inference=False, - add_interruption_token=False, - pad_factor_text_speech=10, - force_interruption=False, - normalize_audio_volume=True, - use_librosa=False, - profile_multiturn_inference=False, - max_eval_turns=None, -): - main_tokenizer_name = list(model.cfg.text_tokenizers.keys())[0] - - if max_eval_turns is not None: - max_eval_turns = int(max_eval_turns) - if max_eval_turns <= 0: - raise ValueError("--max_eval_turns must be > 0 when provided.") - - truncated_batch = [] - for s in batch: - s = dict(s) - if isinstance(s["text"], list): - s["text"] = s["text"][:max_eval_turns] - if isinstance(s.get("user_audio_file_path"), list): - s["user_audio_file_path"] = s["user_audio_file_path"][:max_eval_turns] - truncated_batch.append(s) - batch = truncated_batch - - is_profile = profile_multiturn_inference - is_duplex = emulate_duplex_inference and not is_profile - - out_dict = { - "duplex_multiturn": is_duplex, - "regular_multiturn": (not is_duplex) and (not is_profile), - "profile_multiturn": is_profile, - "dataset_indices": [int(s.get("__dataset_index__", -1)) for s in batch], - "text_token_counts": [int(s.get("__text_token_count__", -1)) for s in batch], - } - - tokenized_list = [] - batched_turns = [] - batched_turn_lens = [] - valid_turn_masks = [] - - if is_duplex: - for s in batch: - text_data = s["text"] - - if isinstance(text_data, list): - full_ids = [] - for segment in text_data: - seg_ids = model.tokenizer.encode(segment, tokenizer_name=main_tokenizer_name) + [model.eos_id] - pad_ids = [model.pad_id] * (len(seg_ids) * pad_factor_text_speech) - - if force_interruption: - fname = s["audio_filepath"] - no_ext = fname.split(".")[0] - sample_id = int(no_ext.split("_")[-1]) - case = sample_id % 3 - - if case == 0: - if len(seg_ids) >= 2: - seg_ids[-2] = model.interruption_token_id - seg_ids[-1] = model.pad_id - else: - pad_ids[0] = model.interruption_token_id - elif case == 1: - eos_idx = min(6, len(pad_ids) - 1) - pad_ids[eos_idx] = model.interruption_token_id - else: - pad_ids[0] = model.interruption_token_id - - elif add_interruption_token: - eos_idx = int(len(pad_ids) * 0.7) - pad_ids[eos_idx] = model.interruption_token_id - - full_ids.extend(seg_ids) - full_ids.extend(pad_ids) - - tokenized_list.append(torch.as_tensor(full_ids, dtype=torch.long)) - else: - tokenized_list.append( - torch.as_tensor( - model.tokenizer.encode(text_data, tokenizer_name=main_tokenizer_name) + [model.eos_id], - dtype=torch.long, - ) - ) - - prefix = torch.full((25,), model.pad_id, dtype=torch.long) - tokenized_list = [torch.cat([prefix, x]) for x in tokenized_list] - out_dict["input_lengths"] = torch.tensor([len(x) for x in tokenized_list], dtype=torch.long) - out_dict["input_ids"] = pad_sequence(tokenized_list, batch_first=True, padding_value=model.pad_id) - - else: - max_turns = 1 - for s in batch: - if isinstance(s["text"], list): - max_turns = max(max_turns, len(s["text"])) - - for t in range(max_turns): - turn_t_tokens, turn_t_lens, turn_t_valid = [], [], [] - - for s in batch: - text_data = s["text"] - - if isinstance(text_data, list): - if t < len(text_data): - seg_ids = model.tokenizer.encode(text_data[t], tokenizer_name=main_tokenizer_name) + [ - model.eos_id - ] - turn_t_tokens.append(torch.as_tensor(seg_ids, dtype=torch.long)) - turn_t_lens.append(len(seg_ids)) - turn_t_valid.append(True) - else: - turn_t_tokens.append(torch.as_tensor([model.pad_id], dtype=torch.long)) - turn_t_lens.append(1) - turn_t_valid.append(False) - else: - if t == 0: - seg_ids = model.tokenizer.encode(text_data, tokenizer_name=main_tokenizer_name) + [ - model.eos_id - ] - turn_t_tokens.append(torch.as_tensor(seg_ids, dtype=torch.long)) - turn_t_lens.append(len(seg_ids)) - turn_t_valid.append(True) - else: - turn_t_tokens.append(torch.as_tensor([model.pad_id], dtype=torch.long)) - turn_t_lens.append(1) - turn_t_valid.append(False) - - batched_turns.append(pad_sequence(turn_t_tokens, batch_first=True, padding_value=model.pad_id)) - batched_turn_lens.append(torch.tensor(turn_t_lens, dtype=torch.long)) - valid_turn_masks.append(torch.tensor(turn_t_valid, dtype=torch.bool)) - - out_dict["batched_turns"] = batched_turns - out_dict["batched_turn_lens"] = batched_turn_lens - out_dict["valid_turn_masks"] = valid_turn_masks - - audio_list, audio_lengths, target_num_frames = [], [], [] - max_turns_for_user_audio = len(batched_turns) if not is_duplex else 0 - - if is_profile and max_turns_for_user_audio > 0: - user_audio_by_turn = [[] for _ in range(max_turns_for_user_audio)] - user_audio_lens_by_turn = [[] for _ in range(max_turns_for_user_audio)] - else: - user_audio_by_turn, user_audio_lens_by_turn = [], [] - - for i, s in enumerate(batch): - audio_path = _resolve_audio_path(s.get("context_audio_filepath"), root_path) - wav = _load_audio(audio_path, sample_rate, normalize=normalize_audio_volume, use_librosa=use_librosa) - audio_list.append(wav) - audio_lengths.append(len(wav)) - - if is_profile and max_turns_for_user_audio > 0: - user_audio_paths = s.get("user_audio_file_path", None) - - for t in range(max_turns_for_user_audio): - has_valid_text_turn = (isinstance(s["text"], list) and t < len(s["text"])) or ( - not isinstance(s["text"], list) and t == 0 - ) - - if ( - isinstance(user_audio_paths, list) - and t < len(user_audio_paths) - and user_audio_paths[t] - and has_valid_text_turn - ): - ua_path = _resolve_audio_path(user_audio_paths[t], root_path) - ua_wav = _load_audio( - ua_path, - sample_rate=sample_rate, - normalize=normalize_audio_volume, - use_librosa=use_librosa, - ) - else: - ua_wav = torch.zeros(int(2 * sample_rate), dtype=torch.float32) - - user_audio_by_turn[t].append(ua_wav) - user_audio_lens_by_turn[t].append(len(ua_wav)) - - tdur_audio_path = _resolve_audio_path(s["audio_filepath"], root_path) - - if tdur_audio_path and os.path.exists(tdur_audio_path): - wav_dur = _load_audio( - tdur_audio_path, - sample_rate, - normalize=normalize_audio_volume, - use_librosa=use_librosa, - ) - tdur = wav_dur.shape[0] // model.input_samples_per_frame - target_num_frames.append(tdur * extra_duration_thrshould) - else: - if is_duplex: - current_text_len = len(tokenized_list[i]) - target_num_frames.append(current_text_len if isinstance(s["text"], list) else current_text_len * 5) - else: - target_num_frames.append(sum([l[i].item() for l in batched_turn_lens]) * 5) - - max_audio_len = max(audio_lengths) - B = len(audio_lengths) - padded_audio = torch.zeros((B, max_audio_len), dtype=torch.float32) - - for i, wav in enumerate(audio_list): - padded_audio[i, : len(wav)] = wav - - if is_profile and max_turns_for_user_audio > 0: - padded_user_audio_turns, padded_user_audio_turns_lens = [], [] - - for t in range(max_turns_for_user_audio): - turn_lens = user_audio_lens_by_turn[t] - max_turn_audio_len = max(turn_lens) - padded_turn_audio = torch.zeros((B, max_turn_audio_len), dtype=torch.float32) - - for i, wav in enumerate(user_audio_by_turn[t]): - padded_turn_audio[i, : len(wav)] = wav - - padded_user_audio_turns.append(padded_turn_audio) - padded_user_audio_turns_lens.append(torch.tensor(turn_lens, dtype=torch.long)) - - out_dict["user_audio_turns"] = padded_user_audio_turns - out_dict["user_audio_turns_lens"] = padded_user_audio_turns_lens - - out_dict["context_audio"] = padded_audio - out_dict["context_audio_lengths"] = torch.tensor(audio_lengths, dtype=torch.long) - out_dict["target_audio_paths"] = [s["audio_filepath"] for s in batch] - out_dict["target_num_frames"] = target_num_frames - out_dict["raw_text"] = [" ".join(s["text"]) if isinstance(s["text"], list) else s["text"] for s in batch] - - return out_dict - - -def build_model_and_codec(args, target_device, target_dtype): - model_cfg = EasyMagpieTTSInferenceModel.restore_from(args.checkpoint_path, return_config=True) - - with open_dict(model_cfg): - model_cfg.target = "nemo.collections.tts.models.easy_magpietts_inference.EasyMagpieTTSInferenceModel" - model_cfg.codecmodel_path = args.codec_model_path - model_cfg.train_ds = None - model_cfg.validation_ds = None - model_cfg.run_val_inference = False - model_cfg.use_utmos = False - model_cfg.use_meta_init_for_decoder = True - - if args.phoneme_tokenizer_path and getattr(model_cfg, "phoneme_tokenizer", None) is not None: - model_cfg.phoneme_tokenizer.tokenizer_path = args.phoneme_tokenizer_path - - model = EasyMagpieTTSInferenceModel.restore_from( - args.checkpoint_path, - override_config_path=model_cfg, - map_location=torch.device("cpu"), - ) - model.use_kv_cache_for_inference = True - model.to(dtype=target_dtype) - model.eval().to(target_device) - - model.input_samples_per_frame = int(model.codec_model_samples_per_frame * model.frame_stacking_factor) - model.target_samples_per_frame = model.input_samples_per_frame / (model.sample_rate / model.output_sample_rate) - - codec_model = AudioCodecModel.restore_from(args.codec_model_path, strict=False, map_location=torch.device("cpu")) - if hasattr(codec_model, "discriminator"): - del codec_model.discriminator - codec_model.freeze() - codec_model = codec_model.to(target_device).eval() - - codec_converter = None - if getattr(model, "_codec_converter", None) is not None: - vq_new = deepcopy(model._codec_converter.vector_quantizer_new).to(target_device).eval() - codec_converter = VectorQuantizerIndexConverter( - vector_quantizer_original=codec_model.vector_quantizer, - vector_quantizer_new=vq_new, - ).to(target_device).eval() - - model._codec_helper = CodecHelper(codec_model=codec_model, codec_converter=codec_converter) - model._generate_codec_silence_buffer() - - return model - - -def prepare_inputs_for_device(inputs, model, args, target_dtype, speaker_wav=None): - B = inputs["context_audio"].size(0) - device = model.device - - inputs["context_audio"] = inputs["context_audio"].to(device, dtype=target_dtype) - inputs["context_audio_lengths"] = inputs["context_audio_lengths"].to(device) - - if args.user_custom_speaker_reference and speaker_wav is not None: - inputs["context_audio"] = speaker_wav.repeat(B, 1).detach() - inputs["context_audio_lengths"] = torch.full((B,), speaker_wav.size(-1), dtype=torch.long, device=device) - - if "user_audio_turns" in inputs: - inputs["user_audio_turns"] = [x.to(device, dtype=target_dtype) for x in inputs["user_audio_turns"]] - inputs["user_audio_turns_lens"] = [x.to(device) for x in inputs["user_audio_turns_lens"]] - - return inputs - - -def run_generation(model, inputs, args, codec_sil_codes): - B = inputs["context_audio"].size(0) - device = model.device - profile_turn_frame_ranges = [] - profile_decode_start_frame = 0 - - with torch.inference_mode(): - wav = inputs["context_audio"] - wav_len = inputs["context_audio_lengths"] - codes, codes_lens = model._codec_helper.audio_to_codes(wav, wav_len) - - use_lang = bool(getattr(model, "add_language_to_context_text", False)) - ctx_text = f"[{args.language.upper()}]" if use_lang else "[NO TEXT CONTEXT]" - ctx_text_ids = model.tokenizer.encode(ctx_text, tokenizer_name=model.text_conditioning_tokenizer_name) - - ctx_toks = torch.tensor([ctx_text_ids], dtype=torch.long, device=device).expand(B, -1) - ctx_toks_lens = torch.tensor([len(ctx_text_ids)] * B, dtype=torch.long, device=device) - - state = model.streaming_init( - context_audio_codes=codes, - context_audio_codes_lens=codes_lens, - context_text_tokens=ctx_toks, - context_text_tokens_lens=ctx_toks_lens, - use_cfg=args.use_cfg, - cfg_scale=args.cfg_scale, - use_local_transformer=True, - temperature=args.temperature, - topk=args.topk, - phoneme_input_type="pred", - phoneme_sampling_method="argmax", - use_inference_mode=True, - ) - - if inputs["duplex_multiturn"]: - text = inputs["input_ids"].to(device) - text_lens = inputs["input_lengths"].to(device) - - in_initial_silence = torch.ones(B, dtype=torch.bool, device=device) - in_post_speech_silence = torch.zeros(B, dtype=torch.bool, device=device) - - text_exhausted = state.text_tokens_seen >= text_lens - - while not text_exhausted.all(): - state.finished = state.finished & text_exhausted - state.text_finished = state.text_finished & text_exhausted - - if hasattr(state, "phoneme_stream_ended"): - state.phoneme_stream_ended = state.phoneme_stream_ended & text_exhausted - - positions = state.text_tokens_seen.clamp(max=text.size(1) - 1) - current_tokens = text[torch.arange(B, device=device), positions] - current_tokens = torch.where( - text_exhausted, - torch.full_like(current_tokens, model.eos_id), - current_tokens, - ) - - is_pad_or_eos = (current_tokens == model.pad_id) | (current_tokens == model.eos_id) - in_initial_silence = in_initial_silence & is_pad_or_eos - in_post_speech_silence = in_post_speech_silence & is_pad_or_eos - - state, audio_codes, _ = model.streaming_step( - state=state, - text_tokens=current_tokens, - use_inference_mode=True, - ) - - if audio_codes is not None and args.force_speech_sil_codes: - force_silence_mask = in_initial_silence | in_post_speech_silence - if force_silence_mask.any(): - expanded_sil = codec_sil_codes.view(1, -1, 1).expand_as(audio_codes) - mask_3d = force_silence_mask.view(B, 1, 1) - state.all_predictions[-1] = torch.where(mask_3d, expanded_sil, audio_codes) - - in_post_speech_silence = in_post_speech_silence | state.finished - text_exhausted = state.text_tokens_seen >= text_lens - - elif inputs["regular_multiturn"]: - batched_turns = inputs["batched_turns"] - batched_turn_lens = inputs["batched_turn_lens"] - valid_turn_masks = inputs["valid_turn_masks"] - - turn_offsets = torch.zeros(B, dtype=torch.long, device=device) - - for t in range(len(batched_turns)): - turn_text = batched_turns[t].to(device) - turn_lens = batched_turn_lens[t].to(device) - valid_mask = valid_turn_masks[t].to(device) - - state.finished = state.finished & (~valid_mask) - state.text_finished = state.text_finished & (~valid_mask) - - if hasattr(state, "phoneme_stream_ended"): - state.phoneme_stream_ended = state.phoneme_stream_ended & (~valid_mask) - - if state.finished.all(): - continue - - turn_offsets = torch.where(valid_mask, state.text_tokens_seen, turn_offsets) - turn_steps = 0 - - while not state.finished.all() and turn_steps < args.max_tts_steps: - turn_steps += 1 - - relative_positions = state.text_tokens_seen - turn_offsets - positions = relative_positions.clamp(min=0, max=turn_text.size(1) - 1) - current_tokens = turn_text[torch.arange(B, device=device), positions] - - exhausted = relative_positions >= turn_lens - current_tokens = torch.where( - exhausted, - torch.full_like(current_tokens, model.eos_id), - current_tokens, - ) - - state, _, _ = model.streaming_step( - state=state, - text_tokens=current_tokens, - use_inference_mode=True, - ) - - elif inputs["profile_multiturn"]: - if B != 1: - raise RuntimeError("--profile_multiturn_inference requires --batch_size=1 per process.") - - batched_turns = inputs["batched_turns"] - batched_turn_lens = inputs["batched_turn_lens"] - valid_turn_masks = inputs["valid_turn_masks"] - - for t in range(len(batched_turns)): - turn_text = batched_turns[t].to(device) - turn_lens = batched_turn_lens[t].to(device) - valid_mask = valid_turn_masks[t].to(device) - - if not bool(valid_mask[0].item()): - continue - - state.finished.zero_() - state.text_finished.zero_() - state.audio_prediction_end_idx.fill_(-1) - - if hasattr(state, "turn_text_tokens_seen"): - state.turn_text_tokens_seen.zero_() - if hasattr(state, "phoneme_steps"): - state.phoneme_steps.zero_() - if hasattr(state, "phoneme_stream_ended"): - state.phoneme_stream_ended.zero_() - if hasattr(state, "phoneme_eos_detected"): - state.phoneme_eos_detected.zero_() - - state.last_phoneme_tokens = None - - if not model.cfg.get("condition_on_user_speech", False): - if "user_audio_turns" in inputs: - profile_T = int(round(inputs["user_audio_turns"][t].size(-1) / model.input_samples_per_frame)) - profile_seconds = profile_T * model.input_samples_per_frame / model.sample_rate - else: - profile_seconds = args.profile_pad_min_sec + torch.rand((), device=device).item() * ( - args.profile_pad_max_sec - args.profile_pad_min_sec - ) - profile_T = max( - 1, - int(round(profile_seconds * model.sample_rate / model.input_samples_per_frame)), - ) - - profile_tokens = torch.full((1, profile_T), model.pad_id, dtype=torch.long, device=device) - user_audio_channel_embedding = None - - else: - if "user_audio_turns" in inputs: - user_audio = inputs["user_audio_turns"][t] - user_audio_lens = inputs["user_audio_turns_lens"][t] - else: - user_audio = inputs["context_audio"] - user_audio_lens = inputs["context_audio_lengths"] - - user_audio_codes, user_audio_codes_lens = model._codec_helper.audio_to_codes( - user_audio, - user_audio_lens, - ) - - if model._codec_converter is not None: - user_audio_codes = model._codec_converter.convert_original_to_new( - audio_tokens=user_audio_codes, - audio_lens=user_audio_codes_lens, - ).long() - - user_audio_codes, user_audio_codes_lens = model.stack_codes( - user_audio_codes, - user_audio_codes_lens, - model.audio_bos_id, - model.audio_eos_id, - model.frame_stacking_factor, - model.num_audio_codebooks, - ) - - user_audio_embedded = model.embed_audio_tokens(user_audio_codes) - - boundary_trim = model.cfg.get("user_audio_boundary_trim", 4) - boundary_trim = 0 if boundary_trim is None else int(boundary_trim) - - if boundary_trim == 0: - real_start = 0 - real_end = int(user_audio_codes_lens[0].item()) - else: - turn_len_with_special = int(user_audio_codes_lens[0].item()) - real_start = 1 - real_end = max(real_start, turn_len_with_special - 1) - - user_audio_embedded = user_audio_embedded[:, real_start:real_end] - copy_len = user_audio_embedded.size(1) - - if boundary_trim > 0: - trim = min(boundary_trim, copy_len // 2) - if trim > 0: - user_audio_embedded[:, :trim] = 0.0 - user_audio_embedded[:, copy_len - trim :] = 0.0 - - bos_user_pad = torch.zeros( - user_audio_embedded.size(0), - 1, - user_audio_embedded.size(2), - device=user_audio_embedded.device, - dtype=user_audio_embedded.dtype, - ) - user_audio_embedded = torch.cat([bos_user_pad, user_audio_embedded], dim=1) - - profile_T = user_audio_embedded.size(1) - profile_tokens = torch.full((B, profile_T), model.pad_id, dtype=torch.long, device=device) - user_audio_channel_embedding = user_audio_embedded - profile_seconds = profile_T * model.input_samples_per_frame / model.sample_rate - - delay_tokens = int(state.config.training_mode.streaming_speech_delay) - delay_tokens = min(delay_tokens, int(turn_lens[0].item()), profile_T) - - warmup_tokens = turn_text[:, :delay_tokens] - turn_text = turn_text[:, delay_tokens:] - turn_lens = torch.clamp(turn_lens - delay_tokens, min=0) - - if user_audio_channel_embedding is not None and delay_tokens > 0: - warmup_user_audio = user_audio_channel_embedding[:, -delay_tokens:] - user_audio_channel_embedding = user_audio_channel_embedding[:, :-delay_tokens] - profile_tokens = profile_tokens[:, :-delay_tokens] - else: - warmup_user_audio = None - - if profile_tokens.size(1) > 0: - state = model.streaming_prefill_profile( - state=state, - text_tokens=profile_tokens, - use_inference_mode=True, - user_audio_channel_embedding=user_audio_channel_embedding, - ) - - for i in range(delay_tokens): - user_step_emb = warmup_user_audio[:, i] if warmup_user_audio is not None else None - - state.finished.zero_() - state, _, _ = model.streaming_step( - state=state, - text_tokens=warmup_tokens[:, i], - user_audio_channel_embedding=user_step_emb, - prefill_like_step=not bool(model.cfg.get("agent_mask_include_transition_prefix", False)), - prefill_like_is_last_step=(i == delay_tokens - 1), - use_inference_mode=True, - ) - - logging.info(f"[profile_multiturn] turn={t} prefilled {profile_T} steps ({profile_seconds:.2f}s)") - - turn_start_frame = sum(p.size(-1) for p in state.all_predictions) - if t == 0: - state.audio_prediction_start_idx.fill_(turn_start_frame) - profile_decode_start_frame = turn_start_frame - - turn_offset = state.text_tokens_seen.clone() - turn_steps = 0 - saw_audio = False - turn_ended_with_audio_eos = False - - while turn_steps < args.max_tts_steps: - turn_steps += 1 - state.finished.zero_() - - relative_position = state.text_tokens_seen - turn_offset - text_exhausted = relative_position >= turn_lens - - if turn_text.size(1) == 0: - current_tokens = torch.full((B,), model.eos_id, dtype=torch.long, device=device) - else: - position = relative_position.clamp(min=0, max=turn_text.size(1) - 1) - current_tokens = turn_text[torch.arange(B, device=device), position] - current_tokens = torch.where( - text_exhausted, - torch.full_like(current_tokens, model.eos_id), - current_tokens, - ) - - state, audio_codes, _ = model.streaming_step( - state=state, - text_tokens=current_tokens, - use_inference_mode=True, - ) - - if audio_codes is not None and not saw_audio: - saw_audio = True - - if bool(text_exhausted[0].item()) and bool(state.finished[0].item()): - turn_ended_with_audio_eos = True - break - - state.audio_prediction_end_idx.fill_(-1) - state.finished.zero_() - - turn_end_frame = sum(p.size(-1) for p in state.all_predictions) - profile_turn_frame_ranges.append((t, turn_start_frame, turn_end_frame)) - - logging.info( - f"[profile_multiturn] turn={t} steps={turn_steps} " - f"saw_audio={saw_audio} ended_with_audio_eos={turn_ended_with_audio_eos}" - ) - - bos_id = getattr(model, "audio_bos_id", -1) - eos_id = getattr(model, "audio_eos_id", -1) - speaking_id = getattr(model, "audio_user_speaking_id", -1) - speaking_end_id = getattr(model, "audio_user_speaking_end_id", -1) - sil_injection = codec_sil_codes.view(1, -1, 1) - - for step_idx in range(len(state.all_predictions)): - pred = state.all_predictions[step_idx] - mask = (pred == bos_id) | (pred == eos_id) | (pred == speaking_id) | (pred == speaking_end_id) - frame_mask = mask.any(dim=1, keepdim=True) - if frame_mask.any(): - state.all_predictions[step_idx] = torch.where(frame_mask, sil_injection.expand_as(pred), pred) - - if inputs["duplex_multiturn"] or inputs["profile_multiturn"]: - state.audio_prediction_end_idx.fill_(-1) - - finalize_output = model.streaming_finalize(state, use_inference_mode=True) - - return finalize_output, profile_turn_frame_ranges, profile_decode_start_frame - - -def update_metrics_and_save_audio( - model, - inputs, - finalize_output, - profile_turn_frame_ranges, - profile_decode_start_frame, - intelligibility, - secs_metric, - args, - rank, - run_id: int = 0, -): - device = model.device - B = inputs["context_audio"].size(0) - - with fp32_precision(): - audio_f32 = finalize_output.audio.float() - audio_len = finalize_output.audio_len.int() - - expected_audio_lens = ( - torch.tensor(inputs["target_num_frames"], device=device) * model.target_samples_per_frame - ).int() - - if inputs["duplex_multiturn"]: - text_lens = inputs["input_lengths"].to(device) - audio_len = (text_lens * model.target_samples_per_frame).int() - audio_len = torch.min(audio_len, torch.tensor(audio_f32.size(1), device=device)) - elif inputs["profile_multiturn"]: - audio_len = finalize_output.audio_len.int() - else: - audio_len = torch.min(audio_len, expected_audio_lens) - - metric_audio_pred = resample(audio_f32, model.output_sample_rate, 16000) - metric_audio_pred_lens = (audio_len / model.output_sample_rate * 16000).to(torch.long) - - context_audio = resample(inputs["context_audio"].float(), model.sample_rate, 16000) - context_audio_lens = (inputs["context_audio_lengths"] / model.sample_rate * 16000).to(torch.long) - - metric_audio_pred = torch_rms_norm(metric_audio_pred) - context_audio = torch_rms_norm(context_audio) - - asr_hyps = intelligibility.update( - name="dataset", - refs=inputs["raw_text"], - pred_audio=metric_audio_pred, - pred_audio_lens=metric_audio_pred_lens, - asr_hyps=None, - ) - - secs_metric.update( - name="dataset", - target_audio=context_audio, - target_audio_lens=context_audio_lens, - pred_audio=metric_audio_pred, - pred_audio_lens=metric_audio_pred_lens, - ) - - audio_out_dir = get_audio_out_dir(args) - os.makedirs(audio_out_dir, exist_ok=True) - - audio_f32_cpu = audio_f32.detach().cpu() - audio_len_cpu = audio_len.detach().cpu() - - for i in range(B): - target_path = inputs["target_audio_paths"][i] - base_name = os.path.basename(target_path) - stem, ext = os.path.splitext(base_name) - if not ext: - ext = ".wav" - - dataset_idx = inputs.get("dataset_indices", [-1] * B)[i] - run_prefix = f"run{int(run_id):03d}_" if int(getattr(args, "num_eval_runs", 1)) > 1 else "" - safe_stem = f"{run_prefix}idx{dataset_idx:08d}_{stem}" if dataset_idx >= 0 else f"{run_prefix}rank{rank}_{stem}" - - if inputs["profile_multiturn"]: - full_len = int(audio_len_cpu[i].item()) - full_wav_t = audio_f32_cpu[i, :full_len].float() - - samples_per_prediction_frame = model.codec_model_samples_per_frame / ( - model.sample_rate / model.output_sample_rate - ) - - aligned_agent = torch.zeros_like(full_wav_t) - - for turn_id, start_frame, end_frame in profile_turn_frame_ranges: - rel_start_frame = start_frame - profile_decode_start_frame - rel_end_frame = end_frame - profile_decode_start_frame - - start_sample = int(round(rel_start_frame * samples_per_prediction_frame)) - end_sample = int(round(rel_end_frame * samples_per_prediction_frame)) - - start_sample = max(0, min(start_sample, full_len)) - end_sample = max(start_sample, min(end_sample, full_len)) - - aligned_agent[start_sample:end_sample] = full_wav_t[start_sample:end_sample] - - turn_wav = aligned_agent[start_sample:end_sample].numpy() - out_path = os.path.join(audio_out_dir, f"{safe_stem}_turn_{turn_id}{ext}") - sf.write(out_path, turn_wav, samplerate=model.output_sample_rate) - - out_path = os.path.join(audio_out_dir, f"{safe_stem}{ext}") - sf.write(out_path, aligned_agent.numpy(), samplerate=model.output_sample_rate) - - if "user_audio_turns" in inputs: - user_segments = [] - - first_user_len_in = int(inputs["user_audio_turns_lens"][0][i].item()) - first_user_delay_out = int(round(first_user_len_in * model.output_sample_rate / model.sample_rate)) - - for turn_id, start_frame, _ in profile_turn_frame_ranges: - if turn_id >= len(inputs["user_audio_turns"]): - continue - - turn_audio = inputs["user_audio_turns"][turn_id][i].detach().cpu().float() - turn_audio_len = int(inputs["user_audio_turns_lens"][turn_id][i].item()) - turn_audio = turn_audio[:turn_audio_len] - - turn_audio_out = resample( - turn_audio.unsqueeze(0), - model.sample_rate, - model.output_sample_rate, - ).squeeze(0) - - if turn_id == 0: - user_start_sample = 0 - else: - prev_turn_end_frame = profile_turn_frame_ranges[turn_id - 1][2] - rel_prev_end_frame = prev_turn_end_frame - profile_decode_start_frame - user_start_sample = first_user_delay_out + int( - round(rel_prev_end_frame * samples_per_prediction_frame) - ) - - user_segments.append((user_start_sample, turn_audio_out)) - - total_user_len = 0 - for s, wav_seg in user_segments: - total_user_len = max(total_user_len, s + wav_seg.numel()) - - user_ch = torch.zeros(total_user_len) - for s, wav_seg in user_segments: - e = s + wav_seg.numel() - user_ch[s:e] += wav_seg - - agent_ch = torch.cat([torch.zeros(first_user_delay_out, dtype=aligned_agent.dtype), aligned_agent]) - - stereo_len = max(user_ch.numel(), agent_ch.numel()) - user_pad = torch.zeros(stereo_len) - agent_pad = torch.zeros(stereo_len) - - user_pad[: user_ch.numel()] = user_ch - agent_pad[: agent_ch.numel()] = agent_ch - - stereo = torch.stack([user_pad, agent_pad], dim=1).numpy() - aligned_path = os.path.join(audio_out_dir, f"{safe_stem}_user_agent_aligned{ext}") - sf.write(aligned_path, stereo, samplerate=model.output_sample_rate) - - else: - wav = audio_f32_cpu[i, : audio_len_cpu[i]].numpy() - out_path = os.path.join(audio_out_dir, f"{safe_stem}{ext}") - sf.write(out_path, wav, samplerate=model.output_sample_rate) - - return audio_f32.detach(), audio_len.detach(), asr_hyps - - -def compute_filewise_metrics_for_batch( - rank: int, - model, - inputs, - audio_f32: torch.Tensor, - audio_len: torch.Tensor, - asr_hyps: List[str], - run_id: int = 0, -): - filewise_rows = [] - B = audio_f32.size(0) - device = model.device - - for i in range(B): - dataset_idx = int(inputs.get("dataset_indices", [-1] * B)[i]) - target_path = inputs["target_audio_paths"][i] - ref_text = inputs["raw_text"][i] - asr_hyp_text = asr_hyps[i] if asr_hyps is not None and i < len(asr_hyps) else None - - pred_len_i = int(audio_len[i].item()) - pred_audio_i = audio_f32[i : i + 1, :pred_len_i].float() - pred_audio_len_i = torch.tensor([pred_len_i], dtype=torch.long, device=device) - - context_len_i = int(inputs["context_audio_lengths"][i].item()) - context_audio_i = inputs["context_audio"][i : i + 1, :context_len_i].float() - context_audio_len_i = torch.tensor([context_len_i], dtype=torch.long, device=device) - - with fp32_precision(): - pred_16k = resample(pred_audio_i, model.output_sample_rate, 16000) - pred_16k_len = (pred_audio_len_i / model.output_sample_rate * 16000).to(torch.long) - - context_16k = resample(context_audio_i, model.sample_rate, 16000) - context_16k_len = (context_audio_len_i / model.sample_rate * 16000).to(torch.long) - - pred_16k = torch_rms_norm(pred_16k) - context_16k = torch_rms_norm(context_16k) - - one_intelligibility = Intelligibility( - "stt_en_fastconformer_transducer_large", - reuse_asr_hyps=True, - ).reset() - one_intelligibility.update( - name="dataset", - refs=[ref_text], - pred_audio=None, - pred_audio_lens=None, - asr_hyps=[asr_hyp_text or ""], - ) - one_intel_metrics = metric_dict_to_jsonable(one_intelligibility.compute()) - - one_secs = SECS("titanet_large").reset() - one_secs.update( - name="dataset", - target_audio=context_16k, - target_audio_lens=context_16k_len, - pred_audio=pred_16k, - pred_audio_lens=pred_16k_len, - ) - one_secs_metrics = metric_dict_to_jsonable(one_secs.compute()) - - cer = safe_metric_scalar(one_intel_metrics, ["cer", "cer_dataset"]) - wer = safe_metric_scalar(one_intel_metrics, ["wer", "wer_dataset"]) - secs = safe_metric_scalar(one_secs_metrics, ["secs", "secs_dataset"]) - - filewise_rows.append( - { - "run_id": int(run_id), - "rank": int(rank), - "dataset_index": int(dataset_idx), - "target_audio_path": target_path, - "reference_text": ref_text, - "asr_hyp": asr_hyp_text, - "cer": cer, - "wer": wer, - "secs": secs, - "pred_audio_samples": int(pred_len_i), - "pred_audio_seconds": float(pred_len_i / model.output_sample_rate), - "intelligibility": one_intel_metrics, - "secs_metrics": one_secs_metrics, - } - ) - - return filewise_rows - - -def load_speaker_wav_if_needed(args, model, target_dtype): - if args.user_custom_speaker_reference and args.inference_speaker_reference: - return _load_audio( - args.inference_speaker_reference, - model.sample_rate, - normalize=args.normalize_volume, - use_librosa=args.use_librosa, - ).unsqueeze(0).to(model.device, dtype=target_dtype) - - return None - - -def compute_and_save_rank_metrics(args, rank, world_size, num_processed, elapsed, intelligibility, secs_metric): - if num_processed > 0: - with fp32_precision(): - cer_wer = metric_dict_to_jsonable(intelligibility.compute()) - secs_scores = metric_dict_to_jsonable(secs_metric.compute()) - else: - cer_wer = {} - secs_scores = {} - - rank_metrics = { - "rank": int(rank), - "world_size": int(world_size), - "num_processed": int(num_processed), - "elapsed_sec": float(elapsed), - "num_eval_runs": int(getattr(args, "num_eval_runs", 1)), - "intelligibility": cer_wer, - "secs": secs_scores, - } - - rank_path = os.path.join(args.out_dir, f"metrics_rank{rank:04d}.json") - write_json(rank_path, rank_metrics) - - return rank_metrics - - -def merge_metrics_on_rank0(args, rank, world_size): - if rank != 0: - return None - - rank_metric_files = [os.path.join(args.out_dir, f"metrics_rank{r:04d}.json") for r in range(world_size)] - - rank_metrics = [] - for path in rank_metric_files: - if not os.path.exists(path): - logging.warning(f"Missing rank metric file: {path}") - continue - with open(path, "r", encoding="utf-8") as f: - rank_metrics.append(json.load(f)) - - total_n = sum(int(m.get("num_processed", 0)) for m in rank_metrics) - - def weighted_average(section: str): - keys = set() - for m in rank_metrics: - keys.update(m.get(section, {}).keys()) - - out = {} - for k in sorted(keys): - numerator = 0.0 - denominator = 0 - - for m in rank_metrics: - n = int(m.get("num_processed", 0)) - if n <= 0: - continue - - value = m.get(section, {}).get(k, None) - if value is None or isinstance(value, str): - continue - - try: - value = float(value) - except Exception: - continue - - numerator += value * n - denominator += n - - if denominator > 0: - out[k] = numerator / denominator - - return out - - final_metrics = { - "world_size": int(world_size), - "num_processed": int(total_n), - "aggregation": "sum(rank_metric * rank_num_samples) / total_num_samples; repeated DistributedSampler samples included", - "intelligibility": weighted_average("intelligibility"), - "secs": weighted_average("secs"), - "ranks": rank_metrics, - } - - final_json_path = os.path.join(args.out_dir, "metrics_final.json") - final_txt_path = os.path.join(args.out_dir, "metrics_final.txt") - - write_json(final_json_path, final_metrics) - - final_text = format_final_metric_text(final_metrics) - write_text_atomic(final_txt_path, final_text) - - print("\n--- Final Evaluation Metrics ---", flush=True) - print(final_text, flush=True) - - logging.info(f"Final metrics JSON saved to: {final_json_path}") - logging.info(f"Final metrics TXT saved to: {final_txt_path}") - logging.info(json.dumps(final_metrics, indent=2, sort_keys=True)) - - return final_metrics - - -def merge_filewise_metrics_on_rank0(args, rank: int, world_size: int): - if rank != 0 or not args.save_filewise_metrics: - return [] - - all_rows = [] - - for r in range(world_size): - path = os.path.join(args.out_dir, f"filewise_metrics_rank{r:04d}.jsonl") - if not os.path.exists(path): - logging.warning(f"Missing filewise metrics file: {path}") - continue - - with open(path, "r", encoding="utf-8") as f: - for line in f: - line = line.strip() - if line: - all_rows.append(json.loads(line)) - - deduped = {} - for row in all_rows: - run_id = int(row.get("run_id", 0)) - idx = int(row.get("dataset_index", -1)) - key = (run_id, idx) - if key not in deduped: - deduped[key] = row - - all_rows = list(deduped.values()) - - all_rows.sort( - key=lambda x: ( - x.get("cer") is not None, - float(x.get("cer")) if x.get("cer") is not None else -1.0, - ), - reverse=True, - ) - - jsonl_path = os.path.join(args.out_dir, "filewise_metrics_sorted_by_cer.jsonl") - csv_path = os.path.join(args.out_dir, "filewise_metrics_sorted_by_cer.csv") - - write_jsonl(jsonl_path, all_rows) - write_filewise_csv(csv_path, all_rows) - - logging.info(f"Saved sorted filewise metrics JSONL to: {jsonl_path}") - logging.info(f"Saved sorted filewise metrics CSV to: {csv_path}") - - topk = min(int(args.filewise_metrics_topk_log), len(all_rows)) - if topk > 0: - logging.info(f"Top {topk} worst CER samples:") - for row in all_rows[:topk]: - logging.info( - "run_id=%s dataset_index=%s cer=%s wer=%s secs=%s path=%s text=%s" - % ( - row.get("run_id"), - row.get("dataset_index"), - row.get("cer"), - row.get("wer"), - row.get("secs"), - row.get("target_audio_path"), - row.get("reference_text"), - ) - ) - - return all_rows - - -def compute_aggregates_from_filewise_rows(rows: List[Dict[str, Any]]): - """ - Compute final metrics after deduplicating DistributedSampler repeats. - - CER/WER are computed the same way Intelligibility.compute() does: - word_error_rate(all_normalized_hyps, all_normalized_refs, use_cer=True/False) - - SECS is averaged over the deduplicated per-file SECS values. - """ - if len(rows) == 0: - return { - "cer": None, - "wer": None, - "secs": None, - "num_samples": 0, - } - - normalizer = EnglishTextNormalizer() - refs = [] - hyps = [] - secs_vals = [] - - for row in rows: - ref = row.get("reference_text", "") - hyp = row.get("asr_hyp", "") - refs.append(normalizer(ref)) - hyps.append(normalizer(hyp)) - - if row.get("secs") is not None: - secs_vals.append(float(row["secs"])) - - cer = float(word_error_rate(hyps, refs, use_cer=True)) if refs else None - wer = float(word_error_rate(hyps, refs, use_cer=False)) if refs else None - secs = (sum(secs_vals) / len(secs_vals)) if secs_vals else None - - return { - "cer": cer, - "wer": wer, - "secs": secs, - "num_samples": len(rows), - } - - -def compute_run_averaged_aggregates_from_filewise_rows(rows: List[Dict[str, Any]]): - """ - Compute a metric per run, then average those run metrics equally. - - This is useful when --num_eval_runs > 1 and you want the final number to mean: - average(metric(run_0), metric(run_1), ..., metric(run_N-1)) - """ - if len(rows) == 0: - return { - "cer": None, - "wer": None, - "secs": None, - "num_runs": 0, - "num_samples_per_run": {}, - "runs": [], - } - - grouped = {} - for row in rows: - run_id = int(row.get("run_id", 0)) - grouped.setdefault(run_id, []).append(row) - - run_summaries = [] - for run_id in sorted(grouped.keys()): - summary = compute_aggregates_from_filewise_rows(grouped[run_id]) - summary["run_id"] = int(run_id) - run_summaries.append(summary) - - def avg_key(key): - vals = [float(r[key]) for r in run_summaries if r.get(key) is not None] - if not vals: - return None - return sum(vals) / len(vals) - - return { - "cer": avg_key("cer"), - "wer": avg_key("wer"), - "secs": avg_key("secs"), - "num_runs": len(run_summaries), - "num_samples_per_run": {str(r["run_id"]): int(r.get("num_samples", 0)) for r in run_summaries}, - "runs": run_summaries, - } - -def save_filewise_final_summary(args, filewise_rows: List[Dict[str, Any]]): - all_observation_summary = compute_aggregates_from_filewise_rows(filewise_rows) - run_averaged_summary = compute_run_averaged_aggregates_from_filewise_rows(filewise_rows) - - obj = { - "aggregation": ( - "deduplicated_by_(run_id,dataset_index); " - "cer_wer_use_corpus_word_error_rate_matching_Intelligibility_compute; " - "primary_summary_is_mean_over_runs" - ), - "run_averaged": run_averaged_summary, - "all_observations": all_observation_summary, - } - - path = os.path.join(args.out_dir, "metrics_final_filewise_average.json") - write_json(path, obj) - - final_txt_path = os.path.join(args.out_dir, "metrics_final.txt") - final_text = format_filewise_final_metric_text(run_averaged_summary) - write_text_atomic(final_txt_path, final_text) - - print("\n--- Final Filewise Run-Averaged Evaluation Metrics ---", flush=True) - print(final_text, flush=True) - - logging.info(f"Filewise averaged final metrics saved to: {path}") - logging.info(f"Final metrics TXT saved to: {final_txt_path}") - - return obj - - -def parse_args(): - parser = argparse.ArgumentParser(description="EasyMagpieTTS Multi-GPU Inference Evaluation") - - parser.add_argument("--checkpoint_path", type=str, required=True) - parser.add_argument("--codec_model_path", type=str, required=True) - parser.add_argument("--datasets_json_path", type=str, required=True) - parser.add_argument("--out_dir", type=str, required=True) - - parser.add_argument("--phoneme_tokenizer_path", type=str, default=None) - parser.add_argument("--audio_dir", type=str, default=None) - parser.add_argument("--inference_dtype", type=str, default="float32", choices=["float32", "float16", "bfloat16"]) - parser.add_argument("--debug_dtype", action="store_true") - parser.add_argument("--debug_gpu_assignment", action="store_true") - parser.add_argument("--use_librosa", action="store_true") - - parser.add_argument("--batch_size", type=int, default=6) - parser.add_argument( - "--num_eval_runs", - type=int, - default=1, - help="Repeat the same evaluation set N times. Final filewise metrics are averaged across runs.", - ) - parser.add_argument( - "--sort_by_text_token_count", - action="store_true", - help="Sort samples by summed text-token count before DistributedSampler sharding for better GPU load balance.", - ) - parser.add_argument( - "--sort_text_token_count_descending", - action="store_true", - help="When sorting by token count, sort longest first. Default is shortest first to make DistributedSampler padding cheap.", - ) - parser.add_argument("--num_workers", type=int, default=4) - parser.add_argument("--num_turns", type=int, default=1) - parser.add_argument("--pad_factor_text_speech", type=int, default=10) - - parser.add_argument("--emulate_duplex_inference", action="store_true") - parser.add_argument("--add_interruption_token", action="store_true") - parser.add_argument("--force_interruption", action="store_true") - parser.add_argument("--profile_multiturn_inference", action="store_true") - parser.add_argument("--profile_pad_min_sec", type=float, default=2.0) - parser.add_argument("--profile_pad_max_sec", type=float, default=2.0) - parser.add_argument("--max_eval_turns", type=int, default=6) - - parser.add_argument("--user_custom_speaker_reference", action="store_true") - parser.add_argument("--inference_speaker_reference", type=str, default=None) - parser.add_argument("--language", type=str, default="en") - - parser.add_argument("--use_cfg", action="store_true") - parser.add_argument("--cfg_scale", type=float, default=2.5) - parser.add_argument("--temperature", type=float, default=0.7) - parser.add_argument("--topk", type=int, default=80) - parser.add_argument("--max_tts_steps", type=int, default=2000) - parser.add_argument("--force_speech_sil_codes", action="store_true") - parser.add_argument("--normalize_volume", type=lambda x: str(x).lower() in ["true", "1", "yes"], default=True) - - parser.add_argument( - "--save_filewise_metrics", - action="store_true", - help="Save per-file CER/WER/SECS metrics sorted by CER descending.", - ) - parser.add_argument( - "--filewise_metrics_topk_log", - type=int, - default=20, - help="Number of worst CER samples to print on rank 0.", - ) - - return parser.parse_args() - - -def main(): - args = parse_args() - if int(args.num_eval_runs) <= 0: - raise RuntimeError("--num_eval_runs must be >= 1") - if int(args.num_eval_runs) > 1 and not args.save_filewise_metrics: - args.save_filewise_metrics = True - print("[info] --num_eval_runs > 1, enabling --save_filewise_metrics for run-averaged final metrics.", flush=True) - os.makedirs(args.out_dir, exist_ok=True) - os.makedirs(get_audio_out_dir(args), exist_ok=True) - - distributed, rank, local_rank, world_size, device_index = setup_distributed() - - if args.profile_multiturn_inference and args.batch_size != 1: - raise RuntimeError( - "--profile_multiturn_inference requires --batch_size=1 per process. " - "Use multiple GPUs/processes for parallelism instead of increasing batch_size." - ) - - if args.profile_pad_max_sec < args.profile_pad_min_sec: - raise RuntimeError("--profile_pad_max_sec must be >= --profile_pad_min_sec.") - - target_device = torch.device(f"cuda:{device_index}" if torch.cuda.is_available() and device_index >= 0 else "cpu") - target_dtype = getattr(torch, args.inference_dtype) - torch.set_default_dtype(target_dtype) - - hostname = socket.gethostname() - cuda_name = torch.cuda.get_device_name(target_device) if torch.cuda.is_available() and device_index >= 0 else "cpu" - - all_rank_print( - rank, - f"host={hostname} local_rank={local_rank} world_size={world_size} " - f"device={target_device} device_name={cuda_name}", - ) - - model = build_model_and_codec(args, target_device, target_dtype) - codec_sil_codes = model.codec_sil_codes - - if args.debug_dtype: - handles, stats, examples = attach_dtype_counter(model) - else: - handles = stats = examples = None - - with fp32_precision(): - intelligibility = Intelligibility("stt_en_fastconformer_transducer_large", reuse_asr_hyps=False).reset() - secs_metric = SECS("titanet_large").reset() - - eval_dataset = EvalJSONLDataset(args.datasets_json_path, num_turns=args.num_turns) - if args.sort_by_text_token_count: - eval_dataset = SortedByTextTokenCountDataset( - dataset=eval_dataset, - model=model, - max_eval_turns=args.max_eval_turns, - descending=bool(args.sort_text_token_count_descending), - ) - sort_dir = "descending" if args.sort_text_token_count_descending else "ascending" - rank0_print(rank, f"[info] Sorted evaluation samples by summed text-token count {sort_dir}.") - - if distributed: - sampler = DistributedSampler( - eval_dataset, - num_replicas=world_size, - rank=rank, - shuffle=False, - drop_last=False, - ) - else: - sampler = SequentialSampler(eval_dataset) - - if args.debug_gpu_assignment: - if distributed: - assigned_sampler_indices = list(iter(sampler)) - assigned_dataset_indices = [ - int(eval_dataset[i].get("__dataset_index__", -1)) - for i in assigned_sampler_indices - ] - repeated_on_rank = len(assigned_dataset_indices) - len(set(assigned_dataset_indices)) - all_rank_print( - rank, - f"assigned {len(assigned_dataset_indices)} / {len(eval_dataset)} samples " - f"to gpu={local_rank}; repeated_on_this_rank={repeated_on_rank}; " - f"dataset_indices={assigned_dataset_indices}; " - f"text_token_counts={[int(eval_dataset[i].get('__text_token_count__', -1)) for i in assigned_sampler_indices]}", - ) - else: - assigned_dataset_indices = [ - int(eval_dataset[i].get("__dataset_index__", -1)) - for i in range(len(eval_dataset)) - ] - all_rank_print( - rank, - f"assigned {len(assigned_dataset_indices)} samples to single process: " - f"dataset_indices={assigned_dataset_indices}; " - f"text_token_counts={[int(eval_dataset[i].get('__text_token_count__', -1)) for i in range(len(eval_dataset))]}", - ) - - collate_fn = partial( - collate_and_tokenize_custom, - model=model, - extra_duration_thrshould=1.5, - sample_rate=model.sample_rate, - root_path=args.audio_dir, - emulate_duplex_inference=args.emulate_duplex_inference, - add_interruption_token=args.add_interruption_token, - pad_factor_text_speech=args.pad_factor_text_speech, - force_interruption=args.force_interruption, - normalize_audio_volume=args.normalize_volume, - use_librosa=args.use_librosa, - profile_multiturn_inference=args.profile_multiturn_inference, - max_eval_turns=args.max_eval_turns, - ) - - dataloader = DataLoader( - dataset=eval_dataset, - batch_size=args.batch_size, - sampler=sampler, - collate_fn=collate_fn, - num_workers=args.num_workers, - pin_memory=True, - drop_last=False, - ) - - speaker_wav = load_speaker_wav_if_needed(args, model, target_dtype) - - if distributed: - dist.barrier() - - start_time = time.time() - num_processed = 0 - rank_filewise_rows = [] - - for run_id in range(int(args.num_eval_runs)): - if distributed and hasattr(sampler, "set_epoch"): - sampler.set_epoch(run_id) - - if args.debug_gpu_assignment: - all_rank_print(rank, f"starting eval run {run_id + 1}/{int(args.num_eval_runs)}") - - for batch_id, inputs in enumerate(dataloader): - batch_indices = inputs.get("dataset_indices", []) - num_processed += len(batch_indices) - - if args.debug_gpu_assignment: - all_rank_print( - rank, - f"run_id={run_id} gpu={local_rank} batch_id={batch_id} " - f"dataset_indices={batch_indices} " - f"text_token_counts={inputs.get('text_token_counts', [])} " - f"target_paths={inputs.get('target_audio_paths', [])}", - ) - - inputs = prepare_inputs_for_device(inputs, model, args, target_dtype, speaker_wav=speaker_wav) - - finalize_output, profile_turn_frame_ranges, profile_decode_start_frame = run_generation( - model=model, - inputs=inputs, - args=args, - codec_sil_codes=codec_sil_codes, - ) - - audio_f32_for_metrics, audio_len_for_metrics, asr_hyps_for_metrics = update_metrics_and_save_audio( - model=model, - inputs=inputs, - finalize_output=finalize_output, - profile_turn_frame_ranges=profile_turn_frame_ranges, - profile_decode_start_frame=profile_decode_start_frame, - intelligibility=intelligibility, - secs_metric=secs_metric, - args=args, - rank=rank, - run_id=run_id, - ) - - if args.save_filewise_metrics: - filewise_rows = compute_filewise_metrics_for_batch( - rank=rank, - model=model, - inputs=inputs, - audio_f32=audio_f32_for_metrics, - audio_len=audio_len_for_metrics, - asr_hyps=asr_hyps_for_metrics, - run_id=run_id, - ) - rank_filewise_rows.extend(filewise_rows) - - if args.debug_dtype and batch_id == 0 and run_id == 0: - report_dtype_stats(handles, stats, examples, rank=rank) - - elapsed = time.time() - start_time - - rank_metrics = compute_and_save_rank_metrics( - args=args, - rank=rank, - world_size=world_size, - num_processed=num_processed, - elapsed=elapsed, - intelligibility=intelligibility, - secs_metric=secs_metric, - ) - - all_rank_print(rank, f"saved rank metrics: {json.dumps(rank_metrics, sort_keys=True)}") - - if args.save_filewise_metrics: - rank_filewise_rows.sort( - key=lambda x: ( - x.get("cer") is not None, - float(x.get("cer")) if x.get("cer") is not None else -1.0, - ), - reverse=True, - ) - - rank_filewise_path = os.path.join(args.out_dir, f"filewise_metrics_rank{rank:04d}.jsonl") - write_jsonl(rank_filewise_path, rank_filewise_rows) - all_rank_print(rank, f"saved filewise metrics: {rank_filewise_path}") - - if distributed: - dist.barrier() - - merge_metrics_on_rank0(args, rank, world_size) - - if args.save_filewise_metrics: - if distributed: - dist.barrier() - - filewise_rows = merge_filewise_metrics_on_rank0(args, rank, world_size) - - if rank == 0: - save_filewise_final_summary(args, filewise_rows) - - cleanup_distributed() - - -if __name__ == "__main__": - main() diff --git a/examples/tts/easy_magpietts_inference_multiturn_multigpu_turn_metric.py b/examples/tts/easy_magpietts_inference_multiturn_multigpu_turn_metric.py deleted file mode 100644 index deb2244f2376..000000000000 --- a/examples/tts/easy_magpietts_inference_multiturn_multigpu_turn_metric.py +++ /dev/null @@ -1,2247 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -""" -Multi-GPU EasyMagpieTTS / NemotronTTS multiturn inference evaluation. - -Key behavior: - - Uses torchrun env vars RANK, LOCAL_RANK, WORLD_SIZE for sharding/GPU assignment. - - Does NOT initialize torch.distributed. This avoids NeMo ASR doing distributed - collectives during metric computation. - - Generation runs first for all assigned samples. - - ASR and SECS are loaded only after generation is done and the TTS/codec model - has been deleted from GPU memory. - - ASR and SECS are loaded sequentially: ASR first, then released; SECS second. - - For --profile_multiturn_inference, metrics are computed turn-by-turn. - Final filewise outputs are grouped back to one row per original sample, with - lists for asr_hyp/reference_text/cer_turns/wer_turns/secs_turns. - - Uses DistributedSampler with explicit rank/world_size. A few repeated samples - may appear when len(dataset) is not divisible by world_size. Filewise final - metrics deduplicate sampler-padding repeats by (run_id, dataset_index, - turn_id), then group turns into one row per sample with metric lists, while - preserving --num_eval_runs repetitions. - - --sort_by_text_token_count sorts samples by total text-token count before - sharding to improve GPU load balance. - - Saves audio in out_dir/audios/. - - Saves metrics in out_dir/. - -Recommended single-node torchrun: - CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ - torchrun --standalone --nproc_per_node=8 easy_magpietts_inference_multiturn_multigpu_postgen_metrics.py ... - -Recommended single-node srun wrapper: - srun --nodes=1 --ntasks=1 --ntasks-per-node=1 --container-image=... \ - bash -lc 'torchrun --standalone --nproc_per_node=8 easy_magpietts_inference_multiturn_multigpu_postgen_metrics.py ...' -""" - -import argparse -import csv -import json -import os -import socket -import time -from copy import deepcopy -from functools import partial -from typing import Any, Dict, Iterable, List, Tuple - -import librosa -import soundfile as sf -import torch -from omegaconf import open_dict -from torch.nn.utils.rnn import pad_sequence -from torch.utils.data import DataLoader, Dataset, DistributedSampler, SequentialSampler - -from nemo.collections.audio.parts.utils.transforms import resample -from nemo.collections.asr.metrics.wer import word_error_rate -from nemo.collections.speechlm2.parts.metrics.asr_cer_wer import Intelligibility -from nemo.collections.speechlm2.parts.metrics.secs import SECS -from nemo.collections.speechlm2.parts.precision import fp32_precision -from nemo.collections.tts.models import AudioCodecModel -from nemo.collections.tts.models.easy_magpietts_inference import EasyMagpieTTSInferenceModel -from nemo.collections.tts.modules.audio_codec_modules import VectorQuantizerIndexConverter -from nemo.collections.tts.modules.magpietts_modules import CodecHelper -from nemo.collections.tts.parts.utils.tts_dataset_utils import normalize_volume -from nemo.utils import logging -from whisper_normalizer.english import EnglishTextNormalizer - - -torch.set_float32_matmul_precision("medium") -torch.backends.cudnn.allow_tf32 = True -torch.backends.cuda.matmul.allow_tf32 = True - - -# ----------------------------- -# Rank / file helpers -# ----------------------------- - - -def get_rank_info() -> Tuple[bool, int, int, int]: - world_size = int(os.environ.get("WORLD_SIZE", "1")) - rank = int(os.environ.get("RANK", os.environ.get("SLURM_PROCID", "0"))) - local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("SLURM_LOCALID", "0"))) - distributed = world_size > 1 - return distributed, rank, local_rank, world_size - - -def get_visible_device_index(local_rank: int) -> int: - if not torch.cuda.is_available(): - return -1 - ndev = torch.cuda.device_count() - if ndev <= 0: - return -1 - return local_rank % ndev - - -def setup_distributed(): - """ - Do not initialize torch.distributed. - - We only need RANK/LOCAL_RANK/WORLD_SIZE for rank assignment and dataset - sharding. Initializing a process group can cause NeMo ASR to run distributed - collectives during transcribe(), which may hang when ranks have different - audio lengths or workloads. - """ - distributed, rank, local_rank, world_size = get_rank_info() - device_index = get_visible_device_index(local_rank) - - if torch.cuda.is_available() and device_index >= 0: - torch.cuda.set_device(device_index) - - return distributed, rank, local_rank, world_size, device_index - - -def cleanup_distributed(): - return - - -def all_rank_print(rank: int, msg: str): - print(f"[rank={rank}] {msg}", flush=True) - - -def rank0_print(rank: int, msg: str): - if rank == 0: - print(msg, flush=True) - - -def get_audio_out_dir(args) -> str: - return os.path.join(args.out_dir, "audios") - - -def get_generated_turn_audio_dir(args) -> str: - return os.path.join(get_audio_out_dir(args), "metric_turns") - - -def get_context_metric_audio_dir(args) -> str: - return os.path.join(get_audio_out_dir(args), "metric_context") - - -def write_json(path: str, obj: Dict[str, Any]): - os.makedirs(os.path.dirname(path), exist_ok=True) - tmp_path = path + ".tmp" - with open(tmp_path, "w", encoding="utf-8") as f: - json.dump(obj, f, indent=2, sort_keys=True, ensure_ascii=False) - os.replace(tmp_path, path) - - -def write_text_atomic(path: str, text: str): - os.makedirs(os.path.dirname(path), exist_ok=True) - tmp_path = path + ".tmp" - with open(tmp_path, "w", encoding="utf-8") as f: - f.write(text) - os.replace(tmp_path, path) - - -def write_jsonl(path: str, rows: List[Dict[str, Any]]): - os.makedirs(os.path.dirname(path), exist_ok=True) - tmp_path = path + ".tmp" - with open(tmp_path, "w", encoding="utf-8") as f: - for row in rows: - f.write(json.dumps(row, sort_keys=True, ensure_ascii=False) + "\n") - os.replace(tmp_path, path) - - -def wait_for_files(paths: List[str], timeout_sec: float = 7200.0, poll_sec: float = 5.0): - start = time.time() - while True: - missing = [p for p in paths if not os.path.exists(p)] - if not missing: - return - if time.time() - start > timeout_sec: - raise TimeoutError("Timed out waiting for files:\n" + "\n".join(missing)) - time.sleep(poll_sec) - - -def wait_for_rank_metric_files(args, world_size: int): - paths = [os.path.join(args.out_dir, f"metrics_rank{r:04d}.json") for r in range(world_size)] - wait_for_files(paths) - - -def wait_for_rank_filewise_metric_files(args, world_size: int): - paths = [os.path.join(args.out_dir, f"filewise_metrics_rank{r:04d}.jsonl") for r in range(world_size)] - wait_for_files(paths) - - -def scalarize_metric_value(v: Any): - if torch.is_tensor(v): - if v.numel() == 1: - return float(v.detach().cpu().item()) - return v.detach().cpu().tolist() - try: - import numpy as np - - if isinstance(v, np.generic): - return float(v.item()) - except Exception: - pass - if isinstance(v, (int, float, str, bool)) or v is None: - return v - return str(v) - - -def metric_dict_to_jsonable(d: Dict[str, Any]) -> Dict[str, Any]: - return {str(k): scalarize_metric_value(v) for k, v in d.items()} - - -def safe_metric_scalar(metric_dict: Dict[str, Any], preferred_keys: List[str]): - for key in preferred_keys: - if key in metric_dict: - value = metric_dict[key] - if torch.is_tensor(value): - return float(value.detach().cpu().item()) - return float(value) - return None - - -def get_first_metric(metrics: Dict[str, Any], names: List[str], default=None): - for name in names: - if name in metrics: - return metrics[name] - return default - - -def format_final_metric_text(final_metrics: Dict[str, Any]) -> str: - intelligibility = final_metrics.get("intelligibility", {}) - secs = final_metrics.get("secs", {}) - - cer = get_first_metric(intelligibility, ["cer", "cer_dataset"]) - wer = get_first_metric(intelligibility, ["wer", "wer_dataset"]) - secs_value = get_first_metric(secs, ["secs", "secs_dataset"]) - - def fmt(x): - if x is None: - return "nan" - try: - return f"{float(x):.10f}" - except Exception: - return str(x) - - return f"Average CER: {fmt(cer)}\nAverage WER: {fmt(wer)}\nSECS: {fmt(secs_value)}\n" - - -def format_filewise_final_metric_text(filewise_summary: Dict[str, Any]) -> str: - def fmt(x): - if x is None: - return "nan" - try: - return f"{float(x):.10f}" - except Exception: - return str(x) - - return ( - f"Average CER: {fmt(filewise_summary.get('cer'))}\n" - f"Average WER: {fmt(filewise_summary.get('wer'))}\n" - f"SECS: {fmt(filewise_summary.get('secs'))}\n" - ) - - -def write_filewise_csv(path: str, rows: List[Dict[str, Any]]): - """Write sample-level filewise metrics. - - Several fields are lists (turn_ids, reference_text, asr_hyp, cer_turns, - etc.), so they are JSON-encoded inside CSV cells. - """ - os.makedirs(os.path.dirname(path), exist_ok=True) - tmp_path = path + ".tmp" - - fieldnames = [ - "run_id", - "dataset_index", - "rank", - "num_turns", - "cer", - "wer", - "secs", - "turn_ids", - "cer_turns", - "wer_turns", - "secs_turns", - "pred_audio_seconds_turns", - "target_audio_path", - "context_audio_path", - "pred_audio_paths", - "reference_text", - "asr_hyp", - ] - - def csv_value(v): - if isinstance(v, (list, dict)): - return json.dumps(v, ensure_ascii=False) - return v - - with open(tmp_path, "w", encoding="utf-8", newline="") as f: - writer = csv.DictWriter(f, fieldnames=fieldnames) - writer.writeheader() - for row in rows: - writer.writerow({k: csv_value(row.get(k, None)) for k in fieldnames}) - - os.replace(tmp_path, path) - -# ----------------------------- -# Dataset helpers -# ----------------------------- - - -def _combined_audio_name(first_audio_filepath: str, paths: List[str]) -> str: - base_names = [os.path.splitext(os.path.basename(p))[0] for p in paths if p] - ext = os.path.splitext(paths[-1])[1] if paths and paths[-1] else "" - combined_name = "_".join(base_names) + ext - dir_name = os.path.dirname(first_audio_filepath) - return os.path.join(dir_name, combined_name) if dir_name else combined_name - - -class EvalJSONLDataset(Dataset): - def __init__(self, file_path: str, num_turns: int = 1): - self.samples = [] - raw_samples = [] - - with open(file_path, "r", encoding="utf-8") as f: - for line_idx, line in enumerate(f, 1): - line = line.strip() - if not line: - continue - try: - sample = json.loads(line) - sample["__dataset_index__"] = len(raw_samples) - raw_samples.append(sample) - except json.JSONDecodeError as e: - raise ValueError(f"Invalid JSON on line {line_idx}: {e}") - - if num_turns <= 1: - self.samples = raw_samples - return - - single_turn_by_speaker = {} - for sample in raw_samples: - if isinstance(sample["text"], list): - self.samples.append(sample) - else: - speaker = sample.get("speaker", "unknown") - single_turn_by_speaker.setdefault(speaker, []).append(sample) - - synthetic_index = len(raw_samples) - for _, speaker_samples in single_turn_by_speaker.items(): - buffer_texts, buffer_paths = [], [] - first_sample_meta = None - - for sample in speaker_samples: - if not buffer_texts: - first_sample_meta = dict(sample) - - buffer_texts.append(sample["text"]) - buffer_paths.append(sample.get("audio_filepath", "")) - - if len(buffer_texts) == num_turns: - first_sample_meta["text"] = buffer_texts - first_sample_meta["audio_filepath"] = _combined_audio_name( - first_sample_meta.get("audio_filepath", ""), - buffer_paths, - ) - first_sample_meta["__dataset_index__"] = synthetic_index - synthetic_index += 1 - self.samples.append(first_sample_meta) - buffer_texts, buffer_paths, first_sample_meta = [], [], None - - if buffer_texts and first_sample_meta is not None: - first_sample_meta["text"] = buffer_texts - first_sample_meta["audio_filepath"] = _combined_audio_name( - first_sample_meta.get("audio_filepath", ""), - buffer_paths, - ) - first_sample_meta["__dataset_index__"] = synthetic_index - synthetic_index += 1 - self.samples.append(first_sample_meta) - - def __len__(self): - return len(self.samples) - - def __getitem__(self, idx): - return self.samples[idx] - - -def _sample_text_segments_for_count(sample: Dict[str, Any], max_eval_turns=None) -> List[str]: - text_data = sample.get("text", "") - if isinstance(text_data, list): - segments = text_data - if max_eval_turns is not None: - segments = segments[: int(max_eval_turns)] - return [str(x) for x in segments] - return [str(text_data)] - - -def estimate_text_token_count(sample: Dict[str, Any], model, max_eval_turns=None) -> int: - main_tokenizer_name = list(model.cfg.text_tokenizers.keys())[0] - total = 0 - for segment in _sample_text_segments_for_count(sample, max_eval_turns=max_eval_turns): - total += len(model.tokenizer.encode(segment, tokenizer_name=main_tokenizer_name)) + 1 - return int(total) - - -class SortedByTextTokenCountDataset(Dataset): - def __init__(self, dataset: Dataset, model, max_eval_turns=None, descending: bool = True): - self.dataset = dataset - scored = [] - for i in range(len(dataset)): - sample = dict(dataset[i]) - token_count = estimate_text_token_count(sample, model=model, max_eval_turns=max_eval_turns) - sample["__text_token_count__"] = int(token_count) - scored.append((token_count, i, sample)) - - scored.sort(key=lambda x: (x[0], -x[1]), reverse=bool(descending)) - self.indices = [i for _, i, _ in scored] - self.token_counts = {i: int(tok) for tok, i, _ in scored} - - def __len__(self): - return len(self.indices) - - def __getitem__(self, local_idx): - original_idx = self.indices[local_idx] - sample = dict(self.dataset[original_idx]) - sample["__text_token_count__"] = self.token_counts[original_idx] - return sample - - -# ----------------------------- -# Audio / collate helpers -# ----------------------------- - - -def _resolve_audio_path(path, root_path): - if path is None: - return None - if root_path is not None and not os.path.isabs(path): - return os.path.join(root_path, path) - return path - - -def _load_audio(path, sample_rate, normalize=True, use_librosa=False): - if path is None or not os.path.exists(path): - return torch.zeros(1, dtype=torch.float32) - - if use_librosa: - wav, sr = librosa.load(path, sr=sample_rate, mono=True) - if normalize: - wav = normalize_volume(wav) - return torch.as_tensor(wav, dtype=torch.float32) - - wav, sr = sf.read(path, dtype="float32") - if wav.ndim > 1: - wav = wav.mean(axis=1) - - if normalize: - wav = normalize_volume(wav) - - wav = torch.as_tensor(wav, dtype=torch.float32).unsqueeze(0) - return resample(wav, sr, sample_rate).squeeze(0) - - -def collate_and_tokenize_custom( - batch, - model, - extra_duration_thrshould=1.3, - sample_rate=22050, - root_path=None, - emulate_duplex_inference=False, - add_interruption_token=False, - pad_factor_text_speech=10, - force_interruption=False, - normalize_audio_volume=True, - use_librosa=False, - profile_multiturn_inference=False, - max_eval_turns=None, -): - main_tokenizer_name = list(model.cfg.text_tokenizers.keys())[0] - - if max_eval_turns is not None: - max_eval_turns = int(max_eval_turns) - if max_eval_turns <= 0: - raise ValueError("--max_eval_turns must be > 0 when provided.") - - truncated_batch = [] - for s in batch: - s = dict(s) - if isinstance(s["text"], list): - s["text"] = s["text"][:max_eval_turns] - if isinstance(s.get("user_audio_file_path"), list): - s["user_audio_file_path"] = s["user_audio_file_path"][:max_eval_turns] - truncated_batch.append(s) - batch = truncated_batch - - is_profile = profile_multiturn_inference - is_duplex = emulate_duplex_inference and not is_profile - - out_dict = { - "duplex_multiturn": is_duplex, - "regular_multiturn": (not is_duplex) and (not is_profile), - "profile_multiturn": is_profile, - "dataset_indices": [int(s.get("__dataset_index__", -1)) for s in batch], - "text_token_counts": [int(s.get("__text_token_count__", -1)) for s in batch], - } - - tokenized_list = [] - batched_turns = [] - batched_turn_lens = [] - valid_turn_masks = [] - - if is_duplex: - for s in batch: - text_data = s["text"] - - if isinstance(text_data, list): - full_ids = [] - for segment in text_data: - seg_ids = model.tokenizer.encode(segment, tokenizer_name=main_tokenizer_name) + [model.eos_id] - pad_ids = [model.pad_id] * (len(seg_ids) * pad_factor_text_speech) - - if force_interruption: - fname = s["audio_filepath"] - no_ext = fname.split(".")[0] - sample_id = int(no_ext.split("_")[-1]) - case = sample_id % 3 - - if case == 0: - if len(seg_ids) >= 2: - seg_ids[-2] = model.interruption_token_id - seg_ids[-1] = model.pad_id - else: - pad_ids[0] = model.interruption_token_id - elif case == 1: - eos_idx = min(6, len(pad_ids) - 1) - pad_ids[eos_idx] = model.interruption_token_id - else: - pad_ids[0] = model.interruption_token_id - - elif add_interruption_token: - eos_idx = int(len(pad_ids) * 0.7) - pad_ids[eos_idx] = model.interruption_token_id - - full_ids.extend(seg_ids) - full_ids.extend(pad_ids) - - tokenized_list.append(torch.as_tensor(full_ids, dtype=torch.long)) - else: - tokenized_list.append( - torch.as_tensor( - model.tokenizer.encode(text_data, tokenizer_name=main_tokenizer_name) + [model.eos_id], - dtype=torch.long, - ) - ) - - prefix = torch.full((25,), model.pad_id, dtype=torch.long) - tokenized_list = [torch.cat([prefix, x]) for x in tokenized_list] - out_dict["input_lengths"] = torch.tensor([len(x) for x in tokenized_list], dtype=torch.long) - out_dict["input_ids"] = pad_sequence(tokenized_list, batch_first=True, padding_value=model.pad_id) - - else: - max_turns = 1 - for s in batch: - if isinstance(s["text"], list): - max_turns = max(max_turns, len(s["text"])) - - for t in range(max_turns): - turn_t_tokens, turn_t_lens, turn_t_valid = [], [], [] - - for s in batch: - text_data = s["text"] - - if isinstance(text_data, list): - if t < len(text_data): - seg_ids = model.tokenizer.encode(text_data[t], tokenizer_name=main_tokenizer_name) + [ - model.eos_id - ] - turn_t_tokens.append(torch.as_tensor(seg_ids, dtype=torch.long)) - turn_t_lens.append(len(seg_ids)) - turn_t_valid.append(True) - else: - turn_t_tokens.append(torch.as_tensor([model.pad_id], dtype=torch.long)) - turn_t_lens.append(1) - turn_t_valid.append(False) - else: - if t == 0: - seg_ids = model.tokenizer.encode(text_data, tokenizer_name=main_tokenizer_name) + [ - model.eos_id - ] - turn_t_tokens.append(torch.as_tensor(seg_ids, dtype=torch.long)) - turn_t_lens.append(len(seg_ids)) - turn_t_valid.append(True) - else: - turn_t_tokens.append(torch.as_tensor([model.pad_id], dtype=torch.long)) - turn_t_lens.append(1) - turn_t_valid.append(False) - - batched_turns.append(pad_sequence(turn_t_tokens, batch_first=True, padding_value=model.pad_id)) - batched_turn_lens.append(torch.tensor(turn_t_lens, dtype=torch.long)) - valid_turn_masks.append(torch.tensor(turn_t_valid, dtype=torch.bool)) - - out_dict["batched_turns"] = batched_turns - out_dict["batched_turn_lens"] = batched_turn_lens - out_dict["valid_turn_masks"] = valid_turn_masks - - audio_list, audio_lengths, target_num_frames = [], [], [] - context_audio_paths = [] - max_turns_for_user_audio = len(batched_turns) if not is_duplex else 0 - - if is_profile and max_turns_for_user_audio > 0: - user_audio_by_turn = [[] for _ in range(max_turns_for_user_audio)] - user_audio_lens_by_turn = [[] for _ in range(max_turns_for_user_audio)] - else: - user_audio_by_turn, user_audio_lens_by_turn = [], [] - - for i, s in enumerate(batch): - audio_path = _resolve_audio_path(s.get("context_audio_filepath"), root_path) - context_audio_paths.append(audio_path) - wav = _load_audio(audio_path, sample_rate, normalize=normalize_audio_volume, use_librosa=use_librosa) - audio_list.append(wav) - audio_lengths.append(len(wav)) - - if is_profile and max_turns_for_user_audio > 0: - user_audio_paths = s.get("user_audio_file_path", None) - - for t in range(max_turns_for_user_audio): - has_valid_text_turn = (isinstance(s["text"], list) and t < len(s["text"])) or ( - not isinstance(s["text"], list) and t == 0 - ) - - if ( - isinstance(user_audio_paths, list) - and t < len(user_audio_paths) - and user_audio_paths[t] - and has_valid_text_turn - ): - ua_path = _resolve_audio_path(user_audio_paths[t], root_path) - ua_wav = _load_audio( - ua_path, - sample_rate=sample_rate, - normalize=normalize_audio_volume, - use_librosa=use_librosa, - ) - else: - ua_wav = torch.zeros(int(2 * sample_rate), dtype=torch.float32) - - user_audio_by_turn[t].append(ua_wav) - user_audio_lens_by_turn[t].append(len(ua_wav)) - - tdur_audio_path = _resolve_audio_path(s["audio_filepath"], root_path) - - if tdur_audio_path and os.path.exists(tdur_audio_path): - wav_dur = _load_audio( - tdur_audio_path, - sample_rate, - normalize=normalize_audio_volume, - use_librosa=use_librosa, - ) - tdur = wav_dur.shape[0] // model.input_samples_per_frame - target_num_frames.append(tdur * extra_duration_thrshould) - else: - if is_duplex: - current_text_len = len(tokenized_list[i]) - target_num_frames.append(current_text_len if isinstance(s["text"], list) else current_text_len * 5) - else: - target_num_frames.append(sum([l[i].item() for l in batched_turn_lens]) * 5) - - max_audio_len = max(audio_lengths) - B = len(audio_lengths) - padded_audio = torch.zeros((B, max_audio_len), dtype=torch.float32) - - for i, wav in enumerate(audio_list): - padded_audio[i, : len(wav)] = wav - - if is_profile and max_turns_for_user_audio > 0: - padded_user_audio_turns, padded_user_audio_turns_lens = [], [] - - for t in range(max_turns_for_user_audio): - turn_lens = user_audio_lens_by_turn[t] - max_turn_audio_len = max(turn_lens) - padded_turn_audio = torch.zeros((B, max_turn_audio_len), dtype=torch.float32) - - for i, wav in enumerate(user_audio_by_turn[t]): - padded_turn_audio[i, : len(wav)] = wav - - padded_user_audio_turns.append(padded_turn_audio) - padded_user_audio_turns_lens.append(torch.tensor(turn_lens, dtype=torch.long)) - - out_dict["user_audio_turns"] = padded_user_audio_turns - out_dict["user_audio_turns_lens"] = padded_user_audio_turns_lens - - raw_turn_texts = [] - for s in batch: - if isinstance(s["text"], list): - raw_turn_texts.append([str(x) for x in s["text"]]) - else: - raw_turn_texts.append([str(s["text"])]) - - out_dict["context_audio"] = padded_audio - out_dict["context_audio_lengths"] = torch.tensor(audio_lengths, dtype=torch.long) - out_dict["context_audio_paths"] = context_audio_paths - out_dict["target_audio_paths"] = [s["audio_filepath"] for s in batch] - out_dict["target_num_frames"] = target_num_frames - out_dict["raw_turn_texts"] = raw_turn_texts - out_dict["raw_text"] = [" ".join(x) for x in raw_turn_texts] - - return out_dict - - -# ----------------------------- -# Model / generation -# ----------------------------- - - -def attach_dtype_counter(model): - handles = [] - stats = {} - examples = {} - - def is_leaf(module): - return len(list(module.children())) == 0 - - def get_dtype(x): - if torch.is_tensor(x): - return str(x.dtype) - if isinstance(x, (list, tuple)): - for t in x: - if torch.is_tensor(t): - return str(t.dtype) - return "other" - - def get_module_group(name): - return name.split(".")[0] if "." in name else name - - def hook_fn(name): - def fn(module, inputs, outputs): - dtype = get_dtype(outputs) - if dtype not in ["torch.float16", "torch.bfloat16", "torch.float32"]: - dtype = "other" - group = get_module_group(name) - if group not in stats: - stats[group] = { - "torch.float16": 0, - "torch.bfloat16": 0, - "torch.float32": 0, - "other": 0, - } - examples[group] = { - "torch.float16": [], - "torch.bfloat16": [], - "torch.float32": [], - "other": [], - } - stats[group][dtype] += 1 - if len(examples[group][dtype]) < 3: - examples[group][dtype].append(module.__class__.__name__) - - return fn - - for name, module in model.named_modules(): - if is_leaf(module): - handles.append(module.register_forward_hook(hook_fn(name))) - return handles, stats, examples - - -def report_dtype_stats(handles, stats, examples, rank=0): - for h in handles: - h.remove() - logging.info(f"[rank={rank}] === DTYPE USAGE PER MODULE ===") - for group, group_stats in stats.items(): - total = sum(group_stats.values()) - if total == 0: - continue - logging.info(f"[rank={rank}] --- {group} ---") - for dtype, count in group_stats.items(): - if count > 0: - logging.info(f"[rank={rank}] {dtype}: {count} ({100 * count / total:.2f}%)") - logging.info(f"[rank={rank}] === DTYPE EXAMPLES ===") - for group, group_examples in examples.items(): - for dtype, mods in group_examples.items(): - if mods: - logging.info(f"[rank={rank}] {group} {dtype}: {mods}") - - -def build_model_and_codec(args, target_device, target_dtype): - model_cfg = EasyMagpieTTSInferenceModel.restore_from(args.checkpoint_path, return_config=True) - - with open_dict(model_cfg): - model_cfg.target = "nemo.collections.tts.models.easy_magpietts_inference.EasyMagpieTTSInferenceModel" - model_cfg.codecmodel_path = args.codec_model_path - model_cfg.train_ds = None - model_cfg.validation_ds = None - model_cfg.run_val_inference = False - model_cfg.use_utmos = False - model_cfg.use_meta_init_for_decoder = True - - if args.phoneme_tokenizer_path and getattr(model_cfg, "phoneme_tokenizer", None) is not None: - model_cfg.phoneme_tokenizer.tokenizer_path = args.phoneme_tokenizer_path - - model = EasyMagpieTTSInferenceModel.restore_from( - args.checkpoint_path, - override_config_path=model_cfg, - map_location=torch.device("cpu"), - ) - model.use_kv_cache_for_inference = True - model.to(dtype=target_dtype) - model.eval().to(target_device) - - model.input_samples_per_frame = int(model.codec_model_samples_per_frame * model.frame_stacking_factor) - model.target_samples_per_frame = model.input_samples_per_frame / (model.sample_rate / model.output_sample_rate) - - codec_model = AudioCodecModel.restore_from(args.codec_model_path, strict=False, map_location=torch.device("cpu")) - if hasattr(codec_model, "discriminator"): - del codec_model.discriminator - codec_model.freeze() - codec_model = codec_model.to(target_device).eval() - - codec_converter = None - if getattr(model, "_codec_converter", None) is not None: - vq_new = deepcopy(model._codec_converter.vector_quantizer_new).to(target_device).eval() - codec_converter = VectorQuantizerIndexConverter( - vector_quantizer_original=codec_model.vector_quantizer, - vector_quantizer_new=vq_new, - ).to(target_device).eval() - - model._codec_helper = CodecHelper(codec_model=codec_model, codec_converter=codec_converter) - model._generate_codec_silence_buffer() - - return model - - -def prepare_inputs_for_device(inputs, model, args, target_dtype, speaker_wav=None): - B = inputs["context_audio"].size(0) - device = model.device - - inputs["context_audio"] = inputs["context_audio"].to(device, dtype=target_dtype) - inputs["context_audio_lengths"] = inputs["context_audio_lengths"].to(device) - - if args.user_custom_speaker_reference and speaker_wav is not None: - inputs["context_audio"] = speaker_wav.repeat(B, 1).detach() - inputs["context_audio_lengths"] = torch.full((B,), speaker_wav.size(-1), dtype=torch.long, device=device) - - if "user_audio_turns" in inputs: - inputs["user_audio_turns"] = [x.to(device, dtype=target_dtype) for x in inputs["user_audio_turns"]] - inputs["user_audio_turns_lens"] = [x.to(device) for x in inputs["user_audio_turns_lens"]] - - return inputs - - -def run_generation(model, inputs, args, codec_sil_codes): - B = inputs["context_audio"].size(0) - device = model.device - profile_turn_frame_ranges = [] - profile_decode_start_frame = 0 - - with torch.inference_mode(): - wav = inputs["context_audio"] - wav_len = inputs["context_audio_lengths"] - codes, codes_lens = model._codec_helper.audio_to_codes(wav, wav_len) - - use_lang = bool(getattr(model, "add_language_to_context_text", False)) - ctx_text = f"[{args.language.upper()}]" if use_lang else "[NO TEXT CONTEXT]" - ctx_text_ids = model.tokenizer.encode(ctx_text, tokenizer_name=model.text_conditioning_tokenizer_name) - - ctx_toks = torch.tensor([ctx_text_ids], dtype=torch.long, device=device).expand(B, -1) - ctx_toks_lens = torch.tensor([len(ctx_text_ids)] * B, dtype=torch.long, device=device) - - state = model.streaming_init( - context_audio_codes=codes, - context_audio_codes_lens=codes_lens, - context_text_tokens=ctx_toks, - context_text_tokens_lens=ctx_toks_lens, - use_cfg=args.use_cfg, - cfg_scale=args.cfg_scale, - use_local_transformer=True, - temperature=args.temperature, - topk=args.topk, - phoneme_input_type="pred", - phoneme_sampling_method="argmax", - use_inference_mode=True, - ) - - if inputs["duplex_multiturn"]: - text = inputs["input_ids"].to(device) - text_lens = inputs["input_lengths"].to(device) - - in_initial_silence = torch.ones(B, dtype=torch.bool, device=device) - in_post_speech_silence = torch.zeros(B, dtype=torch.bool, device=device) - text_exhausted = state.text_tokens_seen >= text_lens - - while not text_exhausted.all(): - state.finished = state.finished & text_exhausted - state.text_finished = state.text_finished & text_exhausted - - if hasattr(state, "phoneme_stream_ended"): - state.phoneme_stream_ended = state.phoneme_stream_ended & text_exhausted - - positions = state.text_tokens_seen.clamp(max=text.size(1) - 1) - current_tokens = text[torch.arange(B, device=device), positions] - current_tokens = torch.where( - text_exhausted, - torch.full_like(current_tokens, model.eos_id), - current_tokens, - ) - - is_pad_or_eos = (current_tokens == model.pad_id) | (current_tokens == model.eos_id) - in_initial_silence = in_initial_silence & is_pad_or_eos - in_post_speech_silence = in_post_speech_silence & is_pad_or_eos - - state, audio_codes, _ = model.streaming_step( - state=state, - text_tokens=current_tokens, - use_inference_mode=True, - ) - - if audio_codes is not None and args.force_speech_sil_codes: - force_silence_mask = in_initial_silence | in_post_speech_silence - if force_silence_mask.any(): - expanded_sil = codec_sil_codes.view(1, -1, 1).expand_as(audio_codes) - mask_3d = force_silence_mask.view(B, 1, 1) - state.all_predictions[-1] = torch.where(mask_3d, expanded_sil, audio_codes) - - in_post_speech_silence = in_post_speech_silence | state.finished - text_exhausted = state.text_tokens_seen >= text_lens - - elif inputs["regular_multiturn"]: - batched_turns = inputs["batched_turns"] - batched_turn_lens = inputs["batched_turn_lens"] - valid_turn_masks = inputs["valid_turn_masks"] - turn_offsets = torch.zeros(B, dtype=torch.long, device=device) - - for t in range(len(batched_turns)): - turn_text = batched_turns[t].to(device) - turn_lens = batched_turn_lens[t].to(device) - valid_mask = valid_turn_masks[t].to(device) - - state.finished = state.finished & (~valid_mask) - state.text_finished = state.text_finished & (~valid_mask) - - if hasattr(state, "phoneme_stream_ended"): - state.phoneme_stream_ended = state.phoneme_stream_ended & (~valid_mask) - - if state.finished.all(): - continue - - turn_offsets = torch.where(valid_mask, state.text_tokens_seen, turn_offsets) - turn_steps = 0 - - while not state.finished.all() and turn_steps < args.max_tts_steps: - turn_steps += 1 - relative_positions = state.text_tokens_seen - turn_offsets - positions = relative_positions.clamp(min=0, max=turn_text.size(1) - 1) - current_tokens = turn_text[torch.arange(B, device=device), positions] - - exhausted = relative_positions >= turn_lens - current_tokens = torch.where( - exhausted, - torch.full_like(current_tokens, model.eos_id), - current_tokens, - ) - - state, _, _ = model.streaming_step( - state=state, - text_tokens=current_tokens, - use_inference_mode=True, - ) - - elif inputs["profile_multiturn"]: - if B != 1: - raise RuntimeError("--profile_multiturn_inference requires --batch_size=1 per process.") - - batched_turns = inputs["batched_turns"] - batched_turn_lens = inputs["batched_turn_lens"] - valid_turn_masks = inputs["valid_turn_masks"] - - for t in range(len(batched_turns)): - turn_text = batched_turns[t].to(device) - turn_lens = batched_turn_lens[t].to(device) - valid_mask = valid_turn_masks[t].to(device) - - if not bool(valid_mask[0].item()): - continue - - state.finished.zero_() - state.text_finished.zero_() - state.audio_prediction_end_idx.fill_(-1) - - if hasattr(state, "turn_text_tokens_seen"): - state.turn_text_tokens_seen.zero_() - if hasattr(state, "phoneme_steps"): - state.phoneme_steps.zero_() - if hasattr(state, "phoneme_stream_ended"): - state.phoneme_stream_ended.zero_() - if hasattr(state, "phoneme_eos_detected"): - state.phoneme_eos_detected.zero_() - state.last_phoneme_tokens = None - - if not model.cfg.get("condition_on_user_speech", False): - if "user_audio_turns" in inputs: - profile_T = int(round(inputs["user_audio_turns"][t].size(-1) / model.input_samples_per_frame)) - profile_seconds = profile_T * model.input_samples_per_frame / model.sample_rate - else: - profile_seconds = args.profile_pad_min_sec + torch.rand((), device=device).item() * ( - args.profile_pad_max_sec - args.profile_pad_min_sec - ) - profile_T = max( - 1, - int(round(profile_seconds * model.sample_rate / model.input_samples_per_frame)), - ) - - profile_tokens = torch.full((1, profile_T), model.pad_id, dtype=torch.long, device=device) - user_audio_channel_embedding = None - - else: - if "user_audio_turns" in inputs: - user_audio = inputs["user_audio_turns"][t] - user_audio_lens = inputs["user_audio_turns_lens"][t] - else: - user_audio = inputs["context_audio"] - user_audio_lens = inputs["context_audio_lengths"] - - user_audio_codes, user_audio_codes_lens = model._codec_helper.audio_to_codes( - user_audio, - user_audio_lens, - ) - - if model._codec_converter is not None: - user_audio_codes = model._codec_converter.convert_original_to_new( - audio_tokens=user_audio_codes, - audio_lens=user_audio_codes_lens, - ).long() - - user_audio_codes, user_audio_codes_lens = model.stack_codes( - user_audio_codes, - user_audio_codes_lens, - model.audio_bos_id, - model.audio_eos_id, - model.frame_stacking_factor, - model.num_audio_codebooks, - ) - - user_audio_embedded = model.embed_audio_tokens(user_audio_codes) - - boundary_trim = model.cfg.get("user_audio_boundary_trim", 4) - boundary_trim = 0 if boundary_trim is None else int(boundary_trim) - - if boundary_trim == 0: - real_start = 0 - real_end = int(user_audio_codes_lens[0].item()) - else: - turn_len_with_special = int(user_audio_codes_lens[0].item()) - real_start = 1 - real_end = max(real_start, turn_len_with_special - 1) - - user_audio_embedded = user_audio_embedded[:, real_start:real_end] - copy_len = user_audio_embedded.size(1) - - if boundary_trim > 0: - trim = min(boundary_trim, copy_len // 2) - if trim > 0: - user_audio_embedded[:, :trim] = 0.0 - user_audio_embedded[:, copy_len - trim :] = 0.0 - - bos_user_pad = torch.zeros( - user_audio_embedded.size(0), - 1, - user_audio_embedded.size(2), - device=user_audio_embedded.device, - dtype=user_audio_embedded.dtype, - ) - user_audio_embedded = torch.cat([bos_user_pad, user_audio_embedded], dim=1) - - profile_T = user_audio_embedded.size(1) - profile_tokens = torch.full((B, profile_T), model.pad_id, dtype=torch.long, device=device) - user_audio_channel_embedding = user_audio_embedded - profile_seconds = profile_T * model.input_samples_per_frame / model.sample_rate - - delay_tokens = int(state.config.training_mode.streaming_speech_delay) - delay_tokens = min(delay_tokens, int(turn_lens[0].item()), profile_T) - - warmup_tokens = turn_text[:, :delay_tokens] - turn_text = turn_text[:, delay_tokens:] - turn_lens = torch.clamp(turn_lens - delay_tokens, min=0) - - if user_audio_channel_embedding is not None and delay_tokens > 0: - warmup_user_audio = user_audio_channel_embedding[:, -delay_tokens:] - user_audio_channel_embedding = user_audio_channel_embedding[:, :-delay_tokens] - profile_tokens = profile_tokens[:, :-delay_tokens] - else: - warmup_user_audio = None - - if profile_tokens.size(1) > 0: - state = model.streaming_prefill_profile( - state=state, - text_tokens=profile_tokens, - use_inference_mode=True, - user_audio_channel_embedding=user_audio_channel_embedding, - ) - - for i in range(delay_tokens): - user_step_emb = warmup_user_audio[:, i] if warmup_user_audio is not None else None - - state.finished.zero_() - state, _, _ = model.streaming_step( - state=state, - text_tokens=warmup_tokens[:, i], - user_audio_channel_embedding=user_step_emb, - prefill_like_step=not bool(model.cfg.get("agent_mask_include_transition_prefix", False)), - prefill_like_is_last_step=(i == delay_tokens - 1), - use_inference_mode=True, - ) - - logging.info(f"[profile_multiturn] turn={t} prefilled {profile_T} steps ({profile_seconds:.2f}s)") - - turn_start_frame = sum(p.size(-1) for p in state.all_predictions) - if t == 0: - state.audio_prediction_start_idx.fill_(turn_start_frame) - profile_decode_start_frame = turn_start_frame - - turn_offset = state.text_tokens_seen.clone() - turn_steps = 0 - saw_audio = False - turn_ended_with_audio_eos = False - - while turn_steps < args.max_tts_steps: - turn_steps += 1 - state.finished.zero_() - - relative_position = state.text_tokens_seen - turn_offset - text_exhausted = relative_position >= turn_lens - - if turn_text.size(1) == 0: - current_tokens = torch.full((B,), model.eos_id, dtype=torch.long, device=device) - else: - position = relative_position.clamp(min=0, max=turn_text.size(1) - 1) - current_tokens = turn_text[torch.arange(B, device=device), position] - current_tokens = torch.where( - text_exhausted, - torch.full_like(current_tokens, model.eos_id), - current_tokens, - ) - - state, audio_codes, _ = model.streaming_step( - state=state, - text_tokens=current_tokens, - use_inference_mode=True, - ) - - if audio_codes is not None and not saw_audio: - saw_audio = True - - if bool(text_exhausted[0].item()) and bool(state.finished[0].item()): - turn_ended_with_audio_eos = True - break - - state.audio_prediction_end_idx.fill_(-1) - state.finished.zero_() - - turn_end_frame = sum(p.size(-1) for p in state.all_predictions) - profile_turn_frame_ranges.append((t, turn_start_frame, turn_end_frame)) - - logging.info( - f"[profile_multiturn] turn={t} steps={turn_steps} " - f"saw_audio={saw_audio} ended_with_audio_eos={turn_ended_with_audio_eos}" - ) - - bos_id = getattr(model, "audio_bos_id", -1) - eos_id = getattr(model, "audio_eos_id", -1) - speaking_id = getattr(model, "audio_user_speaking_id", -1) - speaking_end_id = getattr(model, "audio_user_speaking_end_id", -1) - sil_injection = codec_sil_codes.view(1, -1, 1) - - for step_idx in range(len(state.all_predictions)): - pred = state.all_predictions[step_idx] - mask = (pred == bos_id) | (pred == eos_id) | (pred == speaking_id) | (pred == speaking_end_id) - frame_mask = mask.any(dim=1, keepdim=True) - if frame_mask.any(): - state.all_predictions[step_idx] = torch.where(frame_mask, sil_injection.expand_as(pred), pred) - - if inputs["duplex_multiturn"] or inputs["profile_multiturn"]: - state.audio_prediction_end_idx.fill_(-1) - - finalize_output = model.streaming_finalize(state, use_inference_mode=True) - - return finalize_output, profile_turn_frame_ranges, profile_decode_start_frame - - -def load_speaker_wav_if_needed(args, model, target_dtype): - if args.user_custom_speaker_reference and args.inference_speaker_reference: - return _load_audio( - args.inference_speaker_reference, - model.sample_rate, - normalize=args.normalize_volume, - use_librosa=args.use_librosa, - ).unsqueeze(0).to(model.device, dtype=target_dtype) - - return None - - -# ----------------------------- -# Save generation outputs and metric manifests -# ----------------------------- - - -def write_audio_1d(path: str, wav: torch.Tensor, sr: int): - os.makedirs(os.path.dirname(path), exist_ok=True) - wav_np = wav.detach().cpu().float().numpy() - sf.write(path, wav_np, samplerate=sr) - - -def build_metric_item( - run_id: int, - rank: int, - dataset_index: int, - turn_id: int, - target_audio_path: str, - reference_text: str, - pred_audio_path: str, - context_audio_path: str, - pred_audio_samples: int, - context_audio_samples: int, - output_sample_rate: int, - context_sample_rate: int, -): - return { - "run_id": int(run_id), - "rank": int(rank), - "dataset_index": int(dataset_index), - "turn_id": int(turn_id), - "target_audio_path": target_audio_path, - "reference_text": reference_text, - "pred_audio_path": pred_audio_path, - "context_audio_path": context_audio_path, - "pred_audio_samples": int(pred_audio_samples), - "context_audio_samples": int(context_audio_samples), - "pred_audio_seconds": float(pred_audio_samples / output_sample_rate), - "context_audio_seconds": float(context_audio_samples / context_sample_rate), - "output_sample_rate": int(output_sample_rate), - "context_sample_rate": int(context_sample_rate), - } - - -def save_generation_outputs_and_build_metric_items( - model, - inputs, - finalize_output, - profile_turn_frame_ranges, - profile_decode_start_frame, - args, - rank: int, - run_id: int, -): - device = model.device - B = inputs["context_audio"].size(0) - - with fp32_precision(): - audio_f32 = finalize_output.audio.float() - audio_len = finalize_output.audio_len.int() - - expected_audio_lens = ( - torch.tensor(inputs["target_num_frames"], device=device) * model.target_samples_per_frame - ).int() - - if inputs["duplex_multiturn"]: - text_lens = inputs["input_lengths"].to(device) - audio_len = (text_lens * model.target_samples_per_frame).int() - audio_len = torch.min(audio_len, torch.tensor(audio_f32.size(1), device=device)) - elif inputs["profile_multiturn"]: - audio_len = finalize_output.audio_len.int() - else: - audio_len = torch.min(audio_len, expected_audio_lens) - - audio_out_dir = get_audio_out_dir(args) - metric_turn_dir = get_generated_turn_audio_dir(args) - metric_context_dir = get_context_metric_audio_dir(args) - os.makedirs(audio_out_dir, exist_ok=True) - os.makedirs(metric_turn_dir, exist_ok=True) - os.makedirs(metric_context_dir, exist_ok=True) - - audio_f32_cpu = audio_f32.detach().cpu() - audio_len_cpu = audio_len.detach().cpu() - metric_items = [] - - for i in range(B): - target_path = inputs["target_audio_paths"][i] - base_name = os.path.basename(target_path) - stem, ext = os.path.splitext(base_name) - if not ext: - ext = ".wav" - - dataset_idx = int(inputs.get("dataset_indices", [-1] * B)[i]) - safe_stem = ( - f"run{run_id:02d}_idx{dataset_idx:08d}_{stem}" - if dataset_idx >= 0 - else f"run{run_id:02d}_rank{rank}_{stem}" - ) - - context_len = int(inputs["context_audio_lengths"][i].detach().cpu().item()) - context_wav = inputs["context_audio"][i, :context_len].detach().cpu().float() - context_metric_path = os.path.join(metric_context_dir, f"{safe_stem}_context.wav") - write_audio_1d(context_metric_path, context_wav, model.sample_rate) - - if inputs["profile_multiturn"]: - full_len = int(audio_len_cpu[i].item()) - full_wav_t = audio_f32_cpu[i, :full_len].float() - - samples_per_prediction_frame = model.codec_model_samples_per_frame / ( - model.sample_rate / model.output_sample_rate - ) - - aligned_agent = torch.zeros_like(full_wav_t) - raw_turn_texts = inputs.get("raw_turn_texts", [[] for _ in range(B)]) - - for turn_id, start_frame, end_frame in profile_turn_frame_ranges: - rel_start_frame = start_frame - profile_decode_start_frame - rel_end_frame = end_frame - profile_decode_start_frame - - start_sample = int(round(rel_start_frame * samples_per_prediction_frame)) - end_sample = int(round(rel_end_frame * samples_per_prediction_frame)) - - start_sample = max(0, min(start_sample, full_len)) - end_sample = max(start_sample, min(end_sample, full_len)) - - aligned_agent[start_sample:end_sample] = full_wav_t[start_sample:end_sample] - - turn_wav = aligned_agent[start_sample:end_sample].float() - turn_out_path = os.path.join(audio_out_dir, f"{safe_stem}_turn_{turn_id}{ext}") - write_audio_1d(turn_out_path, turn_wav, model.output_sample_rate) - - metric_turn_path = os.path.join(metric_turn_dir, f"{safe_stem}_turn_{turn_id}.wav") - write_audio_1d(metric_turn_path, turn_wav, model.output_sample_rate) - - if turn_id < len(raw_turn_texts[i]): - metric_items.append( - build_metric_item( - run_id=run_id, - rank=rank, - dataset_index=dataset_idx, - turn_id=turn_id, - target_audio_path=target_path, - reference_text=str(raw_turn_texts[i][turn_id]), - pred_audio_path=metric_turn_path, - context_audio_path=context_metric_path, - pred_audio_samples=int(turn_wav.numel()), - context_audio_samples=int(context_wav.numel()), - output_sample_rate=model.output_sample_rate, - context_sample_rate=model.sample_rate, - ) - ) - - full_out_path = os.path.join(audio_out_dir, f"{safe_stem}{ext}") - write_audio_1d(full_out_path, aligned_agent, model.output_sample_rate) - - if "user_audio_turns" in inputs: - user_segments = [] - - first_user_len_in = int(inputs["user_audio_turns_lens"][0][i].item()) - first_user_delay_out = int(round(first_user_len_in * model.output_sample_rate / model.sample_rate)) - - for turn_id, start_frame, _ in profile_turn_frame_ranges: - if turn_id >= len(inputs["user_audio_turns"]): - continue - - turn_audio = inputs["user_audio_turns"][turn_id][i].detach().cpu().float() - turn_audio_len = int(inputs["user_audio_turns_lens"][turn_id][i].item()) - turn_audio = turn_audio[:turn_audio_len] - - turn_audio_out = resample( - turn_audio.unsqueeze(0), - model.sample_rate, - model.output_sample_rate, - ).squeeze(0) - - if turn_id == 0: - user_start_sample = 0 - else: - prev_turn_end_frame = profile_turn_frame_ranges[turn_id - 1][2] - rel_prev_end_frame = prev_turn_end_frame - profile_decode_start_frame - user_start_sample = first_user_delay_out + int( - round(rel_prev_end_frame * samples_per_prediction_frame) - ) - - user_segments.append((user_start_sample, turn_audio_out.detach().cpu().float())) - - total_user_len = 0 - for s, wav_seg in user_segments: - total_user_len = max(total_user_len, s + wav_seg.numel()) - - user_ch = torch.zeros(total_user_len) - for s, wav_seg in user_segments: - e = s + wav_seg.numel() - user_ch[s:e] += wav_seg - - agent_ch = torch.cat([torch.zeros(first_user_delay_out, dtype=aligned_agent.dtype), aligned_agent]) - - stereo_len = max(user_ch.numel(), agent_ch.numel()) - user_pad = torch.zeros(stereo_len) - agent_pad = torch.zeros(stereo_len) - - user_pad[: user_ch.numel()] = user_ch - agent_pad[: agent_ch.numel()] = agent_ch - - stereo = torch.stack([user_pad, agent_pad], dim=1).numpy() - aligned_path = os.path.join(audio_out_dir, f"{safe_stem}_user_agent_aligned{ext}") - sf.write(aligned_path, stereo, samplerate=model.output_sample_rate) - - else: - full_len = int(audio_len_cpu[i].item()) - wav = audio_f32_cpu[i, :full_len].float() - out_path = os.path.join(audio_out_dir, f"{safe_stem}{ext}") - write_audio_1d(out_path, wav, model.output_sample_rate) - - metric_turn_path = os.path.join(metric_turn_dir, f"{safe_stem}_turn_0.wav") - write_audio_1d(metric_turn_path, wav, model.output_sample_rate) - - metric_items.append( - build_metric_item( - run_id=run_id, - rank=rank, - dataset_index=dataset_idx, - turn_id=0, - target_audio_path=target_path, - reference_text=str(inputs["raw_text"][i]), - pred_audio_path=metric_turn_path, - context_audio_path=context_metric_path, - pred_audio_samples=int(wav.numel()), - context_audio_samples=int(context_wav.numel()), - output_sample_rate=model.output_sample_rate, - context_sample_rate=model.sample_rate, - ) - ) - - return metric_items - - -# ----------------------------- -# Metrics after generation -# ----------------------------- - - -def torch_rms_norm(wav: torch.Tensor, db_level: float = -27.0) -> torch.Tensor: - denom = torch.sum(wav**2) - if denom <= 0: - return wav - r = 10 ** (db_level / 20) - a = torch.sqrt((wav.size(-1) * (r**2)) / denom) - return wav * a - - -def _load_audio_for_metric(path: str, sample_rate: int): - wav = _load_audio(path, sample_rate=sample_rate, normalize=False, use_librosa=False) - if wav.numel() == 0: - wav = torch.zeros(1, dtype=torch.float32) - return wav.float() - - -def _pad_audio_1d_list(wavs: List[torch.Tensor], device, dtype=torch.float32): - if len(wavs) == 0: - return torch.zeros((0, 1), device=device, dtype=dtype), torch.zeros((0,), device=device, dtype=torch.long) - - lens = torch.tensor([max(1, int(w.numel())) for w in wavs], device=device, dtype=torch.long) - max_len = int(lens.max().item()) - out = torch.zeros((len(wavs), max_len), device=device, dtype=dtype) - - for i, w in enumerate(wavs): - w = w.to(device=device, dtype=dtype).flatten() - if w.numel() == 0: - continue - out[i, : w.numel()] = w - - return out, lens - - -def chunk_list(xs: List[Any], chunk_size: int) -> Iterable[List[Any]]: - chunk_size = max(1, int(chunk_size)) - for start in range(0, len(xs), chunk_size): - yield xs[start : start + chunk_size] - - -def _metric_device(): - return "cuda" if torch.cuda.is_available() else "cpu" - - -def _load_metric_batch_audio(batch_items: List[Dict[str, Any]], args): - pred_wavs = [] - context_wavs = [] - - for item in batch_items: - pred = _load_audio_for_metric(item["pred_audio_path"], sample_rate=int(item["output_sample_rate"])) - context = _load_audio_for_metric(item["context_audio_path"], sample_rate=int(item["context_sample_rate"])) - - if args.max_metric_audio_sec is not None: - max_pred_len = int(float(args.max_metric_audio_sec) * int(item["output_sample_rate"])) - pred = pred[: max(1, max_pred_len)] - - pred_wavs.append(pred) - context_wavs.append(context) - - device = _metric_device() - pred_audio, pred_lens = _pad_audio_1d_list(pred_wavs, device=device) - context_audio, context_lens = _pad_audio_1d_list(context_wavs, device=device) - output_sample_rate = int(batch_items[0]["output_sample_rate"]) - context_sample_rate = int(batch_items[0]["context_sample_rate"]) - - return pred_audio, pred_lens, context_audio, context_lens, output_sample_rate, context_sample_rate - - -def compute_metrics_after_generation(args, rank: int, world_size: int, metric_items: List[Dict[str, Any]]): - """ - Load metric models only after generation is complete. - - Order: - 1. Load ASR, compute turn-level CER/WER and ASR hyps, then free ASR. - 2. Load SECS speaker encoder and compute turn-level SECS. - 3. Save rank-level aggregate metrics from the same turn-level rows. - - SECS is always computed turn-by-turn, like CER/WER. The grouped filewise - output stores secs_turns and sample-level secs, and metrics_final.* receives - the turn-level aggregate SECS. - """ - metric_start = time.time() - - if len(metric_items) == 0: - return { - "rank": int(rank), - "world_size": int(world_size), - "num_processed": 0, - "num_metric_items": 0, - "metric_elapsed_sec": 0.0, - "intelligibility": {}, - "secs": {}, - }, [] - - normalizer = EnglishTextNormalizer() - normalizer.ignore_patterns = r"$^" - filewise_rows = [] - - # ASR pass. - all_rank_print(rank, f"loading ASR after generation: {args.asr_model_name}") - with fp32_precision(): - intelligibility = Intelligibility(args.asr_model_name, reuse_asr_hyps=False).reset() - - for batch_items in chunk_list(metric_items, args.metric_batch_size): - refs = [x["reference_text"] for x in batch_items] - pred_audio, pred_lens, _, _, output_sr, _ = _load_metric_batch_audio(batch_items, args) - - with fp32_precision(): - pred_16k = resample(pred_audio, output_sr, 16000) - pred_16k_lens = (pred_lens / output_sr * 16000).to(torch.long) - pred_16k = torch_rms_norm(pred_16k) - - asr_hyps = intelligibility.update( - name="dataset", - refs=refs, - pred_audio=pred_16k, - pred_audio_lens=pred_16k_lens, - asr_hyps=None, - ) - - for item, hyp in zip(batch_items, asr_hyps): - ref_norm = normalizer(str(item["reference_text"])).strip() - hyp_norm = normalizer(str(hyp)).strip() - if ref_norm == "": - cer = None - wer = None - else: - cer = float(word_error_rate([hyp_norm], [ref_norm], use_cer=True)) - wer = float(word_error_rate([hyp_norm], [ref_norm], use_cer=False)) - - row = dict(item) - row["asr_hyp"] = hyp - row["cer"] = cer - row["wer"] = wer - row["secs"] = None - filewise_rows.append(row) - - with fp32_precision(): - cer_wer = metric_dict_to_jsonable(intelligibility.compute()) - del intelligibility - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - # SECS pass. This is intentionally turn-level, matching CER/WER. - # We keep one aggregate SECS metric for metrics_final.* and also compute - # one SECS value per filewise turn row so grouped outputs have secs_turns. - all_rank_print(rank, f"loading speaker encoder after ASR is released: {args.secs_model_name}") - with fp32_precision(): - secs_metric = SECS(args.secs_model_name).reset() - - # Aggregate turn-level SECS for metrics_final.json / metrics_final.txt. - for batch_items in chunk_list(metric_items, args.metric_batch_size): - pred_audio, pred_lens, context_audio, context_lens, output_sr, context_sr = _load_metric_batch_audio( - batch_items, args - ) - - with fp32_precision(): - pred_16k = resample(pred_audio, output_sr, 16000) - pred_16k_lens = (pred_lens / output_sr * 16000).to(torch.long) - context_16k = resample(context_audio, context_sr, 16000) - context_16k_lens = (context_lens / context_sr * 16000).to(torch.long) - - pred_16k = torch_rms_norm(pred_16k) - context_16k = torch_rms_norm(context_16k) - - secs_metric.update( - name="dataset", - target_audio=context_16k, - target_audio_lens=context_16k_lens, - pred_audio=pred_16k, - pred_audio_lens=pred_16k_lens, - ) - - with fp32_precision(): - secs_scores = metric_dict_to_jsonable(secs_metric.compute()) - del secs_metric - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - # Per-turn SECS for filewise/grouped outputs. This is always computed so - # secs_turns and sample-level secs are never null in final filewise metrics. - # It is slower than aggregate-only SECS, but it matches the turn-level - # semantics requested for CER/WER/SECS. - all_rank_print(rank, "computing per-turn SECS rows") - for row in filewise_rows: - pred_audio, pred_lens, context_audio, context_lens, output_sr, context_sr = _load_metric_batch_audio([row], args) - - with fp32_precision(): - one_secs = SECS(args.secs_model_name).reset() - pred_16k = resample(pred_audio, output_sr, 16000) - pred_16k_lens = (pred_lens / output_sr * 16000).to(torch.long) - context_16k = resample(context_audio, context_sr, 16000) - context_16k_lens = (context_lens / context_sr * 16000).to(torch.long) - - pred_16k = torch_rms_norm(pred_16k) - context_16k = torch_rms_norm(context_16k) - - one_secs.update( - name="dataset", - target_audio=context_16k, - target_audio_lens=context_16k_lens, - pred_audio=pred_16k, - pred_audio_lens=pred_16k_lens, - ) - one_secs_metrics = metric_dict_to_jsonable(one_secs.compute()) - - row["secs"] = safe_metric_scalar(one_secs_metrics, ["secs", "secs_dataset"]) - row["secs_metrics"] = one_secs_metrics - del one_secs - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - metric_elapsed = time.time() - metric_start - - rank_metrics = { - "rank": int(rank), - "world_size": int(world_size), - "num_processed": len({(x["run_id"], x["dataset_index"]) for x in metric_items}), - "num_metric_items": int(len(metric_items)), - "metric_elapsed_sec": float(metric_elapsed), - "intelligibility": cer_wer, - "secs": secs_scores, - } - - return rank_metrics, filewise_rows - - -# ----------------------------- -# Merge helpers -# ----------------------------- - - -def compute_and_save_rank_metrics_file(args, rank_metrics: Dict[str, Any], rank: int): - rank_path = os.path.join(args.out_dir, f"metrics_rank{rank:04d}.json") - write_json(rank_path, rank_metrics) - return rank_metrics - - -def merge_metrics_on_rank0(args, rank, world_size): - if rank != 0: - return None - - rank_metric_files = [os.path.join(args.out_dir, f"metrics_rank{r:04d}.json") for r in range(world_size)] - - rank_metrics = [] - for path in rank_metric_files: - if not os.path.exists(path): - logging.warning(f"Missing rank metric file: {path}") - continue - with open(path, "r", encoding="utf-8") as f: - rank_metrics.append(json.load(f)) - - total_n = sum(int(m.get("num_metric_items", m.get("num_processed", 0))) for m in rank_metrics) - - def weighted_average(section: str): - keys = set() - for m in rank_metrics: - keys.update(m.get(section, {}).keys()) - - out = {} - for k in sorted(keys): - numerator = 0.0 - denominator = 0 - - for m in rank_metrics: - n = int(m.get("num_metric_items", m.get("num_processed", 0))) - if n <= 0: - continue - - value = m.get(section, {}).get(k, None) - if value is None or isinstance(value, str): - continue - - try: - value = float(value) - except Exception: - continue - - numerator += value * n - denominator += n - - if denominator > 0: - out[k] = numerator / denominator - - return out - - final_metrics = { - "world_size": int(world_size), - "num_metric_items": int(total_n), - "aggregation": "sum(rank_metric * rank_num_metric_items) / total_num_metric_items", - "intelligibility": weighted_average("intelligibility"), - "secs": weighted_average("secs"), - "ranks": rank_metrics, - } - - final_json_path = os.path.join(args.out_dir, "metrics_final.json") - final_txt_path = os.path.join(args.out_dir, "metrics_final.txt") - - write_json(final_json_path, final_metrics) - - final_text = format_final_metric_text(final_metrics) - write_text_atomic(final_txt_path, final_text) - - print("\n--- Final Evaluation Metrics ---", flush=True) - print(final_text, flush=True) - - logging.info(f"Final metrics JSON saved to: {final_json_path}") - logging.info(f"Final metrics TXT saved to: {final_txt_path}") - logging.info(json.dumps(final_metrics, indent=2, sort_keys=True)) - - return final_metrics - - -def merge_filewise_metrics_on_rank0(args, rank: int, world_size: int): - """Merge per-turn rank metric rows into one row per original sample. - - Rank files still contain one row per turn because metrics are computed - turn-by-turn. The final filewise outputs group those turn rows by - (run_id, dataset_index), producing one JSONL/CSV row per original sample - with list fields: - reference_text, asr_hyp, cer_turns, wer_turns, secs_turns. - - DistributedSampler padding repeats are deduplicated by - (run_id, dataset_index, turn_id), but repetitions from --num_eval_runs are - preserved because run_id is part of the key. - """ - if rank != 0 or not args.save_filewise_metrics: - return [] - - turn_rows = [] - - for r in range(world_size): - path = os.path.join(args.out_dir, f"filewise_metrics_rank{r:04d}.jsonl") - if not os.path.exists(path): - logging.warning(f"Missing filewise metrics file: {path}") - continue - - with open(path, "r", encoding="utf-8") as f: - for line in f: - line = line.strip() - if line: - turn_rows.append(json.loads(line)) - - # Deduplicate DistributedSampler padding repeats, but preserve --num_eval_runs. - deduped_turns = {} - for row in turn_rows: - run_id = int(row.get("run_id", 0)) - idx = int(row.get("dataset_index", -1)) - turn_id = int(row.get("turn_id", 0)) - key = (run_id, idx, turn_id) - if key not in deduped_turns: - deduped_turns[key] = row - - turn_rows = list(deduped_turns.values()) - - # Group turn rows into one row per original file/sample. - grouped = {} - for row in turn_rows: - run_id = int(row.get("run_id", 0)) - idx = int(row.get("dataset_index", -1)) - key = (run_id, idx) - - if key not in grouped: - grouped[key] = { - "run_id": run_id, - "dataset_index": idx, - "rank": int(row.get("rank", -1)), - "target_audio_path": row.get("target_audio_path", ""), - "context_audio_path": row.get("context_audio_path", ""), - "turn_rows": [], - } - - grouped[key]["turn_rows"].append(row) - - def avg(vals): - vals = [float(x) for x in vals if x is not None and math.isfinite(float(x))] - return None if not vals else sum(vals) / len(vals) - - sample_rows = [] - for _, group in grouped.items(): - turns = sorted(group["turn_rows"], key=lambda x: int(x.get("turn_id", 0))) - - cer_turns = [r.get("cer") for r in turns] - wer_turns = [r.get("wer") for r in turns] - secs_turns = [r.get("secs") for r in turns] - - sample_row = { - "run_id": group["run_id"], - "dataset_index": group["dataset_index"], - "rank": group["rank"], - "num_turns": len(turns), - "turn_ids": [int(r.get("turn_id", 0)) for r in turns], - "target_audio_path": group["target_audio_path"], - "context_audio_path": group["context_audio_path"], - "pred_audio_paths": [r.get("pred_audio_path", "") for r in turns], - "pred_audio_seconds_turns": [r.get("pred_audio_seconds") for r in turns], - "reference_text": [r.get("reference_text", "") for r in turns], - "asr_hyp": [r.get("asr_hyp", "") for r in turns], - "cer_turns": cer_turns, - "wer_turns": wer_turns, - "secs_turns": secs_turns, - "cer": avg(cer_turns), - "wer": avg(wer_turns), - "secs": avg(secs_turns), - } - - sample_rows.append(sample_row) - - # Sort samples by average CER descending for failure analysis. - sample_rows.sort( - key=lambda x: ( - x.get("cer") is not None, - float(x.get("cer")) if x.get("cer") is not None else -1.0, - ), - reverse=True, - ) - - jsonl_path = os.path.join(args.out_dir, "filewise_metrics_sorted_by_cer.jsonl") - csv_path = os.path.join(args.out_dir, "filewise_metrics_sorted_by_cer.csv") - - write_jsonl(jsonl_path, sample_rows) - write_filewise_csv(csv_path, sample_rows) - - logging.info(f"Saved sample-level filewise metrics JSONL to: {jsonl_path}") - logging.info(f"Saved sample-level filewise metrics CSV to: {csv_path}") - - topk = min(int(args.filewise_metrics_topk_log), len(sample_rows)) - if topk > 0: - logging.info(f"Top {topk} worst CER samples:") - for row in sample_rows[:topk]: - logging.info( - "run_id=%s dataset_index=%s num_turns=%s cer=%s wer=%s secs=%s path=%s" - % ( - row.get("run_id"), - row.get("dataset_index"), - row.get("num_turns"), - row.get("cer"), - row.get("wer"), - row.get("secs"), - row.get("target_audio_path"), - ) - ) - - return sample_rows - -def compute_aggregates_from_filewise_rows(rows: List[Dict[str, Any]]): - """Aggregate over sample-level rows. - - Each row may internally contain multiple turn metrics in cer_turns/wer_turns, - but the final filewise average is over original samples/files. - """ - if len(rows) == 0: - return { - "cer": None, - "wer": None, - "secs": None, - "num_samples": 0, - } - - def avg_key(key): - vals = [float(r[key]) for r in rows if r.get(key) is not None] - if len(vals) == 0: - return None - return sum(vals) / len(vals) - - return { - "cer": avg_key("cer"), - "wer": avg_key("wer"), - "secs": avg_key("secs"), - "num_samples": len(rows), - } - -def save_filewise_final_summary(args, filewise_rows: List[Dict[str, Any]]): - filewise_summary = compute_aggregates_from_filewise_rows(filewise_rows) - - obj = { - "aggregation": "mean_over_sample_metrics_each_sample_contains_turn_metric_lists", - **filewise_summary, - } - - path = os.path.join(args.out_dir, "metrics_final_filewise_average.json") - write_json(path, obj) - - sample_metrics_final_path = os.path.join(args.out_dir, "metrics_final_sample_average.json") - write_json(sample_metrics_final_path, obj) - - final_txt_path = os.path.join(args.out_dir, "metrics_final.txt") - final_text = format_filewise_final_metric_text(filewise_summary) - write_text_atomic(final_txt_path, final_text) - - print("\n--- Final Sample-Averaged Evaluation Metrics ---", flush=True) - print(final_text, flush=True) - - logging.info(f"Filewise averaged final metrics saved to: {path}") - logging.info(f"Sample averaged metrics_final JSON saved to: {sample_metrics_final_path}") - logging.info(f"Final metrics TXT saved to: {final_txt_path}") - - return obj - - -# ----------------------------- -# Args / main -# ----------------------------- - - -def parse_args(): - parser = argparse.ArgumentParser(description="EasyMagpieTTS Multi-GPU Inference Evaluation") - - parser.add_argument("--checkpoint_path", type=str, required=True) - parser.add_argument("--codec_model_path", type=str, required=True) - parser.add_argument("--datasets_json_path", type=str, required=True) - parser.add_argument("--out_dir", type=str, required=True) - - parser.add_argument("--phoneme_tokenizer_path", type=str, default=None) - parser.add_argument("--audio_dir", type=str, default=None) - parser.add_argument("--inference_dtype", type=str, default="float32", choices=["float32", "float16", "bfloat16"]) - parser.add_argument("--debug_dtype", action="store_true") - parser.add_argument("--debug_gpu_assignment", action="store_true") - parser.add_argument("--use_librosa", action="store_true") - - parser.add_argument("--batch_size", type=int, default=6) - parser.add_argument("--num_workers", type=int, default=4) - parser.add_argument("--num_turns", type=int, default=1) - parser.add_argument("--pad_factor_text_speech", type=int, default=10) - - parser.add_argument("--emulate_duplex_inference", action="store_true") - parser.add_argument("--add_interruption_token", action="store_true") - parser.add_argument("--force_interruption", action="store_true") - parser.add_argument("--profile_multiturn_inference", action="store_true") - parser.add_argument("--profile_pad_min_sec", type=float, default=2.0) - parser.add_argument("--profile_pad_max_sec", type=float, default=2.0) - parser.add_argument("--max_eval_turns", type=int, default=6) - - parser.add_argument("--user_custom_speaker_reference", action="store_true") - parser.add_argument("--inference_speaker_reference", type=str, default=None) - parser.add_argument("--language", type=str, default="en") - - parser.add_argument("--use_cfg", action="store_true") - parser.add_argument("--cfg_scale", type=float, default=2.5) - parser.add_argument("--temperature", type=float, default=0.7) - parser.add_argument("--topk", type=int, default=80) - parser.add_argument("--max_tts_steps", type=int, default=2000) - parser.add_argument("--force_speech_sil_codes", action="store_true") - parser.add_argument("--normalize_volume", type=lambda x: str(x).lower() in ["true", "1", "yes"], default=True) - - parser.add_argument( - "--save_filewise_metrics", - action="store_true", - help="Save per-turn/file CER/WER metrics sorted by CER descending.", - ) - parser.add_argument( - "--compute_filewise_secs", - action="store_true", - help="Also compute per-turn/file SECS. Slower because it runs SECS per row.", - ) - parser.add_argument( - "--filewise_metrics_topk_log", - type=int, - default=20, - help="Number of worst CER samples to print on rank 0.", - ) - parser.add_argument( - "--num_eval_runs", - type=int, - default=1, - help="Repeat the full eval set N times. Repetitions are preserved in final filewise average.", - ) - parser.add_argument( - "--sort_by_text_token_count", - action="store_true", - help="Sort eval samples by total text token count before distributed sharding for better load balancing.", - ) - parser.add_argument( - "--metric_batch_size", - type=int, - default=8, - help="Batch size used for post-generation ASR/SECS metric computation.", - ) - parser.add_argument( - "--max_metric_audio_sec", - type=float, - default=120.0, - help="Clamp generated audio length used for ASR/SECS metrics to avoid metric OOM/hangs.", - ) - parser.add_argument( - "--asr_model_name", - type=str, - default="stt_en_fastconformer_transducer_large", - help="Pretrained NeMo ASR model used for CER/WER.", - ) - parser.add_argument( - "--secs_model_name", - type=str, - default="titanet_large", - help="Pretrained speaker encoder model used for SECS.", - ) - - return parser.parse_args() - - -def main(): - args = parse_args() - os.makedirs(args.out_dir, exist_ok=True) - os.makedirs(get_audio_out_dir(args), exist_ok=True) - os.makedirs(get_generated_turn_audio_dir(args), exist_ok=True) - os.makedirs(get_context_metric_audio_dir(args), exist_ok=True) - - distributed, rank, local_rank, world_size, device_index = setup_distributed() - - if args.profile_multiturn_inference and args.batch_size != 1: - raise RuntimeError( - "--profile_multiturn_inference requires --batch_size=1 per process. " - "Use multiple GPUs/processes for parallelism instead of increasing batch_size." - ) - - if args.profile_pad_max_sec < args.profile_pad_min_sec: - raise RuntimeError("--profile_pad_max_sec must be >= --profile_pad_min_sec.") - - if args.num_eval_runs <= 0: - raise RuntimeError("--num_eval_runs must be >= 1.") - - target_device = torch.device(f"cuda:{device_index}" if torch.cuda.is_available() and device_index >= 0 else "cpu") - target_dtype = getattr(torch, args.inference_dtype) - torch.set_default_dtype(target_dtype) - - hostname = socket.gethostname() - cuda_name = torch.cuda.get_device_name(target_device) if torch.cuda.is_available() and device_index >= 0 else "cpu" - - all_rank_print( - rank, - f"host={hostname} local_rank={local_rank} world_size={world_size} " - f"device={target_device} device_name={cuda_name}", - ) - - model = build_model_and_codec(args, target_device, target_dtype) - codec_sil_codes = model.codec_sil_codes - - if args.debug_dtype: - handles, stats, examples = attach_dtype_counter(model) - else: - handles = stats = examples = None - - full_eval_dataset = EvalJSONLDataset(args.datasets_json_path, num_turns=args.num_turns) - # debug - # full_eval_dataset.samples = full_eval_dataset.samples[:7] - - if args.sort_by_text_token_count: - full_eval_dataset = SortedByTextTokenCountDataset( - full_eval_dataset, - model=model, - max_eval_turns=args.max_eval_turns, - descending=True, - ) - - collate_fn = partial( - collate_and_tokenize_custom, - model=model, - extra_duration_thrshould=1.5, - sample_rate=model.sample_rate, - root_path=args.audio_dir, - emulate_duplex_inference=args.emulate_duplex_inference, - add_interruption_token=args.add_interruption_token, - pad_factor_text_speech=args.pad_factor_text_speech, - force_interruption=args.force_interruption, - normalize_audio_volume=args.normalize_volume, - use_librosa=args.use_librosa, - profile_multiturn_inference=args.profile_multiturn_inference, - max_eval_turns=args.max_eval_turns, - ) - - speaker_wav = load_speaker_wav_if_needed(args, model, target_dtype) - - generation_start = time.time() - all_metric_items = [] - total_batches = 0 - total_generated_samples = 0 - - for run_id in range(args.num_eval_runs): - if distributed: - sampler = DistributedSampler( - full_eval_dataset, - num_replicas=world_size, - rank=rank, - shuffle=False, - drop_last=False, - ) - sampler.set_epoch(run_id) - else: - sampler = SequentialSampler(full_eval_dataset) - - if args.debug_gpu_assignment: - try: - assigned_indices = list(iter(sampler)) - assigned_dataset_indices = [ - int(full_eval_dataset[i].get("__dataset_index__", -1)) for i in assigned_indices - ] - all_rank_print( - rank, - f"run_id={run_id} assigned {len(assigned_dataset_indices)} / {len(full_eval_dataset)} " - f"samples to gpu={local_rank}: dataset_indices={assigned_dataset_indices}", - ) - except Exception as e: - all_rank_print(rank, f"Could not print assigned indices: {repr(e)}") - - dataloader = DataLoader( - dataset=full_eval_dataset, - batch_size=args.batch_size, - sampler=sampler, - collate_fn=collate_fn, - num_workers=args.num_workers, - pin_memory=True, - drop_last=False, - ) - - for batch_id, inputs in enumerate(dataloader): - total_batches += 1 - batch_indices = inputs.get("dataset_indices", []) - total_generated_samples += len(batch_indices) - - if args.debug_gpu_assignment: - all_rank_print( - rank, - f"run_id={run_id} gpu={local_rank} batch_id={batch_id} " - f"dataset_indices={batch_indices} text_token_counts={inputs.get('text_token_counts', [])} " - f"target_paths={inputs.get('target_audio_paths', [])}", - ) - - inputs = prepare_inputs_for_device(inputs, model, args, target_dtype, speaker_wav=speaker_wav) - - finalize_output, profile_turn_frame_ranges, profile_decode_start_frame = run_generation( - model=model, - inputs=inputs, - args=args, - codec_sil_codes=codec_sil_codes, - ) - - metric_items = save_generation_outputs_and_build_metric_items( - model=model, - inputs=inputs, - finalize_output=finalize_output, - profile_turn_frame_ranges=profile_turn_frame_ranges, - profile_decode_start_frame=profile_decode_start_frame, - args=args, - rank=rank, - run_id=run_id, - ) - all_metric_items.extend(metric_items) - - if args.debug_dtype and batch_id == 0 and run_id == 0: - report_dtype_stats(handles, stats, examples, rank=rank) - - generation_elapsed = time.time() - generation_start - - # Save pre-metric manifest for debugging and restartability. - metric_manifest_path = os.path.join(args.out_dir, f"metric_items_rank{rank:04d}.jsonl") - write_jsonl(metric_manifest_path, all_metric_items) - - # Free TTS/codec model memory before loading ASR and speaker encoder metrics. - del model - if speaker_wav is not None: - del speaker_wav - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - all_rank_print( - rank, - f"generation done: batches={total_batches} generated_samples_with_sampler_padding={total_generated_samples} " - f"metric_items={len(all_metric_items)} elapsed_sec={generation_elapsed:.2f}. " - "Loading ASR/SECS metrics now.", - ) - - rank_metrics, rank_filewise_rows = compute_metrics_after_generation( - args=args, - rank=rank, - world_size=world_size, - metric_items=all_metric_items, - ) - rank_metrics["generation_elapsed_sec"] = float(generation_elapsed) - rank_metrics["num_generated_samples_with_sampler_padding"] = int(total_generated_samples) - - rank_metrics = compute_and_save_rank_metrics_file(args, rank_metrics, rank) - all_rank_print(rank, f"saved rank metrics: {json.dumps(rank_metrics, sort_keys=True)}") - - if args.save_filewise_metrics: - rank_filewise_rows.sort( - key=lambda x: ( - x.get("cer") is not None, - float(x.get("cer")) if x.get("cer") is not None else -1.0, - ), - reverse=True, - ) - - rank_filewise_path = os.path.join(args.out_dir, f"filewise_metrics_rank{rank:04d}.jsonl") - write_jsonl(rank_filewise_path, rank_filewise_rows) - all_rank_print(rank, f"saved filewise metrics: {rank_filewise_path}") - - if rank == 0: - wait_for_rank_metric_files(args, world_size) - - merge_metrics_on_rank0(args, rank, world_size) - - if args.save_filewise_metrics: - if rank == 0: - wait_for_rank_filewise_metric_files(args, world_size) - - filewise_rows = merge_filewise_metrics_on_rank0(args, rank, world_size) - - if rank == 0: - save_filewise_final_summary(args, filewise_rows) - - cleanup_distributed() - - -if __name__ == "__main__": - main() diff --git a/examples/tts/easy_magpietts_inference_multiturn_runner.py b/examples/tts/easy_magpietts_inference_multiturn_runner.py deleted file mode 100644 index 5d089bf4b95e..000000000000 --- a/examples/tts/easy_magpietts_inference_multiturn_runner.py +++ /dev/null @@ -1,750 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -TTS inference and evaluation entry point for MagpieTTS/EasyMagpieTTS. - -This version adds EasyMagpie multiturn user-audio inference as a first-class -runner mode while keeping the existing EasyMagpie evaluation pipeline. The new -runner writes turn-level EasyMagpie-compatible generated files and a generated -turn-level manifest, so ``evaluate_generated_audio_dir`` can compute CER/WER, -SSIM, UTMOSv2, EOU, FCD, CSVs and plots without custom metric code. -""" -from __future__ import annotations - -import argparse -import copy -import json -import os -import random -import shutil -import time -from dataclasses import fields -from pathlib import Path -from typing import List, Optional, Tuple - -import numpy as np -import torch - -from nemo.collections.asr.parts.utils.manifest_utils import read_manifest -from nemo.collections.tts.models.easy_magpietts_inference import EasyModelInferenceParameters -from nemo.collections.tts.models.magpietts import ModelInferenceParameters -from nemo.collections.tts.modules.magpietts_inference.evaluate_generated_audio import load_evalset_config -from nemo.collections.tts.modules.magpietts_inference.evaluation import ( - DEFAULT_VIOLIN_METRICS, - EvaluationConfig, - compute_mean_with_confidence_interval, - evaluate_generated_audio_dir, -) -from nemo.collections.tts.modules.magpietts_inference.inference import ( - BaseInferenceConfig, - BaseInferenceRunner, - EasyMagpieInferenceConfig, - EasyMagpieInferenceRunner, - EasyMagpieMultiturnUserAudioInferenceConfig, - EasyMagpieMultiturnUserAudioInferenceRunner, - MagpieInferenceConfig, - MagpieInferenceRunner, -) -from nemo.collections.tts.modules.magpietts_inference.utils import ( - ModelLoadConfig, - get_experiment_name_from_checkpoint_path, - load_easy_magpie_model, - load_magpie_model, - log_model_architecture_summary, -) -from nemo.collections.tts.modules.magpietts_inference.visualization import create_combined_box_plot, create_violin_plot -from nemo.collections.tts.modules.magpietts_modules import EOSDetectionMethod -from nemo.utils import logging - - -def parse_layer_list(layer_str: Optional[str]) -> Optional[List[int]]: - if layer_str is None: - return None - return [int(l.strip()) for l in layer_str.split(",")] - - -def write_csv_header_if_needed(csv_path: str, header: str) -> None: - if not os.path.exists(csv_path): - os.makedirs(os.path.dirname(csv_path), exist_ok=True) - with open(csv_path, "w", encoding="utf-8") as f: - f.write(header + "\n") - - -def append_metrics_to_csv(csv_path: str, checkpoint_name: str, dataset: str, metrics: dict) -> None: - values = [ - checkpoint_name, - dataset, - metrics.get('cer_filewise_avg', ''), - metrics.get('wer_filewise_avg', ''), - metrics.get('cer_cumulative', ''), - metrics.get('wer_cumulative', ''), - metrics.get('ssim_pred_gt_avg', ''), - metrics.get('ssim_pred_context_avg', ''), - metrics.get('ssim_gt_context_avg', ''), - metrics.get('ssim_pred_gt_avg_alternate', ''), - metrics.get('ssim_pred_context_avg_alternate', ''), - metrics.get('ssim_gt_context_avg_alternate', ''), - metrics.get('cer_gt_audio_cumulative', ''), - metrics.get('wer_gt_audio_cumulative', ''), - metrics.get('utmosv2_avg', ''), - metrics.get('total_gen_audio_seconds', ''), - metrics.get('frechet_codec_distance', ''), - metrics.get('eou_cutoff_rate', ''), - metrics.get('eou_silence_rate', ''), - metrics.get('eou_noise_rate', ''), - metrics.get('eou_error_rate', ''), - ] - with open(csv_path, "a", encoding="utf-8") as f: - f.write(",".join(str(v).replace(",", " ") for v in values) + "\n") - logging.info(f"Metrics appended to: {csv_path}") - - -def create_formatted_metrics_mean_ci(metrics_mean_ci: dict) -> dict: - for k, v in metrics_mean_ci.items(): - if isinstance(v, list): - mean, ci = float(v[0]), float(v[1]) - logging.info(f"Metric {k}: {mean:.4f} ± {ci:.4f}") - metrics_mean_ci[k] = f"{mean:.4f} ± {ci:.4f}" - return metrics_mean_ci - - -def filter_datasets(dataset_meta_info: dict, datasets: Optional[str]) -> List[str]: - if datasets is None: - return list(dataset_meta_info.keys()) - selected = datasets.split(",") - for dataset in selected: - if dataset not in dataset_meta_info: - raise ValueError(f"Dataset {dataset} not found in dataset meta info") - return selected - - -def _runner_eval_manifest_and_audio_dir(runner: BaseInferenceRunner, default_manifest: str, default_audio_dir: str): - """Return evaluation manifest/audio dir produced by the runner, if any.""" - eval_manifest = getattr(runner, "evaluation_manifest_path", None) or default_manifest - eval_audio_dir = getattr(runner, "evaluation_audio_dir", None) or default_audio_dir - return eval_manifest, eval_audio_dir - - - -def _get_torchrun_rank_info() -> Tuple[int, int, int]: - """Return (rank, world_size, local_rank) from torchrun/SLURM env vars. - - We intentionally do not initialize torch.distributed here. The inference - script only needs env-based sharding, while NeMo evaluation models can run - without distributed collectives. - """ - rank = int(os.environ.get("RANK", os.environ.get("SLURM_PROCID", "0"))) - world_size = int(os.environ.get("WORLD_SIZE", "1")) - local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("SLURM_LOCALID", "0"))) - return rank, world_size, local_rank - - -def _configure_cuda_for_rank() -> Tuple[int, int, int]: - rank, world_size, local_rank = _get_torchrun_rank_info() - if torch.cuda.is_available(): - device_count = torch.cuda.device_count() - if device_count > 0: - torch.cuda.set_device(local_rank % device_count) - logging.info( - f"Using CUDA device {local_rank % device_count}; " - f"rank={rank}, local_rank={local_rank}, world_size={world_size}" - ) - return rank, world_size, local_rank - - -def _wait_for_multiturn_rank_manifests(repeat_audio_dir: str, world_size: int, timeout_sec: int = 7200) -> None: - deadline = time.time() + timeout_sec - while time.time() < deadline: - missing = [] - for rank in range(world_size): - path = os.path.join( - repeat_audio_dir, - f"rank_{rank:04d}", - f"multiturn_user_audio_turn_manifest_rank{rank:04d}.jsonl", - ) - if not os.path.exists(path): - missing.append(path) - if not missing: - return - time.sleep(5) - raise RuntimeError(f"Timed out waiting for multiturn rank manifests: {missing}") - - -def _copy_or_link(src: str, dst: str) -> None: - if src is None or not os.path.exists(src): - return - os.makedirs(os.path.dirname(dst), exist_ok=True) - try: - if os.path.lexists(dst): - os.remove(dst) - os.symlink(os.path.abspath(src), dst) - except Exception: - shutil.copyfile(src, dst) - - -def _merge_multiturn_rank_outputs(repeat_audio_dir: str, world_size: int, save_predicted_codes: bool) -> str: - """Merge rank-local multiturn outputs into one EasyMagpie-compatible dir. - - Each rank writes local files named predicted_audio_0.wav, target_audio_0.wav, - context_audio_0.wav, predicted_codes_0.pt, ... inside rank_XXXX/. This - function remaps them to contiguous global indices in repeat_audio_dir/ and - writes a merged turn-level manifest. - """ - merged_records = [] - global_idx = 0 - - for rank in range(world_size): - rank_dir = os.path.join(repeat_audio_dir, f"rank_{rank:04d}") - rank_manifest = os.path.join(rank_dir, f"multiturn_user_audio_turn_manifest_rank{rank:04d}.jsonl") - if not os.path.exists(rank_manifest): - raise FileNotFoundError(f"Missing rank manifest: {rank_manifest}") - - with open(rank_manifest, "r", encoding="utf-8") as f: - rank_records = [json.loads(line) for line in f if line.strip()] - - for local_idx, record in enumerate(rank_records): - pred_src = os.path.join(rank_dir, f"predicted_audio_{local_idx}.wav") - pred_dst = os.path.join(repeat_audio_dir, f"predicted_audio_{global_idx}.wav") - _copy_or_link(pred_src, pred_dst) - - if save_predicted_codes: - code_src = os.path.join(rank_dir, f"predicted_codes_{local_idx}.pt") - code_dst = os.path.join(repeat_audio_dir, f"predicted_codes_{global_idx}.pt") - _copy_or_link(code_src, code_dst) - - target_src = os.path.join(rank_dir, record.get("audio_filepath", f"target_audio_{local_idx}.wav")) - target_dst = os.path.join(repeat_audio_dir, f"target_audio_{global_idx}.wav") - _copy_or_link(target_src, target_dst) - - context_src = os.path.join( - rank_dir, - record.get("context_audio_filepath", f"context_audio_{local_idx}.wav"), - ) - context_dst = os.path.join(repeat_audio_dir, f"context_audio_{global_idx}.wav") - _copy_or_link(context_src, context_dst) - - merged = dict(record) - merged["audio_filepath"] = f"target_audio_{global_idx}.wav" - merged["context_audio_filepath"] = f"context_audio_{global_idx}.wav" - merged["rank"] = rank - merged["rank_local_idx"] = local_idx - merged_records.append(merged) - global_idx += 1 - - merged_manifest = os.path.join(repeat_audio_dir, "multiturn_user_audio_turn_manifest.jsonl") - with open(merged_manifest, "w", encoding="utf-8") as f: - for record in merged_records: - f.write(json.dumps(record, ensure_ascii=False) + "\n") - - logging.info(f"Merged {len(merged_records)} multiturn turn records into {merged_manifest}") - return merged_manifest - - -def run_inference_and_evaluation( - runner: BaseInferenceRunner, - checkpoint_name: str, - inference_config: BaseInferenceConfig, - eval_config: EvaluationConfig, - dataset_meta_info: dict, - datasets: List[str], - out_dir: str, - flops_per_component: dict, - moe_info: str, - num_repeats: int = 1, - confidence_level: float = 0.95, - violin_plot_metrics: Optional[List[str]] = None, - clean_up_disk: bool = False, - skip_evaluation: bool = False, -) -> Tuple[Optional[float], Optional[float]]: - if violin_plot_metrics is None: - violin_plot_metrics = list(DEFAULT_VIOLIN_METRICS) - if not eval_config.with_utmosv2 and 'utmosv2' in violin_plot_metrics: - violin_plot_metrics.remove('utmosv2') - - rank, world_size, _ = _get_torchrun_rank_info() - is_distributed = world_size > 1 - is_multiturn_user_audio = getattr(runner, "produces_turn_level_evaluation", False) - - if hasattr(runner, "set_distributed_context"): - runner.set_distributed_context(rank=rank, world_size=world_size) - - full_checkpoint_name = ( - f"{checkpoint_name}_{moe_info}{inference_config.build_identifier()}_SV_{eval_config.sv_model}" - ) - - ssim_per_dataset = [] - cer_per_dataset = [] - all_datasets_filewise_metrics = {} - - csv_header = ( - "checkpoint_name,dataset,cer_filewise_avg,wer_filewise_avg,cer_cumulative," - "wer_cumulative,ssim_pred_gt_avg,ssim_pred_context_avg,ssim_gt_context_avg," - "ssim_pred_gt_avg_alternate,ssim_pred_context_avg_alternate," - "ssim_gt_context_avg_alternate,cer_gt_audio_cumulative,wer_gt_audio_cumulative," - "utmosv2_avg,total_gen_audio_seconds,frechet_codec_distance," - "eou_cutoff_rate,eou_silence_rate,eou_noise_rate,eou_error_rate" - ) - - for dataset in datasets: - logging.info(f"Processing dataset: {dataset}") - meta = dataset_meta_info[dataset] - manifest_records = read_manifest(meta['manifest_path']) - language = meta.get('whisper_language', 'en') - - dataset_meta_for_dl = copy.deepcopy(meta) - for key in ["whisper_language", "load_cached_codes_if_available"]: - dataset_meta_for_dl.pop(key, None) - - eval_dir = os.path.join(out_dir, f"{full_checkpoint_name}_{dataset}") - audio_dir = os.path.join(eval_dir, "audio") - os.makedirs(eval_dir, exist_ok=True) - - per_run_csv = os.path.join(eval_dir, "all_experiment_metrics.csv") - if rank == 0: - write_csv_header_if_needed(per_run_csv, csv_header) - - metrics_all_repeats = [] - filewise_metrics_all_repeats = [] - - for repeat_idx in range(num_repeats): - logging.info(f"Repeat {repeat_idx + 1}/{num_repeats} for dataset {dataset}, rank {rank}/{world_size}") - repeat_audio_dir = os.path.join(audio_dir, f"repeat_{repeat_idx}") - os.makedirs(repeat_audio_dir, exist_ok=True) - - test_dataset = runner.create_dataset({dataset: dataset_meta_for_dl}) - - if not is_multiturn_user_audio: - if is_distributed: - raise RuntimeError( - "torchrun multi-GPU sharding is currently implemented for " - "--easy_magpie_inference_mode multiturn_user_audio only. " - "Use the existing single-process path for single_turn/magpie, or add a " - "rank-safe merge path for those runners." - ) - if len(test_dataset) != len(manifest_records): - raise ValueError( - f"Dataset length mismatch: {len(test_dataset)} vs {len(manifest_records)} manifest records" - ) - - if is_distributed and is_multiturn_user_audio: - rank_audio_dir = os.path.join(repeat_audio_dir, f"rank_{rank:04d}") - inference_output_dir = rank_audio_dir - else: - inference_output_dir = repeat_audio_dir - - rtf_metrics_list, _, codec_file_paths = runner.run_inference_on_dataset( - dataset=test_dataset, - output_dir=inference_output_dir, - manifest_records=manifest_records, - audio_base_dir=meta['audio_dir'], - save_cross_attention_maps=True, - save_context_audio=(repeat_idx == 0), - save_predicted_codes=eval_config.with_fcd, - ) - - mean_rtf = runner.compute_mean_rtf_metrics(rtf_metrics_list) - for component_name, component_flops in flops_per_component.items(): - for key, value in component_flops.items(): - mean_rtf[f"{component_name}_{key}"] = value - logging.info(f"{component_name} FLOPs per token: {component_flops['total_flops_per_token']:,}") - - rtf_path = os.path.join(eval_dir, f"{dataset}_rtf_metrics_{repeat_idx}_rank{rank:04d}.json") - with open(rtf_path, "w", encoding="utf-8") as f: - json.dump(mean_rtf, f, indent=4) - - if skip_evaluation: - logging.info("Skipping evaluation as requested.") - continue - - if is_distributed and is_multiturn_user_audio: - if rank != 0: - # Non-zero ranks only generate. Rank 0 waits and evaluates merged outputs. - continue - - _wait_for_multiturn_rank_manifests(repeat_audio_dir, world_size) - merged_manifest_path = _merge_multiturn_rank_outputs( - repeat_audio_dir=repeat_audio_dir, - world_size=world_size, - save_predicted_codes=eval_config.with_fcd, - ) - eval_manifest_path = merged_manifest_path - eval_audio_dir = repeat_audio_dir - else: - eval_manifest_path, eval_audio_dir = _runner_eval_manifest_and_audio_dir( - runner, - default_manifest=meta['manifest_path'], - default_audio_dir=meta['audio_dir'], - ) - - eval_config_for_dataset = EvaluationConfig( - sv_model=eval_config.sv_model, - asr_model_name=eval_config.asr_model_name, - eou_model_name=eval_config.eou_model_name, - language=language, - with_utmosv2=eval_config.with_utmosv2, - with_fcd=eval_config.with_fcd, - codec_model_path=eval_config.codec_model_path, - device=eval_config.device, - ) - - metrics, filewise_metrics = evaluate_generated_audio_dir( - manifest_path=eval_manifest_path, - audio_dir=eval_audio_dir, - generated_audio_dir=repeat_audio_dir, - config=eval_config_for_dataset, - ) - - metrics_all_repeats.append(metrics) - filewise_metrics_all_repeats.extend(filewise_metrics) - - with open(os.path.join(eval_dir, f"{dataset}_metrics_{repeat_idx}.json"), "w", encoding="utf-8") as f: - json.dump(metrics, f, indent=4) - - sorted_filewise = sorted(filewise_metrics, key=lambda x: x.get('cer', 0), reverse=True) - with open(os.path.join(eval_dir, f"{dataset}_filewise_metrics_{repeat_idx}.json"), "w", encoding="utf-8") as f: - json.dump(sorted_filewise, f, indent=4) - - append_metrics_to_csv(per_run_csv, full_checkpoint_name, dataset, metrics) - create_violin_plot( - filewise_metrics, - violin_plot_metrics, - Path(eval_dir) / f"{dataset}_violin_{repeat_idx}.png", - ) - - # EasyMagpie deletes codec files after evaluation. For distributed - # multiturn, the merged predicted_codes_*.pt live in repeat_audio_dir. - cleanup_code_paths = codec_file_paths - if is_distributed and is_multiturn_user_audio: - cleanup_code_paths = list(Path(repeat_audio_dir).glob("predicted_codes_*.pt")) - for codec_file_path in cleanup_code_paths: - if os.path.exists(codec_file_path): - os.remove(codec_file_path) - - if rank != 0: - continue - - if skip_evaluation or not metrics_all_repeats: - continue - - all_datasets_filewise_metrics[dataset] = filewise_metrics_all_repeats - metrics_mean_ci = compute_mean_with_confidence_interval(metrics_all_repeats, confidence=confidence_level) - formatted_metrics_mean_ci = create_formatted_metrics_mean_ci(metrics_mean_ci) - - ci_csv = os.path.join(out_dir, "all_experiment_metrics_with_ci.csv") - write_csv_header_if_needed(ci_csv, csv_header) - append_metrics_to_csv(ci_csv, full_checkpoint_name, dataset, formatted_metrics_mean_ci) - - ssim_values = [m['ssim_pred_context_avg'] for m in metrics_all_repeats] - cer_values = [m['cer_cumulative'] for m in metrics_all_repeats] - ssim_per_dataset.append(np.mean(ssim_values)) - cer_per_dataset.append(np.mean(cer_values)) - - if rank == 0 and len(all_datasets_filewise_metrics) > 1: - combined_plot_path = os.path.join(out_dir, f"{full_checkpoint_name}_combined_violin_plot.png") - create_combined_box_plot(all_datasets_filewise_metrics, violin_plot_metrics, combined_plot_path) - - if rank == 0 and clean_up_disk: - logging.info(f"Cleaning up output directory: {out_dir}") - shutil.rmtree(out_dir) - - if rank == 0 and ssim_per_dataset and cer_per_dataset: - return np.mean(cer_per_dataset), np.mean(ssim_per_dataset) - return None, None - - -def _get_shared_inference_param_names() -> set: - magpie_fields = {f.name for f in fields(ModelInferenceParameters)} - easy_fields = {f.name for f in fields(EasyModelInferenceParameters)} - return magpie_fields & easy_fields - - -def _add_inference_param_fields( - group: argparse._ArgumentGroup, - param_cls: type, - skip_fields: Optional[set] = None, - only_fields: Optional[set] = None, -) -> None: - if skip_fields is None: - skip_fields = set() - for f in fields(param_cls): - if f.name in skip_fields: - continue - if only_fields is not None and f.name not in only_fields: - continue - extra_args: dict = {"type": f.type} - if f.type == bool: - extra_args = {"action": "store_true"} - if f.name in ("estimate_alignment_from_layers", "apply_prior_to_layers"): - extra_args = {"help": "Must be a comma separate string. Not enclosed in brackets", "type": str} - elif f.name == "eos_detection_method": - extra_args["choices"] = [m.value for m in EOSDetectionMethod] - group.add_argument(f"--{f.name}", **extra_args) - - -def _add_common_args(parser: argparse.ArgumentParser) -> None: - parser.add_argument('--model_type', type=str, default='magpie', choices=['magpie', 'easy_magpie']) - parser.add_argument('--deterministic', action='store_true') - - model_group = parser.add_argument_group('Model Loading') - model_group.add_argument('--hparams_files', type=str, default=None) - model_group.add_argument('--checkpoint_files', type=str, default=None) - model_group.add_argument('--nemo_files', type=str, default=None) - model_group.add_argument('--codecmodel_path', type=str, required=True) - model_group.add_argument('--hparams_file_from_wandb', action='store_true') - model_group.add_argument('--legacy_codebooks', action='store_true') - model_group.add_argument('--legacy_text_conditioning', action='store_true') - - data_group = parser.add_argument_group('Dataset and Output') - data_group.add_argument('--datasets_json_path', type=str, required=True, default=None) - data_group.add_argument('--datasets_base_path', type=Path, default=None) - data_group.add_argument('--datasets', type=str, default=None) - data_group.add_argument('--out_dir', type=str, required=True) - data_group.add_argument('--log_exp_name', action='store_true') - data_group.add_argument('--clean_up_disk', action='store_true') - - infer_group = parser.add_argument_group('Common Inference Parameters') - infer_group.add_argument('--batch_size', type=int, default=32) - infer_group.add_argument('--use_cfg', action='store_true') - infer_group.add_argument('--use_local_transformer', action='store_true') - shared_param_names = _get_shared_inference_param_names() - _add_inference_param_fields(infer_group, ModelInferenceParameters, only_fields=shared_param_names) - - eval_group = parser.add_argument_group('Evaluation') - eval_group.add_argument('--run_evaluation', action='store_true') - eval_group.add_argument('--sv_model', type=str, default='titanet', choices=['titanet', 'wavlm']) - eval_group.add_argument('--asr_model_name', type=str, default='nvidia/parakeet-tdt-1.1b') - eval_group.add_argument('--eou_model_name', type=str, default='facebook/wav2vec2-base-960h') - eval_group.add_argument('--num_repeats', type=int, default=1) - eval_group.add_argument('--confidence_level', type=float, default=0.95) - eval_group.add_argument('--disable_utmosv2', action='store_true') - eval_group.add_argument('--violin_plot_metrics', type=str, nargs='*', default=['cer', 'pred_context_ssim', 'utmosv2']) - eval_group.add_argument('--disable_fcd', action='store_true') - - target_group = parser.add_argument_group('Quality Targets') - target_group.add_argument('--cer_target', type=float, default=None) - target_group.add_argument('--ssim_target', type=float, default=None) - - -def seed_all(seed: int): - torch.manual_seed(seed) - random.seed(seed) - np.random.seed(seed) - torch.backends.cudnn.benchmark = False - torch.use_deterministic_algorithms(True) - - -def _add_magpie_args(parser: argparse.ArgumentParser) -> None: - group = parser.add_argument_group('MagpieTTS-specific Parameters') - shared_param_names = _get_shared_inference_param_names() - _add_inference_param_fields(group, ModelInferenceParameters, skip_fields=shared_param_names) - group.add_argument('--maskgit_n_steps', type=int, default=3) - group.add_argument('--maskgit_noise_scale', type=float, default=0.0) - group.add_argument('--maskgit_fixed_schedule', type=int, nargs='+', default=None) - group.add_argument('--maskgit_sampling_type', default=None, choices=['default', 'causal', 'purity_causal', 'purity_default']) - - -def _add_easy_magpie_args(parser: argparse.ArgumentParser) -> None: - group = parser.add_argument_group('EasyMagpieTTS-specific Parameters') - group.add_argument('--easy_magpie_inference_mode', type=str, default='single_turn', choices=['single_turn', 'multiturn_user_audio']) - group.add_argument('--max_eval_turns', type=int, default=6) - group.add_argument('--no_save_debug_multiturn_audio', action='store_true') - group.add_argument('--phoneme_input_type', type=str, default='gt', choices=['gt', 'predicted']) - group.add_argument('--phoneme_sampling_method', type=str, default='argmax', choices=['argmax', 'multinomial', 'greedy']) - group.add_argument('--dropout_text_input', action='store_true') - group.add_argument('--phoneme_tokenizer_path', type=str, default=None) - group.add_argument('--disable_cas_for_context_text', action='store_true') - - -def create_argument_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(description='TTS Inference and Evaluation (MagpieTTS & EasyMagpieTTS)') - _add_common_args(parser) - _add_magpie_args(parser) - _add_easy_magpie_args(parser) - return parser - - -def _build_inference_params_from_args(param_cls: type, args): - params = {} - for f in fields(param_cls): - arg_val = vars(args).get(f.name) - if arg_val is not None: - if f.name in ('estimate_alignment_from_layers', 'apply_prior_to_layers'): - params[f.name] = parse_layer_list(arg_val) - else: - params[f.name] = arg_val - return param_cls.from_dict(params) - - -def _build_magpie_config(args) -> MagpieInferenceConfig: - return MagpieInferenceConfig( - model_inference_parameters=_build_inference_params_from_args(ModelInferenceParameters, args), - batch_size=args.batch_size, - use_cfg=args.use_cfg, - apply_attention_prior=args.apply_attention_prior, - use_local_transformer=args.use_local_transformer, - maskgit_n_steps=args.maskgit_n_steps, - maskgit_noise_scale=args.maskgit_noise_scale, - maskgit_fixed_schedule=args.maskgit_fixed_schedule, - maskgit_sampling_type=args.maskgit_sampling_type, - ) - - -def _build_easy_magpie_config(args) -> EasyMagpieInferenceConfig: - cfg_cls = ( - EasyMagpieMultiturnUserAudioInferenceConfig - if args.easy_magpie_inference_mode == 'multiturn_user_audio' - else EasyMagpieInferenceConfig - ) - kwargs = dict( - model_inference_parameters=_build_inference_params_from_args(EasyModelInferenceParameters, args), - batch_size=args.batch_size, - use_cfg=args.use_cfg, - use_local_transformer=args.use_local_transformer, - phoneme_input_type=args.phoneme_input_type, - phoneme_sampling_method=args.phoneme_sampling_method, - dropout_text_input=args.dropout_text_input, - ) - if cfg_cls is EasyMagpieMultiturnUserAudioInferenceConfig: - kwargs.update( - max_eval_turns=args.max_eval_turns, - save_debug_multiturn_audio=not args.no_save_debug_multiturn_audio, - ) - return cfg_cls(**kwargs) - - -def _select_runner_cls(args): - if args.model_type == 'magpie': - if args.easy_magpie_inference_mode != 'single_turn': - raise ValueError('--easy_magpie_inference_mode is only supported with --model_type easy_magpie') - return MagpieInferenceRunner - if args.easy_magpie_inference_mode == 'multiturn_user_audio': - return EasyMagpieMultiturnUserAudioInferenceRunner - return EasyMagpieInferenceRunner - - -def main(argv=None): - parser = create_argument_parser() - args = parser.parse_args(argv) - rank, world_size, local_rank = _configure_cuda_for_rank() - - if args.model_type == 'easy_magpie' and args.easy_magpie_inference_mode == 'multiturn_user_audio' and args.batch_size > 1: - parser.error("--easy_magpie_inference_mode multiturn_user_audio requires --batch_size 1.") - - if args.deterministic: - seed_all(seed=9) - - dataset_meta_info = load_evalset_config(config_path=args.datasets_json_path, dataset_base_path=args.datasets_base_path) - datasets = filter_datasets(dataset_meta_info, args.datasets) - logging.info(f"Loaded {len(datasets)} datasets: {', '.join(datasets)}") - - has_checkpoint_mode = ( - args.hparams_files is not None - and args.checkpoint_files is not None - and args.hparams_files != 'null' - and args.checkpoint_files != 'null' - ) - has_nemo_mode = args.nemo_files is not None and args.nemo_files != 'null' - - if not has_checkpoint_mode and not has_nemo_mode: - parser.error('You must provide either --hparams_files/--checkpoint_files or --nemo_files') - - is_easy_magpie = args.model_type == 'easy_magpie' - load_fn = load_easy_magpie_model if is_easy_magpie else load_magpie_model - inference_config = _build_easy_magpie_config(args) if is_easy_magpie else _build_magpie_config(args) - runner_cls = _select_runner_cls(args) - - eval_config = EvaluationConfig( - sv_model=args.sv_model, - asr_model_name=args.asr_model_name, - eou_model_name=args.eou_model_name, - with_utmosv2=not args.disable_utmosv2, - with_fcd=not args.disable_fcd, - codec_model_path=args.codecmodel_path if not args.disable_fcd else None, - ) - - cer, ssim = None, None - - def run_one_model(model_config: ModelLoadConfig): - nonlocal cer, ssim - model, checkpoint_name = load_fn(model_config) - moe_info, flops_per_component = log_model_architecture_summary(model) - if args.log_exp_name and model_config.checkpoint_file: - exp_name = get_experiment_name_from_checkpoint_path(model_config.checkpoint_file) - checkpoint_name = f'{exp_name}__{checkpoint_name}' - runner = runner_cls(model, inference_config) - cer, ssim = run_inference_and_evaluation( - runner=runner, - checkpoint_name=checkpoint_name, - inference_config=inference_config, - eval_config=eval_config, - dataset_meta_info=dataset_meta_info, - datasets=datasets, - out_dir=args.out_dir, - flops_per_component=flops_per_component, - moe_info=moe_info, - num_repeats=args.num_repeats, - confidence_level=args.confidence_level, - violin_plot_metrics=args.violin_plot_metrics, - clean_up_disk=args.clean_up_disk, - skip_evaluation=not args.run_evaluation, - ) - - if has_checkpoint_mode: - hparam_files = args.hparams_files.split(',') - checkpoint_files = args.checkpoint_files.split(',') - if len(hparam_files) != len(checkpoint_files): - parser.error('Number of hparams_files must match number of checkpoint_files') - for hparams_file, checkpoint_file in zip(hparam_files, checkpoint_files): - logging.info(f'Processing checkpoint: {checkpoint_file}') - run_one_model( - ModelLoadConfig( - hparams_file=hparams_file, - checkpoint_file=checkpoint_file, - codecmodel_path=args.codecmodel_path, - legacy_codebooks=args.legacy_codebooks, - legacy_text_conditioning=args.legacy_text_conditioning, - hparams_from_wandb=args.hparams_file_from_wandb, - phoneme_tokenizer_path=getattr(args, 'phoneme_tokenizer_path', None), - disable_cas_for_context_text=args.disable_cas_for_context_text, - ) - ) - else: - for nemo_file in args.nemo_files.split(','): - logging.info(f'Processing NeMo file: {nemo_file}') - run_one_model( - ModelLoadConfig( - nemo_file=nemo_file, - codecmodel_path=args.codecmodel_path, - legacy_codebooks=args.legacy_codebooks, - legacy_text_conditioning=args.legacy_text_conditioning, - phoneme_tokenizer_path=getattr(args, 'phoneme_tokenizer_path', None), - disable_cas_for_context_text=args.disable_cas_for_context_text, - ) - ) - - if cer is not None and args.cer_target is not None: - if cer > args.cer_target: - raise ValueError(f'CER {cer:.4f} exceeds target {args.cer_target:.4f}') - logging.info(f'CER {cer:.4f} meets target {args.cer_target:.4f}') - - if ssim is not None and args.ssim_target is not None: - if ssim < args.ssim_target: - raise ValueError(f'SSIM {ssim:.4f} below target {args.ssim_target:.4f}') - logging.info(f'SSIM {ssim:.4f} meets target {args.ssim_target:.4f}') - - logging.info('Inference and evaluation completed successfully.') - - -if __name__ == '__main__': - main() From 0149169b2e95eae3a8643abb0a84f6bf730407bd Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 9 Jun 2026 06:15:23 -0700 Subject: [PATCH 079/109] Undo unecessary changes on cutset Signed-off-by: Edresson Casanova --- nemo/collections/common/data/lhotse/cutset.py | 90 ++----------------- 1 file changed, 7 insertions(+), 83 deletions(-) diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index 5dc3e944cfcb..085536458c17 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -14,8 +14,6 @@ """Lhotse CutSet utilities and Parquet manifest support for NeMo.""" import io -import os -import json import logging import random import re @@ -26,8 +24,6 @@ from pathlib import Path from typing import KeysView, List, Mapping, Sequence, Tuple, Union -from copy import deepcopy - import numpy as np import omegaconf import soundfile as sf @@ -59,9 +55,6 @@ ) from nemo.collections.common.parts.preprocessing.manifest import get_full_path -from lhotse import Recording, AudioSource, SupervisionSegment, MonoCut, CutSet - -from pydub.utils import mediainfo def temperature_reweighting(weights: List[Union[float, int]], temperature: float = 1.0) -> List[float]: """ @@ -946,8 +939,6 @@ def read_lhotse_magpietts_data_as_s2s_duplex(config) -> Tuple[CutSet, bool]: add_extra_end_sil = config.get("add_extra_end_silence", False) extra_end_silence_range = config.get("extra_end_silence_range", [0.5, 6.0]) - add_extra_begin_sil = config.get("add_extra_begin_sil", False) - extra_begin_silence_range = config.get("extra_begin_silence_range", [0.5, 6.0]) sample_rate = config.get("sample_rate", 22050) max_cer = config.get("max_cer", 0.03) @@ -962,53 +953,14 @@ def create_recording_from_array(samples: np.ndarray, sampling_rate: int, recordi buffer.seek(0) return Recording.from_bytes(buffer.read(), recording_id=recording_id) - def materialize_to_monocut(cut_like: Cut, cut_id: str, sample_rate: int) -> MonoCut: - audio = cut_like.load_audio() # renders mix -> (C, N) - rec = create_recording_from_array(audio, sample_rate, recording_id=f"{cut_id}_rec") - return MonoCut( - id=cut_id, - start=0.0, - duration=cut_like.duration, - channel=0, - recording=rec, - supervisions=[], - ).move_to_memory(audio_format="wav") - - def prepend_silence_monocut( - cut: MonoCut, - sil_duration: float, - sample_rate: int, - recording_id: str, - cut_id: str, - ) -> MonoCut: - audio = cut.load_audio() # (C, N) - n_pad = int(round(sil_duration * sample_rate)) - if n_pad <= 0: - return cut - - pad = np.zeros((audio.shape[0], n_pad), dtype=audio.dtype) - audio2 = np.concatenate([pad, audio], axis=1) - - rec = create_recording_from_array(audio2, sample_rate, recording_id=recording_id) - return MonoCut( - id=cut_id, - start=0.0, - duration=audio2.shape[1] / sample_rate, - channel=0, - recording=rec, - supervisions=[], - ).move_to_memory(audio_format="wav") - def convert_cut_fn(cut: Cut) -> Cut: """Convert a single cut into the continuation format.""" - - orig_agent_sup = fastcopy(cut.supervisions[0]) + orig_agent_sup = deepcopy(cut.supervisions[0]) target_audio_orig_dur = cut.target_audio.duration # Resample audios cut.target_audio = cut.target_audio.resample(sample_rate) - if cut.has_custom("context_audio"): - cut.context_audio = cut.context_audio.resample(sample_rate) + cut.context_audio = cut.context_audio.resample(sample_rate) total_duration = cut.target_audio.duration # Prepare MonoCuts @@ -1041,7 +993,7 @@ def convert_cut_fn(cut: Cut) -> Cut: user_sup = fastcopy(orig_agent_sup, start=0.0, duration=0.08, speaker="user", text="dummy text") agent_sup = fastcopy(orig_agent_sup, start=0.0, duration=target_audio_orig_dur - 0.08, speaker="agent") - # Optionally add extra silence on the end + # Optionally add extra silence if add_extra_end_sil: sil_duration = random.uniform(*extra_end_silence_range) cut_target = cut_target.pad(duration=total_duration + sil_duration, direction="right") @@ -1051,41 +1003,12 @@ def convert_cut_fn(cut: Cut) -> Cut: agent_sup.duration += sil_duration + 1.0 user_sup.duration += sil_duration - # Optionally add extra silence on the start - if add_extra_begin_sil: - sil_duration = random.uniform(*extra_begin_silence_range) - # Pad both streams on the left (adds zeros at the start) - prev_target_dur = cut_target.duration - prev_source_dur = cut_source.duration - # prepend zeros explicitly - cut_target = prepend_silence_monocut( - cut_target, sil_duration, sample_rate, - recording_id=f"{cut.id}_target_pre", cut_id=f"{cut.id}_target" - ) - cut_source = prepend_silence_monocut( - cut_source, sil_duration, sample_rate, - recording_id=f"{cut.id}_source_pre", cut_id=f"{cut.id}_source" - ) - - # Shift supervision start times forward, because audio got longer at the beginning - user_sup.start += sil_duration - agent_sup.start += sil_duration - # Assemble final cut cut_source.supervisions = [user_sup, agent_sup] cut_source.target_audio = cut_target.recording cut_source.duration = cut_target.duration - if cut.has_custom("context_audio"): - cut_source.context_audio = cut.context_audio - if cut.has_custom("context_codes"): - cut_source.context_codes = cut.context_codes - if cut.has_custom("target_codes"): - cut_source.target_codes = cut.target_codes - if cut.has_custom("lang"): - cut_source.lang = cut.lang - if cut.has_custom("ipa"): - cut_source.ipa = cut.ipa - cut_source.formatter = "lhotse_magpietts_data_as_continuation" + cut_source.context_audio = cut.context_audio + cut_source.task = "lhotse_magpietts_data_as_continuation" return cut_source @@ -1125,6 +1048,7 @@ def filter_target_speaker_fn(cut: Cut) -> bool: def read_s2s_duplex_reverse_role(config) -> Tuple[CutSet, bool]: """ Reverse the speaker roles and swap the source/target audio streams in a Duplex S2S CutSet. + This parser takes an existing conversational dataset and inverts the perspective by swapping the "user" and "agent" supervision labels. It also swaps the primary `recording` (usually source audio) with the `target_audio` to fully simulate the @@ -1837,4 +1761,4 @@ def read_nemo_tarred_to_duplex(config) -> tuple[CutSet, bool]: convert_fn = partial(_convert_tarred_to_duplex, agent_silence_duration=agent_silence_duration) cuts = cuts.map(convert_fn) - return cuts, is_tarred + return cuts, is_tarred \ No newline at end of file From 54f94f48086457a0931b6a6eb5e949617e279eab Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 9 Jun 2026 06:25:25 -0700 Subject: [PATCH 080/109] Remove unused user_audio_mask Signed-off-by: Edresson Casanova --- ...text_to_speech_dataset_lhotse_multiturn.py | 3 - nemo/collections/tts/models/easy_magpietts.py | 76 ------------------- 2 files changed, 79 deletions(-) diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py index 0bdb8c2009a6..9aa637426bc4 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py @@ -258,9 +258,6 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: # Get all agent supervisions in this cut agent_sups = [sup for sup in cut.supervisions if sup.speaker in self.output_roles] - # It is a multiturn if there's more than 1 agent turn - is_multiturn = not (len(agent_sups) <= 1) - def _align_codebooks(t): C = t.shape[1] if C < self.num_audio_codebooks: diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index 43c8ec13bd9b..73dc175807ce 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -1376,7 +1376,6 @@ def training_step(self, batch, batch_idx): D = user_audio_embedded.shape[-1] user_audio_embedded_restored = user_audio_embedded.new_zeros(B, T, D) - user_audio_embedded_mask = torch.zeros(B, T, device=user_audio_embedded.device, dtype=torch.bool) sample_prob = float(self.cfg.get("user_cond_trim_augmentation_sample_prob", 0.0) or 0.0) turn_prob = float(self.cfg.get("user_cond_trim_augmentation_turn_prob", 0.0) or 0.0) @@ -1421,15 +1420,12 @@ def training_step(self, batch, batch_idx): continue turn_emb = turn_emb[:copy_len].clone() - turn_mask = torch.ones(copy_len, device=user_audio_embedded.device, dtype=torch.bool) if boundary_trim > 0: trim = min(boundary_trim, copy_len // 2) if trim > 0: turn_emb[:trim] = 0.0 turn_emb[copy_len - trim:] = 0.0 - turn_mask[:trim] = False - turn_mask[copy_len - trim:] = False if bool(sample_trim_aug[b].item()): do_turn_aug = torch.rand((), device=user_audio_embedded.device).item() < turn_prob @@ -1454,23 +1450,16 @@ def training_step(self, batch, batch_idx): )[0] zero_emb_pad = turn_emb.new_zeros(trim_amount, turn_emb.size(-1)) - zero_mask_pad = torch.zeros(trim_amount, device=turn_mask.device, dtype=turn_mask.dtype) if aug_choice == "left": # Remove tokens from the left, then right-pad zeros. kept_emb = turn_emb[trim_amount:] - kept_mask = turn_mask[trim_amount:] turn_emb = torch.cat([kept_emb, zero_emb_pad], dim=0) - turn_mask = torch.cat([kept_mask, zero_mask_pad], dim=0) elif aug_choice == "right": # Remove tokens from the right, then right-pad zeros. - # This preserves timing of the left side and removes the transition/right edge. kept_emb = turn_emb[: copy_len - trim_amount] - kept_mask = turn_mask[: copy_len - trim_amount] - turn_emb = torch.cat([kept_emb, zero_emb_pad], dim=0) - turn_mask = torch.cat([kept_mask, zero_mask_pad], dim=0) else: # "both" # Remove trim_amount total tokens split across left and right. @@ -1482,84 +1471,19 @@ def training_step(self, batch, batch_idx): left_trim, right_trim = right_trim, left_trim kept_emb = turn_emb[left_trim : copy_len - right_trim] - kept_mask = turn_mask[left_trim : copy_len - right_trim] - turn_emb = torch.cat([kept_emb, zero_emb_pad], dim=0) - turn_mask = torch.cat([kept_mask, zero_mask_pad], dim=0) # Safety: keep exact same length for restore assignment. turn_emb = turn_emb[:copy_len] - turn_mask = turn_mask[:copy_len] dst_start = start_frame dst_end = start_frame + copy_len user_audio_embedded_restored[b, dst_start:dst_end] = turn_emb - user_audio_embedded_mask[b, dst_start:dst_end] = turn_mask user_audio_embedded = user_audio_embedded_restored - user_audio_mask = user_audio_embedded_mask - - # compare these two masks showing count batch level overlaps, left and right overlap per item batch. Consider only where both are ones. - """ - if "agent_mask" in batch and batch["agent_mask"] is not None: - user_cmp = user_audio_mask.bool() - agent_cmp = batch["agent_mask"].to(user_cmp.device).bool() - T_cmp = min(user_cmp.size(1), agent_cmp.size(1)) - valid = torch.arange(T_cmp, device=user_cmp.device)[None, :] < batch["text_lens"].to(user_cmp.device)[:, None] - valid = valid[:, :T_cmp] - user_cmp = user_cmp[:, :T_cmp] & valid - agent_cmp = agent_cmp[:, :T_cmp] & valid - - overlap = user_cmp & agent_cmp - - left_overlap = torch.zeros(B, device=user_cmp.device, dtype=torch.long) - right_overlap = torch.zeros(B, device=user_cmp.device, dtype=torch.long) - middle_overlap = torch.zeros(B, device=user_cmp.device, dtype=torch.long) - - boundary_width = int(self.cfg.get("user_agent_overlap_boundary_width", 10)) - - for bi in range(B): - user_idx = user_cmp[bi].nonzero(as_tuple=False).flatten() - if user_idx.numel() == 0: - continue - - breaks = torch.where(user_idx[1:] != user_idx[:-1] + 1)[0] + 1 - spans = torch.tensor_split(user_idx, breaks.cpu().tolist()) - - for span in spans: - s = int(span[0].item()) - e = int(span[-1].item()) + 1 - - overlap_span = overlap[bi, s:e] - if not overlap_span.any(): - continue - - left_end = min(e, s + boundary_width) - right_start = max(s, e - boundary_width) - - left_overlap[bi] += overlap[bi, s:left_end].sum() - right_overlap[bi] += overlap[bi, right_start:e].sum() - - boundary_mask = torch.zeros_like(overlap_span) - boundary_mask[: left_end - s] = True - boundary_mask[right_start - s :] = True - - middle_overlap[bi] += (overlap_span & ~boundary_mask).sum() - - logging.info( - "[user/agent-mask overlap debug] " - f"overlap_frames={overlap.sum(dim=1).detach().cpu().tolist()} " - f"left_overlap={left_overlap.detach().cpu().tolist()} " - f"right_overlap={right_overlap.detach().cpu().tolist()} " - f"middle_overlap={middle_overlap.detach().cpu().tolist()} " - f"user_frames={user_cmp.sum(dim=1).detach().cpu().tolist()} " - f"agent_frames={agent_cmp.sum(dim=1).detach().cpu().tolist()}" - ) - """ else: user_audio_embedded = None - user_audio_mask = None batch_output = self.process_batch( text=batch['text'], From 0910c680b7d9e1fb2fff4b5570d5780aac7c8e28 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 9 Jun 2026 06:29:58 -0700 Subject: [PATCH 081/109] Aplly Black Signed-off-by: Edresson Casanova --- examples/tts/easy_magpietts.py | 3 +- examples/tts/magpietts_inference.py | 4 +- nemo/collections/common/data/lhotse/cutset.py | 2 +- ...text_to_speech_dataset_lhotse_multiturn.py | 200 ++++++++++++------ nemo/collections/tts/models/easy_magpietts.py | 69 +++--- .../tts/models/easy_magpietts_inference.py | 51 +++-- .../modules/magpietts_inference/inference.py | 35 +-- .../tts/modules/nemotron_h_decoder.py | 28 +-- 8 files changed, 239 insertions(+), 153 deletions(-) diff --git a/examples/tts/easy_magpietts.py b/examples/tts/easy_magpietts.py index 26872c8edb87..74ac2c6e3965 100644 --- a/examples/tts/easy_magpietts.py +++ b/examples/tts/easy_magpietts.py @@ -54,13 +54,12 @@ def main(cfg): model = EasyMagpieTTSModel(cfg=cfg.model, trainer=trainer) else: raise NotImplementedError(f"Only train, onlinepo_train and test modes are supported. Got {mode}") - + if cfg.get("pretrained_model", None): model.restore_from_pretrained_checkpoint(cfg.pretrained_model) model.maybe_init_from_pretrained_checkpoint(cfg=cfg) - if mode in ['train', 'onlinepo_train']: trainer.fit(model) elif mode == 'test': diff --git a/examples/tts/magpietts_inference.py b/examples/tts/magpietts_inference.py index 033fe0efd8e2..99f780f476e1 100644 --- a/examples/tts/magpietts_inference.py +++ b/examples/tts/magpietts_inference.py @@ -252,7 +252,6 @@ def turn_sort_key(r): "speaker": group["speaker"], "rank": group["rank"], "num_turns": len(turns), - # Sample-level averages over all turns. "cer": _mean_finite(cer_turns), "wer": _mean_finite(wer_turns), @@ -262,7 +261,6 @@ def turn_sort_key(r): "utmosv2": _mean_finite(utmosv2_turns), "eou_trailing_duration": _mean_finite(eou_trailing_duration_turns), "eou_trail_rms_ratio": _mean_finite(eou_trail_rms_ratio_turns), - # Turn-by-turn values, old-script style. "turn_ids": [r.get("turn_id", i) for i, r in enumerate(turns)], "cer_turns": cer_turns, @@ -1252,4 +1250,4 @@ def main(argv=None): if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index 085536458c17..8e613091b7e1 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -1761,4 +1761,4 @@ def read_nemo_tarred_to_duplex(config) -> tuple[CutSet, bool]: convert_fn = partial(_convert_tarred_to_duplex, agent_silence_duration=agent_silence_duration) cuts = cuts.map(convert_fn) - return cuts, is_tarred \ No newline at end of file + return cuts, is_tarred diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py index 9aa637426bc4..32dfb9f79bd3 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py @@ -216,7 +216,9 @@ def __init__( self.phoneme_turn_dropout_turn_prob = phoneme_turn_dropout_turn_prob self.phoneme_turn_max_words_to_drop = phoneme_turn_max_words_to_drop - self.frame_length = (self.codec_model_samples_per_frame / codec_model_input_sample_rate) * frame_stacking_factor + self.frame_length = ( + self.codec_model_samples_per_frame / codec_model_input_sample_rate + ) * frame_stacking_factor def get_num_audio_samples_to_slice(self, duration, sample_rate): num_codec_frames = int(duration * sample_rate / self.codec_model_samples_per_frame) @@ -263,7 +265,7 @@ def _align_codebooks(t): if C < self.num_audio_codebooks: return F.pad(t, (0, self.num_audio_codebooks - C)) elif C > self.num_audio_codebooks: - return t[:, :self.num_audio_codebooks] + return t[:, : self.num_audio_codebooks] return t with fp32_precision(): @@ -277,7 +279,7 @@ def _align_codebooks(t): for i, cut in enumerate(cuts): # Extract the raw, unpadded 1D numpy array for this specific cut - t_audio = target_audio[i, :target_audio_lens[i]].numpy() + t_audio = target_audio[i, : target_audio_lens[i]].numpy() # For tts/single-turn cuts, source audio should be zeros like target audio # For multi-turn cuts, keep loading source audio from the cut recording. if str(getattr(cut, "task", "tts")).lower() == "tts": @@ -297,32 +299,54 @@ def _align_codebooks(t): # 3. Re-pad the newly stitched arrays to the batch's new maximum length target_audio = collate_vectors(target_audio_list, padding_value=0.0) target_audio_lens = torch.tensor([len(a) for a in target_audio_list], dtype=torch.long) - + source_audio = collate_vectors(source_audio_list, padding_value=0.0) source_audio_lens = torch.tensor([len(a) for a in source_audio_list], dtype=torch.long) - user_audio_turn_splitted, user_audio_turn_splitted_lens, user_audio_turn_splitted_indices = extract_turn_audio_channel( - cuts=cuts, - source_audio_list=source_audio_list, - source_sample_rate=self.source_sample_rate, - roles=self.input_roles, + user_audio_turn_splitted, user_audio_turn_splitted_lens, user_audio_turn_splitted_indices = ( + extract_turn_audio_channel( + cuts=cuts, + source_audio_list=source_audio_list, + source_sample_rate=self.source_sample_rate, + roles=self.input_roles, + ) ) target_text_tokens, target_token_lens = collate_token_channel( - cuts, self.text_tokenizer, self.frame_length, roles=self.output_roles, - add_text_bos=self.add_text_bos, tokenizer_names=batch_tokenizer_names, - pad_id=self.pad_id, eos_id=self.eos_id, bos_id=self.bos_id, interruption_token_id=self.interruption_token_id + cuts, + self.text_tokenizer, + self.frame_length, + roles=self.output_roles, + add_text_bos=self.add_text_bos, + tokenizer_names=batch_tokenizer_names, + pad_id=self.pad_id, + eos_id=self.eos_id, + bos_id=self.bos_id, + interruption_token_id=self.interruption_token_id, ) source_tokens, source_token_lens = collate_token_channel( - cuts, self.text_tokenizer, self.frame_length, roles=self.input_roles, - add_text_bos=self.add_text_bos, tokenizer_names=batch_tokenizer_names, - pad_id=self.pad_id, eos_id=self.eos_id, bos_id=self.bos_id, interruption_token_id=self.interruption_token_id + cuts, + self.text_tokenizer, + self.frame_length, + roles=self.input_roles, + add_text_bos=self.add_text_bos, + tokenizer_names=batch_tokenizer_names, + pad_id=self.pad_id, + eos_id=self.eos_id, + bos_id=self.bos_id, + interruption_token_id=self.interruption_token_id, ) if self.phoneme_tokenizer is not None: target_phoneme_tokens, target_phoneme_lens, phoneme_turn_dropout = collate_phoneme_channel( - cuts, self.phoneme_tokenizer, self.frame_length, roles=self.output_roles, - ignore_phoneme_languages=self.ignore_phoneme_languages, pad_id=self.phoneme_tokenizer.pad, eos_id=self.phoneme_tokenizer.eos_token_id, bos_id=self.phoneme_tokenizer.bos_token_id, + cuts, + self.phoneme_tokenizer, + self.frame_length, + roles=self.output_roles, + ignore_phoneme_languages=self.ignore_phoneme_languages, + pad_id=self.phoneme_tokenizer.pad, + eos_id=self.phoneme_tokenizer.eos_token_id, + bos_id=self.phoneme_tokenizer.bos_token_id, phoneme_turn_dropout_batch_prob=self.phoneme_turn_dropout_batch_prob, phoneme_turn_dropout_turn_prob=self.phoneme_turn_dropout_turn_prob, phoneme_turn_max_words_to_drop=self.phoneme_turn_max_words_to_drop, @@ -335,10 +359,10 @@ def _align_codebooks(t): audio_list_16khz = [] audio_len_list_16khz = [] prior_list = [] - + target_codes_list = [] source_codes_list = [] - + context_audio_list = [] context_audio_len_list = [] context_audio_codes_list = [] @@ -366,7 +390,11 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) dataset_name = "unknown" dataset_name_list.append(dataset_name) - language = cut.lang if cut.has_custom("lang") else next((sup.language for sup in reversed(cut.supervisions) if sup.has_custom("language")), "en") + language = ( + cut.lang + if cut.has_custom("lang") + else next((sup.language for sup in reversed(cut.supervisions) if sup.has_custom("language")), "en") + ) language_list.append(language) # Target and Source Codes @@ -382,10 +410,14 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) if self.load_cached_codes_if_available and cut.has_custom("context_codes"): context_audio_codes_array = cut.context_codes.load().astype(np.int32) context_audio_codes = torch.from_numpy(context_audio_codes_array) - _available_context_duration = (context_audio_codes.shape[1] * self.codec_model_samples_per_frame / self.sample_rate) + _available_context_duration = ( + context_audio_codes.shape[1] * self.codec_model_samples_per_frame / self.sample_rate + ) _context_duration_to_slice = _sample_context_duration_with_available_limit(_available_context_duration) - _num_frames_to_slice = int(_context_duration_to_slice * self.sample_rate / self.codec_model_samples_per_frame) - + _num_frames_to_slice = int( + _context_duration_to_slice * self.sample_rate / self.codec_model_samples_per_frame + ) + if _num_frames_to_slice < context_audio_codes.shape[1]: start_idx = random.randint(0, context_audio_codes.shape[1] - _num_frames_to_slice) context_audio_codes = context_audio_codes[:, start_idx : start_idx + _num_frames_to_slice] @@ -396,7 +428,7 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) context_audio_codes = _align_codebooks(context_audio_codes.T) context_audio_codes_list.append(context_audio_codes) context_audio_codes_len_list.append(context_audio_codes.shape[0]) - + elif cut.has_custom("context_audio"): with fp32_precision(): context_audio_array = cut.context_audio.resample(self.sample_rate).load_audio().squeeze(0) @@ -405,7 +437,9 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) _available_context_duration = len(context_audio_array) / self.sample_rate _context_duration_to_slice = _sample_context_duration_with_available_limit(_available_context_duration) - _num_samples_to_slice = self.get_num_audio_samples_to_slice(_context_duration_to_slice, self.sample_rate) + _num_samples_to_slice = self.get_num_audio_samples_to_slice( + _context_duration_to_slice, self.sample_rate + ) if _num_samples_to_slice < len(context_audio_array): start_idx = random.randint(0, len(context_audio_array) - _num_samples_to_slice) @@ -413,7 +447,7 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) else: _num_repeats = int(np.ceil(_num_samples_to_slice / len(context_audio_array))) context_audio_array = np.tile(context_audio_array, _num_repeats)[:_num_samples_to_slice] - + context_audio = torch.from_numpy(context_audio_array) context_audio_list.append(context_audio) context_audio_len_list.append(context_audio.shape[0]) @@ -427,7 +461,9 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) codes_array = cut.target_codes.load().astype(np.int32) start_frame = int(max(0, sup.start) * self.sample_rate / self.codec_model_samples_per_frame) num_frames = int(sup.duration * self.sample_rate / self.codec_model_samples_per_frame) - context_audio_codes = torch.from_numpy(codes_array)[:, start_frame : start_frame + num_frames].T + context_audio_codes = torch.from_numpy(codes_array)[ + :, start_frame : start_frame + num_frames + ].T context_audio_codes = _align_codebooks(context_audio_codes) else: context_audio_codes = torch.zeros([0, self.num_audio_codebooks], dtype=torch.int32) @@ -437,14 +473,16 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) if len(matching_supervisions) > 0: sup = random.choice(matching_supervisions) with fp32_precision(): - turn_cut = cut.resample(self.sample_rate, recording_field="target_audio").truncate(offset=max(0, sup.start), duration=sup.duration) + turn_cut = cut.resample(self.sample_rate, recording_field="target_audio").truncate( + offset=max(0, sup.start), duration=sup.duration + ) context_audio_array = turn_cut.load_custom("target_audio").squeeze(0) if self.volume_norm: context_audio_array = normalize_volume(context_audio_array) context_audio = torch.from_numpy(context_audio_array) else: context_audio = torch.zeros(self.codec_model_samples_per_frame, dtype=torch.float32) - + context_audio_list.append(context_audio) context_audio_len_list.append(context_audio.shape[0]) @@ -457,7 +495,9 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) audio_array_16khz = normalize_volume(audio_array_16khz) _available_context_duration = len(audio_array_16khz) / 16_000 - _context_duration_to_slice = _sample_context_duration_with_available_limit(_available_context_duration) + _context_duration_to_slice = _sample_context_duration_with_available_limit( + _available_context_duration + ) _num_samples_to_slice = int(_context_duration_to_slice * 16_000) if _num_samples_to_slice < len(audio_array_16khz): start_idx = random.randint(0, len(audio_array_16khz) - _num_samples_to_slice) @@ -466,7 +506,9 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) matching_supervisions = [s for s in cut.supervisions if s.speaker in self.output_roles] if len(matching_supervisions) > 0: sup = random.choice(matching_supervisions) - turn_cut = cut.resample(16_000, recording_field="target_audio").truncate(offset=max(0, sup.start), duration=sup.duration) + turn_cut = cut.resample(16_000, recording_field="target_audio").truncate( + offset=max(0, sup.start), duration=sup.duration + ) audio_array_16khz = turn_cut.load_custom("target_audio").squeeze(0) else: audio_array_16khz = np.zeros(16000, dtype=np.float32) @@ -480,20 +522,30 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) # Context Text if self.use_text_conditioning_tokenizer: - context_text = next((sup.context_text for sup in cut.supervisions if sup.has_custom("context_text")), None) + context_text = next( + (sup.context_text for sup in cut.supervisions if sup.has_custom("context_text")), None + ) if context_text is not None: if self.text_context_remapping is not None and context_text in self.text_context_remapping: if self.dataset_type == 'train' and random.random() < self.text_context_remapping_prob: context_text = self.text_context_remapping[context_text] - context_text_tokens = self.text_tokenizer.encode(context_text, tokenizer_name=self.text_conditioning_tokenizer_name) + context_text_tokens = self.text_tokenizer.encode( + context_text, tokenizer_name=self.text_conditioning_tokenizer_name + ) has_text_context = True else: - context_text = f"[{language.upper()}]" if self.add_language_to_context_text else "[NO TEXT CONTEXT]" - context_text_tokens = self.text_tokenizer.encode(context_text, tokenizer_name=self.text_conditioning_tokenizer_name) + context_text = ( + f"[{language.upper()}]" if self.add_language_to_context_text else "[NO TEXT CONTEXT]" + ) + context_text_tokens = self.text_tokenizer.encode( + context_text, tokenizer_name=self.text_conditioning_tokenizer_name + ) has_text_context = False - + if self.pad_context_text_to_max_duration: - _required_len = int(self.context_duration_max * self.sample_rate / self.codec_model_samples_per_frame) + 2 + _required_len = ( + int(self.context_duration_max * self.sample_rate / self.codec_model_samples_per_frame) + 2 + ) if len(context_text_tokens) < _required_len: _pad_id = self.text_tokenizer.tokenizer_pad_ids[self.text_conditioning_tokenizer_name] context_text_tokens += [_pad_id] * (_required_len - len(context_text_tokens)) @@ -508,7 +560,13 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) # Align Prior (Note: Using full target length to preserve shape compatibility) if self.include_align_prior: tok_name = batch_tokenizer_names[i] - full_text_len = sum([len(self.text_tokenizer.encode(sup.text, tokenizer_name=tok_name)) for sup in cut.supervisions if sup.speaker in self.output_roles]) + full_text_len = sum( + [ + len(self.text_tokenizer.encode(sup.text, tokenizer_name=tok_name)) + for sup in cut.supervisions + if sup.speaker in self.output_roles + ] + ) if self.add_text_bos: full_text_len += 2 * sum([1 for sup in cut.supervisions if sup.speaker in self.output_roles]) @@ -521,9 +579,13 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) if self.load_cached_codes_if_available and cut.has_custom("target_codes"): spec_len = int(target_codes_list[-1].shape[0]) + 1 else: - spec_len = int(target_audio_lens[i] / self.codec_model_samples_per_frame) + 2 # +1 extra in case it was truncated + spec_len = ( + int(target_audio_lens[i] / self.codec_model_samples_per_frame) + 2 + ) # +1 extra in case it was truncated - align_prior = beta_binomial_prior_distribution(phoneme_count=full_text_len, mel_count=spec_len, scaling_factor=self.prior_scaling_factor) + align_prior = beta_binomial_prior_distribution( + phoneme_count=full_text_len, mel_count=spec_len, scaling_factor=self.prior_scaling_factor + ) prior_list.append(torch.tensor(align_prior, dtype=torch.float32)) reward = next((sup.reward for sup in reversed(cut.supervisions) if sup.has_custom("reward")), None) @@ -542,7 +604,9 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) "source_token_lens": source_token_lens, "text": target_text_tokens, "text_lens": target_token_lens, - "raw_texts": [" ".join(s.text for s in cut.supervisions if s.speaker in self.output_roles) for cut in cuts], + "raw_texts": [ + " ".join(s.text for s in cut.supervisions if s.speaker in self.output_roles) for cut in cuts + ], "task": [getattr(cut, "task", "tts") for cut in cuts], "user_audio_turn_splitted": user_audio_turn_splitted, "user_audio_turn_splitted_lens": user_audio_turn_splitted_lens, @@ -552,7 +616,7 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) if target_codes_list: batch_dict["audio_codes"] = collate_matrices(target_codes_list, padding_value=0).transpose(1, 2) batch_dict["audio_codes_lens"] = torch.IntTensor([c.shape[0] for c in target_codes_list]) - + if source_codes_list: batch_dict["source_codes"] = collate_matrices(source_codes_list, padding_value=0).transpose(1, 2) batch_dict["source_codes_lens"] = torch.IntTensor([c.shape[0] for c in source_codes_list]) @@ -569,9 +633,11 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) if len(context_audio_list) > 0: batch_dict["context_audio"] = collate_vectors(context_audio_list, padding_value=0.0) batch_dict["context_audio_lens"] = torch.IntTensor(context_audio_len_list) - + if len(context_audio_codes_list) > 0: - batch_dict["context_audio_codes"] = collate_matrices(context_audio_codes_list, padding_value=0).transpose(1, 2) + batch_dict["context_audio_codes"] = collate_matrices(context_audio_codes_list, padding_value=0).transpose( + 1, 2 + ) batch_dict["context_audio_codes_lens"] = torch.IntTensor(context_audio_codes_len_list) if self.use_text_conditioning_tokenizer: @@ -628,8 +694,16 @@ def collate_token_channel( tok_name = tokenizer_names[i] if tokenizer_names else "english_phoneme" tokens.append( build_token_channel( - c, tokenizer, frame_length, roles, pad_id, eos_id, bos_id, interruption_token_id, - add_text_bos, tok_name, + c, + tokenizer, + frame_length, + roles, + pad_id, + eos_id, + bos_id, + interruption_token_id, + add_text_bos, + tok_name, ) ) token_lens = torch.tensor([len(tt) for tt in tokens]) @@ -653,15 +727,13 @@ def build_speaker_mask_channel( return mask + def collate_speaker_mask_channel( cuts: CutSet, frame_length: Seconds, output_roles: set[str], ): - masks = [ - build_speaker_mask_channel(cut, frame_length, output_roles) - for cut in cuts - ] + masks = [build_speaker_mask_channel(cut, frame_length, output_roles) for cut in cuts] mask_lens = torch.tensor([len(m) for m in masks]) return collate_vectors(masks, padding_value=0.0), mask_lens @@ -685,7 +757,7 @@ def build_token_channel( for supervision in cut.supervisions: if supervision.speaker in roles: text = supervision.text - + if hasattr(tokenizer, "encode"): try: raw_ids = tokenizer.encode(text=text, tokenizer_name=tokenizer_name) @@ -705,8 +777,8 @@ def build_token_channel( endpos = pos + len(text_ids) if endpos > len(tokens): - text_ids = text_ids[:len(tokens) - pos] - tokens[pos:pos+len(text_ids)] = text_ids + text_ids = text_ids[: len(tokens) - pos] + tokens[pos : pos + len(text_ids)] = text_ids # add interruption token, used for add speech eos and interrupt the model interruption_pos = compute_num_frames(supervision.end, frame_length, cut.sampling_rate) @@ -786,8 +858,14 @@ def collate_phoneme_channel( dropout_flags = [] for i, c in enumerate(cuts): token, dropout_applied = build_phoneme_channel( - c, phoneme_tokenizer, frame_length, roles, - ignore_phoneme_languages, pad_id, eos_id, bos_id, + c, + phoneme_tokenizer, + frame_length, + roles, + ignore_phoneme_languages, + pad_id, + eos_id, + bos_id, phoneme_turn_dropout_batch_prob=phoneme_turn_dropout_batch_prob, phoneme_turn_dropout_turn_prob=phoneme_turn_dropout_turn_prob, phoneme_turn_max_words_to_drop=phoneme_turn_max_words_to_drop, @@ -813,7 +891,11 @@ def build_phoneme_channel( phoneme_turn_max_words_to_drop: int = 2, apply_turn_dropout: bool = False, ) -> tuple[torch.Tensor, bool]: - language = cut.lang if cut.has_custom("lang") else next((sup.language for sup in cut.supervisions if sup.has_custom("language")), "en") + language = ( + cut.lang + if cut.has_custom("lang") + else next((sup.language for sup in cut.supervisions if sup.has_custom("language")), "en") + ) total = compute_num_frames(cut.duration, frame_length, cut.sampling_rate) tokens = torch.ones(total, dtype=torch.long) * pad_id @@ -850,7 +932,7 @@ def build_phoneme_channel( endpos = pos + len(phoneme_ids) if endpos > len(tokens): - phoneme_ids = phoneme_ids[:len(tokens) - pos] - tokens[pos:pos+len(phoneme_ids)] = phoneme_ids + phoneme_ids = phoneme_ids[: len(tokens) - pos] + tokens[pos : pos + len(phoneme_ids)] = phoneme_ids - return tokens, dropout_applied \ No newline at end of file + return tokens, dropout_applied diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index 73dc175807ce..a32b75f85daa 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -63,6 +63,7 @@ from transformers import WhisperForConditionalGeneration, WhisperProcessor from typing import List + @dataclass class ProcessBatchOutput: """ @@ -171,7 +172,6 @@ def _get_state_dict_keys_to_exclude(self): '_utmos_calculator', ] - def compute_loss( self, logits, @@ -272,11 +272,7 @@ def compute_phoneme_loss( phoneme_loss = raw_loss * effective_mask phoneme_loss = phoneme_loss.sum() / effective_mask.sum().clamp_min(1.0) - total_phoneme_loss = ( - phoneme_loss - if total_phoneme_loss is None - else total_phoneme_loss + phoneme_loss - ) + total_phoneme_loss = phoneme_loss if total_phoneme_loss is None else total_phoneme_loss + phoneme_loss total_phoneme_loss = total_phoneme_loss / self.phoneme_stacking_factor return total_phoneme_loss, loss_mask @@ -498,7 +494,7 @@ def prepare_phoneme_channel_embeddings( # Apply mask to zero out padding if self.cfg.get("use_multiturn_dataset", False): phoneme_pad_id = getattr(self.phoneme_tokenizer, "pad", -1) - phoneme_mask = (phoneme_tokens_stacked[:, 0, :] != phoneme_pad_id) # Check the first layer of the stack + phoneme_mask = phoneme_tokens_stacked[:, 0, :] != phoneme_pad_id # Check the first layer of the stack # Apply mask to zero out padding phoneme_embedded = phoneme_embedded * phoneme_mask.unsqueeze(2) # (B, T', E) else: @@ -772,10 +768,7 @@ def prepare_audio_channel_embeddings( user_to_agent = torch.zeros_like(target_agent_mask) user_to_agent[:, 1:] = ( - target_agent_mask[:, 1:] - & (~target_agent_mask[:, :-1]) - & valid[:, 1:] - & valid[:, :-1] + target_agent_mask[:, 1:] & (~target_agent_mask[:, :-1]) & valid[:, 1:] & valid[:, :-1] ) end_tok_input = torch.full_like(audio_codes_input, self.audio_user_speaking_end_id) @@ -817,7 +810,13 @@ def prepare_audio_channel_embeddings( lengths=[delay, audio_codes_lens_target], ) - return audio_channel_embedding, audio_channel_lens, audio_codes_target, audio_codes_lens_target, loss_agent_mask + return ( + audio_channel_embedding, + audio_channel_lens, + audio_codes_target, + audio_codes_lens_target, + loss_agent_mask, + ) def slice_sequence_embeddings(self, sequence_embeddings, context_lens, target_lens): """ @@ -863,7 +862,7 @@ def process_batch( training_mode: Optional[TrainingMode] = None, task: Optional[List[str]] = None, agent_mask: Optional[torch.Tensor] = None, - user_audio_embedded: Optional[torch.Tensor] = None + user_audio_embedded: Optional[torch.Tensor] = None, ) -> ProcessBatchOutput: """ Simplified batch processing using channel-based embedding architecture. @@ -946,7 +945,7 @@ def process_batch( speech_eos_mask = None if self.cfg.get("use_multiturn_dataset", False): - speech_eos_mask = (text == self.interruption_token_id) # (B, T) + speech_eos_mask = text == self.interruption_token_id # (B, T) # remove the interruption token for all task, expect for interruption if not task or "interruption" not in str(task[0]): text[speech_eos_mask] = self.tokenizer.pad # Clean up the text channel @@ -1175,7 +1174,12 @@ def process_batch( logits = self.final_proj(pred_embeddings_audio) # Compute codebook loss - codebook_loss, _ = self.compute_loss(logits, audio_codes_target, audio_codes_lens_target, agent_mask_target=agent_mask if self.cfg.get("mask_user_on_loss", False) else None) + codebook_loss, _ = self.compute_loss( + logits, + audio_codes_target, + audio_codes_lens_target, + agent_mask_target=agent_mask if self.cfg.get("mask_user_on_loss", False) else None, + ) loss = self.parallel_codebook_loss_scale * codebook_loss # Compute local transformer loss if applicable @@ -1187,7 +1191,10 @@ def process_batch( pred_embeddings, audio_codes_target, targets_offset_by_one=False ) local_transformer_loss, _ = self.compute_loss( - local_transformer_logits, audio_codes_target, audio_codes_lens_target, agent_mask_target=agent_mask if self.cfg.get("mask_user_on_loss", False) else None + local_transformer_logits, + audio_codes_target, + audio_codes_lens_target, + agent_mask_target=agent_mask if self.cfg.get("mask_user_on_loss", False) else None, ) loss = loss + self.local_transformer_loss_scale * local_transformer_loss @@ -1308,10 +1315,13 @@ def training_step(self, batch, batch_idx): # randomly drop individual turns. if turn_silence_prob > 0.0: - turn_silence_mask = torch.rand( - user_audio.size(0), - device=user_audio.device, - ) < turn_silence_prob + turn_silence_mask = ( + torch.rand( + user_audio.size(0), + device=user_audio.device, + ) + < turn_silence_prob + ) else: turn_silence_mask = torch.zeros( user_audio.size(0), @@ -1323,10 +1333,13 @@ def training_step(self, batch, batch_idx): if sample_silence_prob > 0.0: B = batch["text"].shape[0] - sample_silence_mask = torch.rand( - B, - device=user_audio.device, - ) < sample_silence_prob + sample_silence_mask = ( + torch.rand( + B, + device=user_audio.device, + ) + < sample_silence_prob + ) indices = batch["user_audio_turn_splitted_indices"].to(user_audio.device) turn_batch_indices = indices[:, 0].long() @@ -1338,9 +1351,7 @@ def training_step(self, batch, batch_idx): dtype=torch.bool, ) - sample_drop_turn_mask[valid_turns] = sample_silence_mask[ - turn_batch_indices[valid_turns] - ] + sample_drop_turn_mask[valid_turns] = sample_silence_mask[turn_batch_indices[valid_turns]] silence_mask = turn_silence_mask | sample_drop_turn_mask else: @@ -1425,7 +1436,7 @@ def training_step(self, batch, batch_idx): trim = min(boundary_trim, copy_len // 2) if trim > 0: turn_emb[:trim] = 0.0 - turn_emb[copy_len - trim:] = 0.0 + turn_emb[copy_len - trim :] = 0.0 if bool(sample_trim_aug[b].item()): do_turn_aug = torch.rand((), device=user_audio_embedded.device).item() < turn_prob @@ -1500,7 +1511,7 @@ def training_step(self, batch, batch_idx): mode="train", task=batch["task"] if self.cfg.get("use_multiturn_dataset", False) else None, agent_mask=batch["agent_mask"] if self.cfg.get("use_multiturn_dataset", False) else None, - user_audio_embedded=user_audio_embedded + user_audio_embedded=user_audio_embedded, ) loss = batch_output.loss codebook_loss = batch_output.codebook_loss diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index fb34090aecd8..d76223466967 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -152,7 +152,7 @@ class StreamingState: audio_prediction_end_idx: torch.Tensor phoneme_prediction_start_idx: torch.Tensor phoneme_prediction_end_idx: torch.Tensor - + gt_phoneme_embeddings: Optional[torch.Tensor] = None # (B, T', E) pre-computed GT embeddings gt_phoneme_lens: Optional[torch.Tensor] = None # (B,) lengths after stacking gt_audio_embeddings: Optional[torch.Tensor] = None # (B, T', E) pre-computed GT audio embeddings @@ -346,7 +346,6 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): mode='train', ) - base_num_tokens = len(self.tokenizer.tokens) # Assign standard special tokens sequentially @@ -655,7 +654,9 @@ def restore_from_pretrained_checkpoint(self, checkpoint_path): checkpoint_state = torch.load(checkpoint_path, map_location='cpu') else: checkpoint_state = torch.load(checkpoint_path, map_location='cpu') - checkpoint_state = set_model_dict_for_partial_init(checkpoint_state, self.state_dict(), allow_partial_copy=True) + checkpoint_state = set_model_dict_for_partial_init( + checkpoint_state, self.state_dict(), allow_partial_copy=True + ) self.load_state_dict(checkpoint_state, strict=True) logging.info(f"Model restored from the checkpoint: {checkpoint_path} !") @@ -691,7 +692,7 @@ def _generate_codec_silence_buffer(self): else: self._codec_sil_codes_buffer.copy_(sil_tensor_converted) self._codec_sil_codes_buffer_unconverted.copy_(sil_tensor_unconverted) - + def streaming_prefill_profile( self, state: StreamingState, @@ -712,7 +713,9 @@ def streaming_prefill_profile( # ----------------------- # TEXT CHANNEL # ----------------------- - text_emb = self.embed_text_tokens(text_tokens, text_lens=None, is_multiturn=self.cfg.get("use_multiturn_dataset", False)) + text_emb = self.embed_text_tokens( + text_tokens, text_lens=None, is_multiturn=self.cfg.get("use_multiturn_dataset", False) + ) if self.cfg.get("use_multiturn_dataset", False): text_emb[text_tokens == self.pad_id] = 0.0 @@ -750,7 +753,7 @@ def streaming_prefill_profile( # Match training channel sum: text + audio profile-token/silence input. combined_emb = text_emb + audio_emb - + if self.cfg.get("condition_on_user_speech", False): combined_emb = combined_emb + user_audio_channel_embedding @@ -791,7 +794,9 @@ def streaming_prefill_profile( state.audio_steps += T # Make the next normal streaming_step continue from silence, not AUDIO_BOS. - state.all_predictions.append(sil_codes_unstacked) # keep silence so that in the target audio user will be silence + state.all_predictions.append( + sil_codes_unstacked + ) # keep silence so that in the target audio user will be silence if self.cfg.get("use_user_speaking_token", False): state.last_audio_codes = torch.full( (B, C * S), @@ -906,7 +911,7 @@ def embed_text_tokens( ) if is_multiturn: - text_mask = (text_tokens != self.tokenizer.pad) + text_mask = text_tokens != self.tokenizer.pad else: text_mask = get_mask_from_lengths(text_lens) @@ -1232,7 +1237,6 @@ def prepare_context_tensors( return context_embedding, context_lens, context_audio_codes, context_audio_codes_lens - def stack_codes(self, codes, codes_lens, bos_id, eos_id, stacking_factor, num_codebooks): """ Stack multiple time steps into the channel dimension to reduce sequence length. @@ -1516,7 +1520,7 @@ def streaming_init( if self.cfg.get("use_multiturn_dataset", False): phoneme_pad_id = getattr(self.phoneme_tokenizer, "pad", -1) - phoneme_mask = (gt_phoneme_stacked[:, 0, :] != phoneme_pad_id) + phoneme_mask = gt_phoneme_stacked[:, 0, :] != phoneme_pad_id gt_phoneme_embeddings = gt_phoneme_embeddings * phoneme_mask.unsqueeze(2) # Process GT audio codes if provided (for teacher forcing) @@ -1615,7 +1619,10 @@ def streaming_step( # Phase 1: Prepare input embedding and determine per-item phase masks next_input, needs_context, needs_phoneme, needs_audio = self._prepare_streaming_input( - state, text_tokens, force_dropout_text, user_audio_channel_embedding=user_audio_channel_embedding, + state, + text_tokens, + force_dropout_text, + user_audio_channel_embedding=user_audio_channel_embedding, ) # Phase 2: Transformer forward pass @@ -1719,7 +1726,6 @@ def streaming_step( state.last_audio_codes = sil.reshape(B, C * S) return state, None, pred_phoneme_tokens - # Phase 3: Update counters and extract predictions audio_codes_next, pred_phoneme_tokens = self._process_predictions( state, needs_context, needs_phoneme, needs_audio @@ -1755,16 +1761,10 @@ def _prepare_streaming_input( turn_text_tokens_seen = getattr(state, "turn_text_tokens_seen", state.text_tokens_seen) needs_phoneme = ( - (~needs_context) - & (turn_text_tokens_seen >= streaming_phonemes_delay) - & (~state.phoneme_stream_ended) + (~needs_context) & (turn_text_tokens_seen >= streaming_phonemes_delay) & (~state.phoneme_stream_ended) ) - needs_audio = ( - (~needs_context) - & (turn_text_tokens_seen >= streaming_speech_delay) - & (~state.finished) - ) + needs_audio = (~needs_context) & (turn_text_tokens_seen >= streaming_speech_delay) & (~state.finished) next_input = torch.zeros(batch_size, 1, self.cfg.embedding_dim, device=device) @@ -1794,7 +1794,7 @@ def _prepare_streaming_input( # Zero out padding tokens exactly like in process_batch if self.cfg.get("use_multiturn_dataset", False): - is_pad = (text_tokens_2d == self.tokenizer.pad) + is_pad = text_tokens_2d == self.tokenizer.pad text_embedded[is_pad] = 0.0 is_eos_token = (text_tokens == self.eos_id) & needs_text # (B,) bool @@ -1837,9 +1837,7 @@ def _prepare_streaming_input( if has_last_phoneme.any() and state.last_phoneme_tokens is not None: last_phoneme_tokens = state.last_phoneme_tokens # (B, S_ph) - last_phoneme_emb = self.embed_phoneme_tokens( - last_phoneme_tokens.unsqueeze(2) - ) # (B, 1, E) + last_phoneme_emb = self.embed_phoneme_tokens(last_phoneme_tokens.unsqueeze(2)) # (B, 1, E) # Match training: PAD phoneme inputs contribute zero embedding. if self.cfg.get("use_multiturn_dataset", False): @@ -1849,7 +1847,9 @@ def _prepare_streaming_input( phoneme_is_pad = last_phoneme_tokens[:, 0] == phoneme_pad_id # (B,) last_phoneme_emb = last_phoneme_emb * (~phoneme_is_pad).view(batch_size, 1, 1).float() else: - raise ValueError("self.phoneme_tokenizer.pad is not defined, so it is not possible to zero-out the phoneme on the padding positon, please verify it!") + raise ValueError( + "self.phoneme_tokenizer.pad is not defined, so it is not possible to zero-out the phoneme on the padding positon, please verify it!" + ) last_mask = has_last_phoneme.view(batch_size, 1, 1).float() phoneme_emb = phoneme_emb + last_phoneme_emb * last_mask @@ -1918,7 +1918,6 @@ def _prepare_streaming_input( return next_input, needs_context, needs_phoneme, needs_audio - def _process_predictions( self, state: StreamingState, diff --git a/nemo/collections/tts/modules/magpietts_inference/inference.py b/nemo/collections/tts/modules/magpietts_inference/inference.py index 947579471dbc..906eb3991eaf 100644 --- a/nemo/collections/tts/modules/magpietts_inference/inference.py +++ b/nemo/collections/tts/modules/magpietts_inference/inference.py @@ -682,8 +682,6 @@ def _compute_end_of_text_flags( return is_end_of_text - - @dataclass class EasyMagpieMultiturnUserAudioInferenceConfig(EasyMagpieInferenceConfig): """Configuration for EasyMagpie multiturn user-audio inference. @@ -830,7 +828,6 @@ def collate_fn(self, batch: List[dict]) -> Dict[str, Any]: } - class _InferenceSubset(torch.utils.data.Dataset): """Subset wrapper that preserves the wrapped dataset collate_fn.""" @@ -845,6 +842,7 @@ def __len__(self): def __getitem__(self, idx: int): return self.dataset[self.indices[idx]] + class EasyMagpieInferenceRunner(BaseInferenceRunner): """Runner for decoder-only EasyMagpieTTSInferenceModel. @@ -989,6 +987,7 @@ def _run_decoder_only_inference( return all_rtf_metrics, generated_audio_paths, codec_file_paths + class EasyMagpieMultiturnUserAudioInferenceRunner(BaseInferenceRunner): """Runner for decoder-only EasyMagpieTTS multiturn user-audio inference. @@ -1065,7 +1064,9 @@ def _move_batch_to_device(batch: Dict[str, Any], device: torch.device) -> Dict[s return out @staticmethod - def _copy_or_link(src: Optional[str], dst: str, required: bool = False, description: str = "audio") -> Optional[str]: + def _copy_or_link( + src: Optional[str], dst: str, required: bool = False, description: str = "audio" + ) -> Optional[str]: """Copy/symlink an audio artifact and optionally fail fast if missing. Evaluation later expects target_audio_*.wav/context_audio_*.wav to exist. @@ -1384,7 +1385,9 @@ def _run_multiturn_generation(self, batch: Dict[str, Any]): return finalize_output, turn_frame_ranges, decode_start_frame, generated_codes @staticmethod - def _save_code_slice(generated_codes, batch_idx: int, start_frame: int, end_frame: int, path: str) -> Optional[str]: + def _save_code_slice( + generated_codes, batch_idx: int, start_frame: int, end_frame: int, path: str + ) -> Optional[str]: if generated_codes is None: return None os.makedirs(os.path.dirname(path), exist_ok=True) @@ -1533,9 +1536,7 @@ def _run_multiturn_user_audio_inference( os.makedirs(debug_user_dir, exist_ok=True) os.makedirs(debug_mixed_dir, exist_ok=True) os.makedirs(debug_full_agent_dir, exist_ok=True) - logging.info( - f"Saving multiturn debug/listening audios under audios_MT: {mt_debug_output_dir}" - ) + logging.info(f"Saving multiturn debug/listening audios under audios_MT: {mt_debug_output_dir}") rank = int(getattr(self, "distributed_rank", 0)) world_size = int(getattr(self, "distributed_world_size", 1)) @@ -1703,9 +1704,13 @@ def _run_multiturn_user_audio_inference( debug_mixed_dir=debug_mixed_dir, ) - audio_seconds = sum( - max(0, int(round((end - start) * samples_per_prediction_frame))) for _, start, end in turn_frame_ranges - ) / sample_rate + audio_seconds = ( + sum( + max(0, int(round((end - start) * samples_per_prediction_frame))) + for _, start, end in turn_frame_ranges + ) + / sample_rate + ) all_rtf_metrics.append( { "inference_time": elapsed, @@ -1716,7 +1721,9 @@ def _run_multiturn_user_audio_inference( self.evaluation_audio_dir = output_dir rank = int(getattr(self, "distributed_rank", 0)) - self.evaluation_manifest_path = os.path.join(output_dir, f"multiturn_user_audio_turn_manifest_rank{rank:04d}.jsonl") + self.evaluation_manifest_path = os.path.join( + output_dir, f"multiturn_user_audio_turn_manifest_rank{rank:04d}.jsonl" + ) self.evaluation_manifest_records = turn_manifest_records with open(self.evaluation_manifest_path, "w", encoding="utf-8") as f: for record in turn_manifest_records: @@ -1776,7 +1783,9 @@ def _save_debug_user_agent_audio( else: prev_turn_end_frame = turn_frame_ranges[turn_id - 1][2] rel_prev_end_frame = prev_turn_end_frame - decode_start_frame - user_start_sample = first_user_delay_out + int(round(rel_prev_end_frame * samples_per_prediction_frame)) + user_start_sample = first_user_delay_out + int( + round(rel_prev_end_frame * samples_per_prediction_frame) + ) user_segments.append((user_start_sample, turn_audio_out)) total_user_len = 0 diff --git a/nemo/collections/tts/modules/nemotron_h_decoder.py b/nemo/collections/tts/modules/nemotron_h_decoder.py index 75a47aeb9d5b..8e487da639e6 100644 --- a/nemo/collections/tts/modules/nemotron_h_decoder.py +++ b/nemo/collections/tts/modules/nemotron_h_decoder.py @@ -106,6 +106,7 @@ def make_mamba_conv_cache_from_sequence( return F.pad(x, (conv_kernel_size - x.size(-1), 0)).contiguous() + def get_cached_mamba_ssm_state( cache_params: "HybridMambaAttentionDynamicCache", layer_idx: int, @@ -132,6 +133,7 @@ def get_cached_mamba_ssm_state( return state + def get_activation_fn(activation: str): """Get activation function by name.""" if activation == "silu" or activation == "swish": @@ -537,11 +539,7 @@ def cuda_kernels_forward( - self.num_heads ) // 2 - has_cache_prefix = ( - cache_params is not None - and cache_position is not None - and cache_position[0] > 0 - ) + has_cache_prefix = cache_params is not None and cache_position is not None and cache_position[0] > 0 is_decode_step = has_cache_prefix and seq_len == 1 if is_decode_step: @@ -645,9 +643,7 @@ def cuda_kernels_forward( raw_for_cache = raw_hidden_states_B_C if self.activation not in ["silu", "swish"]: - conv_out = self.act( - self.conv1d(conv_input)[..., : conv_input.size(-1)].transpose(1, 2) - ) + conv_out = self.act(self.conv1d(conv_input)[..., : conv_input.size(-1)].transpose(1, 2)) else: conv_out = causal_conv1d_fn( x=conv_input, @@ -747,11 +743,7 @@ def torch_forward( dim=-1, ) - has_cache_prefix = ( - cache_params is not None - and cache_position is not None - and cache_position[0] > 0 - ) + has_cache_prefix = cache_params is not None and cache_position is not None and cache_position[0] > 0 is_decode_step = has_cache_prefix and seq_len == 1 # ----------------------- @@ -795,9 +787,7 @@ def torch_forward( conv_input = raw_hidden_states_B_C.transpose(1, 2) raw_for_cache = raw_hidden_states_B_C - conv_out = self.act( - self.conv1d(conv_input)[..., : conv_input.size(-1)].transpose(1, 2) - ) + conv_out = self.act(self.conv1d(conv_input)[..., : conv_input.size(-1)].transpose(1, 2)) # Keep only outputs corresponding to the new chunk. hidden_states_B_C = conv_out[:, -seq_len:, :].contiguous() @@ -836,9 +826,7 @@ def torch_forward( dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) A_expanded = ( - A[..., None, None] - .expand(self.num_heads, self.head_dim, self.ssm_state_size) - .to(dtype=torch.float32) + A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) ) dA = (torch.exp(dt[..., None] * A_expanded)).to(device=cache_device) @@ -1544,8 +1532,8 @@ def custom_forward(hidden_states): cache_position=None, attention_mask=layer_mask, ) - return custom_forward + return custom_forward if self.gradient_checkpointing and self.training: hidden_states = torch.utils.checkpoint.checkpoint( From cf4ae239dbb1b561add2b6f4ba9e5040da86d3b4 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 9 Jun 2026 11:03:20 -0700 Subject: [PATCH 082/109] Remove unused params Signed-off-by: Edresson Casanova --- .../tts/data/text_to_speech_dataset_lhotse_multiturn.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py index 32dfb9f79bd3..1504366324ee 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py @@ -257,9 +257,6 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: else: batch_tokenizer_names.append("english_phoneme") - # Get all agent supervisions in this cut - agent_sups = [sup for sup in cut.supervisions if sup.speaker in self.output_roles] - def _align_codebooks(t): C = t.shape[1] if C < self.num_audio_codebooks: From eb239d352a0e8c5130aae4743c216dd159d37bfa Mon Sep 17 00:00:00 2001 From: Shehzeen Hussain Date: Tue, 9 Jun 2026 15:25:34 -0700 Subject: [PATCH 083/109] short phoneme turn handling Signed-off-by: Shehzeen Hussain Signed-off-by: Edresson Casanova --- .../tts/data/text_to_speech_dataset_lhotse_multiturn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py index 1504366324ee..72c60c05bebf 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py @@ -148,7 +148,7 @@ class MagpieTTSLhotseMultiturnDataset(torch.utils.data.Dataset): are not set externally. Defaults to None. text_context_remapping: Dict defining mapping of multiple text contexts to a single text context. text_context_remapping_prob: Probability of remapping the original text context to a remapped text context. - phoneme_turn_max_words_to_drop: Turns with this many words or fewer keep phoneme tokens as pad_id. + phoneme_turn_max_words_to_drop: Turns with this many words or fewer keep empty phoneme string. """ def __init__( @@ -910,9 +910,6 @@ def build_phoneme_channel( dropout_applied = True continue - if _count_words_ignoring_punctuation(supervision.text) <= phoneme_turn_max_words_to_drop: - continue - if isinstance(phoneme_tokenizer, IPABPETokenizer): ipa_text = _get_supervision_ipa_text(supervision) if language in ignore_phoneme_languages: @@ -920,6 +917,9 @@ def build_phoneme_channel( else: ipa_text = supervision.text + if _count_words_ignoring_punctuation(supervision.text) <= phoneme_turn_max_words_to_drop: + ipa_text = "" + phoneme_ids = phoneme_tokenizer.encode(ipa_text) phoneme_ids = [bos_id] + phoneme_ids + [eos_id] phoneme_ids = torch.as_tensor(phoneme_ids, dtype=torch.long) From 341af8913451739a7077bef66195e49e53a49a06 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 10 Jun 2026 09:19:01 -0700 Subject: [PATCH 084/109] Bug fix on metrics computation Signed-off-by: Edresson Casanova --- examples/tts/magpietts_inference.py | 38 +++++++++++++------ .../modules/magpietts_inference/inference.py | 4 +- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/examples/tts/magpietts_inference.py b/examples/tts/magpietts_inference.py index 99f780f476e1..7e2458ceaeea 100644 --- a/examples/tts/magpietts_inference.py +++ b/examples/tts/magpietts_inference.py @@ -439,16 +439,21 @@ def _wait_for_multiturn_rank_manifests(repeat_audio_dir: str, world_size: int, t raise RuntimeError(f"Timed out waiting for multiturn rank manifests: {missing}") -def _copy_or_link(src: str, dst: str) -> None: +def _copy_or_link(src: str, dst: str, required: bool = False) -> None: if src is None or not os.path.exists(src): - return - os.makedirs(os.path.dirname(dst), exist_ok=True) - try: if os.path.lexists(dst): os.remove(dst) - os.symlink(os.path.abspath(src), dst) - except Exception: - shutil.copyfile(src, dst) + if required: + raise FileNotFoundError(f"Missing required merge source: {src} -> {dst}") + return + + os.makedirs(os.path.dirname(dst), exist_ok=True) + + if os.path.lexists(dst): + os.remove(dst) + + # Prefer real files for evaluator inputs; broken symlinks confuse librosa/UTMOS. + shutil.copyfile(src, dst) def _merge_multiturn_rank_outputs(repeat_audio_dir: str, world_size: int, save_predicted_codes: bool) -> str: @@ -459,6 +464,17 @@ def _merge_multiturn_rank_outputs(repeat_audio_dir: str, world_size: int, save_p function remaps them to contiguous global indices in repeat_audio_dir/ and writes a merged turn-level manifest. """ + # clean previous merged files + for pattern in [ + "predicted_audio_*.wav", + "target_audio_*.wav", + "context_audio_*.wav", + "predicted_codes_*.pt", + ]: + for path in Path(repeat_audio_dir).glob(pattern): + if path.is_symlink() or path.exists(): + path.unlink(missing_ok=True) + merged_records = [] global_idx = 0 @@ -474,23 +490,23 @@ def _merge_multiturn_rank_outputs(repeat_audio_dir: str, world_size: int, save_p for local_idx, record in enumerate(rank_records): pred_src = os.path.join(rank_dir, f"predicted_audio_{local_idx}.wav") pred_dst = os.path.join(repeat_audio_dir, f"predicted_audio_{global_idx}.wav") - _copy_or_link(pred_src, pred_dst) + _copy_or_link(pred_src, pred_dst, required=True) if save_predicted_codes: code_src = os.path.join(rank_dir, f"predicted_codes_{local_idx}.pt") code_dst = os.path.join(repeat_audio_dir, f"predicted_codes_{global_idx}.pt") - _copy_or_link(code_src, code_dst) + _copy_or_link(code_src, code_dst, required=False) target_src = os.path.join(rank_dir, record.get("audio_filepath", f"target_audio_{local_idx}.wav")) target_dst = os.path.join(repeat_audio_dir, f"target_audio_{global_idx}.wav") - _copy_or_link(target_src, target_dst) + _copy_or_link(target_src, target_dst, required=True) context_src = os.path.join( rank_dir, record.get("context_audio_filepath", f"context_audio_{local_idx}.wav"), ) context_dst = os.path.join(repeat_audio_dir, f"context_audio_{global_idx}.wav") - _copy_or_link(context_src, context_dst) + _copy_or_link(context_src, context_dst, required=True) merged = dict(record) merged["audio_filepath"] = f"target_audio_{global_idx}.wav" diff --git a/nemo/collections/tts/modules/magpietts_inference/inference.py b/nemo/collections/tts/modules/magpietts_inference/inference.py index 906eb3991eaf..15c6973a4c2d 100644 --- a/nemo/collections/tts/modules/magpietts_inference/inference.py +++ b/nemo/collections/tts/modules/magpietts_inference/inference.py @@ -867,8 +867,8 @@ def create_dataset( dataset = MagpieTTSDataset( dataset_meta=dataset_meta, sample_rate=self.model.sample_rate, - min_duration=0.5, - max_duration=20, + min_duration=None, + max_duration=None, codec_model_samples_per_frame=self.model.codec_model_samples_per_frame, bos_id=getattr(self.model, "bos_id", None), eos_id=self.model.eos_id, From 20f777baea721e07c79472c6817908a26d5db735 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 16 Jun 2026 05:30:46 -0700 Subject: [PATCH 085/109] Add ECAPA2 SSIM Signed-off-by: Edresson Casanova --- examples/tts/magpietts_inference.py | 31 +++- .../evaluate_generated_audio.py | 172 +++++++++++++++++- .../modules/magpietts_inference/evaluation.py | 9 + 3 files changed, 204 insertions(+), 8 deletions(-) diff --git a/examples/tts/magpietts_inference.py b/examples/tts/magpietts_inference.py index 7e2458ceaeea..0a072d35dae3 100644 --- a/examples/tts/magpietts_inference.py +++ b/examples/tts/magpietts_inference.py @@ -129,6 +129,9 @@ def append_metrics_to_csv(csv_path: str, checkpoint_name: str, dataset: str, met metrics.get('ssim_pred_gt_avg_alternate', ''), metrics.get('ssim_pred_context_avg_alternate', ''), metrics.get('ssim_gt_context_avg_alternate', ''), + metrics.get('ssim_pred_gt_avg_ecapa2', ''), + metrics.get('ssim_pred_context_avg_ecapa2', ''), + metrics.get('ssim_gt_context_avg_ecapa2', ''), metrics.get('cer_gt_audio_cumulative', ''), metrics.get('wer_gt_audio_cumulative', ''), metrics.get('utmosv2_avg', ''), @@ -241,6 +244,9 @@ def turn_sort_key(r): pred_context_ssim_turns = [r.get("pred_context_ssim") for r in turns] pred_gt_ssim_turns = [r.get("pred_gt_ssim") for r in turns] gt_context_ssim_turns = [r.get("gt_context_ssim") for r in turns] + pred_context_ssim_ecapa2_turns = [r.get("pred_context_ssim_ecapa2") for r in turns] + pred_gt_ssim_ecapa2_turns = [r.get("pred_gt_ssim_ecapa2") for r in turns] + gt_context_ssim_ecapa2_turns = [r.get("gt_context_ssim_ecapa2") for r in turns] utmosv2_turns = [r.get("utmosv2") for r in turns] eou_type_turns = [r.get("eou_type") for r in turns] eou_trailing_duration_turns = [r.get("eou_trailing_duration") for r in turns] @@ -258,6 +264,9 @@ def turn_sort_key(r): "pred_context_ssim": _mean_finite(pred_context_ssim_turns), "pred_gt_ssim": _mean_finite(pred_gt_ssim_turns), "gt_context_ssim": _mean_finite(gt_context_ssim_turns), + "pred_context_ssim_ecapa2": _mean_finite(pred_context_ssim_ecapa2_turns), + "pred_gt_ssim_ecapa2": _mean_finite(pred_gt_ssim_ecapa2_turns), + "gt_context_ssim_ecapa2": _mean_finite(gt_context_ssim_ecapa2_turns), "utmosv2": _mean_finite(utmosv2_turns), "eou_trailing_duration": _mean_finite(eou_trailing_duration_turns), "eou_trail_rms_ratio": _mean_finite(eou_trail_rms_ratio_turns), @@ -268,6 +277,9 @@ def turn_sort_key(r): "pred_context_ssim_turns": pred_context_ssim_turns, "pred_gt_ssim_turns": pred_gt_ssim_turns, "gt_context_ssim_turns": gt_context_ssim_turns, + "pred_context_ssim_ecapa2_turns": pred_context_ssim_ecapa2_turns, + "pred_gt_ssim_ecapa2_turns": pred_gt_ssim_ecapa2_turns, + "gt_context_ssim_ecapa2_turns": gt_context_ssim_ecapa2_turns, "utmosv2_turns": utmosv2_turns, "eou_type_turns": eou_type_turns, "eou_trailing_duration_turns": eou_trailing_duration_turns, @@ -302,6 +314,9 @@ def _write_grouped_multiturn_filewise_metrics_csv(csv_path: str, grouped_rows: l "pred_context_ssim", "pred_gt_ssim", "gt_context_ssim", + "pred_context_ssim_ecapa2", + "pred_gt_ssim_ecapa2", + "gt_context_ssim_ecapa2", "utmosv2", "eou_trailing_duration", "eou_trail_rms_ratio", @@ -311,6 +326,9 @@ def _write_grouped_multiturn_filewise_metrics_csv(csv_path: str, grouped_rows: l "pred_context_ssim_turns", "pred_gt_ssim_turns", "gt_context_ssim_turns", + "pred_context_ssim_ecapa2_turns", + "pred_gt_ssim_ecapa2_turns", + "gt_context_ssim_ecapa2_turns", "utmosv2_turns", "eou_type_turns", "eou_trailing_duration_turns", @@ -594,7 +612,9 @@ def run_inference_and_evaluation( "checkpoint_name,dataset,cer_filewise_avg,wer_filewise_avg,cer_cumulative," "wer_cumulative,ssim_pred_gt_avg,ssim_pred_context_avg,ssim_gt_context_avg," "ssim_pred_gt_avg_alternate,ssim_pred_context_avg_alternate," - "ssim_gt_context_avg_alternate,cer_gt_audio_cumulative,wer_gt_audio_cumulative," + "ssim_gt_context_avg_alternate,ssim_pred_gt_avg_ecapa2," + "ssim_pred_context_avg_ecapa2,ssim_gt_context_avg_ecapa2," + "cer_gt_audio_cumulative,wer_gt_audio_cumulative," "utmosv2_avg,total_gen_audio_seconds,frechet_codec_distance," "eou_cutoff_rate,eou_silence_rate,eou_noise_rate,eou_error_rate" ) @@ -711,6 +731,9 @@ def run_inference_and_evaluation( with_utmosv2=eval_config.with_utmosv2, with_fcd=eval_config.with_fcd, codec_model_path=eval_config.codec_model_path, + normalize_volume=eval_config.normalize_volume, + with_ecapa2=eval_config.with_ecapa2, + ecapa2_cache_dir=eval_config.ecapa2_cache_dir, device=eval_config.device, asr_batch_size=eval_config.asr_batch_size, eou_batch_size=eval_config.eou_batch_size, @@ -955,6 +978,9 @@ def _add_common_args(parser: argparse.ArgumentParser) -> None: eval_group.add_argument('--num_repeats', type=int, default=1) eval_group.add_argument('--confidence_level', type=float, default=0.95) eval_group.add_argument('--disable_utmosv2', action='store_true') + eval_group.add_argument('--normalize_volume', action='store_true') + eval_group.add_argument('--use_ecapa2', action='store_true', help='Compute ECAPA2 speaker similarity metrics') + eval_group.add_argument('--ecapa2_cache_dir', type=str, default=None) eval_group.add_argument( '--violin_plot_metrics', type=str, @@ -1155,6 +1181,9 @@ def main(argv=None): with_utmosv2=not args.disable_utmosv2, with_fcd=not args.disable_fcd, codec_model_path=args.codecmodel_path if not args.disable_fcd else None, + normalize_volume=args.normalize_volume, + with_ecapa2=args.use_ecapa2, + ecapa2_cache_dir=args.ecapa2_cache_dir, asr_batch_size=args.asr_batch_size, eou_batch_size=args.eou_batch_size, ) diff --git a/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py b/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py index 9d2540a694ff..08eaa34cef70 100644 --- a/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py +++ b/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py @@ -35,7 +35,7 @@ from nemo.collections.asr.metrics.wer import word_error_rate_detail from nemo.collections.tts.metrics.eou_classifier import EoUClassification, EoUClassifier, EoUType from nemo.collections.tts.metrics.frechet_codec_distance import FrechetCodecDistance -from nemo.collections.tts.parts.utils.tts_dataset_utils import get_text_processor +from nemo.collections.tts.parts.utils.tts_dataset_utils import get_text_processor, normalize_volume as normalize_audio_volume from nemo.utils import logging # Optional import for UTMOSv2 (audio quality metric) @@ -51,11 +51,25 @@ "To install utmosv2 run `pip install git+https://github.com/sarulab-speech/UTMOSv2.git@v1.2.1`." ) +try: + from huggingface_hub import hf_hub_download + + HF_HUB_AVAILABLE = True +except (ImportError, ModuleNotFoundError) as e: + HF_HUB_AVAILABLE = False + logging.warning( + f"huggingface_hub not available: {e}. " + "ECAPA2 speaker similarity metrics will be disabled." + ) + FILEWISE_METRICS_TO_SAVE = [ 'cer', 'wer', 'pred_context_ssim', + 'pred_gt_ssim_ecapa2', + 'pred_context_ssim_ecapa2', + 'gt_context_ssim_ecapa2', 'pred_text', 'gt_text', 'gt_audio_filepath', @@ -209,6 +223,28 @@ def pad_audio_to_min_length(audio_np: np.ndarray, sampling_rate: int, min_second return audio_np +def _normalize_audio_for_metrics(audio_path: Optional[str], output_path: str) -> Optional[str]: + if audio_path is None or audio_path == "": + return None + + audio_np, sampling_rate = sf.read(audio_path, dtype="float32", always_2d=False) + if audio_np.ndim == 2: + audio_np = audio_np.mean(axis=1) + audio_np = normalize_audio_volume(audio_np.flatten().astype(np.float32)) + audio_np = np.asarray(audio_np, dtype=np.float32) + + os.makedirs(os.path.dirname(output_path), exist_ok=True) + sf.write(output_path, audio_np, sampling_rate) + return output_path + + +def _normalize_audio_paths_for_metrics(audio_paths: list, output_dir: str, prefix: str) -> list: + return [ + _normalize_audio_for_metrics(audio_path, os.path.join(output_dir, f"{prefix}_{idx}.wav")) + for idx, audio_path in enumerate(audio_paths) + ] + + def extract_embedding(model, extractor, audio_path, device, sv_model_type): speech_array, sampling_rate = librosa.load(audio_path, sr=16000) # pad to 0.5 seconds as the extractor may not be able to handle very short signals @@ -227,6 +263,27 @@ def extract_embedding(model, extractor, audio_path, device, sv_model_type): return embeddings.squeeze() +def load_ecapa2_model(device, cache_dir=None): + if not HF_HUB_AVAILABLE: + raise ImportError("huggingface_hub is required to download/load ECAPA2") + + # automatically checks for cached file, optionally set `cache_dir` location + ecapa2_file = hf_hub_download(repo_id='Jenthe/ECAPA2', filename='ecapa2.pt', cache_dir=cache_dir) + return torch.jit.load(ecapa2_file, map_location='cpu').to(device).eval() + + +def extract_ecapa2_embedding(model, audio_path, device, model_sr=16000): + speech_array, sampling_rate = librosa.load(audio_path, sr=model_sr) + # pad to 0.5 seconds as the extractor may not be able to handle very short signals + speech_array = pad_audio_to_min_length(speech_array, int(sampling_rate), min_seconds=0.5) + audio = torch.from_numpy(speech_array).float().unsqueeze(0).to(device) + + with torch.inference_mode(): + embeddings = model(audio) + embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) + return embeddings.cpu().detach().squeeze() + + def compute_utmosv2_scores(audio_dir, device): if not UTMOSV2_AVAILABLE: logging.warning("UTMOSv2Calculator not available. Skipping UTMOSv2 score computation.") @@ -265,7 +322,12 @@ def transcribed_batched( def load_evaluation_models( - language="en", sv_model_type="titanet", asr_model_name="stt_en_conformer_transducer_large", device="cuda" + language="en", + sv_model_type="titanet", + asr_model_name="stt_en_conformer_transducer_large", + device="cuda", + with_ecapa2=False, + ecapa2_cache_dir=None, ): """Load ASR and speaker verification models used for evaluation. @@ -284,6 +346,7 @@ def load_evaluation_models( 'whisper_model': None, 'whisper_processor': None, 'feature_extractor': None, + 'ecapa2_model': None, } if language == "en": @@ -314,6 +377,16 @@ def load_evaluation_models( ) models['sv_model_alternate'] = models['sv_model_alternate'].to(device).eval() + if with_ecapa2: + if HF_HUB_AVAILABLE: + logging.info("Loading ECAPA2 model...") + try: + models['ecapa2_model'] = load_ecapa2_model(device=device, cache_dir=ecapa2_cache_dir) + except Exception as e: + logging.warning(f"ECAPA2 model could not be loaded: {e}. ECAPA2 metrics will be set to NaN.") + else: + logging.warning("ECAPA2 requested but huggingface_hub is not available. ECAPA2 metrics will be set to NaN.") + return models @@ -345,6 +418,9 @@ def evaluate_dir( sv_model_type="titanet", asr_model_name="stt_en_conformer_transducer_large", with_utmosv2=True, + normalize_volume=False, + with_ecapa2=False, + ecapa2_cache_dir=None, asr_batch_size=32, eou_batch_size=32, device="cuda", @@ -375,9 +451,36 @@ def evaluate_dir( # Resolve ground-truth and context audio paths for all records gt_audio_paths = [_resolve_path(audio_dir, r.get('audio_filepath')) for r in records] context_audio_paths = [_resolve_path(audio_dir, r.get('context_audio_filepath')) for r in records] + original_audio_file_lists = list(audio_file_lists) + original_gt_audio_paths = list(gt_audio_paths) + original_context_audio_paths = list(context_audio_paths) + generated_audio_dir_for_metrics = generated_audio_dir + normalization_temp_dir = None + + if normalize_volume: + logging.info("Normalizing audio volume before metric computation...") + normalization_temp_dir = tempfile.TemporaryDirectory() + normalized_generated_audio_dir = os.path.join(normalization_temp_dir.name, "generated_audio") + audio_file_lists = _normalize_audio_paths_for_metrics( + audio_file_lists, normalized_generated_audio_dir, "predicted_audio" + ) + gt_audio_paths = _normalize_audio_paths_for_metrics( + gt_audio_paths, os.path.join(normalization_temp_dir.name, "gt_audio"), "gt_audio" + ) + context_audio_paths = _normalize_audio_paths_for_metrics( + context_audio_paths, os.path.join(normalization_temp_dir.name, "context_audio"), "context_audio" + ) + generated_audio_dir_for_metrics = normalized_generated_audio_dir # 2. Load models - models = load_evaluation_models(language, sv_model_type, asr_model_name, device) + models = load_evaluation_models( + language, + sv_model_type, + asr_model_name, + device, + with_ecapa2=with_ecapa2, + ecapa2_cache_dir=ecapa2_cache_dir, + ) asr_model = models['asr_model'] whisper_model = models['whisper_model'] @@ -385,6 +488,7 @@ def evaluate_dir( feature_extractor = models['feature_extractor'] speaker_verification_model = models['sv_model'] speaker_verification_model_alternate = models['sv_model_alternate'] + ecapa2_model = models['ecapa2_model'] # 3. EoU classifier (support for English only) if language == "en": @@ -403,7 +507,7 @@ def evaluate_dir( "UTMOSv2 was requested (with_utmosv2=True) but the UTMOSv2 library is not available. " "UTMOSv2 scores will be set to NaN for all files." ) - utmosv2_scores = compute_utmosv2_scores(generated_audio_dir, device) + utmosv2_scores = compute_utmosv2_scores(generated_audio_dir_for_metrics, device) # 5. ASR transcription in batches logging.info(f"Doing batched ASR transcription with batch size {asr_batch_size}...") @@ -496,6 +600,13 @@ def evaluate_dir( device=device, sv_model_type=sv_model_type, ) + extract_ecapa2_embedding_fn = None + if ecapa2_model is not None: + extract_ecapa2_embedding_fn = partial( + extract_ecapa2_embedding, + model=ecapa2_model, + device=device, + ) # Initialize SSIMs with a default since the context or ground truth audio # may be unavailable. @@ -505,6 +616,14 @@ def evaluate_dir( gt_context_ssim_alternate = float('NaN') pred_gt_ssim = float('NaN') pred_gt_ssim_alternate = float('NaN') + pred_context_ssim_ecapa2 = float('NaN') + gt_context_ssim_ecapa2 = float('NaN') + pred_gt_ssim_ecapa2 = float('NaN') + pred_speaker_embedding_ecapa2 = None + gt_speaker_embedding_ecapa2 = None + + if extract_ecapa2_embedding_fn is not None and (gt_audio_filepath is not None or context_audio_filepath is not None): + pred_speaker_embedding_ecapa2 = extract_ecapa2_embedding_fn(audio_path=pred_audio_filepath) if gt_audio_filepath is not None: # Ground truth vs. predicted @@ -521,6 +640,12 @@ def evaluate_dir( gt_speaker_embedding_alternate, pred_speaker_embedding_alternate, dim=0 ).item() + if extract_ecapa2_embedding_fn is not None: + gt_speaker_embedding_ecapa2 = extract_ecapa2_embedding_fn(audio_path=gt_audio_filepath) + pred_gt_ssim_ecapa2 = torch.nn.functional.cosine_similarity( + gt_speaker_embedding_ecapa2, pred_speaker_embedding_ecapa2, dim=0 + ).item() + if context_audio_filepath is not None: context_speaker_embedding = extract_embedding_fn(audio_path=context_audio_filepath) context_speaker_embedding_alternate = extract_embedding_fn_alternate(audio_path=context_audio_filepath) @@ -544,6 +669,16 @@ def evaluate_dir( gt_context_ssim_alternate = torch.nn.functional.cosine_similarity( gt_speaker_embedding_alternate, context_speaker_embedding_alternate, dim=0 ).item() + + if extract_ecapa2_embedding_fn is not None: + context_speaker_embedding_ecapa2 = extract_ecapa2_embedding_fn(audio_path=context_audio_filepath) + pred_context_ssim_ecapa2 = torch.nn.functional.cosine_similarity( + pred_speaker_embedding_ecapa2, context_speaker_embedding_ecapa2, dim=0 + ).item() + if gt_audio_filepath is not None: + gt_context_ssim_ecapa2 = torch.nn.functional.cosine_similarity( + gt_speaker_embedding_ecapa2, context_speaker_embedding_ecapa2, dim=0 + ).item() file_duration = get_wav_file_duration(pred_audio_filepath) total_generated_audio_seconds += file_duration @@ -576,9 +711,12 @@ def evaluate_dir( 'pred_gt_ssim_alternate': pred_gt_ssim_alternate, 'pred_context_ssim_alternate': pred_context_ssim_alternate, 'gt_context_ssim_alternate': gt_context_ssim_alternate, - 'gt_audio_filepath': gt_audio_filepath, - 'pred_audio_filepath': pred_audio_filepath, - 'context_audio_filepath': context_audio_filepath, + 'pred_gt_ssim_ecapa2': pred_gt_ssim_ecapa2, + 'pred_context_ssim_ecapa2': pred_context_ssim_ecapa2, + 'gt_context_ssim_ecapa2': gt_context_ssim_ecapa2, + 'gt_audio_filepath': original_gt_audio_paths[ridx], + 'pred_audio_filepath': original_audio_file_lists[ridx], + 'context_audio_filepath': original_context_audio_paths[ridx], 'utmosv2': utmosv2_score, 'eou_type': eou_type, 'eou_trailing_duration': eou_trailing, @@ -588,6 +726,9 @@ def evaluate_dir( } ) + if normalization_temp_dir is not None: + normalization_temp_dir.cleanup() + return filewise_metrics @@ -601,6 +742,9 @@ def evaluate( with_utmosv2=True, with_fcd=True, codec_model_path=None, + normalize_volume=False, + with_ecapa2=False, + ecapa2_cache_dir=None, asr_batch_size=32, eou_batch_size=32, device="cuda", @@ -636,6 +780,9 @@ def evaluate( sv_model_type=sv_model_type, asr_model_name=asr_model_name, with_utmosv2=with_utmosv2, + normalize_volume=normalize_volume, + with_ecapa2=with_ecapa2, + ecapa2_cache_dir=ecapa2_cache_dir, asr_batch_size=asr_batch_size, eou_batch_size=eou_batch_size, device=device, @@ -725,6 +872,11 @@ def compute_global_metrics( sum(m['pred_context_ssim_alternate'] for m in filewise_metrics) / n ) avg_metrics['ssim_gt_context_avg_alternate'] = sum(m['gt_context_ssim_alternate'] for m in filewise_metrics) / n + avg_metrics['ssim_pred_gt_avg_ecapa2'] = sum(m.get('pred_gt_ssim_ecapa2', float('nan')) for m in filewise_metrics) / n + avg_metrics['ssim_pred_context_avg_ecapa2'] = ( + sum(m.get('pred_context_ssim_ecapa2', float('nan')) for m in filewise_metrics) / n + ) + avg_metrics['ssim_gt_context_avg_ecapa2'] = sum(m.get('gt_context_ssim_ecapa2', float('nan')) for m in filewise_metrics) / n # Cumulative WER/CER on ground-truth audio transcriptions (if available) gt_audio_texts = [m['gt_audio_text'] for m in filewise_metrics] @@ -778,6 +930,9 @@ def main(): parser.add_argument('--generated_audio_dir', type=str, default=None) parser.add_argument('--whisper_language', type=str, default="en") parser.add_argument('--evalset', type=str, default=None) + parser.add_argument('--normalize_volume', action='store_true') + parser.add_argument('--use_ecapa2', action='store_true', help='Compute ECAPA2 speaker similarity metrics') + parser.add_argument('--ecapa2_cache_dir', type=str, default=None) args = parser.parse_args() if args.evalset is not None: @@ -793,6 +948,9 @@ def main(): args.whisper_language, sv_model_type="wavlm", asr_model_name="nvidia/parakeet-ctc-0.6b", + normalize_volume=args.normalize_volume, + with_ecapa2=args.use_ecapa2, + ecapa2_cache_dir=args.ecapa2_cache_dir, ) diff --git a/nemo/collections/tts/modules/magpietts_inference/evaluation.py b/nemo/collections/tts/modules/magpietts_inference/evaluation.py index bb9013cc9ff1..c9c2b5b7e52c 100644 --- a/nemo/collections/tts/modules/magpietts_inference/evaluation.py +++ b/nemo/collections/tts/modules/magpietts_inference/evaluation.py @@ -43,6 +43,9 @@ class EvaluationConfig: with_utmosv2: Whether to compute UTMOSv2 (Mean Opinion Score) metrics. with_fcd: Whether to compute Frechet Codec Distance metric. codec_model_path: Path to the audio codec model. If None, will skip computing Frechet Codec Distance metric. + normalize_volume: Whether to apply normalize_volume before computing evaluation metrics. + with_ecapa2: Whether to compute ECAPA2 speaker similarity metrics. Disabled by default. + ecapa2_cache_dir: Optional Hugging Face cache directory for the ECAPA2 checkpoint. device: Device to use for running models used during evaluation. """ @@ -53,6 +56,9 @@ class EvaluationConfig: with_utmosv2: bool = True with_fcd: bool = True codec_model_path: str = None + normalize_volume: bool = False + with_ecapa2: bool = False + ecapa2_cache_dir: str = None device: str = "cuda" asr_batch_size: int = 32 eou_batch_size: int = 32 @@ -95,6 +101,9 @@ def evaluate_generated_audio_dir( with_utmosv2=config.with_utmosv2, with_fcd=config.with_fcd, codec_model_path=config.codec_model_path, + normalize_volume=config.normalize_volume, + with_ecapa2=config.with_ecapa2, + ecapa2_cache_dir=config.ecapa2_cache_dir, device=config.device, eou_model_name=config.eou_model_name, asr_batch_size=config.asr_batch_size, From 23a44f27d4fc29600258c32d4a3ebaacc4266eb2 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 16 Jun 2026 05:31:53 -0700 Subject: [PATCH 086/109] Revert "Add ECAPA2 SSIM" This reverts commit cea3a1a5dd8b601908430bcef0c529710beb339e. Signed-off-by: Edresson Casanova --- examples/tts/magpietts_inference.py | 31 +--- .../evaluate_generated_audio.py | 172 +----------------- .../modules/magpietts_inference/evaluation.py | 9 - 3 files changed, 8 insertions(+), 204 deletions(-) diff --git a/examples/tts/magpietts_inference.py b/examples/tts/magpietts_inference.py index 0a072d35dae3..7e2458ceaeea 100644 --- a/examples/tts/magpietts_inference.py +++ b/examples/tts/magpietts_inference.py @@ -129,9 +129,6 @@ def append_metrics_to_csv(csv_path: str, checkpoint_name: str, dataset: str, met metrics.get('ssim_pred_gt_avg_alternate', ''), metrics.get('ssim_pred_context_avg_alternate', ''), metrics.get('ssim_gt_context_avg_alternate', ''), - metrics.get('ssim_pred_gt_avg_ecapa2', ''), - metrics.get('ssim_pred_context_avg_ecapa2', ''), - metrics.get('ssim_gt_context_avg_ecapa2', ''), metrics.get('cer_gt_audio_cumulative', ''), metrics.get('wer_gt_audio_cumulative', ''), metrics.get('utmosv2_avg', ''), @@ -244,9 +241,6 @@ def turn_sort_key(r): pred_context_ssim_turns = [r.get("pred_context_ssim") for r in turns] pred_gt_ssim_turns = [r.get("pred_gt_ssim") for r in turns] gt_context_ssim_turns = [r.get("gt_context_ssim") for r in turns] - pred_context_ssim_ecapa2_turns = [r.get("pred_context_ssim_ecapa2") for r in turns] - pred_gt_ssim_ecapa2_turns = [r.get("pred_gt_ssim_ecapa2") for r in turns] - gt_context_ssim_ecapa2_turns = [r.get("gt_context_ssim_ecapa2") for r in turns] utmosv2_turns = [r.get("utmosv2") for r in turns] eou_type_turns = [r.get("eou_type") for r in turns] eou_trailing_duration_turns = [r.get("eou_trailing_duration") for r in turns] @@ -264,9 +258,6 @@ def turn_sort_key(r): "pred_context_ssim": _mean_finite(pred_context_ssim_turns), "pred_gt_ssim": _mean_finite(pred_gt_ssim_turns), "gt_context_ssim": _mean_finite(gt_context_ssim_turns), - "pred_context_ssim_ecapa2": _mean_finite(pred_context_ssim_ecapa2_turns), - "pred_gt_ssim_ecapa2": _mean_finite(pred_gt_ssim_ecapa2_turns), - "gt_context_ssim_ecapa2": _mean_finite(gt_context_ssim_ecapa2_turns), "utmosv2": _mean_finite(utmosv2_turns), "eou_trailing_duration": _mean_finite(eou_trailing_duration_turns), "eou_trail_rms_ratio": _mean_finite(eou_trail_rms_ratio_turns), @@ -277,9 +268,6 @@ def turn_sort_key(r): "pred_context_ssim_turns": pred_context_ssim_turns, "pred_gt_ssim_turns": pred_gt_ssim_turns, "gt_context_ssim_turns": gt_context_ssim_turns, - "pred_context_ssim_ecapa2_turns": pred_context_ssim_ecapa2_turns, - "pred_gt_ssim_ecapa2_turns": pred_gt_ssim_ecapa2_turns, - "gt_context_ssim_ecapa2_turns": gt_context_ssim_ecapa2_turns, "utmosv2_turns": utmosv2_turns, "eou_type_turns": eou_type_turns, "eou_trailing_duration_turns": eou_trailing_duration_turns, @@ -314,9 +302,6 @@ def _write_grouped_multiturn_filewise_metrics_csv(csv_path: str, grouped_rows: l "pred_context_ssim", "pred_gt_ssim", "gt_context_ssim", - "pred_context_ssim_ecapa2", - "pred_gt_ssim_ecapa2", - "gt_context_ssim_ecapa2", "utmosv2", "eou_trailing_duration", "eou_trail_rms_ratio", @@ -326,9 +311,6 @@ def _write_grouped_multiturn_filewise_metrics_csv(csv_path: str, grouped_rows: l "pred_context_ssim_turns", "pred_gt_ssim_turns", "gt_context_ssim_turns", - "pred_context_ssim_ecapa2_turns", - "pred_gt_ssim_ecapa2_turns", - "gt_context_ssim_ecapa2_turns", "utmosv2_turns", "eou_type_turns", "eou_trailing_duration_turns", @@ -612,9 +594,7 @@ def run_inference_and_evaluation( "checkpoint_name,dataset,cer_filewise_avg,wer_filewise_avg,cer_cumulative," "wer_cumulative,ssim_pred_gt_avg,ssim_pred_context_avg,ssim_gt_context_avg," "ssim_pred_gt_avg_alternate,ssim_pred_context_avg_alternate," - "ssim_gt_context_avg_alternate,ssim_pred_gt_avg_ecapa2," - "ssim_pred_context_avg_ecapa2,ssim_gt_context_avg_ecapa2," - "cer_gt_audio_cumulative,wer_gt_audio_cumulative," + "ssim_gt_context_avg_alternate,cer_gt_audio_cumulative,wer_gt_audio_cumulative," "utmosv2_avg,total_gen_audio_seconds,frechet_codec_distance," "eou_cutoff_rate,eou_silence_rate,eou_noise_rate,eou_error_rate" ) @@ -731,9 +711,6 @@ def run_inference_and_evaluation( with_utmosv2=eval_config.with_utmosv2, with_fcd=eval_config.with_fcd, codec_model_path=eval_config.codec_model_path, - normalize_volume=eval_config.normalize_volume, - with_ecapa2=eval_config.with_ecapa2, - ecapa2_cache_dir=eval_config.ecapa2_cache_dir, device=eval_config.device, asr_batch_size=eval_config.asr_batch_size, eou_batch_size=eval_config.eou_batch_size, @@ -978,9 +955,6 @@ def _add_common_args(parser: argparse.ArgumentParser) -> None: eval_group.add_argument('--num_repeats', type=int, default=1) eval_group.add_argument('--confidence_level', type=float, default=0.95) eval_group.add_argument('--disable_utmosv2', action='store_true') - eval_group.add_argument('--normalize_volume', action='store_true') - eval_group.add_argument('--use_ecapa2', action='store_true', help='Compute ECAPA2 speaker similarity metrics') - eval_group.add_argument('--ecapa2_cache_dir', type=str, default=None) eval_group.add_argument( '--violin_plot_metrics', type=str, @@ -1181,9 +1155,6 @@ def main(argv=None): with_utmosv2=not args.disable_utmosv2, with_fcd=not args.disable_fcd, codec_model_path=args.codecmodel_path if not args.disable_fcd else None, - normalize_volume=args.normalize_volume, - with_ecapa2=args.use_ecapa2, - ecapa2_cache_dir=args.ecapa2_cache_dir, asr_batch_size=args.asr_batch_size, eou_batch_size=args.eou_batch_size, ) diff --git a/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py b/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py index 08eaa34cef70..9d2540a694ff 100644 --- a/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py +++ b/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py @@ -35,7 +35,7 @@ from nemo.collections.asr.metrics.wer import word_error_rate_detail from nemo.collections.tts.metrics.eou_classifier import EoUClassification, EoUClassifier, EoUType from nemo.collections.tts.metrics.frechet_codec_distance import FrechetCodecDistance -from nemo.collections.tts.parts.utils.tts_dataset_utils import get_text_processor, normalize_volume as normalize_audio_volume +from nemo.collections.tts.parts.utils.tts_dataset_utils import get_text_processor from nemo.utils import logging # Optional import for UTMOSv2 (audio quality metric) @@ -51,25 +51,11 @@ "To install utmosv2 run `pip install git+https://github.com/sarulab-speech/UTMOSv2.git@v1.2.1`." ) -try: - from huggingface_hub import hf_hub_download - - HF_HUB_AVAILABLE = True -except (ImportError, ModuleNotFoundError) as e: - HF_HUB_AVAILABLE = False - logging.warning( - f"huggingface_hub not available: {e}. " - "ECAPA2 speaker similarity metrics will be disabled." - ) - FILEWISE_METRICS_TO_SAVE = [ 'cer', 'wer', 'pred_context_ssim', - 'pred_gt_ssim_ecapa2', - 'pred_context_ssim_ecapa2', - 'gt_context_ssim_ecapa2', 'pred_text', 'gt_text', 'gt_audio_filepath', @@ -223,28 +209,6 @@ def pad_audio_to_min_length(audio_np: np.ndarray, sampling_rate: int, min_second return audio_np -def _normalize_audio_for_metrics(audio_path: Optional[str], output_path: str) -> Optional[str]: - if audio_path is None or audio_path == "": - return None - - audio_np, sampling_rate = sf.read(audio_path, dtype="float32", always_2d=False) - if audio_np.ndim == 2: - audio_np = audio_np.mean(axis=1) - audio_np = normalize_audio_volume(audio_np.flatten().astype(np.float32)) - audio_np = np.asarray(audio_np, dtype=np.float32) - - os.makedirs(os.path.dirname(output_path), exist_ok=True) - sf.write(output_path, audio_np, sampling_rate) - return output_path - - -def _normalize_audio_paths_for_metrics(audio_paths: list, output_dir: str, prefix: str) -> list: - return [ - _normalize_audio_for_metrics(audio_path, os.path.join(output_dir, f"{prefix}_{idx}.wav")) - for idx, audio_path in enumerate(audio_paths) - ] - - def extract_embedding(model, extractor, audio_path, device, sv_model_type): speech_array, sampling_rate = librosa.load(audio_path, sr=16000) # pad to 0.5 seconds as the extractor may not be able to handle very short signals @@ -263,27 +227,6 @@ def extract_embedding(model, extractor, audio_path, device, sv_model_type): return embeddings.squeeze() -def load_ecapa2_model(device, cache_dir=None): - if not HF_HUB_AVAILABLE: - raise ImportError("huggingface_hub is required to download/load ECAPA2") - - # automatically checks for cached file, optionally set `cache_dir` location - ecapa2_file = hf_hub_download(repo_id='Jenthe/ECAPA2', filename='ecapa2.pt', cache_dir=cache_dir) - return torch.jit.load(ecapa2_file, map_location='cpu').to(device).eval() - - -def extract_ecapa2_embedding(model, audio_path, device, model_sr=16000): - speech_array, sampling_rate = librosa.load(audio_path, sr=model_sr) - # pad to 0.5 seconds as the extractor may not be able to handle very short signals - speech_array = pad_audio_to_min_length(speech_array, int(sampling_rate), min_seconds=0.5) - audio = torch.from_numpy(speech_array).float().unsqueeze(0).to(device) - - with torch.inference_mode(): - embeddings = model(audio) - embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) - return embeddings.cpu().detach().squeeze() - - def compute_utmosv2_scores(audio_dir, device): if not UTMOSV2_AVAILABLE: logging.warning("UTMOSv2Calculator not available. Skipping UTMOSv2 score computation.") @@ -322,12 +265,7 @@ def transcribed_batched( def load_evaluation_models( - language="en", - sv_model_type="titanet", - asr_model_name="stt_en_conformer_transducer_large", - device="cuda", - with_ecapa2=False, - ecapa2_cache_dir=None, + language="en", sv_model_type="titanet", asr_model_name="stt_en_conformer_transducer_large", device="cuda" ): """Load ASR and speaker verification models used for evaluation. @@ -346,7 +284,6 @@ def load_evaluation_models( 'whisper_model': None, 'whisper_processor': None, 'feature_extractor': None, - 'ecapa2_model': None, } if language == "en": @@ -377,16 +314,6 @@ def load_evaluation_models( ) models['sv_model_alternate'] = models['sv_model_alternate'].to(device).eval() - if with_ecapa2: - if HF_HUB_AVAILABLE: - logging.info("Loading ECAPA2 model...") - try: - models['ecapa2_model'] = load_ecapa2_model(device=device, cache_dir=ecapa2_cache_dir) - except Exception as e: - logging.warning(f"ECAPA2 model could not be loaded: {e}. ECAPA2 metrics will be set to NaN.") - else: - logging.warning("ECAPA2 requested but huggingface_hub is not available. ECAPA2 metrics will be set to NaN.") - return models @@ -418,9 +345,6 @@ def evaluate_dir( sv_model_type="titanet", asr_model_name="stt_en_conformer_transducer_large", with_utmosv2=True, - normalize_volume=False, - with_ecapa2=False, - ecapa2_cache_dir=None, asr_batch_size=32, eou_batch_size=32, device="cuda", @@ -451,36 +375,9 @@ def evaluate_dir( # Resolve ground-truth and context audio paths for all records gt_audio_paths = [_resolve_path(audio_dir, r.get('audio_filepath')) for r in records] context_audio_paths = [_resolve_path(audio_dir, r.get('context_audio_filepath')) for r in records] - original_audio_file_lists = list(audio_file_lists) - original_gt_audio_paths = list(gt_audio_paths) - original_context_audio_paths = list(context_audio_paths) - generated_audio_dir_for_metrics = generated_audio_dir - normalization_temp_dir = None - - if normalize_volume: - logging.info("Normalizing audio volume before metric computation...") - normalization_temp_dir = tempfile.TemporaryDirectory() - normalized_generated_audio_dir = os.path.join(normalization_temp_dir.name, "generated_audio") - audio_file_lists = _normalize_audio_paths_for_metrics( - audio_file_lists, normalized_generated_audio_dir, "predicted_audio" - ) - gt_audio_paths = _normalize_audio_paths_for_metrics( - gt_audio_paths, os.path.join(normalization_temp_dir.name, "gt_audio"), "gt_audio" - ) - context_audio_paths = _normalize_audio_paths_for_metrics( - context_audio_paths, os.path.join(normalization_temp_dir.name, "context_audio"), "context_audio" - ) - generated_audio_dir_for_metrics = normalized_generated_audio_dir # 2. Load models - models = load_evaluation_models( - language, - sv_model_type, - asr_model_name, - device, - with_ecapa2=with_ecapa2, - ecapa2_cache_dir=ecapa2_cache_dir, - ) + models = load_evaluation_models(language, sv_model_type, asr_model_name, device) asr_model = models['asr_model'] whisper_model = models['whisper_model'] @@ -488,7 +385,6 @@ def evaluate_dir( feature_extractor = models['feature_extractor'] speaker_verification_model = models['sv_model'] speaker_verification_model_alternate = models['sv_model_alternate'] - ecapa2_model = models['ecapa2_model'] # 3. EoU classifier (support for English only) if language == "en": @@ -507,7 +403,7 @@ def evaluate_dir( "UTMOSv2 was requested (with_utmosv2=True) but the UTMOSv2 library is not available. " "UTMOSv2 scores will be set to NaN for all files." ) - utmosv2_scores = compute_utmosv2_scores(generated_audio_dir_for_metrics, device) + utmosv2_scores = compute_utmosv2_scores(generated_audio_dir, device) # 5. ASR transcription in batches logging.info(f"Doing batched ASR transcription with batch size {asr_batch_size}...") @@ -600,13 +496,6 @@ def evaluate_dir( device=device, sv_model_type=sv_model_type, ) - extract_ecapa2_embedding_fn = None - if ecapa2_model is not None: - extract_ecapa2_embedding_fn = partial( - extract_ecapa2_embedding, - model=ecapa2_model, - device=device, - ) # Initialize SSIMs with a default since the context or ground truth audio # may be unavailable. @@ -616,14 +505,6 @@ def evaluate_dir( gt_context_ssim_alternate = float('NaN') pred_gt_ssim = float('NaN') pred_gt_ssim_alternate = float('NaN') - pred_context_ssim_ecapa2 = float('NaN') - gt_context_ssim_ecapa2 = float('NaN') - pred_gt_ssim_ecapa2 = float('NaN') - pred_speaker_embedding_ecapa2 = None - gt_speaker_embedding_ecapa2 = None - - if extract_ecapa2_embedding_fn is not None and (gt_audio_filepath is not None or context_audio_filepath is not None): - pred_speaker_embedding_ecapa2 = extract_ecapa2_embedding_fn(audio_path=pred_audio_filepath) if gt_audio_filepath is not None: # Ground truth vs. predicted @@ -640,12 +521,6 @@ def evaluate_dir( gt_speaker_embedding_alternate, pred_speaker_embedding_alternate, dim=0 ).item() - if extract_ecapa2_embedding_fn is not None: - gt_speaker_embedding_ecapa2 = extract_ecapa2_embedding_fn(audio_path=gt_audio_filepath) - pred_gt_ssim_ecapa2 = torch.nn.functional.cosine_similarity( - gt_speaker_embedding_ecapa2, pred_speaker_embedding_ecapa2, dim=0 - ).item() - if context_audio_filepath is not None: context_speaker_embedding = extract_embedding_fn(audio_path=context_audio_filepath) context_speaker_embedding_alternate = extract_embedding_fn_alternate(audio_path=context_audio_filepath) @@ -669,16 +544,6 @@ def evaluate_dir( gt_context_ssim_alternate = torch.nn.functional.cosine_similarity( gt_speaker_embedding_alternate, context_speaker_embedding_alternate, dim=0 ).item() - - if extract_ecapa2_embedding_fn is not None: - context_speaker_embedding_ecapa2 = extract_ecapa2_embedding_fn(audio_path=context_audio_filepath) - pred_context_ssim_ecapa2 = torch.nn.functional.cosine_similarity( - pred_speaker_embedding_ecapa2, context_speaker_embedding_ecapa2, dim=0 - ).item() - if gt_audio_filepath is not None: - gt_context_ssim_ecapa2 = torch.nn.functional.cosine_similarity( - gt_speaker_embedding_ecapa2, context_speaker_embedding_ecapa2, dim=0 - ).item() file_duration = get_wav_file_duration(pred_audio_filepath) total_generated_audio_seconds += file_duration @@ -711,12 +576,9 @@ def evaluate_dir( 'pred_gt_ssim_alternate': pred_gt_ssim_alternate, 'pred_context_ssim_alternate': pred_context_ssim_alternate, 'gt_context_ssim_alternate': gt_context_ssim_alternate, - 'pred_gt_ssim_ecapa2': pred_gt_ssim_ecapa2, - 'pred_context_ssim_ecapa2': pred_context_ssim_ecapa2, - 'gt_context_ssim_ecapa2': gt_context_ssim_ecapa2, - 'gt_audio_filepath': original_gt_audio_paths[ridx], - 'pred_audio_filepath': original_audio_file_lists[ridx], - 'context_audio_filepath': original_context_audio_paths[ridx], + 'gt_audio_filepath': gt_audio_filepath, + 'pred_audio_filepath': pred_audio_filepath, + 'context_audio_filepath': context_audio_filepath, 'utmosv2': utmosv2_score, 'eou_type': eou_type, 'eou_trailing_duration': eou_trailing, @@ -726,9 +588,6 @@ def evaluate_dir( } ) - if normalization_temp_dir is not None: - normalization_temp_dir.cleanup() - return filewise_metrics @@ -742,9 +601,6 @@ def evaluate( with_utmosv2=True, with_fcd=True, codec_model_path=None, - normalize_volume=False, - with_ecapa2=False, - ecapa2_cache_dir=None, asr_batch_size=32, eou_batch_size=32, device="cuda", @@ -780,9 +636,6 @@ def evaluate( sv_model_type=sv_model_type, asr_model_name=asr_model_name, with_utmosv2=with_utmosv2, - normalize_volume=normalize_volume, - with_ecapa2=with_ecapa2, - ecapa2_cache_dir=ecapa2_cache_dir, asr_batch_size=asr_batch_size, eou_batch_size=eou_batch_size, device=device, @@ -872,11 +725,6 @@ def compute_global_metrics( sum(m['pred_context_ssim_alternate'] for m in filewise_metrics) / n ) avg_metrics['ssim_gt_context_avg_alternate'] = sum(m['gt_context_ssim_alternate'] for m in filewise_metrics) / n - avg_metrics['ssim_pred_gt_avg_ecapa2'] = sum(m.get('pred_gt_ssim_ecapa2', float('nan')) for m in filewise_metrics) / n - avg_metrics['ssim_pred_context_avg_ecapa2'] = ( - sum(m.get('pred_context_ssim_ecapa2', float('nan')) for m in filewise_metrics) / n - ) - avg_metrics['ssim_gt_context_avg_ecapa2'] = sum(m.get('gt_context_ssim_ecapa2', float('nan')) for m in filewise_metrics) / n # Cumulative WER/CER on ground-truth audio transcriptions (if available) gt_audio_texts = [m['gt_audio_text'] for m in filewise_metrics] @@ -930,9 +778,6 @@ def main(): parser.add_argument('--generated_audio_dir', type=str, default=None) parser.add_argument('--whisper_language', type=str, default="en") parser.add_argument('--evalset', type=str, default=None) - parser.add_argument('--normalize_volume', action='store_true') - parser.add_argument('--use_ecapa2', action='store_true', help='Compute ECAPA2 speaker similarity metrics') - parser.add_argument('--ecapa2_cache_dir', type=str, default=None) args = parser.parse_args() if args.evalset is not None: @@ -948,9 +793,6 @@ def main(): args.whisper_language, sv_model_type="wavlm", asr_model_name="nvidia/parakeet-ctc-0.6b", - normalize_volume=args.normalize_volume, - with_ecapa2=args.use_ecapa2, - ecapa2_cache_dir=args.ecapa2_cache_dir, ) diff --git a/nemo/collections/tts/modules/magpietts_inference/evaluation.py b/nemo/collections/tts/modules/magpietts_inference/evaluation.py index c9c2b5b7e52c..bb9013cc9ff1 100644 --- a/nemo/collections/tts/modules/magpietts_inference/evaluation.py +++ b/nemo/collections/tts/modules/magpietts_inference/evaluation.py @@ -43,9 +43,6 @@ class EvaluationConfig: with_utmosv2: Whether to compute UTMOSv2 (Mean Opinion Score) metrics. with_fcd: Whether to compute Frechet Codec Distance metric. codec_model_path: Path to the audio codec model. If None, will skip computing Frechet Codec Distance metric. - normalize_volume: Whether to apply normalize_volume before computing evaluation metrics. - with_ecapa2: Whether to compute ECAPA2 speaker similarity metrics. Disabled by default. - ecapa2_cache_dir: Optional Hugging Face cache directory for the ECAPA2 checkpoint. device: Device to use for running models used during evaluation. """ @@ -56,9 +53,6 @@ class EvaluationConfig: with_utmosv2: bool = True with_fcd: bool = True codec_model_path: str = None - normalize_volume: bool = False - with_ecapa2: bool = False - ecapa2_cache_dir: str = None device: str = "cuda" asr_batch_size: int = 32 eou_batch_size: int = 32 @@ -101,9 +95,6 @@ def evaluate_generated_audio_dir( with_utmosv2=config.with_utmosv2, with_fcd=config.with_fcd, codec_model_path=config.codec_model_path, - normalize_volume=config.normalize_volume, - with_ecapa2=config.with_ecapa2, - ecapa2_cache_dir=config.ecapa2_cache_dir, device=config.device, eou_model_name=config.eou_model_name, asr_batch_size=config.asr_batch_size, From 94af99efe49bc521afc7a2f34222f95e64fd3a25 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 18 Jun 2026 08:09:46 -0700 Subject: [PATCH 087/109] Add Speaker filter Signed-off-by: Edresson Casanova --- .../common/data/lhotse/dataloader.py | 18 +++++ .../common/data/lhotse/sampling.py | 39 +++++++++++ .../evaluate_generated_audio.py | 2 +- .../common/test_lhotse_tts_filters.py | 68 +++++++++++++++++++ 4 files changed, 126 insertions(+), 1 deletion(-) diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index 5af5f5d004d7..f99d615d61b4 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -56,6 +56,7 @@ FixedBucketBatchSizeConstraint2D, MultimodalFixedBucketBatchSizeConstraint2D, MultimodalSamplingConstraint, + SpeakerFilter, TokenCountFilter, TokenPerSecondFilter, TokenPerTokenFilter, @@ -149,6 +150,8 @@ class LhotseDataLoadingConfig: # 2.3 Filters on CER and/or cosine speaker similarity of the context audio serving for TTS use cases. max_cer: float | None = float("inf") min_context_speaker_similarity: float | None = -1 + excluded_speaker_ids: Any = None + speaker_filter_fields: list[str] | None = None # 2.4 Filters on validation status. If the validation status is not "pass", the cut will be filtered out. keep: str = "pass" @@ -255,6 +258,19 @@ class LhotseDataLoadingConfig: # our support of object stores and gzipped files that generally don't have indexes of byte offsets per line. slice_length: Optional[int] = None +def resolve_excluded_speaker_ids(excluded_speaker_ids): + if excluded_speaker_ids is None: + return None + + if isinstance(excluded_speaker_ids, str): + excluded_speaker_ids = OmegaConf.load(excluded_speaker_ids) + + excluded_speaker_ids = OmegaConf.to_container(excluded_speaker_ids, resolve=True) + + if isinstance(excluded_speaker_ids, dict): + excluded_speaker_ids = excluded_speaker_ids["excluded_speaker_ids"] + + return excluded_speaker_ids def determine_use_iterable_dataset(use_iterable_dataset: bool, config: DictConfig) -> bool: """Determine whether to use iterable dataset for a given configuration.""" @@ -612,6 +628,8 @@ def get_lhotse_sampler_from_config(config, global_rank, world_size, tokenizer=No # validation status filtering cuts = cuts.filter(ValidationStatusFilter(config.keep)) + # Exclude cuts that contain known test speakers. + cuts = cuts.filter(SpeakerFilter(resolve_excluded_speaker_ids(config.excluded_speaker_ids), speaker_fields=config.speaker_filter_fields)) # CER filtering, same as native NeMo dataloaders. cuts = cuts.filter(CERFilter(config.max_cer)) # Context speaker similarity filtering, same as native NeMo dataloaders. diff --git a/nemo/collections/common/data/lhotse/sampling.py b/nemo/collections/common/data/lhotse/sampling.py index 0c927058c234..e73039d252f9 100644 --- a/nemo/collections/common/data/lhotse/sampling.py +++ b/nemo/collections/common/data/lhotse/sampling.py @@ -308,6 +308,45 @@ def __call__(self, example) -> bool: return True +class SpeakerFilter: + """ + Callable, returns ``False`` if any supervision in a cut belongs to an excluded speaker. + Checks configured supervision attributes/custom fields. + Acts as a pass-through for objects of other type than Cut. + """ + + def __init__( + self, + excluded_speaker_ids: Sequence[str] | None = None, + speaker_fields: Sequence[str] = None, + ) -> None: + self.excluded_speaker_ids = set(ifnone(excluded_speaker_ids, ())) + self.enabled = len(self.excluded_speaker_ids) > 0 + if self.enabled and speaker_fields is None: + raise ValueError( + "SpeakerFilter requires speaker_fields when excluded_speaker_ids is set. " + "Example: speaker_filter_fields=[speaker_id]" + ) + self.speaker_fields = tuple(ifnone(speaker_fields, ())) + + def __call__(self, example) -> bool: + if not self.enabled or not isinstance(example, Cut): + return True + + excluded_speaker_ids = self.excluded_speaker_ids + + for supervision in example.supervisions: + for field in self.speaker_fields: + if supervision.has_custom(field): + speaker_id = getattr(supervision, field) + else: + speaker_id = getattr(supervision, field, None) + + if speaker_id in excluded_speaker_ids: + return False + return True + + class ContextSpeakerSimilarityFilter: """ Callable, returns ``True`` if a cut's context speaker similarity is greater than min_context_speaker_similarity and ``False`` otherwise. diff --git a/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py b/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py index 9d2540a694ff..42640dee124f 100644 --- a/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py +++ b/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py @@ -494,7 +494,7 @@ def evaluate_dir( model=speaker_verification_model_alternate, extractor=feature_extractor, device=device, - sv_model_type=sv_model_type, + sv_model_type="titanet", # alternate is always titanet ) # Initialize SSIMs with a default since the context or ground truth audio diff --git a/tests/collections/common/test_lhotse_tts_filters.py b/tests/collections/common/test_lhotse_tts_filters.py index 877b0886ab97..e03afa7a73da 100644 --- a/tests/collections/common/test_lhotse_tts_filters.py +++ b/tests/collections/common/test_lhotse_tts_filters.py @@ -19,6 +19,7 @@ from nemo.collections.common.data.lhotse.sampling import ( CERFilter, + SpeakerFilter, ContextSpeakerSimilarityFilter, ValidationStatusFilter, ) @@ -141,3 +142,70 @@ def test_cut_validation_status_filter(cut_example): f = ValidationStatusFilter("any_other_status") assert f(cut_example) == False + +def test_cut_speaker_filter_by_speaker(cut_example): + f = SpeakerFilter( + excluded_speaker_ids=["| Language:en Dataset:nvyt2505 Speaker:Zdud2gXLTXY_SPEAKER_02 |"], + speaker_fields=["speaker"], + ) + assert f(cut_example) == False + + f = SpeakerFilter( + excluded_speaker_ids=["some_other_speaker"], + speaker_fields=["speaker"], + ) + assert f(cut_example) == True + + +def test_cut_speaker_filter_by_custom_speaker_id(cut_example): + cut_example.supervisions[0].custom["speaker_id"] = "test_speaker_001" + + f = SpeakerFilter( + excluded_speaker_ids=["test_speaker_001"], + speaker_fields=["speaker_id"], + ) + assert f(cut_example) == False + + f = SpeakerFilter( + excluded_speaker_ids=["some_other_speaker"], + speaker_fields=["speaker_id"], + ) + assert f(cut_example) == True + + +def test_cut_speaker_filter_multiple_fields(cut_example): + cut_example.supervisions[0].custom["speaker_id"] = "test_speaker_001" + + f = SpeakerFilter( + excluded_speaker_ids=["test_speaker_001"], + speaker_fields=["speaker", "speaker_id"], + ) + assert f(cut_example) == False + + f = SpeakerFilter( + excluded_speaker_ids=["some_other_speaker"], + speaker_fields=["speaker", "speaker_id"], + ) + assert f(cut_example) == True + + +def test_cut_speaker_filter_disabled(cut_example): + f = SpeakerFilter( + excluded_speaker_ids=None, + speaker_fields=["speaker_id"], + ) + assert f(cut_example) == True + + f = SpeakerFilter( + excluded_speaker_ids=[], + speaker_fields=["speaker_id"], + ) + assert f(cut_example) == True + + +def test_cut_speaker_filter_requires_fields_when_enabled(): + with pytest.raises(ValueError): + SpeakerFilter( + excluded_speaker_ids=["test_speaker_001"], + speaker_fields=None, + ) \ No newline at end of file From a94514f9131383365b66997fd0933efab43214fb Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Sat, 20 Jun 2026 06:31:18 -0700 Subject: [PATCH 088/109] Add support for the text pretraining checkpoint Signed-off-by: Edresson Casanova --- examples/tts/easy_magpietts.py | 3 + .../tts/models/easy_magpietts_inference.py | 115 ++++++++++++++++++ .../tts/modules/nemotron_h_decoder.py | 11 +- 3 files changed, 120 insertions(+), 9 deletions(-) diff --git a/examples/tts/easy_magpietts.py b/examples/tts/easy_magpietts.py index 74ac2c6e3965..c4c15626bd68 100644 --- a/examples/tts/easy_magpietts.py +++ b/examples/tts/easy_magpietts.py @@ -58,6 +58,9 @@ def main(cfg): if cfg.get("pretrained_model", None): model.restore_from_pretrained_checkpoint(cfg.pretrained_model) + if cfg.get("pretrained_text_model_path", None): + model.restore_from_text_pretrained_checkpoint(cfg.pretrained_text_model_path) + model.maybe_init_from_pretrained_checkpoint(cfg=cfg) if mode in ['train', 'onlinepo_train']: diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index d76223466967..0cd117e4d3cb 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -11,6 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json +import os + import random import time import random @@ -28,6 +31,7 @@ from omegaconf import DictConfig from torch import nn from transformers import AutoConfig, AutoModelForCausalLM +from safetensors.torch import load_file from nemo.core.connectors.save_restore_connector import SaveRestoreConnector from nemo.collections.tts.data.text_to_speech_dataset_lhotse import setup_tokenizers @@ -661,6 +665,117 @@ def restore_from_pretrained_checkpoint(self, checkpoint_path): self.load_state_dict(checkpoint_state, strict=True) logging.info(f"Model restored from the checkpoint: {checkpoint_path} !") + def restore_from_text_pretrained_checkpoint( + self, + checkpoint_dir: str, + *, + strict: bool = True, + allow_partial_copy: bool = False, + load_embeddings: bool = True, + ): + """ + Restore only the decoder/backbone weights from a text-pretrained HF safetensors checkpoint. + + This is intended for checkpoints laid out like: + + model.safetensors.index.json + model-00001-of-00003.safetensors + model-00002-of-00003.safetensors + model-00003-of-00003.safetensors + + The HF NemotronH checkpoint stores decoder weights under `backbone.*`. + This model stores the decoder directly as `self.decoder`, so this method maps: + + backbone.layers.0... -> layers.0... + backbone.norm_f... -> norm_f... + backbone.embeddings... -> embeddings... # optional, skipped by default + + Args: + checkpoint_dir: Directory containing HF safetensors checkpoint shards. + strict: Passed to `self.decoder.load_state_dict`. + allow_partial_copy: Whether to use NeMo partial-init helper for compatible partial tensors. + load_embeddings: Whether to load `backbone.embeddings.*`. + Defaults to False because EasyMagpie replaces decoder embeddings with `self.text_embedding`. + + Returns: + The result of `self.decoder.load_state_dict(...)`. + """ + + if checkpoint_dir is None: + raise ValueError("`checkpoint_dir` must not be None.") + + if not os.path.isdir(checkpoint_dir): + raise ValueError(f"`checkpoint_dir` must be a directory: {checkpoint_dir}") + + index_path = os.path.join(checkpoint_dir, "model.safetensors.index.json") + if not os.path.exists(index_path): + raise FileNotFoundError(f"Missing safetensors index file: {index_path}") + + with open(index_path, "r", encoding="utf-8") as f: + index = json.load(f) + + weight_map = index.get("weight_map", {}) + if not weight_map: + raise ValueError(f"No `weight_map` found in: {index_path}") + + # Load only shards that contain decoder/backbone weights. + shard_to_keys = {} + for hf_key, shard_name in weight_map.items(): + if not hf_key.startswith("backbone."): + continue + + decoder_key = hf_key[len("backbone.") :] + + if not load_embeddings and decoder_key.startswith("embeddings."): + continue + + shard_to_keys.setdefault(shard_name, []).append((hf_key, decoder_key)) + + if not shard_to_keys: + raise ValueError( + f"No `backbone.*` weights found in safetensors index: {index_path}" + ) + + decoder_state = {} + for shard_name, key_pairs in shard_to_keys.items(): + shard_path = os.path.join(checkpoint_dir, shard_name) + if not os.path.exists(shard_path): + raise FileNotFoundError(f"Missing checkpoint shard: {shard_path}") + + shard_state = load_file(shard_path, device="cpu") + + for hf_key, decoder_key in key_pairs: + if hf_key in shard_state: + decoder_state[decoder_key] = shard_state[hf_key] + + del shard_state + + target_decoder_state = self.decoder.state_dict() + + # Drop keys that are not present in the current decoder, unless strict=True. + # This keeps the method robust to HF-only keys. + if not strict: + decoder_state = { + k: v for k, v in decoder_state.items() + if k in target_decoder_state + } + + if allow_partial_copy: + decoder_state = set_model_dict_for_partial_init( + decoder_state, + target_decoder_state, + allow_partial_copy=True, + ) + + self.decoder.load_state_dict(decoder_state, strict=strict) + + logging.info( + "Decoder restored from text-pretrained safetensors checkpoint: " + f"{checkpoint_dir}. " + f"Loaded {len(decoder_state)} tensors. " + f"load_embeddings={load_embeddings}, strict={strict}." + ) + def _generate_codec_silence_buffer(self): codec_device = next(self._codec_model.parameters()).device diff --git a/nemo/collections/tts/modules/nemotron_h_decoder.py b/nemo/collections/tts/modules/nemotron_h_decoder.py index 8e487da639e6..d566ad6c5d20 100644 --- a/nemo/collections/tts/modules/nemotron_h_decoder.py +++ b/nemo/collections/tts/modules/nemotron_h_decoder.py @@ -32,6 +32,7 @@ from nemo.utils import logging +from transformers.activations import ACT2FN # Try to import optimized kernels, fall back to pure PyTorch if unavailable try: @@ -136,15 +137,7 @@ def get_cached_mamba_ssm_state( def get_activation_fn(activation: str): """Get activation function by name.""" - if activation == "silu" or activation == "swish": - return F.silu - elif activation == "gelu": - return F.gelu - elif activation == "relu": - return F.relu - else: - raise ValueError(f"Unsupported activation: {activation}") - + return ACT2FN[activation] @dataclass class NemotronHConfig: From 84ba0c92ad8c3ca90d934f85061beb40b6a0167c Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Sat, 20 Jun 2026 06:31:54 -0700 Subject: [PATCH 089/109] Revert "Add support for the text pretraining checkpoint" This reverts commit 33eeb603b2261b88725fb9ac63b8242fb57ea239. Signed-off-by: Edresson Casanova --- examples/tts/easy_magpietts.py | 3 - .../tts/models/easy_magpietts_inference.py | 115 ------------------ .../tts/modules/nemotron_h_decoder.py | 11 +- 3 files changed, 9 insertions(+), 120 deletions(-) diff --git a/examples/tts/easy_magpietts.py b/examples/tts/easy_magpietts.py index c4c15626bd68..74ac2c6e3965 100644 --- a/examples/tts/easy_magpietts.py +++ b/examples/tts/easy_magpietts.py @@ -58,9 +58,6 @@ def main(cfg): if cfg.get("pretrained_model", None): model.restore_from_pretrained_checkpoint(cfg.pretrained_model) - if cfg.get("pretrained_text_model_path", None): - model.restore_from_text_pretrained_checkpoint(cfg.pretrained_text_model_path) - model.maybe_init_from_pretrained_checkpoint(cfg=cfg) if mode in ['train', 'onlinepo_train']: diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index 0cd117e4d3cb..d76223466967 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -11,9 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json -import os - import random import time import random @@ -31,7 +28,6 @@ from omegaconf import DictConfig from torch import nn from transformers import AutoConfig, AutoModelForCausalLM -from safetensors.torch import load_file from nemo.core.connectors.save_restore_connector import SaveRestoreConnector from nemo.collections.tts.data.text_to_speech_dataset_lhotse import setup_tokenizers @@ -665,117 +661,6 @@ def restore_from_pretrained_checkpoint(self, checkpoint_path): self.load_state_dict(checkpoint_state, strict=True) logging.info(f"Model restored from the checkpoint: {checkpoint_path} !") - def restore_from_text_pretrained_checkpoint( - self, - checkpoint_dir: str, - *, - strict: bool = True, - allow_partial_copy: bool = False, - load_embeddings: bool = True, - ): - """ - Restore only the decoder/backbone weights from a text-pretrained HF safetensors checkpoint. - - This is intended for checkpoints laid out like: - - model.safetensors.index.json - model-00001-of-00003.safetensors - model-00002-of-00003.safetensors - model-00003-of-00003.safetensors - - The HF NemotronH checkpoint stores decoder weights under `backbone.*`. - This model stores the decoder directly as `self.decoder`, so this method maps: - - backbone.layers.0... -> layers.0... - backbone.norm_f... -> norm_f... - backbone.embeddings... -> embeddings... # optional, skipped by default - - Args: - checkpoint_dir: Directory containing HF safetensors checkpoint shards. - strict: Passed to `self.decoder.load_state_dict`. - allow_partial_copy: Whether to use NeMo partial-init helper for compatible partial tensors. - load_embeddings: Whether to load `backbone.embeddings.*`. - Defaults to False because EasyMagpie replaces decoder embeddings with `self.text_embedding`. - - Returns: - The result of `self.decoder.load_state_dict(...)`. - """ - - if checkpoint_dir is None: - raise ValueError("`checkpoint_dir` must not be None.") - - if not os.path.isdir(checkpoint_dir): - raise ValueError(f"`checkpoint_dir` must be a directory: {checkpoint_dir}") - - index_path = os.path.join(checkpoint_dir, "model.safetensors.index.json") - if not os.path.exists(index_path): - raise FileNotFoundError(f"Missing safetensors index file: {index_path}") - - with open(index_path, "r", encoding="utf-8") as f: - index = json.load(f) - - weight_map = index.get("weight_map", {}) - if not weight_map: - raise ValueError(f"No `weight_map` found in: {index_path}") - - # Load only shards that contain decoder/backbone weights. - shard_to_keys = {} - for hf_key, shard_name in weight_map.items(): - if not hf_key.startswith("backbone."): - continue - - decoder_key = hf_key[len("backbone.") :] - - if not load_embeddings and decoder_key.startswith("embeddings."): - continue - - shard_to_keys.setdefault(shard_name, []).append((hf_key, decoder_key)) - - if not shard_to_keys: - raise ValueError( - f"No `backbone.*` weights found in safetensors index: {index_path}" - ) - - decoder_state = {} - for shard_name, key_pairs in shard_to_keys.items(): - shard_path = os.path.join(checkpoint_dir, shard_name) - if not os.path.exists(shard_path): - raise FileNotFoundError(f"Missing checkpoint shard: {shard_path}") - - shard_state = load_file(shard_path, device="cpu") - - for hf_key, decoder_key in key_pairs: - if hf_key in shard_state: - decoder_state[decoder_key] = shard_state[hf_key] - - del shard_state - - target_decoder_state = self.decoder.state_dict() - - # Drop keys that are not present in the current decoder, unless strict=True. - # This keeps the method robust to HF-only keys. - if not strict: - decoder_state = { - k: v for k, v in decoder_state.items() - if k in target_decoder_state - } - - if allow_partial_copy: - decoder_state = set_model_dict_for_partial_init( - decoder_state, - target_decoder_state, - allow_partial_copy=True, - ) - - self.decoder.load_state_dict(decoder_state, strict=strict) - - logging.info( - "Decoder restored from text-pretrained safetensors checkpoint: " - f"{checkpoint_dir}. " - f"Loaded {len(decoder_state)} tensors. " - f"load_embeddings={load_embeddings}, strict={strict}." - ) - def _generate_codec_silence_buffer(self): codec_device = next(self._codec_model.parameters()).device diff --git a/nemo/collections/tts/modules/nemotron_h_decoder.py b/nemo/collections/tts/modules/nemotron_h_decoder.py index d566ad6c5d20..8e487da639e6 100644 --- a/nemo/collections/tts/modules/nemotron_h_decoder.py +++ b/nemo/collections/tts/modules/nemotron_h_decoder.py @@ -32,7 +32,6 @@ from nemo.utils import logging -from transformers.activations import ACT2FN # Try to import optimized kernels, fall back to pure PyTorch if unavailable try: @@ -137,7 +136,15 @@ def get_cached_mamba_ssm_state( def get_activation_fn(activation: str): """Get activation function by name.""" - return ACT2FN[activation] + if activation == "silu" or activation == "swish": + return F.silu + elif activation == "gelu": + return F.gelu + elif activation == "relu": + return F.relu + else: + raise ValueError(f"Unsupported activation: {activation}") + @dataclass class NemotronHConfig: From d766b7af27cd88b6722d3391183255bb329aa00c Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Sat, 20 Jun 2026 10:47:37 -0700 Subject: [PATCH 090/109] Add emotion cosine similarity and emotion match rate metric Signed-off-by: Edresson Casanova --- examples/tts/magpietts_inference.py | 32 +- .../tts/metrics/emotion_encoder.py | 1313 +++++++++++++++++ .../evaluate_generated_audio.py | 149 +- .../modules/magpietts_inference/evaluation.py | 12 + 4 files changed, 1477 insertions(+), 29 deletions(-) create mode 100644 nemo/collections/tts/metrics/emotion_encoder.py diff --git a/examples/tts/magpietts_inference.py b/examples/tts/magpietts_inference.py index 7e2458ceaeea..85899679584c 100644 --- a/examples/tts/magpietts_inference.py +++ b/examples/tts/magpietts_inference.py @@ -129,6 +129,8 @@ def append_metrics_to_csv(csv_path: str, checkpoint_name: str, dataset: str, met metrics.get('ssim_pred_gt_avg_alternate', ''), metrics.get('ssim_pred_context_avg_alternate', ''), metrics.get('ssim_gt_context_avg_alternate', ''), + metrics.get('esim_pred_gt_avg', ''), + metrics.get('ems_pred_gt_avg', ''), metrics.get('cer_gt_audio_cumulative', ''), metrics.get('wer_gt_audio_cumulative', ''), metrics.get('utmosv2_avg', ''), @@ -241,6 +243,8 @@ def turn_sort_key(r): pred_context_ssim_turns = [r.get("pred_context_ssim") for r in turns] pred_gt_ssim_turns = [r.get("pred_gt_ssim") for r in turns] gt_context_ssim_turns = [r.get("gt_context_ssim") for r in turns] + pred_gt_esim_turns = [r.get("pred_gt_esim") for r in turns] + pred_gt_ems_turns = [r.get("pred_gt_ems") for r in turns] utmosv2_turns = [r.get("utmosv2") for r in turns] eou_type_turns = [r.get("eou_type") for r in turns] eou_trailing_duration_turns = [r.get("eou_trailing_duration") for r in turns] @@ -258,6 +262,8 @@ def turn_sort_key(r): "pred_context_ssim": _mean_finite(pred_context_ssim_turns), "pred_gt_ssim": _mean_finite(pred_gt_ssim_turns), "gt_context_ssim": _mean_finite(gt_context_ssim_turns), + "pred_gt_esim": _mean_finite(pred_gt_esim_turns), + "pred_gt_ems": _mean_finite(pred_gt_ems_turns), "utmosv2": _mean_finite(utmosv2_turns), "eou_trailing_duration": _mean_finite(eou_trailing_duration_turns), "eou_trail_rms_ratio": _mean_finite(eou_trail_rms_ratio_turns), @@ -268,6 +274,8 @@ def turn_sort_key(r): "pred_context_ssim_turns": pred_context_ssim_turns, "pred_gt_ssim_turns": pred_gt_ssim_turns, "gt_context_ssim_turns": gt_context_ssim_turns, + "pred_gt_esim_turns": pred_gt_esim_turns, + "pred_gt_ems_turns": pred_gt_ems_turns, "utmosv2_turns": utmosv2_turns, "eou_type_turns": eou_type_turns, "eou_trailing_duration_turns": eou_trailing_duration_turns, @@ -302,6 +310,8 @@ def _write_grouped_multiturn_filewise_metrics_csv(csv_path: str, grouped_rows: l "pred_context_ssim", "pred_gt_ssim", "gt_context_ssim", + "pred_gt_esim", + "pred_gt_ems", "utmosv2", "eou_trailing_duration", "eou_trail_rms_ratio", @@ -311,6 +321,8 @@ def _write_grouped_multiturn_filewise_metrics_csv(csv_path: str, grouped_rows: l "pred_context_ssim_turns", "pred_gt_ssim_turns", "gt_context_ssim_turns", + "pred_gt_esim_turns", + "pred_gt_ems_turns", "utmosv2_turns", "eou_type_turns", "eou_trailing_duration_turns", @@ -594,7 +606,8 @@ def run_inference_and_evaluation( "checkpoint_name,dataset,cer_filewise_avg,wer_filewise_avg,cer_cumulative," "wer_cumulative,ssim_pred_gt_avg,ssim_pred_context_avg,ssim_gt_context_avg," "ssim_pred_gt_avg_alternate,ssim_pred_context_avg_alternate," - "ssim_gt_context_avg_alternate,cer_gt_audio_cumulative,wer_gt_audio_cumulative," + "ssim_gt_context_avg_alternate,esim_pred_gt_avg,ems_pred_gt_avg," + "cer_gt_audio_cumulative,wer_gt_audio_cumulative," "utmosv2_avg,total_gen_audio_seconds,frechet_codec_distance," "eou_cutoff_rate,eou_silence_rate,eou_noise_rate,eou_error_rate" ) @@ -711,6 +724,10 @@ def run_inference_and_evaluation( with_utmosv2=eval_config.with_utmosv2, with_fcd=eval_config.with_fcd, codec_model_path=eval_config.codec_model_path, + with_emotion_metrics=eval_config.with_emotion_metrics, + emotion_model_size=eval_config.emotion_model_size, + emotion_embedding_type=eval_config.emotion_embedding_type, + emotion_cache_dir=eval_config.emotion_cache_dir, device=eval_config.device, asr_batch_size=eval_config.asr_batch_size, eou_batch_size=eval_config.eou_batch_size, @@ -955,6 +972,15 @@ def _add_common_args(parser: argparse.ArgumentParser) -> None: eval_group.add_argument('--num_repeats', type=int, default=1) eval_group.add_argument('--confidence_level', type=float, default=0.95) eval_group.add_argument('--disable_utmosv2', action='store_true') + eval_group.add_argument('--with_emotion_metrics', action='store_true') + eval_group.add_argument('--emotion_model_size', type=str, default="small", choices=["small", "large"]) + eval_group.add_argument( + '--emotion_embedding_type', + type=str, + default="head_concat", + choices=["head_concat", "head_mean", "score_vector"], + ) + eval_group.add_argument('--emotion_cache_dir', type=str, default=None) eval_group.add_argument( '--violin_plot_metrics', type=str, @@ -1155,6 +1181,10 @@ def main(argv=None): with_utmosv2=not args.disable_utmosv2, with_fcd=not args.disable_fcd, codec_model_path=args.codecmodel_path if not args.disable_fcd else None, + with_emotion_metrics=args.with_emotion_metrics, + emotion_model_size=args.emotion_model_size, + emotion_embedding_type=args.emotion_embedding_type, + emotion_cache_dir=args.emotion_cache_dir, asr_batch_size=args.asr_batch_size, eou_batch_size=args.eou_batch_size, ) diff --git a/nemo/collections/tts/metrics/emotion_encoder.py b/nemo/collections/tts/metrics/emotion_encoder.py new file mode 100644 index 000000000000..7b5d0de9cb26 --- /dev/null +++ b/nemo/collections/tts/metrics/emotion_encoder.py @@ -0,0 +1,1313 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Lightweight LAION Empathic Insight Voice interface. + +This script provides a Hugging Face-style Python class for LAION's +Empathic Insight Voice models without using ModelScope. + +It supports: + + 1. Restoring the Whisper encoder from Hugging Face. + 2. Restoring LAION classifier MLP heads from Hugging Face .pth files. + 3. Extracting the full Whisper encoder embedding: + [B, 1500, 768] + 4. Extracting one classifier projection embedding: + Small: [B, 64] + Large: [B, 128] + 5. Extracting an SV-style emotion similarity embedding by concatenating + multiple classifier projection embeddings: + Small, 40 labels: [B, 40 * 64] = [B, 2560] + Large, 40 labels: [B, 40 * 128] = [B, 5120] + 6. Extracting an official-style emotion score vector: + [B, num_labels] + 7. Computing ranked emotion predictions and cosine similarity. + +Recommended for emotion similarity: + + model = EmpathicInsightVoice.from_pretrained(size="small", device="cuda") + emb = model.extract_emotion_embedding("audio.wav", embedding_type="head_concat") + sim = model.emotion_similarity("a.wav", "b.wav", embedding_type="head_concat") + +Notes: + + - The official model is a collection of independent expert heads. Each head + predicts one emotion or attribute score. + - The "head_concat" embedding is an engineering adaptation for + speaker-verification-style similarity. It concatenates the learned + projection outputs from multiple classifier heads. + - The "score_vector" embedding is closer to the documented inference output: + a vector of raw emotion intensity scores. +""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path +from typing import Any, Optional, Sequence, Union + +import librosa +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from huggingface_hub import hf_hub_download +from transformers import WhisperForConditionalGeneration, WhisperProcessor + + +# ============================================================================= +# Label mapping +# ============================================================================= + +# MIRRORING-style 12 emotions (https://bench.theliva.ai/legacy/mirroring.html): +# Amusement, Anger, Elation, Impatience, Surprise, +# Emotional Numbness, Contemplation, Disappointment, +# Confusion, Pride, Affection, Sadness. +# +# The public-facing labels below are Python-friendly. +# Some labels map to LAION's original longer checkpoint names: +# impatience -> model_Impatience_and_Irritability_best.pth +# surprise -> model_Astonishment_Surprise_best.pth + +LAION_LABEL_TO_FILENAME: dict[str, str] = { + "amusement": "model_Amusement_best.pth", + "anger": "model_Anger_best.pth", + "elation": "model_Elation_best.pth", + "impatience": "model_Impatience_and_Irritability_best.pth", + "surprise": "model_Astonishment_Surprise_best.pth", + "emotional_numbness": "model_Emotional_Numbness_best.pth", + "contemplation": "model_Contemplation_best.pth", + "disappointment": "model_Disappointment_best.pth", + "confusion": "model_Confusion_best.pth", + "pride": "model_Pride_best.pth", + "affection": "model_Affection_best.pth", + "sadness": "model_Sadness_best.pth", +} + + +PRIMARY_EMOTION_LABELS: list[str] = [ + "amusement", + "anger", + "elation", + "impatience", + "surprise", + "emotional_numbness", + "contemplation", + "disappointment", + "confusion", + "pride", + "affection", + "sadness", +] + + +AUXILIARY_SIMILARITY_LABELS: list[str] = [] + +# ============================================================================= +# Model architecture specs +# ============================================================================= + +MODEL_SPECS: dict[str, dict[str, Any]] = { + "small": { + "repo_id": "laion/Empathic-Insight-Voice-Small", + "whisper_model_id": "mkrausio/EmoWhisper-AnS-Small-v0.1", + "sample_rate": 16000, + "max_audio_seconds": 30.0, + "seq_len": 1500, + "embed_dim": 768, + "projection_dim": 64, + "mlp_hidden_dims": [64, 32, 16], + "mlp_dropouts": [0.0, 0.1, 0.1, 0.1], + }, + "large": { + "repo_id": "laion/Empathic-Insight-Voice-Large", + "whisper_model_id": "mkrausio/EmoWhisper-AnS-Small-v0.1", + "sample_rate": 16000, + "max_audio_seconds": 30.0, + "seq_len": 1500, + "embed_dim": 768, + "projection_dim": 128, + "mlp_hidden_dims": [128, 64, 32], + "mlp_dropouts": [0.0, 0.1, 0.1, 0.1], + }, +} + + +# ============================================================================= +# MLP head +# ============================================================================= + +class FullEmbeddingMLP(nn.Module): + """Classifier head used by Empathic Insight Voice. + + The model receives a full Whisper encoder sequence embedding: + + [batch, seq_len, embed_dim] + + For Empathic Insight Voice this is normally: + + [batch, 1500, 768] + + It then performs: + + flatten -> projection -> MLP -> scalar score + + The projection output is useful as an SV-style emotion embedding: + + Small: [batch, 64] + Large: [batch, 128] + + Each restored classifier head has its own projection layer. Therefore, + "anger" projection, "sadness" projection, and "arousal" projection are + all different learned spaces. + """ + + def __init__( + self, + seq_len: int, + embed_dim: int, + projection_dim: int, + mlp_hidden_dims: Sequence[int], + mlp_dropout_rates: Sequence[float], + ) -> None: + super().__init__() + + if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1: + raise ValueError( + "Dropout rates length error. " + f"Expected {len(mlp_hidden_dims) + 1}, " + f"got {len(mlp_dropout_rates)}." + ) + + self.seq_len = seq_len + self.embed_dim = embed_dim + self.projection_dim = projection_dim + + self.flatten = nn.Flatten() + self.proj = nn.Linear(seq_len * embed_dim, projection_dim) + + layers: list[nn.Module] = [ + nn.ReLU(), + nn.Dropout(mlp_dropout_rates[0]), + ] + + current_dim = projection_dim + for i, hidden_dim in enumerate(mlp_hidden_dims): + layers.extend( + [ + nn.Linear(current_dim, hidden_dim), + nn.ReLU(), + nn.Dropout(mlp_dropout_rates[i + 1]), + ] + ) + current_dim = hidden_dim + + layers.append(nn.Linear(current_dim, 1)) + self.mlp = nn.Sequential(*layers) + + def extract_projected_embedding(self, x: torch.Tensor) -> torch.Tensor: + """Return the classifier projection embedding before the MLP. + + Args: + x: + Whisper embedding with shape [B, seq_len, embed_dim], or + [B, 1, seq_len, embed_dim]. + + Returns: + Projected embedding with shape [B, projection_dim]. + """ + if x.ndim == 4 and x.shape[1] == 1: + x = x.squeeze(1) + + if x.ndim != 3: + raise ValueError( + f"Expected x with shape [B, T, C], got shape {tuple(x.shape)}." + ) + + return self.proj(self.flatten(x)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Return one scalar score per input example.""" + projected = self.extract_projected_embedding(x) + return self.mlp(projected) + + +# ============================================================================= +# Main class +# ============================================================================= + +class EmpathicInsightVoice(nn.Module): + """Lightweight Hugging Face-style Empathic Insight Voice class. + + This class intentionally does not inherit from NeMo ModelPT. It behaves + like a normal PyTorch/Hugging Face utility model. + + Main methods: + + - extract_whisper_embedding(audio_path) + Returns [1, 1500, 768]. + + - extract_classifier_projection(audio_path, label) + Returns one head-specific projection: + Small: [1, 64] + Large: [1, 128] + + - extract_emotion_embedding(audio_path, embedding_type="head_concat") + Recommended SV-style emotion similarity embedding. + + - predict_emotions_from_embedding(embedding) + Returns raw scores and ranked softmax-like top emotions. + + - compute(audio_path) + Returns emotion predictions and optionally an embedding. + + - emotion_similarity(audio_path_a, audio_path_b) + Returns cosine similarity between extracted emotion embeddings. + """ + + def __init__( + self, + size: str = "small", + device: Union[str, torch.device] = "cuda", + mlp_device: Optional[Union[str, torch.device]] = None, + cache_dir: Optional[Union[str, Path]] = None, + cache_classifiers: bool = True, + load_all_classifiers: bool = False, + top_k_emotions: int = 5, + torch_dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ) -> None: + super().__init__() + + if size not in MODEL_SPECS: + raise ValueError( + f"Unsupported size={size!r}. Expected one of {sorted(MODEL_SPECS)}." + ) + + self.size = size + self.spec = MODEL_SPECS[size] + self.cache_classifiers = cache_classifiers + self.top_k_emotions = top_k_emotions + self.cache_dir = Path(cache_dir) if cache_dir is not None else None + + requested_device = torch.device(device) + if requested_device.type == "cuda" and not torch.cuda.is_available(): + requested_device = torch.device("cpu") + + self.device = requested_device + self.mlp_device = ( + torch.device(mlp_device) + if mlp_device is not None + else self.device + ) + + self.sample_rate = int(self.spec["sample_rate"]) + self.max_audio_seconds = float(self.spec["max_audio_seconds"]) + + # Load Whisper processor and encoder model. + self.processor = WhisperProcessor.from_pretrained( + self.spec["whisper_model_id"], + cache_dir=str(self.cache_dir) if self.cache_dir is not None else None, + trust_remote_code=trust_remote_code, + ) + + whisper_kwargs: dict[str, Any] = { + "cache_dir": str(self.cache_dir) if self.cache_dir is not None else None, + "trust_remote_code": trust_remote_code, + } + if torch_dtype is not None: + whisper_kwargs["torch_dtype"] = torch_dtype + + self.whisper_model = WhisperForConditionalGeneration.from_pretrained( + self.spec["whisper_model_id"], + **whisper_kwargs, + ).to(self.device) + + self.whisper_model.eval() + + # Restored MLP heads are stored here. + # + # ModuleDict keys must be sanitized because labels can contain punctuation. + self.classifiers = nn.ModuleDict() + + if load_all_classifiers: + self.load_classifiers() + + @classmethod + def from_pretrained( + cls, + size: str = "small", + **kwargs: Any, + ) -> "EmpathicInsightVoice": + """Construct the model using Hugging Face checkpoints. + + Example: + model = EmpathicInsightVoice.from_pretrained( + size="small", + device="cuda", + mlp_device="cuda", + ) + """ + return cls(size=size, **kwargs) + + @property + def repo_id(self) -> str: + """Hugging Face repo ID for the selected model size.""" + return str(self.spec["repo_id"]) + + @property + def available_labels(self) -> list[str]: + """All labels known to this script.""" + return list(LAION_LABEL_TO_FILENAME.keys()) + + @property + def projection_dim(self) -> int: + """Classifier projection dimension for the selected model size.""" + return int(self.spec["projection_dim"]) + + # ------------------------------------------------------------------------- + # Audio and Whisper embedding extraction + # ------------------------------------------------------------------------- + + @torch.no_grad() + def extract_whisper_embedding(self, audio_path: Union[str, Path]) -> torch.Tensor: + """Extract the full Whisper encoder embedding from an audio file. + + Args: + audio_path: + Path to an audio file readable by librosa. + + Returns: + Tensor with shape [1, 1500, 768]. + """ + waveform, _ = librosa.load(str(audio_path), sr=self.sample_rate, mono=True) + waveform = self._prepare_waveform(waveform) + return self.extract_whisper_embedding_from_waveform(waveform) + + @torch.no_grad() + def extract_whisper_embedding_from_waveform( + self, + waveform: np.ndarray, + ) -> torch.Tensor: + """Extract the full Whisper encoder embedding from a waveform. + + Args: + waveform: + Mono waveform at self.sample_rate. + + Returns: + Tensor with shape [1, 1500, 768]. + """ + waveform = self._prepare_waveform(waveform) + + input_features = self.processor( + waveform, + sampling_rate=self.sample_rate, + return_tensors="pt", + ).input_features + + input_features = input_features.to(self.device) + input_features = input_features.to(self.whisper_model.dtype) + + encoder_outputs = self.whisper_model.get_encoder()( + input_features=input_features + ) + + embedding = encoder_outputs.last_hidden_state + embedding = self._pad_or_trim_embedding(embedding) + + return embedding + + # ------------------------------------------------------------------------- + # Classifier projection extraction + # ------------------------------------------------------------------------- + + @torch.no_grad() + def extract_classifier_projection( + self, + audio_path: Union[str, Path], + label: str, + normalize: bool = True, + ) -> torch.Tensor: + """Extract one head-specific classifier projection embedding. + + This is the closest equivalent to extracting an x-vector-like embedding + from one specific emotion classifier head. + + Flow: + audio -> Whisper encoder -> label-specific classifier projection + + Args: + audio_path: + Input audio path. + label: + Label whose classifier projection should be used, for example: + "anger", "sadness", "arousal". + normalize: + If True, apply L2 normalization. + + Returns: + Small: [1, 64] + Large: [1, 128] + """ + whisper_embedding = self.extract_whisper_embedding(audio_path) + return self.extract_classifier_projection_from_whisper_embedding( + whisper_embedding=whisper_embedding, + label=label, + normalize=normalize, + ) + + @torch.no_grad() + def extract_classifier_projection_from_whisper_embedding( + self, + whisper_embedding: torch.Tensor, + label: str, + normalize: bool = True, + ) -> torch.Tensor: + """Extract one classifier projection from an existing Whisper embedding.""" + classifier = self._get_classifier(label) + param = next(classifier.parameters()) + + working_embedding = whisper_embedding.to(device=param.device, dtype=param.dtype) + projected = classifier.extract_projected_embedding(working_embedding) + + projected = projected.float() + if normalize: + projected = F.normalize(projected, p=2, dim=-1) + + return projected + + # ------------------------------------------------------------------------- + # Emotion similarity embeddings + # ------------------------------------------------------------------------- + + @torch.no_grad() + def extract_emotion_embedding( + self, + audio_path: Union[str, Path], + labels: Optional[Sequence[str]] = None, + embedding_type: str = "head_concat", + normalize: bool = True, + include_auxiliary: bool = False, + ) -> torch.Tensor: + """Extract a fixed-dimensional emotion embedding. + + Recommended for SV-style emotion similarity: + embedding_type="head_concat" + + Supported embedding types: + + 1. "head_concat" + Concatenate the projection output of each selected classifier head. + + Small, 40 primary labels: + [1, 40 * 64] = [1, 2560] + + Large, 40 primary labels: + [1, 40 * 128] = [1, 5120] + + 2. "head_mean" + Average the projection outputs across selected heads. + + Small: + [1, 64] + + Large: + [1, 128] + + 3. "score_vector" + Use raw scalar outputs from the selected classifier heads. + + Shape: + [1, num_labels] + + This is closest to the official annotation output. + + Args: + audio_path: + Input audio path. + labels: + Labels to use. If None, PRIMARY_EMOTION_LABELS are used. + embedding_type: + "head_concat", "head_mean", or "score_vector". + normalize: + If True, apply L2 normalization to the final embedding. + include_auxiliary: + If labels is None, append AUXILIARY_SIMILARITY_LABELS. + + Returns: + torch.Tensor fixed-dimensional emotion embedding. + """ + whisper_embedding = self.extract_whisper_embedding(audio_path) + labels_to_run = self._default_similarity_labels( + labels=labels, + include_auxiliary=include_auxiliary, + ) + + return self.extract_emotion_embedding_from_whisper_embedding( + whisper_embedding=whisper_embedding, + labels=labels_to_run, + embedding_type=embedding_type, + normalize=normalize, + ) + + @torch.no_grad() + def extract_emotion_embedding_from_whisper_embedding( + self, + whisper_embedding: torch.Tensor, + labels: Optional[Sequence[str]] = None, + embedding_type: str = "head_concat", + normalize: bool = True, + include_auxiliary: bool = False, + ) -> torch.Tensor: + """Extract a fixed-dimensional emotion embedding from Whisper features.""" + labels_to_run = self._default_similarity_labels( + labels=labels, + include_auxiliary=include_auxiliary, + ) + + if embedding_type == "score_vector": + prediction = self.predict_emotions_from_embedding( + embedding=whisper_embedding, + labels=labels_to_run, + return_raw_scores=True, + rank_scores=False, + ) + raw_scores = prediction["raw_scores"] + + output = torch.tensor( + [raw_scores[label] for label in labels_to_run], + dtype=torch.float32, + ).unsqueeze(0) + + elif embedding_type in {"head_concat", "head_mean"}: + projected_embeddings: list[torch.Tensor] = [] + + for label in labels_to_run: + classifier = self._get_classifier(label) + param = next(classifier.parameters()) + + working_embedding = whisper_embedding.to( + device=param.device, + dtype=param.dtype, + ) + + projected = classifier.extract_projected_embedding(working_embedding) + projected_embeddings.append(projected.float().cpu()) + + if embedding_type == "head_concat": + output = torch.cat(projected_embeddings, dim=-1) + else: + output = torch.stack(projected_embeddings, dim=0).mean(dim=0) + + else: + raise ValueError( + f"Unsupported embedding_type={embedding_type!r}. " + "Expected 'head_concat', 'head_mean', or 'score_vector'." + ) + + if normalize: + output = F.normalize(output, p=2, dim=-1) + + return output + + @torch.no_grad() + def emotion_similarity( + self, + audio_path_a: Union[str, Path], + audio_path_b: Union[str, Path], + labels: Optional[Sequence[str]] = None, + embedding_type: str = "head_concat", + include_auxiliary: bool = False, + ) -> float: + """Compute cosine similarity between two audios in emotion space. + + Args: + audio_path_a: + First audio path. + audio_path_b: + Second audio path. + labels: + Optional label subset. + embedding_type: + "head_concat", "head_mean", or "score_vector". + include_auxiliary: + If labels is None, append auxiliary labels. + + Returns: + Cosine similarity as a Python float. + """ + emb_a = self.extract_emotion_embedding( + audio_path=audio_path_a, + labels=labels, + embedding_type=embedding_type, + normalize=True, + include_auxiliary=include_auxiliary, + ) + emb_b = self.extract_emotion_embedding( + audio_path=audio_path_b, + labels=labels, + embedding_type=embedding_type, + normalize=True, + include_auxiliary=include_auxiliary, + ) + + return float(F.cosine_similarity(emb_a, emb_b, dim=-1).item()) + + # ------------------------------------------------------------------------- + # Prediction + # ------------------------------------------------------------------------- + @torch.no_grad() + def compare_emotion_pair( + self, + audio_path_a: Union[str, Path], + audio_path_b: Union[str, Path], + labels: Optional[Sequence[str]] = None, + embedding_type: str = "score_vector", + ) -> dict[str, Any]: + """Compare two audio files using the 12-emotion set. + + This method does not perform corpus-level ranking. It only returns: + + - top emotion for audio A + - top emotion for audio B + - matched emotion label if both top emotions match + - emotion similarity + + Args: + audio_path_a: + First audio file. + audio_path_b: + Second audio file. + labels: + Optional subset of labels. Defaults to PRIMARY_EMOTION_LABELS, + which is the 12-emotion set. + embedding_type: + Similarity representation: + - "score_vector": cosine over raw 12-emotion score vector. + - "head_concat": cosine over concatenated classifier projections. + - "head_mean": cosine over averaged classifier projections. + + For MIRRORING-style emotion-vector similarity, use "score_vector". + + Returns: + { + "audio_path_a": str, + "audio_path_b": str, + "audio_a_top_emotion": str | None, + "audio_b_top_emotion": str | None, + "top_emotion_match": bool , + "emotion_similarity": float, + "audio_a_raw_scores": dict[str, float], + "audio_b_raw_scores": dict[str, float], + } + """ + labels_to_run = self._validate_labels(labels or PRIMARY_EMOTION_LABELS) + + result_a = self.compute( + audio_path=audio_path_a, + labels=labels_to_run, + return_embedding=False, + return_raw_scores=True, + ) + + result_b = self.compute( + audio_path=audio_path_b, + labels=labels_to_run, + return_embedding=False, + return_raw_scores=True, + ) + + top_a = result_a["top_emotion"] + top_b = result_b["top_emotion"] + + similarity = self.emotion_similarity( + audio_path_a=audio_path_a, + audio_path_b=audio_path_b, + labels=labels_to_run, + embedding_type=embedding_type, + include_auxiliary=False, + ) + + return { + "audio_path_a": str(audio_path_a), + "audio_path_b": str(audio_path_b), + "audio_a_top_emotion": top_a, + "audio_b_top_emotion": top_b, + "top_emotion_match": top_a is not None and top_a == top_b, + "emotion_similarity": similarity, + "audio_a_raw_scores": result_a["raw_scores"], + "audio_b_raw_scores": result_b["raw_scores"], + } + + @torch.no_grad() + def compute( + self, + audio_path: Union[str, Path], + labels: Optional[Sequence[str]] = None, + return_embedding: bool = True, + embedding_type: str = "head_concat", + return_raw_scores: bool = True, + include_auxiliary_for_embedding: bool = False, + ) -> dict[str, Any]: + """Compute emotion predictions and optionally an embedding. + + Args: + audio_path: + Input audio file. + labels: + Prediction labels. If None, all known labels are attempted. + For the official 40-emotion profile, pass PRIMARY_EMOTION_LABELS. + return_embedding: + If True, return an emotion embedding. + embedding_type: + Embedding type to return: + "head_concat", "head_mean", or "score_vector". + return_raw_scores: + If True, return raw classifier outputs. + include_auxiliary_for_embedding: + If True and return_embedding=True, include auxiliary labels in the + returned embedding. + + Returns: + { + "audio_path": str, + "model_size": "small" | "large", + "top_emotion": str | None, + "emotions": { + label: {"score": float, "rank": int} + }, + "raw_scores": { + label: float + }, + "embedding": torch.Tensor, + "embedding_type": str + } + """ + whisper_embedding = self.extract_whisper_embedding(audio_path) + + prediction = self.predict_emotions_from_embedding( + embedding=whisper_embedding, + labels=labels, + return_raw_scores=return_raw_scores, + rank_scores=True, + ) + + top_emotion = None + if prediction["emotions"]: + top_emotion = next(iter(prediction["emotions"])) + + output: dict[str, Any] = { + "audio_path": str(audio_path), + "model_size": self.size, + "top_emotion": top_emotion, + "emotions": prediction["emotions"], + } + + if return_raw_scores: + output["raw_scores"] = prediction["raw_scores"] + + if return_embedding: + output["embedding"] = self.extract_emotion_embedding_from_whisper_embedding( + whisper_embedding=whisper_embedding, + labels=None, + embedding_type=embedding_type, + normalize=True, + include_auxiliary=include_auxiliary_for_embedding, + ) + output["embedding_type"] = embedding_type + + return output + + @torch.no_grad() + def predict_emotions_from_embedding( + self, + embedding: torch.Tensor, + labels: Optional[Sequence[str]] = None, + return_raw_scores: bool = True, + rank_scores: bool = True, + ) -> dict[str, Any]: + """Predict emotion or attribute scores from a Whisper embedding. + + Args: + embedding: + Whisper encoder embedding, normally [1, 1500, 768]. + labels: + Labels to evaluate. If None, all known labels are attempted. + return_raw_scores: + If True, include raw classifier scores. + rank_scores: + If True, softmax and rank scores into top-k emotions. + + Returns: + Dict containing: + "emotions": ranked top-k scores if rank_scores=True + "raw_scores": raw scalar outputs if return_raw_scores=True + """ + labels_to_run = self._validate_labels(labels) + + raw_scores: dict[str, float] = {} + + for label in labels_to_run: + classifier = self._get_classifier(label) + param = next(classifier.parameters()) + + working_embedding = embedding.to(device=param.device, dtype=param.dtype) + score = classifier(working_embedding).detach().cpu().item() + raw_scores[label] = float(score) + + if not self.cache_classifiers: + cache_key = self._cache_key(label) + if cache_key in self.classifiers: + del self.classifiers[cache_key] + + output: dict[str, Any] = {} + + if rank_scores: + output["emotions"] = self._softmax_and_rank(raw_scores) + else: + output["emotions"] = {} + + if return_raw_scores: + output["raw_scores"] = raw_scores + + return output + + # ------------------------------------------------------------------------- + # Classifier loading + # ------------------------------------------------------------------------- + + def load_classifiers( + self, + labels: Optional[Sequence[str]] = None, + ) -> None: + """Eagerly download and restore classifier MLPs. + + By default the class lazy-loads heads when needed. This method is useful + when you want to pre-load selected heads before repeated inference. + """ + labels_to_load = self._validate_labels(labels) + for label in labels_to_load: + self._get_classifier(label) + + def _get_classifier(self, label: str) -> FullEmbeddingMLP: + """Download, reconstruct, and restore one classifier MLP head. + + This is where the classifier MLP is restored: + + classifier = FullEmbeddingMLP(...) + state_dict = torch.load(...) + state_dict = strip "_orig_mod." prefix if needed + classifier.load_state_dict(state_dict) + + Args: + label: + Python-friendly label key, such as "anger" or "arousal". + + Returns: + Restored FullEmbeddingMLP. + """ + if label not in LAION_LABEL_TO_FILENAME: + raise ValueError( + f"Unknown label {label!r}. Available labels: " + f"{sorted(LAION_LABEL_TO_FILENAME)}" + ) + + cache_key = self._cache_key(label) + + if cache_key in self.classifiers: + classifier = self.classifiers[cache_key] + if not isinstance(classifier, FullEmbeddingMLP): + raise TypeError( + f"Cached classifier for {label!r} has unexpected type " + f"{type(classifier)}." + ) + return classifier + + filename = LAION_LABEL_TO_FILENAME[label] + + local_path = hf_hub_download( + repo_id=self.repo_id, + filename=filename, + cache_dir=str(self.cache_dir) if self.cache_dir is not None else None, + repo_type="model", + ) + + classifier = FullEmbeddingMLP( + seq_len=int(self.spec["seq_len"]), + embed_dim=int(self.spec["embed_dim"]), + projection_dim=int(self.spec["projection_dim"]), + mlp_hidden_dims=list(self.spec["mlp_hidden_dims"]), + mlp_dropout_rates=list(self.spec["mlp_dropouts"]), + ) + + state_dict = torch.load(local_path, map_location="cpu") + + if not isinstance(state_dict, dict): + raise RuntimeError( + f"Expected {local_path} to contain a state_dict, " + f"but got {type(state_dict)}." + ) + + state_dict = self._strip_orig_mod_prefix_if_needed(state_dict) + + try: + classifier.load_state_dict(state_dict) + except RuntimeError as exc: + raise RuntimeError( + f"Failed to load classifier for label={label!r} from {local_path}. " + f"This often means the selected size={self.size!r} has different " + "MLP dimensions than MODEL_SPECS declares." + ) from exc + + classifier.eval() + classifier = classifier.to(self.mlp_device) + + # Keep classifier dtype compatible with Whisper dtype if user loaded + # Whisper in fp16/bf16. + if self.whisper_model.dtype in (torch.float16, torch.bfloat16): + classifier = classifier.to(dtype=self.whisper_model.dtype) + + if self.cache_classifiers: + self.classifiers[cache_key] = classifier + + return classifier + + # ------------------------------------------------------------------------- + # Internal helpers + # ------------------------------------------------------------------------- + + def _validate_labels( + self, + labels: Optional[Sequence[str]], + ) -> list[str]: + """Validate label names and return a concrete list.""" + if labels is None: + return list(LAION_LABEL_TO_FILENAME.keys()) + + labels_list = list(labels) + unknown = sorted(set(labels_list) - set(LAION_LABEL_TO_FILENAME.keys())) + + if unknown: + raise ValueError( + f"Unknown labels: {unknown}. " + f"Available labels: {sorted(LAION_LABEL_TO_FILENAME.keys())}" + ) + + return labels_list + + def _default_similarity_labels( + self, + labels: Optional[Sequence[str]], + include_auxiliary: bool, + ) -> list[str]: + """Choose the default labels for emotion similarity embeddings.""" + if labels is not None: + return self._validate_labels(labels) + + labels_to_run = list(PRIMARY_EMOTION_LABELS) + + if include_auxiliary: + labels_to_run.extend(AUXILIARY_SIMILARITY_LABELS) + + return self._validate_labels(labels_to_run) + + def _prepare_waveform(self, waveform: np.ndarray) -> np.ndarray: + """Convert waveform to mono float32 and trim to max_audio_seconds.""" + waveform = np.asarray(waveform, dtype=np.float32) + + if waveform.ndim > 1: + waveform = np.mean(waveform, axis=0).astype(np.float32) + + max_samples = int(self.sample_rate * self.max_audio_seconds) + + if waveform.shape[0] > max_samples: + waveform = waveform[:max_samples] + + return waveform + + def _pad_or_trim_embedding(self, embedding: torch.Tensor) -> torch.Tensor: + """Pad or trim Whisper encoder output to the expected sequence length.""" + seq_len = int(self.spec["seq_len"]) + embed_dim = int(self.spec["embed_dim"]) + + if embedding.ndim != 3: + raise RuntimeError( + f"Expected Whisper embedding with shape [B, T, C], " + f"got {tuple(embedding.shape)}." + ) + + if embedding.shape[-1] != embed_dim: + raise RuntimeError( + f"Unexpected embedding dim. Expected {embed_dim}, " + f"got {embedding.shape[-1]}." + ) + + current_seq_len = embedding.shape[1] + + if current_seq_len < seq_len: + padding = torch.zeros( + ( + embedding.shape[0], + seq_len - current_seq_len, + embed_dim, + ), + device=embedding.device, + dtype=embedding.dtype, + ) + embedding = torch.cat([embedding, padding], dim=1) + + elif current_seq_len > seq_len: + embedding = embedding[:, :seq_len, :] + + return embedding + + def _softmax_and_rank( + self, + raw_scores: dict[str, float], + ) -> dict[str, dict[str, Union[float, int]]]: + """Convert raw scores to a sorted top-k softmax dictionary. + + The raw MLP outputs are independent regression scores. This method is + mainly for producing an easy top emotion label. For similarity, prefer + raw score vectors or projection embeddings. + """ + if not raw_scores: + return {} + + labels = list(raw_scores.keys()) + values = np.array([raw_scores[label] for label in labels], dtype=np.float32) + + values = values - np.max(values) + exp_values = np.exp(values) + probs = exp_values / np.sum(exp_values) + + ranked = sorted( + zip(labels, probs.tolist()), + key=lambda item: item[1], + reverse=True, + ) + + ranked = ranked[: self.top_k_emotions] + + return { + label: { + "score": float(prob), + "rank": rank, + } + for rank, (label, prob) in enumerate(ranked, start=1) + } + + @staticmethod + def _strip_orig_mod_prefix_if_needed( + state_dict: dict[str, torch.Tensor], + ) -> dict[str, torch.Tensor]: + """Strip torch.compile '_orig_mod.' prefixes if present.""" + if not any(key.startswith("_orig_mod.") for key in state_dict.keys()): + return state_dict + + return { + key[len("_orig_mod.") :] if key.startswith("_orig_mod.") else key: value + for key, value in state_dict.items() + } + + @staticmethod + def _cache_key(label: str) -> str: + """Convert an arbitrary label into a safe ModuleDict key.""" + return ( + label.replace(".", "_") + .replace("/", "_") + .replace("-", "_") + .replace(" ", "_") + .replace("&", "and") + ) + + def cleanup(self) -> None: + """Move modules to CPU and clear classifier cache.""" + for key in list(self.classifiers.keys()): + self.classifiers[key].cpu() + + self.classifiers.clear() + self.whisper_model.cpu() + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +# ============================================================================= +# CLI utilities +# ============================================================================= + +def _tensor_info(tensor: torch.Tensor) -> dict[str, Any]: + return { + "shape": list(tensor.shape), + "dtype": str(tensor.dtype), + "device": str(tensor.device), + } + + +def _parse_labels(labels: Optional[str]) -> Optional[list[str]]: + if labels is None or labels.strip() == "": + return None + + return [item.strip() for item in labels.split(",") if item.strip()] + + +def main() -> None: + parser = argparse.ArgumentParser( + description="LAION Empathic Insight Voice embeddings and similarity." + ) + parser.add_argument( + "--audio", + type=str, + required=True, + help="Input audio path.", + ) + parser.add_argument( + "--audio-b", + type=str, + default=None, + help="Optional second audio path for similarity.", + ) + parser.add_argument( + "--size", + type=str, + default="small", + choices=["small", "large"], + help="Model size.", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device for Whisper encoder.", + ) + parser.add_argument( + "--mlp-device", + type=str, + default=None, + help="Device for MLP classifier heads. Defaults to --device.", + ) + parser.add_argument( + "--cache-dir", + type=str, + default=None, + help="Optional Hugging Face cache directory.", + ) + parser.add_argument( + "--embedding-type", + type=str, + default="head_concat", + choices=["head_concat", "head_mean", "score_vector"], + help="Emotion embedding type.", + ) + parser.add_argument( + "--labels", + type=str, + default=None, + help=( + "Comma-separated labels to use. " + "Example: anger,sadness,arousal. " + "If omitted, primary emotion labels are used for similarity." + ), + ) + parser.add_argument( + "--include-auxiliary", + action="store_true", + help="Include auxiliary similarity labels when --labels is omitted.", + ) + parser.add_argument( + "--load-all-classifiers", + action="store_true", + help="Eagerly load all known classifiers at startup.", + ) + parser.add_argument( + "--top-k", + type=int, + default=5, + help="Number of ranked emotions to return.", + ) + + args = parser.parse_args() + + labels = _parse_labels(args.labels) + + model = EmpathicInsightVoice.from_pretrained( + size=args.size, + device=args.device, + mlp_device=args.mlp_device, + cache_dir=args.cache_dir, + cache_classifiers=True, + load_all_classifiers=args.load_all_classifiers, + top_k_emotions=args.top_k, + ) + + result = model.compute( + audio_path=args.audio, + labels=labels, + return_embedding=True, + embedding_type=args.embedding_type, + return_raw_scores=True, + include_auxiliary_for_embedding=args.include_auxiliary, + ) + + printable: dict[str, Any] = { + "audio_path": result["audio_path"], + "model_size": result["model_size"], + "top_emotion": result["top_emotion"], + "embedding_type": result["embedding_type"], + "embedding": _tensor_info(result["embedding"]), + "emotions": result["emotions"], + "raw_scores": result["raw_scores"], + } + + if args.audio_b is not None: + printable["audio_b"] = args.audio_b + printable["similarity"] = model.emotion_similarity( + audio_path_a=args.audio, + audio_path_b=args.audio_b, + labels=labels, + embedding_type=args.embedding_type, + include_auxiliary=args.include_auxiliary, + ) + + result = model.compare_emotion_pair( + audio_path_a=args.audio, + audio_path_b=args.audio_b, + embedding_type="head_concat", + ) + + result_score_vector = model.compare_emotion_pair( + audio_path_a=args.audio, + audio_path_b=args.audio_b, + embedding_type="score_vector", + ) + + result_score_mean = model.compare_emotion_pair( + audio_path_a=args.audio, + audio_path_b=args.audio_b, + embedding_type="head_mean", + ) + + print("embedding_type=head_concat") + print(result["audio_a_top_emotion"]) + print(result["audio_b_top_emotion"]) + print(result["top_emotion_match"]) + print(result["emotion_similarity"]) + + + print("embedding_type=score_vector") + print(result_score_vector["audio_a_top_emotion"]) + print(result_score_vector["audio_b_top_emotion"]) + print(result_score_vector["top_emotion_match"]) + print(result_score_vector["emotion_similarity"]) + + print("embedding_type=head_mean") + print(result_score_mean["audio_a_top_emotion"]) + print(result_score_mean["audio_b_top_emotion"]) + print(result_score_mean["top_emotion_match"]) + print(result_score_mean["emotion_similarity"]) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py b/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py index 42640dee124f..14d6892b3a0d 100644 --- a/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py +++ b/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py @@ -56,6 +56,8 @@ 'cer', 'wer', 'pred_context_ssim', + 'pred_gt_esim', + 'pred_gt_ems', 'pred_text', 'gt_text', 'gt_audio_filepath', @@ -265,7 +267,13 @@ def transcribed_batched( def load_evaluation_models( - language="en", sv_model_type="titanet", asr_model_name="stt_en_conformer_transducer_large", device="cuda" + language="en", + sv_model_type="titanet", + asr_model_name="stt_en_conformer_transducer_large", + device="cuda", + with_emotion_metrics=False, + emotion_model_size="small", + emotion_cache_dir=None, ): """Load ASR and speaker verification models used for evaluation. @@ -284,6 +292,7 @@ def load_evaluation_models( 'whisper_model': None, 'whisper_processor': None, 'feature_extractor': None, + 'emotion_model': None, } if language == "en": @@ -314,9 +323,43 @@ def load_evaluation_models( ) models['sv_model_alternate'] = models['sv_model_alternate'].to(device).eval() + if with_emotion_metrics: + logging.info("Loading emotion encoder for ESIM/EMS metrics...") + try: + from nemo.collections.tts.metrics.emotion_encoder import EmpathicInsightVoice + + models['emotion_model'] = EmpathicInsightVoice.from_pretrained( + size=emotion_model_size, + device=device, + mlp_device=device, + cache_dir=emotion_cache_dir, + cache_classifiers=True, + load_all_classifiers=False, + top_k_emotions=1, + ).eval() + except Exception as e: + logging.warning(f"Emotion encoder could not be loaded: {e}. ESIM/EMS metrics will be set to NaN.") + return models +def compute_emotion_pair_metrics(emotion_model, gt_audio_path, pred_audio_path, embedding_type="score_vector"): + """Compute ground-truth to predicted emotion similarity and top-emotion match.""" + if emotion_model is None or gt_audio_path is None or pred_audio_path is None: + return float('NaN'), float('NaN') + + try: + result = emotion_model.compare_emotion_pair( + audio_path_a=gt_audio_path, + audio_path_b=pred_audio_path, + embedding_type=embedding_type, + ) + return float(result["emotion_similarity"]), float(result["top_emotion_match"]) + except Exception as e: + logging.warning(f"Could not compute ESIM/EMS for {gt_audio_path} and {pred_audio_path}: {e}") + return float('NaN'), float('NaN') + + def classify_eou_batched( eou_classifier: EoUClassifier, items: list[tuple[Union[str, np.ndarray], str]], batch_size: int = 32 ) -> list[EoUClassification]: @@ -345,6 +388,10 @@ def evaluate_dir( sv_model_type="titanet", asr_model_name="stt_en_conformer_transducer_large", with_utmosv2=True, + with_emotion_metrics=False, + emotion_model_size="small", + emotion_embedding_type="score_vector", + emotion_cache_dir=None, asr_batch_size=32, eou_batch_size=32, device="cuda", @@ -377,7 +424,15 @@ def evaluate_dir( context_audio_paths = [_resolve_path(audio_dir, r.get('context_audio_filepath')) for r in records] # 2. Load models - models = load_evaluation_models(language, sv_model_type, asr_model_name, device) + models = load_evaluation_models( + language, + sv_model_type, + asr_model_name, + device, + with_emotion_metrics=with_emotion_metrics, + emotion_model_size=emotion_model_size, + emotion_cache_dir=emotion_cache_dir, + ) asr_model = models['asr_model'] whisper_model = models['whisper_model'] @@ -385,6 +440,7 @@ def evaluate_dir( feature_extractor = models['feature_extractor'] speaker_verification_model = models['sv_model'] speaker_verification_model_alternate = models['sv_model_alternate'] + emotion_model = models['emotion_model'] # 3. EoU classifier (support for English only) if language == "en": @@ -476,6 +532,16 @@ def evaluate_dir( detailed_cer = word_error_rate_detail(hypotheses=[pred_text], references=[gt_text], use_cer=True) detailed_wer = word_error_rate_detail(hypotheses=[pred_text], references=[gt_text], use_cer=False) + pred_gt_esim = float('NaN') + pred_gt_ems = float('NaN') + if with_emotion_metrics: + pred_gt_esim, pred_gt_ems = compute_emotion_pair_metrics( + emotion_model, + gt_audio_filepath, + pred_audio_filepath, + embedding_type=emotion_embedding_type, + ) + logging.info(f"{ridx} GT Text: {gt_text}") logging.info(f"{ridx} Pr Text: {pred_text}") # Format cer and wer to 2 decimal places @@ -561,32 +627,35 @@ def evaluate_dir( eou_trailing = float('nan') eou_rms_ratio = float('nan') - filewise_metrics.append( - { - 'gt_text': gt_text, - 'pred_text': pred_text, - 'gt_audio_text': gt_audio_text, - 'detailed_cer': detailed_cer, - 'detailed_wer': detailed_wer, - 'cer': detailed_cer[0], - 'wer': detailed_wer[0], - 'pred_gt_ssim': pred_gt_ssim, - 'pred_context_ssim': pred_context_ssim, - 'gt_context_ssim': gt_context_ssim, - 'pred_gt_ssim_alternate': pred_gt_ssim_alternate, - 'pred_context_ssim_alternate': pred_context_ssim_alternate, - 'gt_context_ssim_alternate': gt_context_ssim_alternate, - 'gt_audio_filepath': gt_audio_filepath, - 'pred_audio_filepath': pred_audio_filepath, - 'context_audio_filepath': context_audio_filepath, - 'utmosv2': utmosv2_score, - 'eou_type': eou_type, - 'eou_trailing_duration': eou_trailing, - 'eou_trail_rms_ratio': eou_rms_ratio, - 'total_gen_audio_seconds': file_duration, - 'predicted_codes_path': codes_file_lists[ridx] if has_codes else None, - } - ) + metric_row = { + 'gt_text': gt_text, + 'pred_text': pred_text, + 'gt_audio_text': gt_audio_text, + 'detailed_cer': detailed_cer, + 'detailed_wer': detailed_wer, + 'cer': detailed_cer[0], + 'wer': detailed_wer[0], + 'pred_gt_ssim': pred_gt_ssim, + 'pred_context_ssim': pred_context_ssim, + 'gt_context_ssim': gt_context_ssim, + 'pred_gt_ssim_alternate': pred_gt_ssim_alternate, + 'pred_context_ssim_alternate': pred_context_ssim_alternate, + 'gt_context_ssim_alternate': gt_context_ssim_alternate, + 'gt_audio_filepath': gt_audio_filepath, + 'pred_audio_filepath': pred_audio_filepath, + 'context_audio_filepath': context_audio_filepath, + 'utmosv2': utmosv2_score, + 'eou_type': eou_type, + 'eou_trailing_duration': eou_trailing, + 'eou_trail_rms_ratio': eou_rms_ratio, + 'total_gen_audio_seconds': file_duration, + 'predicted_codes_path': codes_file_lists[ridx] if has_codes else None, + } + if with_emotion_metrics: + metric_row['pred_gt_esim'] = pred_gt_esim + metric_row['pred_gt_ems'] = pred_gt_ems + + filewise_metrics.append(metric_row) return filewise_metrics @@ -601,6 +670,10 @@ def evaluate( with_utmosv2=True, with_fcd=True, codec_model_path=None, + with_emotion_metrics=False, + emotion_model_size="small", + emotion_embedding_type="head_concat", + emotion_cache_dir=None, asr_batch_size=32, eou_batch_size=32, device="cuda", @@ -636,6 +709,10 @@ def evaluate( sv_model_type=sv_model_type, asr_model_name=asr_model_name, with_utmosv2=with_utmosv2, + with_emotion_metrics=with_emotion_metrics, + emotion_model_size=emotion_model_size, + emotion_embedding_type=emotion_embedding_type, + emotion_cache_dir=emotion_cache_dir, asr_batch_size=asr_batch_size, eou_batch_size=eou_batch_size, device=device, @@ -725,6 +802,9 @@ def compute_global_metrics( sum(m['pred_context_ssim_alternate'] for m in filewise_metrics) / n ) avg_metrics['ssim_gt_context_avg_alternate'] = sum(m['gt_context_ssim_alternate'] for m in filewise_metrics) / n + if 'pred_gt_esim' in filewise_metrics[0]: + avg_metrics['esim_pred_gt_avg'] = sum(m['pred_gt_esim'] for m in filewise_metrics) / n + avg_metrics['ems_pred_gt_avg'] = sum(m['pred_gt_ems'] for m in filewise_metrics) / n # Cumulative WER/CER on ground-truth audio transcriptions (if available) gt_audio_texts = [m['gt_audio_text'] for m in filewise_metrics] @@ -778,6 +858,15 @@ def main(): parser.add_argument('--generated_audio_dir', type=str, default=None) parser.add_argument('--whisper_language', type=str, default="en") parser.add_argument('--evalset', type=str, default=None) + parser.add_argument('--with_emotion_metrics', action='store_true') + parser.add_argument('--emotion_model_size', type=str, default="small", choices=["small", "large"]) + parser.add_argument( + '--emotion_embedding_type', + type=str, + default="score_vector", + choices=["head_concat", "head_mean", "score_vector"], + ) + parser.add_argument('--emotion_cache_dir', type=str, default=None) args = parser.parse_args() if args.evalset is not None: @@ -793,6 +882,10 @@ def main(): args.whisper_language, sv_model_type="wavlm", asr_model_name="nvidia/parakeet-ctc-0.6b", + with_emotion_metrics=args.with_emotion_metrics, + emotion_model_size=args.emotion_model_size, + emotion_embedding_type=args.emotion_embedding_type, + emotion_cache_dir=args.emotion_cache_dir, ) diff --git a/nemo/collections/tts/modules/magpietts_inference/evaluation.py b/nemo/collections/tts/modules/magpietts_inference/evaluation.py index bb9013cc9ff1..5f1ee4a1652f 100644 --- a/nemo/collections/tts/modules/magpietts_inference/evaluation.py +++ b/nemo/collections/tts/modules/magpietts_inference/evaluation.py @@ -43,6 +43,10 @@ class EvaluationConfig: with_utmosv2: Whether to compute UTMOSv2 (Mean Opinion Score) metrics. with_fcd: Whether to compute Frechet Codec Distance metric. codec_model_path: Path to the audio codec model. If None, will skip computing Frechet Codec Distance metric. + with_emotion_metrics: Whether to compute ESIM and EMS using the emotion encoder. + emotion_model_size: Emotion encoder size ("small" or "large"). + emotion_embedding_type: Emotion embedding type used for ESIM. + emotion_cache_dir: Optional Hugging Face cache directory for the emotion encoder. device: Device to use for running models used during evaluation. """ @@ -53,6 +57,10 @@ class EvaluationConfig: with_utmosv2: bool = True with_fcd: bool = True codec_model_path: str = None + with_emotion_metrics: bool = False + emotion_model_size: str = "small" + emotion_embedding_type: str = "head_concat" + emotion_cache_dir: str = None device: str = "cuda" asr_batch_size: int = 32 eou_batch_size: int = 32 @@ -95,6 +103,10 @@ def evaluate_generated_audio_dir( with_utmosv2=config.with_utmosv2, with_fcd=config.with_fcd, codec_model_path=config.codec_model_path, + with_emotion_metrics=config.with_emotion_metrics, + emotion_model_size=config.emotion_model_size, + emotion_embedding_type=config.emotion_embedding_type, + emotion_cache_dir=config.emotion_cache_dir, device=config.device, eou_model_name=config.eou_model_name, asr_batch_size=config.asr_batch_size, From 106e4feaba988af2d731162c6c71a0c85cffe6cf Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Sat, 20 Jun 2026 10:51:31 -0700 Subject: [PATCH 091/109] Update emotion_embedding_type default to score_vector Signed-off-by: Edresson Casanova --- examples/tts/magpietts_inference.py | 2 +- nemo/collections/tts/modules/magpietts_inference/evaluation.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/tts/magpietts_inference.py b/examples/tts/magpietts_inference.py index 85899679584c..61dcff60122a 100644 --- a/examples/tts/magpietts_inference.py +++ b/examples/tts/magpietts_inference.py @@ -977,7 +977,7 @@ def _add_common_args(parser: argparse.ArgumentParser) -> None: eval_group.add_argument( '--emotion_embedding_type', type=str, - default="head_concat", + default="score_vector", choices=["head_concat", "head_mean", "score_vector"], ) eval_group.add_argument('--emotion_cache_dir', type=str, default=None) diff --git a/nemo/collections/tts/modules/magpietts_inference/evaluation.py b/nemo/collections/tts/modules/magpietts_inference/evaluation.py index 5f1ee4a1652f..4d6ef1d5e846 100644 --- a/nemo/collections/tts/modules/magpietts_inference/evaluation.py +++ b/nemo/collections/tts/modules/magpietts_inference/evaluation.py @@ -59,7 +59,7 @@ class EvaluationConfig: codec_model_path: str = None with_emotion_metrics: bool = False emotion_model_size: str = "small" - emotion_embedding_type: str = "head_concat" + emotion_embedding_type: str = "score_vector" emotion_cache_dir: str = None device: str = "cuda" asr_batch_size: int = 32 From e4864eda693d2c2a6e11b07e8040890298912572 Mon Sep 17 00:00:00 2001 From: Shehzeen Hussain Date: Sat, 20 Jun 2026 17:07:28 -0700 Subject: [PATCH 092/109] infererence user turn end token fix Signed-off-by: Shehzeen Hussain Signed-off-by: Edresson Casanova --- nemo/collections/tts/models/easy_magpietts_inference.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index d76223466967..96135df233d1 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -1700,12 +1700,11 @@ def streaming_step( state.phoneme_steps += needs_phoneme.long() state.audio_steps += needs_audio.long() - # Same behavior as your previous code when phoneme is disabled. + # Match training: the input slot that predicts the first agent frame after + # a user/non-agent region should receive the learned user-speaking-end token. use_end_token = ( prefill_like_is_last_step and self.cfg.get("use_user_speaking_end_token", False) - and not self.cfg.get("agent_mask_include_transition_prefix", False) - and self.phoneme_tokenizer is None ) if use_end_token: From f2c4b23cf05a61e376ce387813e838cbf36051f6 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 22 Jun 2026 06:11:35 -0700 Subject: [PATCH 093/109] Update Signed-off-by: Edresson Casanova --- nemo/collections/tts/modules/magpietts_inference/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/tts/modules/magpietts_inference/inference.py b/nemo/collections/tts/modules/magpietts_inference/inference.py index 15c6973a4c2d..ec0802b68051 100644 --- a/nemo/collections/tts/modules/magpietts_inference/inference.py +++ b/nemo/collections/tts/modules/magpietts_inference/inference.py @@ -807,7 +807,7 @@ def collate_fn(self, batch: List[dict]) -> Dict[str, Any]: user_audio_turns.append(wav.unsqueeze(0)) user_audio_turns_lens.append(torch.tensor([wav.numel()], dtype=torch.long)) - target_turn_audio_paths = sample.get("target_audio_file_path", sample.get("target_audio_filepath", None)) + target_turn_audio_paths = sample.get("target_audio_file_path", None) if target_turn_audio_paths is not None and not isinstance(target_turn_audio_paths, list): target_turn_audio_paths = [target_turn_audio_paths] From c849cba9638fdaddbd4a7e02787719089b46f91d Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 22 Jun 2026 07:20:39 -0700 Subject: [PATCH 094/109] Add strip_text_annotations_for_metrics parameter Signed-off-by: Edresson Casanova --- examples/tts/magpietts_inference.py | 7 +++ .../evaluate_generated_audio.py | 49 ++++++++++++++++++- .../modules/magpietts_inference/evaluation.py | 3 ++ 3 files changed, 58 insertions(+), 1 deletion(-) diff --git a/examples/tts/magpietts_inference.py b/examples/tts/magpietts_inference.py index 61dcff60122a..135d12ebdb4d 100644 --- a/examples/tts/magpietts_inference.py +++ b/examples/tts/magpietts_inference.py @@ -728,6 +728,7 @@ def run_inference_and_evaluation( emotion_model_size=eval_config.emotion_model_size, emotion_embedding_type=eval_config.emotion_embedding_type, emotion_cache_dir=eval_config.emotion_cache_dir, + strip_text_annotations_for_metrics=eval_config.strip_text_annotations_for_metrics, device=eval_config.device, asr_batch_size=eval_config.asr_batch_size, eou_batch_size=eval_config.eou_batch_size, @@ -973,6 +974,11 @@ def _add_common_args(parser: argparse.ArgumentParser) -> None: eval_group.add_argument('--confidence_level', type=float, default=0.95) eval_group.add_argument('--disable_utmosv2', action='store_true') eval_group.add_argument('--with_emotion_metrics', action='store_true') + eval_group.add_argument( + '--strip_text_annotations_for_metrics', + action='store_true', + help='Strip bracket/tag/control annotations from reference and ASR hypothesis text while computing text metrics.', + ) eval_group.add_argument('--emotion_model_size', type=str, default="small", choices=["small", "large"]) eval_group.add_argument( '--emotion_embedding_type', @@ -1185,6 +1191,7 @@ def main(argv=None): emotion_model_size=args.emotion_model_size, emotion_embedding_type=args.emotion_embedding_type, emotion_cache_dir=args.emotion_cache_dir, + strip_text_annotations_for_metrics=args.strip_text_annotations_for_metrics, asr_batch_size=args.asr_batch_size, eou_batch_size=args.eou_batch_size, ) diff --git a/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py b/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py index 14d6892b3a0d..3bcbde1a3d97 100644 --- a/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py +++ b/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py @@ -18,6 +18,7 @@ import json import os import pprint +import re import tempfile import time from collections import Counter @@ -52,6 +53,36 @@ ) +# Regexes mirrored from the IPA preprocessing script that creates +# custom["text_without_annotation"]. This is used only for text inputs +# during metric computation when requested. +_WS_RE = re.compile(r"\s+") +_SPACE_BEFORE_PUNCT_RE = re.compile(r"\s+([,.;:!?؟،؛])") +_TATWEEL_RE = re.compile("\u0640+") +_ANNOTATION_OR_MARKER_RE = re.compile( + r""" + \[[^\[\]\n]{1,512}\] # square annotation: [breath], [نقر] + | \n]{1,512}> # XML/style/language tags + | \{/?[^{}\n]{1,512}\} # curly control/pronunciation tags + | [-–—]{2,} # multi-dash cutoff: --, ---, —— + | (?<=\S)[-–—](?=\s|$) # trailing single dash after a token: word- + | (?:^|(?<=\s))[-–—](?=\s|$) # standalone dash + | \.{3,} # ASCII ellipsis + | …+ # Unicode ellipsis + | \*+ # emphasis marker: *word* + """, + re.VERBOSE, +) + + +def strip_text_annotations_from_text(text: str) -> str: + """Return orthographic text with annotation/control tokens removed.""" + text = _ANNOTATION_OR_MARKER_RE.sub(" ", str(text)) + text = _TATWEEL_RE.sub("", text) + text = _WS_RE.sub(" ", text).strip() + text = _SPACE_BEFORE_PUNCT_RE.sub(r"\1", text) + return text.strip() + FILEWISE_METRICS_TO_SAVE = [ 'cer', 'wer', @@ -388,6 +419,7 @@ def evaluate_dir( sv_model_type="titanet", asr_model_name="stt_en_conformer_transducer_large", with_utmosv2=True, + strip_text_annotations_for_metrics=False, with_emotion_metrics=False, emotion_model_size="small", emotion_embedding_type="score_vector", @@ -476,6 +508,8 @@ def evaluate_dir( asr_batch_size, label="predicted", ) + if strip_text_annotations_for_metrics: + pred_texts = [strip_text_annotations_from_text(text) for text in pred_texts] pred_texts = [text_processor.process_text_for_wer(text) for text in pred_texts] # Transcribe ground truth audios if len(gt_audio_paths) > 0: @@ -489,6 +523,8 @@ def evaluate_dir( asr_batch_size, label="ground truth", ) + if strip_text_annotations_for_metrics: + gt_audio_texts = [strip_text_annotations_from_text(text) for text in gt_audio_texts] gt_audio_texts = [text_processor.process_text_for_wer(text) for text in gt_audio_texts] else: gt_audio_texts = [None] * len(records) @@ -502,7 +538,10 @@ def evaluate_dir( text_field = 'normalized_text' else: text_field = 'text' - processed_text = text_processor.process_text_for_wer(record[text_field]) + text = record[text_field] + if strip_text_annotations_for_metrics: + text = strip_text_annotations_from_text(text) + processed_text = text_processor.process_text_for_wer(text) gt_texts_processed.append(processed_text) # 7. Batched EoU classification @@ -668,6 +707,7 @@ def evaluate( sv_model_type="titanet", asr_model_name="stt_en_conformer_transducer_large", with_utmosv2=True, + strip_text_annotations_for_metrics=False, with_fcd=True, codec_model_path=None, with_emotion_metrics=False, @@ -709,6 +749,7 @@ def evaluate( sv_model_type=sv_model_type, asr_model_name=asr_model_name, with_utmosv2=with_utmosv2, + strip_text_annotations_for_metrics=strip_text_annotations_for_metrics, with_emotion_metrics=with_emotion_metrics, emotion_model_size=emotion_model_size, emotion_embedding_type=emotion_embedding_type, @@ -859,6 +900,11 @@ def main(): parser.add_argument('--whisper_language', type=str, default="en") parser.add_argument('--evalset', type=str, default=None) parser.add_argument('--with_emotion_metrics', action='store_true') + parser.add_argument( + '--strip_text_annotations_for_metrics', + action='store_true', + help='Strip bracket/tag/control annotations from reference and ASR hypothesis text while computing text metrics.', + ) parser.add_argument('--emotion_model_size', type=str, default="small", choices=["small", "large"]) parser.add_argument( '--emotion_embedding_type', @@ -883,6 +929,7 @@ def main(): sv_model_type="wavlm", asr_model_name="nvidia/parakeet-ctc-0.6b", with_emotion_metrics=args.with_emotion_metrics, + strip_text_annotations_for_metrics=args.strip_text_annotations_for_metrics, emotion_model_size=args.emotion_model_size, emotion_embedding_type=args.emotion_embedding_type, emotion_cache_dir=args.emotion_cache_dir, diff --git a/nemo/collections/tts/modules/magpietts_inference/evaluation.py b/nemo/collections/tts/modules/magpietts_inference/evaluation.py index 4d6ef1d5e846..a92659209f31 100644 --- a/nemo/collections/tts/modules/magpietts_inference/evaluation.py +++ b/nemo/collections/tts/modules/magpietts_inference/evaluation.py @@ -47,6 +47,7 @@ class EvaluationConfig: emotion_model_size: Emotion encoder size ("small" or "large"). emotion_embedding_type: Emotion embedding type used for ESIM. emotion_cache_dir: Optional Hugging Face cache directory for the emotion encoder. + strip_text_annotations_for_metrics: Whether to strip annotation/control markers from reference and ASR hypothesis text before text metrics. device: Device to use for running models used during evaluation. """ @@ -61,6 +62,7 @@ class EvaluationConfig: emotion_model_size: str = "small" emotion_embedding_type: str = "score_vector" emotion_cache_dir: str = None + strip_text_annotations_for_metrics: bool = False device: str = "cuda" asr_batch_size: int = 32 eou_batch_size: int = 32 @@ -107,6 +109,7 @@ def evaluate_generated_audio_dir( emotion_model_size=config.emotion_model_size, emotion_embedding_type=config.emotion_embedding_type, emotion_cache_dir=config.emotion_cache_dir, + strip_text_annotations_for_metrics=config.strip_text_annotations_for_metrics, device=config.device, eou_model_name=config.eou_model_name, asr_batch_size=config.asr_batch_size, From a98bf4729d7102bdf39e6d70eee5a5c3c006500f Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 22 Jun 2026 09:35:47 -0700 Subject: [PATCH 095/109] Remove symbolic links on multiturn eval and added ground truth multiturn audio Signed-off-by: Edresson Casanova --- examples/tts/magpietts_inference.py | 30 ++-- .../modules/magpietts_inference/inference.py | 131 +++++++++++------- 2 files changed, 102 insertions(+), 59 deletions(-) diff --git a/examples/tts/magpietts_inference.py b/examples/tts/magpietts_inference.py index 135d12ebdb4d..a9ebb5c1bfda 100644 --- a/examples/tts/magpietts_inference.py +++ b/examples/tts/magpietts_inference.py @@ -451,7 +451,14 @@ def _wait_for_multiturn_rank_manifests(repeat_audio_dir: str, world_size: int, t raise RuntimeError(f"Timed out waiting for multiturn rank manifests: {missing}") -def _copy_or_link(src: str, dst: str, required: bool = False) -> None: +def _path_is_under_dir(path: str, directory: str) -> bool: + try: + return os.path.commonpath([os.path.realpath(path), os.path.realpath(directory)]) == os.path.realpath(directory) + except ValueError: + return False + + +def _move_or_copy_rank_output(src: str, dst: str, required: bool = False, rank_dir: str = None) -> None: if src is None or not os.path.exists(src): if os.path.lexists(dst): os.remove(dst) @@ -464,8 +471,13 @@ def _copy_or_link(src: str, dst: str, required: bool = False) -> None: if os.path.lexists(dst): os.remove(dst) - # Prefer real files for evaluator inputs; broken symlinks confuse librosa/UTMOS. - shutil.copyfile(src, dst) + # Move files produced inside rank_XXXX/ into the merged evaluation directory. + # If a manifest unexpectedly points outside rank_XXXX/ (for example an + # absolute dataset path), copy instead so original input data is not moved. + if rank_dir is not None and not _path_is_under_dir(src, rank_dir): + shutil.copyfile(src, dst) + else: + shutil.move(src, dst) def _merge_multiturn_rank_outputs(repeat_audio_dir: str, world_size: int, save_predicted_codes: bool) -> str: @@ -473,8 +485,8 @@ def _merge_multiturn_rank_outputs(repeat_audio_dir: str, world_size: int, save_p Each rank writes local files named predicted_audio_0.wav, target_audio_0.wav, context_audio_0.wav, predicted_codes_0.pt, ... inside rank_XXXX/. This - function remaps them to contiguous global indices in repeat_audio_dir/ and - writes a merged turn-level manifest. + function moves rank-local artifacts to contiguous global indices in + repeat_audio_dir/ and writes a merged turn-level manifest. """ # clean previous merged files for pattern in [ @@ -502,23 +514,23 @@ def _merge_multiturn_rank_outputs(repeat_audio_dir: str, world_size: int, save_p for local_idx, record in enumerate(rank_records): pred_src = os.path.join(rank_dir, f"predicted_audio_{local_idx}.wav") pred_dst = os.path.join(repeat_audio_dir, f"predicted_audio_{global_idx}.wav") - _copy_or_link(pred_src, pred_dst, required=True) + _move_or_copy_rank_output(pred_src, pred_dst, required=True, rank_dir=rank_dir) if save_predicted_codes: code_src = os.path.join(rank_dir, f"predicted_codes_{local_idx}.pt") code_dst = os.path.join(repeat_audio_dir, f"predicted_codes_{global_idx}.pt") - _copy_or_link(code_src, code_dst, required=False) + _move_or_copy_rank_output(code_src, code_dst, required=False, rank_dir=rank_dir) target_src = os.path.join(rank_dir, record.get("audio_filepath", f"target_audio_{local_idx}.wav")) target_dst = os.path.join(repeat_audio_dir, f"target_audio_{global_idx}.wav") - _copy_or_link(target_src, target_dst, required=True) + _move_or_copy_rank_output(target_src, target_dst, required=True, rank_dir=rank_dir) context_src = os.path.join( rank_dir, record.get("context_audio_filepath", f"context_audio_{local_idx}.wav"), ) context_dst = os.path.join(repeat_audio_dir, f"context_audio_{global_idx}.wav") - _copy_or_link(context_src, context_dst, required=True) + _move_or_copy_rank_output(context_src, context_dst, required=True, rank_dir=rank_dir) merged = dict(record) merged["audio_filepath"] = f"target_audio_{global_idx}.wav" diff --git a/nemo/collections/tts/modules/magpietts_inference/inference.py b/nemo/collections/tts/modules/magpietts_inference/inference.py index ec0802b68051..90fe035090aa 100644 --- a/nemo/collections/tts/modules/magpietts_inference/inference.py +++ b/nemo/collections/tts/modules/magpietts_inference/inference.py @@ -797,6 +797,9 @@ def collate_fn(self, batch: List[dict]) -> Dict[str, Any]: if not isinstance(user_audio_paths, list): user_audio_paths = [] + # make the user audios path absolute if needed + user_audio_paths = [self._resolve_path(path) for path in user_audio_paths] + user_audio_turns = [] user_audio_turns_lens = [] for turn_id in range(max_turns): @@ -807,10 +810,16 @@ def collate_fn(self, batch: List[dict]) -> Dict[str, Any]: user_audio_turns.append(wav.unsqueeze(0)) user_audio_turns_lens.append(torch.tensor([wav.numel()], dtype=torch.long)) - target_turn_audio_paths = sample.get("target_audio_file_path", None) - if target_turn_audio_paths is not None and not isinstance(target_turn_audio_paths, list): + target_turn_audio_paths = sample.get("target_audio_file_path", sample.get("target_audio_filepath", None)) + if target_turn_audio_paths is None: + target_turn_audio_paths = [] + elif not isinstance(target_turn_audio_paths, list): target_turn_audio_paths = [target_turn_audio_paths] + # make the target audio path absolute if needed + target_turn_audio_paths = [self._resolve_path(path) for path in target_turn_audio_paths] + target_audio_path = self._resolve_path(sample.get("audio_filepath")) + return { "idx": torch.tensor([int(sample["idx"])], dtype=torch.long), "raw_record": sample, @@ -822,7 +831,7 @@ def collate_fn(self, batch: List[dict]) -> Dict[str, Any]: "context_audio_lengths": context_audio_lens, "user_audio_turns": user_audio_turns, "user_audio_turns_lens": user_audio_turns_lens, - "target_audio_path": sample.get("audio_filepath"), + "target_audio_path": target_audio_path, "target_turn_audio_paths": target_turn_audio_paths, "languages": [sample.get("language", "en")], } @@ -1064,10 +1073,10 @@ def _move_batch_to_device(batch: Dict[str, Any], device: torch.device) -> Dict[s return out @staticmethod - def _copy_or_link( + def _copy_file( src: Optional[str], dst: str, required: bool = False, description: str = "audio" ) -> Optional[str]: - """Copy/symlink an audio artifact and optionally fail fast if missing. + """Copy an audio artifact and optionally fail fast if missing. Evaluation later expects target_audio_*.wav/context_audio_*.wav to exist. Silently skipping those files makes evaluate_generated_audio_dir fail much @@ -1086,17 +1095,13 @@ def _copy_or_link( return None os.makedirs(os.path.dirname(dst), exist_ok=True) - try: - if os.path.lexists(dst): - os.remove(dst) - os.symlink(os.path.abspath(src), dst) - except Exception: - shutil.copyfile(src, dst) + if os.path.lexists(dst): + os.remove(dst) + shutil.copy(src, dst) if required and not os.path.exists(dst): raise FileNotFoundError( - f"Failed to materialize required {description}: src={src}, dst={dst}. " - "The destination may be a broken symlink." + f"Failed to materialize required {description}: src={src}, dst={dst}." ) return dst @@ -1529,13 +1534,9 @@ def _run_multiturn_user_audio_inference( self._delete_old_generated_files(output_dir) mt_debug_output_dir = self._get_multiturn_debug_output_dir(output_dir) - debug_user_dir = os.path.join(mt_debug_output_dir, "debug_user_turns") debug_mixed_dir = os.path.join(mt_debug_output_dir, "debug_mixed_user_agent") - debug_full_agent_dir = os.path.join(mt_debug_output_dir, "debug_full_agent") if self.config.save_debug_multiturn_audio: - os.makedirs(debug_user_dir, exist_ok=True) os.makedirs(debug_mixed_dir, exist_ok=True) - os.makedirs(debug_full_agent_dir, exist_ok=True) logging.info(f"Saving multiturn debug/listening audios under audios_MT: {mt_debug_output_dir}") rank = int(getattr(self, "distributed_rank", 0)) @@ -1632,7 +1633,7 @@ def _run_multiturn_user_audio_inference( codec_file_paths.append(saved_code_path) turn_context_path = os.path.join(output_dir, f"context_audio_{item_idx}.wav") - self._copy_or_link( + self._copy_file( context_audio_path, turn_context_path, required=True, @@ -1655,7 +1656,7 @@ def _run_multiturn_user_audio_inference( target_src = context_audio_path target_dst = os.path.join(output_dir, f"target_audio_{item_idx}.wav") - self._copy_or_link( + self._copy_file( target_src, target_dst, required=True, @@ -1682,14 +1683,6 @@ def _run_multiturn_user_audio_inference( ) item_idx += 1 - if self.config.save_debug_multiturn_audio: - debug_sample_stem = self._target_audio_stem_for_debug(raw_record, sample_idx) - full_agent_path = os.path.join( - debug_full_agent_dir, - f"{debug_sample_stem}__sample_{sample_idx}__predicted_full_agent.wav", - ) - sf.write(full_agent_path, aligned_agent.numpy(), sample_rate) - if self.config.save_debug_multiturn_audio and "user_audio_turns" in batch: self._save_debug_user_agent_audio( batch=batch, @@ -1700,7 +1693,6 @@ def _run_multiturn_user_audio_inference( aligned_agent=aligned_agent, samples_per_prediction_frame=samples_per_prediction_frame, output_dir=output_dir, - debug_user_dir=debug_user_dir, debug_mixed_dir=debug_mixed_dir, ) @@ -1750,7 +1742,6 @@ def _save_debug_user_agent_audio( aligned_agent: torch.Tensor, samples_per_prediction_frame: float, output_dir: str, - debug_user_dir: str, debug_mixed_dir: str, ) -> None: sample_rate = getattr(self.model, "output_sample_rate", self.model.sample_rate) @@ -1758,25 +1749,28 @@ def _save_debug_user_agent_audio( first_user_len_in = int(batch["user_audio_turns_lens"][0][0].detach().cpu().item()) first_user_delay_out = int(round(first_user_len_in * sample_rate / self.model.sample_rate)) + def load_debug_audio(path: Optional[str]) -> Optional[torch.Tensor]: + if path is None or path == "" or not os.path.exists(path): + return None + audio, sr = sf.read(path, dtype="float32", always_2d=False) + wav = torch.as_tensor(audio, dtype=torch.float32) + if wav.ndim == 2: + wav = wav.mean(dim=1) + wav = wav.flatten() + if sr != sample_rate: + wav = resample(wav.unsqueeze(0), sr, sample_rate).squeeze(0) + return wav.contiguous() + user_segments = [] - for turn_id, _, _ in turn_frame_ranges: + user_turns_for_gt = [] + for local_turn_idx, (turn_id, _, _) in enumerate(turn_frame_ranges): if turn_id >= len(batch["user_audio_turns"]): continue turn_audio = batch["user_audio_turns"][turn_id][0].detach().cpu().float() turn_audio_len = int(batch["user_audio_turns_lens"][turn_id][0].detach().cpu().item()) turn_audio = turn_audio[:turn_audio_len] turn_audio_out = resample(turn_audio.unsqueeze(0), self.model.sample_rate, sample_rate).squeeze(0) - - debug_turn_stem = self._target_audio_stem_for_debug( - raw_record, - sample_idx, - local_turn_idx=int(turn_id), - ) - user_turn_path = os.path.join( - debug_user_dir, - f"{debug_turn_stem}__sample_{sample_idx}__turn_{turn_id}__user.wav", - ) - sf.write(user_turn_path, turn_audio_out.numpy(), sample_rate) + user_turns_for_gt.append((local_turn_idx, turn_audio_out)) if turn_id == 0: user_start_sample = 0 @@ -1802,15 +1796,6 @@ def _save_debug_user_agent_audio( user_pad[: user_ch.numel()] = user_ch agent_pad[: agent_ch.numel()] = agent_ch - mono_mix = torch.clamp(user_pad + agent_pad, min=-1.0, max=1.0) - sf.write( - os.path.join( - debug_mixed_dir, - f"{debug_sample_stem}__sample_{sample_idx}__user_agent_mixed_mono.wav", - ), - mono_mix.numpy(), - sample_rate, - ) stereo = torch.stack([user_pad, agent_pad], dim=1).numpy() sf.write( os.path.join( @@ -1820,3 +1805,49 @@ def _save_debug_user_agent_audio( stereo, sample_rate, ) + + target_turn_audio_paths = batch.get("target_turn_audio_paths", []) or [] + if isinstance(target_turn_audio_paths, str): + target_turn_audio_paths = [target_turn_audio_paths] + + gt_user_segments = [] + gt_agent_segments = [] + cursor = 0 + for local_turn_idx, user_wav in user_turns_for_gt: + gt_user_segments.append((cursor, user_wav)) + cursor += user_wav.numel() + + target_wav = None + if local_turn_idx < len(target_turn_audio_paths): + target_wav = load_debug_audio(target_turn_audio_paths[local_turn_idx]) + + if target_wav is None: + logging.warning( + "Could not load target turn audio for multiturn ground-truth debug stereo; " + f"sample_idx={sample_idx}, local_turn_idx={local_turn_idx}" + ) + continue + + gt_agent_segments.append((cursor, target_wav)) + cursor += target_wav.numel() + + if gt_user_segments or gt_agent_segments: + gt_len = 0 + for start, wav in gt_user_segments + gt_agent_segments: + gt_len = max(gt_len, start + wav.numel()) + gt_user_ch = torch.zeros(gt_len) + gt_agent_ch = torch.zeros(gt_len) + for start, wav in gt_user_segments: + gt_user_ch[start : start + wav.numel()] += wav + for start, wav in gt_agent_segments: + gt_agent_ch[start : start + wav.numel()] += wav + + gt_stereo = torch.stack([gt_user_ch, gt_agent_ch], dim=1).numpy() + sf.write( + os.path.join( + debug_mixed_dir, + f"{debug_sample_stem}__sample_{sample_idx}__user_agent_ground_truth_stereo.wav", + ), + gt_stereo, + sample_rate, + ) From f03723a091247429ebf5659befe3c6894c60468b Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 22 Jun 2026 11:52:49 -0700 Subject: [PATCH 096/109] Fix empty turn issue for no annotation data Signed-off-by: Edresson Casanova --- .../modules/magpietts_inference/inference.py | 110 ++++++++++++------ 1 file changed, 74 insertions(+), 36 deletions(-) diff --git a/nemo/collections/tts/modules/magpietts_inference/inference.py b/nemo/collections/tts/modules/magpietts_inference/inference.py index 90fe035090aa..a21be2fadc81 100644 --- a/nemo/collections/tts/modules/magpietts_inference/inference.py +++ b/nemo/collections/tts/modules/magpietts_inference/inference.py @@ -768,6 +768,10 @@ def _as_turn_list(value) -> List[str]: return [str(x) for x in value] return [str(value)] + @staticmethod + def _has_valid_turn_text(text: str) -> bool: + return any(ch.isalnum() for ch in str(text or "")) + def collate_fn(self, batch: List[dict]) -> Dict[str, Any]: if len(batch) != 1: raise RuntimeError("multiturn_user_audio inference currently requires batch_size=1.") @@ -780,35 +784,58 @@ def collate_fn(self, batch: List[dict]) -> Dict[str, Any]: raw_turn_texts = self._as_turn_list(sample["text"])[: self.max_eval_turns] max_turns = len(raw_turn_texts) - batched_turns = [] - batched_turn_lens = [] - valid_turn_masks = [] - for turn_text in raw_turn_texts: - ids = model.tokenizer.encode(turn_text, tokenizer_name=main_tokenizer_name) + [model.eos_id] - batched_turns.append(torch.tensor([ids], dtype=torch.long)) - batched_turn_lens.append(torch.tensor([len(ids)], dtype=torch.long)) - valid_turn_masks.append(torch.tensor([True], dtype=torch.bool)) - - context_path = self._resolve_path(sample.get("context_audio_filepath")) - context_audio = self._load_audio_1d(context_path, sample_rate).unsqueeze(0) - context_audio_lens = torch.tensor([context_audio.size(1)], dtype=torch.long) - user_audio_paths = sample.get("user_audio_file_path", None) if not isinstance(user_audio_paths, list): user_audio_paths = [] # make the user audios path absolute if needed user_audio_paths = [self._resolve_path(path) for path in user_audio_paths] - - user_audio_turns = [] - user_audio_turns_lens = [] + + raw_user_audio_turns = [] for turn_id in range(max_turns): if turn_id < len(user_audio_paths) and user_audio_paths[turn_id]: wav = self._load_audio_1d(user_audio_paths[turn_id], sample_rate) else: wav = torch.zeros(int(2 * sample_rate), dtype=torch.float32) + raw_user_audio_turns.append(wav) + + batched_turns = [] + batched_turn_lens = [] + valid_turn_masks = [] + user_audio_turns = [] + user_audio_turns_lens = [] + turn_ids = [] + pending_user_audio_turns = [] + skipped_turns = 0 + for turn_id, turn_text in enumerate(raw_turn_texts): + pending_user_audio_turns.append(raw_user_audio_turns[turn_id]) + if not self._has_valid_turn_text(turn_text): + skipped_turns += 1 + continue + + ids = model.tokenizer.encode(turn_text, tokenizer_name=main_tokenizer_name) + [model.eos_id] + batched_turns.append(torch.tensor([ids], dtype=torch.long)) + batched_turn_lens.append(torch.tensor([len(ids)], dtype=torch.long)) + valid_turn_masks.append(torch.tensor([True], dtype=torch.bool)) + wav = ( + torch.cat(pending_user_audio_turns, dim=0) + if pending_user_audio_turns + else raw_user_audio_turns[turn_id] + ) user_audio_turns.append(wav.unsqueeze(0)) user_audio_turns_lens.append(torch.tensor([wav.numel()], dtype=torch.long)) + turn_ids.append(turn_id) + pending_user_audio_turns = [] + + if skipped_turns > 0: + logging.info( + f"Skipping {skipped_turns} empty or punctuation-only multiturn agent turns " + f"for sample_idx={sample.get('idx')}" + ) + + context_path = self._resolve_path(sample.get("context_audio_filepath")) + context_audio = self._load_audio_1d(context_path, sample_rate).unsqueeze(0) + context_audio_lens = torch.tensor([context_audio.size(1)], dtype=torch.long) target_turn_audio_paths = sample.get("target_audio_file_path", sample.get("target_audio_filepath", None)) if target_turn_audio_paths is None: @@ -827,6 +854,7 @@ def collate_fn(self, batch: List[dict]) -> Dict[str, Any]: "batched_turns": batched_turns, "batched_turn_lens": batched_turn_lens, "valid_turn_masks": valid_turn_masks, + "turn_ids": turn_ids, "context_audio": context_audio, "context_audio_lengths": context_audio_lens, "user_audio_turns": user_audio_turns, @@ -1212,10 +1240,12 @@ def _run_multiturn_generation(self, batch: Dict[str, Any]): decode_start_frame = 0 max_decoder_steps = params.max_decoder_steps - for turn_id in range(len(batch["batched_turns"])): - turn_text = batch["batched_turns"][turn_id].to(device) - turn_lens = batch["batched_turn_lens"][turn_id].to(device) - valid_mask = batch["valid_turn_masks"][turn_id].to(device) + turn_ids = batch.get("turn_ids", list(range(len(batch["batched_turns"])))) + for local_turn_idx in range(len(batch["batched_turns"])): + turn_id = int(turn_ids[local_turn_idx]) + turn_text = batch["batched_turns"][local_turn_idx].to(device) + turn_lens = batch["batched_turn_lens"][local_turn_idx].to(device) + valid_mask = batch["valid_turn_masks"][local_turn_idx].to(device) if not bool(valid_mask[0].item()): continue @@ -1233,15 +1263,15 @@ def _run_multiturn_generation(self, batch: Dict[str, Any]): state.last_phoneme_tokens = None if not model.cfg.get("condition_on_user_speech", False): - user_audio = batch["user_audio_turns"][turn_id] + user_audio = batch["user_audio_turns"][local_turn_idx] user_audio_prefill_steps = int(round(user_audio.size(-1) / model.input_samples_per_frame)) user_audio_prefill_tokens = torch.full( (1, user_audio_prefill_steps), model.pad_id, dtype=torch.long, device=device ) user_audio_channel_embedding = None else: - user_audio = batch["user_audio_turns"][turn_id] - user_audio_lens = batch["user_audio_turns_lens"][turn_id] + user_audio = batch["user_audio_turns"][local_turn_idx] + user_audio_lens = batch["user_audio_turns_lens"][local_turn_idx] user_audio_codes, user_audio_codes_lens = model._codec_helper.audio_to_codes( user_audio, user_audio_lens ) @@ -1329,7 +1359,7 @@ def _run_multiturn_generation(self, batch: Dict[str, Any]): ) turn_start_frame = sum(p.size(-1) for p in state.all_predictions) - if turn_id == 0: + if not turn_frame_ranges: state.audio_prediction_start_idx.fill_(turn_start_frame) decode_start_frame = turn_start_frame @@ -1588,6 +1618,13 @@ def _run_multiturn_user_audio_inference( f"num_turns={len(raw_turn_texts)}" ) + if not batch["batched_turns"]: + logging.warning( + "Skipping multiturn_user_audio sample with no valid agent text turns after filtering; " + f"sample_idx={sample_idx}" + ) + continue + start_time = time.time() output, turn_frame_ranges, decode_start_frame, generated_codes = self._run_multiturn_generation(batch) elapsed = time.time() - start_time @@ -1612,6 +1649,7 @@ def _run_multiturn_user_audio_inference( target_turn_audio_paths = batch.get("target_turn_audio_paths") for local_turn_idx, (turn_id, start_frame, end_frame) in enumerate(turn_frame_ranges): + source_turn_idx = int(turn_id) rel_start_frame = start_frame - decode_start_frame rel_end_frame = end_frame - decode_start_frame start_sample = int(round(rel_start_frame * samples_per_prediction_frame)) @@ -1643,7 +1681,7 @@ def _run_multiturn_user_audio_inference( target_src = self._resolve_target_audio_for_turn( raw_record=raw_record, target_turn_audio_paths=target_turn_audio_paths, - local_turn_idx=local_turn_idx, + local_turn_idx=source_turn_idx, audio_base_dir=audio_base_dir, ) if target_src is None or not os.path.exists(target_src): @@ -1667,7 +1705,7 @@ def _run_multiturn_user_audio_inference( { "audio_filepath": f"target_audio_{item_idx}.wav", "context_audio_filepath": f"context_audio_{item_idx}.wav", - "text": raw_turn_texts[local_turn_idx] if local_turn_idx < len(raw_turn_texts) else "", + "text": raw_turn_texts[source_turn_idx] if source_turn_idx < len(raw_turn_texts) else "", "speaker": str(sample_idx), "source_sample_idx": sample_idx, "turn_id": int(turn_id), @@ -1764,18 +1802,18 @@ def load_debug_audio(path: Optional[str]) -> Optional[torch.Tensor]: user_segments = [] user_turns_for_gt = [] for local_turn_idx, (turn_id, _, _) in enumerate(turn_frame_ranges): - if turn_id >= len(batch["user_audio_turns"]): + if local_turn_idx >= len(batch["user_audio_turns"]): continue - turn_audio = batch["user_audio_turns"][turn_id][0].detach().cpu().float() - turn_audio_len = int(batch["user_audio_turns_lens"][turn_id][0].detach().cpu().item()) + turn_audio = batch["user_audio_turns"][local_turn_idx][0].detach().cpu().float() + turn_audio_len = int(batch["user_audio_turns_lens"][local_turn_idx][0].detach().cpu().item()) turn_audio = turn_audio[:turn_audio_len] turn_audio_out = resample(turn_audio.unsqueeze(0), self.model.sample_rate, sample_rate).squeeze(0) - user_turns_for_gt.append((local_turn_idx, turn_audio_out)) + user_turns_for_gt.append((int(turn_id), turn_audio_out)) - if turn_id == 0: + if local_turn_idx == 0: user_start_sample = 0 else: - prev_turn_end_frame = turn_frame_ranges[turn_id - 1][2] + prev_turn_end_frame = turn_frame_ranges[local_turn_idx - 1][2] rel_prev_end_frame = prev_turn_end_frame - decode_start_frame user_start_sample = first_user_delay_out + int( round(rel_prev_end_frame * samples_per_prediction_frame) @@ -1813,18 +1851,18 @@ def load_debug_audio(path: Optional[str]) -> Optional[torch.Tensor]: gt_user_segments = [] gt_agent_segments = [] cursor = 0 - for local_turn_idx, user_wav in user_turns_for_gt: + for source_turn_idx, user_wav in user_turns_for_gt: gt_user_segments.append((cursor, user_wav)) cursor += user_wav.numel() target_wav = None - if local_turn_idx < len(target_turn_audio_paths): - target_wav = load_debug_audio(target_turn_audio_paths[local_turn_idx]) + if source_turn_idx < len(target_turn_audio_paths): + target_wav = load_debug_audio(target_turn_audio_paths[source_turn_idx]) if target_wav is None: logging.warning( "Could not load target turn audio for multiturn ground-truth debug stereo; " - f"sample_idx={sample_idx}, local_turn_idx={local_turn_idx}" + f"sample_idx={sample_idx}, source_turn_idx={source_turn_idx}" ) continue From 1b160913200bf0d153c892897e7d4b4f1162a90d Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 22 Jun 2026 16:54:09 -0700 Subject: [PATCH 097/109] Fix short words issue Signed-off-by: Edresson Casanova --- .../modules/magpietts_inference/inference.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/nemo/collections/tts/modules/magpietts_inference/inference.py b/nemo/collections/tts/modules/magpietts_inference/inference.py index a21be2fadc81..5d0105f2c626 100644 --- a/nemo/collections/tts/modules/magpietts_inference/inference.py +++ b/nemo/collections/tts/modules/magpietts_inference/inference.py @@ -1325,11 +1325,22 @@ def _run_multiturn_generation(self, batch: Dict[str, Any]): user_audio_channel_embedding = user_audio_embedded delay_tokens = int(state.config.training_mode.streaming_speech_delay) - delay_tokens = min(delay_tokens, int(turn_lens[0].item()), user_audio_prefill_steps) + delay_tokens = min(delay_tokens, user_audio_prefill_steps) + + num_warmup_text_tokens = min(delay_tokens, int(turn_lens[0].item()), turn_text.size(1)) + # handle short turns so that it does not advance the text channel and keep the delay_tokens. + warmup_tokens = torch.full( + (B, delay_tokens), + model.pad_id, + dtype=turn_text.dtype, + device=device, + ) + + if num_warmup_text_tokens > 0: + warmup_tokens[:, :num_warmup_text_tokens] = turn_text[:, :num_warmup_text_tokens] - warmup_tokens = turn_text[:, :delay_tokens] - turn_text = turn_text[:, delay_tokens:] - turn_lens = torch.clamp(turn_lens - delay_tokens, min=0) + turn_text = turn_text[:, num_warmup_text_tokens:] + turn_lens = torch.clamp(turn_lens - num_warmup_text_tokens, min=0) if user_audio_channel_embedding is not None and delay_tokens > 0: warmup_user_audio = user_audio_channel_embedding[:, -delay_tokens:] From 0b65f2ffb68fe2dedcb0cb7be1e69c3822152388 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 23 Jun 2026 08:36:08 -0700 Subject: [PATCH 098/109] Add phoneme prediction on exportable .json and .csv files Signed-off-by: Edresson Casanova --- examples/tts/magpietts_inference.py | 44 +-- .../evaluate_generated_audio.py | 6 + .../modules/magpietts_inference/inference.py | 254 +++++++++++++++++- 3 files changed, 281 insertions(+), 23 deletions(-) diff --git a/examples/tts/magpietts_inference.py b/examples/tts/magpietts_inference.py index a9ebb5c1bfda..1ec31a4414bd 100644 --- a/examples/tts/magpietts_inference.py +++ b/examples/tts/magpietts_inference.py @@ -190,6 +190,9 @@ def _enrich_filewise_metrics_with_manifest(filewise_metrics: list, manifest_path "rank_local_idx", "audio_filepath", "context_audio_filepath", + "predicted_phoneme_text", + "predicted_phoneme_tokens", + "predicted_phoneme_token_labels", ]: if key in record and key not in new_row: new_row[key] = record[key] @@ -249,6 +252,11 @@ def turn_sort_key(r): eou_type_turns = [r.get("eou_type") for r in turns] eou_trailing_duration_turns = [r.get("eou_trailing_duration") for r in turns] eou_trail_rms_ratio_turns = [r.get("eou_trail_rms_ratio") for r in turns] + predicted_phoneme_text_turns = [r.get("predicted_phoneme_text", "") for r in turns] + predicted_phoneme_tokens_turns = [r.get("predicted_phoneme_tokens", []) for r in turns] + predicted_phoneme_token_labels_turns = [ + r.get("predicted_phoneme_token_labels", []) for r in turns + ] grouped_rows.append( { @@ -280,6 +288,9 @@ def turn_sort_key(r): "eou_type_turns": eou_type_turns, "eou_trailing_duration_turns": eou_trailing_duration_turns, "eou_trail_rms_ratio_turns": eou_trail_rms_ratio_turns, + "predicted_phoneme_text_turns": predicted_phoneme_text_turns, + "predicted_phoneme_tokens_turns": predicted_phoneme_tokens_turns, + "predicted_phoneme_token_labels_turns": predicted_phoneme_token_labels_turns, "reference_text": [r.get("gt_text", "") for r in turns], "asr_hyp": [r.get("pred_text", "") for r in turns], "pred_audio_paths": [r.get("pred_audio_filepath", "") for r in turns], @@ -332,6 +343,9 @@ def _write_grouped_multiturn_filewise_metrics_csv(csv_path: str, grouped_rows: l "pred_audio_paths", "reference_text", "asr_hyp", + "predicted_phoneme_text_turns", + "predicted_phoneme_tokens_turns", + "predicted_phoneme_token_labels_turns", ] def csv_value(value): @@ -451,14 +465,7 @@ def _wait_for_multiturn_rank_manifests(repeat_audio_dir: str, world_size: int, t raise RuntimeError(f"Timed out waiting for multiturn rank manifests: {missing}") -def _path_is_under_dir(path: str, directory: str) -> bool: - try: - return os.path.commonpath([os.path.realpath(path), os.path.realpath(directory)]) == os.path.realpath(directory) - except ValueError: - return False - - -def _move_or_copy_rank_output(src: str, dst: str, required: bool = False, rank_dir: str = None) -> None: +def _copy_or_link(src: str, dst: str, required: bool = False) -> None: if src is None or not os.path.exists(src): if os.path.lexists(dst): os.remove(dst) @@ -471,13 +478,8 @@ def _move_or_copy_rank_output(src: str, dst: str, required: bool = False, rank_d if os.path.lexists(dst): os.remove(dst) - # Move files produced inside rank_XXXX/ into the merged evaluation directory. - # If a manifest unexpectedly points outside rank_XXXX/ (for example an - # absolute dataset path), copy instead so original input data is not moved. - if rank_dir is not None and not _path_is_under_dir(src, rank_dir): - shutil.copyfile(src, dst) - else: - shutil.move(src, dst) + # Prefer real files for evaluator inputs; broken symlinks confuse librosa/UTMOS. + shutil.copyfile(src, dst) def _merge_multiturn_rank_outputs(repeat_audio_dir: str, world_size: int, save_predicted_codes: bool) -> str: @@ -485,8 +487,8 @@ def _merge_multiturn_rank_outputs(repeat_audio_dir: str, world_size: int, save_p Each rank writes local files named predicted_audio_0.wav, target_audio_0.wav, context_audio_0.wav, predicted_codes_0.pt, ... inside rank_XXXX/. This - function moves rank-local artifacts to contiguous global indices in - repeat_audio_dir/ and writes a merged turn-level manifest. + function remaps them to contiguous global indices in repeat_audio_dir/ and + writes a merged turn-level manifest. """ # clean previous merged files for pattern in [ @@ -514,23 +516,23 @@ def _merge_multiturn_rank_outputs(repeat_audio_dir: str, world_size: int, save_p for local_idx, record in enumerate(rank_records): pred_src = os.path.join(rank_dir, f"predicted_audio_{local_idx}.wav") pred_dst = os.path.join(repeat_audio_dir, f"predicted_audio_{global_idx}.wav") - _move_or_copy_rank_output(pred_src, pred_dst, required=True, rank_dir=rank_dir) + _copy_or_link(pred_src, pred_dst, required=True) if save_predicted_codes: code_src = os.path.join(rank_dir, f"predicted_codes_{local_idx}.pt") code_dst = os.path.join(repeat_audio_dir, f"predicted_codes_{global_idx}.pt") - _move_or_copy_rank_output(code_src, code_dst, required=False, rank_dir=rank_dir) + _copy_or_link(code_src, code_dst, required=False) target_src = os.path.join(rank_dir, record.get("audio_filepath", f"target_audio_{local_idx}.wav")) target_dst = os.path.join(repeat_audio_dir, f"target_audio_{global_idx}.wav") - _move_or_copy_rank_output(target_src, target_dst, required=True, rank_dir=rank_dir) + _copy_or_link(target_src, target_dst, required=True) context_src = os.path.join( rank_dir, record.get("context_audio_filepath", f"context_audio_{local_idx}.wav"), ) context_dst = os.path.join(repeat_audio_dir, f"context_audio_{global_idx}.wav") - _move_or_copy_rank_output(context_src, context_dst, required=True, rank_dir=rank_dir) + _copy_or_link(context_src, context_dst, required=True) merged = dict(record) merged["audio_filepath"] = f"target_audio_{global_idx}.wav" diff --git a/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py b/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py index 3bcbde1a3d97..a3682ce09462 100644 --- a/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py +++ b/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py @@ -91,6 +91,9 @@ def strip_text_annotations_from_text(text: str) -> str: 'pred_gt_ems', 'pred_text', 'gt_text', + 'predicted_phoneme_text', + 'predicted_phoneme_tokens', + 'predicted_phoneme_token_labels', 'gt_audio_filepath', 'pred_audio_filepath', 'context_audio_filepath', @@ -670,6 +673,9 @@ def evaluate_dir( 'gt_text': gt_text, 'pred_text': pred_text, 'gt_audio_text': gt_audio_text, + 'predicted_phoneme_text': record.get('predicted_phoneme_text', ''), + 'predicted_phoneme_tokens': record.get('predicted_phoneme_tokens', []), + 'predicted_phoneme_token_labels': record.get('predicted_phoneme_token_labels', []), 'detailed_cer': detailed_cer, 'detailed_wer': detailed_wer, 'cer': detailed_cer[0], diff --git a/nemo/collections/tts/modules/magpietts_inference/inference.py b/nemo/collections/tts/modules/magpietts_inference/inference.py index 5d0105f2c626..4b3ede84b014 100644 --- a/nemo/collections/tts/modules/magpietts_inference/inference.py +++ b/nemo/collections/tts/modules/magpietts_inference/inference.py @@ -1237,6 +1237,7 @@ def _run_multiturn_generation(self, batch: Dict[str, Any]): ) turn_frame_ranges = [] + turn_phoneme_outputs = [] decode_start_frame = 0 max_decoder_steps = params.max_decoder_steps @@ -1249,6 +1250,8 @@ def _run_multiturn_generation(self, batch: Dict[str, Any]): if not bool(valid_mask[0].item()): continue + phoneme_start_step = len(getattr(state, "all_phoneme_predictions", []) or []) + state.finished.zero_() state.text_finished.zero_() state.audio_prediction_end_idx.fill_(-1) @@ -1405,7 +1408,25 @@ def _run_multiturn_generation(self, batch: Dict[str, Any]): state.audio_prediction_end_idx.fill_(-1) state.finished.zero_() turn_end_frame = sum(p.size(-1) for p in state.all_predictions) + phoneme_end_step = len(getattr(state, "all_phoneme_predictions", []) or []) + ( + predicted_phoneme_text, + predicted_phoneme_tokens, + predicted_phoneme_token_labels, + ) = self._decode_phoneme_prediction_slice( + getattr(state, "all_phoneme_predictions", []), + phoneme_start_step, + phoneme_end_step, + ) turn_frame_ranges.append((turn_id, turn_start_frame, turn_end_frame)) + turn_phoneme_outputs.append( + { + "turn_id": turn_id, + "predicted_phoneme_text": predicted_phoneme_text, + "predicted_phoneme_tokens": predicted_phoneme_tokens, + "predicted_phoneme_token_labels": predicted_phoneme_token_labels, + } + ) codec_sil_codes = self._ensure_codec_silence_codes() bos_id = getattr(model, "audio_bos_id", -1) @@ -1428,7 +1449,224 @@ def _run_multiturn_generation(self, batch: Dict[str, Any]): finalize_output = model.streaming_finalize(state, use_inference_mode=True) - return finalize_output, turn_frame_ranges, decode_start_frame, generated_codes + return finalize_output, turn_frame_ranges, turn_phoneme_outputs, decode_start_frame, generated_codes + + @staticmethod + def _gpt2_byte_decoder() -> Dict[str, int]: + """Return GPT-2/byte-level-BPE unicode-char -> byte lookup. + + Some NeMo tokenizer wrappers expose byte-level BPE pieces instead of + fully decoded text. Those pieces look like ``ËĪ``/``Ġ`` in JSON. They + are not mojibake from JSON itself; they are the reversible GPT-2 byte + alphabet. Reversing that alphabet gives readable IPA again. + """ + byte_values = list(range(ord("!"), ord("~") + 1)) + byte_values += list(range(ord("¡"), ord("¬") + 1)) + byte_values += list(range(ord("®"), ord("ÿ") + 1)) + unicode_values = list(byte_values) + n = 0 + for byte in range(256): + if byte not in byte_values: + byte_values.append(byte) + unicode_values.append(256 + n) + n += 1 + return {chr(unicode_value): byte for byte, unicode_value in zip(byte_values, unicode_values)} + + @classmethod + def _decode_byte_level_bpe_text(cls, text: str) -> str: + """Decode GPT-2 byte-level BPE artifacts such as ``ËĪ`` and ``Ġ``. + + If ``text`` is already normal Unicode IPA, it is returned unchanged. + """ + if not text: + return "" + + byte_artifact_markers = ("Ġ", "Ċ", "Ë", "É", "Ê", "Ã", "Å", "Î") + if not any(marker in text for marker in byte_artifact_markers): + return text.strip() + + byte_decoder = cls._gpt2_byte_decoder() + byte_values = bytearray() + for ch in text: + value = byte_decoder.get(ch) + if value is not None: + byte_values.append(value) + + if not byte_values: + return text.strip() + + try: + decoded = byte_values.decode("utf-8") + except UnicodeDecodeError: + return text.replace("Ġ", " ").replace("Ċ", "\n").strip() + + return " ".join(decoded.split()) + + @staticmethod + def _phoneme_special_token_map(tokenizer) -> Dict[int, str]: + """Map phoneme special token ids to stable debug labels.""" + special_attrs = [ + ("bos_token_id", ""), + ("eos_token_id", ""), + ("pad_token_id", ""), + ("pad", ""), + ("unk_token_id", ""), + ("mask_token_id", ""), + ] + special = {} + for attr, label in special_attrs: + value = getattr(tokenizer, attr, None) + if value is None: + continue + try: + token_id = int(value) + except Exception: + continue + special.setdefault(token_id, label) + return special + + def _decode_phoneme_id_run(self, token_ids: List[int]) -> str: + """Decode a contiguous non-special phoneme-id span to readable text.""" + if self.model.phoneme_tokenizer is None or not token_ids: + return "" + + tokenizer = self.model.phoneme_tokenizer + candidates = [] + + # Prefer a wrapped HuggingFace tokenizer if present; it knows how to + # merge byte-level BPE pieces correctly. + for decoder in (getattr(tokenizer, "tokenizer", None), getattr(tokenizer, "_tokenizer", None), tokenizer): + if decoder is None or not hasattr(decoder, "decode"): + continue + try: + candidates.append( + decoder.decode( + token_ids, + skip_special_tokens=False, + clean_up_tokenization_spaces=False, + ) + ) + continue + except TypeError: + pass + except Exception: + continue + try: + candidates.append(decoder.decode(token_ids)) + except Exception: + continue + + # Fall back to token-piece lookup. Join without spaces because BPE pieces + # already encode whitespace via the byte-level ``Ġ`` marker. + for converter in (getattr(tokenizer, "tokenizer", None), getattr(tokenizer, "_tokenizer", None), tokenizer): + if converter is None or not hasattr(converter, "convert_ids_to_tokens"): + continue + try: + pieces = converter.convert_ids_to_tokens(token_ids) + if isinstance(pieces, (list, tuple)): + candidates.append("".join(str(piece) for piece in pieces)) + elif pieces is not None: + candidates.append(str(pieces)) + except Exception: + continue + + ids_to_tokens = getattr(tokenizer, "ids_to_tokens", None) + if ids_to_tokens is not None: + try: + if callable(ids_to_tokens): + pieces = [ids_to_tokens(int(token_id)) for token_id in token_ids] + elif isinstance(ids_to_tokens, dict): + pieces = [ids_to_tokens.get(int(token_id), "") for token_id in token_ids] + else: + pieces = [ids_to_tokens[int(token_id)] for token_id in token_ids] + candidates.append("".join(str(piece) for piece in pieces)) + except Exception: + pass + + for candidate in candidates: + text = self._decode_byte_level_bpe_text(str(candidate or "")) + if text: + return text + return "" + + def _format_phoneme_debug_sequence(self, token_ids: List[int]) -> Tuple[str, List[str]]: + """Return readable text and per-token labels while keeping specials. + + ``predicted_phoneme_tokens`` keeps the integer ids, including special + markers. This function creates the human-readable companion fields: + ``predicted_phoneme_text`` and ``predicted_phoneme_token_labels``. + """ + tokenizer = self.model.phoneme_tokenizer + special = self._phoneme_special_token_map(tokenizer) + labels: List[str] = [] + text_parts: List[str] = [] + run: List[int] = [] + + def flush_run() -> None: + if not run: + return + decoded_run = self._decode_phoneme_id_run(run) + if decoded_run: + text_parts.append(decoded_run) + run.clear() + + for token_id in token_ids: + token_id = int(token_id) + special_label = special.get(token_id) + if special_label is not None: + flush_run() + text_parts.append(special_label) + labels.append(f"{special_label}:{token_id}") + else: + run.append(token_id) + decoded_piece = self._decode_phoneme_id_run([token_id]) + labels.append(f"{token_id}:{decoded_piece}" if decoded_piece else str(token_id)) + flush_run() + + return " ".join(part for part in text_parts if part), labels + + def _decode_phoneme_prediction_slice( + self, + phoneme_predictions, + start_step: int, + end_step: int, + ) -> Tuple[str, List[int], List[str]]: + if self.model.phoneme_tokenizer is None or not phoneme_predictions: + return "", [], [] + + start_step = max(0, min(int(start_step), len(phoneme_predictions))) + end_step = max(start_step, min(int(end_step), len(phoneme_predictions))) + if end_step <= start_step: + return "", [], [] + + phoneme_tensor = torch.stack(phoneme_predictions[start_step:end_step], dim=-1) + raw_tokens = phoneme_tensor[0].detach().cpu().T.reshape(-1).long().tolist() + + tokenizer = self.model.phoneme_tokenizer + eos_id = getattr(tokenizer, "eos_token_id", None) + bos_id = getattr(tokenizer, "bos_token_id", None) + + # The streaming phoneme predictor is seeded with BOS as input before the + # first predicted phoneme. Add it explicitly to the debug sequence so the + # exported JSON shows the same boundary condition used by inference. + debug_tokens: List[int] = [] + if bos_id is not None: + try: + debug_tokens.append(int(bos_id)) + except Exception: + pass + + for token in raw_tokens: + token = int(token) + debug_tokens.append(token) + if eos_id is not None and token == int(eos_id): + break + + if not debug_tokens: + return "", [], [] + + phoneme_text, token_labels = self._format_phoneme_debug_sequence(debug_tokens) + return phoneme_text, debug_tokens, token_labels @staticmethod def _save_code_slice( @@ -1637,7 +1875,9 @@ def _run_multiturn_user_audio_inference( continue start_time = time.time() - output, turn_frame_ranges, decode_start_frame, generated_codes = self._run_multiturn_generation(batch) + output, turn_frame_ranges, turn_phoneme_outputs, decode_start_frame, generated_codes = ( + self._run_multiturn_generation(batch) + ) elapsed = time.time() - start_time predicted_audio = output.audio.float().detach().cpu() @@ -1661,6 +1901,11 @@ def _run_multiturn_user_audio_inference( for local_turn_idx, (turn_id, start_frame, end_frame) in enumerate(turn_frame_ranges): source_turn_idx = int(turn_id) + phoneme_output = ( + turn_phoneme_outputs[local_turn_idx] + if local_turn_idx < len(turn_phoneme_outputs) + else {} + ) rel_start_frame = start_frame - decode_start_frame rel_end_frame = end_frame - decode_start_frame start_sample = int(round(rel_start_frame * samples_per_prediction_frame)) @@ -1717,6 +1962,11 @@ def _run_multiturn_user_audio_inference( "audio_filepath": f"target_audio_{item_idx}.wav", "context_audio_filepath": f"context_audio_{item_idx}.wav", "text": raw_turn_texts[source_turn_idx] if source_turn_idx < len(raw_turn_texts) else "", + "predicted_phoneme_text": phoneme_output.get("predicted_phoneme_text", ""), + "predicted_phoneme_tokens": phoneme_output.get("predicted_phoneme_tokens", []), + "predicted_phoneme_token_labels": phoneme_output.get( + "predicted_phoneme_token_labels", [] + ), "speaker": str(sample_idx), "source_sample_idx": sample_idx, "turn_id": int(turn_id), From 6a2190d2bbd6692d4ac4b76246b15f7e20ebfa47 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 23 Jun 2026 09:43:23 -0700 Subject: [PATCH 099/109] Fix try except errors and fix normalization Signed-off-by: Edresson Casanova --- examples/tts/magpietts_inference.py | 4 +- .../tts/data/text_to_speech_dataset_lhotse.py | 10 - .../tts/metrics/emotion_encoder.py | 1 - .../modules/magpietts_inference/inference.py | 189 +++++++++--------- 4 files changed, 101 insertions(+), 103 deletions(-) diff --git a/examples/tts/magpietts_inference.py b/examples/tts/magpietts_inference.py index 1ec31a4414bd..7381e95538f5 100644 --- a/examples/tts/magpietts_inference.py +++ b/examples/tts/magpietts_inference.py @@ -151,7 +151,7 @@ def _mean_finite(values: list): for value in values: try: value = float(value) - except Exception: + except (TypeError, ValueError): continue if np.isfinite(value): vals.append(value) @@ -236,7 +236,7 @@ def _group_multiturn_filewise_metrics_by_sample(filewise_metrics: list) -> list: def turn_sort_key(r): try: return int(r.get("turn_id", 0)) - except Exception: + except (TypeError, ValueError): return 0 turns = sorted(turns, key=turn_sort_key) diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py index 2905de072f96..fe791d09608d 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py @@ -464,16 +464,6 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) else: text_str = cut.supervisions[0].text - should_use_ipa_as_text = ( - self.dataset_type == 'train' - and self.ipa_as_text_prob > 0.0 - and random.random() < self.ipa_as_text_prob - and cut.supervisions[0].has_custom("ipa") - and language not in self.ignore_phoneme_languages - ) - if should_use_ipa_as_text: - text_str = cut.supervisions[0].ipa - raw_text_list.append(text_str) if cut.has_custom("tokenizer_names"): # Pick a random tokenizer from the list of tokenizers diff --git a/nemo/collections/tts/metrics/emotion_encoder.py b/nemo/collections/tts/metrics/emotion_encoder.py index 7b5d0de9cb26..89224be0e4d4 100644 --- a/nemo/collections/tts/metrics/emotion_encoder.py +++ b/nemo/collections/tts/metrics/emotion_encoder.py @@ -43,7 +43,6 @@ from __future__ import annotations import argparse -import json from pathlib import Path from typing import Any, Optional, Sequence, Union diff --git a/nemo/collections/tts/modules/magpietts_inference/inference.py b/nemo/collections/tts/modules/magpietts_inference/inference.py index 4b3ede84b014..efbe5a5f4c5f 100644 --- a/nemo/collections/tts/modules/magpietts_inference/inference.py +++ b/nemo/collections/tts/modules/magpietts_inference/inference.py @@ -745,20 +745,17 @@ def _load_audio_1d(self, path: str, sample_rate: int) -> torch.Tensor: raise FileNotFoundError(f"Missing audio path: {path}") audio, sr = sf.read(path, dtype="float32", always_2d=False) - wav = torch.as_tensor(audio, dtype=torch.float32) - if wav.ndim == 2: - wav = wav.mean(dim=1) - wav = wav.flatten() - if sr != sample_rate: - wav = resample(wav.unsqueeze(0), sr, sample_rate).squeeze(0) + if audio.ndim == 2: + audio = audio.mean(axis=1) if self.normalize_audio: - try: - wav = normalize_volume(wav) - except Exception: - # Keep evaluation robust across normalize_volume signature changes. - pass + audio = normalize_volume(audio) + + wav = torch.as_tensor(audio, dtype=torch.float32).flatten() + + if sr != sample_rate: + wav = resample(wav.unsqueeze(0), sr, sample_rate).squeeze(0) return wav.contiguous() @@ -1502,92 +1499,107 @@ def _decode_byte_level_bpe_text(cls, text: str) -> str: return " ".join(decoded.split()) + @staticmethod + def _token_id_as_int(value) -> Optional[int]: + """Return a scalar token id as int, or None when it is not a valid id.""" + if value is None: + return None + if isinstance(value, torch.Tensor): + if value.numel() != 1: + return None + value = value.detach().cpu().item() + if isinstance(value, int): + return value + if isinstance(value, str): + value = value.strip() + if not value: + return None + signless = value[1:] if value[0] in ("+", "-") else value + if signless.isdigit(): + return int(value) + return None + @staticmethod def _phoneme_special_token_map(tokenizer) -> Dict[int, str]: - """Map phoneme special token ids to stable debug labels.""" + """Map phoneme special token ids to stable debug labels. + + IPABPETokenizer stores its specials directly in ``_inv_vocab`` as + ````, ````, and ````. Prefer those labels so + the exported debug JSON matches the tokenizer vocabulary. + """ + special: Dict[int, str] = {} + + inv_vocab = getattr(tokenizer, "_inv_vocab", None) + if isinstance(inv_vocab, dict): + for token_id, piece in inv_vocab.items(): + token_id = EasyMagpieMultiturnUserAudioInferenceRunner._token_id_as_int(token_id) + if token_id is None: + continue + piece = str(piece) + if piece.startswith("<") and piece.endswith(">"): + special[token_id] = piece + special_attrs = [ - ("bos_token_id", ""), - ("eos_token_id", ""), - ("pad_token_id", ""), - ("pad", ""), - ("unk_token_id", ""), - ("mask_token_id", ""), + ("bos_token_id", ""), + ("bos", ""), + ("eos_token_id", ""), + ("eos", ""), + ("pad_token_id", ""), + ("pad", ""), + ("unk_token_id", ""), + ("mask_token_id", ""), ] - special = {} for attr, label in special_attrs: - value = getattr(tokenizer, attr, None) - if value is None: - continue - try: - token_id = int(value) - except Exception: - continue - special.setdefault(token_id, label) + token_id = EasyMagpieMultiturnUserAudioInferenceRunner._token_id_as_int(getattr(tokenizer, attr, None)) + if token_id is not None: + special.setdefault(token_id, label) + return special + @classmethod + def _phoneme_token_pieces(cls, tokenizer, token_ids: List[int]) -> List[str]: + """Return tokenizer pieces for ids without dropping special tokens.""" + token_ids = [int(token_id) for token_id in token_ids] + + inv_vocab = getattr(tokenizer, "_inv_vocab", None) + if isinstance(inv_vocab, dict): + return [str(inv_vocab.get(token_id, "")) for token_id in token_ids] + + internal_tokenizer = getattr(tokenizer, "_tokenizer", None) + if internal_tokenizer is not None and hasattr(internal_tokenizer, "id_to_token"): + pieces = [] + for token_id in token_ids: + piece = internal_tokenizer.id_to_token(token_id) + pieces.append("" if piece is None else str(piece)) + return pieces + + tokens = getattr(tokenizer, "tokens", None) + if isinstance(tokens, dict): + inv = {int(idx): str(piece) for piece, idx in tokens.items() if cls._token_id_as_int(idx) is not None} + return [inv.get(token_id, "") for token_id in token_ids] + + return [str(token_id) for token_id in token_ids] + def _decode_phoneme_id_run(self, token_ids: List[int]) -> str: - """Decode a contiguous non-special phoneme-id span to readable text.""" + """Decode a contiguous non-special phoneme-id span to readable IPA.""" if self.model.phoneme_tokenizer is None or not token_ids: return "" tokenizer = self.model.phoneme_tokenizer - candidates = [] + pieces = self._phoneme_token_pieces(tokenizer, token_ids) + if pieces: + # IPABPETokenizer piece strings are byte-level BPE symbols. Joining + # and byte-decoding them produces readable IPA, e.g. ``ËĪÉijËIJ`` -> ``ˈɑː``. + decoded_from_pieces = self._decode_byte_level_bpe_text("".join(pieces)) + if decoded_from_pieces: + return decoded_from_pieces - # Prefer a wrapped HuggingFace tokenizer if present; it knows how to - # merge byte-level BPE pieces correctly. - for decoder in (getattr(tokenizer, "tokenizer", None), getattr(tokenizer, "_tokenizer", None), tokenizer): - if decoder is None or not hasattr(decoder, "decode"): - continue - try: - candidates.append( - decoder.decode( - token_ids, - skip_special_tokens=False, - clean_up_tokenization_spaces=False, - ) - ) - continue - except TypeError: - pass - except Exception: - continue - try: - candidates.append(decoder.decode(token_ids)) - except Exception: - continue + decoder = getattr(tokenizer, "decode", None) + if callable(decoder): + decoded = decoder([int(token_id) for token_id in token_ids]) + return self._decode_byte_level_bpe_text(str(decoded or "")) - # Fall back to token-piece lookup. Join without spaces because BPE pieces - # already encode whitespace via the byte-level ``Ġ`` marker. - for converter in (getattr(tokenizer, "tokenizer", None), getattr(tokenizer, "_tokenizer", None), tokenizer): - if converter is None or not hasattr(converter, "convert_ids_to_tokens"): - continue - try: - pieces = converter.convert_ids_to_tokens(token_ids) - if isinstance(pieces, (list, tuple)): - candidates.append("".join(str(piece) for piece in pieces)) - elif pieces is not None: - candidates.append(str(pieces)) - except Exception: - continue - - ids_to_tokens = getattr(tokenizer, "ids_to_tokens", None) - if ids_to_tokens is not None: - try: - if callable(ids_to_tokens): - pieces = [ids_to_tokens(int(token_id)) for token_id in token_ids] - elif isinstance(ids_to_tokens, dict): - pieces = [ids_to_tokens.get(int(token_id), "") for token_id in token_ids] - else: - pieces = [ids_to_tokens[int(token_id)] for token_id in token_ids] - candidates.append("".join(str(piece) for piece in pieces)) - except Exception: - pass - - for candidate in candidates: - text = self._decode_byte_level_bpe_text(str(candidate or "")) - if text: - return text - return "" + return " ".join(pieces) def _format_phoneme_debug_sequence(self, token_ids: List[int]) -> Tuple[str, List[str]]: """Return readable text and per-token labels while keeping specials. @@ -1643,23 +1655,20 @@ def _decode_phoneme_prediction_slice( raw_tokens = phoneme_tensor[0].detach().cpu().T.reshape(-1).long().tolist() tokenizer = self.model.phoneme_tokenizer - eos_id = getattr(tokenizer, "eos_token_id", None) - bos_id = getattr(tokenizer, "bos_token_id", None) + eos_id = self._token_id_as_int(getattr(tokenizer, "eos_token_id", None)) + bos_id = self._token_id_as_int(getattr(tokenizer, "bos_token_id", None)) # The streaming phoneme predictor is seeded with BOS as input before the # first predicted phoneme. Add it explicitly to the debug sequence so the # exported JSON shows the same boundary condition used by inference. debug_tokens: List[int] = [] if bos_id is not None: - try: - debug_tokens.append(int(bos_id)) - except Exception: - pass + debug_tokens.append(bos_id) for token in raw_tokens: token = int(token) debug_tokens.append(token) - if eos_id is not None and token == int(eos_id): + if eos_id is not None and token == eos_id: break if not debug_tokens: From 59788974fc774647c912d21959227a9ff9e18711 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 23 Jun 2026 10:09:06 -0700 Subject: [PATCH 100/109] Clean up unused code Signed-off-by: Edresson Casanova --- examples/tts/magpietts_inference.py | 2 +- .../tts/data/text_to_speech_dataset_lhotse.py | 1 - nemo/collections/tts/models/easy_magpietts.py | 44 ------------------- .../modules/magpietts_inference/inference.py | 2 +- 4 files changed, 2 insertions(+), 47 deletions(-) diff --git a/examples/tts/magpietts_inference.py b/examples/tts/magpietts_inference.py index 7381e95538f5..d3efec180588 100644 --- a/examples/tts/magpietts_inference.py +++ b/examples/tts/magpietts_inference.py @@ -1068,7 +1068,7 @@ def _add_easy_magpie_args(parser: argparse.ArgumentParser) -> None: '--phoneme_sampling_method', type=str, default='argmax', - choices=['argmax', 'multinomial', 'greedy'], + choices=['argmax', 'multinomial'], help='Sampling method for phoneme prediction', ) group.add_argument('--dropout_text_input', action='store_true', help='Force dropout on text input') diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py index fe791d09608d..c9baaaa96324 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse.py @@ -463,7 +463,6 @@ def _sample_context_duration_with_available_limit(available_duration_sec: float) text_str = cut.supervisions[0].normalized_text else: text_str = cut.supervisions[0].text - raw_text_list.append(text_str) if cut.has_custom("tokenizer_names"): # Pick a random tokenizer from the list of tokenizers diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index a32b75f85daa..a95e94e79fd4 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -779,24 +779,6 @@ def prepare_audio_channel_embeddings( audio_codes_input, ) - # Note that consider the current_streaming_speech_delay tokens/user speaking tokens on the loss, - # allowing to predict them in autoregressive way - transition_prefix = int(current_streaming_speech_delay or 0) - if self.cfg.get("agent_mask_include_transition_prefix", False) and transition_prefix > 0: - agent_i = target_agent_mask.float().unsqueeze(1) - - agent_i = torch.nn.functional.pad(agent_i, (0, transition_prefix)) - loss_agent_mask = ( - torch.nn.functional.max_pool1d( - agent_i, - kernel_size=transition_prefix + 1, - stride=1, - ) - .squeeze(1) - .bool() - & valid - ) - # Embed audio tokens audio_embedded = self.embed_audio_tokens(audio_codes_input) # (B, T'-1, E) @@ -1224,32 +1206,6 @@ def process_batch( custom_mask = None if self.cfg.get("phoneme_loss_mask_padding", False): custom_mask = pb_phoneme_tokens_target[:, 0, :] != self.phoneme_tokenizer.pad # (B, T') - elif self.cfg.get("phoneme_loss_mask_include_transition", False): - # agent_mask is aligned to the speech/audio supervision region. - # Expand it left so phoneme loss also covers the phoneme->speech transition. - # The optional +1 gives one extra supervised boundary step, useful for PAD -> BOS / target-shift robustness. - transition_prefix = max( - 0, - int(current_streaming_speech_delay - current_streaming_phonemes_delay) + 1, - ) - agent_i = agent_mask.float().unsqueeze(1) # (B, 1, T) - - if transition_prefix > 0: - # Right padding + max_pool expands the active region to the left. - agent_i = torch.nn.functional.pad(agent_i, (0, transition_prefix)) - - custom_mask = ( - torch.nn.functional.max_pool1d( - agent_i, - kernel_size=transition_prefix + 1, - stride=1, - ) - .squeeze(1) - .bool() - ) - else: - custom_mask = agent_mask - elif self.cfg.get("mask_user_on_loss", False): custom_mask = agent_mask diff --git a/nemo/collections/tts/modules/magpietts_inference/inference.py b/nemo/collections/tts/modules/magpietts_inference/inference.py index efbe5a5f4c5f..16929924a9ef 100644 --- a/nemo/collections/tts/modules/magpietts_inference/inference.py +++ b/nemo/collections/tts/modules/magpietts_inference/inference.py @@ -1364,7 +1364,7 @@ def _run_multiturn_generation(self, batch: Dict[str, Any]): state=state, text_tokens=warmup_tokens[:, i], user_audio_channel_embedding=user_step_emb, - prefill_like_step=not bool(model.cfg.get("agent_mask_include_transition_prefix", False)), + prefill_like_step=True, prefill_like_is_last_step=(i == delay_tokens - 1), use_inference_mode=True, ) From 609bce6258401122e7948722b1ff3a533c89e405 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 23 Jun 2026 12:10:41 -0700 Subject: [PATCH 101/109] Apply Black Signed-off-by: Edresson Casanova --- examples/tts/magpietts_inference.py | 4 +- .../common/data/lhotse/dataloader.py | 8 +- .../tts/metrics/emotion_encoder.py | 83 +++++++------------ .../tts/models/easy_magpietts_inference.py | 5 +- .../evaluate_generated_audio.py | 3 +- .../modules/magpietts_inference/inference.py | 16 +--- .../common/test_lhotse_tts_filters.py | 3 +- 7 files changed, 49 insertions(+), 73 deletions(-) diff --git a/examples/tts/magpietts_inference.py b/examples/tts/magpietts_inference.py index d3efec180588..7d4666497c44 100644 --- a/examples/tts/magpietts_inference.py +++ b/examples/tts/magpietts_inference.py @@ -254,9 +254,7 @@ def turn_sort_key(r): eou_trail_rms_ratio_turns = [r.get("eou_trail_rms_ratio") for r in turns] predicted_phoneme_text_turns = [r.get("predicted_phoneme_text", "") for r in turns] predicted_phoneme_tokens_turns = [r.get("predicted_phoneme_tokens", []) for r in turns] - predicted_phoneme_token_labels_turns = [ - r.get("predicted_phoneme_token_labels", []) for r in turns - ] + predicted_phoneme_token_labels_turns = [r.get("predicted_phoneme_token_labels", []) for r in turns] grouped_rows.append( { diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index f99d615d61b4..3be6fa1162a9 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -258,6 +258,7 @@ class LhotseDataLoadingConfig: # our support of object stores and gzipped files that generally don't have indexes of byte offsets per line. slice_length: Optional[int] = None + def resolve_excluded_speaker_ids(excluded_speaker_ids): if excluded_speaker_ids is None: return None @@ -272,6 +273,7 @@ def resolve_excluded_speaker_ids(excluded_speaker_ids): return excluded_speaker_ids + def determine_use_iterable_dataset(use_iterable_dataset: bool, config: DictConfig) -> bool: """Determine whether to use iterable dataset for a given configuration.""" assert not ( @@ -629,7 +631,11 @@ def get_lhotse_sampler_from_config(config, global_rank, world_size, tokenizer=No # validation status filtering cuts = cuts.filter(ValidationStatusFilter(config.keep)) # Exclude cuts that contain known test speakers. - cuts = cuts.filter(SpeakerFilter(resolve_excluded_speaker_ids(config.excluded_speaker_ids), speaker_fields=config.speaker_filter_fields)) + cuts = cuts.filter( + SpeakerFilter( + resolve_excluded_speaker_ids(config.excluded_speaker_ids), speaker_fields=config.speaker_filter_fields + ) + ) # CER filtering, same as native NeMo dataloaders. cuts = cuts.filter(CERFilter(config.max_cer)) # Context speaker similarity filtering, same as native NeMo dataloaders. diff --git a/nemo/collections/tts/metrics/emotion_encoder.py b/nemo/collections/tts/metrics/emotion_encoder.py index 89224be0e4d4..019d6b14d6cb 100644 --- a/nemo/collections/tts/metrics/emotion_encoder.py +++ b/nemo/collections/tts/metrics/emotion_encoder.py @@ -1,5 +1,17 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """ Lightweight LAION Empathic Insight Voice interface. @@ -137,6 +149,7 @@ # MLP head # ============================================================================= + class FullEmbeddingMLP(nn.Module): """Classifier head used by Empathic Insight Voice. @@ -220,9 +233,7 @@ def extract_projected_embedding(self, x: torch.Tensor) -> torch.Tensor: x = x.squeeze(1) if x.ndim != 3: - raise ValueError( - f"Expected x with shape [B, T, C], got shape {tuple(x.shape)}." - ) + raise ValueError(f"Expected x with shape [B, T, C], got shape {tuple(x.shape)}.") return self.proj(self.flatten(x)) @@ -236,6 +247,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Main class # ============================================================================= + class EmpathicInsightVoice(nn.Module): """Lightweight Hugging Face-style Empathic Insight Voice class. @@ -280,9 +292,7 @@ def __init__( super().__init__() if size not in MODEL_SPECS: - raise ValueError( - f"Unsupported size={size!r}. Expected one of {sorted(MODEL_SPECS)}." - ) + raise ValueError(f"Unsupported size={size!r}. Expected one of {sorted(MODEL_SPECS)}.") self.size = size self.spec = MODEL_SPECS[size] @@ -295,11 +305,7 @@ def __init__( requested_device = torch.device("cpu") self.device = requested_device - self.mlp_device = ( - torch.device(mlp_device) - if mlp_device is not None - else self.device - ) + self.mlp_device = torch.device(mlp_device) if mlp_device is not None else self.device self.sample_rate = int(self.spec["sample_rate"]) self.max_audio_seconds = float(self.spec["max_audio_seconds"]) @@ -409,9 +415,7 @@ def extract_whisper_embedding_from_waveform( input_features = input_features.to(self.device) input_features = input_features.to(self.whisper_model.dtype) - encoder_outputs = self.whisper_model.get_encoder()( - input_features=input_features - ) + encoder_outputs = self.whisper_model.get_encoder()(input_features=input_features) embedding = encoder_outputs.last_hidden_state embedding = self._pad_or_trim_embedding(embedding) @@ -908,20 +912,14 @@ def _get_classifier(self, label: str) -> FullEmbeddingMLP: Restored FullEmbeddingMLP. """ if label not in LAION_LABEL_TO_FILENAME: - raise ValueError( - f"Unknown label {label!r}. Available labels: " - f"{sorted(LAION_LABEL_TO_FILENAME)}" - ) + raise ValueError(f"Unknown label {label!r}. Available labels: " f"{sorted(LAION_LABEL_TO_FILENAME)}") cache_key = self._cache_key(label) if cache_key in self.classifiers: classifier = self.classifiers[cache_key] if not isinstance(classifier, FullEmbeddingMLP): - raise TypeError( - f"Cached classifier for {label!r} has unexpected type " - f"{type(classifier)}." - ) + raise TypeError(f"Cached classifier for {label!r} has unexpected type " f"{type(classifier)}.") return classifier filename = LAION_LABEL_TO_FILENAME[label] @@ -944,10 +942,7 @@ def _get_classifier(self, label: str) -> FullEmbeddingMLP: state_dict = torch.load(local_path, map_location="cpu") if not isinstance(state_dict, dict): - raise RuntimeError( - f"Expected {local_path} to contain a state_dict, " - f"but got {type(state_dict)}." - ) + raise RuntimeError(f"Expected {local_path} to contain a state_dict, " f"but got {type(state_dict)}.") state_dict = self._strip_orig_mod_prefix_if_needed(state_dict) @@ -990,8 +985,7 @@ def _validate_labels( if unknown: raise ValueError( - f"Unknown labels: {unknown}. " - f"Available labels: {sorted(LAION_LABEL_TO_FILENAME.keys())}" + f"Unknown labels: {unknown}. " f"Available labels: {sorted(LAION_LABEL_TO_FILENAME.keys())}" ) return labels_list @@ -1032,16 +1026,10 @@ def _pad_or_trim_embedding(self, embedding: torch.Tensor) -> torch.Tensor: embed_dim = int(self.spec["embed_dim"]) if embedding.ndim != 3: - raise RuntimeError( - f"Expected Whisper embedding with shape [B, T, C], " - f"got {tuple(embedding.shape)}." - ) + raise RuntimeError(f"Expected Whisper embedding with shape [B, T, C], " f"got {tuple(embedding.shape)}.") if embedding.shape[-1] != embed_dim: - raise RuntimeError( - f"Unexpected embedding dim. Expected {embed_dim}, " - f"got {embedding.shape[-1]}." - ) + raise RuntimeError(f"Unexpected embedding dim. Expected {embed_dim}, " f"got {embedding.shape[-1]}.") current_seq_len = embedding.shape[1] @@ -1114,13 +1102,7 @@ def _strip_orig_mod_prefix_if_needed( @staticmethod def _cache_key(label: str) -> str: """Convert an arbitrary label into a safe ModuleDict key.""" - return ( - label.replace(".", "_") - .replace("/", "_") - .replace("-", "_") - .replace(" ", "_") - .replace("&", "and") - ) + return label.replace(".", "_").replace("/", "_").replace("-", "_").replace(" ", "_").replace("&", "and") def cleanup(self) -> None: """Move modules to CPU and clear classifier cache.""" @@ -1138,6 +1120,7 @@ def cleanup(self) -> None: # CLI utilities # ============================================================================= + def _tensor_info(tensor: torch.Tensor) -> dict[str, Any]: return { "shape": list(tensor.shape), @@ -1154,9 +1137,7 @@ def _parse_labels(labels: Optional[str]) -> Optional[list[str]]: def main() -> None: - parser = argparse.ArgumentParser( - description="LAION Empathic Insight Voice embeddings and similarity." - ) + parser = argparse.ArgumentParser(description="LAION Empathic Insight Voice embeddings and similarity.") parser.add_argument( "--audio", type=str, @@ -1270,7 +1251,7 @@ def main() -> None: embedding_type=args.embedding_type, include_auxiliary=args.include_auxiliary, ) - + result = model.compare_emotion_pair( audio_path_a=args.audio, audio_path_b=args.audio_b, @@ -1295,7 +1276,6 @@ def main() -> None: print(result["top_emotion_match"]) print(result["emotion_similarity"]) - print("embedding_type=score_vector") print(result_score_vector["audio_a_top_emotion"]) print(result_score_vector["audio_b_top_emotion"]) @@ -1308,5 +1288,6 @@ def main() -> None: print(result_score_mean["top_emotion_match"]) print(result_score_mean["emotion_similarity"]) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index 96135df233d1..0dfea3b653e0 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -1702,10 +1702,7 @@ def streaming_step( # Match training: the input slot that predicts the first agent frame after # a user/non-agent region should receive the learned user-speaking-end token. - use_end_token = ( - prefill_like_is_last_step - and self.cfg.get("use_user_speaking_end_token", False) - ) + use_end_token = prefill_like_is_last_step and self.cfg.get("use_user_speaking_end_token", False) if use_end_token: state.last_audio_codes = torch.full( diff --git a/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py b/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py index a3682ce09462..5601366addfe 100644 --- a/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py +++ b/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py @@ -83,6 +83,7 @@ def strip_text_annotations_from_text(text: str) -> str: text = _SPACE_BEFORE_PUNCT_RE.sub(r"\1", text) return text.strip() + FILEWISE_METRICS_TO_SAVE = [ 'cer', 'wer', @@ -602,7 +603,7 @@ def evaluate_dir( model=speaker_verification_model_alternate, extractor=feature_extractor, device=device, - sv_model_type="titanet", # alternate is always titanet + sv_model_type="titanet", # alternate is always titanet ) # Initialize SSIMs with a default since the context or ground truth audio diff --git a/nemo/collections/tts/modules/magpietts_inference/inference.py b/nemo/collections/tts/modules/magpietts_inference/inference.py index 16929924a9ef..265686fa2a1b 100644 --- a/nemo/collections/tts/modules/magpietts_inference/inference.py +++ b/nemo/collections/tts/modules/magpietts_inference/inference.py @@ -1098,9 +1098,7 @@ def _move_batch_to_device(batch: Dict[str, Any], device: torch.device) -> Dict[s return out @staticmethod - def _copy_file( - src: Optional[str], dst: str, required: bool = False, description: str = "audio" - ) -> Optional[str]: + def _copy_file(src: Optional[str], dst: str, required: bool = False, description: str = "audio") -> Optional[str]: """Copy an audio artifact and optionally fail fast if missing. Evaluation later expects target_audio_*.wav/context_audio_*.wav to exist. @@ -1125,9 +1123,7 @@ def _copy_file( shutil.copy(src, dst) if required and not os.path.exists(dst): - raise FileNotFoundError( - f"Failed to materialize required {description}: src={src}, dst={dst}." - ) + raise FileNotFoundError(f"Failed to materialize required {description}: src={src}, dst={dst}.") return dst def _resolve_audio_path(self, path: Optional[str], audio_base_dir: str) -> Optional[str]: @@ -1911,9 +1907,7 @@ def _run_multiturn_user_audio_inference( for local_turn_idx, (turn_id, start_frame, end_frame) in enumerate(turn_frame_ranges): source_turn_idx = int(turn_id) phoneme_output = ( - turn_phoneme_outputs[local_turn_idx] - if local_turn_idx < len(turn_phoneme_outputs) - else {} + turn_phoneme_outputs[local_turn_idx] if local_turn_idx < len(turn_phoneme_outputs) else {} ) rel_start_frame = start_frame - decode_start_frame rel_end_frame = end_frame - decode_start_frame @@ -1973,9 +1967,7 @@ def _run_multiturn_user_audio_inference( "text": raw_turn_texts[source_turn_idx] if source_turn_idx < len(raw_turn_texts) else "", "predicted_phoneme_text": phoneme_output.get("predicted_phoneme_text", ""), "predicted_phoneme_tokens": phoneme_output.get("predicted_phoneme_tokens", []), - "predicted_phoneme_token_labels": phoneme_output.get( - "predicted_phoneme_token_labels", [] - ), + "predicted_phoneme_token_labels": phoneme_output.get("predicted_phoneme_token_labels", []), "speaker": str(sample_idx), "source_sample_idx": sample_idx, "turn_id": int(turn_id), diff --git a/tests/collections/common/test_lhotse_tts_filters.py b/tests/collections/common/test_lhotse_tts_filters.py index e03afa7a73da..4242820200d7 100644 --- a/tests/collections/common/test_lhotse_tts_filters.py +++ b/tests/collections/common/test_lhotse_tts_filters.py @@ -143,6 +143,7 @@ def test_cut_validation_status_filter(cut_example): f = ValidationStatusFilter("any_other_status") assert f(cut_example) == False + def test_cut_speaker_filter_by_speaker(cut_example): f = SpeakerFilter( excluded_speaker_ids=["| Language:en Dataset:nvyt2505 Speaker:Zdud2gXLTXY_SPEAKER_02 |"], @@ -208,4 +209,4 @@ def test_cut_speaker_filter_requires_fields_when_enabled(): SpeakerFilter( excluded_speaker_ids=["test_speaker_001"], speaker_fields=None, - ) \ No newline at end of file + ) From 82d1efaf573ac25854558eeeb15590f0f4fd1c26 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 23 Jun 2026 12:30:29 -0700 Subject: [PATCH 102/109] Fix imports order Signed-off-by: Edresson Casanova --- .../data/text_to_speech_dataset_lhotse_multiturn.py | 2 +- nemo/collections/tts/models/easy_magpietts.py | 6 +++--- .../tts/models/easy_magpietts_inference.py | 13 +++++-------- .../tts/modules/magpietts_inference/inference.py | 2 +- tests/collections/common/test_lhotse_tts_filters.py | 2 +- 5 files changed, 11 insertions(+), 14 deletions(-) diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py index 72c60c05bebf..618f5de3033e 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py @@ -22,7 +22,7 @@ from hydra.utils import instantiate from lhotse import CutSet, Seconds, compute_num_frames from lhotse.cut import Cut -from lhotse.dataset.collation import collate_matrices, collate_vectors, collate_audio +from lhotse.dataset.collation import collate_audio, collate_matrices, collate_vectors from lhotse.utils import ifnone from omegaconf import DictConfig from transformers import AutoTokenizer, T5Tokenizer diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index a95e94e79fd4..f6ad19eb14b5 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -30,11 +30,10 @@ import nemo.collections.asr as nemo_asr from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.asr.parts.mixins.transcription import TranscribeConfig -from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config from nemo.collections.common.data.fallback import FallbackDataset +from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config from nemo.collections.tts.data.text_to_speech_dataset_lhotse import MagpieTTSLhotseDataset, setup_tokenizers from nemo.collections.tts.data.text_to_speech_dataset_lhotse_multiturn import MagpieTTSLhotseMultiturnDataset - from nemo.collections.tts.models.easy_magpietts_inference import EasyMagpieTTSInferenceModel, TrainingMode from nemo.collections.tts.modules.magpietts_modules import ( LocalTransformerType, @@ -60,9 +59,10 @@ except (ImportError, ModuleNotFoundError): HAVE_UTMOSV2 = False -from transformers import WhisperForConditionalGeneration, WhisperProcessor from typing import List +from transformers import WhisperForConditionalGeneration, WhisperProcessor + @dataclass class ProcessBatchOutput: diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index 0dfea3b653e0..7f18322c9963 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -12,14 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import random -import time -import random -import time import tempfile +import time +from collections import Counter from dataclasses import dataclass, fields from functools import partial from typing import Any, Dict, List, Optional, Sequence, Tuple -from collections import Counter import numpy as np import soundfile as sf @@ -29,11 +27,11 @@ from torch import nn from transformers import AutoConfig, AutoModelForCausalLM -from nemo.core.connectors.save_restore_connector import SaveRestoreConnector +from nemo.collections.audio.parts.utils.transforms import resample +from nemo.collections.speechlm2.parts.pretrained import set_model_dict_for_partial_init from nemo.collections.tts.data.text_to_speech_dataset_lhotse import setup_tokenizers from nemo.collections.tts.models import AudioCodecModel from nemo.collections.tts.modules import transformer_2501 -from nemo.collections.audio.parts.utils.transforms import resample from nemo.collections.tts.modules.audio_codec_modules import VectorQuantizerIndexConverter from nemo.collections.tts.modules.magpietts_modules import ( CharAwareSubwordEncoder, @@ -46,11 +44,10 @@ from nemo.collections.tts.parts.utils.helpers import get_mask_from_lengths from nemo.core.classes import ModelPT from nemo.core.classes.common import PretrainedModelInfo, safe_instantiate +from nemo.core.connectors.save_restore_connector import SaveRestoreConnector from nemo.utils import logging from nemo.utils.exceptions import NeMoBaseException -from nemo.collections.speechlm2.parts.pretrained import set_model_dict_for_partial_init - @dataclass class TrainingMode: diff --git a/nemo/collections/tts/modules/magpietts_inference/inference.py b/nemo/collections/tts/modules/magpietts_inference/inference.py index 265686fa2a1b..eaa0759cbd1a 100644 --- a/nemo/collections/tts/modules/magpietts_inference/inference.py +++ b/nemo/collections/tts/modules/magpietts_inference/inference.py @@ -39,11 +39,11 @@ import torch from nemo.collections.asr.parts.utils.manifest_utils import read_manifest +from nemo.collections.audio.parts.utils.transforms import resample from nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers import AggregatedTTSTokenizer, IPATokenizer from nemo.collections.tts.data.text_to_speech_dataset import ChunkedTTSInferenceDataset, MagpieTTSDataset from nemo.collections.tts.models.easy_magpietts_inference import EasyModelInferenceParameters from nemo.collections.tts.models.magpietts import ModelInferenceParameters -from nemo.collections.audio.parts.utils.transforms import resample from nemo.collections.tts.parts.utils.tts_dataset_utils import normalize_volume, stack_tensors from nemo.utils import logging diff --git a/tests/collections/common/test_lhotse_tts_filters.py b/tests/collections/common/test_lhotse_tts_filters.py index 4242820200d7..90a08f32d55e 100644 --- a/tests/collections/common/test_lhotse_tts_filters.py +++ b/tests/collections/common/test_lhotse_tts_filters.py @@ -19,8 +19,8 @@ from nemo.collections.common.data.lhotse.sampling import ( CERFilter, - SpeakerFilter, ContextSpeakerSimilarityFilter, + SpeakerFilter, ValidationStatusFilter, ) From e987ee1a4be5b1ec36feb8089fd24674656da98d Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 23 Jun 2026 14:00:53 -0700 Subject: [PATCH 103/109] Add unit tests Signed-off-by: Edresson Casanova --- .../collections/speechlm2/parts/pretrained.py | 3 +- .../tts/data/test_magpietts_dataset_lhotse.py | 257 +++++++++ .../tts/models/test_easy_magpietts.py | 520 ++++++++++++++++++ 3 files changed, 779 insertions(+), 1 deletion(-) create mode 100644 tests/collections/tts/data/test_magpietts_dataset_lhotse.py create mode 100644 tests/collections/tts/models/test_easy_magpietts.py diff --git a/nemo/collections/speechlm2/parts/pretrained.py b/nemo/collections/speechlm2/parts/pretrained.py index 301f31c2179e..a557dd9291aa 100644 --- a/nemo/collections/speechlm2/parts/pretrained.py +++ b/nemo/collections/speechlm2/parts/pretrained.py @@ -332,7 +332,8 @@ def set_model_dict_for_partial_init( dimension 0 differs and all trailing dimensions match. Defaults to False. Returns: - Updated model state dictionary with compatible pretrained weights loaded. + Dict[str, torch.Tensor]: + The updated model state dictionary with compatible pretrained weights loaded. Example: >>> model_dict = model.state_dict() diff --git a/tests/collections/tts/data/test_magpietts_dataset_lhotse.py b/tests/collections/tts/data/test_magpietts_dataset_lhotse.py new file mode 100644 index 000000000000..9f0dbaab58f1 --- /dev/null +++ b/tests/collections/tts/data/test_magpietts_dataset_lhotse.py @@ -0,0 +1,257 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +from io import BytesIO +from pathlib import Path + +import numpy as np +import pytest +import torch +from lhotse import CutSet, SupervisionSegment +from lhotse.array import Array, TemporalArray +from lhotse.testing.dummies import dummy_cut, dummy_recording +from omegaconf import OmegaConf + +from nemo.collections.tts.data.text_to_speech_dataset_lhotse import MagpieTTSLhotseDataset +from nemo.collections.tts.data.text_to_speech_dataset_lhotse_multiturn import MagpieTTSLhotseMultiturnDataset + + +pytestmark = pytest.mark.unit + +SAMPLE_RATE = 24000 +CODEC_MODEL_SAMPLES_PER_FRAME = 480 +CODEC_MODEL_INPUT_SAMPLE_RATE = 24000 +FRAME_STACKING_FACTOR = 1 +NUM_AUDIO_CODEBOOKS = 8 + +BPE_TOKENIZER_NAME = "nemotron_bpe" +BPE_TOKENIZER_MODEL = "nvidia/NVIDIA-Nemotron-Nano-9B-v2" +BPE_TOKENIZER_CACHED_PATH = Path("/home/TestData/nvidia--NVIDIA-Nemotron-Nano-9B-v2/") +if BPE_TOKENIZER_CACHED_PATH.exists(): + BPE_TOKENIZER_MODEL = str(BPE_TOKENIZER_CACHED_PATH) + + +def _seed_everything(): + random.seed(42) + np.random.seed(42) + torch.manual_seed(42) + + +def _tokenizer_config(): + return OmegaConf.create( + { + BPE_TOKENIZER_NAME: { + "_target_": "AutoTokenizer", + "pretrained_model": BPE_TOKENIZER_MODEL, + } + } + ) + + +def _memory_temporal_array(values, frame_shift=0.02): + buffer = BytesIO() + np.save(buffer, values) + return TemporalArray( + array=Array( + storage_type="memory_npy", + storage_path="", + storage_key=buffer.getvalue(), + shape=list(values.shape), + ), + temporal_dim=-1, + frame_shift=frame_shift, + start=0, + ) + + +def _cached_codes(num_codebooks=NUM_AUDIO_CODEBOOKS, num_frames=5, offset=0): + codes = np.arange(num_codebooks * num_frames, dtype=np.int32).reshape(num_codebooks, num_frames) + return (codes + offset) % 16 + + +def _single_turn_cutset(): + cut = dummy_cut( + 0, + duration=0.5, + recording=dummy_recording(0, duration=0.5, with_data=True, sampling_rate=SAMPLE_RATE), + ) + cut.target_audio = dummy_recording(10, duration=0.5, with_data=True, sampling_rate=SAMPLE_RATE) + cut.supervisions = [ + SupervisionSegment( + id="single-turn", + recording_id=cut.recording_id, + start=0.0, + duration=0.2, + text="hello", + language="en", + speaker="| Language:en Dataset:Unit Speaker:spk |", + custom={"context_text": "speaker prompt", "normalized_text": "hello"}, + ) + ] + cut.custom = { + **(cut.custom or {}), + "target_codes": _memory_temporal_array(_cached_codes(num_frames=4)), + "context_codes": _memory_temporal_array(_cached_codes(num_frames=3, offset=3)), + "tokenizer_names": [BPE_TOKENIZER_NAME], + "lang": "en", + } + return CutSet.from_cuts([cut]) + + +def _multiturn_cutset(): + cut = dummy_cut( + 1, + duration=0.8, + recording=dummy_recording(1, duration=0.8, with_data=True, sampling_rate=SAMPLE_RATE), + ) + cut.target_audio = dummy_recording(11, duration=0.8, with_data=True, sampling_rate=SAMPLE_RATE) + cut.supervisions = [ + SupervisionSegment( + id="turn-user-0", + recording_id=cut.recording_id, + start=0.0, + duration=0.2, + text="hi", + language="en", + speaker="user", + custom={"context_text": "chat prompt"}, + ), + SupervisionSegment( + id="turn-agent-0", + recording_id=cut.recording_id, + start=0.3, + duration=0.2, + text="hello", + language="en", + speaker="assistant", + ), + SupervisionSegment( + id="turn-user-1", + recording_id=cut.recording_id, + start=0.55, + duration=0.1, + text="ok", + language="en", + speaker="user", + ), + SupervisionSegment( + id="turn-agent-1", + recording_id=cut.recording_id, + start=0.68, + duration=0.1, + text="okay", + language="en", + speaker="assistant", + custom={"reward": 0.5}, + ), + ] + cut.custom = { + **(cut.custom or {}), + "task": "dialog", + "target_codes": _memory_temporal_array(_cached_codes(num_frames=8)), + "source_codes": _memory_temporal_array(_cached_codes(num_frames=8, offset=1)), + "context_codes": _memory_temporal_array(_cached_codes(num_frames=4, offset=2)), + "tokenizer_names": [BPE_TOKENIZER_NAME], + "lang": "en", + } + return CutSet.from_cuts([cut]) + + +def _dataset_kwargs(): + return { + "sample_rate": SAMPLE_RATE, + "volume_norm": False, + "codec_model_samples_per_frame": CODEC_MODEL_SAMPLES_PER_FRAME, + "num_audio_codebooks": NUM_AUDIO_CODEBOOKS, + "prior_scaling_factor": None, + "load_cached_codes_if_available": True, + "dataset_type": "train", + "load_16khz_audio": False, + "pad_context_text_to_max_duration": False, + "context_duration_min": 0.04, + "context_duration_max": 0.04, + "use_text_conditioning_tokenizer": True, + "text_conditioning_tokenizer_name": BPE_TOKENIZER_NAME, + "tokenizer_config": _tokenizer_config(), + } + + +class TestMagpieTTSLhotseDatasets: + def test_single_turn_dataset_uses_bpe_and_cached_codes(self): + _seed_everything() + dataset = MagpieTTSLhotseDataset(**_dataset_kwargs()) + + batch = dataset[_single_turn_cutset()] + + assert batch["dataset_names"] == ["Unit"] + assert batch["languages"] == ["en"] + assert batch["raw_texts"] == ["hello"] + assert "audio" not in batch + assert "context_audio" not in batch + assert batch["audio_codes"].shape == (1, NUM_AUDIO_CODEBOOKS, 4) + assert batch["audio_codes_lens"].tolist() == [4] + assert batch["context_audio_codes"].shape == (1, NUM_AUDIO_CODEBOOKS, 2) + assert batch["context_audio_codes_lens"].tolist() == [2] + assert batch["text"].shape[0] == 1 + assert batch["text_lens"].item() > 0 + assert batch["context_text_tokens"].shape[0] == 1 + assert batch["context_text_tokens_lens"].item() > 0 + assert batch["has_text_context"].tolist() == [True] + + def test_multiturn_dataset_uses_bpe_and_cached_codes(self): + _seed_everything() + kwargs = _dataset_kwargs() + kwargs.update( + { + "codec_model_input_sample_rate": CODEC_MODEL_INPUT_SAMPLE_RATE, + "frame_stacking_factor": FRAME_STACKING_FACTOR, + "source_sample_rate": SAMPLE_RATE, + "input_roles": ["user"], + "output_roles": ["assistant"], + "add_text_bos": False, + } + ) + dataset = MagpieTTSLhotseMultiturnDataset(**kwargs) + + batch = dataset[_multiturn_cutset()] + + assert batch["sample_id"] == ["dummy-mono-cut-0001"] + assert batch["dataset_names"] == ["unknown"] + assert batch["languages"] == ["en"] + assert batch["raw_texts"] == ["hello okay"] + assert batch["task"] == ["dialog"] + assert batch["audio"].shape[0] == 1 + assert batch["source_audio"].shape[0] == 1 + assert batch["audio_codes"].shape == (1, NUM_AUDIO_CODEBOOKS, 8) + assert batch["audio_codes_lens"].tolist() == [8] + assert batch["source_codes"].shape == (1, NUM_AUDIO_CODEBOOKS, 8) + assert batch["source_codes_lens"].tolist() == [8] + assert batch["context_audio_codes"].shape == (1, NUM_AUDIO_CODEBOOKS, 2) + assert batch["context_audio_codes_lens"].tolist() == [2] + assert batch["source_tokens"].shape[0] == 1 + assert batch["source_token_lens"].item() > 0 + assert batch["text"].shape[0] == 1 + assert batch["text_lens"].item() > 0 + assert batch["agent_mask"].shape == batch["user_mask"].shape + assert batch["agent_mask_lens"].tolist() == batch["user_mask_lens"].tolist() + assert batch["agent_mask"].sum().item() > 0 + assert batch["user_mask"].sum().item() > 0 + assert batch["user_audio_turn_splitted"].shape[0] == 2 + assert batch["user_audio_turn_splitted_lens"].tolist() == [4800, 2400] + assert batch["user_audio_turn_splitted_indices"].shape == (2, 3) + assert batch["context_text_tokens"].shape[0] == 1 + assert batch["context_text_tokens_lens"].item() > 0 + assert batch["has_text_context"].tolist() == [True] + torch.testing.assert_close(batch["rewards"], torch.tensor([0.5], device=batch["rewards"].device)) diff --git a/tests/collections/tts/models/test_easy_magpietts.py b/tests/collections/tts/models/test_easy_magpietts.py new file mode 100644 index 000000000000..6f66f6a76a96 --- /dev/null +++ b/tests/collections/tts/models/test_easy_magpietts.py @@ -0,0 +1,520 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +from contextlib import contextmanager +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import patch + +import numpy as np +import pytest +import torch +from omegaconf import OmegaConf +from torch import nn + +from nemo.collections.tts.models import AudioCodecModel +from nemo.collections.tts.models.easy_magpietts import EasyMagpieTTSModel +from nemo.collections.tts.models.easy_magpietts_inference import EasyModelInferenceParameters, TrainingMode +from tests.collections.tts.models.test_audio_codec import create_codec_config + + +if torch.cuda.is_available(): + torch.set_default_device("cuda") + + +pytestmark = pytest.mark.unit + +BPE_TOKENIZER_NAME = "nemotron_bpe" +BPE_TOKENIZER_MODEL = "nvidia/NVIDIA-Nemotron-Nano-9B-v2" +BPE_TOKENIZER_CACHED_PATH = Path("/home/TestData/nvidia--NVIDIA-Nemotron-Nano-9B-v2/") +if BPE_TOKENIZER_CACHED_PATH.exists(): + BPE_TOKENIZER_MODEL = str(BPE_TOKENIZER_CACHED_PATH) + + +def _restore_codec_as_random_initialized_model(*args, **kwargs): + del args + if kwargs.get("return_config", False): + return create_codec_config() + + codec_cfg = kwargs.get("override_config_path", None) + if codec_cfg is None: + codec_cfg = create_codec_config() + codec_model = AudioCodecModel(cfg=codec_cfg) + codec_model.freeze() + return codec_model + + +@contextmanager +def _codec_restore_uses_random_initialized_audio_codec(): + from nemo.collections.tts.models import easy_magpietts_inference + + with patch.object( + easy_magpietts_inference.AudioCodecModel, + "restore_from", + staticmethod(_restore_codec_as_random_initialized_model), + ): + yield + + +def _seed_everything(): + random.seed(42) + np.random.seed(42) + torch.manual_seed(42) + + +def tiny_easy_magpie_cfg(overrides=None): + cfg = OmegaConf.create( + { + "codecmodel_path": "dummy_codec.nemo", + "decoder_type": "nemotron_h", + "embedding_dim": 32, + "hidden_dim": 32, + "audio_embedding_dim": 16, + "frame_stacking_factor": 1, + "local_transformer_type": "none", + "disable_lm_text_head": True, + "disable_subword_embedding": False, + "use_bpe_char_tokenizer": True, + "text_conditioning_tokenizer_name": BPE_TOKENIZER_NAME, + "use_multiturn_dataset": False, + "run_val_inference": False, + "use_utmos": False, + "cfg_unconditional_prob": 0.0, + "dropout_text_input_prob": 0.0, + "phoneme_corruption_batch_prob": 0.0, + "phoneme_corruption_timestep_ratio": 0.0, + "phoneme_as_text_prob": 0.0, + "mask_user_on_loss": True, + "text_tokenizers": { + BPE_TOKENIZER_NAME: { + "_target_": "AutoTokenizer", + "pretrained_model": BPE_TOKENIZER_MODEL, + } + }, + "training_modes": [ + { + "text_input_mode": "streaming", + "streaming_phonemes_delay": 1, + "streaming_speech_delay": 2, + } + ], + "nemotron_h_config": { + "hidden_size": 32, + "num_hidden_layers": 1, + "vocab_size": 64, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "mamba_num_heads": 4, + "mamba_head_dim": 8, + "ssm_state_size": 8, + "n_groups": 2, + "intermediate_size": 64, + "hybrid_override_pattern": "*", + "use_cache": True, + "_attn_implementation": "sdpa", + }, + "optimizer": { + "_target_": "torch.optim.AdamW", + "lr": 0.001, + }, + } + ) + if overrides is not None: + cfg = OmegaConf.merge(cfg, overrides) + return cfg + + +def _make_easy_magpie_model(cfg=None): + _seed_everything() + with _codec_restore_uses_random_initialized_audio_codec(): + model = EasyMagpieTTSModel(cfg or tiny_easy_magpie_cfg()) + model.eval() + return model + + +@pytest.fixture() +def model(): + return _make_easy_magpie_model() + + +def _padded_token_tensor(model, texts): + tokenized = [ + model.tokenizer.encode(text, tokenizer_name=BPE_TOKENIZER_NAME) + [model.eos_id] for text in texts + ] + lens = torch.tensor([len(tokens) for tokens in tokenized], dtype=torch.long) + max_len = int(lens.max().item()) + padded = torch.full((len(tokenized), max_len), model.pad_id, dtype=torch.long) + for idx, tokens in enumerate(tokenized): + padded[idx, : len(tokens)] = torch.tensor(tokens, dtype=torch.long) + return padded, lens + + +def _toy_codes(model, batch_size, num_frames): + codes = torch.zeros(batch_size, model.num_audio_codebooks, num_frames, dtype=torch.long) + frame_ids = torch.arange(num_frames, dtype=torch.long) + for batch_idx in range(batch_size): + for codebook_idx in range(model.num_audio_codebooks): + codes[batch_idx, codebook_idx] = (frame_ids + batch_idx + codebook_idx * 7) % model.codebook_size + return codes + + +def _toy_batch(model): + text, text_lens = _padded_token_tensor(model, ["abc", "de"]) + context_text_tokens, context_text_tokens_lens = _padded_token_tensor(model, ["hi", "ok"]) + + audio_codes = _toy_codes(model, batch_size=2, num_frames=4) + audio_codes_lens = torch.tensor([4, 3], dtype=torch.long) + context_audio_codes = _toy_codes(model, batch_size=2, num_frames=2) + context_audio_codes[1, :, 1] = 0 + context_audio_codes_lens = torch.tensor([2, 1], dtype=torch.long) + agent_mask = torch.tensor( + [ + [False, True, True, False, False], + [True, False, True, False, False], + ], + dtype=torch.bool, + ) + + return { + "text": text, + "text_lens": text_lens, + "context_text_tokens": context_text_tokens, + "context_text_tokens_lens": context_text_tokens_lens, + "audio_codes": audio_codes, + "audio_codes_lens": audio_codes_lens, + "context_audio_codes": context_audio_codes, + "context_audio_codes_lens": context_audio_codes_lens, + "agent_mask": agent_mask, + "task": ["tts", "tts"], + } + + +@pytest.fixture() +def toy_batch(model): + return _toy_batch(model) + + +def test_training_mode_and_inference_parameters(): + mode = TrainingMode( + text_input_mode="streaming", + streaming_phonemes_delay=4, + streaming_speech_delay=8, + mode_idx=2, + ) + assert mode.name == "streaming_4_8" + + params = EasyModelInferenceParameters.from_dict( + { + "max_decoder_steps": 11, + "temperature": 0.25, + "topk": 7, + "cfg_scale": 1.5, + "unknown_key": "ignored", + } + ) + assert params == EasyModelInferenceParameters( + max_decoder_steps=11, + temperature=0.25, + topk=7, + cfg_scale=1.5, + ) + + +def test_easy_magpietts_model_construction(model): + expected_device = "cuda" if torch.cuda.is_available() else "cpu" + assert next(model.parameters()).device.type == expected_device + assert model.tokenizer is not None + assert model.decoder is not None + assert len(model.audio_embeddings) == model.num_audio_codebooks * model.frame_stacking_factor + assert isinstance(model.audio_in_projection, nn.Linear) + assert isinstance(model.audio_out_projection, nn.Linear) + assert model.final_proj.out_features == model.num_audio_codebooks * model.num_all_tokens_per_codebook + assert model.audio_bos_id == model.codebook_size + assert model.audio_eos_id == model.codebook_size + 1 + assert model.training_modes[0].name == "streaming_1_2" + assert model.default_inference_mode == "streaming_1_2" + assert model.lm_text_head is None + assert model.use_bpe_char_tokenizer + assert model.text_conditioning_tokenizer_name == BPE_TOKENIZER_NAME + assert hasattr(model, "cas_encoder") + + +def test_state_dict_excludes_codec(model): + state = model.state_dict() + assert state + assert not any("_codec_model" in key for key in state) + assert any(key.startswith("audio_embeddings.") for key in state) + assert any(key.startswith("final_proj.") for key in state) + + +def test_audio_and_text_embedding_shapes(model): + audio_tokens = _toy_codes(model, batch_size=2, num_frames=3) + audio_tokens[0, :, -1] = model.audio_eos_id + audio_embedded = model.embed_audio_tokens(audio_tokens) + assert audio_embedded.shape == (2, 3, model.cfg.embedding_dim) + assert audio_embedded.dtype == torch.float32 + assert torch.isfinite(audio_embedded).all() + + text_tokens, text_lens = _padded_token_tensor(model, ["abc", "de"]) + text_embedded = model.embed_text_tokens(text_tokens, text_lens=text_lens) + assert text_embedded.shape == (2, text_tokens.size(1), model.cfg.embedding_dim) + assert text_embedded.dtype == torch.float32 + assert torch.isfinite(text_embedded).all() + + +def test_stack_codes_round_trip_expected_shape(model): + codes = _toy_codes(model, batch_size=2, num_frames=4) + codes_lens = torch.tensor([4, 4], dtype=torch.long) + + stacked, stacked_lens = model.stack_codes( + codes, + codes_lens, + bos_id=model.audio_bos_id, + eos_id=model.audio_eos_id, + stacking_factor=2, + num_codebooks=model.num_audio_codebooks, + ) + unstacked, unstacked_lens = model.unstack_codes(stacked, stacked_lens, stacking_factor=2) + + assert stacked.shape == (2, model.num_audio_codebooks * 2, 2) + assert stacked_lens.tolist() == [2, 2] + assert unstacked.shape == codes.shape + assert unstacked_lens.tolist() == codes_lens.tolist() + torch.testing.assert_close(unstacked, codes) + + +def test_compute_loss_with_and_without_agent_mask(model): + _seed_everything() + batch_size, num_frames = 2, 5 + audio_codes = torch.randint( + low=0, + high=model.num_all_tokens_per_codebook, + size=(batch_size, model.num_audio_codebooks, num_frames), + ) + audio_codes_lens = torch.tensor([5, 3], dtype=torch.long) + logits = torch.randn( + batch_size, + num_frames, + model.num_audio_codebooks * model.num_all_tokens_per_codebook, + ) + agent_mask = torch.tensor( + [ + [True, True, False, False, False], + [False, True, True, False, False], + ], + dtype=torch.bool, + ) + + loss, loss_mask = model.compute_loss(logits, audio_codes, audio_codes_lens) + masked_loss, masked_loss_mask = model.compute_loss( + logits, + audio_codes, + audio_codes_lens, + agent_mask_target=agent_mask, + ) + + assert loss.ndim == 0 + assert masked_loss.ndim == 0 + assert torch.isfinite(loss) + assert torch.isfinite(masked_loss) + assert loss_mask.shape == (batch_size, model.num_audio_codebooks, num_frames) + assert masked_loss_mask.shape == loss_mask.shape + assert loss_mask.dtype == torch.bool + + +def test_prepare_audio_channel_embeddings_shapes(model): + audio_codes = _toy_codes(model, batch_size=2, num_frames=3) + audio_codes[1, :, 2] = 0 + audio_codes_lens = torch.tensor([3, 2], dtype=torch.long) + delay = torch.tensor([2, 1], dtype=torch.long) + agent_mask = torch.tensor( + [ + [True, False, False, False], + [False, True, False, False], + ], + dtype=torch.bool, + ) + + embeddings, lens, targets, target_lens, loss_agent_mask = model.prepare_audio_channel_embeddings( + audio_codes=audio_codes, + audio_codes_lens=audio_codes_lens, + delay=delay, + agent_mask=agent_mask, + ) + + assert embeddings.shape == (2, int((delay + target_lens).max().item()), model.cfg.embedding_dim) + assert embeddings.dtype == torch.float32 + assert torch.isfinite(embeddings).all() + assert lens.tolist() == (delay + target_lens).tolist() + assert targets.shape == (2, model.num_audio_codebooks, int(target_lens.max().item())) + assert target_lens.tolist() == [4, 3] + assert loss_agent_mask.shape == (2, targets.size(2)) + assert loss_agent_mask.dtype == torch.bool + + +def test_forward_with_inputs_embeds(model): + _seed_everything() + inputs_embeds = torch.randn(2, 6, model.cfg.embedding_dim) + attention_mask = torch.ones(2, 6, dtype=torch.bool) + + output = model.forward(inputs_embeds=inputs_embeds, attention_mask=attention_mask, use_cache=True) + + assert output.last_hidden_state.shape == inputs_embeds.shape + assert output.last_hidden_state.dtype == torch.float32 + assert torch.isfinite(output.last_hidden_state).all() + assert output.past_key_values is not None + + +def test_logits_to_audio_codes_schema(model): + logits = torch.zeros(2, 4, model.num_audio_codebooks * model.num_all_tokens_per_codebook) + expected_tokens = [] + for codebook_idx in range(model.num_audio_codebooks): + token_id = codebook_idx + 3 + expected_tokens.append(token_id) + offset = codebook_idx * model.num_all_tokens_per_codebook + logits[:, :, offset + token_id] = 5.0 + audio_codes_lens = torch.tensor([4, 2], dtype=torch.long) + + audio_codes = model.logits_to_audio_codes(logits, audio_codes_lens) + + assert audio_codes.shape == (2, model.num_audio_codebooks, 4) + assert audio_codes.dtype == torch.long + for codebook_idx, token_id in enumerate(expected_tokens): + assert audio_codes[0, codebook_idx].tolist() == [token_id] * 4 + assert audio_codes[1, :, 2:].eq(0).all() + + +def test_process_batch_smoke(model, toy_batch): + _seed_everything() + output = model.process_batch( + text=toy_batch["text"], + text_lens=toy_batch["text_lens"], + context_text_tokens=toy_batch["context_text_tokens"], + context_text_tokens_lens=toy_batch["context_text_tokens_lens"], + audio_codes=toy_batch["audio_codes"], + audio_codes_lens=toy_batch["audio_codes_lens"], + context_audio_codes=toy_batch["context_audio_codes"], + context_audio_codes_lens=toy_batch["context_audio_codes_lens"], + mode="val", + training_mode=model.training_modes[0], + agent_mask=toy_batch["agent_mask"], + ) + + assert output.selected_training_mode == model.default_inference_mode + assert torch.isfinite(output.loss) + assert torch.isfinite(output.codebook_loss) + assert output.phoneme_loss is None + assert output.local_transformer_loss is None + assert output.logits.shape[:2] == output.audio_codes_target.shape[0::2] + assert output.logits.shape[-1] == model.num_audio_codebooks * model.num_all_tokens_per_codebook + assert output.audio_codes_target.dtype == torch.long + assert output.context_audio_codes.shape[1] == model.num_audio_codebooks + + +def test_process_batch_with_multiturn_dataset_enabled(): + _seed_everything() + model = _make_easy_magpie_model(tiny_easy_magpie_cfg({"use_multiturn_dataset": True})) + batch = _toy_batch(model) + text = batch["text"].clone() + text[0, 0] = model.interruption_token_id + + output = model.process_batch( + text=text, + text_lens=batch["text_lens"], + context_text_tokens=batch["context_text_tokens"], + context_text_tokens_lens=batch["context_text_tokens_lens"], + audio_codes=batch["audio_codes"], + audio_codes_lens=batch["audio_codes_lens"], + context_audio_codes=batch["context_audio_codes"], + context_audio_codes_lens=batch["context_audio_codes_lens"], + mode="val", + training_mode=model.training_modes[0], + task=batch["task"], + agent_mask=batch["agent_mask"], + ) + + assert text[0, 0].item() == model.pad_id + assert torch.isfinite(output.loss) + assert torch.isfinite(output.codebook_loss) + assert output.local_transformer_loss is None + + +def test_process_batch_with_autoregressive_local_transformer(): + _seed_everything() + model = _make_easy_magpie_model( + tiny_easy_magpie_cfg( + { + "local_transformer_type": "autoregressive", + "local_transformer_hidden_dim": 32, + "local_transformer_n_layers": 1, + "local_transformer_n_heads": 4, + "local_transformer_loss_scale": 0.5, + } + ) + ) + batch = _toy_batch(model) + + output = model.process_batch( + text=batch["text"], + text_lens=batch["text_lens"], + context_text_tokens=batch["context_text_tokens"], + context_text_tokens_lens=batch["context_text_tokens_lens"], + audio_codes=batch["audio_codes"], + audio_codes_lens=batch["audio_codes_lens"], + context_audio_codes=batch["context_audio_codes"], + context_audio_codes_lens=batch["context_audio_codes_lens"], + mode="val", + training_mode=model.training_modes[0], + agent_mask=batch["agent_mask"], + ) + + assert torch.isfinite(output.loss) + assert torch.isfinite(output.codebook_loss) + assert torch.isfinite(output.local_transformer_loss) + assert output.local_transformer_logits is not None + assert output.local_transformer_logits.shape == output.logits.shape + + +def test_training_step_smoke(model, toy_batch): + _seed_everything() + model.train() + + with patch.object(model, "log", lambda *args, **kwargs: None), patch.object( + model, "log_dict", lambda *args, **kwargs: None + ): + loss = model.training_step(toy_batch, batch_idx=0) + + assert loss.ndim == 0 + assert torch.isfinite(loss) + assert loss.item() > 0 + + +def test_validation_step_smoke(model, toy_batch, tmp_path): + _seed_everything() + model.eval() + object.__setattr__( + model, + "_trainer", + SimpleNamespace(world_size=1, global_rank=0, local_rank=0, log_dir=str(tmp_path), current_epoch=0), + ) + + with patch.object(model, "log_val_audio_example", lambda *args, **kwargs: {}): + output = model.validation_step(toy_batch, batch_idx=1) + + assert set(output.keys()) == {"val_loss", "val_codebook_loss", "val_local_transformer_loss"} + assert torch.isfinite(output["val_loss"]) + assert torch.isfinite(output["val_codebook_loss"]) + assert output["val_local_transformer_loss"] is None + assert model.validation_step_outputs[-1] == output From da882835a91543112d2b1ac8277e54ee7ed40d01 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 23 Jun 2026 14:04:18 -0700 Subject: [PATCH 104/109] Fix black check Signed-off-by: Edresson Casanova --- tests/collections/tts/models/test_easy_magpietts.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/collections/tts/models/test_easy_magpietts.py b/tests/collections/tts/models/test_easy_magpietts.py index 6f66f6a76a96..9ff273be8a3f 100644 --- a/tests/collections/tts/models/test_easy_magpietts.py +++ b/tests/collections/tts/models/test_easy_magpietts.py @@ -150,9 +150,7 @@ def model(): def _padded_token_tensor(model, texts): - tokenized = [ - model.tokenizer.encode(text, tokenizer_name=BPE_TOKENIZER_NAME) + [model.eos_id] for text in texts - ] + tokenized = [model.tokenizer.encode(text, tokenizer_name=BPE_TOKENIZER_NAME) + [model.eos_id] for text in texts] lens = torch.tensor([len(tokens) for tokens in tokenized], dtype=torch.long) max_len = int(lens.max().item()) padded = torch.full((len(tokenized), max_len), model.pad_id, dtype=torch.long) @@ -491,8 +489,9 @@ def test_training_step_smoke(model, toy_batch): _seed_everything() model.train() - with patch.object(model, "log", lambda *args, **kwargs: None), patch.object( - model, "log_dict", lambda *args, **kwargs: None + with ( + patch.object(model, "log", lambda *args, **kwargs: None), + patch.object(model, "log_dict", lambda *args, **kwargs: None), ): loss = model.training_step(toy_batch, batch_idx=0) From 500c8454a46e748a0a2cc786e23547f2b5ad4f90 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 23 Jun 2026 19:13:00 -0700 Subject: [PATCH 105/109] Replace instantiate with safe_instantiate Signed-off-by: Edresson Casanova --- .../tts/data/text_to_speech_dataset_lhotse_multiturn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py index 618f5de3033e..fa4c73977c3d 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py @@ -19,7 +19,6 @@ import torch import torch.nn.functional as F import torch.utils.data -from hydra.utils import instantiate from lhotse import CutSet, Seconds, compute_num_frames from lhotse.cut import Cut from lhotse.dataset.collation import collate_audio, collate_matrices, collate_vectors @@ -34,6 +33,7 @@ normalize_volume, stack_tensors, ) +from nemo.core.classes.common import safe_instantiate from nemo.utils import logging @@ -49,8 +49,8 @@ def setup_tokenizers(all_tokenizers_config, mode='train'): else: text_tokenizer_kwargs = {} if "g2p" in tokenizer_config: - text_tokenizer_kwargs["g2p"] = instantiate(tokenizer_config.g2p) - tokenizer = instantiate(tokenizer_config, **text_tokenizer_kwargs) + text_tokenizer_kwargs["g2p"] = safe_instantiate(tokenizer_config.g2p) + tokenizer = safe_instantiate(tokenizer_config, **text_tokenizer_kwargs) if mode == 'test' and hasattr(tokenizer, "set_phone_prob"): tokenizer.set_phone_prob(1.0) tokenizers.append(tokenizer) @@ -242,7 +242,7 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: self.pad_id = self.text_tokenizer.pad if self.phoneme_tokenizer is None and self.phoneme_tokenizer_config is not None: - self.phoneme_tokenizer = instantiate(self.phoneme_tokenizer_config) + self.phoneme_tokenizer = safe_instantiate(self.phoneme_tokenizer_config) cuts = cuts.transform_text(_strip_timestamps) for cut in cuts: From dc1d48b5298da2dd5ae256694bbbff542096ee0d Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 24 Jun 2026 05:38:35 -0700 Subject: [PATCH 106/109] Fix unit tests Signed-off-by: Edresson Casanova --- nemo/collections/tts/modules/ffn_modules.py | 1 + .../tts/modules/transformer_2501.py | 7 ++++--- .../tts/models/test_easy_magpietts.py | 19 +++++++++++++++---- 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/nemo/collections/tts/modules/ffn_modules.py b/nemo/collections/tts/modules/ffn_modules.py index 853514dd1f58..7f477fcbcd19 100644 --- a/nemo/collections/tts/modules/ffn_modules.py +++ b/nemo/collections/tts/modules/ffn_modules.py @@ -89,6 +89,7 @@ def forward(self, signal, signal_mask=None): # signal: (B, C, T) # signal_mask: (B, T) or None (if None, assumes all positions are valid) if signal_mask is not None: + signal_mask = signal_mask.to(device=signal.device, dtype=signal.dtype) signal = signal * signal_mask.unsqueeze(1) if self.is_causal: # TODO: maybe replace with identify rather than keep conditional if in forward signal = F.pad(signal, self.causal_padding) diff --git a/nemo/collections/tts/modules/transformer_2501.py b/nemo/collections/tts/modules/transformer_2501.py index f0e245c9f258..5deaf898c8be 100644 --- a/nemo/collections/tts/modules/transformer_2501.py +++ b/nemo/collections/tts/modules/transformer_2501.py @@ -20,7 +20,6 @@ from nemo.collections.tts.modules.ffn_modules import PositionwiseConvFF from nemo.collections.tts.modules.moe_modules import PositionwiseConvFFMoE -from nemo.utils import logging # TODO: Move the cache implementation out of the Module class, and pass it as part of the forward so we can reset # as needed in the inference pipeline. @@ -120,6 +119,7 @@ def attn_naive( # attn_prior or square mask or vanilla attention if attn_prior is not None: + attn_prior = attn_prior.to(device=attn_score.device, dtype=attn_score.dtype) eps = torch.finfo(attn_prior.dtype).tiny attn_prior = attn_prior[:, :T] # trim for inference attn_prior = attn_prior[:, None] + eps @@ -129,7 +129,7 @@ def attn_naive( attn_score_log = F.log_softmax(attn_score, dim=-1) + attn_prior_log if self.make_prior_window_strict: # Make sure attention scores are lowest (eps) where prior is zero. - min_score = torch.log(torch.tensor(eps)).to(attn_score_log.device) + min_score = torch.log(torch.tensor(eps, device=attn_score_log.device, dtype=attn_score_log.dtype)) attn_score_log = attn_score_log.masked_fill( attn_prior == 0, min_score ) # Wherever prior is zero, set scores to eps. @@ -244,6 +244,7 @@ def compute_qkv_and_mask( mask = None if query_mask is not None: + query_mask = query_mask.to(device=query.device) # query_mask is a boolean mask of shape (B, T) # mask should be of shape (B, 1, T, T) where mask[:,0,i,:] == mask[:,0,:,i] == query_mask mask = query_mask.unsqueeze(1) * query_mask.unsqueeze(2) @@ -310,7 +311,7 @@ def compute_qkv_and_mask( self.cache['cross_k'] = k self.cache['cross_v'] = v - mask = memory_mask[:, None, None] if memory_mask is not None else None + mask = memory_mask.to(device=query.device)[:, None, None] if memory_mask is not None else None return q, k, v, mask diff --git a/tests/collections/tts/models/test_easy_magpietts.py b/tests/collections/tts/models/test_easy_magpietts.py index 9ff273be8a3f..742db9e1b8f4 100644 --- a/tests/collections/tts/models/test_easy_magpietts.py +++ b/tests/collections/tts/models/test_easy_magpietts.py @@ -30,10 +30,6 @@ from tests.collections.tts.models.test_audio_codec import create_codec_config -if torch.cuda.is_available(): - torch.set_default_device("cuda") - - pytestmark = pytest.mark.unit BPE_TOKENIZER_NAME = "nemotron_bpe" @@ -43,6 +39,21 @@ BPE_TOKENIZER_MODEL = str(BPE_TOKENIZER_CACHED_PATH) +@pytest.fixture(autouse=True, scope="module") +def _default_device_cuda(): + """Run this module with a CUDA default device without leaking it to later modules.""" + if not torch.cuda.is_available(): + yield + return + + prev = torch.get_default_device() + torch.set_default_device("cuda") + try: + yield + finally: + torch.set_default_device(prev) + + def _restore_codec_as_random_initialized_model(*args, **kwargs): del args if kwargs.get("return_config", False): From eb6b252077815bb9f690c1cf69acd28eb70cc4e2 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 24 Jun 2026 10:48:27 -0700 Subject: [PATCH 107/109] Add turn level volume normalization Signed-off-by: Edresson Casanova --- .../tts/data/text_to_speech_dataset_lhotse_multiturn.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py index fa4c73977c3d..708d8f2b4a20 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py @@ -306,6 +306,7 @@ def _align_codebooks(t): source_audio_list=source_audio_list, source_sample_rate=self.source_sample_rate, roles=self.input_roles, + volume_norm=self.volume_norm, ) ) @@ -790,6 +791,7 @@ def extract_turn_audio_channel( source_audio_list: list[torch.Tensor], source_sample_rate: int, roles: set[str], + volume_norm: bool = True, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Packs role-specific speech turns into batch dimension. @@ -823,6 +825,12 @@ def extract_turn_audio_channel( continue chunk = audio[start:end] + + # normalize turn level audio + if volume_norm: + chunk = normalize_volume(chunk.numpy()) + chunk = torch.from_numpy(chunk) + turn_audio.append(chunk) turn_lens.append(chunk.shape[0]) turn_indices.append([batch_idx, start, end]) From b4aef5e93ad62ccd34d921e6932f13cd165dffbe Mon Sep 17 00:00:00 2001 From: Shehzeen Hussain Date: Wed, 24 Jun 2026 15:23:55 -0700 Subject: [PATCH 108/109] remove configurable cas encoder layers since it is not needed Signed-off-by: Shehzeen Hussain --- examples/tts/conf/magpietts/easy_magpietts.yaml | 1 - .../tts/conf/magpietts/easy_magpietts_lhotse.yaml | 1 - .../tts/models/easy_magpietts_inference.py | 1 - nemo/collections/tts/modules/magpietts_modules.py | 12 ++---------- 4 files changed, 2 insertions(+), 13 deletions(-) diff --git a/examples/tts/conf/magpietts/easy_magpietts.yaml b/examples/tts/conf/magpietts/easy_magpietts.yaml index f1a58d0ae71c..2d0c274eb7e0 100644 --- a/examples/tts/conf/magpietts/easy_magpietts.yaml +++ b/examples/tts/conf/magpietts/easy_magpietts.yaml @@ -17,7 +17,6 @@ model: disable_lm_text_head: false disable_subword_embedding: false use_bpe_char_tokenizer: true - cas_encoder_n_layers: 1 # HuggingFace backend config (used when decoder_type: "huggingface") transformer_hf_backend: "Qwen/Qwen2.5-1.5B" diff --git a/examples/tts/conf/magpietts/easy_magpietts_lhotse.yaml b/examples/tts/conf/magpietts/easy_magpietts_lhotse.yaml index 01d3666ba958..d6b02adc4b97 100644 --- a/examples/tts/conf/magpietts/easy_magpietts_lhotse.yaml +++ b/examples/tts/conf/magpietts/easy_magpietts_lhotse.yaml @@ -15,7 +15,6 @@ model: disable_lm_text_head: false disable_subword_embedding: false use_bpe_char_tokenizer: true - cas_encoder_n_layers: 1 # HuggingFace backend config (used when decoder_type: "huggingface") transformer_hf_backend: "Qwen/Qwen2.5-1.5B" diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index 7f18322c9963..5ade43603b7b 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -557,7 +557,6 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): llm_tokenizer_vocab=subword_vocab, subword_padding_idx=self.tokenizer.pad, special_vocab=special_vocab, - n_layers=cfg.get('cas_encoder_n_layers', 1), ) if self.disable_subword_embedding and not hasattr(self, 'cas_encoder'): diff --git a/nemo/collections/tts/modules/magpietts_modules.py b/nemo/collections/tts/modules/magpietts_modules.py index fcffeb21e96e..efaeea3d09c8 100644 --- a/nemo/collections/tts/modules/magpietts_modules.py +++ b/nemo/collections/tts/modules/magpietts_modules.py @@ -181,14 +181,7 @@ class CharAwareSubwordEncoder(NeuralModule): The output is a tensor of shape (batch_size, max_subword_length, d_embed). """ - def __init__( - self, - d_embed: int, - llm_tokenizer_vocab: dict, - subword_padding_idx: int, - special_vocab: dict = None, - n_layers: int = 1, - ): + def __init__(self, d_embed: int, llm_tokenizer_vocab: dict, subword_padding_idx: int, special_vocab: dict = None): """ Args: d_embed (int): The dimension of the embedding. @@ -198,7 +191,6 @@ def __init__( subword_padding_idx (int): The padding index for the subword vocabulary. special_vocab (dict): items of special token dictionary (usually BOS, EOS) eg. special_vocab = {'': 30001, '': 30002} - n_layers (int): Number of transformer layers used in the char-aware encoder. """ super().__init__() self.subword_id_to_char_ids, self.char_vocab = build_vocabs( @@ -206,7 +198,7 @@ def __init__( ) self.embed_tokens = torch.nn.Embedding(self.vocab_size + 1, d_embed, padding_idx=self.vocab_size) self.encoder = transformer_2501.Transformer( - n_layers=n_layers, + n_layers=1, d_model=d_embed, d_ffn=d_embed * 4, sa_n_heads=8, From 3b5077f7a6a8661eeb78370afd0ad46dd0e23d12 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 25 Jun 2026 04:50:53 -0700 Subject: [PATCH 109/109] Left silence pad short user audios for safety Signed-off-by: Edresson Casanova --- .../modules/magpietts_inference/inference.py | 48 +++++++++++++++++-- 1 file changed, 45 insertions(+), 3 deletions(-) diff --git a/nemo/collections/tts/modules/magpietts_inference/inference.py b/nemo/collections/tts/modules/magpietts_inference/inference.py index eaa0759cbd1a..7bd5f2ac8c63 100644 --- a/nemo/collections/tts/modules/magpietts_inference/inference.py +++ b/nemo/collections/tts/modules/magpietts_inference/inference.py @@ -1160,6 +1160,33 @@ def _ensure_codec_silence_codes(self) -> torch.Tensor: return self.model._codec_sil_codes_buffer.to(self.model.device).long() + @staticmethod + def _left_pad_raw_audio_if_short( + user_audio: torch.Tensor, + user_audio_lens: torch.Tensor, + min_len: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Left-pad one raw user-audio turn with silence only when it is shorter than min_len.""" + min_len = int(min_len) + if min_len <= 0: + return user_audio, user_audio_lens + if user_audio_lens.numel() != 1: + raise RuntimeError("multiturn_user_audio raw-audio left padding expects batch_size=1.") + + # In this runner each user turn is kept as its own [1, T] tensor, so T + # is the valid raw-audio length. Checking size(-1) avoids a CUDA sync + # from user_audio_lens[0].item() in the common no-padding path. + current_len = int(user_audio.size(-1)) + if current_len >= min_len: + return user_audio, user_audio_lens + + padded_audio = user_audio.new_zeros(user_audio.shape[:-1] + (min_len,)) + if current_len > 0: + padded_audio[..., -current_len:] = user_audio + + padded_lens = user_audio_lens.new_full(user_audio_lens.shape, min_len) + return padded_audio, padded_lens + def _run_multiturn_generation(self, batch: Dict[str, Any]): model = self.model device = model.device @@ -1258,9 +1285,12 @@ def _run_multiturn_generation(self, batch: Dict[str, Any]): getattr(state, attr).zero_() state.last_phoneme_tokens = None + delay_tokens = int(state.config.training_mode.streaming_speech_delay) + samples_per_streaming_step = int(model.codec_model_samples_per_frame * model.frame_stacking_factor) + if not model.cfg.get("condition_on_user_speech", False): user_audio = batch["user_audio_turns"][local_turn_idx] - user_audio_prefill_steps = int(round(user_audio.size(-1) / model.input_samples_per_frame)) + user_audio_prefill_steps = int(round(user_audio.size(-1) / samples_per_streaming_step)) user_audio_prefill_tokens = torch.full( (1, user_audio_prefill_steps), model.pad_id, dtype=torch.long, device=device ) @@ -1268,6 +1298,13 @@ def _run_multiturn_generation(self, batch: Dict[str, Any]): else: user_audio = batch["user_audio_turns"][local_turn_idx] user_audio_lens = batch["user_audio_turns_lens"][local_turn_idx] + min_user_audio_len = delay_tokens * samples_per_streaming_step + user_audio, user_audio_lens = self._left_pad_raw_audio_if_short( + user_audio=user_audio, + user_audio_lens=user_audio_lens, + min_len=min_user_audio_len, + ) + user_audio_codes, user_audio_codes_lens = model._codec_helper.audio_to_codes( user_audio, user_audio_lens ) @@ -1320,8 +1357,13 @@ def _run_multiturn_generation(self, batch: Dict[str, Any]): ) user_audio_channel_embedding = user_audio_embedded - delay_tokens = int(state.config.training_mode.streaming_speech_delay) - delay_tokens = min(delay_tokens, user_audio_prefill_steps) + if user_audio_channel_embedding is None: + delay_tokens = min(delay_tokens, user_audio_prefill_steps) + elif user_audio_prefill_steps < delay_tokens: + raise RuntimeError( + "Raw user-audio left padding did not produce enough warmup steps: " + f"user_audio_prefill_steps={user_audio_prefill_steps}, delay_tokens={delay_tokens}." + ) num_warmup_text_tokens = min(delay_tokens, int(turn_lens[0].item()), turn_text.size(1)) # handle short turns so that it does not advance the text channel and keep the delay_tokens.