diff --git a/examples/tts/conf/magpietts/easy_magpietts_lhotse_multiturn.yaml b/examples/tts/conf/magpietts/easy_magpietts_lhotse_multiturn.yaml new file mode 100644 index 000000000000..e0571290516b --- /dev/null +++ b/examples/tts/conf/magpietts/easy_magpietts_lhotse_multiturn.yaml @@ -0,0 +1,231 @@ +name: Magpie-TTS-DecoderOnly-EN + +quadratic_duration: 20 + +# Adjust batch size based on GPU memory +# When doing weighted sampling with multiple manifests, this defines how many training steps are in an epoch. +# If null, then weighted sampling is disabled. + +model: + use_lhotse: true + + # Decoder backend selection + # Options: "huggingface" (default), "nemotron_h" + decoder_type: "huggingface" + + # HuggingFace backend config (used when decoder_type: "huggingface") + transformer_hf_backend: "Qwen/Qwen2.5-1.5B" + + # NemotronH config (used when decoder_type: "nemotron_h") + # Hybrid Mamba2/MoE/Attention model (~3B total, ~600-800M active). Layer types via hybrid_override_pattern: + # 'M' = Mamba2 layer, '*' = Attention layer, '-' = MLP layer, 'E' = MoE layer + nemotron_h_config: + hidden_size: 1536 # Should match embedding_dim + num_hidden_layers: 48 + vocab_size: 131072 + # Attention config + num_attention_heads: 12 + num_key_value_heads: 4 + attention_dropout: 0.0 + attention_bias: false + max_position_embeddings: 8192 + # Mamba config + mamba_num_heads: 64 + mamba_head_dim: 24 + ssm_state_size: 128 + conv_kernel: 4 + n_groups: 8 + chunk_size: 256 + mamba_hidden_act: "silu" + use_conv_bias: true + use_bias: false + # MLP config + intermediate_size: 4096 + mlp_hidden_act: "silu" + mlp_bias: false + # MoE config (scaled from Nemotron-3-Nano-30B-A3B) + n_routed_experts: 48 + num_experts_per_tok: 6 + moe_intermediate_size: 1024 + moe_shared_expert_intermediate_size: 2048 + n_group: 1 + topk_group: 1 + routed_scaling_factor: 2.5 + norm_topk_prob: true + # Layer pattern: (M E M E M *) x 8 => 16 Mamba, 16 MoE, 8 Attention + hybrid_override_pattern: "MEMEM*MEMEM*MEMEM*MEMEM*MEMEM*MEMEM*MEMEM*MEMEM*" + # Normalization + layer_norm_epsilon: 1e-5 + residual_in_fp32: true + + use_text_conditioning_encoder: true # If true, distilbert will be used to encode context_text if provided. + context_duration_min: 5.0 + context_duration_max: 5.0 + load_cached_codes_if_available: true + + embedding_dim: 1536 + hidden_dim: 1536 + audio_embedding_dim: 1536 # Can set a smaller dimension for audio embeddings to reduce parameters. Set equal to hidden_dim for no projection. + codecmodel_path: ??? + + # Local transformer parameters for autoregressive codebook prediction within a frame + local_transformer_type: "autoregressive" # "none", "autoregressive" + # Below args are only relevant if use_local_transformer is autoregressive + local_transformer_loss_scale: 1.0 + phoneme_loss_weight: 1.0 + local_transformer_n_layers: 3 + local_transformer_n_heads: 12 + local_transformer_hidden_dim: 1536 + + cfg_unconditional_prob: 0.05 + + # Multi-mode training configuration + training_modes: + - text_input_mode: "streaming" # Options: "full", "streaming" + streaming_phonemes_delay: 0 + streaming_speech_delay: 1 + + frame_stacking_factor: 2 + phoneme_stacking_factor: 1 + phoneme_confidence_unk_threshold: 0.0 # If max phoneme probability is below this threshold at inference-time, replace the predicted timestep with UNK to reduce error propagation. + dropout_text_input_prob: 0.1 + phoneme_corruption_batch_prob: 0.1 + phoneme_corruption_timestep_ratio: 0.15 + phoneme_corruption_unk_mode_prob: 0.5 + phoneme_corruption_type: "repeat_skip_unk" # "repeat_skip_unk" or "complete_channel" + phoneme_turn_dropout_batch_prob: 0.0 # prob of applying turn dropout to a sample + phoneme_turn_dropout_turn_prob: 0.0 # prob of dropping each phoneme turn within a sample + phoneme_turn_max_words_to_drop: 0 # turns with <= this many words keep phoneme tokens as pad_id + + phoneme_tokenizer: + _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPABPETokenizer + tokenizer_path: ??? + + text_tokenizers: + nemotron_nano_30b: + _target_: AutoTokenizer + pretrained_model: "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16" + + train_ds: + use_lhotse: ${model.use_lhotse} + volume_norm: true + dataset: + multi_config: true + shuffle: true + seed: 42 + shard_seed: "trng" + + sampler_fusion: randomized_round_robin + sampler_weights: + tts_data: 0.5 + duplex_data: 0.5 + tts_data: + min_duration: 0.2 + min_context_speaker_similarity: 0.6 + max_cer: 0.03 + batch_duration : ??? # in seconds. Adjust based on your GPU memory. + quadratic_duration: ${quadratic_duration} + use_bucketing: true + num_buckets: 20 + bucket_buffer_size: 10_000 + shuffle_buffer_size: 10_000 + num_cuts_for_bins_estimate: 10_000 + shard_seed: "trng" + drop_last: true + shuffle: true + num_workers: 6 + pin_memory: true + + input_cfg: + - type: lhotse_shar + shar_path: ??? + weight: 1.0 + tags: + tokenizer_names: ["english_phoneme"] + + duplex_data: + input_cfg: /lustre/fsw/convai_convaird_nemo-speech/data/duplex/multispeaker_syn_duplex.yaml + use_bucketing: true + num_buckets: 20 + bucket_buffer_size: 1_000 + shuffle_buffer_size: 1_000 + num_cuts_for_bins_estimate: 1_000 + max_duration: 300 # 5 mi max duration + bucket_duration_bins: [4.0, 8.9, 10.2, 11.6, 13.2, 15.0, 17.0, 19.3, 25.0, 31.5, 38.5, 46.0, 55.5, 66.5, 79.5, 93.3, 110.0, 130.0, 156.8, 203.3] + bucket_batch_size: [75, 33, 29, 25, 23, 20, 18, 15, 12, 10, 8, 7, 5, 4, 3, 3, 2, 2, 1, 1] + + + validation_ds: + use_lhotse: ${model.use_lhotse} + volume_norm: true + + dataset: + min_duration: 0.2 + min_context_speaker_similarity: 0.6 + max_cer: 0.03 + batch_duration: ??? # recommend to use smaller batch_duration for validation dataset than training dataset. + quadratic_duration: ${quadratic_duration} + use_bucketing: false + force_finite: true + force_map_dataset: true + drop_last: false + shuffle: false + num_workers: 2 + pin_memory: true + seed: 42 + shard_seed: "randomized" + + input_cfg: + - type: lhotse_shar + shar_path: ??? + weight: 1.0 + tags: + tokenizer_names: ["english_phoneme"] + + optim: + _target_: torch.optim.AdamW + lr: 1e-4 + + sched: + name: ExponentialLR + gamma: 0.998 + +trainer: + num_nodes: 1 + devices: -1 + accelerator: gpu + strategy: ddp_find_unused_parameters_true + precision: bf16-mixed + max_steps: ??? + accumulate_grad_batches: 1 + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + log_every_n_steps: 100 + limit_train_batches: 1_000 + val_check_interval: 1_000 + num_sanity_val_steps: 0 + benchmark: false + use_distributed_sampler: false # required because Lhotse has its own handling + gradient_clip_val: 2.5 + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_wandb_logger: false + wandb_logger_kwargs: + entity: null + name: ${name} + project: null + group: null + resume: true + create_checkpoint_callback: true + checkpoint_callback_params: + monitor: val_loss + mode: min + save_top_k: 5 + save_best_model: true + always_save_nemo: true + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.4f}-{step}-{epoch}' + resume_if_exists: true + resume_ignore_no_checkpoint: true \ No newline at end of file diff --git a/examples/tts/easy_magpietts.py b/examples/tts/easy_magpietts.py index 5e9be71a7805..74ac2c6e3965 100644 --- a/examples/tts/easy_magpietts.py +++ b/examples/tts/easy_magpietts.py @@ -55,6 +55,9 @@ def main(cfg): else: raise NotImplementedError(f"Only train, onlinepo_train and test modes are supported. Got {mode}") + if cfg.get("pretrained_model", None): + model.restore_from_pretrained_checkpoint(cfg.pretrained_model) + model.maybe_init_from_pretrained_checkpoint(cfg=cfg) if mode in ['train', 'onlinepo_train']: diff --git a/examples/tts/magpietts_inference.py b/examples/tts/magpietts_inference.py index 90f97fdedfa8..7d4666497c44 100644 --- a/examples/tts/magpietts_inference.py +++ b/examples/tts/magpietts_inference.py @@ -60,6 +60,7 @@ import os import random import shutil +import time from dataclasses import fields from pathlib import Path from typing import List, Optional, Tuple @@ -82,6 +83,8 @@ BaseInferenceRunner, EasyMagpieInferenceConfig, EasyMagpieInferenceRunner, + EasyMagpieMultiturnUserAudioInferenceConfig, + EasyMagpieMultiturnUserAudioInferenceRunner, MagpieInferenceConfig, MagpieInferenceRunner, ) @@ -126,6 +129,8 @@ def append_metrics_to_csv(csv_path: str, checkpoint_name: str, dataset: str, met metrics.get('ssim_pred_gt_avg_alternate', ''), metrics.get('ssim_pred_context_avg_alternate', ''), metrics.get('ssim_gt_context_avg_alternate', ''), + metrics.get('esim_pred_gt_avg', ''), + metrics.get('ems_pred_gt_avg', ''), metrics.get('cer_gt_audio_cumulative', ''), metrics.get('wer_gt_audio_cumulative', ''), metrics.get('utmosv2_avg', ''), @@ -141,6 +146,244 @@ def append_metrics_to_csv(csv_path: str, checkpoint_name: str, dataset: str, met logging.info(f"Metrics appended to: {csv_path}") +def _mean_finite(values: list): + vals = [] + for value in values: + try: + value = float(value) + except (TypeError, ValueError): + continue + if np.isfinite(value): + vals.append(value) + return None if not vals else float(np.mean(vals)) + + +def _enrich_filewise_metrics_with_manifest(filewise_metrics: list, manifest_path: str) -> list: + """Attach multiturn manifest metadata back to evaluator filewise rows. + + evaluate_generated_audio_dir() returns one filewise row per generated turn, + but the filtered row does not preserve source_sample_idx/turn_id. The + generated multiturn manifest has the same order as predicted_audio_*.wav, so + we merge by list index before grouping. + """ + if manifest_path is None or not os.path.exists(manifest_path): + logging.warning(f"Could not enrich multiturn filewise metrics; manifest missing: {manifest_path}") + return filewise_metrics + + manifest_records = read_manifest(manifest_path) + if len(manifest_records) != len(filewise_metrics): + logging.warning( + "Could not safely enrich multiturn filewise metrics; " + f"manifest rows={len(manifest_records)} filewise rows={len(filewise_metrics)} " + f"manifest_path={manifest_path}" + ) + return filewise_metrics + + enriched = [] + for row, record in zip(filewise_metrics, manifest_records): + new_row = dict(row) + for key in [ + "source_sample_idx", + "turn_id", + "speaker", + "rank", + "rank_local_idx", + "audio_filepath", + "context_audio_filepath", + "predicted_phoneme_text", + "predicted_phoneme_tokens", + "predicted_phoneme_token_labels", + ]: + if key in record and key not in new_row: + new_row[key] = record[key] + enriched.append(new_row) + + return enriched + + +def _group_multiturn_filewise_metrics_by_sample(filewise_metrics: list) -> list: + """Group turn-level multiturn metrics into one old-style row per sample. + + Each grouped row keeps turn-by-turn CER/WER/SSIM/UTMOS/text/audio lists plus + averaged sample-level values. Rows are sorted by averaged CER descending so + the worst conversations/samples appear first. + """ + grouped = {} + + for row_idx, row in enumerate(filewise_metrics): + source_sample_idx = row.get("source_sample_idx", None) + if source_sample_idx is None: + source_sample_idx = row.get("speaker", None) + if source_sample_idx is None: + source_sample_idx = row_idx + + key = str(source_sample_idx) + if key not in grouped: + grouped[key] = { + "source_sample_idx": source_sample_idx, + "speaker": row.get("speaker", source_sample_idx), + "rank": row.get("rank", None), + "target_audio_path": row.get("gt_audio_filepath", row.get("audio_filepath", "")), + "context_audio_path": row.get("context_audio_filepath", ""), + "turn_rows": [], + } + grouped[key]["turn_rows"].append(row) + + grouped_rows = [] + for _, group in grouped.items(): + turns = group["turn_rows"] + + def turn_sort_key(r): + try: + return int(r.get("turn_id", 0)) + except (TypeError, ValueError): + return 0 + + turns = sorted(turns, key=turn_sort_key) + + cer_turns = [r.get("cer") for r in turns] + wer_turns = [r.get("wer") for r in turns] + pred_context_ssim_turns = [r.get("pred_context_ssim") for r in turns] + pred_gt_ssim_turns = [r.get("pred_gt_ssim") for r in turns] + gt_context_ssim_turns = [r.get("gt_context_ssim") for r in turns] + pred_gt_esim_turns = [r.get("pred_gt_esim") for r in turns] + pred_gt_ems_turns = [r.get("pred_gt_ems") for r in turns] + utmosv2_turns = [r.get("utmosv2") for r in turns] + eou_type_turns = [r.get("eou_type") for r in turns] + eou_trailing_duration_turns = [r.get("eou_trailing_duration") for r in turns] + eou_trail_rms_ratio_turns = [r.get("eou_trail_rms_ratio") for r in turns] + predicted_phoneme_text_turns = [r.get("predicted_phoneme_text", "") for r in turns] + predicted_phoneme_tokens_turns = [r.get("predicted_phoneme_tokens", []) for r in turns] + predicted_phoneme_token_labels_turns = [r.get("predicted_phoneme_token_labels", []) for r in turns] + + grouped_rows.append( + { + "source_sample_idx": group["source_sample_idx"], + "speaker": group["speaker"], + "rank": group["rank"], + "num_turns": len(turns), + # Sample-level averages over all turns. + "cer": _mean_finite(cer_turns), + "wer": _mean_finite(wer_turns), + "pred_context_ssim": _mean_finite(pred_context_ssim_turns), + "pred_gt_ssim": _mean_finite(pred_gt_ssim_turns), + "gt_context_ssim": _mean_finite(gt_context_ssim_turns), + "pred_gt_esim": _mean_finite(pred_gt_esim_turns), + "pred_gt_ems": _mean_finite(pred_gt_ems_turns), + "utmosv2": _mean_finite(utmosv2_turns), + "eou_trailing_duration": _mean_finite(eou_trailing_duration_turns), + "eou_trail_rms_ratio": _mean_finite(eou_trail_rms_ratio_turns), + # Turn-by-turn values, old-script style. + "turn_ids": [r.get("turn_id", i) for i, r in enumerate(turns)], + "cer_turns": cer_turns, + "wer_turns": wer_turns, + "pred_context_ssim_turns": pred_context_ssim_turns, + "pred_gt_ssim_turns": pred_gt_ssim_turns, + "gt_context_ssim_turns": gt_context_ssim_turns, + "pred_gt_esim_turns": pred_gt_esim_turns, + "pred_gt_ems_turns": pred_gt_ems_turns, + "utmosv2_turns": utmosv2_turns, + "eou_type_turns": eou_type_turns, + "eou_trailing_duration_turns": eou_trailing_duration_turns, + "eou_trail_rms_ratio_turns": eou_trail_rms_ratio_turns, + "predicted_phoneme_text_turns": predicted_phoneme_text_turns, + "predicted_phoneme_tokens_turns": predicted_phoneme_tokens_turns, + "predicted_phoneme_token_labels_turns": predicted_phoneme_token_labels_turns, + "reference_text": [r.get("gt_text", "") for r in turns], + "asr_hyp": [r.get("pred_text", "") for r in turns], + "pred_audio_paths": [r.get("pred_audio_filepath", "") for r in turns], + "target_audio_path": group["target_audio_path"], + "context_audio_path": group["context_audio_path"], + "turn_metrics": turns, + } + ) + + grouped_rows.sort( + key=lambda r: ( + r.get("cer") is not None, + float(r["cer"]) if r.get("cer") is not None else -1.0, + ), + reverse=True, + ) + return grouped_rows + + +def _write_grouped_multiturn_filewise_metrics_csv(csv_path: str, grouped_rows: list) -> None: + fieldnames = [ + "source_sample_idx", + "speaker", + "rank", + "num_turns", + "cer", + "wer", + "pred_context_ssim", + "pred_gt_ssim", + "gt_context_ssim", + "pred_gt_esim", + "pred_gt_ems", + "utmosv2", + "eou_trailing_duration", + "eou_trail_rms_ratio", + "turn_ids", + "cer_turns", + "wer_turns", + "pred_context_ssim_turns", + "pred_gt_ssim_turns", + "gt_context_ssim_turns", + "pred_gt_esim_turns", + "pred_gt_ems_turns", + "utmosv2_turns", + "eou_type_turns", + "eou_trailing_duration_turns", + "eou_trail_rms_ratio_turns", + "target_audio_path", + "context_audio_path", + "pred_audio_paths", + "reference_text", + "asr_hyp", + "predicted_phoneme_text_turns", + "predicted_phoneme_tokens_turns", + "predicted_phoneme_token_labels_turns", + ] + + def csv_value(value): + if isinstance(value, (list, dict)): + value = json.dumps(value, ensure_ascii=False) + if value is None: + value = "" + value = str(value).replace('"', '""') + if "," in value or "\n" in value or "[" in value or "{" in value: + value = f'"{value}"' + return value + + with open(csv_path, "w", encoding="utf-8") as f: + f.write(",".join(fieldnames) + "\n") + for row in grouped_rows: + f.write(",".join(csv_value(row.get(k, "")) for k in fieldnames) + "\n") + + +def _save_grouped_multiturn_filewise_metrics( + eval_dir: str, + dataset: str, + repeat_idx: int, + filewise_metrics: list, + manifest_path: str, +) -> None: + enriched_filewise = _enrich_filewise_metrics_with_manifest(filewise_metrics, manifest_path) + grouped_rows = _group_multiturn_filewise_metrics_by_sample(enriched_filewise) + + json_path = os.path.join(eval_dir, f"{dataset}_grouped_filewise_metrics_{repeat_idx}.json") + csv_path = os.path.join(eval_dir, f"{dataset}_grouped_filewise_metrics_{repeat_idx}.csv") + + with open(json_path, "w", encoding="utf-8") as f: + json.dump(grouped_rows, f, indent=4, ensure_ascii=False) + + _write_grouped_multiturn_filewise_metrics_csv(csv_path, grouped_rows) + + logging.info(f"Saved grouped multiturn filewise metrics JSON to: {json_path}") + logging.info(f"Saved grouped multiturn filewise metrics CSV to: {csv_path}") + + def create_formatted_metrics_mean_ci(metrics_mean_ci: dict) -> dict: """Create formatted metrics mean CI.""" for k, v in metrics_mean_ci.items(): @@ -169,6 +412,143 @@ def filter_datasets( return datasets +def _runner_eval_manifest_and_audio_dir(runner: BaseInferenceRunner, default_manifest: str, default_audio_dir: str): + """Return evaluation manifest/audio dir produced by the runner, if any.""" + eval_manifest = getattr(runner, "evaluation_manifest_path", None) or default_manifest + eval_audio_dir = getattr(runner, "evaluation_audio_dir", None) or default_audio_dir + return eval_manifest, eval_audio_dir + + +def _get_torchrun_rank_info() -> Tuple[int, int, int]: + """Return (rank, world_size, local_rank) from torchrun/SLURM env vars. + + We intentionally do not initialize torch.distributed here. The inference + script only needs env-based sharding, while NeMo evaluation models can run + without distributed collectives. + """ + rank = int(os.environ.get("RANK", os.environ.get("SLURM_PROCID", "0"))) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("SLURM_LOCALID", "0"))) + return rank, world_size, local_rank + + +def _configure_cuda_for_rank() -> Tuple[int, int, int]: + rank, world_size, local_rank = _get_torchrun_rank_info() + if torch.cuda.is_available(): + device_count = torch.cuda.device_count() + if device_count > 0: + torch.cuda.set_device(local_rank % device_count) + logging.info( + f"Using CUDA device {local_rank % device_count}; " + f"rank={rank}, local_rank={local_rank}, world_size={world_size}" + ) + return rank, world_size, local_rank + + +def _wait_for_multiturn_rank_manifests(repeat_audio_dir: str, world_size: int, timeout_sec: int = 7200) -> None: + deadline = time.time() + timeout_sec + while time.time() < deadline: + missing = [] + for rank in range(world_size): + path = os.path.join( + repeat_audio_dir, + f"rank_{rank:04d}", + f"multiturn_user_audio_turn_manifest_rank{rank:04d}.jsonl", + ) + if not os.path.exists(path): + missing.append(path) + if not missing: + return + time.sleep(5) + raise RuntimeError(f"Timed out waiting for multiturn rank manifests: {missing}") + + +def _copy_or_link(src: str, dst: str, required: bool = False) -> None: + if src is None or not os.path.exists(src): + if os.path.lexists(dst): + os.remove(dst) + if required: + raise FileNotFoundError(f"Missing required merge source: {src} -> {dst}") + return + + os.makedirs(os.path.dirname(dst), exist_ok=True) + + if os.path.lexists(dst): + os.remove(dst) + + # Prefer real files for evaluator inputs; broken symlinks confuse librosa/UTMOS. + shutil.copyfile(src, dst) + + +def _merge_multiturn_rank_outputs(repeat_audio_dir: str, world_size: int, save_predicted_codes: bool) -> str: + """Merge rank-local multiturn outputs into one EasyMagpie-compatible dir. + + Each rank writes local files named predicted_audio_0.wav, target_audio_0.wav, + context_audio_0.wav, predicted_codes_0.pt, ... inside rank_XXXX/. This + function remaps them to contiguous global indices in repeat_audio_dir/ and + writes a merged turn-level manifest. + """ + # clean previous merged files + for pattern in [ + "predicted_audio_*.wav", + "target_audio_*.wav", + "context_audio_*.wav", + "predicted_codes_*.pt", + ]: + for path in Path(repeat_audio_dir).glob(pattern): + if path.is_symlink() or path.exists(): + path.unlink(missing_ok=True) + + merged_records = [] + global_idx = 0 + + for rank in range(world_size): + rank_dir = os.path.join(repeat_audio_dir, f"rank_{rank:04d}") + rank_manifest = os.path.join(rank_dir, f"multiturn_user_audio_turn_manifest_rank{rank:04d}.jsonl") + if not os.path.exists(rank_manifest): + raise FileNotFoundError(f"Missing rank manifest: {rank_manifest}") + + with open(rank_manifest, "r", encoding="utf-8") as f: + rank_records = [json.loads(line) for line in f if line.strip()] + + for local_idx, record in enumerate(rank_records): + pred_src = os.path.join(rank_dir, f"predicted_audio_{local_idx}.wav") + pred_dst = os.path.join(repeat_audio_dir, f"predicted_audio_{global_idx}.wav") + _copy_or_link(pred_src, pred_dst, required=True) + + if save_predicted_codes: + code_src = os.path.join(rank_dir, f"predicted_codes_{local_idx}.pt") + code_dst = os.path.join(repeat_audio_dir, f"predicted_codes_{global_idx}.pt") + _copy_or_link(code_src, code_dst, required=False) + + target_src = os.path.join(rank_dir, record.get("audio_filepath", f"target_audio_{local_idx}.wav")) + target_dst = os.path.join(repeat_audio_dir, f"target_audio_{global_idx}.wav") + _copy_or_link(target_src, target_dst, required=True) + + context_src = os.path.join( + rank_dir, + record.get("context_audio_filepath", f"context_audio_{local_idx}.wav"), + ) + context_dst = os.path.join(repeat_audio_dir, f"context_audio_{global_idx}.wav") + _copy_or_link(context_src, context_dst, required=True) + + merged = dict(record) + merged["audio_filepath"] = f"target_audio_{global_idx}.wav" + merged["context_audio_filepath"] = f"context_audio_{global_idx}.wav" + merged["rank"] = rank + merged["rank_local_idx"] = local_idx + merged_records.append(merged) + global_idx += 1 + + merged_manifest = os.path.join(repeat_audio_dir, "multiturn_user_audio_turn_manifest.jsonl") + with open(merged_manifest, "w", encoding="utf-8") as f: + for record in merged_records: + f.write(json.dumps(record, ensure_ascii=False) + "\n") + + logging.info(f"Merged {len(merged_records)} multiturn turn records into {merged_manifest}") + return merged_manifest + + def run_inference_and_evaluation( runner: BaseInferenceRunner, checkpoint_name: str, @@ -216,6 +596,13 @@ def run_inference_and_evaluation( if not eval_config.with_utmosv2 and 'utmosv2' in violin_plot_metrics: violin_plot_metrics.remove('utmosv2') + rank, world_size, _ = _get_torchrun_rank_info() + is_distributed = world_size > 1 + is_multiturn_user_audio = getattr(runner, "produces_turn_level_evaluation", False) + + if hasattr(runner, "set_distributed_context"): + runner.set_distributed_context(rank=rank, world_size=world_size) + # Build full checkpoint identifier (include MoE info if present) full_checkpoint_name = ( f"{checkpoint_name}_{moe_info}{inference_config.build_identifier()}_SV_{eval_config.sv_model}" @@ -231,7 +618,8 @@ def run_inference_and_evaluation( "checkpoint_name,dataset,cer_filewise_avg,wer_filewise_avg,cer_cumulative," "wer_cumulative,ssim_pred_gt_avg,ssim_pred_context_avg,ssim_gt_context_avg," "ssim_pred_gt_avg_alternate,ssim_pred_context_avg_alternate," - "ssim_gt_context_avg_alternate,cer_gt_audio_cumulative,wer_gt_audio_cumulative," + "ssim_gt_context_avg_alternate,esim_pred_gt_avg,ems_pred_gt_avg," + "cer_gt_audio_cumulative,wer_gt_audio_cumulative," "utmosv2_avg,total_gen_audio_seconds,frechet_codec_distance," "eou_cutoff_rate,eou_silence_rate,eou_noise_rate,eou_error_rate" ) @@ -255,13 +643,17 @@ def run_inference_and_evaluation( # Setup CSV files per_run_csv = os.path.join(eval_dir, "all_experiment_metrics.csv") - write_csv_header_if_needed(per_run_csv, csv_header) + if rank == 0: + write_csv_header_if_needed(per_run_csv, csv_header) metrics_all_repeats = [] filewise_metrics_all_repeats = [] for repeat_idx in range(num_repeats): - logging.info(f"Repeat {repeat_idx + 1}/{num_repeats} for dataset {dataset}") + repeat_log_msg = f"Repeat {repeat_idx + 1}/{num_repeats} for dataset {dataset}" + if is_distributed: + repeat_log_msg += f", rank {rank}/{world_size}" + logging.info(repeat_log_msg) repeat_audio_dir = os.path.join(audio_dir, f"repeat_{repeat_idx}") os.makedirs(repeat_audio_dir, exist_ok=True) @@ -274,9 +666,21 @@ def run_inference_and_evaluation( f"Dataset length mismatch: {len(test_dataset)} vs {len(manifest_records)} manifest records" ) + if is_distributed and not is_multiturn_user_audio: + raise RuntimeError( + "torchrun multi-GPU sharding is currently implemented for " + "--easy_magpie_inference_mode multiturn_user_audio only. " + "Use the existing single-process path for single_turn/magpie, or add a " + "rank-safe merge path for those runners." + ) + + inference_output_dir = repeat_audio_dir + if is_distributed and is_multiturn_user_audio: + inference_output_dir = os.path.join(repeat_audio_dir, f"rank_{rank:04d}") + rtf_metrics_list, _, codec_file_paths = runner.run_inference_on_dataset( dataset=test_dataset, - output_dir=repeat_audio_dir, + output_dir=inference_output_dir, manifest_records=manifest_records, audio_base_dir=meta['audio_dir'], save_cross_attention_maps=True, @@ -293,7 +697,10 @@ def run_inference_and_evaluation( mean_rtf[f"{component_name}_{key}"] = value logging.info(f"{component_name} FLOPs per token: {component_flops['total_flops_per_token']:,}") - with open(os.path.join(eval_dir, f"{dataset}_rtf_metrics_{repeat_idx}.json"), "w") as f: + rtf_metrics_filename = f"{dataset}_rtf_metrics_{repeat_idx}.json" + if is_distributed: + rtf_metrics_filename = f"{dataset}_rtf_metrics_{repeat_idx}_rank{rank:04d}.json" + with open(os.path.join(eval_dir, rtf_metrics_filename), "w") as f: json.dump(mean_rtf, f, indent=4) if skip_evaluation: @@ -301,6 +708,26 @@ def run_inference_and_evaluation( continue # Run evaluation + if is_distributed and is_multiturn_user_audio: + if rank != 0: + # Non-zero ranks only generate. Rank 0 waits and evaluates merged outputs. + continue + + _wait_for_multiturn_rank_manifests(repeat_audio_dir, world_size) + merged_manifest_path = _merge_multiturn_rank_outputs( + repeat_audio_dir=repeat_audio_dir, + world_size=world_size, + save_predicted_codes=eval_config.with_fcd, + ) + eval_manifest_path = merged_manifest_path + eval_audio_dir = repeat_audio_dir + else: + eval_manifest_path, eval_audio_dir = _runner_eval_manifest_and_audio_dir( + runner, + default_manifest=meta['manifest_path'], + default_audio_dir=meta['audio_dir'], + ) + eval_config_for_dataset = EvaluationConfig( sv_model=eval_config.sv_model, asr_model_name=eval_config.asr_model_name, @@ -309,12 +736,19 @@ def run_inference_and_evaluation( with_utmosv2=eval_config.with_utmosv2, with_fcd=eval_config.with_fcd, codec_model_path=eval_config.codec_model_path, + with_emotion_metrics=eval_config.with_emotion_metrics, + emotion_model_size=eval_config.emotion_model_size, + emotion_embedding_type=eval_config.emotion_embedding_type, + emotion_cache_dir=eval_config.emotion_cache_dir, + strip_text_annotations_for_metrics=eval_config.strip_text_annotations_for_metrics, device=eval_config.device, + asr_batch_size=eval_config.asr_batch_size, + eou_batch_size=eval_config.eou_batch_size, ) metrics, filewise_metrics = evaluate_generated_audio_dir( - manifest_path=meta['manifest_path'], - audio_dir=meta['audio_dir'], + manifest_path=eval_manifest_path, + audio_dir=eval_audio_dir, generated_audio_dir=repeat_audio_dir, config=eval_config_for_dataset, ) @@ -330,6 +764,15 @@ def run_inference_and_evaluation( with open(os.path.join(eval_dir, f"{dataset}_filewise_metrics_{repeat_idx}.json"), "w") as f: json.dump(sorted_filewise, f, indent=4) + if is_multiturn_user_audio: + _save_grouped_multiturn_filewise_metrics( + eval_dir=eval_dir, + dataset=dataset, + repeat_idx=repeat_idx, + filewise_metrics=filewise_metrics, + manifest_path=eval_manifest_path, + ) + # Append to per-run CSV append_metrics_to_csv(per_run_csv, full_checkpoint_name, dataset, metrics) @@ -338,8 +781,16 @@ def run_inference_and_evaluation( create_violin_plot(filewise_metrics, violin_plot_metrics, violin_path) # Delete temporary predicted codes files - for codec_file_path in codec_file_paths: - os.remove(codec_file_path) + if is_distributed and is_multiturn_user_audio: + for codec_file_path in Path(repeat_audio_dir).glob("predicted_codes_*.pt"): + if os.path.exists(codec_file_path): + os.remove(codec_file_path) + else: + for codec_file_path in codec_file_paths: + os.remove(codec_file_path) + + if rank != 0: + continue if skip_evaluation or not metrics_all_repeats: continue @@ -367,17 +818,17 @@ def run_inference_and_evaluation( cer_per_dataset.append(np.mean(cer_values)) # Create combined plot if we have multiple datasets - if len(all_datasets_filewise_metrics) > 1: + if rank == 0 and len(all_datasets_filewise_metrics) > 1: combined_plot_path = os.path.join(out_dir, f"{full_checkpoint_name}_combined_violin_plot.png") create_combined_box_plot(all_datasets_filewise_metrics, violin_plot_metrics, combined_plot_path) # Clean up if requested - if clean_up_disk: + if rank == 0 and clean_up_disk: logging.info(f"Cleaning up output directory: {out_dir}") shutil.rmtree(out_dir) # Return averaged metrics - if ssim_per_dataset and cer_per_dataset: + if rank == 0 and ssim_per_dataset and cer_per_dataset: return np.mean(cer_per_dataset), np.mean(ssim_per_dataset) return None, None @@ -534,6 +985,20 @@ def _add_common_args(parser: argparse.ArgumentParser) -> None: eval_group.add_argument('--num_repeats', type=int, default=1) eval_group.add_argument('--confidence_level', type=float, default=0.95) eval_group.add_argument('--disable_utmosv2', action='store_true') + eval_group.add_argument('--with_emotion_metrics', action='store_true') + eval_group.add_argument( + '--strip_text_annotations_for_metrics', + action='store_true', + help='Strip bracket/tag/control annotations from reference and ASR hypothesis text while computing text metrics.', + ) + eval_group.add_argument('--emotion_model_size', type=str, default="small", choices=["small", "large"]) + eval_group.add_argument( + '--emotion_embedding_type', + type=str, + default="score_vector", + choices=["head_concat", "head_mean", "score_vector"], + ) + eval_group.add_argument('--emotion_cache_dir', type=str, default=None) eval_group.add_argument( '--violin_plot_metrics', type=str, @@ -541,6 +1006,8 @@ def _add_common_args(parser: argparse.ArgumentParser) -> None: default=['cer', 'pred_context_ssim', 'utmosv2'], ) eval_group.add_argument('--disable_fcd', action='store_true') + eval_group.add_argument("--asr_batch_size", type=int, default=32) + eval_group.add_argument("--eou_batch_size", type=int, default=32) # Quality targets target_group = parser.add_argument_group('Quality Targets') @@ -580,6 +1047,14 @@ def _add_magpie_args(parser: argparse.ArgumentParser) -> None: def _add_easy_magpie_args(parser: argparse.ArgumentParser) -> None: """Add arguments specific to decoder-only EasyMagpieTTSInferenceModel.""" group = parser.add_argument_group('EasyMagpieTTS-specific Parameters') + group.add_argument( + '--easy_magpie_inference_mode', + type=str, + default='single_turn', + choices=['single_turn', 'multiturn_user_audio'], + ) + group.add_argument('--max_eval_turns', type=int, default=6) + group.add_argument('--no_save_debug_multiturn_audio', action='store_true') group.add_argument( '--phoneme_input_type', type=str, @@ -649,7 +1124,12 @@ def _build_magpie_config(args) -> MagpieInferenceConfig: def _build_easy_magpie_config(args) -> EasyMagpieInferenceConfig: - return EasyMagpieInferenceConfig( + cfg_cls = ( + EasyMagpieMultiturnUserAudioInferenceConfig + if args.easy_magpie_inference_mode == 'multiturn_user_audio' + else EasyMagpieInferenceConfig + ) + kwargs = dict( model_inference_parameters=_build_inference_params_from_args(EasyModelInferenceParameters, args), batch_size=args.batch_size, use_cfg=args.use_cfg, @@ -658,12 +1138,33 @@ def _build_easy_magpie_config(args) -> EasyMagpieInferenceConfig: phoneme_sampling_method=args.phoneme_sampling_method, dropout_text_input=args.dropout_text_input, ) + if cfg_cls is EasyMagpieMultiturnUserAudioInferenceConfig: + kwargs.update( + max_eval_turns=args.max_eval_turns, + save_debug_multiturn_audio=not args.no_save_debug_multiturn_audio, + ) + return cfg_cls(**kwargs) + + +def _select_runner_cls(args): + if args.model_type == 'magpie': + if args.easy_magpie_inference_mode != 'single_turn': + raise ValueError('--easy_magpie_inference_mode is only supported with --model_type easy_magpie') + return MagpieInferenceRunner + if args.easy_magpie_inference_mode == 'multiturn_user_audio': + return EasyMagpieMultiturnUserAudioInferenceRunner + return EasyMagpieInferenceRunner def main(argv=None): """Entry point for TTS inference and evaluation.""" parser = create_argument_parser() args = parser.parse_args(argv) + if args.model_type == 'easy_magpie' and args.easy_magpie_inference_mode == 'multiturn_user_audio': + _configure_cuda_for_rank() + if args.batch_size > 1: + parser.error("--easy_magpie_inference_mode multiturn_user_audio requires --batch_size 1.") + if args.deterministic: seed_all(seed=9) @@ -689,7 +1190,7 @@ def main(argv=None): is_easy_magpie = args.model_type == 'easy_magpie' load_fn = load_easy_magpie_model if is_easy_magpie else load_magpie_model inference_config = _build_easy_magpie_config(args) if is_easy_magpie else _build_magpie_config(args) - runner_cls = EasyMagpieInferenceRunner if is_easy_magpie else MagpieInferenceRunner + runner_cls = _select_runner_cls(args) eval_config = EvaluationConfig( sv_model=args.sv_model, @@ -698,6 +1199,13 @@ def main(argv=None): with_utmosv2=not args.disable_utmosv2, with_fcd=not args.disable_fcd, codec_model_path=args.codecmodel_path if not args.disable_fcd else None, + with_emotion_metrics=args.with_emotion_metrics, + emotion_model_size=args.emotion_model_size, + emotion_embedding_type=args.emotion_embedding_type, + emotion_cache_dir=args.emotion_cache_dir, + strip_text_annotations_for_metrics=args.strip_text_annotations_for_metrics, + asr_batch_size=args.asr_batch_size, + eou_batch_size=args.eou_batch_size, ) cer, ssim = None, None diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index 5af5f5d004d7..3be6fa1162a9 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -56,6 +56,7 @@ FixedBucketBatchSizeConstraint2D, MultimodalFixedBucketBatchSizeConstraint2D, MultimodalSamplingConstraint, + SpeakerFilter, TokenCountFilter, TokenPerSecondFilter, TokenPerTokenFilter, @@ -149,6 +150,8 @@ class LhotseDataLoadingConfig: # 2.3 Filters on CER and/or cosine speaker similarity of the context audio serving for TTS use cases. max_cer: float | None = float("inf") min_context_speaker_similarity: float | None = -1 + excluded_speaker_ids: Any = None + speaker_filter_fields: list[str] | None = None # 2.4 Filters on validation status. If the validation status is not "pass", the cut will be filtered out. keep: str = "pass" @@ -256,6 +259,21 @@ class LhotseDataLoadingConfig: slice_length: Optional[int] = None +def resolve_excluded_speaker_ids(excluded_speaker_ids): + if excluded_speaker_ids is None: + return None + + if isinstance(excluded_speaker_ids, str): + excluded_speaker_ids = OmegaConf.load(excluded_speaker_ids) + + excluded_speaker_ids = OmegaConf.to_container(excluded_speaker_ids, resolve=True) + + if isinstance(excluded_speaker_ids, dict): + excluded_speaker_ids = excluded_speaker_ids["excluded_speaker_ids"] + + return excluded_speaker_ids + + def determine_use_iterable_dataset(use_iterable_dataset: bool, config: DictConfig) -> bool: """Determine whether to use iterable dataset for a given configuration.""" assert not ( @@ -612,6 +630,12 @@ def get_lhotse_sampler_from_config(config, global_rank, world_size, tokenizer=No # validation status filtering cuts = cuts.filter(ValidationStatusFilter(config.keep)) + # Exclude cuts that contain known test speakers. + cuts = cuts.filter( + SpeakerFilter( + resolve_excluded_speaker_ids(config.excluded_speaker_ids), speaker_fields=config.speaker_filter_fields + ) + ) # CER filtering, same as native NeMo dataloaders. cuts = cuts.filter(CERFilter(config.max_cer)) # Context speaker similarity filtering, same as native NeMo dataloaders. diff --git a/nemo/collections/common/data/lhotse/sampling.py b/nemo/collections/common/data/lhotse/sampling.py index 0c927058c234..e73039d252f9 100644 --- a/nemo/collections/common/data/lhotse/sampling.py +++ b/nemo/collections/common/data/lhotse/sampling.py @@ -308,6 +308,45 @@ def __call__(self, example) -> bool: return True +class SpeakerFilter: + """ + Callable, returns ``False`` if any supervision in a cut belongs to an excluded speaker. + Checks configured supervision attributes/custom fields. + Acts as a pass-through for objects of other type than Cut. + """ + + def __init__( + self, + excluded_speaker_ids: Sequence[str] | None = None, + speaker_fields: Sequence[str] = None, + ) -> None: + self.excluded_speaker_ids = set(ifnone(excluded_speaker_ids, ())) + self.enabled = len(self.excluded_speaker_ids) > 0 + if self.enabled and speaker_fields is None: + raise ValueError( + "SpeakerFilter requires speaker_fields when excluded_speaker_ids is set. " + "Example: speaker_filter_fields=[speaker_id]" + ) + self.speaker_fields = tuple(ifnone(speaker_fields, ())) + + def __call__(self, example) -> bool: + if not self.enabled or not isinstance(example, Cut): + return True + + excluded_speaker_ids = self.excluded_speaker_ids + + for supervision in example.supervisions: + for field in self.speaker_fields: + if supervision.has_custom(field): + speaker_id = getattr(supervision, field) + else: + speaker_id = getattr(supervision, field, None) + + if speaker_id in excluded_speaker_ids: + return False + return True + + class ContextSpeakerSimilarityFilter: """ Callable, returns ``True`` if a cut's context speaker similarity is greater than min_context_speaker_similarity and ``False`` otherwise. diff --git a/nemo/collections/speechlm2/parts/pretrained.py b/nemo/collections/speechlm2/parts/pretrained.py index e0d485bc0529..a557dd9291aa 100644 --- a/nemo/collections/speechlm2/parts/pretrained.py +++ b/nemo/collections/speechlm2/parts/pretrained.py @@ -302,46 +302,111 @@ def setup_parallel_expert_encoder(model: torch.nn.Module): def set_model_dict_for_partial_init( - pretrained_dict: Dict[str, torch.Tensor], model_dict: Dict[str, torch.Tensor] + pretrained_dict: Dict[str, torch.Tensor], + model_dict: Dict[str, torch.Tensor], + allow_partial_copy: bool = False, ) -> Dict[str, torch.Tensor]: """ Partially initialize a model's state dictionary with a pretrained state dictionary. + This function safely copies compatible layers from a pretrained model into a new model, - ignoring layers with mismatched shapes or missing keys. + ignoring layers with missing keys or incompatible shapes. + + By default, only tensors with exactly matching shapes are restored. - Steps: - 1. Remove layers from the pretrained dictionary if their shape does not match the target model. - 2. Keep only keys that exist in the target model. - 3. Update the model dictionary with the filtered pretrained weights. + If ``allow_partial_copy=True``, tensors whose shapes differ only in the first + dimension are partially restored by copying the overlapping rows from the + pretrained tensor into the target tensor. The remaining rows keep their + model-initialized values. This is useful when adding new vocabulary rows or + special tokens, e.g. adding an interruption token to an embedding table. Args: - pretrained_dict (Dict[str, torch.Tensor]): - The state dictionary of the pretrained model. - model_dict (Dict[str, torch.Tensor]): - The state dictionary of the target model to be partially initialized. + pretrained_dict: + State dictionary from the pretrained checkpoint. + + model_dict: + State dictionary of the target model. + + allow_partial_copy: + If True, allow partial row-wise restore for tensors where only + dimension 0 differs and all trailing dimensions match. Defaults to False. Returns: Dict[str, torch.Tensor]: - The updated model state dictionary with compatible layers loaded from the pretrained dictionary. + The updated model state dictionary with compatible pretrained weights loaded. Example: >>> model_dict = model.state_dict() >>> pretrained_dict = load_checkpoint("pretrained_model.ckpt") - >>> model_dict = set_model_dict_for_partial_init(pretrained_dict, model_dict) + >>> model_dict = set_model_dict_for_partial_init( + ... pretrained_dict, + ... model_dict, + ... allow_partial_copy=True, + ... ) >>> model.load_state_dict(model_dict) """ - # 1. Remove layers where pretrained shape differs from model shape - for k, v in list(pretrained_dict.items()): - if k in model_dict and hasattr(model_dict[k], "numel") and v.numel() != model_dict[k].numel(): - del pretrained_dict[k] - logging.info(f" | > Layer with shape mismatch in the model definition: {k}") - - # 2. Keep only keys that exist in the target model - pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} - - # 3. Update model dictionary with filtered pretrained layers - model_dict.update(pretrained_dict) - logging.info(f" | > {len(pretrained_dict)} / {len(model_dict)} layers are restored.") + restored_dict = {} + exact_restored = 0 + partial_restored = 0 + skipped_mismatch = 0 + + for key, pretrained_value in pretrained_dict.items(): + if key not in model_dict: + continue + + model_value = model_dict[key] + + if not hasattr(pretrained_value, "shape") or not hasattr(model_value, "shape"): + continue + + if pretrained_value.shape == model_value.shape: + restored_dict[key] = pretrained_value + exact_restored += 1 + continue + + can_partial_copy = ( + allow_partial_copy + and pretrained_value.ndim == model_value.ndim + and pretrained_value.ndim > 0 + and pretrained_value.shape[1:] == model_value.shape[1:] + ) + + if can_partial_copy: + merged_value = model_value.clone() + rows_to_copy = min(pretrained_value.shape[0], model_value.shape[0]) + + merged_value[:rows_to_copy].copy_( + pretrained_value[:rows_to_copy].to( + device=merged_value.device, + dtype=merged_value.dtype, + ) + ) + + restored_dict[key] = merged_value + partial_restored += 1 + + logging.info( + f" | > Partially restored resized tensor: {key} " + f"pretrained={tuple(pretrained_value.shape)} " + f"model={tuple(model_value.shape)} " + f"copied_rows={rows_to_copy}" + ) + continue + + skipped_mismatch += 1 + logging.info( + f" | > Layer with shape mismatch in the model definition: {key} " + f"pretrained={tuple(pretrained_value.shape)} " + f"model={tuple(model_value.shape)}" + ) + + model_dict.update(restored_dict) + + logging.info( + f" | > {len(restored_dict)} / {len(model_dict)} layers are restored " + f"({exact_restored} exact, {partial_restored} partial, " + f"{skipped_mismatch} skipped due to incompatible shape)." + ) return model_dict diff --git a/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py new file mode 100644 index 000000000000..708d8f2b4a20 --- /dev/null +++ b/nemo/collections/tts/data/text_to_speech_dataset_lhotse_multiturn.py @@ -0,0 +1,943 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import random +import re +from typing import Dict, List, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.data +from lhotse import CutSet, Seconds, compute_num_frames +from lhotse.cut import Cut +from lhotse.dataset.collation import collate_audio, collate_matrices, collate_vectors +from lhotse.utils import ifnone +from omegaconf import DictConfig +from transformers import AutoTokenizer, T5Tokenizer + +from nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers import AggregatedTTSTokenizer, IPABPETokenizer +from nemo.collections.speechlm2.parts.precision import fp32_precision +from nemo.collections.tts.parts.utils.tts_dataset_utils import ( + beta_binomial_prior_distribution, + normalize_volume, + stack_tensors, +) +from nemo.core.classes.common import safe_instantiate +from nemo.utils import logging + + +def setup_tokenizers(all_tokenizers_config, mode='train'): + tokenizers = [] + tokenizer_names = [] + for tokenizer_name in all_tokenizers_config: + tokenizer_config = all_tokenizers_config[tokenizer_name] + if tokenizer_config._target_ == 'AutoTokenizer': + tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.pretrained_model, trust_remote_code=True) + elif tokenizer_config._target_ == 'T5Tokenizer': + tokenizer = T5Tokenizer.from_pretrained(tokenizer_config.pretrained_model) + else: + text_tokenizer_kwargs = {} + if "g2p" in tokenizer_config: + text_tokenizer_kwargs["g2p"] = safe_instantiate(tokenizer_config.g2p) + tokenizer = safe_instantiate(tokenizer_config, **text_tokenizer_kwargs) + if mode == 'test' and hasattr(tokenizer, "set_phone_prob"): + tokenizer.set_phone_prob(1.0) + tokenizers.append(tokenizer) + tokenizer_names.append(tokenizer_name) + + aggregated_tokenizer = AggregatedTTSTokenizer(tokenizers, tokenizer_names) + return aggregated_tokenizer + + +def check_speaker_format(item: str): + pattern = r"\| Language:\w+ Dataset:[\w\d\W]+ Speaker:[\w\d\W]+ \|" + return bool(re.match(pattern, item)) + + +def _strip_timestamps( + text: str, _TIMESTAMP_PATTERN=re.compile(r"<\|\d+\|>"), _SPACE_PATTERN=re.compile(r"\s+") +) -> str: + if text is None: + return "" + text = _TIMESTAMP_PATTERN.sub("", text) + return _SPACE_PATTERN.sub(" ", text).strip() + + +def _count_words_ignoring_punctuation(text: str, _PUNCTUATION_PATTERN=re.compile(r"[^\w\s]+")) -> int: + text = _PUNCTUATION_PATTERN.sub("", _strip_timestamps(text)) + return len(text.split()) + + +def _get_supervision_ipa_text(supervision) -> str: + """Return IPA for a supervision, preferring top-level field over custom.""" + ipa_text = getattr(supervision, "ipa", None) + if isinstance(ipa_text, str) and ipa_text.strip(): + return ipa_text + + custom = getattr(supervision, "custom", None) + if isinstance(custom, dict): + custom_ipa = custom.get("ipa") + if isinstance(custom_ipa, str): + return custom_ipa + + return "" + + +class MagpieTTSLhotseMultiturnDataset(torch.utils.data.Dataset): + """ + A PyTorch Dataset for loading and processing Text-to-Speech data for + MagpieTTS models using Lhotse CutSets, specifically designed for datasets + with text or audio context. But either context can be optional. + + This dataset expects Lhotse Cut objects where each cut represents a + target utterance along with its preceding context. Context can be + audio (preferred) or text. It handles loading either pre-computed audio + codes or raw audio waveforms, applying volume normalization, and tokenizing + text transcripts. Context audio/codes are sliced or repeated to fit within + a specified duration range. Optionally, it loads 16kHz audio suitable for + speaker verification models and calculates alignment priors. + + Tokenizers (for target text and optional context text) are initialized lazily + within each dataloader worker process upon first access. + + Args: + sample_rate (int): Target sample rate for loading audio. Audio will be + resampled if necessary. + volume_norm (bool): If True, applies peak volume normalization to audio + waveforms. Defaults to True. + codec_model_samples_per_frame (int): The total downsampling factor of the + audio codec model used to generate codes. Used for padding audio + and calculating number of codec frames. + num_audio_codebooks (int): Number of codebooks used by the audio codec model. + Needed for creating dummy context codes if necessary. + prior_scaling_factor (Optional[float]): Scaling factor for the beta-binomial + alignment prior calculation. If None, priors are not computed. Defaults to None. + load_cached_codes_if_available (bool): If True, attempts to load pre-computed + audio codes from custom fields in the Lhotse Cut (e.g., 'codes_21fpsCausalDecoder', + 'context_codes_21fpsCausalDecoder'). Falls back to loading audio if codes + are not found. Defaults to True. + dataset_type (str): Specifies the mode ('train' or 'test'), mainly affecting + tokenizer settings like phoneme probability. Defaults to 'train'. + load_16khz_audio (bool): If True, loads 16kHz audio suitable for speaker + verification models. It prioritizes context audio ('context_audio' field) + if available, otherwise uses the target audio ('target_audio' field). + Defaults to True. + pad_context_text_to_max_duration (bool): If True and `use_text_conditioning_tokenizer` + is True, pads the tokenized context text to a length derived from + `context_duration_max`. Defaults to False. + context_duration_min (float): Minimum duration (in seconds) for the context + audio/codes. Context shorter than this will be repeated. Defaults to 3.0. + context_duration_max (float): Maximum duration (in seconds) for the context + audio/codes. Context longer than this will be sliced randomly. Defaults to 10.0. + use_text_conditioning_tokenizer (bool): If True, enables processing of context + text using a separate tokenizer (currently T5Tokenizer). Expects context text + in `cut.supervisions[0].custom['context_text']`. Defaults to False. + tokenizer_config (Optional[DictConfig]): Configuration for the text tokenizers. + Used for lazy initialization within workers. Must be provided if tokenizers + are not set externally. Defaults to None. + text_context_remapping: Dict defining mapping of multiple text contexts to a single text context. + text_context_remapping_prob: Probability of remapping the original text context to a remapped text context. + phoneme_turn_max_words_to_drop: Turns with this many words or fewer keep empty phoneme string. + """ + + def __init__( + self, + sample_rate: int, + volume_norm: bool = True, + codec_model_samples_per_frame: int = None, + codec_model_input_sample_rate: int = None, + frame_stacking_factor: int = None, + num_audio_codebooks: int = None, + prior_scaling_factor: float = None, + load_cached_codes_if_available: bool = True, + dataset_type: str = 'train', + load_16khz_audio: bool = False, + pad_context_text_to_max_duration: bool = False, + context_duration_min: float = 3.0, + context_duration_max: float = 10.0, + use_text_conditioning_tokenizer: bool = False, + text_conditioning_tokenizer_name: str = None, + tokenizer_config: DictConfig = None, + text_context_remapping: Dict[str, str] = None, + text_context_remapping_prob: float = 0.0, + phoneme_tokenizer_config: DictConfig = None, + ignore_phoneme_languages: List[str] = None, + add_language_to_context_text: bool = False, + source_sample_rate: int = 16000, + input_roles: List[str] = ["user", "User"], + output_roles: List[str] = ["assistant", "Assistant", "agent", "Agent"], + add_text_bos: bool = False, + phoneme_turn_dropout_batch_prob: float = 0.0, + phoneme_turn_dropout_turn_prob: float = 0.0, + phoneme_turn_max_words_to_drop: int = 2, + ): + super().__init__() + self.sample_rate = sample_rate + self.volume_norm = volume_norm + + self.codec_model_samples_per_frame = codec_model_samples_per_frame + self.num_audio_codebooks = num_audio_codebooks + + self.include_align_prior = prior_scaling_factor is not None + self.prior_scaling_factor = prior_scaling_factor + self.load_cached_codes_if_available = load_cached_codes_if_available + self.dataset_type = dataset_type + self.load_16khz_audio = load_16khz_audio + self.use_text_conditioning_tokenizer = use_text_conditioning_tokenizer + self.text_conditioning_tokenizer_name = text_conditioning_tokenizer_name + self.pad_context_text_to_max_duration = pad_context_text_to_max_duration + self.context_duration_min = context_duration_min + self.context_duration_max = context_duration_max + self.tokenizer_config = tokenizer_config + self.text_tokenizer = None + self.phoneme_tokenizer = None + self.text_context_remapping = text_context_remapping + self.text_context_remapping_prob = text_context_remapping_prob + self.phoneme_tokenizer_config = phoneme_tokenizer_config + self.ignore_phoneme_languages = ignore_phoneme_languages or [] + self.add_language_to_context_text = add_language_to_context_text + + self.source_sample_rate = source_sample_rate + self.input_roles = set(ifnone(input_roles, ["user"])) + self.output_roles = set(ifnone(output_roles, ["agent"])) + self.add_text_bos = add_text_bos + self.phoneme_turn_dropout_batch_prob = phoneme_turn_dropout_batch_prob + self.phoneme_turn_dropout_turn_prob = phoneme_turn_dropout_turn_prob + self.phoneme_turn_max_words_to_drop = phoneme_turn_max_words_to_drop + + self.frame_length = ( + self.codec_model_samples_per_frame / codec_model_input_sample_rate + ) * frame_stacking_factor + + def get_num_audio_samples_to_slice(self, duration, sample_rate): + num_codec_frames = int(duration * sample_rate / self.codec_model_samples_per_frame) + num_audio_samples = num_codec_frames * self.codec_model_samples_per_frame + return num_audio_samples + + def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List]]: + if self.text_tokenizer is None: + worker_info = torch.utils.data.get_worker_info() + worker_id = worker_info.id if worker_info is not None else 0 + logging.info(f"Worker {worker_id} initializing tokenizers...") + self.text_tokenizer = setup_tokenizers( + all_tokenizers_config=self.tokenizer_config, + mode=self.dataset_type, + ) + num_tokens = len(self.text_tokenizer.tokens) + self.bos_id = num_tokens + self.eos_id = num_tokens + 1 + self.cfg_unk_token_id = num_tokens + 2 + self.interruption_token_id = num_tokens + 3 + self.pad_id = self.text_tokenizer.pad + + if self.phoneme_tokenizer is None and self.phoneme_tokenizer_config is not None: + self.phoneme_tokenizer = safe_instantiate(self.phoneme_tokenizer_config) + + cuts = cuts.transform_text(_strip_timestamps) + for cut in cuts: + if str(getattr(cut, "task", "tts")).lower() == "tts": + for supervision in cut.supervisions: + supervision.speaker = "agent" + + batch_tokenizer_names = [] + for cut in cuts: + if cut.has_custom("tokenizer_names"): + batch_tokenizer_names.append(random.choice(cut.tokenizer_names)) + else: + batch_tokenizer_names.append("english_phoneme") + + def _align_codebooks(t): + C = t.shape[1] + if C < self.num_audio_codebooks: + return F.pad(t, (0, self.num_audio_codebooks - C)) + elif C > self.num_audio_codebooks: + return t[:, : self.num_audio_codebooks] + return t + + with fp32_precision(): + target_audio, target_audio_lens = collate_audio( + cuts.resample(self.sample_rate, recording_field="target_audio"), recording_field="target_audio" + ) + target_audio_list = [] + source_audio_list = [] + # normalize volume and apply audio the removal of user turn if needed + + for i, cut in enumerate(cuts): + + # Extract the raw, unpadded 1D numpy array for this specific cut + t_audio = target_audio[i, : target_audio_lens[i]].numpy() + # For tts/single-turn cuts, source audio should be zeros like target audio + # For multi-turn cuts, keep loading source audio from the cut recording. + if str(getattr(cut, "task", "tts")).lower() == "tts": + src_len = int(round(len(t_audio) / self.sample_rate * self.source_sample_rate)) + s_audio = np.zeros(src_len, dtype=t_audio.dtype) + else: + s_audio = cut.resample(self.source_sample_rate).load_audio().squeeze(0) + + # Apply volume norm locally (so we only normalize the stitched audio, saving math ops) + if self.volume_norm: + t_audio = normalize_volume(t_audio) + s_audio = normalize_volume(s_audio) + + target_audio_list.append(torch.from_numpy(t_audio)) + source_audio_list.append(torch.from_numpy(s_audio)) + + # 3. Re-pad the newly stitched arrays to the batch's new maximum length + target_audio = collate_vectors(target_audio_list, padding_value=0.0) + target_audio_lens = torch.tensor([len(a) for a in target_audio_list], dtype=torch.long) + + source_audio = collate_vectors(source_audio_list, padding_value=0.0) + source_audio_lens = torch.tensor([len(a) for a in source_audio_list], dtype=torch.long) + + user_audio_turn_splitted, user_audio_turn_splitted_lens, user_audio_turn_splitted_indices = ( + extract_turn_audio_channel( + cuts=cuts, + source_audio_list=source_audio_list, + source_sample_rate=self.source_sample_rate, + roles=self.input_roles, + volume_norm=self.volume_norm, + ) + ) + + target_text_tokens, target_token_lens = collate_token_channel( + cuts, + self.text_tokenizer, + self.frame_length, + roles=self.output_roles, + add_text_bos=self.add_text_bos, + tokenizer_names=batch_tokenizer_names, + pad_id=self.pad_id, + eos_id=self.eos_id, + bos_id=self.bos_id, + interruption_token_id=self.interruption_token_id, + ) + source_tokens, source_token_lens = collate_token_channel( + cuts, + self.text_tokenizer, + self.frame_length, + roles=self.input_roles, + add_text_bos=self.add_text_bos, + tokenizer_names=batch_tokenizer_names, + pad_id=self.pad_id, + eos_id=self.eos_id, + bos_id=self.bos_id, + interruption_token_id=self.interruption_token_id, + ) + + if self.phoneme_tokenizer is not None: + target_phoneme_tokens, target_phoneme_lens, phoneme_turn_dropout = collate_phoneme_channel( + cuts, + self.phoneme_tokenizer, + self.frame_length, + roles=self.output_roles, + ignore_phoneme_languages=self.ignore_phoneme_languages, + pad_id=self.phoneme_tokenizer.pad, + eos_id=self.phoneme_tokenizer.eos_token_id, + bos_id=self.phoneme_tokenizer.bos_token_id, + phoneme_turn_dropout_batch_prob=self.phoneme_turn_dropout_batch_prob, + phoneme_turn_dropout_turn_prob=self.phoneme_turn_dropout_turn_prob, + phoneme_turn_max_words_to_drop=self.phoneme_turn_max_words_to_drop, + apply_turn_dropout=self.dataset_type == 'train', + ) + else: + target_phoneme_tokens, target_phoneme_lens, phoneme_turn_dropout = None, None, None + + dataset_name_list = [] + audio_list_16khz = [] + audio_len_list_16khz = [] + prior_list = [] + + target_codes_list = [] + source_codes_list = [] + + context_audio_list = [] + context_audio_len_list = [] + context_audio_codes_list = [] + context_audio_codes_len_list = [] + context_text_tokens_list = [] + context_text_tokens_len_list = [] + context_has_text_context_list = [] + reward_list = [] + language_list = [] + + def _sample_context_duration_with_available_limit(available_duration_sec: float) -> float: + effective_duration_max = min(self.context_duration_max, available_duration_sec) + effective_duration_max = max(self.context_duration_min, effective_duration_max) + return random.uniform(self.context_duration_min, effective_duration_max) + + for i, cut in enumerate(cuts): + speaker_found = False + for sup in reversed(cut.supervisions): + if check_speaker_format(sup.speaker): + dataset_name = sup.speaker.strip().split()[2].split(":")[-1] + speaker_found = True + break + + if not speaker_found: + dataset_name = "unknown" + dataset_name_list.append(dataset_name) + + language = ( + cut.lang + if cut.has_custom("lang") + else next((sup.language for sup in reversed(cut.supervisions) if sup.has_custom("language")), "en") + ) + language_list.append(language) + + # Target and Source Codes + if self.load_cached_codes_if_available: + if cut.has_custom("target_codes"): + codes_array = cut.target_codes.load().astype(np.int32) + target_codes_list.append(torch.from_numpy(codes_array).T) + + if cut.has_custom("source_codes"): + source_codes_list.append(torch.from_numpy(cut.source_codes.load().astype(np.int32)).T) + + # Context Audio or Context Codes + if self.load_cached_codes_if_available and cut.has_custom("context_codes"): + context_audio_codes_array = cut.context_codes.load().astype(np.int32) + context_audio_codes = torch.from_numpy(context_audio_codes_array) + _available_context_duration = ( + context_audio_codes.shape[1] * self.codec_model_samples_per_frame / self.sample_rate + ) + _context_duration_to_slice = _sample_context_duration_with_available_limit(_available_context_duration) + _num_frames_to_slice = int( + _context_duration_to_slice * self.sample_rate / self.codec_model_samples_per_frame + ) + + if _num_frames_to_slice < context_audio_codes.shape[1]: + start_idx = random.randint(0, context_audio_codes.shape[1] - _num_frames_to_slice) + context_audio_codes = context_audio_codes[:, start_idx : start_idx + _num_frames_to_slice] + else: + _num_repeats = int(np.ceil(_num_frames_to_slice / context_audio_codes.shape[1])) + context_audio_codes = context_audio_codes.repeat(1, _num_repeats)[:, :_num_frames_to_slice] + + context_audio_codes = _align_codebooks(context_audio_codes.T) + context_audio_codes_list.append(context_audio_codes) + context_audio_codes_len_list.append(context_audio_codes.shape[0]) + + elif cut.has_custom("context_audio"): + with fp32_precision(): + context_audio_array = cut.context_audio.resample(self.sample_rate).load_audio().squeeze(0) + if self.volume_norm: + context_audio_array = normalize_volume(context_audio_array) + + _available_context_duration = len(context_audio_array) / self.sample_rate + _context_duration_to_slice = _sample_context_duration_with_available_limit(_available_context_duration) + _num_samples_to_slice = self.get_num_audio_samples_to_slice( + _context_duration_to_slice, self.sample_rate + ) + + if _num_samples_to_slice < len(context_audio_array): + start_idx = random.randint(0, len(context_audio_array) - _num_samples_to_slice) + context_audio_array = context_audio_array[start_idx : start_idx + _num_samples_to_slice] + else: + _num_repeats = int(np.ceil(_num_samples_to_slice / len(context_audio_array))) + context_audio_array = np.tile(context_audio_array, _num_repeats)[:_num_samples_to_slice] + + context_audio = torch.from_numpy(context_audio_array) + context_audio_list.append(context_audio) + context_audio_len_list.append(context_audio.shape[0]) + + else: + matching_supervisions = [s for s in cut.supervisions if s.speaker in self.output_roles] + + if self.load_cached_codes_if_available: + if len(matching_supervisions) > 0 and cut.has_custom("target_codes"): + sup = random.choice(matching_supervisions) + codes_array = cut.target_codes.load().astype(np.int32) + start_frame = int(max(0, sup.start) * self.sample_rate / self.codec_model_samples_per_frame) + num_frames = int(sup.duration * self.sample_rate / self.codec_model_samples_per_frame) + context_audio_codes = torch.from_numpy(codes_array)[ + :, start_frame : start_frame + num_frames + ].T + context_audio_codes = _align_codebooks(context_audio_codes) + else: + context_audio_codes = torch.zeros([0, self.num_audio_codebooks], dtype=torch.int32) + context_audio_codes_list.append(context_audio_codes) + context_audio_codes_len_list.append(context_audio_codes.shape[0]) + else: + if len(matching_supervisions) > 0: + sup = random.choice(matching_supervisions) + with fp32_precision(): + turn_cut = cut.resample(self.sample_rate, recording_field="target_audio").truncate( + offset=max(0, sup.start), duration=sup.duration + ) + context_audio_array = turn_cut.load_custom("target_audio").squeeze(0) + if self.volume_norm: + context_audio_array = normalize_volume(context_audio_array) + context_audio = torch.from_numpy(context_audio_array) + else: + context_audio = torch.zeros(self.codec_model_samples_per_frame, dtype=torch.float32) + + context_audio_list.append(context_audio) + context_audio_len_list.append(context_audio.shape[0]) + + # 16khz audio for SV + if self.load_16khz_audio: + with fp32_precision(): + if cut.has_custom("context_audio"): + audio_array_16khz = cut.context_audio.resample(16_000).load_audio().squeeze(0) + if self.volume_norm: + audio_array_16khz = normalize_volume(audio_array_16khz) + + _available_context_duration = len(audio_array_16khz) / 16_000 + _context_duration_to_slice = _sample_context_duration_with_available_limit( + _available_context_duration + ) + _num_samples_to_slice = int(_context_duration_to_slice * 16_000) + if _num_samples_to_slice < len(audio_array_16khz): + start_idx = random.randint(0, len(audio_array_16khz) - _num_samples_to_slice) + audio_array_16khz = audio_array_16khz[start_idx : start_idx + _num_samples_to_slice] + else: + matching_supervisions = [s for s in cut.supervisions if s.speaker in self.output_roles] + if len(matching_supervisions) > 0: + sup = random.choice(matching_supervisions) + turn_cut = cut.resample(16_000, recording_field="target_audio").truncate( + offset=max(0, sup.start), duration=sup.duration + ) + audio_array_16khz = turn_cut.load_custom("target_audio").squeeze(0) + else: + audio_array_16khz = np.zeros(16000, dtype=np.float32) + + if self.volume_norm: + audio_array_16khz = normalize_volume(audio_array_16khz) + + audio_16khz = torch.from_numpy(audio_array_16khz) + audio_list_16khz.append(audio_16khz) + audio_len_list_16khz.append(audio_16khz.shape[0]) + + # Context Text + if self.use_text_conditioning_tokenizer: + context_text = next( + (sup.context_text for sup in cut.supervisions if sup.has_custom("context_text")), None + ) + if context_text is not None: + if self.text_context_remapping is not None and context_text in self.text_context_remapping: + if self.dataset_type == 'train' and random.random() < self.text_context_remapping_prob: + context_text = self.text_context_remapping[context_text] + context_text_tokens = self.text_tokenizer.encode( + context_text, tokenizer_name=self.text_conditioning_tokenizer_name + ) + has_text_context = True + else: + context_text = ( + f"[{language.upper()}]" if self.add_language_to_context_text else "[NO TEXT CONTEXT]" + ) + context_text_tokens = self.text_tokenizer.encode( + context_text, tokenizer_name=self.text_conditioning_tokenizer_name + ) + has_text_context = False + + if self.pad_context_text_to_max_duration: + _required_len = ( + int(self.context_duration_max * self.sample_rate / self.codec_model_samples_per_frame) + 2 + ) + if len(context_text_tokens) < _required_len: + _pad_id = self.text_tokenizer.tokenizer_pad_ids[self.text_conditioning_tokenizer_name] + context_text_tokens += [_pad_id] * (_required_len - len(context_text_tokens)) + else: + context_text_tokens = context_text_tokens[:_required_len] + + context_text_tokens = torch.tensor(context_text_tokens, dtype=torch.int32) + context_text_tokens_list.append(context_text_tokens) + context_text_tokens_len_list.append(context_text_tokens.shape[0]) + context_has_text_context_list.append(has_text_context) + + # Align Prior (Note: Using full target length to preserve shape compatibility) + if self.include_align_prior: + tok_name = batch_tokenizer_names[i] + full_text_len = sum( + [ + len(self.text_tokenizer.encode(sup.text, tokenizer_name=tok_name)) + for sup in cut.supervisions + if sup.speaker in self.output_roles + ] + ) + + if self.add_text_bos: + full_text_len += 2 * sum([1 for sup in cut.supervisions if sup.speaker in self.output_roles]) + else: + # cont eos token + full_text_len += sum([1 for sup in cut.supervisions if sup.speaker in self.output_roles]) + + full_text_len = max(1, full_text_len) + + if self.load_cached_codes_if_available and cut.has_custom("target_codes"): + spec_len = int(target_codes_list[-1].shape[0]) + 1 + else: + spec_len = ( + int(target_audio_lens[i] / self.codec_model_samples_per_frame) + 2 + ) # +1 extra in case it was truncated + + align_prior = beta_binomial_prior_distribution( + phoneme_count=full_text_len, mel_count=spec_len, scaling_factor=self.prior_scaling_factor + ) + prior_list.append(torch.tensor(align_prior, dtype=torch.float32)) + + reward = next((sup.reward for sup in reversed(cut.supervisions) if sup.has_custom("reward")), None) + if reward is not None: + reward_list.append(reward) + + batch_dict = { + "sample_id": [str(cut.id) for cut in cuts], + "dataset_names": dataset_name_list, + "languages": language_list, + "source_audio": source_audio, + "source_audio_lens": source_audio_lens, + "audio": target_audio, + "audio_lens": target_audio_lens, + "source_tokens": source_tokens, + "source_token_lens": source_token_lens, + "text": target_text_tokens, + "text_lens": target_token_lens, + "raw_texts": [ + " ".join(s.text for s in cut.supervisions if s.speaker in self.output_roles) for cut in cuts + ], + "task": [getattr(cut, "task", "tts") for cut in cuts], + "user_audio_turn_splitted": user_audio_turn_splitted, + "user_audio_turn_splitted_lens": user_audio_turn_splitted_lens, + "user_audio_turn_splitted_indices": user_audio_turn_splitted_indices, + } + + if target_codes_list: + batch_dict["audio_codes"] = collate_matrices(target_codes_list, padding_value=0).transpose(1, 2) + batch_dict["audio_codes_lens"] = torch.IntTensor([c.shape[0] for c in target_codes_list]) + + if source_codes_list: + batch_dict["source_codes"] = collate_matrices(source_codes_list, padding_value=0).transpose(1, 2) + batch_dict["source_codes_lens"] = torch.IntTensor([c.shape[0] for c in source_codes_list]) + + if self.phoneme_tokenizer is not None: + batch_dict["phoneme_tokens"] = target_phoneme_tokens + batch_dict["phoneme_tokens_lens"] = target_phoneme_lens + batch_dict["phoneme_turn_dropout"] = phoneme_turn_dropout + + if len(audio_list_16khz) > 0: + batch_dict["audio_16khz"] = collate_vectors(audio_list_16khz, padding_value=0.0) + batch_dict["audio_lens_16khz"] = torch.IntTensor(audio_len_list_16khz) + + if len(context_audio_list) > 0: + batch_dict["context_audio"] = collate_vectors(context_audio_list, padding_value=0.0) + batch_dict["context_audio_lens"] = torch.IntTensor(context_audio_len_list) + + if len(context_audio_codes_list) > 0: + batch_dict["context_audio_codes"] = collate_matrices(context_audio_codes_list, padding_value=0).transpose( + 1, 2 + ) + batch_dict["context_audio_codes_lens"] = torch.IntTensor(context_audio_codes_len_list) + + if self.use_text_conditioning_tokenizer: + batch_dict['context_text_tokens'] = collate_vectors( + tensors=context_text_tokens_list, + padding_value=self.text_tokenizer.tokenizer_pad_ids[self.text_conditioning_tokenizer_name], + ) + batch_dict['context_text_tokens_lens'] = torch.IntTensor(context_text_tokens_len_list) + batch_dict['has_text_context'] = torch.BoolTensor(context_has_text_context_list) + + if self.include_align_prior: + spec_max_len = max([prior.shape[0] for prior in prior_list]) + text_max_len = max([prior.shape[1] for prior in prior_list]) + batch_dict["align_prior_matrix"] = stack_tensors(prior_list, max_lens=[text_max_len, spec_max_len]) + + if len(reward_list) > 0: + batch_dict['rewards'] = torch.FloatTensor(reward_list) + + agent_mask, agent_mask_lens = collate_speaker_mask_channel( + cuts, + self.frame_length, + self.output_roles, + ) + + user_mask, user_mask_lens = collate_speaker_mask_channel( + cuts, + self.frame_length, + self.input_roles, + ) + + batch_dict["agent_mask"] = agent_mask + batch_dict["agent_mask_lens"] = agent_mask_lens + batch_dict["user_mask"] = user_mask + batch_dict["user_mask_lens"] = user_mask_lens + return batch_dict + + +def collate_token_channel( + cuts: CutSet, + tokenizer, + frame_length: Seconds, + roles: set[str], + add_text_bos: bool = True, + tokenizer_names: list[str] = None, + pad_id: int = None, + eos_id: int = None, + bos_id: int = None, + interruption_token_id: int = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Build and collate token channels aligned to the audio frame grid.""" + tokens = [] + + for i, c in enumerate(cuts): + tok_name = tokenizer_names[i] if tokenizer_names else "english_phoneme" + tokens.append( + build_token_channel( + c, + tokenizer, + frame_length, + roles, + pad_id, + eos_id, + bos_id, + interruption_token_id, + add_text_bos, + tok_name, + ) + ) + token_lens = torch.tensor([len(tt) for tt in tokens]) + return collate_vectors(tokens, padding_value=pad_id), token_lens + + +def build_speaker_mask_channel( + cut: Cut, + frame_length: Seconds, + output_roles: set[str], +) -> torch.Tensor: + total = compute_num_frames(cut.duration, frame_length, cut.sampling_rate) + mask = torch.zeros(total, dtype=torch.float32) + + for supervision in cut.supervisions: + start = compute_num_frames(supervision.start, frame_length, cut.sampling_rate) + end = compute_num_frames(supervision.end, frame_length, cut.sampling_rate) + + if supervision.speaker in output_roles: + mask[start:end] = 1.0 + + return mask + + +def collate_speaker_mask_channel( + cuts: CutSet, + frame_length: Seconds, + output_roles: set[str], +): + masks = [build_speaker_mask_channel(cut, frame_length, output_roles) for cut in cuts] + mask_lens = torch.tensor([len(m) for m in masks]) + return collate_vectors(masks, padding_value=0.0), mask_lens + + +def build_token_channel( + cut: Cut, + tokenizer, + frame_length: Seconds, + roles: set[str], + pad_id: int = -1, + eos_id: int = -2, + bos_id: int = -3, + interruption_token_id: int = -4, + add_text_bos: bool = True, + tokenizer_name: str = "english_phoneme", +) -> torch.Tensor: + + total = compute_num_frames(cut.duration, frame_length, cut.sampling_rate) + tokens = torch.ones(total, dtype=torch.long) * pad_id + + for supervision in cut.supervisions: + if supervision.speaker in roles: + text = supervision.text + + if hasattr(tokenizer, "encode"): + try: + raw_ids = tokenizer.encode(text=text, tokenizer_name=tokenizer_name) + except TypeError: + raw_ids = tokenizer.encode(text) + else: + raw_ids = tokenizer.text_to_ids(text) + + if add_text_bos: + text_ids = torch.as_tensor([bos_id] + raw_ids + [eos_id]) + else: + text_ids = torch.as_tensor(raw_ids + [eos_id]) + + pos = compute_num_frames(supervision.start, frame_length, cut.sampling_rate) + if pos >= len(tokens): + continue + + endpos = pos + len(text_ids) + if endpos > len(tokens): + text_ids = text_ids[: len(tokens) - pos] + tokens[pos : pos + len(text_ids)] = text_ids + + # add interruption token, used for add speech eos and interrupt the model + interruption_pos = compute_num_frames(supervision.end, frame_length, cut.sampling_rate) + if interruption_pos < len(tokens): + tokens[interruption_pos] = interruption_token_id + + return tokens + + +def extract_turn_audio_channel( + cuts: CutSet, + source_audio_list: list[torch.Tensor], + source_sample_rate: int, + roles: set[str], + volume_norm: bool = True, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Packs role-specific speech turns into batch dimension. + + Returns: + user_audio: [N_turns, T_max] + user_audio_lens: [N_turns] + user_audio_indices: [N_turns, 3] + columns = [original_batch_idx, start_sample, end_sample] + in the original source_audio timeline. + """ + turn_audio = [] + turn_lens = [] + turn_indices = [] + + for batch_idx, cut in enumerate(cuts): + audio = source_audio_list[batch_idx] + audio_len = audio.shape[0] + + for sup in cut.supervisions: + if sup.speaker not in roles: + continue + + start = int(round(max(0.0, sup.start) * source_sample_rate)) + end = int(round(max(0.0, sup.end) * source_sample_rate)) + + start = min(max(start, 0), audio_len) + end = min(max(end, 0), audio_len) + + if end <= start: + continue + + chunk = audio[start:end] + + # normalize turn level audio + if volume_norm: + chunk = normalize_volume(chunk.numpy()) + chunk = torch.from_numpy(chunk) + + turn_audio.append(chunk) + turn_lens.append(chunk.shape[0]) + turn_indices.append([batch_idx, start, end]) + + if len(turn_audio) == 0: + return None, None, None + else: + user_audio = collate_vectors(turn_audio, padding_value=0.0) + user_audio_lens = torch.tensor(turn_lens, dtype=torch.long) + user_audio_indices = torch.tensor(turn_indices, dtype=torch.long) + + return user_audio, user_audio_lens, user_audio_indices + + +def collate_phoneme_channel( + cuts: CutSet, + phoneme_tokenizer, + frame_length: Seconds, + roles: set[str], + ignore_phoneme_languages: list[str], + pad_id: int = -1, + eos_id: int = -2, + bos_id: int = -3, + phoneme_turn_dropout_batch_prob: float = 0.0, + phoneme_turn_dropout_turn_prob: float = 0.0, + phoneme_turn_max_words_to_drop: int = 2, + apply_turn_dropout: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + tokens = [] + dropout_flags = [] + for i, c in enumerate(cuts): + token, dropout_applied = build_phoneme_channel( + c, + phoneme_tokenizer, + frame_length, + roles, + ignore_phoneme_languages, + pad_id, + eos_id, + bos_id, + phoneme_turn_dropout_batch_prob=phoneme_turn_dropout_batch_prob, + phoneme_turn_dropout_turn_prob=phoneme_turn_dropout_turn_prob, + phoneme_turn_max_words_to_drop=phoneme_turn_max_words_to_drop, + apply_turn_dropout=apply_turn_dropout, + ) + tokens.append(token) + dropout_flags.append(dropout_applied) + token_lens = torch.tensor([len(tt) for tt in tokens]) + return collate_vectors(tokens, padding_value=pad_id), token_lens, torch.tensor(dropout_flags, dtype=torch.bool) + + +def build_phoneme_channel( + cut: Cut, + phoneme_tokenizer, + frame_length: Seconds, + roles: set[str], + ignore_phoneme_languages: list[str], + pad_id: int = -1, + eos_id: int = -2, + bos_id: int = -3, + phoneme_turn_dropout_batch_prob: float = 0.0, + phoneme_turn_dropout_turn_prob: float = 0.0, + phoneme_turn_max_words_to_drop: int = 2, + apply_turn_dropout: bool = False, +) -> tuple[torch.Tensor, bool]: + language = ( + cut.lang + if cut.has_custom("lang") + else next((sup.language for sup in cut.supervisions if sup.has_custom("language")), "en") + ) + + total = compute_num_frames(cut.duration, frame_length, cut.sampling_rate) + tokens = torch.ones(total, dtype=torch.long) * pad_id + dropout_applied = False + apply_dropout = ( + apply_turn_dropout + and phoneme_turn_dropout_batch_prob > 0.0 + and phoneme_turn_dropout_turn_prob > 0.0 + and random.random() < phoneme_turn_dropout_batch_prob + ) + + for supervision in cut.supervisions: + if supervision.speaker in roles: + if apply_dropout and random.random() < phoneme_turn_dropout_turn_prob: + dropout_applied = True + continue + + if isinstance(phoneme_tokenizer, IPABPETokenizer): + ipa_text = _get_supervision_ipa_text(supervision) + if language in ignore_phoneme_languages: + ipa_text = "" + else: + ipa_text = supervision.text + + if _count_words_ignoring_punctuation(supervision.text) <= phoneme_turn_max_words_to_drop: + ipa_text = "" + + phoneme_ids = phoneme_tokenizer.encode(ipa_text) + phoneme_ids = [bos_id] + phoneme_ids + [eos_id] + phoneme_ids = torch.as_tensor(phoneme_ids, dtype=torch.long) + pos = compute_num_frames(supervision.start, frame_length, cut.sampling_rate) + if pos >= len(tokens): + continue + + endpos = pos + len(phoneme_ids) + if endpos > len(tokens): + phoneme_ids = phoneme_ids[: len(tokens) - pos] + tokens[pos : pos + len(phoneme_ids)] = phoneme_ids + + return tokens, dropout_applied diff --git a/nemo/collections/tts/metrics/emotion_encoder.py b/nemo/collections/tts/metrics/emotion_encoder.py new file mode 100644 index 000000000000..019d6b14d6cb --- /dev/null +++ b/nemo/collections/tts/metrics/emotion_encoder.py @@ -0,0 +1,1293 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Lightweight LAION Empathic Insight Voice interface. + +This script provides a Hugging Face-style Python class for LAION's +Empathic Insight Voice models without using ModelScope. + +It supports: + + 1. Restoring the Whisper encoder from Hugging Face. + 2. Restoring LAION classifier MLP heads from Hugging Face .pth files. + 3. Extracting the full Whisper encoder embedding: + [B, 1500, 768] + 4. Extracting one classifier projection embedding: + Small: [B, 64] + Large: [B, 128] + 5. Extracting an SV-style emotion similarity embedding by concatenating + multiple classifier projection embeddings: + Small, 40 labels: [B, 40 * 64] = [B, 2560] + Large, 40 labels: [B, 40 * 128] = [B, 5120] + 6. Extracting an official-style emotion score vector: + [B, num_labels] + 7. Computing ranked emotion predictions and cosine similarity. + +Recommended for emotion similarity: + + model = EmpathicInsightVoice.from_pretrained(size="small", device="cuda") + emb = model.extract_emotion_embedding("audio.wav", embedding_type="head_concat") + sim = model.emotion_similarity("a.wav", "b.wav", embedding_type="head_concat") + +Notes: + + - The official model is a collection of independent expert heads. Each head + predicts one emotion or attribute score. + - The "head_concat" embedding is an engineering adaptation for + speaker-verification-style similarity. It concatenates the learned + projection outputs from multiple classifier heads. + - The "score_vector" embedding is closer to the documented inference output: + a vector of raw emotion intensity scores. +""" + +from __future__ import annotations + +import argparse +from pathlib import Path +from typing import Any, Optional, Sequence, Union + +import librosa +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from huggingface_hub import hf_hub_download +from transformers import WhisperForConditionalGeneration, WhisperProcessor + + +# ============================================================================= +# Label mapping +# ============================================================================= + +# MIRRORING-style 12 emotions (https://bench.theliva.ai/legacy/mirroring.html): +# Amusement, Anger, Elation, Impatience, Surprise, +# Emotional Numbness, Contemplation, Disappointment, +# Confusion, Pride, Affection, Sadness. +# +# The public-facing labels below are Python-friendly. +# Some labels map to LAION's original longer checkpoint names: +# impatience -> model_Impatience_and_Irritability_best.pth +# surprise -> model_Astonishment_Surprise_best.pth + +LAION_LABEL_TO_FILENAME: dict[str, str] = { + "amusement": "model_Amusement_best.pth", + "anger": "model_Anger_best.pth", + "elation": "model_Elation_best.pth", + "impatience": "model_Impatience_and_Irritability_best.pth", + "surprise": "model_Astonishment_Surprise_best.pth", + "emotional_numbness": "model_Emotional_Numbness_best.pth", + "contemplation": "model_Contemplation_best.pth", + "disappointment": "model_Disappointment_best.pth", + "confusion": "model_Confusion_best.pth", + "pride": "model_Pride_best.pth", + "affection": "model_Affection_best.pth", + "sadness": "model_Sadness_best.pth", +} + + +PRIMARY_EMOTION_LABELS: list[str] = [ + "amusement", + "anger", + "elation", + "impatience", + "surprise", + "emotional_numbness", + "contemplation", + "disappointment", + "confusion", + "pride", + "affection", + "sadness", +] + + +AUXILIARY_SIMILARITY_LABELS: list[str] = [] + +# ============================================================================= +# Model architecture specs +# ============================================================================= + +MODEL_SPECS: dict[str, dict[str, Any]] = { + "small": { + "repo_id": "laion/Empathic-Insight-Voice-Small", + "whisper_model_id": "mkrausio/EmoWhisper-AnS-Small-v0.1", + "sample_rate": 16000, + "max_audio_seconds": 30.0, + "seq_len": 1500, + "embed_dim": 768, + "projection_dim": 64, + "mlp_hidden_dims": [64, 32, 16], + "mlp_dropouts": [0.0, 0.1, 0.1, 0.1], + }, + "large": { + "repo_id": "laion/Empathic-Insight-Voice-Large", + "whisper_model_id": "mkrausio/EmoWhisper-AnS-Small-v0.1", + "sample_rate": 16000, + "max_audio_seconds": 30.0, + "seq_len": 1500, + "embed_dim": 768, + "projection_dim": 128, + "mlp_hidden_dims": [128, 64, 32], + "mlp_dropouts": [0.0, 0.1, 0.1, 0.1], + }, +} + + +# ============================================================================= +# MLP head +# ============================================================================= + + +class FullEmbeddingMLP(nn.Module): + """Classifier head used by Empathic Insight Voice. + + The model receives a full Whisper encoder sequence embedding: + + [batch, seq_len, embed_dim] + + For Empathic Insight Voice this is normally: + + [batch, 1500, 768] + + It then performs: + + flatten -> projection -> MLP -> scalar score + + The projection output is useful as an SV-style emotion embedding: + + Small: [batch, 64] + Large: [batch, 128] + + Each restored classifier head has its own projection layer. Therefore, + "anger" projection, "sadness" projection, and "arousal" projection are + all different learned spaces. + """ + + def __init__( + self, + seq_len: int, + embed_dim: int, + projection_dim: int, + mlp_hidden_dims: Sequence[int], + mlp_dropout_rates: Sequence[float], + ) -> None: + super().__init__() + + if len(mlp_dropout_rates) != len(mlp_hidden_dims) + 1: + raise ValueError( + "Dropout rates length error. " + f"Expected {len(mlp_hidden_dims) + 1}, " + f"got {len(mlp_dropout_rates)}." + ) + + self.seq_len = seq_len + self.embed_dim = embed_dim + self.projection_dim = projection_dim + + self.flatten = nn.Flatten() + self.proj = nn.Linear(seq_len * embed_dim, projection_dim) + + layers: list[nn.Module] = [ + nn.ReLU(), + nn.Dropout(mlp_dropout_rates[0]), + ] + + current_dim = projection_dim + for i, hidden_dim in enumerate(mlp_hidden_dims): + layers.extend( + [ + nn.Linear(current_dim, hidden_dim), + nn.ReLU(), + nn.Dropout(mlp_dropout_rates[i + 1]), + ] + ) + current_dim = hidden_dim + + layers.append(nn.Linear(current_dim, 1)) + self.mlp = nn.Sequential(*layers) + + def extract_projected_embedding(self, x: torch.Tensor) -> torch.Tensor: + """Return the classifier projection embedding before the MLP. + + Args: + x: + Whisper embedding with shape [B, seq_len, embed_dim], or + [B, 1, seq_len, embed_dim]. + + Returns: + Projected embedding with shape [B, projection_dim]. + """ + if x.ndim == 4 and x.shape[1] == 1: + x = x.squeeze(1) + + if x.ndim != 3: + raise ValueError(f"Expected x with shape [B, T, C], got shape {tuple(x.shape)}.") + + return self.proj(self.flatten(x)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Return one scalar score per input example.""" + projected = self.extract_projected_embedding(x) + return self.mlp(projected) + + +# ============================================================================= +# Main class +# ============================================================================= + + +class EmpathicInsightVoice(nn.Module): + """Lightweight Hugging Face-style Empathic Insight Voice class. + + This class intentionally does not inherit from NeMo ModelPT. It behaves + like a normal PyTorch/Hugging Face utility model. + + Main methods: + + - extract_whisper_embedding(audio_path) + Returns [1, 1500, 768]. + + - extract_classifier_projection(audio_path, label) + Returns one head-specific projection: + Small: [1, 64] + Large: [1, 128] + + - extract_emotion_embedding(audio_path, embedding_type="head_concat") + Recommended SV-style emotion similarity embedding. + + - predict_emotions_from_embedding(embedding) + Returns raw scores and ranked softmax-like top emotions. + + - compute(audio_path) + Returns emotion predictions and optionally an embedding. + + - emotion_similarity(audio_path_a, audio_path_b) + Returns cosine similarity between extracted emotion embeddings. + """ + + def __init__( + self, + size: str = "small", + device: Union[str, torch.device] = "cuda", + mlp_device: Optional[Union[str, torch.device]] = None, + cache_dir: Optional[Union[str, Path]] = None, + cache_classifiers: bool = True, + load_all_classifiers: bool = False, + top_k_emotions: int = 5, + torch_dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ) -> None: + super().__init__() + + if size not in MODEL_SPECS: + raise ValueError(f"Unsupported size={size!r}. Expected one of {sorted(MODEL_SPECS)}.") + + self.size = size + self.spec = MODEL_SPECS[size] + self.cache_classifiers = cache_classifiers + self.top_k_emotions = top_k_emotions + self.cache_dir = Path(cache_dir) if cache_dir is not None else None + + requested_device = torch.device(device) + if requested_device.type == "cuda" and not torch.cuda.is_available(): + requested_device = torch.device("cpu") + + self.device = requested_device + self.mlp_device = torch.device(mlp_device) if mlp_device is not None else self.device + + self.sample_rate = int(self.spec["sample_rate"]) + self.max_audio_seconds = float(self.spec["max_audio_seconds"]) + + # Load Whisper processor and encoder model. + self.processor = WhisperProcessor.from_pretrained( + self.spec["whisper_model_id"], + cache_dir=str(self.cache_dir) if self.cache_dir is not None else None, + trust_remote_code=trust_remote_code, + ) + + whisper_kwargs: dict[str, Any] = { + "cache_dir": str(self.cache_dir) if self.cache_dir is not None else None, + "trust_remote_code": trust_remote_code, + } + if torch_dtype is not None: + whisper_kwargs["torch_dtype"] = torch_dtype + + self.whisper_model = WhisperForConditionalGeneration.from_pretrained( + self.spec["whisper_model_id"], + **whisper_kwargs, + ).to(self.device) + + self.whisper_model.eval() + + # Restored MLP heads are stored here. + # + # ModuleDict keys must be sanitized because labels can contain punctuation. + self.classifiers = nn.ModuleDict() + + if load_all_classifiers: + self.load_classifiers() + + @classmethod + def from_pretrained( + cls, + size: str = "small", + **kwargs: Any, + ) -> "EmpathicInsightVoice": + """Construct the model using Hugging Face checkpoints. + + Example: + model = EmpathicInsightVoice.from_pretrained( + size="small", + device="cuda", + mlp_device="cuda", + ) + """ + return cls(size=size, **kwargs) + + @property + def repo_id(self) -> str: + """Hugging Face repo ID for the selected model size.""" + return str(self.spec["repo_id"]) + + @property + def available_labels(self) -> list[str]: + """All labels known to this script.""" + return list(LAION_LABEL_TO_FILENAME.keys()) + + @property + def projection_dim(self) -> int: + """Classifier projection dimension for the selected model size.""" + return int(self.spec["projection_dim"]) + + # ------------------------------------------------------------------------- + # Audio and Whisper embedding extraction + # ------------------------------------------------------------------------- + + @torch.no_grad() + def extract_whisper_embedding(self, audio_path: Union[str, Path]) -> torch.Tensor: + """Extract the full Whisper encoder embedding from an audio file. + + Args: + audio_path: + Path to an audio file readable by librosa. + + Returns: + Tensor with shape [1, 1500, 768]. + """ + waveform, _ = librosa.load(str(audio_path), sr=self.sample_rate, mono=True) + waveform = self._prepare_waveform(waveform) + return self.extract_whisper_embedding_from_waveform(waveform) + + @torch.no_grad() + def extract_whisper_embedding_from_waveform( + self, + waveform: np.ndarray, + ) -> torch.Tensor: + """Extract the full Whisper encoder embedding from a waveform. + + Args: + waveform: + Mono waveform at self.sample_rate. + + Returns: + Tensor with shape [1, 1500, 768]. + """ + waveform = self._prepare_waveform(waveform) + + input_features = self.processor( + waveform, + sampling_rate=self.sample_rate, + return_tensors="pt", + ).input_features + + input_features = input_features.to(self.device) + input_features = input_features.to(self.whisper_model.dtype) + + encoder_outputs = self.whisper_model.get_encoder()(input_features=input_features) + + embedding = encoder_outputs.last_hidden_state + embedding = self._pad_or_trim_embedding(embedding) + + return embedding + + # ------------------------------------------------------------------------- + # Classifier projection extraction + # ------------------------------------------------------------------------- + + @torch.no_grad() + def extract_classifier_projection( + self, + audio_path: Union[str, Path], + label: str, + normalize: bool = True, + ) -> torch.Tensor: + """Extract one head-specific classifier projection embedding. + + This is the closest equivalent to extracting an x-vector-like embedding + from one specific emotion classifier head. + + Flow: + audio -> Whisper encoder -> label-specific classifier projection + + Args: + audio_path: + Input audio path. + label: + Label whose classifier projection should be used, for example: + "anger", "sadness", "arousal". + normalize: + If True, apply L2 normalization. + + Returns: + Small: [1, 64] + Large: [1, 128] + """ + whisper_embedding = self.extract_whisper_embedding(audio_path) + return self.extract_classifier_projection_from_whisper_embedding( + whisper_embedding=whisper_embedding, + label=label, + normalize=normalize, + ) + + @torch.no_grad() + def extract_classifier_projection_from_whisper_embedding( + self, + whisper_embedding: torch.Tensor, + label: str, + normalize: bool = True, + ) -> torch.Tensor: + """Extract one classifier projection from an existing Whisper embedding.""" + classifier = self._get_classifier(label) + param = next(classifier.parameters()) + + working_embedding = whisper_embedding.to(device=param.device, dtype=param.dtype) + projected = classifier.extract_projected_embedding(working_embedding) + + projected = projected.float() + if normalize: + projected = F.normalize(projected, p=2, dim=-1) + + return projected + + # ------------------------------------------------------------------------- + # Emotion similarity embeddings + # ------------------------------------------------------------------------- + + @torch.no_grad() + def extract_emotion_embedding( + self, + audio_path: Union[str, Path], + labels: Optional[Sequence[str]] = None, + embedding_type: str = "head_concat", + normalize: bool = True, + include_auxiliary: bool = False, + ) -> torch.Tensor: + """Extract a fixed-dimensional emotion embedding. + + Recommended for SV-style emotion similarity: + embedding_type="head_concat" + + Supported embedding types: + + 1. "head_concat" + Concatenate the projection output of each selected classifier head. + + Small, 40 primary labels: + [1, 40 * 64] = [1, 2560] + + Large, 40 primary labels: + [1, 40 * 128] = [1, 5120] + + 2. "head_mean" + Average the projection outputs across selected heads. + + Small: + [1, 64] + + Large: + [1, 128] + + 3. "score_vector" + Use raw scalar outputs from the selected classifier heads. + + Shape: + [1, num_labels] + + This is closest to the official annotation output. + + Args: + audio_path: + Input audio path. + labels: + Labels to use. If None, PRIMARY_EMOTION_LABELS are used. + embedding_type: + "head_concat", "head_mean", or "score_vector". + normalize: + If True, apply L2 normalization to the final embedding. + include_auxiliary: + If labels is None, append AUXILIARY_SIMILARITY_LABELS. + + Returns: + torch.Tensor fixed-dimensional emotion embedding. + """ + whisper_embedding = self.extract_whisper_embedding(audio_path) + labels_to_run = self._default_similarity_labels( + labels=labels, + include_auxiliary=include_auxiliary, + ) + + return self.extract_emotion_embedding_from_whisper_embedding( + whisper_embedding=whisper_embedding, + labels=labels_to_run, + embedding_type=embedding_type, + normalize=normalize, + ) + + @torch.no_grad() + def extract_emotion_embedding_from_whisper_embedding( + self, + whisper_embedding: torch.Tensor, + labels: Optional[Sequence[str]] = None, + embedding_type: str = "head_concat", + normalize: bool = True, + include_auxiliary: bool = False, + ) -> torch.Tensor: + """Extract a fixed-dimensional emotion embedding from Whisper features.""" + labels_to_run = self._default_similarity_labels( + labels=labels, + include_auxiliary=include_auxiliary, + ) + + if embedding_type == "score_vector": + prediction = self.predict_emotions_from_embedding( + embedding=whisper_embedding, + labels=labels_to_run, + return_raw_scores=True, + rank_scores=False, + ) + raw_scores = prediction["raw_scores"] + + output = torch.tensor( + [raw_scores[label] for label in labels_to_run], + dtype=torch.float32, + ).unsqueeze(0) + + elif embedding_type in {"head_concat", "head_mean"}: + projected_embeddings: list[torch.Tensor] = [] + + for label in labels_to_run: + classifier = self._get_classifier(label) + param = next(classifier.parameters()) + + working_embedding = whisper_embedding.to( + device=param.device, + dtype=param.dtype, + ) + + projected = classifier.extract_projected_embedding(working_embedding) + projected_embeddings.append(projected.float().cpu()) + + if embedding_type == "head_concat": + output = torch.cat(projected_embeddings, dim=-1) + else: + output = torch.stack(projected_embeddings, dim=0).mean(dim=0) + + else: + raise ValueError( + f"Unsupported embedding_type={embedding_type!r}. " + "Expected 'head_concat', 'head_mean', or 'score_vector'." + ) + + if normalize: + output = F.normalize(output, p=2, dim=-1) + + return output + + @torch.no_grad() + def emotion_similarity( + self, + audio_path_a: Union[str, Path], + audio_path_b: Union[str, Path], + labels: Optional[Sequence[str]] = None, + embedding_type: str = "head_concat", + include_auxiliary: bool = False, + ) -> float: + """Compute cosine similarity between two audios in emotion space. + + Args: + audio_path_a: + First audio path. + audio_path_b: + Second audio path. + labels: + Optional label subset. + embedding_type: + "head_concat", "head_mean", or "score_vector". + include_auxiliary: + If labels is None, append auxiliary labels. + + Returns: + Cosine similarity as a Python float. + """ + emb_a = self.extract_emotion_embedding( + audio_path=audio_path_a, + labels=labels, + embedding_type=embedding_type, + normalize=True, + include_auxiliary=include_auxiliary, + ) + emb_b = self.extract_emotion_embedding( + audio_path=audio_path_b, + labels=labels, + embedding_type=embedding_type, + normalize=True, + include_auxiliary=include_auxiliary, + ) + + return float(F.cosine_similarity(emb_a, emb_b, dim=-1).item()) + + # ------------------------------------------------------------------------- + # Prediction + # ------------------------------------------------------------------------- + @torch.no_grad() + def compare_emotion_pair( + self, + audio_path_a: Union[str, Path], + audio_path_b: Union[str, Path], + labels: Optional[Sequence[str]] = None, + embedding_type: str = "score_vector", + ) -> dict[str, Any]: + """Compare two audio files using the 12-emotion set. + + This method does not perform corpus-level ranking. It only returns: + + - top emotion for audio A + - top emotion for audio B + - matched emotion label if both top emotions match + - emotion similarity + + Args: + audio_path_a: + First audio file. + audio_path_b: + Second audio file. + labels: + Optional subset of labels. Defaults to PRIMARY_EMOTION_LABELS, + which is the 12-emotion set. + embedding_type: + Similarity representation: + - "score_vector": cosine over raw 12-emotion score vector. + - "head_concat": cosine over concatenated classifier projections. + - "head_mean": cosine over averaged classifier projections. + + For MIRRORING-style emotion-vector similarity, use "score_vector". + + Returns: + { + "audio_path_a": str, + "audio_path_b": str, + "audio_a_top_emotion": str | None, + "audio_b_top_emotion": str | None, + "top_emotion_match": bool , + "emotion_similarity": float, + "audio_a_raw_scores": dict[str, float], + "audio_b_raw_scores": dict[str, float], + } + """ + labels_to_run = self._validate_labels(labels or PRIMARY_EMOTION_LABELS) + + result_a = self.compute( + audio_path=audio_path_a, + labels=labels_to_run, + return_embedding=False, + return_raw_scores=True, + ) + + result_b = self.compute( + audio_path=audio_path_b, + labels=labels_to_run, + return_embedding=False, + return_raw_scores=True, + ) + + top_a = result_a["top_emotion"] + top_b = result_b["top_emotion"] + + similarity = self.emotion_similarity( + audio_path_a=audio_path_a, + audio_path_b=audio_path_b, + labels=labels_to_run, + embedding_type=embedding_type, + include_auxiliary=False, + ) + + return { + "audio_path_a": str(audio_path_a), + "audio_path_b": str(audio_path_b), + "audio_a_top_emotion": top_a, + "audio_b_top_emotion": top_b, + "top_emotion_match": top_a is not None and top_a == top_b, + "emotion_similarity": similarity, + "audio_a_raw_scores": result_a["raw_scores"], + "audio_b_raw_scores": result_b["raw_scores"], + } + + @torch.no_grad() + def compute( + self, + audio_path: Union[str, Path], + labels: Optional[Sequence[str]] = None, + return_embedding: bool = True, + embedding_type: str = "head_concat", + return_raw_scores: bool = True, + include_auxiliary_for_embedding: bool = False, + ) -> dict[str, Any]: + """Compute emotion predictions and optionally an embedding. + + Args: + audio_path: + Input audio file. + labels: + Prediction labels. If None, all known labels are attempted. + For the official 40-emotion profile, pass PRIMARY_EMOTION_LABELS. + return_embedding: + If True, return an emotion embedding. + embedding_type: + Embedding type to return: + "head_concat", "head_mean", or "score_vector". + return_raw_scores: + If True, return raw classifier outputs. + include_auxiliary_for_embedding: + If True and return_embedding=True, include auxiliary labels in the + returned embedding. + + Returns: + { + "audio_path": str, + "model_size": "small" | "large", + "top_emotion": str | None, + "emotions": { + label: {"score": float, "rank": int} + }, + "raw_scores": { + label: float + }, + "embedding": torch.Tensor, + "embedding_type": str + } + """ + whisper_embedding = self.extract_whisper_embedding(audio_path) + + prediction = self.predict_emotions_from_embedding( + embedding=whisper_embedding, + labels=labels, + return_raw_scores=return_raw_scores, + rank_scores=True, + ) + + top_emotion = None + if prediction["emotions"]: + top_emotion = next(iter(prediction["emotions"])) + + output: dict[str, Any] = { + "audio_path": str(audio_path), + "model_size": self.size, + "top_emotion": top_emotion, + "emotions": prediction["emotions"], + } + + if return_raw_scores: + output["raw_scores"] = prediction["raw_scores"] + + if return_embedding: + output["embedding"] = self.extract_emotion_embedding_from_whisper_embedding( + whisper_embedding=whisper_embedding, + labels=None, + embedding_type=embedding_type, + normalize=True, + include_auxiliary=include_auxiliary_for_embedding, + ) + output["embedding_type"] = embedding_type + + return output + + @torch.no_grad() + def predict_emotions_from_embedding( + self, + embedding: torch.Tensor, + labels: Optional[Sequence[str]] = None, + return_raw_scores: bool = True, + rank_scores: bool = True, + ) -> dict[str, Any]: + """Predict emotion or attribute scores from a Whisper embedding. + + Args: + embedding: + Whisper encoder embedding, normally [1, 1500, 768]. + labels: + Labels to evaluate. If None, all known labels are attempted. + return_raw_scores: + If True, include raw classifier scores. + rank_scores: + If True, softmax and rank scores into top-k emotions. + + Returns: + Dict containing: + "emotions": ranked top-k scores if rank_scores=True + "raw_scores": raw scalar outputs if return_raw_scores=True + """ + labels_to_run = self._validate_labels(labels) + + raw_scores: dict[str, float] = {} + + for label in labels_to_run: + classifier = self._get_classifier(label) + param = next(classifier.parameters()) + + working_embedding = embedding.to(device=param.device, dtype=param.dtype) + score = classifier(working_embedding).detach().cpu().item() + raw_scores[label] = float(score) + + if not self.cache_classifiers: + cache_key = self._cache_key(label) + if cache_key in self.classifiers: + del self.classifiers[cache_key] + + output: dict[str, Any] = {} + + if rank_scores: + output["emotions"] = self._softmax_and_rank(raw_scores) + else: + output["emotions"] = {} + + if return_raw_scores: + output["raw_scores"] = raw_scores + + return output + + # ------------------------------------------------------------------------- + # Classifier loading + # ------------------------------------------------------------------------- + + def load_classifiers( + self, + labels: Optional[Sequence[str]] = None, + ) -> None: + """Eagerly download and restore classifier MLPs. + + By default the class lazy-loads heads when needed. This method is useful + when you want to pre-load selected heads before repeated inference. + """ + labels_to_load = self._validate_labels(labels) + for label in labels_to_load: + self._get_classifier(label) + + def _get_classifier(self, label: str) -> FullEmbeddingMLP: + """Download, reconstruct, and restore one classifier MLP head. + + This is where the classifier MLP is restored: + + classifier = FullEmbeddingMLP(...) + state_dict = torch.load(...) + state_dict = strip "_orig_mod." prefix if needed + classifier.load_state_dict(state_dict) + + Args: + label: + Python-friendly label key, such as "anger" or "arousal". + + Returns: + Restored FullEmbeddingMLP. + """ + if label not in LAION_LABEL_TO_FILENAME: + raise ValueError(f"Unknown label {label!r}. Available labels: " f"{sorted(LAION_LABEL_TO_FILENAME)}") + + cache_key = self._cache_key(label) + + if cache_key in self.classifiers: + classifier = self.classifiers[cache_key] + if not isinstance(classifier, FullEmbeddingMLP): + raise TypeError(f"Cached classifier for {label!r} has unexpected type " f"{type(classifier)}.") + return classifier + + filename = LAION_LABEL_TO_FILENAME[label] + + local_path = hf_hub_download( + repo_id=self.repo_id, + filename=filename, + cache_dir=str(self.cache_dir) if self.cache_dir is not None else None, + repo_type="model", + ) + + classifier = FullEmbeddingMLP( + seq_len=int(self.spec["seq_len"]), + embed_dim=int(self.spec["embed_dim"]), + projection_dim=int(self.spec["projection_dim"]), + mlp_hidden_dims=list(self.spec["mlp_hidden_dims"]), + mlp_dropout_rates=list(self.spec["mlp_dropouts"]), + ) + + state_dict = torch.load(local_path, map_location="cpu") + + if not isinstance(state_dict, dict): + raise RuntimeError(f"Expected {local_path} to contain a state_dict, " f"but got {type(state_dict)}.") + + state_dict = self._strip_orig_mod_prefix_if_needed(state_dict) + + try: + classifier.load_state_dict(state_dict) + except RuntimeError as exc: + raise RuntimeError( + f"Failed to load classifier for label={label!r} from {local_path}. " + f"This often means the selected size={self.size!r} has different " + "MLP dimensions than MODEL_SPECS declares." + ) from exc + + classifier.eval() + classifier = classifier.to(self.mlp_device) + + # Keep classifier dtype compatible with Whisper dtype if user loaded + # Whisper in fp16/bf16. + if self.whisper_model.dtype in (torch.float16, torch.bfloat16): + classifier = classifier.to(dtype=self.whisper_model.dtype) + + if self.cache_classifiers: + self.classifiers[cache_key] = classifier + + return classifier + + # ------------------------------------------------------------------------- + # Internal helpers + # ------------------------------------------------------------------------- + + def _validate_labels( + self, + labels: Optional[Sequence[str]], + ) -> list[str]: + """Validate label names and return a concrete list.""" + if labels is None: + return list(LAION_LABEL_TO_FILENAME.keys()) + + labels_list = list(labels) + unknown = sorted(set(labels_list) - set(LAION_LABEL_TO_FILENAME.keys())) + + if unknown: + raise ValueError( + f"Unknown labels: {unknown}. " f"Available labels: {sorted(LAION_LABEL_TO_FILENAME.keys())}" + ) + + return labels_list + + def _default_similarity_labels( + self, + labels: Optional[Sequence[str]], + include_auxiliary: bool, + ) -> list[str]: + """Choose the default labels for emotion similarity embeddings.""" + if labels is not None: + return self._validate_labels(labels) + + labels_to_run = list(PRIMARY_EMOTION_LABELS) + + if include_auxiliary: + labels_to_run.extend(AUXILIARY_SIMILARITY_LABELS) + + return self._validate_labels(labels_to_run) + + def _prepare_waveform(self, waveform: np.ndarray) -> np.ndarray: + """Convert waveform to mono float32 and trim to max_audio_seconds.""" + waveform = np.asarray(waveform, dtype=np.float32) + + if waveform.ndim > 1: + waveform = np.mean(waveform, axis=0).astype(np.float32) + + max_samples = int(self.sample_rate * self.max_audio_seconds) + + if waveform.shape[0] > max_samples: + waveform = waveform[:max_samples] + + return waveform + + def _pad_or_trim_embedding(self, embedding: torch.Tensor) -> torch.Tensor: + """Pad or trim Whisper encoder output to the expected sequence length.""" + seq_len = int(self.spec["seq_len"]) + embed_dim = int(self.spec["embed_dim"]) + + if embedding.ndim != 3: + raise RuntimeError(f"Expected Whisper embedding with shape [B, T, C], " f"got {tuple(embedding.shape)}.") + + if embedding.shape[-1] != embed_dim: + raise RuntimeError(f"Unexpected embedding dim. Expected {embed_dim}, " f"got {embedding.shape[-1]}.") + + current_seq_len = embedding.shape[1] + + if current_seq_len < seq_len: + padding = torch.zeros( + ( + embedding.shape[0], + seq_len - current_seq_len, + embed_dim, + ), + device=embedding.device, + dtype=embedding.dtype, + ) + embedding = torch.cat([embedding, padding], dim=1) + + elif current_seq_len > seq_len: + embedding = embedding[:, :seq_len, :] + + return embedding + + def _softmax_and_rank( + self, + raw_scores: dict[str, float], + ) -> dict[str, dict[str, Union[float, int]]]: + """Convert raw scores to a sorted top-k softmax dictionary. + + The raw MLP outputs are independent regression scores. This method is + mainly for producing an easy top emotion label. For similarity, prefer + raw score vectors or projection embeddings. + """ + if not raw_scores: + return {} + + labels = list(raw_scores.keys()) + values = np.array([raw_scores[label] for label in labels], dtype=np.float32) + + values = values - np.max(values) + exp_values = np.exp(values) + probs = exp_values / np.sum(exp_values) + + ranked = sorted( + zip(labels, probs.tolist()), + key=lambda item: item[1], + reverse=True, + ) + + ranked = ranked[: self.top_k_emotions] + + return { + label: { + "score": float(prob), + "rank": rank, + } + for rank, (label, prob) in enumerate(ranked, start=1) + } + + @staticmethod + def _strip_orig_mod_prefix_if_needed( + state_dict: dict[str, torch.Tensor], + ) -> dict[str, torch.Tensor]: + """Strip torch.compile '_orig_mod.' prefixes if present.""" + if not any(key.startswith("_orig_mod.") for key in state_dict.keys()): + return state_dict + + return { + key[len("_orig_mod.") :] if key.startswith("_orig_mod.") else key: value + for key, value in state_dict.items() + } + + @staticmethod + def _cache_key(label: str) -> str: + """Convert an arbitrary label into a safe ModuleDict key.""" + return label.replace(".", "_").replace("/", "_").replace("-", "_").replace(" ", "_").replace("&", "and") + + def cleanup(self) -> None: + """Move modules to CPU and clear classifier cache.""" + for key in list(self.classifiers.keys()): + self.classifiers[key].cpu() + + self.classifiers.clear() + self.whisper_model.cpu() + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +# ============================================================================= +# CLI utilities +# ============================================================================= + + +def _tensor_info(tensor: torch.Tensor) -> dict[str, Any]: + return { + "shape": list(tensor.shape), + "dtype": str(tensor.dtype), + "device": str(tensor.device), + } + + +def _parse_labels(labels: Optional[str]) -> Optional[list[str]]: + if labels is None or labels.strip() == "": + return None + + return [item.strip() for item in labels.split(",") if item.strip()] + + +def main() -> None: + parser = argparse.ArgumentParser(description="LAION Empathic Insight Voice embeddings and similarity.") + parser.add_argument( + "--audio", + type=str, + required=True, + help="Input audio path.", + ) + parser.add_argument( + "--audio-b", + type=str, + default=None, + help="Optional second audio path for similarity.", + ) + parser.add_argument( + "--size", + type=str, + default="small", + choices=["small", "large"], + help="Model size.", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device for Whisper encoder.", + ) + parser.add_argument( + "--mlp-device", + type=str, + default=None, + help="Device for MLP classifier heads. Defaults to --device.", + ) + parser.add_argument( + "--cache-dir", + type=str, + default=None, + help="Optional Hugging Face cache directory.", + ) + parser.add_argument( + "--embedding-type", + type=str, + default="head_concat", + choices=["head_concat", "head_mean", "score_vector"], + help="Emotion embedding type.", + ) + parser.add_argument( + "--labels", + type=str, + default=None, + help=( + "Comma-separated labels to use. " + "Example: anger,sadness,arousal. " + "If omitted, primary emotion labels are used for similarity." + ), + ) + parser.add_argument( + "--include-auxiliary", + action="store_true", + help="Include auxiliary similarity labels when --labels is omitted.", + ) + parser.add_argument( + "--load-all-classifiers", + action="store_true", + help="Eagerly load all known classifiers at startup.", + ) + parser.add_argument( + "--top-k", + type=int, + default=5, + help="Number of ranked emotions to return.", + ) + + args = parser.parse_args() + + labels = _parse_labels(args.labels) + + model = EmpathicInsightVoice.from_pretrained( + size=args.size, + device=args.device, + mlp_device=args.mlp_device, + cache_dir=args.cache_dir, + cache_classifiers=True, + load_all_classifiers=args.load_all_classifiers, + top_k_emotions=args.top_k, + ) + + result = model.compute( + audio_path=args.audio, + labels=labels, + return_embedding=True, + embedding_type=args.embedding_type, + return_raw_scores=True, + include_auxiliary_for_embedding=args.include_auxiliary, + ) + + printable: dict[str, Any] = { + "audio_path": result["audio_path"], + "model_size": result["model_size"], + "top_emotion": result["top_emotion"], + "embedding_type": result["embedding_type"], + "embedding": _tensor_info(result["embedding"]), + "emotions": result["emotions"], + "raw_scores": result["raw_scores"], + } + + if args.audio_b is not None: + printable["audio_b"] = args.audio_b + printable["similarity"] = model.emotion_similarity( + audio_path_a=args.audio, + audio_path_b=args.audio_b, + labels=labels, + embedding_type=args.embedding_type, + include_auxiliary=args.include_auxiliary, + ) + + result = model.compare_emotion_pair( + audio_path_a=args.audio, + audio_path_b=args.audio_b, + embedding_type="head_concat", + ) + + result_score_vector = model.compare_emotion_pair( + audio_path_a=args.audio, + audio_path_b=args.audio_b, + embedding_type="score_vector", + ) + + result_score_mean = model.compare_emotion_pair( + audio_path_a=args.audio, + audio_path_b=args.audio_b, + embedding_type="head_mean", + ) + + print("embedding_type=head_concat") + print(result["audio_a_top_emotion"]) + print(result["audio_b_top_emotion"]) + print(result["top_emotion_match"]) + print(result["emotion_similarity"]) + + print("embedding_type=score_vector") + print(result_score_vector["audio_a_top_emotion"]) + print(result_score_vector["audio_b_top_emotion"]) + print(result_score_vector["top_emotion_match"]) + print(result_score_vector["emotion_similarity"]) + + print("embedding_type=head_mean") + print(result_score_mean["audio_a_top_emotion"]) + print(result_score_mean["audio_b_top_emotion"]) + print(result_score_mean["top_emotion_match"]) + print(result_score_mean["emotion_similarity"]) + + +if __name__ == "__main__": + main() diff --git a/nemo/collections/tts/models/easy_magpietts.py b/nemo/collections/tts/models/easy_magpietts.py index ffe4a0cfd055..f6ad19eb14b5 100644 --- a/nemo/collections/tts/models/easy_magpietts.py +++ b/nemo/collections/tts/models/easy_magpietts.py @@ -30,8 +30,10 @@ import nemo.collections.asr as nemo_asr from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.asr.parts.mixins.transcription import TranscribeConfig +from nemo.collections.common.data.fallback import FallbackDataset from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config from nemo.collections.tts.data.text_to_speech_dataset_lhotse import MagpieTTSLhotseDataset, setup_tokenizers +from nemo.collections.tts.data.text_to_speech_dataset_lhotse_multiturn import MagpieTTSLhotseMultiturnDataset from nemo.collections.tts.models.easy_magpietts_inference import EasyMagpieTTSInferenceModel, TrainingMode from nemo.collections.tts.modules.magpietts_modules import ( LocalTransformerType, @@ -57,6 +59,8 @@ except (ImportError, ModuleNotFoundError): HAVE_UTMOSV2 = False +from typing import List + from transformers import WhisperForConditionalGeneration, WhisperProcessor @@ -168,7 +172,13 @@ def _get_state_dict_keys_to_exclude(self): '_utmos_calculator', ] - def compute_loss(self, logits, audio_codes, audio_codes_lens): + def compute_loss( + self, + logits, + audio_codes, + audio_codes_lens, + agent_mask_target=None, + ): """ Computes the audio codebook loss. Used by (1) The main Magpie-TTS transformer @@ -180,40 +190,90 @@ def compute_loss(self, logits, audio_codes, audio_codes_lens): """ loss_mask = get_mask_from_lengths(audio_codes_lens) loss_mask = loss_mask.unsqueeze(1).repeat(1, audio_codes.size(1), 1) + + if agent_mask_target is not None: + agent_mask_target = agent_mask_target.to(device=audio_codes.device, dtype=loss_mask.dtype) + total_codebook_loss = None for codebook in range(audio_codes.size(1)): si = codebook * self.num_all_tokens_per_codebook ei = si + self.num_all_tokens_per_codebook - codebook_logits = logits[:, :, si:ei] # (B, T', num_tokens_per_codebook) - codebook_targets = audio_codes[:, codebook] # (B, T') - codebook_loss = self.cross_entropy_loss( - codebook_logits.permute(0, 2, 1), codebook_targets.long() # (B, num_tokens_per_codebook, T') + codebook_logits = logits[:, :, si:ei] + codebook_targets = audio_codes[:, codebook] + raw_loss = self.cross_entropy_loss( + codebook_logits.permute(0, 2, 1), + codebook_targets.long(), ) # (B, T') - codebook_loss = codebook_loss * loss_mask[:, codebook, :] - codebook_loss = codebook_loss.sum() / loss_mask[:, codebook, :].sum() - if total_codebook_loss is None: - total_codebook_loss = codebook_loss - else: - total_codebook_loss = total_codebook_loss + codebook_loss + effective_mask = loss_mask[:, codebook, :] + if agent_mask_target is not None: + effective_mask = effective_mask * agent_mask_target + codebook_loss = raw_loss * effective_mask + codebook_loss = codebook_loss.sum() / effective_mask.sum().clamp_min(1.0) + total_codebook_loss = codebook_loss if total_codebook_loss is None else total_codebook_loss + codebook_loss total_codebook_loss = total_codebook_loss / audio_codes.size(1) return total_codebook_loss, loss_mask - def compute_phoneme_loss(self, logits, phoneme_tokens, phoneme_tokens_lens): + def compute_phoneme_loss( + self, + logits, + phoneme_tokens, + phoneme_tokens_lens, + custom_mask=None, + ): + """ + logits: (B, T', phoneme_stacking_factor * phoneme_vocab_size) + phoneme_tokens: (B, S, T') + phoneme_tokens_lens: (B,) + custom_mask: optional (B, T') + """ loss_mask = get_mask_from_lengths(phoneme_tokens_lens) + loss_mask = loss_mask.unsqueeze(1).repeat(1, phoneme_tokens.size(1), 1) + + if custom_mask is not None: + custom_mask = custom_mask.bool() + target_T = phoneme_tokens.size(2) + + if custom_mask.size(1) < target_T: + pad = torch.zeros( + custom_mask.size(0), + target_T - custom_mask.size(1), + device=custom_mask.device, + dtype=custom_mask.dtype, + ) + custom_mask = torch.cat([custom_mask, pad], dim=1) + else: + custom_mask = custom_mask[:, :target_T] + + custom_mask = custom_mask.to( + device=phoneme_tokens.device, + dtype=loss_mask.dtype, + ) + total_phoneme_loss = None + for codebook in range(self.phoneme_stacking_factor): si = codebook * self.phoneme_vocab_size ei = si + self.phoneme_vocab_size + phoneme_logits = logits[:, :, si:ei] phoneme_targets = phoneme_tokens[:, codebook] - phoneme_loss = self.cross_entropy_loss(phoneme_logits.permute(0, 2, 1), phoneme_targets) - phoneme_loss = phoneme_loss * loss_mask - phoneme_loss = phoneme_loss.sum() / loss_mask.sum() - if total_phoneme_loss is None: - total_phoneme_loss = phoneme_loss - else: - total_phoneme_loss = total_phoneme_loss + phoneme_loss + + raw_loss = self.cross_entropy_loss( + phoneme_logits.permute(0, 2, 1), + phoneme_targets.long(), + ) # (B, T') + + effective_mask = loss_mask[:, codebook, :] + + if custom_mask is not None: + effective_mask = effective_mask * custom_mask + + phoneme_loss = raw_loss * effective_mask + phoneme_loss = phoneme_loss.sum() / effective_mask.sum().clamp_min(1.0) + + total_phoneme_loss = phoneme_loss if total_phoneme_loss is None else total_phoneme_loss + phoneme_loss + total_phoneme_loss = total_phoneme_loss / self.phoneme_stacking_factor return total_phoneme_loss, loss_mask @@ -326,6 +386,8 @@ def prepare_text_channel_embeddings( text_lens: torch.Tensor, delay: torch.Tensor, dropout_text_input: bool = False, + is_multiturn: bool = False, + text_pad_id: int = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Prepare text embeddings as a channel input with delay handling. @@ -350,12 +412,16 @@ def prepare_text_channel_embeddings( device = text.device # Embed text tokens (CAS-only when disable_subword_embedding=True). - text_embedded = self.embed_text_tokens(text, text_lens=text_lens) # (B, L, E) + text_embedded = self.embed_text_tokens(text, text_lens=text_lens, is_multiturn=is_multiturn) # (B, L, E) # Handle text dropout - zero out the embeddings if dropout_text_input: text_embedded = text_embedded * 0.0 + # multiturn dataset returns a special pad text tokens until it matches the audio len, to keep compatible with regular dataset zero-out those values + if is_multiturn: + text_embedded[text == text_pad_id] = 0.0 + # Create zero tensor for delay padding max_delay = delay.max().item() zero_delay_tensor = torch.zeros(batch_size, max_delay, self.cfg.embedding_dim, device=device) @@ -426,8 +492,14 @@ def prepare_phoneme_channel_embeddings( phoneme_embedded = self.embed_phoneme_tokens(phoneme_tokens_stacked) # (B, T', E) # Apply mask to zero out padding - phoneme_mask = get_mask_from_lengths(phoneme_tokens_lens_stacked) - phoneme_embedded = phoneme_embedded * phoneme_mask.unsqueeze(2) # (B, T', E) + if self.cfg.get("use_multiturn_dataset", False): + phoneme_pad_id = getattr(self.phoneme_tokenizer, "pad", -1) + phoneme_mask = phoneme_tokens_stacked[:, 0, :] != phoneme_pad_id # Check the first layer of the stack + # Apply mask to zero out padding + phoneme_embedded = phoneme_embedded * phoneme_mask.unsqueeze(2) # (B, T', E) + else: + phoneme_mask = get_mask_from_lengths(phoneme_tokens_lens_stacked) + phoneme_embedded = phoneme_embedded * phoneme_mask.unsqueeze(2) # (B, T', E) # Handle phoneme dropout - zero out the embeddings if dropout_complete_phoneme_channel: @@ -523,7 +595,10 @@ def prepare_audio_channel_embeddings( audio_codes: torch.Tensor, audio_codes_lens: torch.Tensor, delay: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + speech_eos_mask: Optional[torch.Tensor] = None, + agent_mask: Optional[torch.Tensor] = None, + current_streaming_speech_delay: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ Prepare audio embeddings as a channel input with delay handling. @@ -544,6 +619,7 @@ def prepare_audio_channel_embeddings( - audio_channel_lens: Total length of audio channel for each batch item (B,) - audio_codes_target: Target audio codes for loss computation (B, C, T'-1) - audio_codes_lens_target: Length of target audio codes (B,) + - loss_agent_mask: Optional mask used for loss masking; None when no agent_mask is provided. """ batch_size = audio_codes.size(0) device = audio_codes.device @@ -560,6 +636,7 @@ def prepare_audio_channel_embeddings( codes_len=audio_codes_lens, bos_id=self.audio_bos_id, eos_id=self.audio_eos_id, + num_eos_tokens=1 if speech_eos_mask is None else 0, ) # Stack audio codes across codebooks @@ -572,6 +649,25 @@ def prepare_audio_channel_embeddings( self.num_audio_codebooks, ) + if speech_eos_mask is not None: + audio_codes_before_speech_eos = audio_codes.clone() + # Shift +1 for BOS alignment and +1 more so EOS is injected after the marked frame. + B_mask, T_mask = speech_eos_mask.shape + shifted_mask = torch.zeros((B_mask, T_mask + 2), dtype=torch.bool, device=device) + shifted_mask[:, 2:] = speech_eos_mask + + # 2. Find the minimum overlapping time dimension + t_mask = shifted_mask.size(1) + t_audio = audio_codes.size(2) + min_t = min(t_mask, t_audio) + + # 3. Slice both to the valid overlap and broadcast the C dimension + valid_mask = shifted_mask[:, :min_t] + expanded_mask = valid_mask.unsqueeze(1).expand(-1, audio_codes.size(1), -1) + + # Inject the EOS token only into the overlapping region + audio_codes[:, :, :min_t][expanded_mask] = self.audio_eos_id + # Prepare input and target for autoregressive training # Input: all tokens except the last (teacher forcing) # Target: all tokens except the first (shifted by one) @@ -579,6 +675,110 @@ def prepare_audio_channel_embeddings( audio_codes_target = audio_codes[:, :, 1:] # (B, C, T'-1) audio_codes_input = audio_codes[:, :, :-1] # (B, C, T'-1) + # Drop some EOS frames from audio input so the model learns recovery when inference misses EOS. + # sample_prob keeps some samples untouched, so the model still learns the normal EOS-input behavior. + if speech_eos_mask is not None and self.training: + drop_eos_sample_prob = float(self.cfg.get("drop_eos_from_audio_input_sample_prob", 0.0)) + drop_eos_frame_prob = float(self.cfg.get("drop_eos_from_audio_input_frame_prob", 0.5)) + + if drop_eos_sample_prob > 0.0 and drop_eos_frame_prob > 0.0: + eos_frame_mask = (audio_codes_input == self.audio_eos_id).any(dim=1) # [B, T] + + sample_drop_mask = torch.rand(batch_size, device=device) < drop_eos_sample_prob # [B] + + frame_drop_mask = ( + eos_frame_mask + & sample_drop_mask.unsqueeze(1) + & (torch.rand_like(eos_frame_mask.float()) < drop_eos_frame_prob) + ) # [B, T] + + audio_codes_input_backup = audio_codes_before_speech_eos[:, :, :-1] + + audio_codes_input = torch.where( + frame_drop_mask.unsqueeze(1), + audio_codes_input_backup, + audio_codes_input, + ) + + # deal with agent mask + loss_agent_mask = None + if agent_mask is not None: + target_T = audio_codes_target.size(2) + + # Align dataloader agent_mask to audio_codes_target time. + if agent_mask.size(1) < target_T: + pad = torch.zeros( + agent_mask.size(0), + target_T - agent_mask.size(1), + device=agent_mask.device, + dtype=torch.bool, + ) + agent_mask = torch.cat([agent_mask.bool(), pad], dim=1) + else: + agent_mask = agent_mask[:, :target_T].bool() + + agent_mask = agent_mask.to(audio_codes_target.device) + + valid = get_mask_from_lengths(audio_codes_lens_target).bool().to(audio_codes_target.device) + agent_mask = agent_mask & valid + + # Keep EOS and the frame before EOS supervised. + eos_any = (audio_codes_target == self.audio_eos_id).any(dim=1) & valid + + eos_prev1 = torch.zeros_like(eos_any) + eos_prev1[:, :-1] = eos_any[:, 1:] + + agent_mask = agent_mask | eos_prev1 | eos_any + target_agent_mask = agent_mask & valid + loss_agent_mask = target_agent_mask + + # Replace user/non-agent regions with a learned token. + # Important: audio_codes_input predicts audio_codes_target, so input mask must be shifted. + if self.cfg.get("use_user_speaking_token", False): + target_non_agent = (~target_agent_mask) & valid + + # audio_codes_input[:, :, t] is the previous token used to predict target t. + input_agent_mask = torch.zeros_like(target_agent_mask) + input_agent_mask[:, 1:] = target_agent_mask[:, :-1] + input_agent_mask[:, 0] = True # Keep first/BOS input untouched. + + input_valid = torch.zeros_like(valid) + input_valid[:, 1:] = valid[:, :-1] + input_valid[:, 0] = valid[:, 0] + + input_non_agent = (~input_agent_mask) & input_valid + + user_tok_input = torch.full_like(audio_codes_input, self.audio_user_speaking_id) + audio_codes_input = torch.where( + input_non_agent.unsqueeze(1), + user_tok_input, + audio_codes_input, + ) + + user_tok_target = torch.full_like(audio_codes_target, self.audio_user_speaking_id) + audio_codes_target = torch.where( + target_non_agent.unsqueeze(1), + user_tok_target, + audio_codes_target, + ) + + # Put audio_user_speaking_end_id in the input slot that predicts the first agent frame + # after a non-agent region. + if self.cfg.get("use_user_speaking_end_token", False): + user_to_agent = torch.zeros_like(target_agent_mask) + + user_to_agent[:, 1:] = ( + target_agent_mask[:, 1:] & (~target_agent_mask[:, :-1]) & valid[:, 1:] & valid[:, :-1] + ) + + end_tok_input = torch.full_like(audio_codes_input, self.audio_user_speaking_end_id) + + audio_codes_input = torch.where( + user_to_agent.unsqueeze(1), + end_tok_input, + audio_codes_input, + ) + # Embed audio tokens audio_embedded = self.embed_audio_tokens(audio_codes_input) # (B, T'-1, E) @@ -592,7 +792,13 @@ def prepare_audio_channel_embeddings( lengths=[delay, audio_codes_lens_target], ) - return audio_channel_embedding, audio_channel_lens, audio_codes_target, audio_codes_lens_target + return ( + audio_channel_embedding, + audio_channel_lens, + audio_codes_target, + audio_codes_lens_target, + loss_agent_mask, + ) def slice_sequence_embeddings(self, sequence_embeddings, context_lens, target_lens): """ @@ -633,8 +839,12 @@ def process_batch( context_audio_codes_lens: torch.Tensor, phoneme_tokens: Optional[torch.Tensor] = None, phoneme_tokens_lens: Optional[torch.Tensor] = None, + phoneme_turn_dropout: Optional[torch.Tensor] = None, mode: str = "train", training_mode: Optional[TrainingMode] = None, + task: Optional[List[str]] = None, + agent_mask: Optional[torch.Tensor] = None, + user_audio_embedded: Optional[torch.Tensor] = None, ) -> ProcessBatchOutput: """ Simplified batch processing using channel-based embedding architecture. @@ -715,12 +925,22 @@ def process_batch( # Streaming mode: context_lens + speech_delay audio_delay = context_lens + current_streaming_speech_delay + speech_eos_mask = None + if self.cfg.get("use_multiturn_dataset", False): + speech_eos_mask = text == self.interruption_token_id # (B, T) + # remove the interruption token for all task, expect for interruption + if not task or "interruption" not in str(task[0]): + text[speech_eos_mask] = self.tokenizer.pad # Clean up the text channel + # else: # ToDo: move self.interruption_token_id forward by audio_delay so that soon it saw the interruption token it is forced to stop instead of await audio_delay tokens + # 3. Prepare text channel embeddings text_channel_embedding, text_channel_lens = self.prepare_text_channel_embeddings( text=text, text_lens=text_lens, delay=text_delay, dropout_text_input=dropout_text_input or dropout_conditional_input, + is_multiturn=self.cfg.get("use_multiturn_dataset", False), + text_pad_id=self.pad_id, ) # 4. Prepare phoneme channel embeddings (if phoneme tokenizer is configured) @@ -766,10 +986,14 @@ def process_batch( audio_channel_lens, audio_codes_target, audio_codes_lens_target, + agent_mask, ) = self.prepare_audio_channel_embeddings( audio_codes=audio_codes, audio_codes_lens=audio_codes_lens, delay=audio_delay, + speech_eos_mask=speech_eos_mask, + agent_mask=agent_mask, + current_streaming_speech_delay=current_streaming_speech_delay, ) # 6. Sum the channel embeddings element-wise @@ -815,6 +1039,77 @@ def process_batch( phoneme_channel_embedding = torch.cat([phoneme_channel_embedding, padding], dim=1) combined_channel_embedding = combined_channel_embedding + phoneme_channel_embedding + if user_audio_embedded is not None: + bos_user_pad = torch.zeros( + user_audio_embedded.size(0), + 1, + user_audio_embedded.size(2), + device=user_audio_embedded.device, + dtype=user_audio_embedded.dtype, + ) + user_audio_embedded = torch.cat([bos_user_pad, user_audio_embedded], dim=1) + + # Align user conditioning to audio_codes_target timeline, + # same as agent_mask from prepare_audio_channel_embeddings(). + target_T = audio_codes_target.size(2) + + if user_audio_embedded.size(1) < target_T: + pad_len = target_T - user_audio_embedded.size(1) + user_audio_embedded = torch.cat( + [ + user_audio_embedded, + torch.zeros( + user_audio_embedded.size(0), + pad_len, + user_audio_embedded.size(2), + device=user_audio_embedded.device, + dtype=user_audio_embedded.dtype, + ), + ], + dim=1, + ) + else: + user_audio_embedded = user_audio_embedded[:, :target_T] + + batch_size = user_audio_embedded.size(0) + device = user_audio_embedded.device + + max_delay = audio_delay.max().item() + zero_delay_tensor = torch.zeros( + batch_size, + max_delay, + self.cfg.embedding_dim, + device=device, + dtype=user_audio_embedded.dtype, + ) + + user_audio_lens = audio_codes_lens_target.to(audio_delay.device) + + user_audio_channel_embedding, _ = self.join_embeddings_temporally( + embeddings=[zero_delay_tensor, user_audio_embedded], + lengths=[audio_delay, user_audio_lens], + ) + + if user_audio_channel_embedding.size(1) < max_channel_len: + pad_len = max_channel_len - user_audio_channel_embedding.size(1) + user_audio_channel_embedding = torch.cat( + [ + user_audio_channel_embedding, + torch.zeros( + batch_size, + pad_len, + user_audio_channel_embedding.size(2), + device=user_audio_channel_embedding.device, + dtype=user_audio_channel_embedding.dtype, + ), + ], + dim=1, + ) + else: + user_audio_channel_embedding = user_audio_channel_embedding[:, :max_channel_len] + + combined_channel_embedding = combined_channel_embedding + user_audio_channel_embedding + # 7. Join context with combined channel embeddings # The combined_channel_lens is the max of all channel lens for each batch item combined_channel_lens = ( @@ -861,7 +1156,12 @@ def process_batch( logits = self.final_proj(pred_embeddings_audio) # Compute codebook loss - codebook_loss, _ = self.compute_loss(logits, audio_codes_target, audio_codes_lens_target) + codebook_loss, _ = self.compute_loss( + logits, + audio_codes_target, + audio_codes_lens_target, + agent_mask_target=agent_mask if self.cfg.get("mask_user_on_loss", False) else None, + ) loss = self.parallel_codebook_loss_scale * codebook_loss # Compute local transformer loss if applicable @@ -873,8 +1173,12 @@ def process_batch( pred_embeddings, audio_codes_target, targets_offset_by_one=False ) local_transformer_loss, _ = self.compute_loss( - local_transformer_logits, audio_codes_target, audio_codes_lens_target + local_transformer_logits, + audio_codes_target, + audio_codes_lens_target, + agent_mask_target=agent_mask if self.cfg.get("mask_user_on_loss", False) else None, ) + loss = loss + self.local_transformer_loss_scale * local_transformer_loss # Compute phoneme loss if applicable @@ -894,10 +1198,19 @@ def process_batch( pb_phoneme_tokens_lens_target = phoneme_tokens_lens_stacked - 1 if (phoneme_corruption_mode != 'repeat_skip') and not ( - dropout_complete_phoneme_channel or dropout_conditional_input or dropout_text_input + dropout_complete_phoneme_channel + or (phoneme_turn_dropout is not None and phoneme_turn_dropout.any()) + or dropout_conditional_input + or dropout_text_input ): + custom_mask = None + if self.cfg.get("phoneme_loss_mask_padding", False): + custom_mask = pb_phoneme_tokens_target[:, 0, :] != self.phoneme_tokenizer.pad # (B, T') + elif self.cfg.get("mask_user_on_loss", False): + custom_mask = agent_mask + phoneme_loss, _ = self.compute_phoneme_loss( - pb_phoneme_logits, pb_phoneme_tokens_target, pb_phoneme_tokens_lens_target + pb_phoneme_logits, pb_phoneme_tokens_target, pb_phoneme_tokens_lens_target, custom_mask=custom_mask ) else: phoneme_loss = torch.tensor(0.0, device=logits.device) @@ -940,6 +1253,205 @@ def training_step(self, batch, batch_idx): audio_lens = batch['audio_lens'] audio_codes, audio_codes_lens = self._codec_helper.audio_to_codes(audio, audio_lens) + if ( + self.cfg.get("use_multiturn_dataset", False) + and batch["user_audio_turn_splitted"] is not None + and self.cfg.get("condition_on_user_speech", False) + ): + input_samples_per_frame = self.codec_model_samples_per_frame * self.frame_stacking_factor + + user_audio = batch["user_audio_turn_splitted"] + user_audio_lens = batch["user_audio_turn_splitted_lens"] + + turn_silence_prob = float(self.cfg.get("user_cond_silence_augmentation_prob", 0.0) or 0.0) + sample_silence_prob = float(self.cfg.get("user_cond_sample_silence_augmentation_prob", 0.0) or 0.0) + + if self.training and (turn_silence_prob > 0.0 or sample_silence_prob > 0.0): + user_audio = user_audio.clone() + + # randomly drop individual turns. + if turn_silence_prob > 0.0: + turn_silence_mask = ( + torch.rand( + user_audio.size(0), + device=user_audio.device, + ) + < turn_silence_prob + ) + else: + turn_silence_mask = torch.zeros( + user_audio.size(0), + device=user_audio.device, + dtype=torch.bool, + ) + + # randomly drop all turns for selected samples + if sample_silence_prob > 0.0: + B = batch["text"].shape[0] + + sample_silence_mask = ( + torch.rand( + B, + device=user_audio.device, + ) + < sample_silence_prob + ) + + indices = batch["user_audio_turn_splitted_indices"].to(user_audio.device) + turn_batch_indices = indices[:, 0].long() + + valid_turns = turn_batch_indices >= 0 + sample_drop_turn_mask = torch.zeros( + user_audio.size(0), + device=user_audio.device, + dtype=torch.bool, + ) + + sample_drop_turn_mask[valid_turns] = sample_silence_mask[turn_batch_indices[valid_turns]] + + silence_mask = turn_silence_mask | sample_drop_turn_mask + else: + silence_mask = turn_silence_mask + + if silence_mask.any(): + user_audio[silence_mask] = 0.0 + + user_audio_codes, user_audio_codes_lens = self._codec_helper.audio_to_codes( + user_audio, + user_audio_lens, + ) + + if self._codec_converter is not None: + user_audio_codes = self._codec_converter.convert_original_to_new( + audio_tokens=user_audio_codes, + audio_lens=user_audio_codes_lens, + ).long() + + user_audio_codes, user_audio_codes_lens = self.stack_codes( + user_audio_codes, + user_audio_codes_lens, + self.audio_bos_id, + self.audio_eos_id, + self.frame_stacking_factor, + self.num_audio_codebooks, + ) + + user_audio_embedded = self.embed_audio_tokens(user_audio_codes) + + B = batch["text"].shape[0] + T = batch["text"].shape[1] + D = user_audio_embedded.shape[-1] + + user_audio_embedded_restored = user_audio_embedded.new_zeros(B, T, D) + + sample_prob = float(self.cfg.get("user_cond_trim_augmentation_sample_prob", 0.0) or 0.0) + turn_prob = float(self.cfg.get("user_cond_trim_augmentation_turn_prob", 0.0) or 0.0) + base_trim = int(self.cfg.get("user_cond_trim_augmentation_base", 0) or 0) + + if self.training and sample_prob > 0.0 and turn_prob > 0.0 and base_trim > 0: + sample_trim_aug = torch.rand(B, device=user_audio_embedded.device) < sample_prob + else: + sample_trim_aug = torch.zeros(B, device=user_audio_embedded.device, dtype=torch.bool) + + indices = batch["user_audio_turn_splitted_indices"].to(user_audio_embedded.device) + for turn_idx, (b, start_sample, end_sample) in enumerate(indices): + b = int(b.item()) + if b < 0: + continue + + start_frame = int(torch.ceil(start_sample.float() / input_samples_per_frame).item()) + end_frame = int(end_sample.item()) // input_samples_per_frame + + start_frame = max(0, min(start_frame, T)) + end_frame = max(start_frame, min(end_frame, T)) + + seq_len = end_frame - start_frame + if seq_len <= 0: + continue + + boundary_trim = self.cfg.get("user_audio_boundary_trim", 0) + boundary_trim = 0 if boundary_trim is None else int(boundary_trim) + + if boundary_trim == 0: + real_start = 0 + real_end = int(user_audio_codes_lens[turn_idx].item()) + else: + turn_len_with_special = int(user_audio_codes_lens[turn_idx].item()) + real_start = 1 + real_end = max(real_start, turn_len_with_special - 1) + + turn_emb = user_audio_embedded[turn_idx, real_start:real_end] + + copy_len = min(seq_len, turn_emb.size(0)) + if copy_len <= 0: + continue + + turn_emb = turn_emb[:copy_len].clone() + + if boundary_trim > 0: + trim = min(boundary_trim, copy_len // 2) + if trim > 0: + turn_emb[:trim] = 0.0 + turn_emb[copy_len - trim :] = 0.0 + + if bool(sample_trim_aug[b].item()): + do_turn_aug = torch.rand((), device=user_audio_embedded.device).item() < turn_prob + + if do_turn_aug: + trim_delta = int( + torch.randint( + low=-1, + high=2, # {-1, 0, 1} + size=(), + device=user_audio_embedded.device, + ).item() + ) + + trim_amount = max(1, base_trim + trim_delta) + trim_amount = min(trim_amount, max(1, copy_len - 1)) + + aug_choice = random.choices( + ["left", "right", "both"], + weights=[0.3, 0.3, 0.4], + k=1, + )[0] + + zero_emb_pad = turn_emb.new_zeros(trim_amount, turn_emb.size(-1)) + + if aug_choice == "left": + # Remove tokens from the left, then right-pad zeros. + kept_emb = turn_emb[trim_amount:] + turn_emb = torch.cat([kept_emb, zero_emb_pad], dim=0) + + elif aug_choice == "right": + # Remove tokens from the right, then right-pad zeros. + kept_emb = turn_emb[: copy_len - trim_amount] + turn_emb = torch.cat([kept_emb, zero_emb_pad], dim=0) + + else: # "both" + # Remove trim_amount total tokens split across left and right. + left_trim = trim_amount // 2 + right_trim = trim_amount - left_trim + + # If trim_amount is odd, randomly decide which side loses the extra token. + if trim_amount % 2 == 1 and torch.rand((), device=user_audio_embedded.device).item() < 0.5: + left_trim, right_trim = right_trim, left_trim + + kept_emb = turn_emb[left_trim : copy_len - right_trim] + turn_emb = torch.cat([kept_emb, zero_emb_pad], dim=0) + + # Safety: keep exact same length for restore assignment. + turn_emb = turn_emb[:copy_len] + + dst_start = start_frame + dst_end = start_frame + copy_len + + user_audio_embedded_restored[b, dst_start:dst_end] = turn_emb + + user_audio_embedded = user_audio_embedded_restored + else: + user_audio_embedded = None + batch_output = self.process_batch( text=batch['text'], text_lens=batch['text_lens'], @@ -951,7 +1463,11 @@ def training_step(self, batch, batch_idx): context_audio_codes_lens=context_audio_codes_lens, phoneme_tokens=batch.get('phoneme_tokens'), phoneme_tokens_lens=batch.get('phoneme_tokens_lens'), + phoneme_turn_dropout=batch.get('phoneme_turn_dropout'), mode="train", + task=batch["task"] if self.cfg.get("use_multiturn_dataset", False) else None, + agent_mask=batch["agent_mask"] if self.cfg.get("use_multiturn_dataset", False) else None, + user_audio_embedded=user_audio_embedded, ) loss = batch_output.loss codebook_loss = batch_output.codebook_loss @@ -1049,6 +1565,8 @@ def validation_step(self, batch, batch_idx): phoneme_tokens=batch.get('phoneme_tokens'), phoneme_tokens_lens=batch.get('phoneme_tokens_lens'), mode="val", + task=batch["task"] if "task" in batch else None, + agent_mask=batch["agent_mask"] if "agent_mask" in batch else None, ) # Access ProcessBatchOutput dataclass attributes # logits come from the parallel prediction head @@ -1304,6 +1822,15 @@ def validation_step(self, batch, batch_idx): return val_output + def on_fit_start(self): + super().on_fit_start() + if not hasattr(self, "_codec_sil_codes_buffer"): + self._generate_codec_silence_buffer() + + def on_validation_epoch_start(self) -> None: + if torch.distributed.is_initialized(): + self.trainer.strategy.model.require_backward_grad_sync = False + def on_validation_epoch_end(self): collect = lambda key: torch.stack([x[key] for x in self.validation_step_outputs]).mean() val_loss = collect("val_loss") @@ -1360,6 +1887,9 @@ def collect_if_exists(key): self.validation_step_outputs.clear() # free memory + if torch.distributed.is_initialized(): + self.trainer.strategy.model.require_backward_grad_sync = True + def get_dataset(self, dataset_cfg, dataset_type): dataset = safe_instantiate( dataset_cfg.dataset, @@ -1393,27 +1923,58 @@ def get_dataset(self, dataset_cfg, dataset_type): def get_lhotse_dataloader(self, dataset_cfg, mode='train') -> torch.utils.data.DataLoader: # TODO @xueyang: better to distinguish cfg. self.cfg is the model cfg, while cfg here is train_ds cfg. Also # cfg is a classifier-free guidance. - dataset = MagpieTTSLhotseDataset( - sample_rate=self.sample_rate, - volume_norm=dataset_cfg.volume_norm, - codec_model_samples_per_frame=self.codec_model_samples_per_frame, - num_audio_codebooks=self.data_num_audio_codebooks, - prior_scaling_factor=0.0, - load_cached_codes_if_available=self.cfg.load_cached_codes_if_available, - dataset_type=mode, # train or test used for setting phone prob to 1.0 in test dataset (worker_init_fn) - load_16khz_audio=False, - pad_context_text_to_max_duration=self.pad_context_text_to_max_duration, - context_duration_min=self.cfg.context_duration_min, - context_duration_max=self.cfg.context_duration_max, - use_text_conditioning_tokenizer=True, - text_conditioning_tokenizer_name=self.text_conditioning_tokenizer_name, - tokenizer_config=self.cfg.text_tokenizers, - phoneme_tokenizer_config=self.cfg.get("phoneme_tokenizer", None), - ignore_phoneme_languages=self.cfg.get("ignore_phoneme_languages", []), - phoneme_as_text_prob=self.phoneme_as_text_prob if mode == 'train' else 0.0, - pronunciation_control_g2p=self.cfg.get("pronunciation_control_g2p", None), - add_language_to_context_text=self.add_language_to_context_text, - ) + if self.cfg.get("use_multiturn_dataset", False): + dataset = MagpieTTSLhotseMultiturnDataset( + sample_rate=self.sample_rate, + volume_norm=dataset_cfg.volume_norm, + codec_model_samples_per_frame=self.codec_model_samples_per_frame, + codec_model_input_sample_rate=self.codec_model_input_sample_rate, + frame_stacking_factor=self.frame_stacking_factor, + num_audio_codebooks=self.data_num_audio_codebooks, + prior_scaling_factor=0.0, + load_cached_codes_if_available=self.cfg.load_cached_codes_if_available, + dataset_type=mode, # train or test used for setting phone prob to 1.0 in test dataset (worker_init_fn) + load_16khz_audio=False, + pad_context_text_to_max_duration=self.pad_context_text_to_max_duration, + context_duration_min=self.cfg.context_duration_min, + context_duration_max=self.cfg.context_duration_max, + use_text_conditioning_tokenizer=True, + text_conditioning_tokenizer_name=self.text_conditioning_tokenizer_name, + tokenizer_config=self.cfg.text_tokenizers, + phoneme_tokenizer_config=self.cfg.get("phoneme_tokenizer", None), + ignore_phoneme_languages=self.cfg.get("ignore_phoneme_languages", []), + add_language_to_context_text=self.add_language_to_context_text, + source_sample_rate=self.sample_rate, + input_roles=["user", "User"], + output_roles=["assistant", "Assistant", "agent", "Agent"], + add_text_bos=self.cfg.get("add_text_bos", False), + phoneme_turn_dropout_batch_prob=self.cfg.get("phoneme_turn_dropout_batch_prob", 0.0), + phoneme_turn_dropout_turn_prob=self.cfg.get("phoneme_turn_dropout_turn_prob", 0.0), + phoneme_turn_max_words_to_drop=self.cfg.get("phoneme_turn_max_words_to_drop", 2), + ) + dataset = FallbackDataset(dataset) + else: + dataset = MagpieTTSLhotseDataset( + sample_rate=self.sample_rate, + volume_norm=dataset_cfg.volume_norm, + codec_model_samples_per_frame=self.codec_model_samples_per_frame, + num_audio_codebooks=self.data_num_audio_codebooks, + prior_scaling_factor=0.0, + load_cached_codes_if_available=self.cfg.load_cached_codes_if_available, + dataset_type=mode, # train or test used for setting phone prob to 1.0 in test dataset (worker_init_fn) + load_16khz_audio=False, + pad_context_text_to_max_duration=self.pad_context_text_to_max_duration, + context_duration_min=self.cfg.context_duration_min, + context_duration_max=self.cfg.context_duration_max, + use_text_conditioning_tokenizer=True, + text_conditioning_tokenizer_name=self.text_conditioning_tokenizer_name, + tokenizer_config=self.cfg.text_tokenizers, + phoneme_tokenizer_config=self.cfg.get("phoneme_tokenizer", None), + ignore_phoneme_languages=self.cfg.get("ignore_phoneme_languages", []), + phoneme_as_text_prob=self.phoneme_as_text_prob if mode == 'train' else 0.0, + pronunciation_control_g2p=self.cfg.get("pronunciation_control_g2p", None), + add_language_to_context_text=self.add_language_to_context_text, + ) data_loader = get_lhotse_dataloader_from_config( config=dataset_cfg.dataset, @@ -1421,6 +1982,7 @@ def get_lhotse_dataloader(self, dataset_cfg, mode='train') -> torch.utils.data.D world_size=self.world_size, dataset=dataset, ) + return data_loader def setup_training_data(self, dataset_cfg): diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index c0ef45d9a7a9..5ade43603b7b 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import random +import tempfile import time +from collections import Counter from dataclasses import dataclass, fields from functools import partial from typing import Any, Dict, List, Optional, Sequence, Tuple @@ -25,6 +27,8 @@ from torch import nn from transformers import AutoConfig, AutoModelForCausalLM +from nemo.collections.audio.parts.utils.transforms import resample +from nemo.collections.speechlm2.parts.pretrained import set_model_dict_for_partial_init from nemo.collections.tts.data.text_to_speech_dataset_lhotse import setup_tokenizers from nemo.collections.tts.models import AudioCodecModel from nemo.collections.tts.modules import transformer_2501 @@ -40,6 +44,7 @@ from nemo.collections.tts.parts.utils.helpers import get_mask_from_lengths from nemo.core.classes import ModelPT from nemo.core.classes.common import PretrainedModelInfo, safe_instantiate +from nemo.core.connectors.save_restore_connector import SaveRestoreConnector from nemo.utils import logging from nemo.utils.exceptions import NeMoBaseException @@ -130,6 +135,7 @@ class StreamingState: full_context_lens: torch.Tensor context_position: torch.Tensor text_tokens_seen: torch.Tensor + turn_text_tokens_seen: torch.Tensor phoneme_steps: torch.Tensor audio_steps: torch.Tensor phoneme_stream_ended: torch.Tensor @@ -143,6 +149,7 @@ class StreamingState: audio_prediction_end_idx: torch.Tensor phoneme_prediction_start_idx: torch.Tensor phoneme_prediction_end_idx: torch.Tensor + gt_phoneme_embeddings: Optional[torch.Tensor] = None # (B, T', E) pre-computed GT embeddings gt_phoneme_lens: Optional[torch.Tensor] = None # (B,) lengths after stacking gt_audio_embeddings: Optional[torch.Tensor] = None # (B, T', E) pre-computed GT audio embeddings @@ -255,6 +262,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.codebook_size = codebook_size self.codec_model_samples_per_frame = codec_model.samples_per_frame + self.codec_model_input_sample_rate = codec_model.sample_rate # Our codebooks start with actual audio codec tokens, followed by special tokens. # The `forced_*` options are for backward compatibility for models trained with older code. get_token_index = partial(SpecialAudioToken.get_index, base_codebook_size=self.codebook_size) @@ -263,6 +271,8 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.context_audio_bos_id = get_token_index(SpecialAudioToken.AUDIO_CONTEXT_BOS) self.context_audio_eos_id = get_token_index(SpecialAudioToken.AUDIO_CONTEXT_EOS) self.mask_token_id = get_token_index(SpecialAudioToken.MASK_TOKEN) + self.audio_user_speaking_id = get_token_index(SpecialAudioToken.USER_SPEAKING) + self.audio_user_speaking_end_id = get_token_index(SpecialAudioToken.USER_SPEAKING_END) self.num_all_tokens_per_codebook = self.codebook_size + len(SpecialAudioToken) self.use_bpe_char_tokenizer = cfg.get('use_bpe_char_tokenizer', False) # If True, text tokens are embedded only with the char-aware subword (CAS) encoder, and the decoder token @@ -270,7 +280,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.disable_subword_embedding = cfg.get('disable_subword_embedding', False) # If True, remove the decoder LM head over text tokens to save parameters when the model does not train or # infer text-token logits from the decoder output. - self.disable_lm_text_head = cfg.get('disable_lm_text_head', False) + self.disable_lm_text_head = cfg.get('disable_lm_text_head', True) # Legacy checkpoints may have trained context text with decoder embeddings only, even when CAS is enabled for # regular text tokens. This flag skips adding CAS embeddings for context text to match those checkpoints. self.disable_cas_for_context_text = cfg.get('disable_cas_for_context_text', False) @@ -333,11 +343,24 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): mode='train', ) - num_tokens_tokenizer = len(self.tokenizer.tokens) - num_tokens = num_tokens_tokenizer + 3 # +3 for BOS, EOS, CFG_UNK - self.bos_id = num_tokens - 3 - self.eos_id = num_tokens - 2 - self.cfg_unk_token_id = num_tokens - 1 + base_num_tokens = len(self.tokenizer.tokens) + + # Assign standard special tokens sequentially + self.bos_id = base_num_tokens + self.eos_id = base_num_tokens + 1 + self.cfg_unk_token_id = base_num_tokens + 2 + special_tokens_added = 3 + + # Conditionally add the interruption token + if cfg.get("use_multiturn_dataset", False): + self.interruption_token_id = base_num_tokens + special_tokens_added + special_tokens_added += 1 + + # Calculate the final total vocabulary size + num_tokens = base_num_tokens + special_tokens_added + + self.pad_id = self.tokenizer.pad + self.phoneme_tokenizer = None if cfg.get('phoneme_tokenizer', None) is not None: self.phoneme_tokenizer = safe_instantiate(cfg.phoneme_tokenizer) @@ -480,6 +503,17 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): else: raise ValueError(f"Unknown decoder_type: {self.decoder_type}. Supported: 'huggingface', 'nemotron_h'") + self.activation_checkpointing = cfg.get("activation_checkpointing", False) + if self.activation_checkpointing: + logging.info("Enabling activation checkpointing for decoder") + + if self.decoder_type == "nemotron_h": + self.decoder.gradient_checkpointing = True + elif hasattr(self.decoder, "gradient_checkpointing_enable"): + self.decoder.gradient_checkpointing_enable() + elif hasattr(self.decoder, "gradient_checkpointing"): + self.decoder.gradient_checkpointing = True + if self.disable_lm_text_head and hasattr(self.decoder, 'lm_head'): self.decoder.lm_head = None @@ -515,6 +549,9 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): '': self.eos_id, '': self.cfg_unk_token_id, } + if cfg.get("use_multiturn_dataset", False): + special_vocab[""] = self.interruption_token_id + self.cas_encoder = CharAwareSubwordEncoder( d_embed=cfg.embedding_dim, llm_tokenizer_vocab=subword_vocab, @@ -587,6 +624,187 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): codebook_size=self.codebook_size, ) + @property + def codec_sil_codes(self): + return self._codec_sil_codes_buffer + + @property + def codec_sil_codes_unconverted(self): + return self._codec_sil_codes_buffer_unconverted + + def restore_from_pretrained_checkpoint(self, checkpoint_path): + """ + Loads model weights a pretrained checkpoint file, supporting partial loading from safetensor and PyTorch formats. + + Args: + checkpoint_path (str): Path to checkpoint file. + + Returns: + None. The model is updated in-place. + """ + if checkpoint_path is not None: + if '.nemo' in checkpoint_path: + with tempfile.TemporaryDirectory() as tmpdir: + SaveRestoreConnector._unpack_nemo_file(checkpoint_path, tmpdir) + checkpoint_path = f"{tmpdir}/model_weights.ckpt" + checkpoint_state = torch.load(checkpoint_path, map_location='cpu') + else: + checkpoint_state = torch.load(checkpoint_path, map_location='cpu') + checkpoint_state = set_model_dict_for_partial_init( + checkpoint_state, self.state_dict(), allow_partial_copy=True + ) + + self.load_state_dict(checkpoint_state, strict=True) + logging.info(f"Model restored from the checkpoint: {checkpoint_path} !") + + def _generate_codec_silence_buffer(self): + codec_device = next(self._codec_model.parameters()).device + + audio = torch.zeros(1, 5 * self.sample_rate, dtype=torch.float32, device=codec_device) + audio_len = torch.tensor([audio.size(-1)], dtype=torch.long, device=codec_device) + + with torch.no_grad(): + sil_codes_raw, sil_codes_lens = self._codec_helper.audio_to_codes(audio, audio_len) + + frames_raw = sil_codes_raw[0].transpose(0, 1) + combos_raw = [tuple(frame.tolist()) for frame in frames_raw] + most_common_raw, _ = Counter(combos_raw).most_common(1)[0] + sil_tensor_unconverted = torch.tensor(most_common_raw, device=codec_device, dtype=torch.long) + + if self._codec_converter is not None: + sil_codes_conv = self._codec_converter.convert_original_to_new( + audio_tokens=sil_codes_raw, audio_lens=sil_codes_lens + ).long() + frames_conv = sil_codes_conv[0].transpose(0, 1) + combos_conv = [tuple(frame.tolist()) for frame in frames_conv] + most_common_conv, _ = Counter(combos_conv).most_common(1)[0] + sil_tensor_converted = torch.tensor(most_common_conv, device=codec_device, dtype=torch.long) + else: + sil_tensor_converted = sil_tensor_unconverted.clone() + + if not hasattr(self, "_codec_sil_codes_buffer"): + self.register_buffer("_codec_sil_codes_buffer", sil_tensor_converted, persistent=False) + self.register_buffer("_codec_sil_codes_buffer_unconverted", sil_tensor_unconverted, persistent=False) + else: + self._codec_sil_codes_buffer.copy_(sil_tensor_converted) + self._codec_sil_codes_buffer_unconverted.copy_(sil_tensor_unconverted) + + def streaming_prefill_profile( + self, + state: StreamingState, + text_tokens: torch.Tensor, # (B, T) or (B,) + user_audio_channel_embedding: torch.Tensor = None, + use_inference_mode: bool = True, + # ToDo: implement audio direct support instead of use silence tokens + ) -> StreamingState: + grad_ctx = torch.inference_mode if use_inference_mode else torch.no_grad + with grad_ctx(): + if text_tokens.dim() == 1: + text_tokens = text_tokens[:, None] + + B, T = text_tokens.shape + device = state.config.device + text_tokens = text_tokens.to(device) + + # ----------------------- + # TEXT CHANNEL + # ----------------------- + text_emb = self.embed_text_tokens( + text_tokens, text_lens=None, is_multiturn=self.cfg.get("use_multiturn_dataset", False) + ) + if self.cfg.get("use_multiturn_dataset", False): + text_emb[text_tokens == self.pad_id] = 0.0 + + # ----------------------- + # AUDIO CHANNEL: previous-token input during profile + # ----------------------- + C = self.num_audio_codebooks + S = self.frame_stacking_factor + + sil_codes = self.codec_sil_codes.to(device=device, dtype=torch.long) # (C,) + + # Keep all_predictions as real silence so decoded waveform has silence during profile. + sil_codes_unstacked = sil_codes.view(1, C, 1).expand(B, C, T * S).contiguous() + + if self.cfg.get("use_user_speaking_token", False): + # Match training: during non-agent/user-speaking regions, the audio INPUT token + # is audio_user_speaking_id. + profile_audio_stacked = torch.full( + (B, C * S, T), + self.audio_user_speaking_id, + dtype=torch.long, + device=device, + ) + else: + profile_audio_stacked, _ = self.stack_codes( + sil_codes_unstacked, + torch.full((B,), T * S, dtype=torch.long, device=device), + bos_id=self.audio_bos_id, + eos_id=self.audio_eos_id, + stacking_factor=S, + num_codebooks=C, + ) # (B, C*S, T) + + audio_emb = self.embed_audio_tokens(profile_audio_stacked) + + # Match training channel sum: text + audio profile-token/silence input. + combined_emb = text_emb + audio_emb + + if self.cfg.get("condition_on_user_speech", False): + combined_emb = combined_emb + user_audio_channel_embedding + + # ----------------------- + # CFG handling + # ----------------------- + if state.config.use_cfg: + # Match regular streaming inference: + # conditional branch = text + audio + # unconditional branch = audio only + inputs_embeds = torch.cat([combined_emb, audio_emb], dim=0) + else: + inputs_embeds = combined_emb + + # ----------------------- + # KV CACHE EXTENSION + # ----------------------- + cache_position = torch.arange( + state.cache_seq_len, + state.cache_seq_len + T, + device=device, + ) + + out = self.forward( + inputs_embeds=inputs_embeds, + attention_mask=None, + use_cache=True, + past_key_values=state.past_key_values, + cache_position=cache_position, + ) + + state.past_key_values = out.past_key_values + state.cache_seq_len += T + state.last_hidden = out.last_hidden_state + + # Advance logical streams consumed by this profile prefill. + state.text_tokens_seen += T + state.audio_steps += T + + # Make the next normal streaming_step continue from silence, not AUDIO_BOS. + state.all_predictions.append( + sil_codes_unstacked + ) # keep silence so that in the target audio user will be silence + if self.cfg.get("use_user_speaking_token", False): + state.last_audio_codes = torch.full( + (B, C * S), + self.audio_user_speaking_id, + dtype=torch.long, + device=device, + ) + else: + state.last_audio_codes = profile_audio_stacked[:, :, -1].contiguous() + + return state + def _get_state_dict_keys_to_exclude(self) -> List[str]: return [ '_codec_model', @@ -672,6 +890,7 @@ def embed_text_tokens( text_tokens: torch.Tensor, text_lens: Optional[torch.Tensor] = None, disable_cas_embedding: bool = False, + is_multiturn: bool = False, ) -> torch.Tensor: """Embed text tokens using decoder embedding + optional CAS, or CAS-only when configured. @@ -680,13 +899,18 @@ def embed_text_tokens( text_lens: Optional valid token lengths for constructing the CAS mask. Defaults to the full sequence length. disable_cas_embedding: When True, skip adding CAS embeddings even if the model uses the BPE char tokenizer. This is needed for legacy models where context text was trained without CAS embeddings. + is_multiturn: When True creates the text_mask based on non text pad ids positions, so that it can support multiturn. """ if text_lens is None: text_lens = torch.full( (text_tokens.size(0),), text_tokens.size(1), dtype=torch.long, device=text_tokens.device ) - text_mask = get_mask_from_lengths(text_lens) + if is_multiturn: + text_mask = text_tokens != self.tokenizer.pad + else: + text_mask = get_mask_from_lengths(text_lens) + if self.disable_subword_embedding: if disable_cas_embedding: raise ValueError("Cannot disable CAS embedding when `disable_subword_embedding=True`.") @@ -952,7 +1176,6 @@ def prepare_context_tensors( ) context_audio_embedded = self.embed_audio_tokens(context_audio_codes) # (B, T', E) batch_size = context_audio_embedded.size(0) - if self.use_speaker_encoder: if ( self.training @@ -1291,6 +1514,11 @@ def streaming_init( ) gt_phoneme_embeddings = self.embed_phoneme_tokens(gt_phoneme_stacked) # (B, T', E) + if self.cfg.get("use_multiturn_dataset", False): + phoneme_pad_id = getattr(self.phoneme_tokenizer, "pad", -1) + phoneme_mask = gt_phoneme_stacked[:, 0, :] != phoneme_pad_id + gt_phoneme_embeddings = gt_phoneme_embeddings * phoneme_mask.unsqueeze(2) + # Process GT audio codes if provided (for teacher forcing) gt_audio_embeddings = None gt_audio_lens_state = None @@ -1326,6 +1554,7 @@ def streaming_init( full_context_lens=full_context_lens, context_position=torch.full((batch_size,), min_context_len, dtype=torch.long, device=device), text_tokens_seen=torch.zeros(batch_size, dtype=torch.long, device=device), + turn_text_tokens_seen=torch.zeros(batch_size, dtype=torch.long, device=device), phoneme_steps=torch.zeros(batch_size, dtype=torch.long, device=device), audio_steps=torch.zeros(batch_size, dtype=torch.long, device=device), phoneme_stream_ended=torch.zeros(batch_size, dtype=torch.bool, device=device), @@ -1352,7 +1581,10 @@ def streaming_step( state: StreamingState, text_tokens: Optional[torch.Tensor] = None, force_dropout_text: bool = False, + user_audio_channel_embedding: Optional[torch.Tensor] = None, + prefill_like_step: bool = False, use_inference_mode: bool = True, + prefill_like_is_last_step: bool = False, ) -> Tuple[StreamingState, Optional[torch.Tensor], Optional[torch.Tensor]]: """ Perform one streaming inference step with batch support. @@ -1383,7 +1615,10 @@ def streaming_step( # Phase 1: Prepare input embedding and determine per-item phase masks next_input, needs_context, needs_phoneme, needs_audio = self._prepare_streaming_input( - state, text_tokens, force_dropout_text + state, + text_tokens, + force_dropout_text, + user_audio_channel_embedding=user_audio_channel_embedding, ) # Phase 2: Transformer forward pass @@ -1400,11 +1635,93 @@ def streaming_step( state.past_key_values = transformer_out.past_key_values state.cache_seq_len += 1 + if prefill_like_step: + # Advance logical streams, keep audio silent, but predict phonemes if enabled. + state.context_position += needs_context.long() + state.text_tokens_seen += (~needs_context).long() + + if hasattr(state, "turn_text_tokens_seen"): + state.turn_text_tokens_seen += (~needs_context).long() + + C = self.num_audio_codebooks + S = self.frame_stacking_factor + B = state.config.batch_size + + sil = self.codec_sil_codes.to(device=device, dtype=torch.long) + sil = sil.view(1, C, 1).expand(B, C, S).contiguous() + + # Keep decoded profile/warmup region silent. + state.all_predictions.append(sil) + + pred_phoneme_tokens = None + + if needs_phoneme.any() and self.phoneme_tokenizer is not None: + first_phoneme_step = needs_phoneme & (state.phoneme_prediction_start_idx == -1) + if first_phoneme_step.any(): + current_phoneme_step_idx = len(state.all_phoneme_predictions) + state.phoneme_prediction_start_idx = torch.where( + first_phoneme_step, + torch.full_like(state.phoneme_prediction_start_idx, current_phoneme_step_idx), + state.phoneme_prediction_start_idx, + ) + + pred_phoneme_tokens = self._predict_phoneme_tokens(state) + if state.last_phoneme_tokens is None: + state.last_phoneme_tokens = pred_phoneme_tokens + else: + update_mask = needs_phoneme.view(B, 1).expand_as(pred_phoneme_tokens) + state.last_phoneme_tokens = torch.where( + update_mask, + pred_phoneme_tokens, + state.last_phoneme_tokens, + ) + + state.all_phoneme_predictions.append(pred_phoneme_tokens) + + phoneme_eos_detected = needs_phoneme & ( + pred_phoneme_tokens == self.phoneme_tokenizer.eos_token_id + ).any(dim=1) + + state.phoneme_eos_detected = state.phoneme_eos_detected | phoneme_eos_detected + + newly_ended_phoneme = phoneme_eos_detected & (state.phoneme_prediction_end_idx == -1) + if newly_ended_phoneme.any(): + current_phoneme_step_idx = len(state.all_phoneme_predictions) + state.phoneme_prediction_end_idx = torch.where( + newly_ended_phoneme, + torch.full_like(state.phoneme_prediction_end_idx, current_phoneme_step_idx), + state.phoneme_prediction_end_idx, + ) + + state.phoneme_steps += needs_phoneme.long() + state.audio_steps += needs_audio.long() + + # Match training: the input slot that predicts the first agent frame after + # a user/non-agent region should receive the learned user-speaking-end token. + use_end_token = prefill_like_is_last_step and self.cfg.get("use_user_speaking_end_token", False) + + if use_end_token: + state.last_audio_codes = torch.full( + (B, C * S), + self.audio_user_speaking_end_id, + dtype=torch.long, + device=device, + ) + elif self.cfg.get("use_user_speaking_token", False): + state.last_audio_codes = torch.full( + (B, C * S), + self.audio_user_speaking_id, + dtype=torch.long, + device=device, + ) + else: + state.last_audio_codes = sil.reshape(B, C * S) + return state, None, pred_phoneme_tokens + # Phase 3: Update counters and extract predictions audio_codes_next, pred_phoneme_tokens = self._process_predictions( state, needs_context, needs_phoneme, needs_audio ) - return state, audio_codes_next, pred_phoneme_tokens def _prepare_streaming_input( @@ -1412,6 +1729,7 @@ def _prepare_streaming_input( state: StreamingState, text_tokens: Optional[torch.Tensor], force_dropout_text: bool, + user_audio_channel_embedding: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Build the input embedding for one streaming step. @@ -1432,10 +1750,13 @@ def _prepare_streaming_input( # Determine phases per batch item needs_context = state.context_position < state.full_context_lens # (B,) bool needs_text = (~needs_context) & (~state.text_finished) + + turn_text_tokens_seen = getattr(state, "turn_text_tokens_seen", state.text_tokens_seen) needs_phoneme = ( - (~needs_context) & (state.text_tokens_seen >= streaming_phonemes_delay) & (~state.phoneme_stream_ended) + (~needs_context) & (turn_text_tokens_seen >= streaming_phonemes_delay) & (~state.phoneme_stream_ended) ) - needs_audio = (~needs_context) & (state.text_tokens_seen >= streaming_speech_delay) & (~state.finished) + + needs_audio = (~needs_context) & (turn_text_tokens_seen >= streaming_speech_delay) & (~state.finished) next_input = torch.zeros(batch_size, 1, self.cfg.embedding_dim, device=device) @@ -1457,11 +1778,17 @@ def _prepare_streaming_input( text_embedded = self.embed_text_tokens( text_tokens_2d, text_lens=torch.ones(batch_size, dtype=torch.long, device=device), + is_multiturn=self.cfg.get("use_multiturn_dataset", False), ) # (B, 1, E) if force_dropout_text: text_embedded = text_embedded * 0 + # Zero out padding tokens exactly like in process_batch + if self.cfg.get("use_multiturn_dataset", False): + is_pad = text_tokens_2d == self.tokenizer.pad + text_embedded[is_pad] = 0.0 + is_eos_token = (text_tokens == self.eos_id) & needs_text # (B,) bool text_add_mask = needs_text.view(batch_size, 1, 1).float() next_input = next_input + text_embedded * text_add_mask @@ -1500,9 +1827,21 @@ def _prepare_streaming_input( phoneme_emb = phoneme_emb + phoneme_bos_emb * first_mask if has_last_phoneme.any() and state.last_phoneme_tokens is not None: - last_phoneme_emb = self.embed_phoneme_tokens( - state.last_phoneme_tokens.unsqueeze(2) - ) # (B, 1, E) + last_phoneme_tokens = state.last_phoneme_tokens # (B, S_ph) + + last_phoneme_emb = self.embed_phoneme_tokens(last_phoneme_tokens.unsqueeze(2)) # (B, 1, E) + + # Match training: PAD phoneme inputs contribute zero embedding. + if self.cfg.get("use_multiturn_dataset", False): + phoneme_pad_id = getattr(self.phoneme_tokenizer, "pad", None) + if phoneme_pad_id is not None: + # Same convention as training: check first stacked phoneme channel. + phoneme_is_pad = last_phoneme_tokens[:, 0] == phoneme_pad_id # (B,) + last_phoneme_emb = last_phoneme_emb * (~phoneme_is_pad).view(batch_size, 1, 1).float() + else: + raise ValueError( + "self.phoneme_tokenizer.pad is not defined, so it is not possible to zero-out the phoneme on the padding positon, please verify it!" + ) last_mask = has_last_phoneme.view(batch_size, 1, 1).float() phoneme_emb = phoneme_emb + last_phoneme_emb * last_mask @@ -1544,6 +1883,14 @@ def _prepare_streaming_input( next_input = next_input + audio_emb + if user_audio_channel_embedding is not None: + user_audio_channel_embedding = user_audio_channel_embedding.unsqueeze(1) + + next_input = next_input + user_audio_channel_embedding.to( + device=next_input.device, + dtype=next_input.dtype, + ) + # --- Handle CFG --- if state.config.use_cfg: next_input_unconditional_context = state.config.dummy_context_embedding_unconditional.expand( @@ -1588,6 +1935,9 @@ def _process_predictions( # Update counters state.context_position = state.context_position + needs_context.long() state.text_tokens_seen = state.text_tokens_seen + (~needs_context).long() + if hasattr(state, "turn_text_tokens_seen"): + state.turn_text_tokens_seen = state.turn_text_tokens_seen + (~needs_context).long() + state.phoneme_steps = state.phoneme_steps + needs_phoneme.long() state.audio_steps = state.audio_steps + needs_audio.long() @@ -2099,11 +2449,13 @@ def _load_audio_for_inference(audio_path: str, target_sample_rate: int) -> torch audio, sr = sf.read(audio_path, dtype='float32') if len(audio.shape) > 1: audio = audio.mean(axis=1) + + audio = torch.from_numpy(audio).unsqueeze(0) + if sr != target_sample_rate: - import librosa + audio = resample(audio.float(), sr, target_sample_rate) - audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sample_rate) - return torch.from_numpy(audio).unsqueeze(0) + return audio @staticmethod def _adjust_audio_to_duration_for_inference( diff --git a/nemo/collections/tts/modules/ffn_modules.py b/nemo/collections/tts/modules/ffn_modules.py index 853514dd1f58..7f477fcbcd19 100644 --- a/nemo/collections/tts/modules/ffn_modules.py +++ b/nemo/collections/tts/modules/ffn_modules.py @@ -89,6 +89,7 @@ def forward(self, signal, signal_mask=None): # signal: (B, C, T) # signal_mask: (B, T) or None (if None, assumes all positions are valid) if signal_mask is not None: + signal_mask = signal_mask.to(device=signal.device, dtype=signal.dtype) signal = signal * signal_mask.unsqueeze(1) if self.is_causal: # TODO: maybe replace with identify rather than keep conditional if in forward signal = F.pad(signal, self.causal_padding) diff --git a/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py b/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py index 9d2540a694ff..5601366addfe 100644 --- a/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py +++ b/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py @@ -18,6 +18,7 @@ import json import os import pprint +import re import tempfile import time from collections import Counter @@ -52,12 +53,48 @@ ) +# Regexes mirrored from the IPA preprocessing script that creates +# custom["text_without_annotation"]. This is used only for text inputs +# during metric computation when requested. +_WS_RE = re.compile(r"\s+") +_SPACE_BEFORE_PUNCT_RE = re.compile(r"\s+([,.;:!?؟،؛])") +_TATWEEL_RE = re.compile("\u0640+") +_ANNOTATION_OR_MARKER_RE = re.compile( + r""" + \[[^\[\]\n]{1,512}\] # square annotation: [breath], [نقر] + | \n]{1,512}> # XML/style/language tags + | \{/?[^{}\n]{1,512}\} # curly control/pronunciation tags + | [-–—]{2,} # multi-dash cutoff: --, ---, —— + | (?<=\S)[-–—](?=\s|$) # trailing single dash after a token: word- + | (?:^|(?<=\s))[-–—](?=\s|$) # standalone dash + | \.{3,} # ASCII ellipsis + | …+ # Unicode ellipsis + | \*+ # emphasis marker: *word* + """, + re.VERBOSE, +) + + +def strip_text_annotations_from_text(text: str) -> str: + """Return orthographic text with annotation/control tokens removed.""" + text = _ANNOTATION_OR_MARKER_RE.sub(" ", str(text)) + text = _TATWEEL_RE.sub("", text) + text = _WS_RE.sub(" ", text).strip() + text = _SPACE_BEFORE_PUNCT_RE.sub(r"\1", text) + return text.strip() + + FILEWISE_METRICS_TO_SAVE = [ 'cer', 'wer', 'pred_context_ssim', + 'pred_gt_esim', + 'pred_gt_ems', 'pred_text', 'gt_text', + 'predicted_phoneme_text', + 'predicted_phoneme_tokens', + 'predicted_phoneme_token_labels', 'gt_audio_filepath', 'pred_audio_filepath', 'context_audio_filepath', @@ -265,7 +302,13 @@ def transcribed_batched( def load_evaluation_models( - language="en", sv_model_type="titanet", asr_model_name="stt_en_conformer_transducer_large", device="cuda" + language="en", + sv_model_type="titanet", + asr_model_name="stt_en_conformer_transducer_large", + device="cuda", + with_emotion_metrics=False, + emotion_model_size="small", + emotion_cache_dir=None, ): """Load ASR and speaker verification models used for evaluation. @@ -284,6 +327,7 @@ def load_evaluation_models( 'whisper_model': None, 'whisper_processor': None, 'feature_extractor': None, + 'emotion_model': None, } if language == "en": @@ -314,9 +358,43 @@ def load_evaluation_models( ) models['sv_model_alternate'] = models['sv_model_alternate'].to(device).eval() + if with_emotion_metrics: + logging.info("Loading emotion encoder for ESIM/EMS metrics...") + try: + from nemo.collections.tts.metrics.emotion_encoder import EmpathicInsightVoice + + models['emotion_model'] = EmpathicInsightVoice.from_pretrained( + size=emotion_model_size, + device=device, + mlp_device=device, + cache_dir=emotion_cache_dir, + cache_classifiers=True, + load_all_classifiers=False, + top_k_emotions=1, + ).eval() + except Exception as e: + logging.warning(f"Emotion encoder could not be loaded: {e}. ESIM/EMS metrics will be set to NaN.") + return models +def compute_emotion_pair_metrics(emotion_model, gt_audio_path, pred_audio_path, embedding_type="score_vector"): + """Compute ground-truth to predicted emotion similarity and top-emotion match.""" + if emotion_model is None or gt_audio_path is None or pred_audio_path is None: + return float('NaN'), float('NaN') + + try: + result = emotion_model.compare_emotion_pair( + audio_path_a=gt_audio_path, + audio_path_b=pred_audio_path, + embedding_type=embedding_type, + ) + return float(result["emotion_similarity"]), float(result["top_emotion_match"]) + except Exception as e: + logging.warning(f"Could not compute ESIM/EMS for {gt_audio_path} and {pred_audio_path}: {e}") + return float('NaN'), float('NaN') + + def classify_eou_batched( eou_classifier: EoUClassifier, items: list[tuple[Union[str, np.ndarray], str]], batch_size: int = 32 ) -> list[EoUClassification]: @@ -345,6 +423,11 @@ def evaluate_dir( sv_model_type="titanet", asr_model_name="stt_en_conformer_transducer_large", with_utmosv2=True, + strip_text_annotations_for_metrics=False, + with_emotion_metrics=False, + emotion_model_size="small", + emotion_embedding_type="score_vector", + emotion_cache_dir=None, asr_batch_size=32, eou_batch_size=32, device="cuda", @@ -377,7 +460,15 @@ def evaluate_dir( context_audio_paths = [_resolve_path(audio_dir, r.get('context_audio_filepath')) for r in records] # 2. Load models - models = load_evaluation_models(language, sv_model_type, asr_model_name, device) + models = load_evaluation_models( + language, + sv_model_type, + asr_model_name, + device, + with_emotion_metrics=with_emotion_metrics, + emotion_model_size=emotion_model_size, + emotion_cache_dir=emotion_cache_dir, + ) asr_model = models['asr_model'] whisper_model = models['whisper_model'] @@ -385,6 +476,7 @@ def evaluate_dir( feature_extractor = models['feature_extractor'] speaker_verification_model = models['sv_model'] speaker_verification_model_alternate = models['sv_model_alternate'] + emotion_model = models['emotion_model'] # 3. EoU classifier (support for English only) if language == "en": @@ -420,6 +512,8 @@ def evaluate_dir( asr_batch_size, label="predicted", ) + if strip_text_annotations_for_metrics: + pred_texts = [strip_text_annotations_from_text(text) for text in pred_texts] pred_texts = [text_processor.process_text_for_wer(text) for text in pred_texts] # Transcribe ground truth audios if len(gt_audio_paths) > 0: @@ -433,6 +527,8 @@ def evaluate_dir( asr_batch_size, label="ground truth", ) + if strip_text_annotations_for_metrics: + gt_audio_texts = [strip_text_annotations_from_text(text) for text in gt_audio_texts] gt_audio_texts = [text_processor.process_text_for_wer(text) for text in gt_audio_texts] else: gt_audio_texts = [None] * len(records) @@ -446,7 +542,10 @@ def evaluate_dir( text_field = 'normalized_text' else: text_field = 'text' - processed_text = text_processor.process_text_for_wer(record[text_field]) + text = record[text_field] + if strip_text_annotations_for_metrics: + text = strip_text_annotations_from_text(text) + processed_text = text_processor.process_text_for_wer(text) gt_texts_processed.append(processed_text) # 7. Batched EoU classification @@ -476,6 +575,16 @@ def evaluate_dir( detailed_cer = word_error_rate_detail(hypotheses=[pred_text], references=[gt_text], use_cer=True) detailed_wer = word_error_rate_detail(hypotheses=[pred_text], references=[gt_text], use_cer=False) + pred_gt_esim = float('NaN') + pred_gt_ems = float('NaN') + if with_emotion_metrics: + pred_gt_esim, pred_gt_ems = compute_emotion_pair_metrics( + emotion_model, + gt_audio_filepath, + pred_audio_filepath, + embedding_type=emotion_embedding_type, + ) + logging.info(f"{ridx} GT Text: {gt_text}") logging.info(f"{ridx} Pr Text: {pred_text}") # Format cer and wer to 2 decimal places @@ -494,7 +603,7 @@ def evaluate_dir( model=speaker_verification_model_alternate, extractor=feature_extractor, device=device, - sv_model_type=sv_model_type, + sv_model_type="titanet", # alternate is always titanet ) # Initialize SSIMs with a default since the context or ground truth audio @@ -561,32 +670,38 @@ def evaluate_dir( eou_trailing = float('nan') eou_rms_ratio = float('nan') - filewise_metrics.append( - { - 'gt_text': gt_text, - 'pred_text': pred_text, - 'gt_audio_text': gt_audio_text, - 'detailed_cer': detailed_cer, - 'detailed_wer': detailed_wer, - 'cer': detailed_cer[0], - 'wer': detailed_wer[0], - 'pred_gt_ssim': pred_gt_ssim, - 'pred_context_ssim': pred_context_ssim, - 'gt_context_ssim': gt_context_ssim, - 'pred_gt_ssim_alternate': pred_gt_ssim_alternate, - 'pred_context_ssim_alternate': pred_context_ssim_alternate, - 'gt_context_ssim_alternate': gt_context_ssim_alternate, - 'gt_audio_filepath': gt_audio_filepath, - 'pred_audio_filepath': pred_audio_filepath, - 'context_audio_filepath': context_audio_filepath, - 'utmosv2': utmosv2_score, - 'eou_type': eou_type, - 'eou_trailing_duration': eou_trailing, - 'eou_trail_rms_ratio': eou_rms_ratio, - 'total_gen_audio_seconds': file_duration, - 'predicted_codes_path': codes_file_lists[ridx] if has_codes else None, - } - ) + metric_row = { + 'gt_text': gt_text, + 'pred_text': pred_text, + 'gt_audio_text': gt_audio_text, + 'predicted_phoneme_text': record.get('predicted_phoneme_text', ''), + 'predicted_phoneme_tokens': record.get('predicted_phoneme_tokens', []), + 'predicted_phoneme_token_labels': record.get('predicted_phoneme_token_labels', []), + 'detailed_cer': detailed_cer, + 'detailed_wer': detailed_wer, + 'cer': detailed_cer[0], + 'wer': detailed_wer[0], + 'pred_gt_ssim': pred_gt_ssim, + 'pred_context_ssim': pred_context_ssim, + 'gt_context_ssim': gt_context_ssim, + 'pred_gt_ssim_alternate': pred_gt_ssim_alternate, + 'pred_context_ssim_alternate': pred_context_ssim_alternate, + 'gt_context_ssim_alternate': gt_context_ssim_alternate, + 'gt_audio_filepath': gt_audio_filepath, + 'pred_audio_filepath': pred_audio_filepath, + 'context_audio_filepath': context_audio_filepath, + 'utmosv2': utmosv2_score, + 'eou_type': eou_type, + 'eou_trailing_duration': eou_trailing, + 'eou_trail_rms_ratio': eou_rms_ratio, + 'total_gen_audio_seconds': file_duration, + 'predicted_codes_path': codes_file_lists[ridx] if has_codes else None, + } + if with_emotion_metrics: + metric_row['pred_gt_esim'] = pred_gt_esim + metric_row['pred_gt_ems'] = pred_gt_ems + + filewise_metrics.append(metric_row) return filewise_metrics @@ -599,8 +714,13 @@ def evaluate( sv_model_type="titanet", asr_model_name="stt_en_conformer_transducer_large", with_utmosv2=True, + strip_text_annotations_for_metrics=False, with_fcd=True, codec_model_path=None, + with_emotion_metrics=False, + emotion_model_size="small", + emotion_embedding_type="head_concat", + emotion_cache_dir=None, asr_batch_size=32, eou_batch_size=32, device="cuda", @@ -636,6 +756,11 @@ def evaluate( sv_model_type=sv_model_type, asr_model_name=asr_model_name, with_utmosv2=with_utmosv2, + strip_text_annotations_for_metrics=strip_text_annotations_for_metrics, + with_emotion_metrics=with_emotion_metrics, + emotion_model_size=emotion_model_size, + emotion_embedding_type=emotion_embedding_type, + emotion_cache_dir=emotion_cache_dir, asr_batch_size=asr_batch_size, eou_batch_size=eou_batch_size, device=device, @@ -725,6 +850,9 @@ def compute_global_metrics( sum(m['pred_context_ssim_alternate'] for m in filewise_metrics) / n ) avg_metrics['ssim_gt_context_avg_alternate'] = sum(m['gt_context_ssim_alternate'] for m in filewise_metrics) / n + if 'pred_gt_esim' in filewise_metrics[0]: + avg_metrics['esim_pred_gt_avg'] = sum(m['pred_gt_esim'] for m in filewise_metrics) / n + avg_metrics['ems_pred_gt_avg'] = sum(m['pred_gt_ems'] for m in filewise_metrics) / n # Cumulative WER/CER on ground-truth audio transcriptions (if available) gt_audio_texts = [m['gt_audio_text'] for m in filewise_metrics] @@ -778,6 +906,20 @@ def main(): parser.add_argument('--generated_audio_dir', type=str, default=None) parser.add_argument('--whisper_language', type=str, default="en") parser.add_argument('--evalset', type=str, default=None) + parser.add_argument('--with_emotion_metrics', action='store_true') + parser.add_argument( + '--strip_text_annotations_for_metrics', + action='store_true', + help='Strip bracket/tag/control annotations from reference and ASR hypothesis text while computing text metrics.', + ) + parser.add_argument('--emotion_model_size', type=str, default="small", choices=["small", "large"]) + parser.add_argument( + '--emotion_embedding_type', + type=str, + default="score_vector", + choices=["head_concat", "head_mean", "score_vector"], + ) + parser.add_argument('--emotion_cache_dir', type=str, default=None) args = parser.parse_args() if args.evalset is not None: @@ -793,6 +935,11 @@ def main(): args.whisper_language, sv_model_type="wavlm", asr_model_name="nvidia/parakeet-ctc-0.6b", + with_emotion_metrics=args.with_emotion_metrics, + strip_text_annotations_for_metrics=args.strip_text_annotations_for_metrics, + emotion_model_size=args.emotion_model_size, + emotion_embedding_type=args.emotion_embedding_type, + emotion_cache_dir=args.emotion_cache_dir, ) diff --git a/nemo/collections/tts/modules/magpietts_inference/evaluation.py b/nemo/collections/tts/modules/magpietts_inference/evaluation.py index 00b25614a9ab..a92659209f31 100644 --- a/nemo/collections/tts/modules/magpietts_inference/evaluation.py +++ b/nemo/collections/tts/modules/magpietts_inference/evaluation.py @@ -43,6 +43,11 @@ class EvaluationConfig: with_utmosv2: Whether to compute UTMOSv2 (Mean Opinion Score) metrics. with_fcd: Whether to compute Frechet Codec Distance metric. codec_model_path: Path to the audio codec model. If None, will skip computing Frechet Codec Distance metric. + with_emotion_metrics: Whether to compute ESIM and EMS using the emotion encoder. + emotion_model_size: Emotion encoder size ("small" or "large"). + emotion_embedding_type: Emotion embedding type used for ESIM. + emotion_cache_dir: Optional Hugging Face cache directory for the emotion encoder. + strip_text_annotations_for_metrics: Whether to strip annotation/control markers from reference and ASR hypothesis text before text metrics. device: Device to use for running models used during evaluation. """ @@ -53,7 +58,14 @@ class EvaluationConfig: with_utmosv2: bool = True with_fcd: bool = True codec_model_path: str = None + with_emotion_metrics: bool = False + emotion_model_size: str = "small" + emotion_embedding_type: str = "score_vector" + emotion_cache_dir: str = None + strip_text_annotations_for_metrics: bool = False device: str = "cuda" + asr_batch_size: int = 32 + eou_batch_size: int = 32 def evaluate_generated_audio_dir( @@ -93,8 +105,15 @@ def evaluate_generated_audio_dir( with_utmosv2=config.with_utmosv2, with_fcd=config.with_fcd, codec_model_path=config.codec_model_path, + with_emotion_metrics=config.with_emotion_metrics, + emotion_model_size=config.emotion_model_size, + emotion_embedding_type=config.emotion_embedding_type, + emotion_cache_dir=config.emotion_cache_dir, + strip_text_annotations_for_metrics=config.strip_text_annotations_for_metrics, device=config.device, eou_model_name=config.eou_model_name, + asr_batch_size=config.asr_batch_size, + eou_batch_size=config.eou_batch_size, ) return avg_metrics, filewise_metrics diff --git a/nemo/collections/tts/modules/magpietts_inference/inference.py b/nemo/collections/tts/modules/magpietts_inference/inference.py index 287149627988..7bd5f2ac8c63 100644 --- a/nemo/collections/tts/modules/magpietts_inference/inference.py +++ b/nemo/collections/tts/modules/magpietts_inference/inference.py @@ -28,6 +28,7 @@ import abc import glob +import json import os import shutil import time @@ -38,11 +39,12 @@ import torch from nemo.collections.asr.parts.utils.manifest_utils import read_manifest +from nemo.collections.audio.parts.utils.transforms import resample from nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers import AggregatedTTSTokenizer, IPATokenizer from nemo.collections.tts.data.text_to_speech_dataset import ChunkedTTSInferenceDataset, MagpieTTSDataset from nemo.collections.tts.models.easy_magpietts_inference import EasyModelInferenceParameters from nemo.collections.tts.models.magpietts import ModelInferenceParameters -from nemo.collections.tts.parts.utils.tts_dataset_utils import stack_tensors +from nemo.collections.tts.parts.utils.tts_dataset_utils import normalize_volume, stack_tensors from nemo.utils import logging @@ -680,6 +682,201 @@ def _compute_end_of_text_flags( return is_end_of_text +@dataclass +class EasyMagpieMultiturnUserAudioInferenceConfig(EasyMagpieInferenceConfig): + """Configuration for EasyMagpie multiturn user-audio inference. + + This mode keeps the standard EasyMagpie/MagpieTTS evaluation contract by + writing one evaluation row per generated agent turn: + + predicted_audio_.wav + predicted_codes_.pt + target_audio_.wav + context_audio_.wav + + The generated turn-level manifest is cached on the runner as + ``evaluation_manifest_path`` so the top-level inference script can pass it + to ``evaluate_generated_audio_dir``. + """ + + max_eval_turns: int = 6 + save_debug_multiturn_audio: bool = True + + def build_identifier(self) -> str: + return super().build_identifier() + f"_MTUserAudio_MaxTurns{self.max_eval_turns}" + + +class EasyMagpieMultiturnUserAudioDataset(torch.utils.data.Dataset): + """Manifest dataset for turn-level multiturn user-audio EasyMagpie inference.""" + + def __init__( + self, + manifest_path: str, + audio_dir: str, + model, + max_eval_turns: int = 6, + normalize_audio: bool = True, + ): + self.manifest_path = manifest_path + self.audio_dir = audio_dir or "" + self.model = model + self.max_eval_turns = max_eval_turns + self.normalize_audio = normalize_audio + self.records = read_manifest(manifest_path) + + def __len__(self): + return len(self.records) + + def __getitem__(self, idx: int): + item = dict(self.records[idx]) + item["idx"] = idx + return item + + def _resolve_path(self, path: Optional[str]) -> Optional[str]: + if path is None or path == "": + return None + if os.path.isabs(path): + return path + return os.path.join(self.audio_dir, path) + + def _load_audio_1d(self, path: str, sample_rate: int) -> torch.Tensor: + path = self._resolve_path(path) + if path is None or not os.path.exists(path): + raise FileNotFoundError(f"Missing audio path: {path}") + + audio, sr = sf.read(path, dtype="float32", always_2d=False) + + if audio.ndim == 2: + audio = audio.mean(axis=1) + + if self.normalize_audio: + audio = normalize_volume(audio) + + wav = torch.as_tensor(audio, dtype=torch.float32).flatten() + + if sr != sample_rate: + wav = resample(wav.unsqueeze(0), sr, sample_rate).squeeze(0) + + return wav.contiguous() + + @staticmethod + def _as_turn_list(value) -> List[str]: + if isinstance(value, list): + return [str(x) for x in value] + return [str(value)] + + @staticmethod + def _has_valid_turn_text(text: str) -> bool: + return any(ch.isalnum() for ch in str(text or "")) + + def collate_fn(self, batch: List[dict]) -> Dict[str, Any]: + if len(batch) != 1: + raise RuntimeError("multiturn_user_audio inference currently requires batch_size=1.") + + sample = batch[0] + model = self.model + sample_rate = model.sample_rate + main_tokenizer_name = list(model.cfg.text_tokenizers.keys())[0] + + raw_turn_texts = self._as_turn_list(sample["text"])[: self.max_eval_turns] + max_turns = len(raw_turn_texts) + + user_audio_paths = sample.get("user_audio_file_path", None) + if not isinstance(user_audio_paths, list): + user_audio_paths = [] + + # make the user audios path absolute if needed + user_audio_paths = [self._resolve_path(path) for path in user_audio_paths] + + raw_user_audio_turns = [] + for turn_id in range(max_turns): + if turn_id < len(user_audio_paths) and user_audio_paths[turn_id]: + wav = self._load_audio_1d(user_audio_paths[turn_id], sample_rate) + else: + wav = torch.zeros(int(2 * sample_rate), dtype=torch.float32) + raw_user_audio_turns.append(wav) + + batched_turns = [] + batched_turn_lens = [] + valid_turn_masks = [] + user_audio_turns = [] + user_audio_turns_lens = [] + turn_ids = [] + pending_user_audio_turns = [] + skipped_turns = 0 + for turn_id, turn_text in enumerate(raw_turn_texts): + pending_user_audio_turns.append(raw_user_audio_turns[turn_id]) + if not self._has_valid_turn_text(turn_text): + skipped_turns += 1 + continue + + ids = model.tokenizer.encode(turn_text, tokenizer_name=main_tokenizer_name) + [model.eos_id] + batched_turns.append(torch.tensor([ids], dtype=torch.long)) + batched_turn_lens.append(torch.tensor([len(ids)], dtype=torch.long)) + valid_turn_masks.append(torch.tensor([True], dtype=torch.bool)) + wav = ( + torch.cat(pending_user_audio_turns, dim=0) + if pending_user_audio_turns + else raw_user_audio_turns[turn_id] + ) + user_audio_turns.append(wav.unsqueeze(0)) + user_audio_turns_lens.append(torch.tensor([wav.numel()], dtype=torch.long)) + turn_ids.append(turn_id) + pending_user_audio_turns = [] + + if skipped_turns > 0: + logging.info( + f"Skipping {skipped_turns} empty or punctuation-only multiturn agent turns " + f"for sample_idx={sample.get('idx')}" + ) + + context_path = self._resolve_path(sample.get("context_audio_filepath")) + context_audio = self._load_audio_1d(context_path, sample_rate).unsqueeze(0) + context_audio_lens = torch.tensor([context_audio.size(1)], dtype=torch.long) + + target_turn_audio_paths = sample.get("target_audio_file_path", sample.get("target_audio_filepath", None)) + if target_turn_audio_paths is None: + target_turn_audio_paths = [] + elif not isinstance(target_turn_audio_paths, list): + target_turn_audio_paths = [target_turn_audio_paths] + + # make the target audio path absolute if needed + target_turn_audio_paths = [self._resolve_path(path) for path in target_turn_audio_paths] + target_audio_path = self._resolve_path(sample.get("audio_filepath")) + + return { + "idx": torch.tensor([int(sample["idx"])], dtype=torch.long), + "raw_record": sample, + "raw_turn_texts": [raw_turn_texts], + "batched_turns": batched_turns, + "batched_turn_lens": batched_turn_lens, + "valid_turn_masks": valid_turn_masks, + "turn_ids": turn_ids, + "context_audio": context_audio, + "context_audio_lengths": context_audio_lens, + "user_audio_turns": user_audio_turns, + "user_audio_turns_lens": user_audio_turns_lens, + "target_audio_path": target_audio_path, + "target_turn_audio_paths": target_turn_audio_paths, + "languages": [sample.get("language", "en")], + } + + +class _InferenceSubset(torch.utils.data.Dataset): + """Subset wrapper that preserves the wrapped dataset collate_fn.""" + + def __init__(self, dataset, indices: List[int]): + self.dataset = dataset + self.indices = list(indices) + self.collate_fn = getattr(dataset, "collate_fn", None) + + def __len__(self): + return len(self.indices) + + def __getitem__(self, idx: int): + return self.dataset[self.indices[idx]] + + class EasyMagpieInferenceRunner(BaseInferenceRunner): """Runner for decoder-only EasyMagpieTTSInferenceModel. @@ -704,8 +901,8 @@ def create_dataset( dataset = MagpieTTSDataset( dataset_meta=dataset_meta, sample_rate=self.model.sample_rate, - min_duration=0.5, - max_duration=20, + min_duration=None, + max_duration=None, codec_model_samples_per_frame=self.model.codec_model_samples_per_frame, bos_id=getattr(self.model, "bos_id", None), eos_id=self.model.eos_id, @@ -823,3 +1020,1176 @@ def _run_decoder_only_inference( item_idx += 1 return all_rtf_metrics, generated_audio_paths, codec_file_paths + + +class EasyMagpieMultiturnUserAudioInferenceRunner(BaseInferenceRunner): + """Runner for decoder-only EasyMagpieTTS multiturn user-audio inference. + + It generates one agent turn at a time using user-audio prefill, but writes + outputs using the standard EasyMagpie evaluation contract. Therefore the + existing magpietts_inference.py evaluation code can call + evaluate_generated_audio_dir() unchanged. + """ + + produces_turn_level_evaluation: bool = True + + def __init__(self, model, config: EasyMagpieMultiturnUserAudioInferenceConfig): + if config.batch_size != 1: + raise ValueError("EasyMagpie multiturn user-audio inference requires batch_size=1.") + super().__init__(model, config) + self.evaluation_manifest_path: Optional[str] = None + self.evaluation_audio_dir: Optional[str] = None + self.evaluation_manifest_records: Optional[List[dict]] = None + + # Used by examples/tts/magpietts_inference.py for torchrun sharding. + self.distributed_rank: int = int(os.environ.get("RANK", "0")) + self.distributed_world_size: int = int(os.environ.get("WORLD_SIZE", "1")) + + def create_dataset( + self, + dataset_meta: dict, + context_duration_min: Optional[float] = None, + context_duration_max: Optional[float] = None, + ) -> EasyMagpieMultiturnUserAudioDataset: + manifest_path, audio_dir = self._read_and_cache_manifest(dataset_meta) + logging.info("Creating multiturn user-audio inference dataset for decoder-only model") + return EasyMagpieMultiturnUserAudioDataset( + manifest_path=manifest_path, + audio_dir=audio_dir, + model=self.model, + max_eval_turns=self.config.max_eval_turns, + normalize_audio=True, + ) + + def set_distributed_context(self, rank: int, world_size: int) -> None: + self.distributed_rank = int(rank) + self.distributed_world_size = int(world_size) + + def run_inference_on_dataset( + self, + dataset: EasyMagpieMultiturnUserAudioDataset, + output_dir: str, + manifest_records: Optional[List[dict]] = None, + audio_base_dir: Optional[str] = None, + save_cross_attention_maps: bool = True, + save_context_audio: bool = True, + save_predicted_codes: bool = True, + ) -> Tuple[List[dict], List[str], List[str]]: + manifest_records, audio_base_dir = self._resolve_manifest_and_audio_dir(manifest_records, audio_base_dir) + return self._run_multiturn_user_audio_inference( + dataset=dataset, + output_dir=output_dir, + manifest_records=manifest_records, + audio_base_dir=audio_base_dir, + save_context_audio=save_context_audio, + save_predicted_codes=save_predicted_codes, + ) + + @staticmethod + def _move_batch_to_device(batch: Dict[str, Any], device: torch.device) -> Dict[str, Any]: + out = {} + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + out[key] = value.to(device) + elif isinstance(value, list): + out[key] = [v.to(device) if isinstance(v, torch.Tensor) else v for v in value] + else: + out[key] = value + return out + + @staticmethod + def _copy_file(src: Optional[str], dst: str, required: bool = False, description: str = "audio") -> Optional[str]: + """Copy an audio artifact and optionally fail fast if missing. + + Evaluation later expects target_audio_*.wav/context_audio_*.wav to exist. + Silently skipping those files makes evaluate_generated_audio_dir fail much + later with a less useful FileNotFoundError, so target/context paths should + call this with required=True. + """ + if src is None or src == "": + if required: + raise FileNotFoundError(f"Missing required {description}: source path is empty for destination {dst}") + return None + if not os.path.exists(src): + if required: + raise FileNotFoundError( + f"Missing required {description}: source path does not exist: {src}; destination would be: {dst}" + ) + return None + + os.makedirs(os.path.dirname(dst), exist_ok=True) + if os.path.lexists(dst): + os.remove(dst) + shutil.copy(src, dst) + + if required and not os.path.exists(dst): + raise FileNotFoundError(f"Failed to materialize required {description}: src={src}, dst={dst}.") + return dst + + def _resolve_audio_path(self, path: Optional[str], audio_base_dir: str) -> Optional[str]: + if path is None or path == "": + return None + if os.path.isabs(path): + return path + return os.path.join(audio_base_dir, path) + + def _ensure_codec_silence_codes(self) -> torch.Tensor: + """Ensure silence codec codes exist before streaming_prefill_profile. + + Newer EasyMagpieTTSInferenceModel exposes codec_sil_codes as a @property + backed by _codec_sil_codes_buffer. Some older branches/checkpoints do not + have that property, but still have _generate_codec_silence_buffer(). This + helper supports both cases and creates a plain module attribute fallback + when the property is absent. + """ + if not hasattr(self.model, "_codec_sil_codes_buffer"): + if not hasattr(self.model, "_generate_codec_silence_buffer"): + raise AttributeError( + "Model does not have _codec_sil_codes_buffer or _generate_codec_silence_buffer(); " + "cannot run multiturn_user_audio prefill." + ) + self.model._generate_codec_silence_buffer() + + class_codec_sil_codes = getattr(type(self.model), "codec_sil_codes", None) + if class_codec_sil_codes is None: + # Compatibility with branches where streaming_prefill_profile expects + # self.codec_sil_codes but the @property was not added. + self.model.codec_sil_codes = self.model._codec_sil_codes_buffer + if hasattr(self.model, "_codec_sil_codes_buffer_unconverted"): + self.model.codec_sil_codes_unconverted = self.model._codec_sil_codes_buffer_unconverted + + return self.model._codec_sil_codes_buffer.to(self.model.device).long() + + @staticmethod + def _left_pad_raw_audio_if_short( + user_audio: torch.Tensor, + user_audio_lens: torch.Tensor, + min_len: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Left-pad one raw user-audio turn with silence only when it is shorter than min_len.""" + min_len = int(min_len) + if min_len <= 0: + return user_audio, user_audio_lens + if user_audio_lens.numel() != 1: + raise RuntimeError("multiturn_user_audio raw-audio left padding expects batch_size=1.") + + # In this runner each user turn is kept as its own [1, T] tensor, so T + # is the valid raw-audio length. Checking size(-1) avoids a CUDA sync + # from user_audio_lens[0].item() in the common no-padding path. + current_len = int(user_audio.size(-1)) + if current_len >= min_len: + return user_audio, user_audio_lens + + padded_audio = user_audio.new_zeros(user_audio.shape[:-1] + (min_len,)) + if current_len > 0: + padded_audio[..., -current_len:] = user_audio + + padded_lens = user_audio_lens.new_full(user_audio_lens.shape, min_len) + return padded_audio, padded_lens + + def _run_multiturn_generation(self, batch: Dict[str, Any]): + model = self.model + device = model.device + B = int(batch["context_audio"].size(0)) + if B != 1: + raise RuntimeError("multiturn_user_audio generation requires batch_size=1.") + + with torch.inference_mode(): + # streaming_prefill_profile reads self.model.codec_sil_codes, so make + # sure the silence buffer/property exists before entering the turn loop. + self._ensure_codec_silence_codes() + + wav = batch["context_audio"] + wav_len = batch["context_audio_lengths"] + codes, codes_lens = model._codec_helper.audio_to_codes(wav, wav_len) + + # add language on context if needed + use_lang = bool(getattr(model, "add_language_to_context_text", False)) + + languages = batch.get("languages", None) + if languages is None: + raise RuntimeError("Missing batch['languages']; collate_fn must provide one language per sample.") + + if not isinstance(languages, (list, tuple)): + raise RuntimeError(f"Expected batch['languages'] to be a list/tuple, got {type(languages)}") + + if len(languages) != B: + raise RuntimeError(f"Expected {B} language entries from collate_fn, got {len(languages)}") + + ctx_texts = [] + for b, language in enumerate(languages): + if language is None or language == "": + raise RuntimeError(f"Missing language for batch item {b}") + ctx_texts.append(f"[{str(language).upper()}]" if use_lang else "[NO TEXT CONTEXT]") + + ctx_text_ids_list = [ + model.tokenizer.encode(ctx_text, tokenizer_name=model.text_conditioning_tokenizer_name) + for ctx_text in ctx_texts + ] + + ctx_toks_lens = torch.tensor( + [len(ctx_text_ids) for ctx_text_ids in ctx_text_ids_list], + dtype=torch.long, + device=device, + ) + + max_ctx_len = int(ctx_toks_lens.max().item()) + ctx_pad_id = int(getattr(model, "pad_id", 0)) + ctx_toks = torch.full((B, max_ctx_len), ctx_pad_id, dtype=torch.long, device=device) + + for b, ctx_text_ids in enumerate(ctx_text_ids_list): + ctx_toks[b, : len(ctx_text_ids)] = torch.tensor(ctx_text_ids, dtype=torch.long, device=device) + + params = self.config.model_inference_parameters + state = model.streaming_init( + context_audio_codes=codes, + context_audio_codes_lens=codes_lens, + context_text_tokens=ctx_toks, + context_text_tokens_lens=ctx_toks_lens, + use_cfg=self.config.use_cfg, + cfg_scale=params.cfg_scale, + use_local_transformer=self.config.use_local_transformer, + temperature=params.temperature, + topk=params.topk, + phoneme_input_type="pred", + phoneme_sampling_method=self.config.phoneme_sampling_method, + use_inference_mode=True, + ) + + turn_frame_ranges = [] + turn_phoneme_outputs = [] + decode_start_frame = 0 + max_decoder_steps = params.max_decoder_steps + + turn_ids = batch.get("turn_ids", list(range(len(batch["batched_turns"])))) + for local_turn_idx in range(len(batch["batched_turns"])): + turn_id = int(turn_ids[local_turn_idx]) + turn_text = batch["batched_turns"][local_turn_idx].to(device) + turn_lens = batch["batched_turn_lens"][local_turn_idx].to(device) + valid_mask = batch["valid_turn_masks"][local_turn_idx].to(device) + if not bool(valid_mask[0].item()): + continue + + phoneme_start_step = len(getattr(state, "all_phoneme_predictions", []) or []) + + state.finished.zero_() + state.text_finished.zero_() + state.audio_prediction_end_idx.fill_(-1) + for attr in [ + "turn_text_tokens_seen", + "phoneme_steps", + "phoneme_stream_ended", + "phoneme_eos_detected", + ]: + if hasattr(state, attr): + getattr(state, attr).zero_() + state.last_phoneme_tokens = None + + delay_tokens = int(state.config.training_mode.streaming_speech_delay) + samples_per_streaming_step = int(model.codec_model_samples_per_frame * model.frame_stacking_factor) + + if not model.cfg.get("condition_on_user_speech", False): + user_audio = batch["user_audio_turns"][local_turn_idx] + user_audio_prefill_steps = int(round(user_audio.size(-1) / samples_per_streaming_step)) + user_audio_prefill_tokens = torch.full( + (1, user_audio_prefill_steps), model.pad_id, dtype=torch.long, device=device + ) + user_audio_channel_embedding = None + else: + user_audio = batch["user_audio_turns"][local_turn_idx] + user_audio_lens = batch["user_audio_turns_lens"][local_turn_idx] + min_user_audio_len = delay_tokens * samples_per_streaming_step + user_audio, user_audio_lens = self._left_pad_raw_audio_if_short( + user_audio=user_audio, + user_audio_lens=user_audio_lens, + min_len=min_user_audio_len, + ) + + user_audio_codes, user_audio_codes_lens = model._codec_helper.audio_to_codes( + user_audio, user_audio_lens + ) + + if model._codec_converter is not None: + user_audio_codes = model._codec_converter.convert_original_to_new( + audio_tokens=user_audio_codes, + audio_lens=user_audio_codes_lens, + ).long() + + user_audio_codes, user_audio_codes_lens = model.stack_codes( + user_audio_codes, + user_audio_codes_lens, + model.audio_bos_id, + model.audio_eos_id, + model.frame_stacking_factor, + model.num_audio_codebooks, + ) + + user_audio_embedded = model.embed_audio_tokens(user_audio_codes) + boundary_trim = model.cfg.get("user_audio_boundary_trim", 4) + boundary_trim = 0 if boundary_trim is None else int(boundary_trim) + + if boundary_trim == 0: + real_start = 0 + real_end = int(user_audio_codes_lens[0].item()) + else: + real_start = 1 + real_end = max(real_start, int(user_audio_codes_lens[0].item()) - 1) + + user_audio_embedded = user_audio_embedded[:, real_start:real_end] + copy_len = user_audio_embedded.size(1) + if boundary_trim > 0: + trim = min(boundary_trim, copy_len // 2) + if trim > 0: + user_audio_embedded[:, :trim] = 0.0 + user_audio_embedded[:, copy_len - trim :] = 0.0 + + bos_user_pad = torch.zeros( + user_audio_embedded.size(0), + 1, + user_audio_embedded.size(2), + device=user_audio_embedded.device, + dtype=user_audio_embedded.dtype, + ) + user_audio_embedded = torch.cat([bos_user_pad, user_audio_embedded], dim=1) + user_audio_prefill_steps = user_audio_embedded.size(1) + user_audio_prefill_tokens = torch.full( + (B, user_audio_prefill_steps), model.pad_id, dtype=torch.long, device=device + ) + user_audio_channel_embedding = user_audio_embedded + + if user_audio_channel_embedding is None: + delay_tokens = min(delay_tokens, user_audio_prefill_steps) + elif user_audio_prefill_steps < delay_tokens: + raise RuntimeError( + "Raw user-audio left padding did not produce enough warmup steps: " + f"user_audio_prefill_steps={user_audio_prefill_steps}, delay_tokens={delay_tokens}." + ) + + num_warmup_text_tokens = min(delay_tokens, int(turn_lens[0].item()), turn_text.size(1)) + # handle short turns so that it does not advance the text channel and keep the delay_tokens. + warmup_tokens = torch.full( + (B, delay_tokens), + model.pad_id, + dtype=turn_text.dtype, + device=device, + ) + + if num_warmup_text_tokens > 0: + warmup_tokens[:, :num_warmup_text_tokens] = turn_text[:, :num_warmup_text_tokens] + + turn_text = turn_text[:, num_warmup_text_tokens:] + turn_lens = torch.clamp(turn_lens - num_warmup_text_tokens, min=0) + + if user_audio_channel_embedding is not None and delay_tokens > 0: + warmup_user_audio = user_audio_channel_embedding[:, -delay_tokens:] + user_audio_channel_embedding = user_audio_channel_embedding[:, :-delay_tokens] + user_audio_prefill_tokens = user_audio_prefill_tokens[:, :-delay_tokens] + else: + warmup_user_audio = None + + if user_audio_prefill_tokens.size(1) > 0: + state = model.streaming_prefill_profile( + state=state, + text_tokens=user_audio_prefill_tokens, + use_inference_mode=True, + user_audio_channel_embedding=user_audio_channel_embedding, + ) + + for i in range(delay_tokens): + user_step_emb = warmup_user_audio[:, i] if warmup_user_audio is not None else None + state.finished.zero_() + state, _, _ = model.streaming_step( + state=state, + text_tokens=warmup_tokens[:, i], + user_audio_channel_embedding=user_step_emb, + prefill_like_step=True, + prefill_like_is_last_step=(i == delay_tokens - 1), + use_inference_mode=True, + ) + + turn_start_frame = sum(p.size(-1) for p in state.all_predictions) + if not turn_frame_ranges: + state.audio_prediction_start_idx.fill_(turn_start_frame) + decode_start_frame = turn_start_frame + + turn_offset = state.text_tokens_seen.clone() + steps = 0 + while steps < max_decoder_steps: + steps += 1 + state.finished.zero_() + relative_position = state.text_tokens_seen - turn_offset + text_exhausted = relative_position >= turn_lens + + if turn_text.size(1) == 0: + current_tokens = torch.full((B,), model.eos_id, dtype=torch.long, device=device) + else: + position = relative_position.clamp(min=0, max=turn_text.size(1) - 1) + current_tokens = turn_text[torch.arange(B, device=device), position] + current_tokens = torch.where( + text_exhausted, + torch.full_like(current_tokens, model.eos_id), + current_tokens, + ) + + state, _, _ = model.streaming_step( + state=state, + text_tokens=current_tokens, + use_inference_mode=True, + ) + + if bool(text_exhausted[0].item()) and bool(state.finished[0].item()): + break + + state.audio_prediction_end_idx.fill_(-1) + state.finished.zero_() + turn_end_frame = sum(p.size(-1) for p in state.all_predictions) + phoneme_end_step = len(getattr(state, "all_phoneme_predictions", []) or []) + ( + predicted_phoneme_text, + predicted_phoneme_tokens, + predicted_phoneme_token_labels, + ) = self._decode_phoneme_prediction_slice( + getattr(state, "all_phoneme_predictions", []), + phoneme_start_step, + phoneme_end_step, + ) + turn_frame_ranges.append((turn_id, turn_start_frame, turn_end_frame)) + turn_phoneme_outputs.append( + { + "turn_id": turn_id, + "predicted_phoneme_text": predicted_phoneme_text, + "predicted_phoneme_tokens": predicted_phoneme_tokens, + "predicted_phoneme_token_labels": predicted_phoneme_token_labels, + } + ) + + codec_sil_codes = self._ensure_codec_silence_codes() + bos_id = getattr(model, "audio_bos_id", -1) + eos_id = getattr(model, "audio_eos_id", -1) + speaking_id = getattr(model, "audio_user_speaking_id", -1) + speaking_end_id = getattr(model, "audio_user_speaking_end_id", -1) + sil_injection = codec_sil_codes.view(1, -1, 1) + + for step_idx in range(len(state.all_predictions)): + pred = state.all_predictions[step_idx] + mask = (pred == bos_id) | (pred == eos_id) | (pred == speaking_id) | (pred == speaking_end_id) + frame_mask = mask.any(dim=1, keepdim=True) + if frame_mask.any(): + state.all_predictions[step_idx] = torch.where(frame_mask, sil_injection.expand_as(pred), pred) + + state.audio_prediction_end_idx.fill_(-1) + generated_codes = None + if getattr(state, "all_predictions", None): + generated_codes = torch.cat(state.all_predictions, dim=-1).detach() + + finalize_output = model.streaming_finalize(state, use_inference_mode=True) + + return finalize_output, turn_frame_ranges, turn_phoneme_outputs, decode_start_frame, generated_codes + + @staticmethod + def _gpt2_byte_decoder() -> Dict[str, int]: + """Return GPT-2/byte-level-BPE unicode-char -> byte lookup. + + Some NeMo tokenizer wrappers expose byte-level BPE pieces instead of + fully decoded text. Those pieces look like ``ËĪ``/``Ġ`` in JSON. They + are not mojibake from JSON itself; they are the reversible GPT-2 byte + alphabet. Reversing that alphabet gives readable IPA again. + """ + byte_values = list(range(ord("!"), ord("~") + 1)) + byte_values += list(range(ord("¡"), ord("¬") + 1)) + byte_values += list(range(ord("®"), ord("ÿ") + 1)) + unicode_values = list(byte_values) + n = 0 + for byte in range(256): + if byte not in byte_values: + byte_values.append(byte) + unicode_values.append(256 + n) + n += 1 + return {chr(unicode_value): byte for byte, unicode_value in zip(byte_values, unicode_values)} + + @classmethod + def _decode_byte_level_bpe_text(cls, text: str) -> str: + """Decode GPT-2 byte-level BPE artifacts such as ``ËĪ`` and ``Ġ``. + + If ``text`` is already normal Unicode IPA, it is returned unchanged. + """ + if not text: + return "" + + byte_artifact_markers = ("Ġ", "Ċ", "Ë", "É", "Ê", "Ã", "Å", "Î") + if not any(marker in text for marker in byte_artifact_markers): + return text.strip() + + byte_decoder = cls._gpt2_byte_decoder() + byte_values = bytearray() + for ch in text: + value = byte_decoder.get(ch) + if value is not None: + byte_values.append(value) + + if not byte_values: + return text.strip() + + try: + decoded = byte_values.decode("utf-8") + except UnicodeDecodeError: + return text.replace("Ġ", " ").replace("Ċ", "\n").strip() + + return " ".join(decoded.split()) + + @staticmethod + def _token_id_as_int(value) -> Optional[int]: + """Return a scalar token id as int, or None when it is not a valid id.""" + if value is None: + return None + if isinstance(value, torch.Tensor): + if value.numel() != 1: + return None + value = value.detach().cpu().item() + if isinstance(value, int): + return value + if isinstance(value, str): + value = value.strip() + if not value: + return None + signless = value[1:] if value[0] in ("+", "-") else value + if signless.isdigit(): + return int(value) + return None + + @staticmethod + def _phoneme_special_token_map(tokenizer) -> Dict[int, str]: + """Map phoneme special token ids to stable debug labels. + + IPABPETokenizer stores its specials directly in ``_inv_vocab`` as + ````, ````, and ````. Prefer those labels so + the exported debug JSON matches the tokenizer vocabulary. + """ + special: Dict[int, str] = {} + + inv_vocab = getattr(tokenizer, "_inv_vocab", None) + if isinstance(inv_vocab, dict): + for token_id, piece in inv_vocab.items(): + token_id = EasyMagpieMultiturnUserAudioInferenceRunner._token_id_as_int(token_id) + if token_id is None: + continue + piece = str(piece) + if piece.startswith("<") and piece.endswith(">"): + special[token_id] = piece + + special_attrs = [ + ("bos_token_id", ""), + ("bos", ""), + ("eos_token_id", ""), + ("eos", ""), + ("pad_token_id", ""), + ("pad", ""), + ("unk_token_id", ""), + ("mask_token_id", ""), + ] + for attr, label in special_attrs: + token_id = EasyMagpieMultiturnUserAudioInferenceRunner._token_id_as_int(getattr(tokenizer, attr, None)) + if token_id is not None: + special.setdefault(token_id, label) + + return special + + @classmethod + def _phoneme_token_pieces(cls, tokenizer, token_ids: List[int]) -> List[str]: + """Return tokenizer pieces for ids without dropping special tokens.""" + token_ids = [int(token_id) for token_id in token_ids] + + inv_vocab = getattr(tokenizer, "_inv_vocab", None) + if isinstance(inv_vocab, dict): + return [str(inv_vocab.get(token_id, "")) for token_id in token_ids] + + internal_tokenizer = getattr(tokenizer, "_tokenizer", None) + if internal_tokenizer is not None and hasattr(internal_tokenizer, "id_to_token"): + pieces = [] + for token_id in token_ids: + piece = internal_tokenizer.id_to_token(token_id) + pieces.append("" if piece is None else str(piece)) + return pieces + + tokens = getattr(tokenizer, "tokens", None) + if isinstance(tokens, dict): + inv = {int(idx): str(piece) for piece, idx in tokens.items() if cls._token_id_as_int(idx) is not None} + return [inv.get(token_id, "") for token_id in token_ids] + + return [str(token_id) for token_id in token_ids] + + def _decode_phoneme_id_run(self, token_ids: List[int]) -> str: + """Decode a contiguous non-special phoneme-id span to readable IPA.""" + if self.model.phoneme_tokenizer is None or not token_ids: + return "" + + tokenizer = self.model.phoneme_tokenizer + pieces = self._phoneme_token_pieces(tokenizer, token_ids) + if pieces: + # IPABPETokenizer piece strings are byte-level BPE symbols. Joining + # and byte-decoding them produces readable IPA, e.g. ``ËĪÉijËIJ`` -> ``ˈɑː``. + decoded_from_pieces = self._decode_byte_level_bpe_text("".join(pieces)) + if decoded_from_pieces: + return decoded_from_pieces + + decoder = getattr(tokenizer, "decode", None) + if callable(decoder): + decoded = decoder([int(token_id) for token_id in token_ids]) + return self._decode_byte_level_bpe_text(str(decoded or "")) + + return " ".join(pieces) + + def _format_phoneme_debug_sequence(self, token_ids: List[int]) -> Tuple[str, List[str]]: + """Return readable text and per-token labels while keeping specials. + + ``predicted_phoneme_tokens`` keeps the integer ids, including special + markers. This function creates the human-readable companion fields: + ``predicted_phoneme_text`` and ``predicted_phoneme_token_labels``. + """ + tokenizer = self.model.phoneme_tokenizer + special = self._phoneme_special_token_map(tokenizer) + labels: List[str] = [] + text_parts: List[str] = [] + run: List[int] = [] + + def flush_run() -> None: + if not run: + return + decoded_run = self._decode_phoneme_id_run(run) + if decoded_run: + text_parts.append(decoded_run) + run.clear() + + for token_id in token_ids: + token_id = int(token_id) + special_label = special.get(token_id) + if special_label is not None: + flush_run() + text_parts.append(special_label) + labels.append(f"{special_label}:{token_id}") + else: + run.append(token_id) + decoded_piece = self._decode_phoneme_id_run([token_id]) + labels.append(f"{token_id}:{decoded_piece}" if decoded_piece else str(token_id)) + flush_run() + + return " ".join(part for part in text_parts if part), labels + + def _decode_phoneme_prediction_slice( + self, + phoneme_predictions, + start_step: int, + end_step: int, + ) -> Tuple[str, List[int], List[str]]: + if self.model.phoneme_tokenizer is None or not phoneme_predictions: + return "", [], [] + + start_step = max(0, min(int(start_step), len(phoneme_predictions))) + end_step = max(start_step, min(int(end_step), len(phoneme_predictions))) + if end_step <= start_step: + return "", [], [] + + phoneme_tensor = torch.stack(phoneme_predictions[start_step:end_step], dim=-1) + raw_tokens = phoneme_tensor[0].detach().cpu().T.reshape(-1).long().tolist() + + tokenizer = self.model.phoneme_tokenizer + eos_id = self._token_id_as_int(getattr(tokenizer, "eos_token_id", None)) + bos_id = self._token_id_as_int(getattr(tokenizer, "bos_token_id", None)) + + # The streaming phoneme predictor is seeded with BOS as input before the + # first predicted phoneme. Add it explicitly to the debug sequence so the + # exported JSON shows the same boundary condition used by inference. + debug_tokens: List[int] = [] + if bos_id is not None: + debug_tokens.append(bos_id) + + for token in raw_tokens: + token = int(token) + debug_tokens.append(token) + if eos_id is not None and token == eos_id: + break + + if not debug_tokens: + return "", [], [] + + phoneme_text, token_labels = self._format_phoneme_debug_sequence(debug_tokens) + return phoneme_text, debug_tokens, token_labels + + @staticmethod + def _save_code_slice( + generated_codes, batch_idx: int, start_frame: int, end_frame: int, path: str + ) -> Optional[str]: + if generated_codes is None: + return None + os.makedirs(os.path.dirname(path), exist_ok=True) + total_frames = int(generated_codes.size(-1)) + start_frame = max(0, min(int(start_frame), total_frames)) + end_frame = max(start_frame, min(int(end_frame), total_frames)) + if end_frame <= start_frame: + return None + codes = generated_codes[batch_idx, :, start_frame:end_frame].detach().cpu().long() + torch.save(codes, path) + return path + + def _resolve_target_audio_for_turn( + self, + raw_record: dict, + target_turn_audio_paths, + local_turn_idx: int, + audio_base_dir: str, + ) -> Optional[str]: + """Resolve the GT target audio for one evaluation turn. + + Prefer per-turn GT if present; otherwise fall back to sample-level audio_filepath. + If no candidate exists, returns None so the caller can fall back to + context audio and keep EasyMagpie evaluation from failing on missing + target_audio_*.wav. + """ + candidates = [] + + if isinstance(target_turn_audio_paths, list) and local_turn_idx < len(target_turn_audio_paths): + candidates.append(target_turn_audio_paths[local_turn_idx]) + elif isinstance(target_turn_audio_paths, str): + candidates.append(target_turn_audio_paths) + + # Common manifest keys. + candidates.extend( + [ + raw_record.get("target_audio_file_path"), + raw_record.get("target_audio_filepath"), + raw_record.get("audio_filepath"), + ] + ) + + tried = [] + for candidate in candidates: + if candidate is None or candidate == "": + continue + if isinstance(candidate, list): + if local_turn_idx < len(candidate): + candidate = candidate[local_turn_idx] + else: + continue + resolved = self._resolve_audio_path(candidate, audio_base_dir) + tried.append(resolved) + if resolved is not None and os.path.exists(resolved): + return resolved + + logging.warning( + "Could not resolve target audio for multiturn_user_audio evaluation turn; " + "caller will fall back to context audio. " + f"sample_idx={raw_record.get('idx')}, local_turn_idx={local_turn_idx}, " + f"audio_base_dir={audio_base_dir}, tried={tried}, " + f"raw_record_keys={sorted(raw_record.keys())}" + ) + return None + + @staticmethod + def _get_multiturn_debug_output_dir(output_dir: str) -> str: + """Return a sibling audios_MT path for debug/listening artifacts. + + Examples: + /audio/repeat_0 + -> /audios_MT/repeat_0 + /audio/repeat_0/rank_0003 + -> /audios_MT/repeat_0/rank_0003 + + Evaluation files stay in the standard audio/ directory; only + human-listening/debug multiturn files go under audios_MT/. + """ + normalized = os.path.normpath(output_dir) + parts = normalized.split(os.sep) + for i in range(len(parts) - 1, -1, -1): + if parts[i] == "audio": + parts[i] = "audios_MT" + prefix = os.sep if normalized.startswith(os.sep) else "" + return prefix + os.path.join(*[p for p in parts if p != ""]) + return normalized + "_MT" + + @staticmethod + def _safe_audio_stem(path_or_name: Optional[str], fallback: str) -> str: + if path_or_name is None or path_or_name == "": + stem = fallback + else: + stem = os.path.splitext(os.path.basename(str(path_or_name)))[0] or fallback + safe = "".join(ch if ch.isalnum() or ch in ("-", "_", ".") else "_" for ch in stem) + return safe or fallback + + def _target_audio_stem_for_debug( + self, + raw_record: dict, + sample_idx: int, + local_turn_idx: Optional[int] = None, + ) -> str: + """Choose a readable debug/listening filename stem from original manifest target audio. + + This does not affect evaluator file names. It is only used under audios_MT/. + Prefer per-turn target fields when present; otherwise use sample-level + audio_filepath. Fall back to sample_. + """ + fallback = f"sample_{sample_idx}" + candidates = [ + raw_record.get("target_audio_file_path"), + raw_record.get("target_audio_filepath"), + raw_record.get("audio_filepath"), + ] + for candidate in candidates: + if candidate is None or candidate == "": + continue + if isinstance(candidate, list): + if local_turn_idx is None: + candidate = candidate[0] if candidate else None + elif local_turn_idx < len(candidate): + candidate = candidate[local_turn_idx] + else: + candidate = None + if candidate: + return self._safe_audio_stem(candidate, fallback) + return fallback + + def _run_multiturn_user_audio_inference( + self, + dataset: EasyMagpieMultiturnUserAudioDataset, + output_dir: str, + manifest_records: List[dict], + audio_base_dir: str, + save_context_audio: bool = True, + save_predicted_codes: bool = True, + ) -> Tuple[List[dict], List[str], List[str]]: + os.makedirs(output_dir, exist_ok=True) + self._delete_old_generated_files(output_dir) + + mt_debug_output_dir = self._get_multiturn_debug_output_dir(output_dir) + debug_mixed_dir = os.path.join(mt_debug_output_dir, "debug_mixed_user_agent") + if self.config.save_debug_multiturn_audio: + os.makedirs(debug_mixed_dir, exist_ok=True) + logging.info(f"Saving multiturn debug/listening audios under audios_MT: {mt_debug_output_dir}") + + rank = int(getattr(self, "distributed_rank", 0)) + world_size = int(getattr(self, "distributed_world_size", 1)) + total_samples = len(dataset) + + if world_size > 1: + rank_indices = list(range(rank, total_samples, world_size)) + logging.info( + f"[MT_USER_AUDIO_SPLIT] rank={rank}/{world_size} " + f"total_samples={total_samples} local_samples={len(rank_indices)} " + f"indices={rank_indices}" + ) + dataset_for_rank = _InferenceSubset(dataset, rank_indices) + else: + rank_indices = list(range(total_samples)) + logging.info( + f"[MT_USER_AUDIO_SPLIT] rank=0/1 " + f"total_samples={total_samples} local_samples={len(rank_indices)} " + f"indices={rank_indices}" + ) + dataset_for_rank = dataset + + dataloader = torch.utils.data.DataLoader( + dataset_for_rank, + batch_size=1, + collate_fn=dataset_for_rank.collate_fn, + num_workers=0, + shuffle=False, + ) + + all_rtf_metrics = [] + generated_audio_paths = [] + codec_file_paths = [] + turn_manifest_records = [] + item_idx = 0 + + sample_rate = getattr(self.model, "output_sample_rate", self.model.sample_rate) + + for batch_idx, batch in enumerate(dataloader): + logging.info(f"Processing multiturn user-audio sample {batch_idx + 1}/{len(dataloader)}") + batch = self._move_batch_to_device(batch, self.model.device) + sample_idx = int(batch["idx"][0].item()) + raw_record = batch["raw_record"] + raw_turn_texts = batch["raw_turn_texts"][0] + logging.info( + f"[MT_USER_AUDIO_PROCESS] rank={rank}/{world_size} " + f"local_batch={batch_idx} global_sample_idx={sample_idx} " + f"num_turns={len(raw_turn_texts)}" + ) + + if not batch["batched_turns"]: + logging.warning( + "Skipping multiturn_user_audio sample with no valid agent text turns after filtering; " + f"sample_idx={sample_idx}" + ) + continue + + start_time = time.time() + output, turn_frame_ranges, turn_phoneme_outputs, decode_start_frame, generated_codes = ( + self._run_multiturn_generation(batch) + ) + elapsed = time.time() - start_time + + predicted_audio = output.audio.float().detach().cpu() + predicted_audio_lens = output.audio_len.int().detach().cpu() + full_len = int(predicted_audio_lens[0].item()) + full_wav = predicted_audio[0, :full_len] + + samples_per_prediction_frame = self.model.codec_model_samples_per_frame / ( + self.model.sample_rate / sample_rate + ) + aligned_agent = torch.zeros_like(full_wav) + + context_len = int(batch["context_audio_lengths"][0].detach().cpu().item()) + context_wav = batch["context_audio"][0, :context_len].detach().cpu().float() + context_audio_path = os.path.join(output_dir, f"context_audio_sample_{sample_idx}.wav") + # Always write context audio because evaluate_generated_audio_dir reads + # context_audio_filepath from the generated turn-level manifest for every repeat. + sf.write(context_audio_path, context_wav.numpy(), self.model.sample_rate) + + target_turn_audio_paths = batch.get("target_turn_audio_paths") + + for local_turn_idx, (turn_id, start_frame, end_frame) in enumerate(turn_frame_ranges): + source_turn_idx = int(turn_id) + phoneme_output = ( + turn_phoneme_outputs[local_turn_idx] if local_turn_idx < len(turn_phoneme_outputs) else {} + ) + rel_start_frame = start_frame - decode_start_frame + rel_end_frame = end_frame - decode_start_frame + start_sample = int(round(rel_start_frame * samples_per_prediction_frame)) + end_sample = int(round(rel_end_frame * samples_per_prediction_frame)) + start_sample = max(0, min(start_sample, full_len)) + end_sample = max(start_sample, min(end_sample, full_len)) + + aligned_agent[start_sample:end_sample] = full_wav[start_sample:end_sample] + turn_wav = aligned_agent[start_sample:end_sample].float() + + predicted_audio_path = os.path.join(output_dir, f"predicted_audio_{item_idx}.wav") + sf.write(predicted_audio_path, turn_wav.numpy(), sample_rate) + generated_audio_paths.append(predicted_audio_path) + + if save_predicted_codes: + code_path = os.path.join(output_dir, f"predicted_codes_{item_idx}.pt") + saved_code_path = self._save_code_slice(generated_codes, 0, start_frame, end_frame, code_path) + if saved_code_path is not None: + codec_file_paths.append(saved_code_path) + + turn_context_path = os.path.join(output_dir, f"context_audio_{item_idx}.wav") + self._copy_file( + context_audio_path, + turn_context_path, + required=True, + description=f"context audio for sample_idx={sample_idx}, turn_id={turn_id}", + ) + + target_src = self._resolve_target_audio_for_turn( + raw_record=raw_record, + target_turn_audio_paths=target_turn_audio_paths, + local_turn_idx=source_turn_idx, + audio_base_dir=audio_base_dir, + ) + if target_src is None or not os.path.exists(target_src): + logging.warning( + "Target audio is missing for multiturn_user_audio evaluation; " + "using context audio as target fallback to avoid evaluator failure. " + f"sample_idx={sample_idx}, turn_id={turn_id}, missing_target={target_src}, " + f"context_audio_path={context_audio_path}" + ) + target_src = context_audio_path + + target_dst = os.path.join(output_dir, f"target_audio_{item_idx}.wav") + self._copy_file( + target_src, + target_dst, + required=True, + description=f"target audio fallback/context for sample_idx={sample_idx}, turn_id={turn_id}", + ) + + turn_manifest_records.append( + { + "audio_filepath": f"target_audio_{item_idx}.wav", + "context_audio_filepath": f"context_audio_{item_idx}.wav", + "text": raw_turn_texts[source_turn_idx] if source_turn_idx < len(raw_turn_texts) else "", + "predicted_phoneme_text": phoneme_output.get("predicted_phoneme_text", ""), + "predicted_phoneme_tokens": phoneme_output.get("predicted_phoneme_tokens", []), + "predicted_phoneme_token_labels": phoneme_output.get("predicted_phoneme_token_labels", []), + "speaker": str(sample_idx), + "source_sample_idx": sample_idx, + "turn_id": int(turn_id), + } + ) + logging.info( + f"[MT_USER_AUDIO_TURN] rank={rank}/{world_size} " + f"global_sample_idx={sample_idx} turn_id={int(turn_id)} " + f"rank_local_item_idx={item_idx} " + f"predicted_audio=predicted_audio_{item_idx}.wav " + f"target_audio=target_audio_{item_idx}.wav " + f"context_audio=context_audio_{item_idx}.wav" + ) + item_idx += 1 + + if self.config.save_debug_multiturn_audio and "user_audio_turns" in batch: + self._save_debug_user_agent_audio( + batch=batch, + sample_idx=sample_idx, + raw_record=raw_record, + turn_frame_ranges=turn_frame_ranges, + decode_start_frame=decode_start_frame, + aligned_agent=aligned_agent, + samples_per_prediction_frame=samples_per_prediction_frame, + output_dir=output_dir, + debug_mixed_dir=debug_mixed_dir, + ) + + audio_seconds = ( + sum( + max(0, int(round((end - start) * samples_per_prediction_frame))) + for _, start, end in turn_frame_ranges + ) + / sample_rate + ) + all_rtf_metrics.append( + { + "inference_time": elapsed, + "audio_seconds": audio_seconds, + "rtf": elapsed / audio_seconds if audio_seconds > 0 else 0.0, + } + ) + + self.evaluation_audio_dir = output_dir + rank = int(getattr(self, "distributed_rank", 0)) + self.evaluation_manifest_path = os.path.join( + output_dir, f"multiturn_user_audio_turn_manifest_rank{rank:04d}.jsonl" + ) + self.evaluation_manifest_records = turn_manifest_records + with open(self.evaluation_manifest_path, "w", encoding="utf-8") as f: + for record in turn_manifest_records: + f.write(json.dumps(record, ensure_ascii=False) + "\n") + + logging.info( + f"[MT_USER_AUDIO_SUMMARY] rank={rank}/{world_size} " + f"assigned_samples={rank_indices} " + f"generated_turns={len(turn_manifest_records)} " + f"generated_audio_paths={len(generated_audio_paths)} " + f"predicted_code_paths={len(codec_file_paths)} " + f"manifest={self.evaluation_manifest_path}" + ) + logging.info(f"Wrote multiturn turn-level evaluation manifest: {self.evaluation_manifest_path}") + return all_rtf_metrics, generated_audio_paths, codec_file_paths + + def _save_debug_user_agent_audio( + self, + batch: Dict[str, Any], + sample_idx: int, + raw_record: dict, + turn_frame_ranges: List[Tuple[int, int, int]], + decode_start_frame: int, + aligned_agent: torch.Tensor, + samples_per_prediction_frame: float, + output_dir: str, + debug_mixed_dir: str, + ) -> None: + sample_rate = getattr(self.model, "output_sample_rate", self.model.sample_rate) + debug_sample_stem = self._target_audio_stem_for_debug(raw_record, sample_idx) + first_user_len_in = int(batch["user_audio_turns_lens"][0][0].detach().cpu().item()) + first_user_delay_out = int(round(first_user_len_in * sample_rate / self.model.sample_rate)) + + def load_debug_audio(path: Optional[str]) -> Optional[torch.Tensor]: + if path is None or path == "" or not os.path.exists(path): + return None + audio, sr = sf.read(path, dtype="float32", always_2d=False) + wav = torch.as_tensor(audio, dtype=torch.float32) + if wav.ndim == 2: + wav = wav.mean(dim=1) + wav = wav.flatten() + if sr != sample_rate: + wav = resample(wav.unsqueeze(0), sr, sample_rate).squeeze(0) + return wav.contiguous() + + user_segments = [] + user_turns_for_gt = [] + for local_turn_idx, (turn_id, _, _) in enumerate(turn_frame_ranges): + if local_turn_idx >= len(batch["user_audio_turns"]): + continue + turn_audio = batch["user_audio_turns"][local_turn_idx][0].detach().cpu().float() + turn_audio_len = int(batch["user_audio_turns_lens"][local_turn_idx][0].detach().cpu().item()) + turn_audio = turn_audio[:turn_audio_len] + turn_audio_out = resample(turn_audio.unsqueeze(0), self.model.sample_rate, sample_rate).squeeze(0) + user_turns_for_gt.append((int(turn_id), turn_audio_out)) + + if local_turn_idx == 0: + user_start_sample = 0 + else: + prev_turn_end_frame = turn_frame_ranges[local_turn_idx - 1][2] + rel_prev_end_frame = prev_turn_end_frame - decode_start_frame + user_start_sample = first_user_delay_out + int( + round(rel_prev_end_frame * samples_per_prediction_frame) + ) + user_segments.append((user_start_sample, turn_audio_out)) + + total_user_len = 0 + for start, wav in user_segments: + total_user_len = max(total_user_len, start + wav.numel()) + user_ch = torch.zeros(total_user_len) + for start, wav in user_segments: + user_ch[start : start + wav.numel()] += wav + + agent_ch = torch.cat([torch.zeros(first_user_delay_out, dtype=aligned_agent.dtype), aligned_agent]) + stereo_len = max(user_ch.numel(), agent_ch.numel()) + user_pad = torch.zeros(stereo_len) + agent_pad = torch.zeros(stereo_len) + user_pad[: user_ch.numel()] = user_ch + agent_pad[: agent_ch.numel()] = agent_ch + + stereo = torch.stack([user_pad, agent_pad], dim=1).numpy() + sf.write( + os.path.join( + debug_mixed_dir, + f"{debug_sample_stem}__sample_{sample_idx}__user_agent_aligned_stereo.wav", + ), + stereo, + sample_rate, + ) + + target_turn_audio_paths = batch.get("target_turn_audio_paths", []) or [] + if isinstance(target_turn_audio_paths, str): + target_turn_audio_paths = [target_turn_audio_paths] + + gt_user_segments = [] + gt_agent_segments = [] + cursor = 0 + for source_turn_idx, user_wav in user_turns_for_gt: + gt_user_segments.append((cursor, user_wav)) + cursor += user_wav.numel() + + target_wav = None + if source_turn_idx < len(target_turn_audio_paths): + target_wav = load_debug_audio(target_turn_audio_paths[source_turn_idx]) + + if target_wav is None: + logging.warning( + "Could not load target turn audio for multiturn ground-truth debug stereo; " + f"sample_idx={sample_idx}, source_turn_idx={source_turn_idx}" + ) + continue + + gt_agent_segments.append((cursor, target_wav)) + cursor += target_wav.numel() + + if gt_user_segments or gt_agent_segments: + gt_len = 0 + for start, wav in gt_user_segments + gt_agent_segments: + gt_len = max(gt_len, start + wav.numel()) + gt_user_ch = torch.zeros(gt_len) + gt_agent_ch = torch.zeros(gt_len) + for start, wav in gt_user_segments: + gt_user_ch[start : start + wav.numel()] += wav + for start, wav in gt_agent_segments: + gt_agent_ch[start : start + wav.numel()] += wav + + gt_stereo = torch.stack([gt_user_ch, gt_agent_ch], dim=1).numpy() + sf.write( + os.path.join( + debug_mixed_dir, + f"{debug_sample_stem}__sample_{sample_idx}__user_agent_ground_truth_stereo.wav", + ), + gt_stereo, + sample_rate, + ) diff --git a/nemo/collections/tts/modules/magpietts_modules.py b/nemo/collections/tts/modules/magpietts_modules.py index 483bae678243..efaeea3d09c8 100644 --- a/nemo/collections/tts/modules/magpietts_modules.py +++ b/nemo/collections/tts/modules/magpietts_modules.py @@ -96,8 +96,8 @@ class SpecialAudioToken(Enum): AUDIO_CONTEXT_EOS = 3 MASK_TOKEN = 4 # Reserve these values so that if we need to add more special tokens in the future the codebook size will remain the same - RESERVED_1 = 5 - RESERVED_2 = 6 + USER_SPEAKING = 5 + USER_SPEAKING_END = 6 RESERVED_3 = 7 @staticmethod @@ -244,6 +244,11 @@ def forward(self, subword_ids: Tensor, subword_mask: Tensor | None = None) -> Te if subword_mask.ndim == 3: subword_mask = subword_mask.squeeze(-1) + if not subword_mask.any(): + B, T = subword_ids.shape + D = self.embed_tokens.embedding_dim + return torch.zeros((B, T, D), dtype=self.embed_tokens.weight.dtype, device=device) + char_ids, char_lengths = self.prepare_inputs(subword_ids, subword_mask) char_mask = get_mask_from_lengths(char_lengths) char_emb = self.embed_tokens(char_ids) diff --git a/nemo/collections/tts/modules/nemotron_h_decoder.py b/nemo/collections/tts/modules/nemotron_h_decoder.py index 986359c0e2b3..8e487da639e6 100644 --- a/nemo/collections/tts/modules/nemotron_h_decoder.py +++ b/nemo/collections/tts/modules/nemotron_h_decoder.py @@ -91,6 +91,49 @@ ) +def make_mamba_conv_cache_from_sequence( + hidden_states_B_C: torch.Tensor, + conv_kernel_size: int, +) -> torch.Tensor: + """ + hidden_states_B_C: (B, T, conv_dim) + returns conv cache: (B, conv_dim, conv_kernel_size) + """ + x = hidden_states_B_C.transpose(1, 2) # (B, conv_dim, T) + + if x.size(-1) >= conv_kernel_size: + return x[:, :, -conv_kernel_size:].contiguous() + + return F.pad(x, (conv_kernel_size - x.size(-1), 0)).contiguous() + + +def get_cached_mamba_ssm_state( + cache_params: "HybridMambaAttentionDynamicCache", + layer_idx: int, + batch_size: int, + num_heads: int, + head_dim: int, + ssm_state_size: int, + device: torch.device, +): + """ + Returns cached SSM state in the shape expected by full-sequence/chunk scan: + (B, num_heads, head_dim, ssm_state_size) + + The cache may contain either: + (B, num_heads * head_dim, ssm_state_size) + or: + (B, num_heads, head_dim, ssm_state_size) + depending on which path updated it last. + """ + state = cache_params.ssm_states[layer_idx].to(device=device) + + if state.dim() == 3: + state = state.view(batch_size, num_heads, head_dim, ssm_state_size) + + return state + + def get_activation_fn(activation: str): """Get activation function by name.""" if activation == "silu" or activation == "swish": @@ -240,6 +283,7 @@ def __init__(self, config: NemotronHConfig, batch_size: int, dtype=torch.float16 intermediate_size = config.mamba_num_heads * config.mamba_head_dim ssm_state_size = config.ssm_state_size conv_kernel_size = config.conv_kernel + conv_dim = intermediate_size + 2 * config.n_groups * config.ssm_state_size self.conv_states = [] self.ssm_states = [] @@ -250,7 +294,7 @@ def __init__(self, config: NemotronHConfig, batch_size: int, dtype=torch.float16 for i in range(config.num_hidden_layers): if config.layers_block_type[i] == "mamba": self.conv_states.append( - torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype) + torch.zeros(batch_size, conv_dim, conv_kernel_size, device=device, dtype=dtype) ) self.ssm_states.append( torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype) @@ -495,10 +539,14 @@ def cuda_kernels_forward( - self.num_heads ) // 2 - if cache_params is not None and cache_position is not None and cache_position[0] > 0: - # Cached forward (single token) + has_cache_prefix = cache_params is not None and cache_position is not None and cache_position[0] > 0 + is_decode_step = has_cache_prefix and seq_len == 1 + + if is_decode_step: + # Cached single-token decode path. _, _, gate, hidden_states_B_C, dt = projected_states.squeeze(1).split( - [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], + dim=-1, ) hidden_states_B_C = causal_conv1d_update( @@ -539,8 +587,9 @@ def cuda_kernels_forward( hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim) hidden_states = self.norm(hidden_states, gate) out = self.out_proj(hidden_states)[:, None, ...] + else: - # Full sequence forward + # Full sequence or cached multi-token prefill path. A = -torch.exp(self.A_log.float()) dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit} @@ -567,31 +616,57 @@ def cuda_kernels_forward( ) else: _, _, gate, hidden_states_B_C, dt = projected_states.split( - [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], + dim=-1, ) - if cache_params is not None: - hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) - conv_states = F.pad( - hidden_states_B_C_transposed, - (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0), - ) - cache_params.update_conv_state( - layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True - ) + raw_hidden_states_B_C = hidden_states_B_C + + # For cached T > 1 prefill, convolution must see the previous conv cache. + if has_cache_prefix: + prev_conv = cache_params.conv_states[self.layer_idx].to( + device=raw_hidden_states_B_C.device, + dtype=raw_hidden_states_B_C.dtype, + ) # (B, conv_dim, K) + + conv_input = torch.cat( + [prev_conv, raw_hidden_states_B_C.transpose(1, 2)], + dim=-1, + ) # (B, conv_dim, K + T) + + raw_for_cache = torch.cat( + [prev_conv.transpose(1, 2), raw_hidden_states_B_C], + dim=1, + ) # (B, K + T, conv_dim) + else: + conv_input = raw_hidden_states_B_C.transpose(1, 2) + raw_for_cache = raw_hidden_states_B_C if self.activation not in ["silu", "swish"]: - hidden_states_B_C = self.act( - self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2) - ) + conv_out = self.act(self.conv1d(conv_input)[..., : conv_input.size(-1)].transpose(1, 2)) else: - hidden_states_B_C = causal_conv1d_fn( - x=hidden_states_B_C.transpose(1, 2), + conv_out = causal_conv1d_fn( + x=conv_input, weight=self.conv1d.weight.squeeze(1), bias=self.conv1d.bias, activation=self.activation, ).transpose(1, 2) + # Keep only outputs corresponding to the new chunk. + hidden_states_B_C = conv_out[:, -seq_len:, :].contiguous() + + # Update cache after using the previous cache. + if cache_params is not None: + conv_states = make_mamba_conv_cache_from_sequence( + raw_for_cache, + cache_params.conv_kernel_size, + ) + cache_params.update_conv_state( + layer_idx=self.layer_idx, + new_conv_state=conv_states, + cache_init=True, + ) + hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) hidden_states, B, C = torch.split( hidden_states_B_C, @@ -599,12 +674,7 @@ def cuda_kernels_forward( dim=-1, ) - scan_output, ssm_state = mamba_chunk_scan_combined( - hidden_states.view(batch_size, seq_len, -1, self.head_dim), - dt, - A, - B.view(batch_size, seq_len, self.n_groups, -1), - C.view(batch_size, seq_len, self.n_groups, -1), + scan_kwargs = dict( chunk_size=self.chunk_size, D=self.D, z=None, @@ -615,8 +685,32 @@ def cuda_kernels_forward( **dt_limit_kwargs, ) + # For cached T > 1 prefill, SSM must start from previous cache state. + if has_cache_prefix: + scan_kwargs["initial_states"] = get_cached_mamba_ssm_state( + cache_params=cache_params, + layer_idx=self.layer_idx, + batch_size=batch_size, + num_heads=self.num_heads, + head_dim=self.head_dim, + ssm_state_size=self.ssm_state_size, + device=hidden_states.device, + ) + + scan_output, ssm_state = mamba_chunk_scan_combined( + hidden_states.view(batch_size, seq_len, -1, self.head_dim), + dt, + A, + B.view(batch_size, seq_len, self.n_groups, -1), + C.view(batch_size, seq_len, self.n_groups, -1), + **scan_kwargs, + ) + if ssm_state is not None and cache_params is not None: - cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state) + cache_params.update_ssm_state( + layer_idx=self.layer_idx, + new_ssm_state=ssm_state, + ) scan_output = scan_output.view(batch_size, seq_len, -1) scan_output = self.norm(scan_output, gate) @@ -645,28 +739,70 @@ def torch_forward( - self.num_heads ) // 2 _, _, gate, hidden_states_B_C, dt = projected_states.split( - [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], + dim=-1, ) + has_cache_prefix = cache_params is not None and cache_position is not None and cache_position[0] > 0 + is_decode_step = has_cache_prefix and seq_len == 1 + + # ----------------------- # Convolution - if cache_params is not None and cache_position is not None and cache_position[0] > 0: + # ----------------------- + if is_decode_step: cache_params.update_conv_state( - layer_idx=self.layer_idx, new_conv_state=hidden_states_B_C, cache_init=False + layer_idx=self.layer_idx, + new_conv_state=hidden_states_B_C, + cache_init=False, + ) + conv_states = cache_params.conv_states[self.layer_idx].to( + device=self.conv1d.weight.device, + dtype=hidden_states_B_C.dtype, ) - conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device) hidden_states_B_C = torch.sum(conv_states * self.conv1d.weight.squeeze(1), dim=-1) if self.use_conv_bias: hidden_states_B_C = hidden_states_B_C + self.conv1d.bias hidden_states_B_C = self.act(hidden_states_B_C) + else: + raw_hidden_states_B_C = hidden_states_B_C + + # For cached T > 1 prefill, convolution must see previous conv cache. + if has_cache_prefix: + prev_conv = cache_params.conv_states[self.layer_idx].to( + device=raw_hidden_states_B_C.device, + dtype=raw_hidden_states_B_C.dtype, + ) # (B, conv_dim, K) + + conv_input = torch.cat( + [prev_conv, raw_hidden_states_B_C.transpose(1, 2)], + dim=-1, + ) # (B, conv_dim, K + T) + + raw_for_cache = torch.cat( + [prev_conv.transpose(1, 2), raw_hidden_states_B_C], + dim=1, + ) # (B, K + T, conv_dim) + else: + conv_input = raw_hidden_states_B_C.transpose(1, 2) + raw_for_cache = raw_hidden_states_B_C + + conv_out = self.act(self.conv1d(conv_input)[..., : conv_input.size(-1)].transpose(1, 2)) + + # Keep only outputs corresponding to the new chunk. + hidden_states_B_C = conv_out[:, -seq_len:, :].contiguous() + + # Update cache after using the previous cache. if cache_params is not None: - hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) - conv_states = F.pad( - hidden_states_B_C_transposed, - (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0), + conv_states = make_mamba_conv_cache_from_sequence( + raw_for_cache, + cache_params.conv_kernel_size, + ) + cache_params.update_conv_state( + layer_idx=self.layer_idx, + new_conv_state=conv_states, + cache_init=True, ) - cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True) - hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)) hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) hidden_states, B, C = torch.split( @@ -675,11 +811,13 @@ def torch_forward( dim=-1, ) + # ----------------------- # SSM + # ----------------------- A = -torch.exp(self.A_log.float()) - if cache_params is not None and cache_position is not None and cache_position[0] > 0: - # Single step SSM update + if is_decode_step: + # Single-step SSM update. cache_device = cache_params.ssm_states[self.layer_idx].device dt = dt[:, 0, :][:, None, ...] dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim) @@ -701,7 +839,8 @@ def torch_forward( dBx = (dB * hidden_states[..., None]).to(device=cache_device) cache_params.update_ssm_state( - layer_idx=self.layer_idx, new_ssm_state=cache_params.ssm_states[self.layer_idx] * dA + dBx + layer_idx=self.layer_idx, + new_ssm_state=cache_params.ssm_states[self.layer_idx] * dA + dBx, ) C = C.reshape(batch_size, self.n_groups, -1)[..., None, :] @@ -717,8 +856,9 @@ def torch_forward( D = self.D[..., None].expand(self.D.shape[0], self.head_dim) y = (y + hidden_states * D).to(y.dtype) y = y.reshape(batch_size, -1)[:, None, ...] + else: - # Full sequence SSM (chunked) + # Full-sequence or cached multi-token SSM. dt = F.softplus(dt + self.dt_bias) dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() @@ -751,8 +891,18 @@ def torch_forward( B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None] states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2) - if cache_params is not None and cache_position is not None and cache_position[0] > 0: - previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device) + # This is the critical fix: + # cached T > 1 prefill must start from the previous SSM state. + if has_cache_prefix: + previous_states = get_cached_mamba_ssm_state( + cache_params=cache_params, + layer_idx=self.layer_idx, + batch_size=batch_size, + num_heads=self.num_heads, + head_dim=self.head_dim, + ssm_state_size=self.ssm_state_size, + device=states.device, + )[:, None, ...] else: previous_states = torch.zeros_like(states[:, :1]) @@ -776,7 +926,10 @@ def torch_forward( y = y.reshape(batch_size, seq_len, -1) if ssm_state is not None and cache_params is not None: - cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state) + cache_params.update_ssm_state( + layer_idx=self.layer_idx, + new_ssm_state=ssm_state, + ) scan_output = self.norm(y, gate) contextualized_states = self.out_proj(scan_output.to(dtype)) @@ -1371,9 +1524,22 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + def create_custom_forward(layer, layer_mask): + def custom_forward(hidden_states): + return layer( + hidden_states, + cache_params=None, + cache_position=None, + attention_mask=layer_mask, + ) + + return custom_forward + if self.gradient_checkpointing and self.training: hidden_states = torch.utils.checkpoint.checkpoint( - layer.__call__, hidden_states, cache_params, cache_position, layer_mask + create_custom_forward(layer, layer_mask), + hidden_states, + use_reentrant=False, ) else: hidden_states = layer( diff --git a/nemo/collections/tts/modules/transformer_2501.py b/nemo/collections/tts/modules/transformer_2501.py index f0e245c9f258..5deaf898c8be 100644 --- a/nemo/collections/tts/modules/transformer_2501.py +++ b/nemo/collections/tts/modules/transformer_2501.py @@ -20,7 +20,6 @@ from nemo.collections.tts.modules.ffn_modules import PositionwiseConvFF from nemo.collections.tts.modules.moe_modules import PositionwiseConvFFMoE -from nemo.utils import logging # TODO: Move the cache implementation out of the Module class, and pass it as part of the forward so we can reset # as needed in the inference pipeline. @@ -120,6 +119,7 @@ def attn_naive( # attn_prior or square mask or vanilla attention if attn_prior is not None: + attn_prior = attn_prior.to(device=attn_score.device, dtype=attn_score.dtype) eps = torch.finfo(attn_prior.dtype).tiny attn_prior = attn_prior[:, :T] # trim for inference attn_prior = attn_prior[:, None] + eps @@ -129,7 +129,7 @@ def attn_naive( attn_score_log = F.log_softmax(attn_score, dim=-1) + attn_prior_log if self.make_prior_window_strict: # Make sure attention scores are lowest (eps) where prior is zero. - min_score = torch.log(torch.tensor(eps)).to(attn_score_log.device) + min_score = torch.log(torch.tensor(eps, device=attn_score_log.device, dtype=attn_score_log.dtype)) attn_score_log = attn_score_log.masked_fill( attn_prior == 0, min_score ) # Wherever prior is zero, set scores to eps. @@ -244,6 +244,7 @@ def compute_qkv_and_mask( mask = None if query_mask is not None: + query_mask = query_mask.to(device=query.device) # query_mask is a boolean mask of shape (B, T) # mask should be of shape (B, 1, T, T) where mask[:,0,i,:] == mask[:,0,:,i] == query_mask mask = query_mask.unsqueeze(1) * query_mask.unsqueeze(2) @@ -310,7 +311,7 @@ def compute_qkv_and_mask( self.cache['cross_k'] = k self.cache['cross_v'] = v - mask = memory_mask[:, None, None] if memory_mask is not None else None + mask = memory_mask.to(device=query.device)[:, None, None] if memory_mask is not None else None return q, k, v, mask diff --git a/tests/collections/common/test_lhotse_tts_filters.py b/tests/collections/common/test_lhotse_tts_filters.py index 877b0886ab97..90a08f32d55e 100644 --- a/tests/collections/common/test_lhotse_tts_filters.py +++ b/tests/collections/common/test_lhotse_tts_filters.py @@ -20,6 +20,7 @@ from nemo.collections.common.data.lhotse.sampling import ( CERFilter, ContextSpeakerSimilarityFilter, + SpeakerFilter, ValidationStatusFilter, ) @@ -141,3 +142,71 @@ def test_cut_validation_status_filter(cut_example): f = ValidationStatusFilter("any_other_status") assert f(cut_example) == False + + +def test_cut_speaker_filter_by_speaker(cut_example): + f = SpeakerFilter( + excluded_speaker_ids=["| Language:en Dataset:nvyt2505 Speaker:Zdud2gXLTXY_SPEAKER_02 |"], + speaker_fields=["speaker"], + ) + assert f(cut_example) == False + + f = SpeakerFilter( + excluded_speaker_ids=["some_other_speaker"], + speaker_fields=["speaker"], + ) + assert f(cut_example) == True + + +def test_cut_speaker_filter_by_custom_speaker_id(cut_example): + cut_example.supervisions[0].custom["speaker_id"] = "test_speaker_001" + + f = SpeakerFilter( + excluded_speaker_ids=["test_speaker_001"], + speaker_fields=["speaker_id"], + ) + assert f(cut_example) == False + + f = SpeakerFilter( + excluded_speaker_ids=["some_other_speaker"], + speaker_fields=["speaker_id"], + ) + assert f(cut_example) == True + + +def test_cut_speaker_filter_multiple_fields(cut_example): + cut_example.supervisions[0].custom["speaker_id"] = "test_speaker_001" + + f = SpeakerFilter( + excluded_speaker_ids=["test_speaker_001"], + speaker_fields=["speaker", "speaker_id"], + ) + assert f(cut_example) == False + + f = SpeakerFilter( + excluded_speaker_ids=["some_other_speaker"], + speaker_fields=["speaker", "speaker_id"], + ) + assert f(cut_example) == True + + +def test_cut_speaker_filter_disabled(cut_example): + f = SpeakerFilter( + excluded_speaker_ids=None, + speaker_fields=["speaker_id"], + ) + assert f(cut_example) == True + + f = SpeakerFilter( + excluded_speaker_ids=[], + speaker_fields=["speaker_id"], + ) + assert f(cut_example) == True + + +def test_cut_speaker_filter_requires_fields_when_enabled(): + with pytest.raises(ValueError): + SpeakerFilter( + excluded_speaker_ids=["test_speaker_001"], + speaker_fields=None, + ) diff --git a/tests/collections/tts/data/test_magpietts_dataset_lhotse.py b/tests/collections/tts/data/test_magpietts_dataset_lhotse.py new file mode 100644 index 000000000000..9f0dbaab58f1 --- /dev/null +++ b/tests/collections/tts/data/test_magpietts_dataset_lhotse.py @@ -0,0 +1,257 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +from io import BytesIO +from pathlib import Path + +import numpy as np +import pytest +import torch +from lhotse import CutSet, SupervisionSegment +from lhotse.array import Array, TemporalArray +from lhotse.testing.dummies import dummy_cut, dummy_recording +from omegaconf import OmegaConf + +from nemo.collections.tts.data.text_to_speech_dataset_lhotse import MagpieTTSLhotseDataset +from nemo.collections.tts.data.text_to_speech_dataset_lhotse_multiturn import MagpieTTSLhotseMultiturnDataset + + +pytestmark = pytest.mark.unit + +SAMPLE_RATE = 24000 +CODEC_MODEL_SAMPLES_PER_FRAME = 480 +CODEC_MODEL_INPUT_SAMPLE_RATE = 24000 +FRAME_STACKING_FACTOR = 1 +NUM_AUDIO_CODEBOOKS = 8 + +BPE_TOKENIZER_NAME = "nemotron_bpe" +BPE_TOKENIZER_MODEL = "nvidia/NVIDIA-Nemotron-Nano-9B-v2" +BPE_TOKENIZER_CACHED_PATH = Path("/home/TestData/nvidia--NVIDIA-Nemotron-Nano-9B-v2/") +if BPE_TOKENIZER_CACHED_PATH.exists(): + BPE_TOKENIZER_MODEL = str(BPE_TOKENIZER_CACHED_PATH) + + +def _seed_everything(): + random.seed(42) + np.random.seed(42) + torch.manual_seed(42) + + +def _tokenizer_config(): + return OmegaConf.create( + { + BPE_TOKENIZER_NAME: { + "_target_": "AutoTokenizer", + "pretrained_model": BPE_TOKENIZER_MODEL, + } + } + ) + + +def _memory_temporal_array(values, frame_shift=0.02): + buffer = BytesIO() + np.save(buffer, values) + return TemporalArray( + array=Array( + storage_type="memory_npy", + storage_path="", + storage_key=buffer.getvalue(), + shape=list(values.shape), + ), + temporal_dim=-1, + frame_shift=frame_shift, + start=0, + ) + + +def _cached_codes(num_codebooks=NUM_AUDIO_CODEBOOKS, num_frames=5, offset=0): + codes = np.arange(num_codebooks * num_frames, dtype=np.int32).reshape(num_codebooks, num_frames) + return (codes + offset) % 16 + + +def _single_turn_cutset(): + cut = dummy_cut( + 0, + duration=0.5, + recording=dummy_recording(0, duration=0.5, with_data=True, sampling_rate=SAMPLE_RATE), + ) + cut.target_audio = dummy_recording(10, duration=0.5, with_data=True, sampling_rate=SAMPLE_RATE) + cut.supervisions = [ + SupervisionSegment( + id="single-turn", + recording_id=cut.recording_id, + start=0.0, + duration=0.2, + text="hello", + language="en", + speaker="| Language:en Dataset:Unit Speaker:spk |", + custom={"context_text": "speaker prompt", "normalized_text": "hello"}, + ) + ] + cut.custom = { + **(cut.custom or {}), + "target_codes": _memory_temporal_array(_cached_codes(num_frames=4)), + "context_codes": _memory_temporal_array(_cached_codes(num_frames=3, offset=3)), + "tokenizer_names": [BPE_TOKENIZER_NAME], + "lang": "en", + } + return CutSet.from_cuts([cut]) + + +def _multiturn_cutset(): + cut = dummy_cut( + 1, + duration=0.8, + recording=dummy_recording(1, duration=0.8, with_data=True, sampling_rate=SAMPLE_RATE), + ) + cut.target_audio = dummy_recording(11, duration=0.8, with_data=True, sampling_rate=SAMPLE_RATE) + cut.supervisions = [ + SupervisionSegment( + id="turn-user-0", + recording_id=cut.recording_id, + start=0.0, + duration=0.2, + text="hi", + language="en", + speaker="user", + custom={"context_text": "chat prompt"}, + ), + SupervisionSegment( + id="turn-agent-0", + recording_id=cut.recording_id, + start=0.3, + duration=0.2, + text="hello", + language="en", + speaker="assistant", + ), + SupervisionSegment( + id="turn-user-1", + recording_id=cut.recording_id, + start=0.55, + duration=0.1, + text="ok", + language="en", + speaker="user", + ), + SupervisionSegment( + id="turn-agent-1", + recording_id=cut.recording_id, + start=0.68, + duration=0.1, + text="okay", + language="en", + speaker="assistant", + custom={"reward": 0.5}, + ), + ] + cut.custom = { + **(cut.custom or {}), + "task": "dialog", + "target_codes": _memory_temporal_array(_cached_codes(num_frames=8)), + "source_codes": _memory_temporal_array(_cached_codes(num_frames=8, offset=1)), + "context_codes": _memory_temporal_array(_cached_codes(num_frames=4, offset=2)), + "tokenizer_names": [BPE_TOKENIZER_NAME], + "lang": "en", + } + return CutSet.from_cuts([cut]) + + +def _dataset_kwargs(): + return { + "sample_rate": SAMPLE_RATE, + "volume_norm": False, + "codec_model_samples_per_frame": CODEC_MODEL_SAMPLES_PER_FRAME, + "num_audio_codebooks": NUM_AUDIO_CODEBOOKS, + "prior_scaling_factor": None, + "load_cached_codes_if_available": True, + "dataset_type": "train", + "load_16khz_audio": False, + "pad_context_text_to_max_duration": False, + "context_duration_min": 0.04, + "context_duration_max": 0.04, + "use_text_conditioning_tokenizer": True, + "text_conditioning_tokenizer_name": BPE_TOKENIZER_NAME, + "tokenizer_config": _tokenizer_config(), + } + + +class TestMagpieTTSLhotseDatasets: + def test_single_turn_dataset_uses_bpe_and_cached_codes(self): + _seed_everything() + dataset = MagpieTTSLhotseDataset(**_dataset_kwargs()) + + batch = dataset[_single_turn_cutset()] + + assert batch["dataset_names"] == ["Unit"] + assert batch["languages"] == ["en"] + assert batch["raw_texts"] == ["hello"] + assert "audio" not in batch + assert "context_audio" not in batch + assert batch["audio_codes"].shape == (1, NUM_AUDIO_CODEBOOKS, 4) + assert batch["audio_codes_lens"].tolist() == [4] + assert batch["context_audio_codes"].shape == (1, NUM_AUDIO_CODEBOOKS, 2) + assert batch["context_audio_codes_lens"].tolist() == [2] + assert batch["text"].shape[0] == 1 + assert batch["text_lens"].item() > 0 + assert batch["context_text_tokens"].shape[0] == 1 + assert batch["context_text_tokens_lens"].item() > 0 + assert batch["has_text_context"].tolist() == [True] + + def test_multiturn_dataset_uses_bpe_and_cached_codes(self): + _seed_everything() + kwargs = _dataset_kwargs() + kwargs.update( + { + "codec_model_input_sample_rate": CODEC_MODEL_INPUT_SAMPLE_RATE, + "frame_stacking_factor": FRAME_STACKING_FACTOR, + "source_sample_rate": SAMPLE_RATE, + "input_roles": ["user"], + "output_roles": ["assistant"], + "add_text_bos": False, + } + ) + dataset = MagpieTTSLhotseMultiturnDataset(**kwargs) + + batch = dataset[_multiturn_cutset()] + + assert batch["sample_id"] == ["dummy-mono-cut-0001"] + assert batch["dataset_names"] == ["unknown"] + assert batch["languages"] == ["en"] + assert batch["raw_texts"] == ["hello okay"] + assert batch["task"] == ["dialog"] + assert batch["audio"].shape[0] == 1 + assert batch["source_audio"].shape[0] == 1 + assert batch["audio_codes"].shape == (1, NUM_AUDIO_CODEBOOKS, 8) + assert batch["audio_codes_lens"].tolist() == [8] + assert batch["source_codes"].shape == (1, NUM_AUDIO_CODEBOOKS, 8) + assert batch["source_codes_lens"].tolist() == [8] + assert batch["context_audio_codes"].shape == (1, NUM_AUDIO_CODEBOOKS, 2) + assert batch["context_audio_codes_lens"].tolist() == [2] + assert batch["source_tokens"].shape[0] == 1 + assert batch["source_token_lens"].item() > 0 + assert batch["text"].shape[0] == 1 + assert batch["text_lens"].item() > 0 + assert batch["agent_mask"].shape == batch["user_mask"].shape + assert batch["agent_mask_lens"].tolist() == batch["user_mask_lens"].tolist() + assert batch["agent_mask"].sum().item() > 0 + assert batch["user_mask"].sum().item() > 0 + assert batch["user_audio_turn_splitted"].shape[0] == 2 + assert batch["user_audio_turn_splitted_lens"].tolist() == [4800, 2400] + assert batch["user_audio_turn_splitted_indices"].shape == (2, 3) + assert batch["context_text_tokens"].shape[0] == 1 + assert batch["context_text_tokens_lens"].item() > 0 + assert batch["has_text_context"].tolist() == [True] + torch.testing.assert_close(batch["rewards"], torch.tensor([0.5], device=batch["rewards"].device)) diff --git a/tests/collections/tts/models/test_easy_magpietts.py b/tests/collections/tts/models/test_easy_magpietts.py new file mode 100644 index 000000000000..742db9e1b8f4 --- /dev/null +++ b/tests/collections/tts/models/test_easy_magpietts.py @@ -0,0 +1,530 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +from contextlib import contextmanager +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import patch + +import numpy as np +import pytest +import torch +from omegaconf import OmegaConf +from torch import nn + +from nemo.collections.tts.models import AudioCodecModel +from nemo.collections.tts.models.easy_magpietts import EasyMagpieTTSModel +from nemo.collections.tts.models.easy_magpietts_inference import EasyModelInferenceParameters, TrainingMode +from tests.collections.tts.models.test_audio_codec import create_codec_config + + +pytestmark = pytest.mark.unit + +BPE_TOKENIZER_NAME = "nemotron_bpe" +BPE_TOKENIZER_MODEL = "nvidia/NVIDIA-Nemotron-Nano-9B-v2" +BPE_TOKENIZER_CACHED_PATH = Path("/home/TestData/nvidia--NVIDIA-Nemotron-Nano-9B-v2/") +if BPE_TOKENIZER_CACHED_PATH.exists(): + BPE_TOKENIZER_MODEL = str(BPE_TOKENIZER_CACHED_PATH) + + +@pytest.fixture(autouse=True, scope="module") +def _default_device_cuda(): + """Run this module with a CUDA default device without leaking it to later modules.""" + if not torch.cuda.is_available(): + yield + return + + prev = torch.get_default_device() + torch.set_default_device("cuda") + try: + yield + finally: + torch.set_default_device(prev) + + +def _restore_codec_as_random_initialized_model(*args, **kwargs): + del args + if kwargs.get("return_config", False): + return create_codec_config() + + codec_cfg = kwargs.get("override_config_path", None) + if codec_cfg is None: + codec_cfg = create_codec_config() + codec_model = AudioCodecModel(cfg=codec_cfg) + codec_model.freeze() + return codec_model + + +@contextmanager +def _codec_restore_uses_random_initialized_audio_codec(): + from nemo.collections.tts.models import easy_magpietts_inference + + with patch.object( + easy_magpietts_inference.AudioCodecModel, + "restore_from", + staticmethod(_restore_codec_as_random_initialized_model), + ): + yield + + +def _seed_everything(): + random.seed(42) + np.random.seed(42) + torch.manual_seed(42) + + +def tiny_easy_magpie_cfg(overrides=None): + cfg = OmegaConf.create( + { + "codecmodel_path": "dummy_codec.nemo", + "decoder_type": "nemotron_h", + "embedding_dim": 32, + "hidden_dim": 32, + "audio_embedding_dim": 16, + "frame_stacking_factor": 1, + "local_transformer_type": "none", + "disable_lm_text_head": True, + "disable_subword_embedding": False, + "use_bpe_char_tokenizer": True, + "text_conditioning_tokenizer_name": BPE_TOKENIZER_NAME, + "use_multiturn_dataset": False, + "run_val_inference": False, + "use_utmos": False, + "cfg_unconditional_prob": 0.0, + "dropout_text_input_prob": 0.0, + "phoneme_corruption_batch_prob": 0.0, + "phoneme_corruption_timestep_ratio": 0.0, + "phoneme_as_text_prob": 0.0, + "mask_user_on_loss": True, + "text_tokenizers": { + BPE_TOKENIZER_NAME: { + "_target_": "AutoTokenizer", + "pretrained_model": BPE_TOKENIZER_MODEL, + } + }, + "training_modes": [ + { + "text_input_mode": "streaming", + "streaming_phonemes_delay": 1, + "streaming_speech_delay": 2, + } + ], + "nemotron_h_config": { + "hidden_size": 32, + "num_hidden_layers": 1, + "vocab_size": 64, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "mamba_num_heads": 4, + "mamba_head_dim": 8, + "ssm_state_size": 8, + "n_groups": 2, + "intermediate_size": 64, + "hybrid_override_pattern": "*", + "use_cache": True, + "_attn_implementation": "sdpa", + }, + "optimizer": { + "_target_": "torch.optim.AdamW", + "lr": 0.001, + }, + } + ) + if overrides is not None: + cfg = OmegaConf.merge(cfg, overrides) + return cfg + + +def _make_easy_magpie_model(cfg=None): + _seed_everything() + with _codec_restore_uses_random_initialized_audio_codec(): + model = EasyMagpieTTSModel(cfg or tiny_easy_magpie_cfg()) + model.eval() + return model + + +@pytest.fixture() +def model(): + return _make_easy_magpie_model() + + +def _padded_token_tensor(model, texts): + tokenized = [model.tokenizer.encode(text, tokenizer_name=BPE_TOKENIZER_NAME) + [model.eos_id] for text in texts] + lens = torch.tensor([len(tokens) for tokens in tokenized], dtype=torch.long) + max_len = int(lens.max().item()) + padded = torch.full((len(tokenized), max_len), model.pad_id, dtype=torch.long) + for idx, tokens in enumerate(tokenized): + padded[idx, : len(tokens)] = torch.tensor(tokens, dtype=torch.long) + return padded, lens + + +def _toy_codes(model, batch_size, num_frames): + codes = torch.zeros(batch_size, model.num_audio_codebooks, num_frames, dtype=torch.long) + frame_ids = torch.arange(num_frames, dtype=torch.long) + for batch_idx in range(batch_size): + for codebook_idx in range(model.num_audio_codebooks): + codes[batch_idx, codebook_idx] = (frame_ids + batch_idx + codebook_idx * 7) % model.codebook_size + return codes + + +def _toy_batch(model): + text, text_lens = _padded_token_tensor(model, ["abc", "de"]) + context_text_tokens, context_text_tokens_lens = _padded_token_tensor(model, ["hi", "ok"]) + + audio_codes = _toy_codes(model, batch_size=2, num_frames=4) + audio_codes_lens = torch.tensor([4, 3], dtype=torch.long) + context_audio_codes = _toy_codes(model, batch_size=2, num_frames=2) + context_audio_codes[1, :, 1] = 0 + context_audio_codes_lens = torch.tensor([2, 1], dtype=torch.long) + agent_mask = torch.tensor( + [ + [False, True, True, False, False], + [True, False, True, False, False], + ], + dtype=torch.bool, + ) + + return { + "text": text, + "text_lens": text_lens, + "context_text_tokens": context_text_tokens, + "context_text_tokens_lens": context_text_tokens_lens, + "audio_codes": audio_codes, + "audio_codes_lens": audio_codes_lens, + "context_audio_codes": context_audio_codes, + "context_audio_codes_lens": context_audio_codes_lens, + "agent_mask": agent_mask, + "task": ["tts", "tts"], + } + + +@pytest.fixture() +def toy_batch(model): + return _toy_batch(model) + + +def test_training_mode_and_inference_parameters(): + mode = TrainingMode( + text_input_mode="streaming", + streaming_phonemes_delay=4, + streaming_speech_delay=8, + mode_idx=2, + ) + assert mode.name == "streaming_4_8" + + params = EasyModelInferenceParameters.from_dict( + { + "max_decoder_steps": 11, + "temperature": 0.25, + "topk": 7, + "cfg_scale": 1.5, + "unknown_key": "ignored", + } + ) + assert params == EasyModelInferenceParameters( + max_decoder_steps=11, + temperature=0.25, + topk=7, + cfg_scale=1.5, + ) + + +def test_easy_magpietts_model_construction(model): + expected_device = "cuda" if torch.cuda.is_available() else "cpu" + assert next(model.parameters()).device.type == expected_device + assert model.tokenizer is not None + assert model.decoder is not None + assert len(model.audio_embeddings) == model.num_audio_codebooks * model.frame_stacking_factor + assert isinstance(model.audio_in_projection, nn.Linear) + assert isinstance(model.audio_out_projection, nn.Linear) + assert model.final_proj.out_features == model.num_audio_codebooks * model.num_all_tokens_per_codebook + assert model.audio_bos_id == model.codebook_size + assert model.audio_eos_id == model.codebook_size + 1 + assert model.training_modes[0].name == "streaming_1_2" + assert model.default_inference_mode == "streaming_1_2" + assert model.lm_text_head is None + assert model.use_bpe_char_tokenizer + assert model.text_conditioning_tokenizer_name == BPE_TOKENIZER_NAME + assert hasattr(model, "cas_encoder") + + +def test_state_dict_excludes_codec(model): + state = model.state_dict() + assert state + assert not any("_codec_model" in key for key in state) + assert any(key.startswith("audio_embeddings.") for key in state) + assert any(key.startswith("final_proj.") for key in state) + + +def test_audio_and_text_embedding_shapes(model): + audio_tokens = _toy_codes(model, batch_size=2, num_frames=3) + audio_tokens[0, :, -1] = model.audio_eos_id + audio_embedded = model.embed_audio_tokens(audio_tokens) + assert audio_embedded.shape == (2, 3, model.cfg.embedding_dim) + assert audio_embedded.dtype == torch.float32 + assert torch.isfinite(audio_embedded).all() + + text_tokens, text_lens = _padded_token_tensor(model, ["abc", "de"]) + text_embedded = model.embed_text_tokens(text_tokens, text_lens=text_lens) + assert text_embedded.shape == (2, text_tokens.size(1), model.cfg.embedding_dim) + assert text_embedded.dtype == torch.float32 + assert torch.isfinite(text_embedded).all() + + +def test_stack_codes_round_trip_expected_shape(model): + codes = _toy_codes(model, batch_size=2, num_frames=4) + codes_lens = torch.tensor([4, 4], dtype=torch.long) + + stacked, stacked_lens = model.stack_codes( + codes, + codes_lens, + bos_id=model.audio_bos_id, + eos_id=model.audio_eos_id, + stacking_factor=2, + num_codebooks=model.num_audio_codebooks, + ) + unstacked, unstacked_lens = model.unstack_codes(stacked, stacked_lens, stacking_factor=2) + + assert stacked.shape == (2, model.num_audio_codebooks * 2, 2) + assert stacked_lens.tolist() == [2, 2] + assert unstacked.shape == codes.shape + assert unstacked_lens.tolist() == codes_lens.tolist() + torch.testing.assert_close(unstacked, codes) + + +def test_compute_loss_with_and_without_agent_mask(model): + _seed_everything() + batch_size, num_frames = 2, 5 + audio_codes = torch.randint( + low=0, + high=model.num_all_tokens_per_codebook, + size=(batch_size, model.num_audio_codebooks, num_frames), + ) + audio_codes_lens = torch.tensor([5, 3], dtype=torch.long) + logits = torch.randn( + batch_size, + num_frames, + model.num_audio_codebooks * model.num_all_tokens_per_codebook, + ) + agent_mask = torch.tensor( + [ + [True, True, False, False, False], + [False, True, True, False, False], + ], + dtype=torch.bool, + ) + + loss, loss_mask = model.compute_loss(logits, audio_codes, audio_codes_lens) + masked_loss, masked_loss_mask = model.compute_loss( + logits, + audio_codes, + audio_codes_lens, + agent_mask_target=agent_mask, + ) + + assert loss.ndim == 0 + assert masked_loss.ndim == 0 + assert torch.isfinite(loss) + assert torch.isfinite(masked_loss) + assert loss_mask.shape == (batch_size, model.num_audio_codebooks, num_frames) + assert masked_loss_mask.shape == loss_mask.shape + assert loss_mask.dtype == torch.bool + + +def test_prepare_audio_channel_embeddings_shapes(model): + audio_codes = _toy_codes(model, batch_size=2, num_frames=3) + audio_codes[1, :, 2] = 0 + audio_codes_lens = torch.tensor([3, 2], dtype=torch.long) + delay = torch.tensor([2, 1], dtype=torch.long) + agent_mask = torch.tensor( + [ + [True, False, False, False], + [False, True, False, False], + ], + dtype=torch.bool, + ) + + embeddings, lens, targets, target_lens, loss_agent_mask = model.prepare_audio_channel_embeddings( + audio_codes=audio_codes, + audio_codes_lens=audio_codes_lens, + delay=delay, + agent_mask=agent_mask, + ) + + assert embeddings.shape == (2, int((delay + target_lens).max().item()), model.cfg.embedding_dim) + assert embeddings.dtype == torch.float32 + assert torch.isfinite(embeddings).all() + assert lens.tolist() == (delay + target_lens).tolist() + assert targets.shape == (2, model.num_audio_codebooks, int(target_lens.max().item())) + assert target_lens.tolist() == [4, 3] + assert loss_agent_mask.shape == (2, targets.size(2)) + assert loss_agent_mask.dtype == torch.bool + + +def test_forward_with_inputs_embeds(model): + _seed_everything() + inputs_embeds = torch.randn(2, 6, model.cfg.embedding_dim) + attention_mask = torch.ones(2, 6, dtype=torch.bool) + + output = model.forward(inputs_embeds=inputs_embeds, attention_mask=attention_mask, use_cache=True) + + assert output.last_hidden_state.shape == inputs_embeds.shape + assert output.last_hidden_state.dtype == torch.float32 + assert torch.isfinite(output.last_hidden_state).all() + assert output.past_key_values is not None + + +def test_logits_to_audio_codes_schema(model): + logits = torch.zeros(2, 4, model.num_audio_codebooks * model.num_all_tokens_per_codebook) + expected_tokens = [] + for codebook_idx in range(model.num_audio_codebooks): + token_id = codebook_idx + 3 + expected_tokens.append(token_id) + offset = codebook_idx * model.num_all_tokens_per_codebook + logits[:, :, offset + token_id] = 5.0 + audio_codes_lens = torch.tensor([4, 2], dtype=torch.long) + + audio_codes = model.logits_to_audio_codes(logits, audio_codes_lens) + + assert audio_codes.shape == (2, model.num_audio_codebooks, 4) + assert audio_codes.dtype == torch.long + for codebook_idx, token_id in enumerate(expected_tokens): + assert audio_codes[0, codebook_idx].tolist() == [token_id] * 4 + assert audio_codes[1, :, 2:].eq(0).all() + + +def test_process_batch_smoke(model, toy_batch): + _seed_everything() + output = model.process_batch( + text=toy_batch["text"], + text_lens=toy_batch["text_lens"], + context_text_tokens=toy_batch["context_text_tokens"], + context_text_tokens_lens=toy_batch["context_text_tokens_lens"], + audio_codes=toy_batch["audio_codes"], + audio_codes_lens=toy_batch["audio_codes_lens"], + context_audio_codes=toy_batch["context_audio_codes"], + context_audio_codes_lens=toy_batch["context_audio_codes_lens"], + mode="val", + training_mode=model.training_modes[0], + agent_mask=toy_batch["agent_mask"], + ) + + assert output.selected_training_mode == model.default_inference_mode + assert torch.isfinite(output.loss) + assert torch.isfinite(output.codebook_loss) + assert output.phoneme_loss is None + assert output.local_transformer_loss is None + assert output.logits.shape[:2] == output.audio_codes_target.shape[0::2] + assert output.logits.shape[-1] == model.num_audio_codebooks * model.num_all_tokens_per_codebook + assert output.audio_codes_target.dtype == torch.long + assert output.context_audio_codes.shape[1] == model.num_audio_codebooks + + +def test_process_batch_with_multiturn_dataset_enabled(): + _seed_everything() + model = _make_easy_magpie_model(tiny_easy_magpie_cfg({"use_multiturn_dataset": True})) + batch = _toy_batch(model) + text = batch["text"].clone() + text[0, 0] = model.interruption_token_id + + output = model.process_batch( + text=text, + text_lens=batch["text_lens"], + context_text_tokens=batch["context_text_tokens"], + context_text_tokens_lens=batch["context_text_tokens_lens"], + audio_codes=batch["audio_codes"], + audio_codes_lens=batch["audio_codes_lens"], + context_audio_codes=batch["context_audio_codes"], + context_audio_codes_lens=batch["context_audio_codes_lens"], + mode="val", + training_mode=model.training_modes[0], + task=batch["task"], + agent_mask=batch["agent_mask"], + ) + + assert text[0, 0].item() == model.pad_id + assert torch.isfinite(output.loss) + assert torch.isfinite(output.codebook_loss) + assert output.local_transformer_loss is None + + +def test_process_batch_with_autoregressive_local_transformer(): + _seed_everything() + model = _make_easy_magpie_model( + tiny_easy_magpie_cfg( + { + "local_transformer_type": "autoregressive", + "local_transformer_hidden_dim": 32, + "local_transformer_n_layers": 1, + "local_transformer_n_heads": 4, + "local_transformer_loss_scale": 0.5, + } + ) + ) + batch = _toy_batch(model) + + output = model.process_batch( + text=batch["text"], + text_lens=batch["text_lens"], + context_text_tokens=batch["context_text_tokens"], + context_text_tokens_lens=batch["context_text_tokens_lens"], + audio_codes=batch["audio_codes"], + audio_codes_lens=batch["audio_codes_lens"], + context_audio_codes=batch["context_audio_codes"], + context_audio_codes_lens=batch["context_audio_codes_lens"], + mode="val", + training_mode=model.training_modes[0], + agent_mask=batch["agent_mask"], + ) + + assert torch.isfinite(output.loss) + assert torch.isfinite(output.codebook_loss) + assert torch.isfinite(output.local_transformer_loss) + assert output.local_transformer_logits is not None + assert output.local_transformer_logits.shape == output.logits.shape + + +def test_training_step_smoke(model, toy_batch): + _seed_everything() + model.train() + + with ( + patch.object(model, "log", lambda *args, **kwargs: None), + patch.object(model, "log_dict", lambda *args, **kwargs: None), + ): + loss = model.training_step(toy_batch, batch_idx=0) + + assert loss.ndim == 0 + assert torch.isfinite(loss) + assert loss.item() > 0 + + +def test_validation_step_smoke(model, toy_batch, tmp_path): + _seed_everything() + model.eval() + object.__setattr__( + model, + "_trainer", + SimpleNamespace(world_size=1, global_rank=0, local_rank=0, log_dir=str(tmp_path), current_epoch=0), + ) + + with patch.object(model, "log_val_audio_example", lambda *args, **kwargs: {}): + output = model.validation_step(toy_batch, batch_idx=1) + + assert set(output.keys()) == {"val_loss", "val_codebook_loss", "val_local_transformer_loss"} + assert torch.isfinite(output["val_loss"]) + assert torch.isfinite(output["val_codebook_loss"]) + assert output["val_local_transformer_loss"] is None + assert model.validation_step_outputs[-1] == output diff --git a/uv.lock b/uv.lock index a6be21cea190..31407662ce89 100644 --- a/uv.lock +++ b/uv.lock @@ -10168,4 +10168,4 @@ source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/e3/02/0f2892c661036d50ede074e376733dca2ae7c6eb617489437771209d4180/zipp-3.23.0.tar.gz", hash = "sha256:a07157588a12518c9d4034df3fbbee09c814741a33ff63c05fa29d26a2404166", size = 25547, upload-time = "2025-06-08T17:06:39.4Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/2e/54/647ade08bf0db230bfea292f893923872fd20be6ac6f53b2b936ba839d75/zipp-3.23.0-py3-none-any.whl", hash = "sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e", size = 10276, upload-time = "2025-06-08T17:06:38.034Z" }, -] +] \ No newline at end of file