diff --git a/base.py b/base.py new file mode 100644 index 0000000..bbff50d --- /dev/null +++ b/base.py @@ -0,0 +1,445 @@ +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Optional, Union, List, Dict, Any +import json +import numpy as np +import logging +from huggingface_hub import hf_hub_download +from sea_g2p import Normalizer + +# Configure logging +logger = logging.getLogger("Vieneu") + +class BaseVieneuTTS(ABC): + """ + Abstract base class for VieNeu-TTS implementations. + Provides shared functionality for voice management and common operations. + """ + + def __init__(self, codec_repo: Optional[str] = None, codec_device: str = "cpu"): + self.sample_rate = 24_000 + self.max_context = 4096 + self.hop_length = 480 + + # Default streaming parameters + self.streaming_overlap_frames = 1 + self.streaming_frames_per_chunk = 50 + self.streaming_lookforward = 5 + self.streaming_lookback = 50 + self.streaming_stride_samples = self.streaming_frames_per_chunk * self.hop_length + + self.assets_dir = Path(__file__).parent / "assets" + self._preset_voices: Dict[str, Any] = {} + self._default_voice: Optional[str] = None + self.normalizer = Normalizer() + self._ref_phoneme_cache: Dict[str, str] = {} + + # Watermarker placeholder + self.watermarker = None + self._init_watermarker() + + if codec_repo: + self._load_codec(codec_repo, codec_device) + + def _load_codec(self, codec_repo: str, codec_device: str) -> None: + """Universal codec loader for all backends.""" + logger.info(f"📦 Loading codec from: {codec_repo} on {codec_device} ...") + + if any(x in codec_repo.lower() for x in ["onnx", "vieneu-codec"]) or codec_repo == "neuphonic/neucodec-onnx-decoder-int8": + if codec_device != "cpu": + logger.warning("⚠️ ONNX decoder only runs on CPU. Ignoring device selection.") + try: + from .utils import NeuCodecOnnx + self.codec = NeuCodecOnnx.from_pretrained(codec_repo) + self._is_onnx_codec = True + return + except Exception as e: + logger.warning(f"Failed to load standalone ONNX decoder: {e}. Trying via neucodec package...") + try: + from neucodec import NeuCodecOnnxDecoder + self.codec = NeuCodecOnnxDecoder.from_pretrained(codec_repo) + self._is_onnx_codec = True + return + except ImportError: + raise ImportError( + "The 'onnxruntime' package is required for ONNX decoder. \n" + "Please install it via: pip install onnxruntime" + ) from e + + # For PyTorch codecs, check for torch first + try: + import torch + from neucodec import NeuCodec, DistillNeuCodec + + # Check MPS + if codec_device == "mps" and not torch.backends.mps.is_available(): + logger.warning("⚠️ MPS not available for codec, falling back to CPU") + codec_device = "cpu" + + if codec_repo == "neuphonic/neucodec": + self.codec = NeuCodec.from_pretrained(codec_repo) + elif codec_repo == "neuphonic/distill-neucodec": + self.codec = DistillNeuCodec.from_pretrained(codec_repo) + else: + raise ValueError(f"Unrecognized codec repository: {codec_repo}") + + self.codec.eval().to(codec_device) + except ImportError: + raise ImportError( + f"Codec '{codec_repo}' requires PyTorch. \n" + "To remain lightweight in Remote mode, please use 'neuphonic/neucodec-onnx-decoder-int8'. \n" + "Or install torch via: pip install vieneu[gpu]" + ) + + + def _init_watermarker(self) -> None: + """Initialize optional audio watermarker.""" + try: + import perth + self.watermarker = perth.PerthImplicitWatermarker() + logger.info("🔒 Audio watermarking initialized (Perth)") + except (ImportError, AttributeError): + self.watermarker = None + + def _load_voices(self, backbone_repo: Optional[str], hf_token: Optional[str] = None, clear_existing: bool = False) -> None: + """Unified voice loading for Local and Remote paths.""" + if not backbone_repo: + return + + path_obj = Path(backbone_repo) + if path_obj.exists(): + # Local Path (Dir or File) + if path_obj.is_dir(): + json_path = path_obj / "voices.json" + else: + json_path = path_obj.parent / "voices.json" + + if json_path.exists(): + self._load_voices_from_file(json_path, clear_existing=clear_existing) + else: + if clear_existing: + self._preset_voices.clear() + logger.warning(f"Validation Warning: Local path '{backbone_repo}' missing 'voices.json'.") + logger.warning(f"Falling back to Custom Voice Cloning mode.") + else: + # Remote Repo + if clear_existing: + self._preset_voices.clear() + + try: + self._load_voices_from_repo(backbone_repo, hf_token) + except Exception as e: + logger.warning(f"Could not load voices from repo '{backbone_repo}': {e}") + logger.warning(f"Falling back to Custom Voice Cloning mode.") + + def _load_voices_from_file(self, file_path: Path, clear_existing: bool = False) -> None: + """Load voices from a local JSON file.""" + try: + if not file_path.exists(): + logger.error(f"Voice file not found: {file_path}") + return + + with open(file_path, 'r', encoding='utf-8') as f: + try: + data = json.load(f) + except json.JSONDecodeError as e: + logger.error(f"Invalid JSON in voice file {file_path}: {e}") + return + + if "presets" in data: + if clear_existing: + self._preset_voices.clear() + logger.info("🧹 Cleared existing voices for replacement") + + # Merge into existing presets + self._preset_voices.update(data["presets"]) + logger.info(f"📢 Loaded {len(data['presets'])} voices from {file_path.name}") + + # Update default voice if provided + if "default_voice" in data and data["default_voice"]: + self._default_voice = data["default_voice"] + + except Exception as e: + logger.error(f"Failed to load voices from {file_path}: {e}") + + def _load_voices_from_repo(self, repo_id: str, hf_token: Optional[str] = None) -> None: + """Download and load voices.json from a HuggingFace repo.""" + voices_file = None + try: + # 1. Try normal download (checks for updates from server) + voices_file = hf_hub_download( + repo_id=repo_id, + filename="voices.json", + token=hf_token, + repo_type="model" + ) + except Exception: + # 2. Network error? Try to use cached version if available + logger.warning(f"Network check failed for voices.json. Trying local cache...") + try: + voices_file = hf_hub_download( + repo_id=repo_id, + filename="voices.json", + token=hf_token, + repo_type="model", + local_files_only=True + ) + logger.info(f"✅ Using cached voices.json") + except Exception: + # 3. No cache available either + pass + + if voices_file: + self._load_voices_from_file(Path(voices_file)) + else: + logger.warning(f"Repository '{repo_id}' is missing 'voices.json'. Falling back to Custom Voice mode.") + + def list_preset_voices(self) -> List[tuple[str, str]]: + """List available preset voices as (description, id).""" + return [ + (v.get("description", k) if isinstance(v, dict) else str(v), k) + for k, v in self._preset_voices.items() + ] + + def get_preset_voice(self, voice_name: Optional[str] = None) -> Dict[str, Any]: + """ + Get reference codes and text for a preset voice. + + Args: + voice_name: Name of voice. If None, uses default_voice. + + Returns: + dict: { 'codes': Union[np.ndarray, 'torch.Tensor'], 'text': str } + """ + if voice_name is None: + voice_name = self._default_voice + if voice_name is None: + if self._preset_voices: + voice_name = next(iter(self._preset_voices)) + else: + raise ValueError("No voice specified and no preset voices available.") + + if voice_name not in self._preset_voices: + raise ValueError(f"Voice '{voice_name}' not found. Available: {self.list_preset_voices()}") + + voice_data = self._preset_voices[voice_name] + codes = voice_data["codes"] + + # Only convert to torch if explicitly requested or if we're not in turbo mode + if isinstance(codes, list): + if codes and isinstance(codes[0], float): + codes = np.array(codes, dtype=np.float32) + else: + # Là integer token sequence (Standard mode) + try: + import torch + codes = torch.tensor(codes, dtype=torch.long) + except ImportError: + codes = np.array(codes, dtype=np.int64) + + return {"codes": codes, "text": voice_data["text"]} + + def get_ref_phonemes(self, ref_text: str) -> str: + """ + Get phonemized version of reference text, using cache if available. + """ + if ref_text not in self._ref_phoneme_cache: + from vieneu_utils.phonemize_text import phonemize_with_dict + self._ref_phoneme_cache[ref_text] = phonemize_with_dict(ref_text) + return self._ref_phoneme_cache[ref_text] + + def save(self, audio: np.ndarray, output_path: Union[str, Path]) -> None: + """Save audio waveform to a file.""" + import soundfile as sf + sf.write(str(output_path), audio, self.sample_rate) + + def encode_reference(self, ref_audio_path: Union[str, Path]) -> Union[np.ndarray, 'torch.Tensor']: + """ + Encode reference audio to codes. + + Args: + ref_audio_path: Path to the reference audio file. + + Returns: + Union[np.ndarray, torch.Tensor]: Encoded codes. + """ + import librosa + wav, _ = librosa.load(ref_audio_path, sr=16000, mono=True) + + # If we have an ONNX encoder or specialized turbo encoder, handle it here + # For now, default backends still use torch + try: + import torch + wav_tensor = torch.from_numpy(wav).float().unsqueeze(0).unsqueeze(0) # [1, 1, T] + + # Ensure device and dtype compatibility + if hasattr(self.codec, "device"): + wav_tensor = wav_tensor.to(self.codec.device) + + with torch.no_grad(): + ref_codes = self.codec.encode_code(audio_or_path=wav_tensor).squeeze(0).squeeze(0) + return ref_codes + except ImportError: + raise ImportError("Torch is required for encode_reference in the current backend. Please install torch or use a backend that supports standalone encoding.") + + def _decode(self, codes_str: str) -> np.ndarray: + """ + Decode speech tokens to audio waveform. + + Args: + codes_str: String containing speech tokens. + + Returns: + np.ndarray: Decoded audio waveform. + """ + from .utils import extract_speech_ids + speech_ids = extract_speech_ids(codes_str) + + if len(speech_ids) == 0: + raise ValueError("No valid speech tokens found in the output.") + + # Onnx decode + if getattr(self, "_is_onnx_codec", False): + codes = np.array(speech_ids, dtype=np.int32)[np.newaxis, np.newaxis, :] + recon = self.codec.decode_code(codes) + # Torch decode + else: + try: + import torch + with torch.no_grad(): + codes = torch.tensor(speech_ids, dtype=torch.long)[None, None, :] + if hasattr(self.codec, "device"): + codes = codes.to(self.codec.device) + + recon = self.codec.decode_code(codes) + if hasattr(recon, "cpu"): + recon = recon.cpu() + if hasattr(recon, "numpy"): + recon = recon.numpy() + except ImportError: + raise ImportError("Torch is required for the current codec backend. Please install torch or use an ONNX-based codec.") + + + return recon[0, 0, :] + + def _resolve_ref_voice( + self, + voice: Optional[Dict[str, Any]] = None, + ref_audio: Optional[Union[str, Path]] = None, + ref_codes: Optional[Union[np.ndarray, 'torch.Tensor']] = None, + ref_text: Optional[str] = None + ) -> tuple[Union[np.ndarray, 'torch.Tensor'], str]: + """Resolve reference voice codes and text.""" + if voice is not None: + ref_codes = voice.get('codes', ref_codes) + ref_text = voice.get('text', ref_text) + + if ref_audio is not None and ref_codes is None: + ref_codes = self.encode_reference(ref_audio) + elif self._default_voice and (ref_codes is None or ref_text is None): + try: + voice_data = self.get_preset_voice(None) + ref_codes = voice_data['codes'] + ref_text = voice_data['text'] + except Exception: + pass + + if ref_codes is None or ref_text is None: + raise ValueError("Must provide either 'voice' dict or both 'ref_codes' and 'ref_text'.") + + return ref_codes, ref_text + + def _apply_watermark(self, wav: np.ndarray) -> np.ndarray: + """Apply watermark to audio if enabled.""" + if self.watermarker: + return self.watermarker.apply_watermark(wav, sample_rate=self.sample_rate) + return wav + + def to_list(self, codes: Any) -> List[int]: + """Convert reference codes (Tensor, Array, List) to a Python list of integers.""" + if isinstance(codes, list): + return codes + if isinstance(codes, np.ndarray): + return codes.flatten().tolist() + + # Check for torch without importing it at module level + try: + import torch + if isinstance(codes, torch.Tensor): + return codes.flatten().tolist() + except ImportError: + pass + + # Fallback for other array-like types + if hasattr(codes, "tolist"): + return codes.flatten().tolist() if hasattr(codes, "flatten") else codes.tolist() + + return list(codes) + + def _format_prompt( + self, + ref_codes: Any, + ref_text: str, + input_text: str, + ref_phonemes: Optional[str] = None, + input_phonemes: Optional[str] = None, + use_chat_format: bool = False, + emotion_tag: Optional[str] = None + ) -> str: + """ + Format the prompt for the TTS model. + Common implementation for LMDeploy (Fast) and Remote backends. + Standard backend uses a specialized chat template via tokenizer. + + Args: + use_chat_format: If True, wraps the prompt with chat-style user/assistant + tokens (used by VieNeu-TTS GPU model). If False (default), + returns a compact prompt without those wrappers. + """ + ref_codes_list = self.to_list(ref_codes) + + # Import inside method to avoid potential circular dependencies between + # base TTS and phonemization utilities. + from vieneu_utils.phonemize_text import phonemize_with_dict + + ref_text_phones = ref_phonemes if ref_phonemes else self.get_ref_phonemes(ref_text) + input_text_phones = input_phonemes if input_phonemes else phonemize_with_dict(input_text, skip_normalize=True) + codes_str = "".join([f"<|speech_{idx}|>" for idx in ref_codes_list]) + + emotion_prefix = emotion_tag if emotion_tag else "" + + if use_chat_format: + return ( + f"user: Convert the text to speech:<|TEXT_PROMPT_START|>{emotion_prefix}{ref_text_phones} {input_text_phones}" + f"<|TEXT_PROMPT_END|>\nassistant:<|SPEECH_GENERATION_START|>{codes_str}" + ) + return ( + f"<|TEXT_PROMPT_START|>{emotion_prefix}{ref_text_phones} {input_text_phones}" + f"<|TEXT_PROMPT_END|><|SPEECH_GENERATION_START|>{codes_str}" + ) + + @abstractmethod + def infer(self, text: str, apply_watermark: bool = True, **kwargs: Any) -> np.ndarray: + """Main inference method for single text.""" + pass + + @abstractmethod + def infer_batch(self, texts: List[str], apply_watermark: bool = True, **kwargs: Any) -> List[np.ndarray]: + """Main inference method for batch processing.""" + pass + + def close(self) -> None: + """Release resources.""" + pass + + def __enter__(self) -> 'BaseVieneuTTS': + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.close() + + def __del__(self): + try: + self.close() + except Exception: + pass diff --git a/boards/WEACT_STM32WB55CGU6/STM32WB55CGU6.jpg b/boards/WEACT_STM32WB55CGU6/STM32WB55CGU6.jpg new file mode 100644 index 0000000..12d40a5 Binary files /dev/null and b/boards/WEACT_STM32WB55CGU6/STM32WB55CGU6.jpg differ diff --git a/gradio_main.py b/gradio_main.py new file mode 100644 index 0000000..32647dd --- /dev/null +++ b/gradio_main.py @@ -0,0 +1,2005 @@ +import sys +import io +if sys.platform == "win32": + try: + sys.stdout.reconfigure(encoding="utf-8") + sys.stderr.reconfigure(encoding="utf-8") + except Exception: + sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace") + sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace") + +import gradio as gr +print("⏳ Đang khởi động VieNeu-TTS... Vui lòng chờ...") +import soundfile as sf +import tempfile +from vieneu import Vieneu +import os +import time +import numpy as np +import queue +import threading +import yaml +import uuid +from vieneu_utils.core_utils import split_text_into_chunks, join_audio_chunks, env_bool, get_silence_duration_v2 +from vieneu_utils.phonemize_text import phonemize_to_chunks +from sea_g2p import Normalizer +import gc + +from apps.ui_utils import ( + _format_duration, + _split_estimate_status, + wrap_with_estimate, + cleanup_gpu_memory, + get_ref_text_cached, + on_codec_change, + validate_audio_duration, + on_custom_id_change +) +from apps.ui_constants import ( + theme, + css, + head_html, + DEFAULT_TEXT_GPU, + DEFAULT_TEXT_TURBO +) + +# --- CONSTANTS & CONFIG --- +CONFIG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "config.yaml") +try: + with open(CONFIG_PATH, "r", encoding="utf-8") as f: + _config = yaml.safe_load(f) or {} +except Exception as e: + raise RuntimeError(f"Không thể đọc config.yaml: {e}") + +BACKBONE_CONFIGS = _config.get("backbone_configs", {}) +CODEC_CONFIGS = _config.get("codec_configs", {}) + +# Refilter and Simplify Configs per requirements +HAS_GPU = False +try: + import torch + HAS_GPU = torch.cuda.is_available() or (sys.platform == "darwin" and torch.backends.mps.is_available()) +except ImportError: + pass + +filtered_backbones = {} +if HAS_GPU: + filtered_backbones["VieNeu-TTS-v2 (GPU)"] = { + "repo": "pnnbao-ump/VieNeu-TTS-v2", + "supports_streaming": False, + "description": "VieNeu-TTS Version 2 - hỗ trợ song ngữ (Anh-Việt) và chế độ podcast" + } + filtered_backbones["VieNeu-TTS (GPU)"] = { + "repo": "pnnbao-ump/VieNeu-TTS", + "supports_streaming": False, + "description": "VieNeu-TTS Version 1 - ổn định, production-ready" + } + filtered_backbones["VieNeu-TTS-0.3B-ngoc-huyen (GPU)"] = { + "repo": "pnnbao-ump/VieNeu-TTS-0.3B-ngoc-huyen", + "supports_streaming": False, + "description": "VieNeu-TTS-0.3B - Ngọc Huyền" + } + +filtered_backbones["VieNeu-TTS-v2 (CPU)"] = { + "repo": "pnnbao-ump/VieNeu-TTS-v2", + "gguf_filename": "VieNeu-TTS-v2-Q4-K-M.gguf", + "supports_streaming": False, + "description": "VieNeu-TTS-v2 (CPU) - GGUF Q4_K_M, hỗ trợ song ngữ & podcast" +} + +filtered_backbones["VieNeu-TTS-v2-Turbo (CPU)"] = { + "repo": "pnnbao-ump/VieNeu-TTS-v2-Turbo-GGUF", + "supports_streaming": True, + "description": "VieNeu-TTS-v2-Turbo - Siêu nhanh, tối ưu tuyệt đối cho CPU & Thiết bị yếu" +} + +BACKBONE_CONFIGS = filtered_backbones + +filtered_codecs = { + "NeuCodec (Distill)": { + "repo": "neuphonic/distill-neucodec", + "description": "Codec mặc định cho model GPU", + "use_preencoded": False + }, + "NeuCodec (ONNX)": { + "repo": "neuphonic/neucodec-onnx-decoder-int8", + "description": "Codec siêu nhẹ, tối ưu cho CPU (ONNX)", + "use_preencoded": False + }, + "VieNeu-Codec": { + "repo": "pnnbao-ump/VieNeu-Codec", + "description": "Codec tối ưu cho Turbo v2 (ONNX)", + "use_preencoded": False + } +} +CODEC_CONFIGS = filtered_codecs + +_text_settings = _config.get("text_settings", {}) +MAX_CHARS_PER_CHUNK = _text_settings.get("max_chars_per_chunk", 256) +MAX_TOTAL_CHARS_STREAMING = _text_settings.get("max_total_chars_streaming", 3000) + +if not BACKBONE_CONFIGS or not CODEC_CONFIGS: + raise ValueError("config.yaml thiếu backbone_configs hoặc codec_configs") + +# --- 1. MODEL CONFIGURATION --- +# Global model instance +tts = None +current_backbone = None +current_codec = None +model_loaded = False +using_lmdeploy = False +PRESET_VOICES_CACHE = [] # List of all voices (tuples or strings) +CONV_VOICES_CACHE = [] # Filtered list for conversation (podcast=True) +MAX_SPEAKERS = 8 # Max concurrent speakers in conversation tab + +# Normalizer (module-level singleton) +_text_normalizer = Normalizer() + +def get_available_devices() -> list[str]: + """Get list of available devices for current platform.""" + devices = ["Auto", "CPU"] + + try: + import torch + if sys.platform == "darwin" and torch.backends.mps.is_available(): + devices.append("MPS") + elif torch.cuda.is_available(): + devices.append("CUDA") + except ImportError: + pass + + return devices + +def get_model_status_message() -> str: + """Reconstruct status message from global state""" + global model_loaded, tts, using_lmdeploy, current_backbone, current_codec + if not model_loaded or tts is None: + return "⏳ Chưa tải model." + + if "v2-Turbo" in (current_backbone or ""): + backend_name = "⚡ Turbo (v2)" + elif using_lmdeploy: + backend_name = "🚀 LMDeploy (Optimized)" + else: + backend_name = "📦 Standard" + + # We don't track the exact device strings perfectly in global state, so we estimate + try: + import torch + has_mps = torch.backends.mps.is_available() + has_cuda = torch.cuda.is_available() + except: + has_mps = has_cuda = False + + device_info = "GPU (CUDA)" if (using_lmdeploy or "CUDA" in (current_backbone or "")) else ("MPS (Metal)" if has_mps else "Auto") + + if "v2-Turbo" in (current_backbone or ""): + codec_device = "GPU/MPS" if (has_cuda or has_mps) else "CPU" + elif "ONNX" in (current_codec or ""): + codec_device = "CPU" + else: + codec_device = "GPU/MPS" if (has_cuda or has_mps) else "CPU" + + preencoded_note = "" + opt_info = "" + if using_lmdeploy and hasattr(tts, 'get_optimization_stats'): + stats = tts.get_optimization_stats() + opt_info = ( + f"\n\n🔧 Tối ưu hóa:" + f"\n • Triton: {'✅' if stats['triton_enabled'] else '❌'}" + f"\n • Max Batch Size (Default): {stats.get('max_batch_size', 'N/A')}" + f"\n • Reference Cache: {stats['cached_references']} voices" + f"\n • Prefix Caching: ❌" + ) + + return ( + f"✅ Model đã tải thành công!\n\n" + f"🔧 Backend: {backend_name}\n" + f" Parrot: {current_backbone} on {device_info}\n" + f"🎵 Codec: {current_codec} on {codec_device}{preencoded_note}{opt_info}" + ) + +def restore_ui_state(): + """Update UI components based on persistence""" + global model_loaded + msg = get_model_status_message() + return ( + msg, + gr.update(interactive=model_loaded), # btn_generate + gr.update(interactive=model_loaded), # btn_generate_conv + gr.update(interactive=False) # btn_stop + ) + +def should_use_lmdeploy(backbone_choice: str, device_choice: str) -> bool: + """Determine if we should use LMDeploy backend.""" + # LMDeploy not supported on macOS + if sys.platform == "darwin": + return False + + if "gguf" in backbone_choice.lower() or "v2-turbo" in backbone_choice.lower(): + return False + + try: + import torch + if device_choice == "Auto": + has_gpu = torch.cuda.is_available() + elif device_choice == "CUDA": + has_gpu = torch.cuda.is_available() + else: + has_gpu = False + return has_gpu + except ImportError: + return False + +def load_model(backbone_choice: str, codec_choice: str, device_choice: str, + force_lmdeploy: bool, custom_model_id: str = "", custom_base_model: str = "", + custom_hf_token: str = ""): + """Load model with optimizations and max batch size control""" + global tts, current_backbone, current_codec, model_loaded, using_lmdeploy + lmdeploy_error_reason = None + model_loaded = False # Ensure we don't try to use a half-loaded model + + # Helper for slot updates (initially no change) + slot_no_updates = [gr.update()] * MAX_SPEAKERS + + yield ( + "⏳ Đang tải model với tối ưu hóa... Lưu ý: Quá trình này sẽ tốn thời gian. Vui lòng kiên nhẫn.", + gr.update(interactive=False), # btn_generate + gr.update(interactive=False), # btn_generate_conv + gr.update(interactive=False), # btn_load + gr.update(interactive=False), # btn_stop + gr.update(), # voice_select + gr.update(), gr.update(), gr.update(), gr.update(), # tab_p, tab_c, tab_sel, mode_state + gr.update(), # conv_tab + *slot_no_updates + ) + + try: + # Cleanup before loading new model + if tts is not None: + tts = None # Reset instead of del to avoid NameError if load fails + cleanup_gpu_memory() + + # Prepare Backbone Config/Repo + custom_loading = False + is_merged_lora = False + + if backbone_choice == "Custom Model": + custom_loading = True + if not custom_model_id or not custom_model_id.strip(): + yield ( + "❌ Lỗi: Vui lòng nhập Model ID cho Custom Model.", + gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), gr.update(interactive=False), gr.update(), + gr.update(), gr.update(), gr.update(), gr.update(), + gr.update(), # conv_tab + *slot_no_updates + ) + return + + # Check if it is a LoRA to merge + if "lora" in custom_model_id.lower(): + # Merging mode + print(f"🔄 Detected LoRA in name. preparing merge with base: {custom_base_model}") + if custom_base_model not in BACKBONE_CONFIGS: + yield ( + f"❌ Lỗi: Base Model '{custom_base_model}' không hợp lệ.", + gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), gr.update(interactive=False), + gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), + gr.update(), # conv_tab + *slot_no_updates + ) + return + + base_config = BACKBONE_CONFIGS[custom_base_model] + backbone_config = { + "repo": base_config["repo"], # Load base first + "supports_streaming": base_config["supports_streaming"], + "description": f"Custom Merged: {custom_model_id} + {custom_base_model}" + } + is_merged_lora = True + else: + # Normal custom model + backbone_config = { + "repo": custom_model_id.strip(), + "supports_streaming": False, # Assume false for unknown + "description": f"Custom Model: {custom_model_id}" + } + else: + backbone_config = BACKBONE_CONFIGS[backbone_choice] + + codec_config = CODEC_CONFIGS[codec_choice] + use_lmdeploy = False + + # Override LMDeploy if custom + if custom_loading: + if "gguf" in backbone_config['repo'].lower() or "v2-turbo" in backbone_config['repo'].lower(): + # GGUF must use Standard/Turbo backend + use_lmdeploy = False + elif is_merged_lora: + # LoRA can use LMDeploy if we merge first (checked logic below) or Standard + use_lmdeploy = force_lmdeploy and should_use_lmdeploy(custom_base_model, device_choice) + else: + # Full custom model (e.g. finetune) + use_lmdeploy = force_lmdeploy and should_use_lmdeploy("VieNeu-TTS (GPU)", device_choice) # Assume GPU compatible? + # Use LMDeploy only if Force LMDeploy is set and the model is compatible + # NOTE: For VieNeu-v2-Turbo, we handle LMDeploy inside TurboGPUVieNeuTTS class, + # so we set use_lmdeploy = False here to avoid generic FastVieNeuTTS loading. + # NOTE: For custom_loading, the block above already decided use_lmdeploy correctly + # (e.g. False for GGUF repos). Do NOT override that decision here. + if "v2-Turbo" in backbone_choice: + should_use_generic_fast = False + elif custom_loading: + should_use_generic_fast = False # already handled above per repo name + else: + should_use_generic_fast = force_lmdeploy and should_use_lmdeploy(backbone_choice, device_choice) + + if should_use_generic_fast: + use_lmdeploy = True + + if use_lmdeploy: + lmdeploy_error_reason = None + print(f"🚀 Using LMDeploy backend with optimizations") + + backbone_device = "cuda" + + if "ONNX" in codec_choice: + codec_device = "cpu" + else: + try: + import torch + codec_device = "cuda" if torch.cuda.is_available() else "cpu" + except ImportError: + codec_device = "cpu" + + # Special handling for Custom LoRA + LMDeploy -> Merge & Save + target_backbone_repo = backbone_config["repo"] + + if custom_loading and is_merged_lora: + safe_name = custom_model_id.strip().replace("/", "_").replace("\\", "_").replace(":", "") + cache_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "merged_models_cache", safe_name) + target_backbone_repo = os.path.abspath(cache_dir) + + # Check if already merged (and voices.json exists) + if not os.path.exists(cache_dir) or not os.path.exists(os.path.join(cache_dir, "vocab.json")): + print(f"🔄 Merging LoRA for LMDeploy optimization: {cache_dir}") + if os.path.exists(cache_dir): + print(" ⚠️ Detected incomplete cache, rebuilding...") + yield ( + f"⏳ Đang merge và lưu model LoRA để tối ưu cho LMDeploy (thao tác này chỉ chạy một lần)...", + gr.update(interactive=False), + gr.update(interactive=False), + gr.update(interactive=False), + gr.update(interactive=False), + gr.update(), + gr.update(), gr.update(), gr.update(), gr.update(), + gr.update(), # conv_tab + *slot_no_updates + ) + + try: + # Use GPU for merging if available for speed + # We use the Base Model specified + from vieneu.standard import VieNeuTTS + base_repo = BACKBONE_CONFIGS[custom_base_model]["repo"] + merge_device = "cuda" if torch.cuda.is_available() else "cpu" + + print(f" • Loading base: {base_repo} ({merge_device})") + temp_tts = VieNeuTTS( + backbone_repo=base_repo, + backbone_device=merge_device, + codec_repo=codec_config["repo"], + codec_device="cpu", # Codec unused for merging, keep on CPU + hf_token=custom_hf_token + ) + + print(f" • Loading Adapter: {custom_model_id}") + temp_tts.load_lora_adapter(custom_model_id.strip(), hf_token=custom_hf_token) + + print(f" • Merging...") + if hasattr(temp_tts.backbone, "merge_and_unload"): + temp_tts.backbone = temp_tts.backbone.merge_and_unload() + + print(f" • Saving to cache: {cache_dir}") + temp_tts.backbone.save_pretrained(cache_dir) + temp_tts.tokenizer.save_pretrained(cache_dir) + + # Fix for LMDeploy: Explicitly save legacy tokenizer files (vocab.json, merges.txt) + # because LMDeploy/Transformers might default to slow tokenizer if fast one has issues, + # and save_pretrained on fast tokenizer sometimes omits legacy files. + try: + print(" • Ensuring legacy tokenizer files...") + from transformers import AutoTokenizer + slow_tokenizer = AutoTokenizer.from_pretrained(base_repo, use_fast=False) + slow_tokenizer.save_pretrained(cache_dir) + except Exception as e: + print(f" ⚠️ Warning: Could not save slow tokenizer files: {e}") + + # Save voices.json to cache directory so FastVieNeuTTS can find it + print(f" • Saving voices definition...") + import json + voices_json_path = os.path.join(cache_dir, "voices.json") + voices_content = { + "meta": { "note": "Automatically generated during LoRA merge" }, + "default_voice": temp_tts._default_voice, + "presets": temp_tts._preset_voices + } + with open(voices_json_path, 'w', encoding='utf-8') as f: + json.dump(voices_content, f, ensure_ascii=False, indent=2) + + del temp_tts + cleanup_gpu_memory() + print(" ✅ Merge & Save successfully!") + + except Exception as e: + import traceback + traceback.print_exc() + raise RuntimeError(f"Failed to merge & save LoRA for LMDeploy: {e}") + + print(f"📦 Loading optimized model...") + print(f" Backbone: {target_backbone_repo} on {backbone_device}") + print(f" Codec: {codec_config['repo']} on {codec_device}") + print(f" Triton: Enabled") + + try: + from vieneu.fast import FastVieNeuTTS + tts = FastVieNeuTTS( + backbone_repo=target_backbone_repo, + backbone_device=backbone_device, + codec_repo=codec_config["repo"], + codec_device=codec_device, + memory_util=0.3, + tp=1, + enable_prefix_caching=False, + enable_triton=True, + hf_token=custom_hf_token + ) + using_lmdeploy = True + + # Legacy caching removed + print(f" ✅ Optimized backend initialized") + + except Exception as e: + import traceback + traceback.print_exc() + + error_str = str(e) + if "$env:CUDA_PATH" in error_str: + lmdeploy_error_reason = "Không tìm thấy biến môi trường CUDA_PATH. Vui lòng cài đặt NVIDIA GPU Computing Toolkit." + else: + lmdeploy_error_reason = f"{error_str}" + + yield ( + f"⚠️ LMDeploy Init Error: {lmdeploy_error_reason}. Đang loading model với backend mặc định - tốc độ chậm hơn so với lmdeploy...", + gr.update(interactive=False), + gr.update(interactive=False), + gr.update(interactive=False), + gr.update(interactive=False), + gr.update(), + gr.update(), gr.update(), gr.update(), gr.update(), + gr.update(), # conv_tab + *slot_no_updates + ) + time.sleep(1) + use_lmdeploy = False + using_lmdeploy = False + + if not use_lmdeploy: + print(f"📦 Using original backend") + + if device_choice == "Auto": + repo_lower = backbone_config['repo'].lower() + is_gguf_backbone = "gguf" in repo_lower + + if is_gguf_backbone: + # GGUF backbones (llama-cpp-python): Metal on Mac, CUDA on Windows/Linux + if sys.platform == "darwin": + backbone_device = "gpu" # llama-cpp-python uses Metal via n_gpu_layers + else: + try: + import torch + backbone_device = "gpu" if torch.cuda.is_available() else "cpu" + except ImportError: + backbone_device = "cpu" + else: + # PyTorch backbones (Standard, Turbo GPU): use native torch device + try: + import torch + if sys.platform == "darwin": + backbone_device = "mps" if torch.backends.mps.is_available() else "cpu" + else: + backbone_device = "cuda" if torch.cuda.is_available() else "cpu" + except ImportError: + backbone_device = "cpu" + + # Codec device + if "ONNX" in codec_choice: + codec_device = "cpu" + else: + try: + import torch + if sys.platform == "darwin": + codec_device = "mps" if torch.backends.mps.is_available() else "cpu" + else: + codec_device = "cuda" if torch.cuda.is_available() else "cpu" + except ImportError: + codec_device = "cpu" + + elif device_choice == "MPS": + backbone_device = "mps" + codec_device = "mps" if "ONNX" not in codec_choice else "cpu" + + else: + backbone_device = device_choice.lower() + codec_device = device_choice.lower() + + if "ONNX" in codec_choice: + codec_device = "cpu" + + if "gguf" in backbone_config['repo'].lower() and backbone_device == "cuda": + # Only Llama-cpp (GGUF) uses the 'gpu' string for CUDA + backbone_device = "gpu" + + print(f"📦 Loading model...") + print(f" Backbone: {backbone_config['repo']} on {backbone_device}") + print(f" Codec: {codec_config['repo']} on {codec_device}") + + if "v2-Turbo" in backbone_choice: + # VieNeu v2 Turbo uses the dedicated backend + print(" ⚡ Mode: Turbo") + mode = "turbo_gpu" if "GPU" in backbone_choice else "turbo" + tts = Vieneu( + mode=mode, + backbone_repo=backbone_config["repo"], + decoder_repo=codec_config["repo"], + device=backbone_device, + backend="lmdeploy" if force_lmdeploy and "GPU" in backbone_choice else "standard", + hf_token=custom_hf_token + ) + else: + from vieneu.standard import VieNeuTTS + tts = VieNeuTTS( + backbone_repo=backbone_config["repo"], + backbone_device=backbone_device, + codec_repo=codec_config["repo"], + codec_device=codec_device, + hf_token=custom_hf_token, + gguf_filename=backbone_config.get("gguf_filename") + ) + + # Perform LoRA Merge if needed (ONLY for Standard Backend) + # For LMDeploy, we handled it above by saving to disk + if is_merged_lora and custom_loading and not using_lmdeploy: + yield ( + f"🔄 Đang tải và merge LoRA adapter: {custom_model_id}...", + gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(), + gr.update(), gr.update(), gr.update(), gr.update(), + gr.update(), # conv_tab + *slot_no_updates + ) + try: + # 1. Load Adapter + tts.load_lora_adapter(custom_model_id.strip(), hf_token=custom_hf_token) + + # 2. Merge and Unload + # Check if backbone matches expected type for merge + if hasattr(tts, 'backbone') and hasattr(tts.backbone, 'merge_and_unload'): + print(" 🔄 Merging LoRA into backbone...") + tts.backbone = tts.backbone.merge_and_unload() + + # Reset LoRA state so it behaves like a normal model + tts._lora_loaded = False + tts._current_lora_repo = None + print(" ✅ Merged successfully!") + else: + print(" ⚠️ Warning: Model does not support merge_and_unload, keeping adapter active.") + + except Exception as e: + raise RuntimeError(f"Failed to merge LoRA: {e}") + + using_lmdeploy = False + + current_backbone = backbone_choice + current_codec = codec_choice + model_loaded = True + + # Success message with optimization info + backend_name = "🚀 LMDeploy (Optimized)" if using_lmdeploy else "📦 Standard" + device_info = "cuda" if use_lmdeploy else (backbone_device if not use_lmdeploy else "N/A") + + streaming_support = "✅ Có" if backbone_config['supports_streaming'] else "❌ Không" + preencoded_note = "\n⚠️ Codec này cần sử dụng pre-encoded codes (.pt files)" if codec_config['use_preencoded'] else "" + + opt_info = "" + if using_lmdeploy and hasattr(tts, 'get_optimization_stats'): + stats = tts.get_optimization_stats() + opt_info = ( + f"\n\n🔧 Tối ưu hóa:" + f"\n • Triton: {'✅' if stats['triton_enabled'] else '❌'}" + f"\n • Max Batch Size (Default): {stats.get('max_batch_size', 'N/A')}" + f"\n • Reference Cache: {stats['cached_references']} voices" + f"\n • Prefix Caching: ❌" + ) + + warning_msg = "" + if lmdeploy_error_reason: + warning_msg = ( + f"\n\n⚠️ **Cảnh báo:** Không thể kích hoạt LMDeploy (Optimized Backend) do lỗi sau:\n" + f"👉 {lmdeploy_error_reason}\n" + f"💡 Hệ thống đã tự động chuyển về chế độ Standard (chậm hơn)." + ) + + success_msg = get_model_status_message() + if warning_msg: + success_msg += warning_msg + + # Prepare voice update + try: + # Get voices with descriptions for UI from SDK + voices = tts.list_preset_voices() + except Exception: + voices = [] + + has_voices = len(voices) > 0 + + if has_voices: + default_v = tts._default_voice + + # Helper to get values list + is_tuple = (len(voices) > 0 and isinstance(voices[0], tuple)) + voice_values = [v[1] for v in voices] if is_tuple else voices + + if not default_v and voice_values: + default_v = voice_values[0] + + # Ensure default_v is in the list and selected correctly + if default_v and default_v not in voice_values: + if is_tuple: + # Try to find a nice description if possible, else use ID + voices.append((default_v, default_v)) + else: + voices.append(default_v) + + # Sort voices by name/label for better UX + if is_tuple: + voices.sort(key=lambda x: str(x[0])) + else: + voices.sort() + + voice_update = gr.update(choices=voices, value=default_v, interactive=True) + + global PRESET_VOICES_CACHE, CONV_VOICES_CACHE + PRESET_VOICES_CACHE = voices + + # Filter voices for conversation tab (podcast=True) + # Handle both boolean True/False and string "True"/"False" + def _check_podcast(v_id): + val = tts._preset_voices.get(v_id, {}).get('podcast', True) + if isinstance(val, str): + return val.strip().lower() == "true" + return bool(val) + + CONV_VOICES_CACHE = [v for v in voices if _check_podcast(v[1])] + + slot_dd_update = gr.update(choices=CONV_VOICES_CACHE) + + # Show Standard Tabs + tab_p = gr.update(visible=True) + tab_c = gr.update(visible=True) + tab_sel = gr.update(selected="preset_mode") + mode_state = "preset_mode" + else: + # Missing voices.json case + msg = "⚠️ Không tìm thấy file voices.json. Vui lòng dùng Tab Voice Cloning." + voice_update = gr.update(choices=[msg], value=msg, interactive=False) + slot_dd_update = gr.update(choices=[]) + + # Show Preset Tab (to see message) and Custom Tab + tab_p = gr.update(visible=True) + tab_c = gr.update(visible=True) + tab_sel = gr.update(selected="preset_mode") + mode_state = "preset_mode" + + # Check if v2 for conversation tab + is_v2 = (backbone_choice == "VieNeu-TTS-v2 (GPU)" or backbone_choice == "VieNeu-TTS-v2 (CPU)") + conv_tab_update = gr.update(visible=is_v2) + + # Update all MAX_SPEAKERS slot dropdowns + slot_updates = [slot_dd_update] * MAX_SPEAKERS + + yield ( + success_msg, + gr.update(interactive=True), # btn_generate + gr.update(interactive=True), # btn_generate_conv + gr.update(interactive=True), # btn_load + gr.update(interactive=False), # btn_stop + voice_update, + tab_p, tab_c, tab_sel, mode_state, + conv_tab_update, + *slot_updates + ) + + except Exception as e: + import traceback + traceback.print_exc() + model_loaded = False + using_lmdeploy = False + + if "$env:CUDA_PATH" in str(e): + yield ( + "❌ Lỗi khi tải model: Không tìm thấy biến môi trường CUDA_PATH. Vui lòng cài đặt NVIDIA GPU Computing Toolkit (https://developer.nvidia.com/cuda/toolkit)", + gr.update(interactive=False), + gr.update(interactive=False), # btn_generate_conv + gr.update(interactive=True), # btn_load + gr.update(interactive=False), # btn_stop + gr.update(), # voice_select + gr.update(), gr.update(), gr.update(), gr.update(), + gr.update(), # conv_tab + *slot_no_updates + ) + else: + yield ( + f"❌ Lỗi khi tải model: {str(e)}", + gr.update(interactive=False), + gr.update(interactive=False), + gr.update(interactive=True), + gr.update(interactive=False), + gr.update(), + gr.update(), gr.update(), gr.update(), gr.update(), + gr.update(), # conv_tab + *slot_no_updates + ) + + +def resolve_voice_id(v_id: str) -> str: + """Robustly resolve voice ID, handling both display labels and internal IDs.""" + if not v_id: + return v_id + + global PRESET_VOICES_CACHE + if not PRESET_VOICES_CACHE: + return v_id + + for item in PRESET_VOICES_CACHE: + if isinstance(item, (list, tuple)) and len(item) >= 2: + label, value = item[0], item[1] + if v_id == value or v_id == label: + return value + else: + if v_id == item: + return item + + return v_id + +# --- TEXT SAFETY HELPERS --- +# Chặn input vượt context 2048 token của backbone. +MAX_INPUT_TOKENS = _text_settings.get("max_input_tokens", 1800) + + +class _ChunkProxy: + """Giữ nguyên các thuộc tính gốc của chunk, chỉ thay text.""" + + def __init__(self, base_chunk, text: str): + self._base_chunk = base_chunk + self.text = text + + def __getattr__(self, name): + return getattr(self._base_chunk, name) + + +def _get_model_tokenizer(): + """Lấy tokenizer từ tts hiện tại nếu có.""" + global tts + if tts is None: + return None + + candidate_attrs = ("tokenizer", "llm_tokenizer", "text_tokenizer", "processor") + for attr in candidate_attrs: + tok = getattr(tts, attr, None) + if tok is not None: + return tok + + backbone = getattr(tts, "backbone", None) + if backbone is not None: + for attr in candidate_attrs: + tok = getattr(backbone, attr, None) + if tok is not None: + return tok + + return None + + +def _count_tokens(tokenizer, text: str) -> int: + """Đếm token an toàn, fallback 0 nếu tokenizer lỗi.""" + if tokenizer is None: + return 0 + + try: + encoded = tokenizer(text, add_special_tokens=False) + input_ids = getattr(encoded, "input_ids", None) + + if input_ids is None and isinstance(encoded, dict): + input_ids = encoded.get("input_ids") + + if input_ids is None: + return 0 + + if isinstance(input_ids, list) and input_ids and isinstance(input_ids[0], (list, tuple)): + return len(input_ids[0]) + + return len(input_ids) + except Exception: + return 0 + + +def _split_text_by_token_limit(text: str, tokenizer, max_tokens: int = MAX_INPUT_TOKENS): + """Cắt text theo token, giữ style cắt theo từ.""" + text = (text or "").strip() + if not text: + return [] + + if tokenizer is None: + return [text] + + words = text.split() + if not words: + return [text] + + chunks = [] + current_words = [] + + for word in words: + candidate = " ".join(current_words + [word]).strip() + + if current_words and _count_tokens(tokenizer, candidate) > max_tokens: + chunks.append(" ".join(current_words).strip()) + current_words = [word] + else: + current_words.append(word) + + if current_words: + chunks.append(" ".join(current_words).strip()) + + return chunks or [text] + + +def _build_safe_text_chunks(raw_text: str, max_chars: int, is_v2_turbo: bool): + """ + Tạo chunk theo flow cũ của repo, rồi chặn thêm theo token. + - V2 Turbo: giữ object chunk gốc, chỉ thay .text nếu cần. + - Standard: giữ list str như cũ. + """ + tokenizer = _get_model_tokenizer() + + if is_v2_turbo: + base_chunks = phonemize_to_chunks(raw_text, max_chars=max_chars) + if tokenizer is None: + return base_chunks + + safe_chunks = [] + for chunk in base_chunks: + chunk_text = getattr(chunk, "text", str(chunk)) + refined_texts = _split_text_by_token_limit(chunk_text, tokenizer) + + if len(refined_texts) == 1 and refined_texts[0] == chunk_text: + safe_chunks.append(chunk) + else: + for refined in refined_texts: + safe_chunks.append(_ChunkProxy(chunk, refined)) + + return safe_chunks + + safe_chunks = [] + for raw_chunk in split_text_into_chunks(raw_text, max_chars=max_chars): + normalized_chunk = _text_normalizer.normalize(raw_chunk) + refined_texts = _split_text_by_token_limit(normalized_chunk, tokenizer) + + for refined in refined_texts: + safe_chunks.extend(split_text_into_chunks(refined, max_chars=max_chars)) + + return safe_chunks + + +# --- 2. DATA & HELPERS --- + +def synthesize_speech(text: str, voice_choice: str, custom_audio, custom_text: str, + mode_tab: str, generation_mode: str, use_batch: bool, max_batch_size_run: int, + temperature: float, max_chars_chunk: int, session_id: str = None): + """Synthesis with optimization support and max batch size control""" + global tts, current_backbone, current_codec, model_loaded, using_lmdeploy + + _STOP_EVENT.clear() # Reset for new generation + + if not model_loaded or tts is None: + yield None, "⚠️ Vui lòng tải model trước!" + return + + if not text or text.strip() == "": + yield None, "⚠️ Vui lòng nhập văn bản!" + return + + raw_text = text.strip() + + codec_config = CODEC_CONFIGS[current_codec] + use_preencoded = codec_config['use_preencoded'] + + + # Setup Reference + yield None, "📄 Đang xử lý Reference..." + + try: + ref_codes = None + ref_text_raw = "" + + if mode_tab == "preset_mode": + if not voice_choice: + raise ValueError("Vui lòng chọn giọng mẫu.") + if "⚠️" in voice_choice: + raise ValueError("Không có giọng mẫu khả dụng. Vui lòng chuyển sang Tab Voice Cloning.") + + # Use SDK method - handles caching and JSON internally + v_id = resolve_voice_id(voice_choice) + voice_data = tts.get_preset_voice(v_id) + ref_codes = voice_data['codes'] + ref_text_raw = voice_data['text'] + + elif mode_tab == "custom_mode": + if custom_audio is None: + raise ValueError("Vui lòng upload file Audio mẫu (Reference Audio)!") + + is_turbo = "v2-Turbo" in (current_backbone or "") + if not is_turbo and (not custom_text or not custom_text.strip()): + raise ValueError("Vui lòng nhập nội dung văn bản của Audio mẫu (Reference Text)!") + + ref_text_raw = custom_text.strip() if custom_text else "" + ref_codes = tts.encode_reference(custom_audio) + + # Ensure numpy for inference + if 'torch' in sys.modules: + import torch + if isinstance(ref_codes, torch.Tensor): + ref_codes = ref_codes.cpu().numpy() + + except Exception as e: + yield None, f"❌ Lỗi xử lý Reference Audio: {str(e)}" + return + + # === STANDARD MODE === + if generation_mode == "Standard (Một lần)": + backend_name = "LMDeploy" if using_lmdeploy else "Standard" + + is_v2_turbo = "v2-Turbo" in (current_backbone or "") + text_chunks = _build_safe_text_chunks(raw_text, max_chars_chunk, is_v2_turbo) + total_chunks = len(text_chunks) + + batch_info = " (Batch Mode)" if use_batch and using_lmdeploy and total_chunks > 1 else "" + + # Show batch size info + batch_size_info = "" + if use_batch and using_lmdeploy and hasattr(tts, 'max_batch_size'): + batch_size_info = f" [Max batch: {tts.max_batch_size}]" + + yield None, f"🚀 Bắt đầu tổng hợp {backend_name}{batch_info}{batch_size_info} ({total_chunks} đoạn)..." + + all_wavs = [] + sr = 24000 + + start_time = time.time() + + try: + if is_v2_turbo: + # Sequential processing with progress updates + total_chunks = len(text_chunks) + for i, chunk in enumerate(text_chunks): + if _STOP_EVENT.is_set(): + yield None, "⏹️ Đã dừng tạo giọng nói." + return + yield None, f"⚡ Turbo v2: Đang xử lý đoạn {i+1}/{total_chunks}..." + + chunk_wav = tts.infer( + chunk.text, + ref_codes=ref_codes, + temperature=temperature, + max_chars=max_chars_chunk, + skip_normalize=True, + skip_phonemize=True + ) + + if chunk_wav is not None and len(chunk_wav) > 0: + all_wavs.append(chunk_wav) + # Add silence between Gradio-level chunks for Turbo + if i < total_chunks - 1: + sil_dur = get_silence_duration_v2(chunk) + sil_wav = np.zeros(int(sr * sil_dur), dtype=np.float32) + all_wavs.append(sil_wav) + + # Use batch processing if enabled and using LMDeploy (for v1) + elif use_batch and using_lmdeploy and hasattr(tts, 'infer_batch') and total_chunks > 1: + # Process in mini-batches to allow cancellation between batches + num_batches = (total_chunks + max_batch_size_run - 1) // max_batch_size_run + total_batch_duration = 0.0 + completed_batches = 0 + + for i in range(0, total_chunks, max_batch_size_run): + if _STOP_EVENT.is_set(): + print("🛑 Synthesis stopped during batch processing.") + yield None, "⏹️ Đã dừng tạo giọng nói." + return + + batch_idx = i // max_batch_size_run + estimate_info = "" + if completed_batches > 0: + average_batch_duration = total_batch_duration / completed_batches + estimated_total = average_batch_duration * num_batches + estimated_remaining = average_batch_duration * max(0, num_batches - batch_idx) + estimate_info = ( + f" | Ước tính còn lại: {_format_duration(estimated_remaining)}" + f" / tổng: {_format_duration(estimated_total)}" + ) + yield None, f"⚡ Đang xử lý batch {batch_idx+1}/{num_batches} (đoạn {i+1}-{min(i+max_batch_size_run, total_chunks)}){estimate_info}..." + + current_batch = text_chunks[i : i + max_batch_size_run] + batch_start_time = time.time() + batch_wavs = tts.infer_batch( + current_batch, + ref_codes=ref_codes, + ref_text=ref_text_raw, + max_batch_size=max_batch_size_run, + temperature=temperature, + skip_normalize=True + ) + batch_duration = time.time() - batch_start_time + total_batch_duration += batch_duration + completed_batches += 1 + average_batch_duration = total_batch_duration / completed_batches + estimated_total = average_batch_duration * num_batches + estimated_remaining = average_batch_duration * max(0, num_batches - completed_batches) + for chunk_wav in batch_wavs: + if chunk_wav is not None and len(chunk_wav) > 0: + all_wavs.append(chunk_wav) + yield None, ( + f"✅ Xong batch {batch_idx+1}/{num_batches} " + f"(trung bình batch: {_format_duration(average_batch_duration)}, " + f"ước tính còn lại: {_format_duration(estimated_remaining)}, " + f"tổng: {_format_duration(estimated_total)})" + ) + + else: + # Sequential processing (PyTorch or GGUF v1) + for i, chunk in enumerate(text_chunks): + if _STOP_EVENT.is_set(): + yield None, "⏹️ Đã dừng tạo giọng nói." + return + yield None, f"⏳ Đang xử lý đoạn {i+1}/{total_chunks}..." + chunk_wav = tts.infer( + chunk, + ref_codes=ref_codes, + ref_text=ref_text_raw, + temperature=temperature, + max_chars=max_chars_chunk, + skip_normalize=True + ) + if chunk_wav is not None and len(chunk_wav) > 0: + all_wavs.append(chunk_wav) + + if not all_wavs: + yield None, "❌ Không sinh được audio nào." + return + + yield None, "💾 Đang ghép file và lưu..." + + # Use utility function for joining with silence/crossfade + # Default silence=0.15s to match SDK + silence_p = 0.15 if not is_v2_turbo else 0.0 # Turbo adds silence internally + final_wav = join_audio_chunks(all_wavs, sr=sr, silence_p=silence_p) + + with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: + sf.write(tmp.name, final_wav, sr) + output_path = tmp.name + + process_time = time.time() - start_time + backend_info = f" (Backend: {'LMDeploy 🚀' if using_lmdeploy else 'Standard 📦'})" + speed_info = f", Tốc độ: {len(final_wav)/sr/process_time:.2f}x realtime" if process_time > 0 else "" + + + yield output_path, f"✅ Hoàn tất! (Thời gian: {process_time:.2f}s{speed_info}){backend_info}" + + # Cleanup memory + if using_lmdeploy and hasattr(tts, 'cleanup_memory'): + tts.cleanup_memory() + + cleanup_gpu_memory() + + except Exception as e: + # Check for CUDA OOM specifically if torch is loaded + if 'torch' in sys.modules: + import torch + if isinstance(e, torch.cuda.OutOfMemoryError): + cleanup_gpu_memory() + yield None, ( + f"❌ GPU hết VRAM! Hãy thử:\n" + f"• Giảm Max Batch Size (hiện tại: {tts.max_batch_size if hasattr(tts, 'max_batch_size') else 'N/A'})\n" + f"• Giảm độ dài văn bản\n\n" + f"Chi tiết: {str(e)}" + ) + return + + import traceback + traceback.print_exc() + cleanup_gpu_memory() + yield None, f"❌ Lỗi Standard Mode: {str(e)}" + return + + # === STREAMING MODE === + else: + sr = 24000 + crossfade_samples = int(sr * 0.03) + audio_queue = queue.Queue(maxsize=100) + PRE_BUFFER_SIZE = 3 + + end_event = threading.Event() + error_event = threading.Event() + error_msg = "" + + is_v2_turbo = "v2-Turbo" in (current_backbone or "") + text_chunks = _build_safe_text_chunks(raw_text, max_chars_chunk, is_v2_turbo) + + def producer_thread(): + nonlocal error_msg + try: + previous_tail = None + + for i, chunk_text in enumerate(text_chunks): + if _STOP_EVENT.is_set(): + break + + if is_v2_turbo: + stream_gen = tts.infer_stream( + chunk_text.text, + ref_codes=ref_codes, + temperature=temperature, + max_chars=max_chars_chunk, + skip_normalize=True, + skip_phonemize=True, + emotion_tag="" + ) + else: + stream_gen = tts.infer_stream( + chunk_text, + ref_codes=ref_codes, + ref_text=ref_text_raw, + temperature=temperature, + max_chars=max_chars_chunk, + skip_normalize=True, + emotion_tag="" + ) + + for part_idx, audio_part in enumerate(stream_gen): + if _STOP_EVENT.is_set(): + break + if audio_part is None or len(audio_part) == 0: + continue + + if previous_tail is not None and len(previous_tail) > 0: + overlap = min(len(previous_tail), len(audio_part), crossfade_samples) + if overlap > 0: + fade_out = np.linspace(1.0, 0.0, overlap, dtype=np.float32) + fade_in = np.linspace(0.0, 1.0, overlap, dtype=np.float32) + + blended = (audio_part[:overlap] * fade_in + + previous_tail[-overlap:] * fade_out) + + processed = np.concatenate([ + previous_tail[:-overlap] if len(previous_tail) > overlap else np.array([]), + blended, + audio_part[overlap:] + ]) + else: + processed = np.concatenate([previous_tail, audio_part]) + + tail_size = min(crossfade_samples, len(processed)) + previous_tail = processed[-tail_size:].copy() + output_chunk = processed[:-tail_size] if len(processed) > tail_size else processed + else: + tail_size = min(crossfade_samples, len(audio_part)) + previous_tail = audio_part[-tail_size:].copy() + output_chunk = audio_part[:-tail_size] if len(audio_part) > tail_size else audio_part + + if len(output_chunk) > 0: + audio_queue.put((sr, output_chunk)) + + # Add silence between chunks for Turbo v2 + if is_v2_turbo and i < len(text_chunks) - 1: + sil_dur = get_silence_duration_v2(chunk_text) + sil_wav = np.zeros(int(sr * sil_dur), dtype=np.float32) + audio_queue.put((sr, sil_wav)) + + if previous_tail is not None and len(previous_tail) > 0: + audio_queue.put((sr, previous_tail)) + + except Exception as e: + import traceback + traceback.print_exc() + error_msg = str(e) + error_event.set() + finally: + end_event.set() + audio_queue.put(None) + + threading.Thread(target=producer_thread, daemon=True).start() + + yield (sr, np.zeros(int(sr * 0.05))), "📄 Đang buffering..." + + pre_buffer = [] + while len(pre_buffer) < PRE_BUFFER_SIZE: + try: + item = audio_queue.get(timeout=5.0) + if item is None: + break + pre_buffer.append(item) + except queue.Empty: + if error_event.is_set(): + yield None, f"❌ Lỗi: {error_msg}" + return + break + + full_audio_buffer = [] + backend_info = "🚀 LMDeploy" if using_lmdeploy else "📦 Standard" + for sr, audio_data in pre_buffer: + full_audio_buffer.append(audio_data) + yield (sr, audio_data), f"🔊 Đang phát ({backend_info})..." + + while True: + try: + item = audio_queue.get(timeout=0.05) + if item is None: + break + sr, audio_data = item + full_audio_buffer.append(audio_data) + yield (sr, audio_data), f"🔊 Đang phát ({backend_info})..." + except queue.Empty: + if error_event.is_set(): + yield None, f"❌ Lỗi: {error_msg}" + break + if end_event.is_set() and audio_queue.empty(): + break + continue + + if full_audio_buffer: + final_wav = np.concatenate(full_audio_buffer) + with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: + sf.write(tmp.name, final_wav, sr) + + yield tmp.name, f"✅ Hoàn tất Streaming! ({backend_info})" + + # Cleanup memory + if using_lmdeploy and hasattr(tts, 'cleanup_memory'): + tts.cleanup_memory() + + cleanup_gpu_memory() + +synthesize_speech_with_estimate = wrap_with_estimate(synthesize_speech) + +def synthesize_conversation_with_empty_estimate(*args): + for audio_path, status in synthesize_conversation(*args): + yield audio_path, status, "" + +# --- CANCELLATION --- +# threading.Event is a mutable object: never reassigned, always the same reference. +# All threads share the exact same object — no scoping/serialization issues. +_STOP_EVENT = threading.Event() + +# --- 3. CONVERSATION LOGIC --- + +def synthesize_conversation( + script_text: str, + *args +): + """ + Synthesizes multi-speaker conversation from a script. + + Gradio passes speaker name boxes and voice dropdowns as individual positional args. + Layout: args[0..MAX_SPEAKERS-1] = speaker names, args[MAX_SPEAKERS..2*MAX_SPEAKERS-1] = voice IDs, + args[2*MAX_SPEAKERS] = silence_duration, args[2*MAX_SPEAKERS+1] = temperature, + args[2*MAX_SPEAKERS+2] = max_chars_chunk, args[2*MAX_SPEAKERS+3] = session_id + """ + speaker_names = list(args[:MAX_SPEAKERS]) + speaker_voices = list(args[MAX_SPEAKERS:MAX_SPEAKERS*2]) + silence_duration = args[MAX_SPEAKERS * 2] + temperature = args[MAX_SPEAKERS * 2 + 1] + max_chars_chunk = args[MAX_SPEAKERS * 2 + 2] + session_id = args[MAX_SPEAKERS * 2 + 3] if len(args) > MAX_SPEAKERS * 2 + 3 else None + + global tts, model_loaded, using_lmdeploy + + _STOP_EVENT.clear() + + if not model_loaded or tts is None: + yield None, "⚠️ Vui lòng tải model trước!" + return + + if not script_text or script_text.strip() == "": + yield None, "⚠️ Vui lòng nhập kịch bản hội thoại!" + return + + # 1. Parse Script + lines = [] + for line in script_text.strip().split('\n'): + if not line.strip(): continue + if ':' in line: + parts = line.split(':', 1) + lines.append({'speaker': parts[0].strip(), 'text': parts[1].strip()}) + else: + if lines: + lines[-1]['text'] += " " + line.strip() + else: + lines.append({'speaker': 'Narrator', 'text': line.strip()}) + + if not lines: + yield None, "⚠️ Không tìm thấy lời thoại hợp lệ (định dạng Nhân vật: Lời thoại)!" + return + + # 2. Build Speaker Mapping from individual slot components + mapping = {} + for name, voice in zip(speaker_names, speaker_voices): + name = str(name).strip() if name else "" + if not name: continue + # Use lowercase key for robust matching + v_id = resolve_voice_id(str(voice)) if voice else "" + mapping[name.lower()] = { + 'type': 'Preset', + 'voice': v_id, + 'ref_text': '' + } + + + # 3. Process Each Line + all_wavs = [] + sr = 24000 + total_lines = len(lines) + + yield None, f"🎭 Đang khởi tạo hội thoại ({total_lines} câu)..." + + start_time = time.time() + + try: + for i, line in enumerate(lines): + if _STOP_EVENT.is_set(): + yield None, "⏹️ Đã dừng hội thoại." + return + spk_name = line['speaker'] + text = line['text'] + + yield None, f"⏳ [{i+1}/{total_lines}] {spk_name}: {text[:30]}..." + + # Determine voice + ref_codes = None + ref_text_val = None + current_voice_obj = None + + # Case-insensitive lookup + config = mapping.get(spk_name.lower()) + + if not config: + print(f" ⚠️ Character '{spk_name}' not found in mapping. Fallback to default.") + # Fallback to default if speaker not mapped + try: + # Get default voice data + default_v_id = tts._default_voice + if not default_v_id: + dv_list = tts.list_preset_voices() + if dv_list: + first = dv_list[0] + default_v_id = first[1] if isinstance(first, tuple) else first + + if default_v_id: + current_voice_obj = tts.get_preset_voice(default_v_id) + ref_codes = current_voice_obj['codes'] + ref_text_val = current_voice_obj['text'] + except Exception as e: + print(f" ❌ Fallback failed: {e}") + else: + try: + v_id = config['voice'] + if config['type'] == "Preset": + current_voice_obj = tts.get_preset_voice(v_id) + if current_voice_obj and 'codes' in current_voice_obj: + ref_codes = current_voice_obj['codes'] + ref_text_val = current_voice_obj['text'] + else: + print(f" ❌ Could not find codes for voice '{v_id}'") + else: # Custom + if v_id and os.path.exists(v_id): + ref_codes = tts.encode_reference(v_id) + ref_text_val = config.get('ref_text', '') + current_voice_obj = {'codes': ref_codes, 'text': ref_text_val} + print(f" 🦜 Using custom voice for '{spk_name}'") + except Exception as e: + print(f" ❌ Lỗi nạp giọng cho {spk_name} (ID: {config.get('voice')}): {e}") + + # Ensure numpy for inference + if 'torch' in sys.modules: + import torch + if isinstance(ref_codes, torch.Tensor): + ref_codes = ref_codes.cpu().numpy() + + # Infer audio + try: + wav = tts.infer( + text, + voice=current_voice_obj, # Use full voice object + ref_codes=ref_codes, # Fallback if object not supported + ref_text=ref_text_val, + temperature=temperature, + max_chars=max_chars_chunk, + emotion_tag="<|emotion_0|>" # Emotion tag for conversation + ) + + all_wavs.append(wav) + + # Add silence between turns + if i < total_lines - 1 and silence_duration > 0: + silence_len = int(sr * silence_duration) + silence = np.zeros(silence_len) + all_wavs.append(silence) + + except Exception as e: + print(f"❌ Lỗi tổng hợp câu {i+1}: {e}") + continue + + if not all_wavs: + yield None, "❌ Không thể tạo được âm thanh nào!" + return + + # 4. Merge and Output + yield None, "🪄 Đang ghép nối âm thanh..." + final_wav = np.concatenate(all_wavs) + + with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: + sf.write(tmp.name, final_wav, sr) + elapsed = time.time() - start_time + yield tmp.name, f"✅ Hoàn tất hội thoại! ({total_lines} câu, xử lý trong {elapsed:.1f}s)" + + except Exception as e: + import traceback + traceback.print_exc() + yield None, f"❌ Lỗi hệ thống: {str(e)}" + +def extract_speakers_from_script(script): + """Find unique speakers and return gr.update() lists for the 8 slot components.""" + global CONV_VOICES_CACHE + if not script: + # Hide all slots + name_updates = [gr.update(value="", visible=False)] * MAX_SPEAKERS + dd_updates = [gr.update(value=None, visible=False)] * MAX_SPEAKERS + row_updates = [gr.update(visible=False)] * MAX_SPEAKERS + return name_updates + dd_updates + row_updates + + speakers = [] + seen = set() + for line in script.strip().split('\n'): + if ':' in line: + s = line.split(':', 1)[0].strip() + if s and s not in seen: + seen.add(s) + speakers.append(s) + + # Auto-match each speaker name to a preset voice + def _best_match(name): + if not CONV_VOICES_CACHE: + return None + + name_l = name.lower() + + # 0. Manual overrides for specific common names + overrides = { + "phương": "Trúc Ly", + "dũng": "Thanh Bình", + "hùng": "Thái Sơn" + } + if name_l in overrides: + target = overrides[name_l].lower() + for v in CONV_VOICES_CACHE: + label, value = (v[0], v[1]) if isinstance(v, tuple) else (v, v) + if target in label.lower() or target in value.lower(): + return value + + # 1. Try to find name in labels or values + for v in CONV_VOICES_CACHE: + label, value = (v[0], v[1]) if isinstance(v, tuple) else (v, v) + if name_l == label.lower() or name_l == value.lower(): + return value + + # 2. Fuzzy match (contains) + for v in CONV_VOICES_CACHE: + label, value = (v[0], v[1]) if isinstance(v, tuple) else (v, v) + if name_l in label.lower() or name_l in value.lower() or label.lower() in name_l or value.lower() in name_l: + return value + + # 3. Default to first voice if no match + first_voice = CONV_VOICES_CACHE[0] + return first_voice[1] if isinstance(first_voice, tuple) else first_voice + + name_updates, dd_updates, row_updates = [], [], [] + for i in range(MAX_SPEAKERS): + if i < len(speakers): + name_updates.append(gr.update(value=speakers[i], visible=True)) + dd_updates.append(gr.update(value=_best_match(speakers[i]), choices=CONV_VOICES_CACHE, visible=True)) + row_updates.append(gr.update(visible=True)) + else: + name_updates.append(gr.update(value="", visible=False)) + dd_updates.append(gr.update(value=None, choices=CONV_VOICES_CACHE, visible=False)) + row_updates.append(gr.update(visible=False)) + + return name_updates + dd_updates + row_updates + +EXAMPLES_LIST = [ + ["Về miền Tây không chỉ để ngắm nhìn sông nước hữu tình, mà còn để cảm nhận tấm chân tình của người dân nơi đây.", "Vĩnh (nam miền Nam)"], + ["Hà Nội những ngày vào thu mang một vẻ đẹp trầm mặc và cổ kính đến lạ thường.", "Bình (nam miền Bắc)"], +] + +with gr.Blocks(theme=theme, css=css, title="VieNeu-TTS", head=head_html) as demo: + # Session ID for cancellation tracking + session_id_state = gr.State("") + + with gr.Column(elem_classes="container"): + gr.HTML(""" +