From 0a9ce0c94018652dbc70bf5708f7ca10e8805130 Mon Sep 17 00:00:00 2001 From: Yuante Li <1957922024@qq.com> Date: Fri, 17 Apr 2026 19:49:58 +0000 Subject: [PATCH 1/7] fix: fix a bug --- src/configs/base.yml | 2 +- src/configs/compare.yml | 2 +- src/utils/model.py | 5 +++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/configs/base.yml b/src/configs/base.yml index a6450ba..79c555c 100644 --- a/src/configs/base.yml +++ b/src/configs/base.yml @@ -5,7 +5,7 @@ env: claim_pool_size: 50 reverse: False time_control: True - streaming_tts: False + streaming_tts: true debater: - side: for diff --git a/src/configs/compare.yml b/src/configs/compare.yml index c7d6bba..1243504 100644 --- a/src/configs/compare.yml +++ b/src/configs/compare.yml @@ -5,7 +5,7 @@ env: claim_pool_size: 50 reverse: False time_control: True - streaming_tts: False + streaming_tts: true debater: - side: for diff --git a/src/utils/model.py b/src/utils/model.py index 923820d..098a9ac 100644 --- a/src/utils/model.py +++ b/src/utils/model.py @@ -5,7 +5,7 @@ from utils.constants import ATTACK_RM_PATH, SUPPORT_RM_PATH, google_api_key from utils.tool import logger - +# litellm._turn_on_debug() safety_setting = [ { "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", @@ -79,11 +79,12 @@ def HelperClient( temperature=temperature, max_tokens=max_tokens, stop=stop, + num_retries=3, **kwargs, ) else: response = litellm.completion( - model=model_name, messages=messages, temperature=temperature, max_tokens=max_tokens, stop=stop, **kwargs + model=model_name, messages=messages, temperature=temperature, max_tokens=max_tokens, stop=stop, num_retries=3, **kwargs ) responses.append(response.choices[0].message.content) return responses From 27a0a845d5f91ab82ca5ecfec3bf75a75cdf167a Mon Sep 17 00:00:00 2001 From: Yuante Li <1957922024@qq.com> Date: Sat, 2 May 2026 19:02:54 +0000 Subject: [PATCH 2/7] feat: improve tts streaming with early-cut, speed adjustment, and chunk pre-start Add early-cut splitting for oversized chunks, TTS speed adjustment post-processing, fast initial compression for short-budget rounds, pre-start of last chunk, and next-chunk context in LLM rewrites. Also fix overlap_viz.sh to accept multiple args. --- src/scripts/overlap_viz.sh | 2 +- src/tts_streaming.py | 391 +++++++++++++++++++++++++++++++------ 2 files changed, 333 insertions(+), 60 deletions(-) diff --git a/src/scripts/overlap_viz.sh b/src/scripts/overlap_viz.sh index c8931bd..e500184 100755 --- a/src/scripts/overlap_viz.sh +++ b/src/scripts/overlap_viz.sh @@ -3,7 +3,7 @@ # Default: all _*_chunks/ directories under src/ SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" -PATTERN="${1:-$SCRIPT_DIR/_*_chunks}" +PATTERN="${@:-$SCRIPT_DIR/_*_chunks}" found=0 for dir in $PATTERN; do diff --git a/src/tts_streaming.py b/src/tts_streaming.py index ed7ed91..ecb69ee 100644 --- a/src/tts_streaming.py +++ b/src/tts_streaming.py @@ -36,6 +36,11 @@ MAX_REFINEMENTS = 10 MAX_PARALLEL_TTS = 8 # max concurrent background TTS threads per chunk MIN_CHUNK_CHARS = 50 # chunks shorter than this are merged into the next one +MAX_CHUNK_CHARS = 900 # long paragraphs are split into sentence sub-chunks +EARLY_CUT_RATIO = 1.25 # fs_est/target_s threshold to trigger early-cut +SPEED_ADJUST_MIN = 0.85 # TTS speed clamp lower bound +SPEED_ADJUST_MAX = 1.15 # TTS speed clamp upper bound +TARGET_WORDS_PER_SECOND = 2.6 # initial compression heuristic for short-budget rounds # -------- dataclasses -------- @@ -98,10 +103,12 @@ def _fastspeech_estimate(text: str) -> float: return float(length) -def _revise_to_n_words(client, text: str, n_words: int, prev_texts: List[str]) -> str: +def _revise_to_n_words(client, text: str, n_words: int, prev_texts: List[str], next_chunk_text: str = "") -> str: context_block = "" + context_word_count = 0 if prev_texts: joined = "\n\n".join(prev_texts) + context_word_count = LengthEstimator.count_words(joined) context_block = ( f"Here is the debate speech text that has already been delivered " f"(spoken aloud before this paragraph):\n\n" @@ -109,6 +116,20 @@ def _revise_to_n_words(client, text: str, n_words: int, prev_texts: List[str]) - f"---\n\n" ) + context_note = ( + f" The preceding speech context contains approximately {context_word_count} words." + if context_word_count > 0 + else "" + ) + + next_block = "" + if next_chunk_text: + next_block = ( + f"\n\nThe paragraph you rewrite will be followed immediately by this next paragraph " + f"(do NOT rewrite it, just ensure your output flows naturally into it):\n\n" + f"{next_chunk_text[:2000]}" + ) + resp = client.chat.completions.create( model="gpt-5-mini", messages=[ @@ -120,6 +141,7 @@ def _revise_to_n_words(client, text: str, n_words: int, prev_texts: List[str]) - "Rewrite ONLY the paragraph provided by the user. " "Preserve the argument, logical flow, and debate rhetoric. " "Do NOT add new arguments or repeat points already made in the preceding text. " + "Ensure the rewritten paragraph connects smoothly with what comes before and after it. " "Output only the rewritten paragraph, no preamble." ), }, @@ -127,9 +149,11 @@ def _revise_to_n_words(client, text: str, n_words: int, prev_texts: List[str]) - "role": "user", "content": ( f"{context_block}" - f"Rewrite the following debate paragraph to be approximately {n_words} words. " + f"Rewrite the following debate paragraph to be approximately {n_words} words." + f"{context_note} " f"Keep the debating style and the core argument intact.\n\n" f"{text[:8000]}" + f"{next_block}" ), }, ], @@ -138,13 +162,14 @@ def _revise_to_n_words(client, text: str, n_words: int, prev_texts: List[str]) - return (resp.choices[0].message.content or "").strip() -def _query_time_profiled(client, content: str, voice: str = "echo") -> Dict[str, Any]: +def _query_time_profiled(client, content: str, voice: str = "echo", speed: float = 1.0) -> Dict[str, Any]: t0 = _now() response = client.audio.speech.create( model="tts-1", voice=voice, input=content[:4096], response_format="mp3", + speed=speed, ) t1 = _now() @@ -163,10 +188,10 @@ def _query_time_profiled(client, content: str, voice: str = "echo") -> Dict[str, } -def _tts_with_retry(client, content: str, voice: str = "echo", max_attempts: int = 5) -> Dict[str, Any]: +def _tts_with_retry(client, content: str, voice: str = "echo", speed: float = 1.0, max_attempts: int = 5) -> Dict[str, Any]: for attempt in range(max_attempts): try: - return _query_time_profiled(client, content, voice=voice) + return _query_time_profiled(client, content, voice=voice, speed=speed) except Exception: if attempt == max_attempts - 1: raise @@ -227,6 +252,8 @@ def _adaptive_refine_parallel( tolerance_upper_s: float, voice: str = "echo", max_ref: int = MAX_REFINEMENTS, + next_chunk_text: str = "", + fast_initial_compress: bool = False, ) -> Dict[str, Any]: executor = concurrent.futures.ThreadPoolExecutor(max_workers=MAX_PARALLEL_TTS) candidates: List[_TtsCandidate] = [] @@ -243,21 +270,41 @@ def _refine(): llm_times: List[float] = [] fs_times: List[float] = [] t0 = _now() - - _t = _now(); est = _fastspeech_estimate(cur); fs_times.append(_now() - _t) - ok = _in_range(est, target_s, tolerance_s, tolerance_upper_s) + n_ref = 0 with candidates_lock: if not stop_event.is_set(): candidates.append(_TtsCandidate( - iteration=0, text=cur, fs_estimated_s=est, + iteration=0, text=cur, fs_estimated_s=0.0, future=executor.submit(_tts_with_retry, client, cur, voice), )) - while not ok and n_ref < max_ref and not stop_event.is_set(): + if fast_initial_compress and target_s > 0: + target_words = max(12, round(target_s * TARGET_WORDS_PER_SECOND)) + _t = _now(); compressed = _revise_to_n_words(client, cur, target_words, prev_texts, next_chunk_text); llm_times.append(_now() - _t) + n_ref = 1 + _t = _now(); compressed_est = _fastspeech_estimate(compressed); fs_times.append(_now() - _t) + compressed_ok = _in_range(compressed_est, target_s, tolerance_s, tolerance_upper_s) + if not stop_event.is_set(): + with candidates_lock: + candidates.append(_TtsCandidate( + iteration=n_ref, text=compressed, fs_estimated_s=compressed_est, + future=executor.submit(_tts_with_retry, client, compressed, voice), + )) + cur = compressed + est = compressed_est + ok = compressed_ok + else: + _t = _now(); est = _fastspeech_estimate(cur); fs_times.append(_now() - _t) + ok = _in_range(est, target_s, tolerance_s, tolerance_upper_s) + with candidates_lock: + if candidates: + candidates[0].fs_estimated_s = est + + while not ok and n_ref < max_ref and not stop_event.is_set() and target_s > 0: cw = LengthEstimator.count_words(cur) tw = max(10, round(cw * target_s / est)) - _t = _now(); cur = _revise_to_n_words(client, cur, tw, prev_texts); llm_times.append(_now() - _t) + _t = _now(); cur = _revise_to_n_words(client, cur, tw, prev_texts, next_chunk_text); llm_times.append(_now() - _t) n_ref += 1 _t = _now(); est = _fastspeech_estimate(cur); fs_times.append(_now() - _t) ok = _in_range(est, target_s, tolerance_s, tolerance_upper_s) @@ -325,6 +372,22 @@ def _refine(): refine_thread.join(timeout=5) executor.shutdown(wait=False) + # ---- speed adjustment: re-TTS if audio is outside tolerance but speed can fix it ---- + audio_s = float(tts_out["audio_seconds"]) + raw_speed = audio_s / target_s if target_s > 0 else 1.0 + speed_used = 1.0 + if not _in_range(audio_s, target_s, tolerance_s, tolerance_upper_s): + clamped = max(SPEED_ADJUST_MIN, min(SPEED_ADJUST_MAX, raw_speed)) + if abs(clamped - 1.0) > 0.01: + try: + speed_tts_out = _tts_with_retry(client, chosen_cand.text, voice=voice, speed=clamped) + if abs(speed_tts_out["audio_seconds"] - target_s) < abs(audio_s - target_s): + tts_out = speed_tts_out + audio_s = float(tts_out["audio_seconds"]) + speed_used = clamped + except Exception: + pass + return { "refined_text": chosen_cand.text, "n_ref_used": refine_stats.get("n_ref", 0), @@ -339,7 +402,8 @@ def _refine(): "llm_times_s": refine_stats.get("llm_times", []), "fs_times_s": refine_stats.get("fs_times", []), "iter_tts_times_s": iter_tts_times, - "audio_seconds": float(tts_out["audio_seconds"]), + "speed_used": speed_used, + "audio_seconds": audio_s, "tts_api_s": float(tts_out["tts_api_s"]), "mp3_parse_s": float(tts_out["mp3_parse_s"]), "mp3_bytes": tts_out["mp3_bytes"], @@ -347,6 +411,44 @@ def _refine(): # -------- chunk utilities -------- +def _split_sentences(text: str) -> List[str]: + """Split text into sentences on '.', '!', '?' boundaries.""" + import re + parts = re.split(r'(?<=[.!?])\s+', text.strip()) + return [p for p in parts if p.strip()] + + +def _early_cut_chunk( + text: str, + target_s: float, + early_cut_ratio: float = EARLY_CUT_RATIO, +) -> Tuple[str, str]: + """ + Split text into (head, tail) where head's FS estimate ≈ target_s. + Returns (head, tail); tail may be empty if the whole text fits. + Only called when fs_estimate(text) / target_s > early_cut_ratio. + """ + sentences = _split_sentences(text) + if len(sentences) <= 1: + return text, "" + + head_sentences: List[str] = [] + for sent in sentences: + candidate = " ".join(head_sentences + [sent]) + est = _fastspeech_estimate(candidate) + if est > target_s and head_sentences: + break + head_sentences.append(sent) + + if not head_sentences: + head_sentences = [sentences[0]] + + head = " ".join(head_sentences) + tail_sentences = sentences[len(head_sentences):] + tail = " ".join(tail_sentences) + return head, tail + + def _merge_short_chunks(segments: List[str], min_chars: int = MIN_CHUNK_CHARS) -> List[str]: result = list(segments) i = 0 @@ -359,6 +461,30 @@ def _merge_short_chunks(segments: List[str], min_chars: int = MIN_CHUNK_CHARS) - return result +def _split_long_chunks(segments: List[str], max_chars: int = MAX_CHUNK_CHARS) -> List[str]: + result: List[str] = [] + for segment in segments: + if len(segment) <= max_chars: + result.append(segment) + continue + + current: List[str] = [] + current_len = 0 + for sentence in _split_sentences(segment): + extra_len = len(sentence) + (1 if current else 0) + if current and current_len + extra_len > max_chars: + result.append(" ".join(current)) + current = [sentence] + current_len = len(sentence) + else: + current.append(sentence) + current_len += extra_len + if current: + result.append(" ".join(current)) + + return result + + def split_by_paragraphs(text: str) -> List[str]: """Split text on double newlines into non-empty paragraphs.""" parts = [p.strip() for p in text.split("\n\n") if p.strip()] @@ -373,6 +499,8 @@ def run_pipeline( tolerance_ratio: float = TOLERANCE_RATIO, voice: str = "echo", out_dir: Optional[Path] = None, + enable_early_cut: bool = True, + early_cut_ratio: float = EARLY_CUT_RATIO, ) -> Tuple[List[ChunkProfile], RoundProfile, bytes, List[str]]: """ Run the streaming TTS pipeline on a list of text segments. @@ -389,7 +517,7 @@ def run_pipeline( audio_total = 0.0 overrun_total = 0.0 - segments_list = _merge_short_chunks(segments_list) + segments_list = list(_merge_short_chunks(_split_long_chunks(segments_list))) n_chunks = len(segments_list) audio_budget_remaining = total_budget_s @@ -397,13 +525,24 @@ def run_pipeline( out_dir = Path(out_dir) out_dir.mkdir(parents=True, exist_ok=True) + fast_initial_compress = total_budget_s <= 120 + final_texts: List[str] = [] combined_audio = AudioSegment.silent(duration=0) all_mp3_bytes: List[bytes] = [] prev_audio_s: Optional[float] = None - for i, chunk in enumerate(segments_list): + # Pre-start future for the last chunk (filled at chunk n_chunks-2) + _last_chunk_future: Optional[concurrent.futures.Future] = None + _last_chunk_executor: Optional[concurrent.futures.ThreadPoolExecutor] = None + _last_chunk_lead_s = 0.0 + + i = 0 + while i < len(segments_list): + chunk = segments_list[i] + n_chunks = len(segments_list) # may grow due to early-cut + chunk_t0 = _now() chunk_words = LengthEstimator.count_words(chunk) chunk_chars = len(chunk) @@ -411,6 +550,22 @@ def run_pipeline( remaining_chars_total = sum(len(c) for c in segments_list[i:]) target_s = audio_budget_remaining * (chunk_chars / remaining_chars_total) + # ---- early-cut: if chunk is too long relative to budget, split it now ---- + if enable_early_cut and i > 0: + fs_pre = _fastspeech_estimate(chunk) + if fs_pre / target_s > early_cut_ratio: + head, tail = _early_cut_chunk(chunk, target_s, early_cut_ratio) + if tail: + segments_list[i] = head + segments_list.insert(i + 1, tail) + chunk = head + n_chunks = len(segments_list) + chunk_chars = len(chunk) + chunk_words = LengthEstimator.count_words(chunk) + remaining_chars_total = sum(len(c) for c in segments_list[i:]) + target_s = audio_budget_remaining * (chunk_chars / remaining_chars_total) + print(f" chunk {i:03d} | early-cut: fs_pre={fs_pre:.1f}s > {early_cut_ratio}x target={target_s:.1f}s → split into head({len(head)}c)+tail({len(tail)}c)") + tol_s = max(MIN_TOLERANCE_S, target_s * tolerance_ratio) remaining_chunks = n_chunks - i tol_upper_s = ( @@ -420,6 +575,7 @@ def run_pipeline( ) max_ref = 3 if i < n_chunks // 2 else MAX_REFINEMENTS + next_chunk_text = segments_list[i + 1] if i + 1 < len(segments_list) else "" seg = None mp3_bytes = b"" @@ -427,59 +583,140 @@ def run_pipeline( tts_api_s = 0.0 mp3_parse_s = 0.0 - # ---- chunk 0: no refinement, sequential TTS ---- + # ---- chunk 0: optionally race original TTS against a target-sized rewrite ---- if i == 0: time_budget_s = target_s - refined = chunk - n_ref_used = 0 - fs_estimated_s = 0.0 - refine_total_s = 0.0 - in_range = True - timed_out = False - n_candidates_submitted = 1 - n_candidates_done = 1 - used_candidate_iter = 0 - iter_llm_times_s = "[]" - iter_fs_times_s = "[]" - iter_tts_times_s = "[]" - - for attempt in range(10): + if fast_initial_compress: + ref_out = _adaptive_refine_parallel( + client, chunk, + target_s=target_s, + time_budget_s=time_budget_s, + prev_texts=[], + tolerance_s=tol_s, + tolerance_upper_s=tol_upper_s, + voice=voice, + max_ref=1, + next_chunk_text=next_chunk_text, + fast_initial_compress=True, + ) + + refined = ref_out["refined_text"] + n_ref_used = ref_out["n_ref_used"] + fs_estimated_s = ref_out["fs_estimated_s"] + refine_total_s = ref_out["refine_total_s"] + in_range = ref_out["target_reached"] + timed_out = ref_out["timed_out"] + n_candidates_submitted = ref_out["n_candidates_submitted"] + n_candidates_done = ref_out["n_candidates_done"] + used_candidate_iter = ref_out["used_candidate_iter"] + iter_llm_times_s = json.dumps([round(t, 3) for t in ref_out["llm_times_s"]]) + iter_fs_times_s = json.dumps([round(t, 3) for t in ref_out["fs_times_s"]]) + iter_tts_times_s = json.dumps(ref_out["iter_tts_times_s"]) + total_elapsed_s = ref_out["total_elapsed_s"] + audio_seconds = ref_out["audio_seconds"] + tts_api_s = ref_out["tts_api_s"] + mp3_parse_s = ref_out["mp3_parse_s"] + mp3_bytes = ref_out["mp3_bytes"] + overrun_s = max(0.0, total_elapsed_s - time_budget_s) + try: - tts_out = _query_time_profiled(client, refined, voice=voice) - audio_seconds = float(tts_out["audio_seconds"]) - tts_api_s = float(tts_out["tts_api_s"]) - mp3_parse_s = float(tts_out["mp3_parse_s"]) - mp3_bytes = tts_out["mp3_bytes"] seg = AudioSegment.from_file(BytesIO(mp3_bytes), format="mp3") - break - except (Exception, CouldntDecodeError) as e: - if attempt == 9: - if isinstance(e, CouldntDecodeError): - import warnings - warnings.warn(f"Chunk {i}: pydub decode failed after 10 attempts: {e}") - seg = AudioSegment.silent(duration=int(audio_seconds * 1000)) - break - raise RuntimeError(f"TTS/decode failed after 10 attempts: {e}") from e - time.sleep(2.0 * (attempt + 1)) - - iter_tts_times_s = json.dumps([round(tts_api_s, 3)]) - total_elapsed_s = tts_api_s - overrun_s = tts_api_s + except CouldntDecodeError as e: + import warnings + warnings.warn(f"Chunk {i}: pydub decode failed: {e}") + seg = AudioSegment.silent(duration=int(audio_seconds * 1000)) + else: + refined = chunk + n_ref_used = 0 + fs_estimated_s = 0.0 + refine_total_s = 0.0 + in_range = True + timed_out = False + n_candidates_submitted = 1 + n_candidates_done = 1 + used_candidate_iter = 0 + iter_llm_times_s = "[]" + iter_fs_times_s = "[]" + iter_tts_times_s = "[]" + + for attempt in range(10): + try: + tts_out = _query_time_profiled(client, refined, voice=voice) + audio_seconds = float(tts_out["audio_seconds"]) + tts_api_s = float(tts_out["tts_api_s"]) + mp3_parse_s = float(tts_out["mp3_parse_s"]) + mp3_bytes = tts_out["mp3_bytes"] + seg = AudioSegment.from_file(BytesIO(mp3_bytes), format="mp3") + break + except (Exception, CouldntDecodeError) as e: + if attempt == 9: + if isinstance(e, CouldntDecodeError): + import warnings + warnings.warn(f"Chunk {i}: pydub decode failed after 10 attempts: {e}") + seg = AudioSegment.silent(duration=int(audio_seconds * 1000)) + break + raise RuntimeError(f"TTS/decode failed after 10 attempts: {e}") from e + time.sleep(2.0 * (attempt + 1)) + + iter_tts_times_s = json.dumps([round(tts_api_s, 3)]) + total_elapsed_s = tts_api_s + overrun_s = tts_api_s # ---- chunks 1+: parallel refinement with background TTS candidates ---- else: time_budget_s = prev_audio_s - ref_out = _adaptive_refine_parallel( - client, chunk, - target_s=target_s, - time_budget_s=time_budget_s, - prev_texts=final_texts, - tolerance_s=tol_s, - tolerance_upper_s=tol_upper_s, - voice=voice, - max_ref=max_ref, - ) + is_last = (i == len(segments_list) - 1) + if is_last and _last_chunk_future is not None: + try: + pre_out = _last_chunk_future.result(timeout=max(0.0, time_budget_s)) + except concurrent.futures.TimeoutError: + pre_out = _last_chunk_future.result() + except Exception: + pre_out = None + if pre_out is not None: + pre_out = dict(pre_out) + pre_out["total_elapsed_s"] = max( + 0.0, + float(pre_out["total_elapsed_s"]) - _last_chunk_lead_s, + ) + ref_out = pre_out + print( + f" [pre-start] adopted pre-start result: " + f"effective_wait={pre_out['total_elapsed_s']:.1f}s, " + f"lead={_last_chunk_lead_s:.1f}s, target={target_s:.1f}s" + ) + else: + ref_out = _adaptive_refine_parallel( + client, chunk, + target_s=target_s, + time_budget_s=time_budget_s, + prev_texts=final_texts, + tolerance_s=tol_s, + tolerance_upper_s=tol_upper_s, + voice=voice, + max_ref=max_ref, + next_chunk_text=next_chunk_text, + fast_initial_compress=fast_initial_compress, + ) + if _last_chunk_executor is not None: + _last_chunk_executor.shutdown(wait=False) + _last_chunk_executor = None + _last_chunk_future = None + _last_chunk_lead_s = 0.0 + else: + ref_out = _adaptive_refine_parallel( + client, chunk, + target_s=target_s, + time_budget_s=time_budget_s, + prev_texts=final_texts, + tolerance_s=tol_s, + tolerance_upper_s=tol_upper_s, + voice=voice, + max_ref=max_ref, + next_chunk_text=next_chunk_text, + fast_initial_compress=fast_initial_compress, + ) refined = ref_out["refined_text"] n_ref_used = ref_out["n_ref_used"] @@ -514,6 +751,36 @@ def run_pipeline( final_texts.append(refined) all_mp3_bytes.append(mp3_bytes) + # ---- pre-start last chunk refinement two chunks early ---- + if ( + i == len(segments_list) - 3 + and len(segments_list) >= 3 + and _last_chunk_future is None + ): + last_idx = len(segments_list) - 1 + last_text = segments_list[last_idx] + last_chars = len(last_text) + remaining_chars_est = sum(len(c) for c in segments_list[i + 1:]) + last_target_s_est = audio_budget_remaining * (last_chars / max(remaining_chars_est, 1)) + last_tol_s_est = max(MIN_TOLERANCE_S, last_target_s_est * tolerance_ratio) + last_tol_upper_s_est = max(MIN_TOLERANCE_S, last_target_s_est * TOLERANCE_RATIO_UPPER) + _last_chunk_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + _last_chunk_future = _last_chunk_executor.submit( + _adaptive_refine_parallel, + client, last_text, + last_target_s_est, + total_budget_s, + list(final_texts), + last_tol_s_est, + last_tol_upper_s_est, + voice, + MAX_REFINEMENTS, + "", + fast_initial_compress, + ) + _last_chunk_lead_s = prev_audio_s or 0.0 + print(f" [pre-start] chunk {last_idx} kicked off at chunk {i}, est_target={last_target_s_est:.1f}s") + if out_dir is not None: (out_dir / f"chunk_{i:03d}.txt").write_text(refined, encoding="utf-8") (out_dir / f"chunk_{i:03d}.mp3").write_bytes(mp3_bytes) @@ -576,6 +843,8 @@ def run_pipeline( f"overrun={overrun_s:.2f}s | remaining={audio_budget_remaining:.1f}s" ) + i += 1 + if out_dir is not None: sep = "\n\n" + ("=" * 80) + "\n\n" (out_dir / "chunks_final.txt").write_text(sep.join(final_texts), encoding="utf-8") @@ -588,7 +857,7 @@ def run_pipeline( round_t1 = _now() round_profile = RoundProfile( - n_chunks=n_chunks, + n_chunks=len(segments_list), total_budget_s=total_budget_s, tolerance_ratio=tolerance_ratio, round_total_s=round_t1 - round_t0, @@ -615,6 +884,8 @@ def convert_text_to_speech_streaming( output_path: str, total_budget_s: float, voice: str = "echo", + enable_early_cut: bool = True, + early_cut_ratio: float = EARLY_CUT_RATIO, ) -> Tuple[str, str, float]: """ Streaming TTS: split content into chunks, adaptively refine each chunk's @@ -647,6 +918,8 @@ def convert_text_to_speech_streaming( total_budget_s=total_budget_s, voice=voice, out_dir=output_path.parent / f"{output_path.stem}_chunks", + enable_early_cut=enable_early_cut, + early_cut_ratio=early_cut_ratio, ) # Save combined audio From 9f533cd396f2f821fb907396f278144d98495d5e Mon Sep 17 00:00:00 2001 From: Yuante Li <1957922024@qq.com> Date: Sun, 3 May 2026 03:26:06 +0000 Subject: [PATCH 3/7] refactor: simplify tts streaming pipeline - remove long-chunk splitting and fast-initial-compress paths - chunk 0 always uses sequential TTS without refinement - track prep_start_lead_s on ChunkProfile for adopted pre-started chunks - default enable_early_cut to False --- src/tts_streaming.py | 177 +++++++++++-------------------------------- 1 file changed, 44 insertions(+), 133 deletions(-) diff --git a/src/tts_streaming.py b/src/tts_streaming.py index ecb69ee..b054610 100644 --- a/src/tts_streaming.py +++ b/src/tts_streaming.py @@ -36,11 +36,9 @@ MAX_REFINEMENTS = 10 MAX_PARALLEL_TTS = 8 # max concurrent background TTS threads per chunk MIN_CHUNK_CHARS = 50 # chunks shorter than this are merged into the next one -MAX_CHUNK_CHARS = 900 # long paragraphs are split into sentence sub-chunks EARLY_CUT_RATIO = 1.25 # fs_est/target_s threshold to trigger early-cut SPEED_ADJUST_MIN = 0.85 # TTS speed clamp lower bound SPEED_ADJUST_MAX = 1.15 # TTS speed clamp upper bound -TARGET_WORDS_PER_SECOND = 2.6 # initial compression heuristic for short-budget rounds # -------- dataclasses -------- @@ -70,6 +68,7 @@ class ChunkProfile: iter_llm_times_s: str # JSON list iter_fs_times_s: str # JSON list iter_tts_times_s: str # JSON list + prep_start_lead_s: float = 0.0 # >0 only for adopted pre-started last chunk; how much earlier than events[i-1].play_start the prep actually began @dataclass @@ -253,7 +252,6 @@ def _adaptive_refine_parallel( voice: str = "echo", max_ref: int = MAX_REFINEMENTS, next_chunk_text: str = "", - fast_initial_compress: bool = False, ) -> Dict[str, Any]: executor = concurrent.futures.ThreadPoolExecutor(max_workers=MAX_PARALLEL_TTS) candidates: List[_TtsCandidate] = [] @@ -270,37 +268,16 @@ def _refine(): llm_times: List[float] = [] fs_times: List[float] = [] t0 = _now() - n_ref = 0 + _t = _now(); est = _fastspeech_estimate(cur); fs_times.append(_now() - _t) + ok = _in_range(est, target_s, tolerance_s, tolerance_upper_s) with candidates_lock: if not stop_event.is_set(): candidates.append(_TtsCandidate( - iteration=0, text=cur, fs_estimated_s=0.0, + iteration=0, text=cur, fs_estimated_s=est, future=executor.submit(_tts_with_retry, client, cur, voice), )) - if fast_initial_compress and target_s > 0: - target_words = max(12, round(target_s * TARGET_WORDS_PER_SECOND)) - _t = _now(); compressed = _revise_to_n_words(client, cur, target_words, prev_texts, next_chunk_text); llm_times.append(_now() - _t) - n_ref = 1 - _t = _now(); compressed_est = _fastspeech_estimate(compressed); fs_times.append(_now() - _t) - compressed_ok = _in_range(compressed_est, target_s, tolerance_s, tolerance_upper_s) - if not stop_event.is_set(): - with candidates_lock: - candidates.append(_TtsCandidate( - iteration=n_ref, text=compressed, fs_estimated_s=compressed_est, - future=executor.submit(_tts_with_retry, client, compressed, voice), - )) - cur = compressed - est = compressed_est - ok = compressed_ok - else: - _t = _now(); est = _fastspeech_estimate(cur); fs_times.append(_now() - _t) - ok = _in_range(est, target_s, tolerance_s, tolerance_upper_s) - with candidates_lock: - if candidates: - candidates[0].fs_estimated_s = est - while not ok and n_ref < max_ref and not stop_event.is_set() and target_s > 0: cw = LengthEstimator.count_words(cur) tw = max(10, round(cw * target_s / est)) @@ -461,30 +438,6 @@ def _merge_short_chunks(segments: List[str], min_chars: int = MIN_CHUNK_CHARS) - return result -def _split_long_chunks(segments: List[str], max_chars: int = MAX_CHUNK_CHARS) -> List[str]: - result: List[str] = [] - for segment in segments: - if len(segment) <= max_chars: - result.append(segment) - continue - - current: List[str] = [] - current_len = 0 - for sentence in _split_sentences(segment): - extra_len = len(sentence) + (1 if current else 0) - if current and current_len + extra_len > max_chars: - result.append(" ".join(current)) - current = [sentence] - current_len = len(sentence) - else: - current.append(sentence) - current_len += extra_len - if current: - result.append(" ".join(current)) - - return result - - def split_by_paragraphs(text: str) -> List[str]: """Split text on double newlines into non-empty paragraphs.""" parts = [p.strip() for p in text.split("\n\n") if p.strip()] @@ -499,7 +452,7 @@ def run_pipeline( tolerance_ratio: float = TOLERANCE_RATIO, voice: str = "echo", out_dir: Optional[Path] = None, - enable_early_cut: bool = True, + enable_early_cut: bool = False, early_cut_ratio: float = EARLY_CUT_RATIO, ) -> Tuple[List[ChunkProfile], RoundProfile, bytes, List[str]]: """ @@ -517,7 +470,7 @@ def run_pipeline( audio_total = 0.0 overrun_total = 0.0 - segments_list = list(_merge_short_chunks(_split_long_chunks(segments_list))) + segments_list = list(_merge_short_chunks(segments_list)) n_chunks = len(segments_list) audio_budget_remaining = total_budget_s @@ -525,8 +478,6 @@ def run_pipeline( out_dir = Path(out_dir) out_dir.mkdir(parents=True, exist_ok=True) - fast_initial_compress = total_budget_s <= 120 - final_texts: List[str] = [] combined_audio = AudioSegment.silent(duration=0) all_mp3_bytes: List[bytes] = [] @@ -582,85 +533,46 @@ def run_pipeline( audio_seconds = 0.0 tts_api_s = 0.0 mp3_parse_s = 0.0 + chunk_lead_s = 0.0 - # ---- chunk 0: optionally race original TTS against a target-sized rewrite ---- + # ---- chunk 0: no refinement, sequential TTS ---- if i == 0: time_budget_s = target_s - if fast_initial_compress: - ref_out = _adaptive_refine_parallel( - client, chunk, - target_s=target_s, - time_budget_s=time_budget_s, - prev_texts=[], - tolerance_s=tol_s, - tolerance_upper_s=tol_upper_s, - voice=voice, - max_ref=1, - next_chunk_text=next_chunk_text, - fast_initial_compress=True, - ) - - refined = ref_out["refined_text"] - n_ref_used = ref_out["n_ref_used"] - fs_estimated_s = ref_out["fs_estimated_s"] - refine_total_s = ref_out["refine_total_s"] - in_range = ref_out["target_reached"] - timed_out = ref_out["timed_out"] - n_candidates_submitted = ref_out["n_candidates_submitted"] - n_candidates_done = ref_out["n_candidates_done"] - used_candidate_iter = ref_out["used_candidate_iter"] - iter_llm_times_s = json.dumps([round(t, 3) for t in ref_out["llm_times_s"]]) - iter_fs_times_s = json.dumps([round(t, 3) for t in ref_out["fs_times_s"]]) - iter_tts_times_s = json.dumps(ref_out["iter_tts_times_s"]) - total_elapsed_s = ref_out["total_elapsed_s"] - audio_seconds = ref_out["audio_seconds"] - tts_api_s = ref_out["tts_api_s"] - mp3_parse_s = ref_out["mp3_parse_s"] - mp3_bytes = ref_out["mp3_bytes"] - overrun_s = max(0.0, total_elapsed_s - time_budget_s) - + refined = chunk + n_ref_used = 0 + fs_estimated_s = 0.0 + refine_total_s = 0.0 + in_range = True + timed_out = False + n_candidates_submitted = 1 + n_candidates_done = 1 + used_candidate_iter = 0 + iter_llm_times_s = "[]" + iter_fs_times_s = "[]" + iter_tts_times_s = "[]" + + for attempt in range(10): try: + tts_out = _query_time_profiled(client, refined, voice=voice) + audio_seconds = float(tts_out["audio_seconds"]) + tts_api_s = float(tts_out["tts_api_s"]) + mp3_parse_s = float(tts_out["mp3_parse_s"]) + mp3_bytes = tts_out["mp3_bytes"] seg = AudioSegment.from_file(BytesIO(mp3_bytes), format="mp3") - except CouldntDecodeError as e: - import warnings - warnings.warn(f"Chunk {i}: pydub decode failed: {e}") - seg = AudioSegment.silent(duration=int(audio_seconds * 1000)) - else: - refined = chunk - n_ref_used = 0 - fs_estimated_s = 0.0 - refine_total_s = 0.0 - in_range = True - timed_out = False - n_candidates_submitted = 1 - n_candidates_done = 1 - used_candidate_iter = 0 - iter_llm_times_s = "[]" - iter_fs_times_s = "[]" - iter_tts_times_s = "[]" - - for attempt in range(10): - try: - tts_out = _query_time_profiled(client, refined, voice=voice) - audio_seconds = float(tts_out["audio_seconds"]) - tts_api_s = float(tts_out["tts_api_s"]) - mp3_parse_s = float(tts_out["mp3_parse_s"]) - mp3_bytes = tts_out["mp3_bytes"] - seg = AudioSegment.from_file(BytesIO(mp3_bytes), format="mp3") - break - except (Exception, CouldntDecodeError) as e: - if attempt == 9: - if isinstance(e, CouldntDecodeError): - import warnings - warnings.warn(f"Chunk {i}: pydub decode failed after 10 attempts: {e}") - seg = AudioSegment.silent(duration=int(audio_seconds * 1000)) - break - raise RuntimeError(f"TTS/decode failed after 10 attempts: {e}") from e - time.sleep(2.0 * (attempt + 1)) - - iter_tts_times_s = json.dumps([round(tts_api_s, 3)]) - total_elapsed_s = tts_api_s - overrun_s = tts_api_s + break + except (Exception, CouldntDecodeError) as e: + if attempt == 9: + if isinstance(e, CouldntDecodeError): + import warnings + warnings.warn(f"Chunk {i}: pydub decode failed after 10 attempts: {e}") + seg = AudioSegment.silent(duration=int(audio_seconds * 1000)) + break + raise RuntimeError(f"TTS/decode failed after 10 attempts: {e}") from e + time.sleep(2.0 * (attempt + 1)) + + iter_tts_times_s = json.dumps([round(tts_api_s, 3)]) + total_elapsed_s = tts_api_s + overrun_s = tts_api_s # ---- chunks 1+: parallel refinement with background TTS candidates ---- else: @@ -681,6 +593,7 @@ def run_pipeline( float(pre_out["total_elapsed_s"]) - _last_chunk_lead_s, ) ref_out = pre_out + chunk_lead_s = _last_chunk_lead_s print( f" [pre-start] adopted pre-start result: " f"effective_wait={pre_out['total_elapsed_s']:.1f}s, " @@ -697,7 +610,6 @@ def run_pipeline( voice=voice, max_ref=max_ref, next_chunk_text=next_chunk_text, - fast_initial_compress=fast_initial_compress, ) if _last_chunk_executor is not None: _last_chunk_executor.shutdown(wait=False) @@ -715,7 +627,6 @@ def run_pipeline( voice=voice, max_ref=max_ref, next_chunk_text=next_chunk_text, - fast_initial_compress=fast_initial_compress, ) refined = ref_out["refined_text"] @@ -776,7 +687,6 @@ def run_pipeline( voice, MAX_REFINEMENTS, "", - fast_initial_compress, ) _last_chunk_lead_s = prev_audio_s or 0.0 print(f" [pre-start] chunk {last_idx} kicked off at chunk {i}, est_target={last_target_s_est:.1f}s") @@ -820,6 +730,7 @@ def run_pipeline( iter_llm_times_s=iter_llm_times_s, iter_fs_times_s=iter_fs_times_s, iter_tts_times_s=iter_tts_times_s, + prep_start_lead_s=chunk_lead_s, ) chunk_profiles.append(cp) @@ -884,7 +795,7 @@ def convert_text_to_speech_streaming( output_path: str, total_budget_s: float, voice: str = "echo", - enable_early_cut: bool = True, + enable_early_cut: bool = False, early_cut_ratio: float = EARLY_CUT_RATIO, ) -> Tuple[str, str, float]: """ From b12afa4ed488b77c4d5c7bbde576af1a02f8d05c Mon Sep 17 00:00:00 2001 From: Yuante Li <1957922024@qq.com> Date: Sun, 3 May 2026 03:26:12 +0000 Subject: [PATCH 4/7] refactor: move streaming_tts flag to EnvConfig CompareEnv now reads streaming_tts from its own config instead of inheriting from each debater's config, so baseline and test runs share the same streaming setting. --- src/compare_env.py | 4 ++-- src/env.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/compare_env.py b/src/compare_env.py index ab6f366..e3ee6dd 100644 --- a/src/compare_env.py +++ b/src/compare_env.py @@ -85,7 +85,7 @@ def step_play(self, side, stage, history, max_time): history=history, max_time=max_time, time_control=self.time_control, - streaming_tts=self.baseline_debaters[side].config.streaming_tts, + streaming_tts=self.config.streaming_tts, ) # Generate test response using reference history @@ -93,7 +93,7 @@ def step_play(self, side, stage, history, max_time): history=history, max_time=max_time, time_control=self.time_control, - streaming_tts=self.test_debaters[side].config.streaming_tts, + streaming_tts=self.config.streaming_tts, ) return base_response, test_response diff --git a/src/env.py b/src/env.py index 9df487c..2830b0c 100644 --- a/src/env.py +++ b/src/env.py @@ -24,6 +24,7 @@ class EnvConfig: claim_pool_size: int = 50 reverse: bool = False time_control: bool = True + streaming_tts: bool = False def extract_overall_score(obj_scores): # larger is better From bce0c3a7ebcff4f73e6665638ebb460824f80029 Mon Sep 17 00:00:00 2001 From: Yuante Li <1957922024@qq.com> Date: Sun, 3 May 2026 03:26:21 +0000 Subject: [PATCH 5/7] feat: add pre-start lane and combined timeline to overlap viz - overlap_viz_par.py renders pre-started last chunks on a separate lane using prep_start_lead_s, so they don't visually collide with the main thread bars of preceding chunks - overlap_viz.sh stacks per-chunk-dir PNGs into a single combined overlap_timeline_combined.png when multiple dirs match - add stack_pngs.py helper that vertically stacks PNGs --- src/scripts/overlap_viz.sh | 14 ++++++++++ src/scripts/overlap_viz_par.py | 48 +++++++++++++++++++++++----------- src/scripts/stack_pngs.py | 26 ++++++++++++++++++ 3 files changed, 73 insertions(+), 15 deletions(-) create mode 100644 src/scripts/stack_pngs.py diff --git a/src/scripts/overlap_viz.sh b/src/scripts/overlap_viz.sh index e500184..c22df85 100755 --- a/src/scripts/overlap_viz.sh +++ b/src/scripts/overlap_viz.sh @@ -1,16 +1,23 @@ #!/bin/bash # Usage: bash overlap_viz.sh [chunks_dir_pattern] # Default: all _*_chunks/ directories under src/ +# When 2+ chunks dirs are matched, also produce a combined top-to-bottom PNG +# (overlap_timeline_combined.png) in the common parent directory. SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" PATTERN="${@:-$SCRIPT_DIR/_*_chunks}" +pngs=() found=0 for dir in $PATTERN; do csv="$dir/chunk_profile.csv" if [ -f "$csv" ]; then echo "[VIZ] $csv" python "$SCRIPT_DIR/overlap_viz_par.py" "$csv" + png="$dir/overlap_timeline_par.png" + if [ -f "$png" ]; then + pngs+=("$png") + fi found=$((found + 1)) else echo "[SKIP] $csv not found" @@ -19,4 +26,11 @@ done if [ "$found" -eq 0 ]; then echo "No chunk_profile.csv found. Run with streaming_tts: true first." +elif [ "${#pngs[@]}" -ge 2 ]; then + parent=$(python -c " +import os, sys +print(os.path.commonpath([os.path.dirname(p) for p in sys.argv[1:]])) +" "${pngs[@]}") + out="$parent/overlap_timeline_combined.png" + python "$SCRIPT_DIR/stack_pngs.py" "${pngs[@]}" -o "$out" fi diff --git a/src/scripts/overlap_viz_par.py b/src/scripts/overlap_viz_par.py index 062ac87..32a8d36 100644 --- a/src/scripts/overlap_viz_par.py +++ b/src/scripts/overlap_viz_par.py @@ -65,19 +65,25 @@ def compute_timeline(rows): events.append({ "chunk": 0, "prep_start": t, + "prep_end": play0_start, "main_segments": [], # no fs/llm "tts_candidates": _build_tts_candidates(rows[0], t), "play_start": play0_start, "play_end": play0_end, "gap": 0.0, + "is_prestart": False, }) for i in range(1, len(rows)): - prep_start = events[i - 1]["play_start"] - prep_duration = rows[i]["total_elapsed_s"] # true wall-clock (parallel) + prep_start_default = events[i - 1]["play_start"] + prep_duration = rows[i]["total_elapsed_s"] # streaming-adjusted (lead already subtracted for pre-started) + + lead_s = float(rows[i].get("prep_start_lead_s", 0.0) or 0.0) + is_prestart = lead_s > 0.01 + prep_start = prep_start_default - lead_s prev_play_end = events[i - 1]["play_end"] - ready_at = prep_start + prep_duration + ready_at = prep_start_default + prep_duration # wall-clock when prep is done gap = max(0.0, ready_at - prev_play_end) play_start = max(prev_play_end, ready_at) play_end = play_start + rows[i]["audio_seconds"] @@ -85,11 +91,13 @@ def compute_timeline(rows): events.append({ "chunk": i, "prep_start": prep_start, + "prep_end": ready_at, "main_segments": _build_main_segments(rows[i], prep_start), "tts_candidates": _build_tts_candidates(rows[i], prep_start), "play_start": play_start, "play_end": play_end, "gap": gap, + "is_prestart": is_prestart, }) return events @@ -148,10 +156,11 @@ def _build_tts_candidates(row, prep_start): "tts_nd": "#BBBBBB", # grey – TTS still running at selection } -BAR_H = 0.32 # playback / main-thread bar height -TTS_H = 0.18 # height of each TTS candidate sub-row -TTS_GAP = 0.06 # vertical gap between candidate sub-rows -Y_MAIN = 0.0 # centre of main-thread lane +BAR_H = 0.32 # playback / main-thread bar height +TTS_H = 0.18 # height of each TTS candidate sub-row +TTS_GAP = 0.06 # vertical gap between candidate sub-rows +Y_MAIN = 0.0 # centre of main-thread lane +Y_PRESTART = -0.6 # centre of pre-start lane (only shown when a pre-started chunk exists) def _tts_y(k): @@ -169,8 +178,10 @@ def plot_timeline(rows, events, out_path, title="Parallel Chunk Overlap Timeline n = len(rows) max_cands = max((len(ev["tts_candidates"]) for ev in events), default=1) Y_PLAY = _y_play(max_cands) + has_prestart = any(ev.get("is_prestart") for ev in events) + bottom_y = Y_PRESTART if has_prestart else Y_MAIN - fig_h = max(5.0, Y_PLAY + 1.2) + fig_h = max(5.0, Y_PLAY - bottom_y + 1.2) fig, ax = plt.subplots(figsize=(max(14, n * 2.5), fig_h)) for ev in events: @@ -195,7 +206,7 @@ def plot_timeline(rows, events, out_path, title="Parallel Chunk Overlap Timeline ha="center", va="center", fontsize=7, color="white") # ── TTS candidates (parallel band) ──────────────────────────────────── - prep_end = ev["prep_start"] + rows[i]["total_elapsed_s"] + prep_end = ev["prep_end"] for cand in cands: y = _tts_y(cand["k"]) if cand["duration"] > 0: @@ -221,13 +232,16 @@ def plot_timeline(rows, events, out_path, title="Parallel Chunk Overlap Timeline fontsize=5.5, color="#888888", va="bottom") # ── main thread (fs / llm) ───────────────────────────────────────────── + # pre-started chunks ran on a background executor: paint on a separate + # lane so they don't collide with the main thread's bars for chunks i-3..i-1. + y_main = Y_PRESTART if ev.get("is_prestart") else Y_MAIN for seg_start, seg_end, kind in ev["main_segments"]: seg_dur = seg_end - seg_start if seg_dur < 0.05: continue - ax.barh(Y_MAIN, seg_dur, left=seg_start, height=BAR_H, + ax.barh(y_main, seg_dur, left=seg_start, height=BAR_H, color=COLORS[kind], alpha=0.88, edgecolor="white", linewidth=0.5) - ax.text(seg_start + seg_dur / 2, Y_MAIN, + ax.text(seg_start + seg_dur / 2, y_main, f"{kind}\n{seg_dur:.1f}s", ha="center", va="center", fontsize=6.5, color="white") @@ -245,12 +259,16 @@ def plot_timeline(rows, events, out_path, title="Parallel Chunk Overlap Timeline # y-axis ticks tts_mid = _tts_y((max_cands - 1) / 2) - ax.set_yticks([Y_MAIN, tts_mid, Y_PLAY]) - ax.set_yticklabels(["Main thread\n(fs + llm)", "TTS candidates\n(parallel)", "Playback"], - fontsize=9) + yticks = [Y_MAIN, tts_mid, Y_PLAY] + ylabels = ["Main thread\n(fs + llm)", "TTS candidates\n(parallel)", "Playback"] + if has_prestart: + yticks = [Y_PRESTART] + yticks + ylabels = ["Pre-start lane\n(last chunk)"] + ylabels + ax.set_yticks(yticks) + ax.set_yticklabels(ylabels, fontsize=9) ax.set_xlabel("Wall-clock time (s)", fontsize=10) ax.set_title(title, fontsize=12, fontweight="bold", pad=10) - ax.set_ylim(Y_MAIN - 0.5, Y_PLAY + 0.7) + ax.set_ylim(bottom_y - 0.5, Y_PLAY + 0.7) ax.grid(axis="x", alpha=0.25) legend_patches = [ diff --git a/src/scripts/stack_pngs.py b/src/scripts/stack_pngs.py new file mode 100644 index 0000000..ad45fb6 --- /dev/null +++ b/src/scripts/stack_pngs.py @@ -0,0 +1,26 @@ +"""Stack PNG images vertically (top-to-bottom) into a single image.""" +import argparse +from pathlib import Path +from PIL import Image + + +def stack_vertical(image_paths, out_path): + images = [Image.open(p).convert("RGB") for p in image_paths] + width = max(im.width for im in images) + height = sum(im.height for im in images) + combined = Image.new("RGB", (width, height), "white") + y = 0 + for im in images: + x = (width - im.width) // 2 + combined.paste(im, (x, y)) + y += im.height + combined.save(out_path) + print(f"Saved → {out_path} ({len(images)} images, {width}x{height})") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("inputs", nargs="+", help="PNG files to stack top-to-bottom") + parser.add_argument("-o", "--output", required=True, help="Output PNG path") + args = parser.parse_args() + stack_vertical([Path(p) for p in args.inputs], Path(args.output)) From a3fa57c9b598cdcffebc172b1f328c56b6015531 Mon Sep 17 00:00:00 2001 From: Yuante Li <1957922024@qq.com> Date: Sun, 3 May 2026 03:26:33 +0000 Subject: [PATCH 6/7] fix: small robustness fixes - ouragents: fall back to self.config.model when helper_model is None, not just when the attribute is missing - utils/model: pass num_retries=3 to litellm.completion - utils/tool: include traceback in extraction-retry warning --- src/ouragents.py | 2 +- src/utils/model.py | 2 +- src/utils/tool.py | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/ouragents.py b/src/ouragents.py index bce4a41..6b4c424 100644 --- a/src/ouragents.py +++ b/src/ouragents.py @@ -66,7 +66,7 @@ def __init__(self, config, motion): + f"use_rehearsal_tree: {self.use_rehearsal_tree}, use_debate_flow_tree: {self.use_debate_flow_tree}" ) - helper_model = getattr(config, "helper_model", self.config.model) + helper_model = getattr(config, "helper_model", None) or self.config.model self.helper_client = partial(HelperClient, model=helper_model, temperature=0, max_tokens=config.max_tokens, n=1) self.simulated_audience = [Audience(AudienceConfig(model=self.config.model, temperature=1)) for _ in range(1)] diff --git a/src/utils/model.py b/src/utils/model.py index 75a4894..c380670 100644 --- a/src/utils/model.py +++ b/src/utils/model.py @@ -145,7 +145,7 @@ def _completion_text(model_name: str, messages, wants_json: bool, temperature: f ) if wants_json: call_kwargs["response_format"] = {"type": "json_object"} - return litellm.completion(**call_kwargs) + return litellm.completion(num_retries=3, **call_kwargs) def _completion_structured( diff --git a/src/utils/tool.py b/src/utils/tool.py index 2d89b56..c43b399 100644 --- a/src/utils/tool.py +++ b/src/utils/tool.py @@ -391,7 +391,8 @@ def get_response_with_retry(llm, prompt, required_key, *, response_model: type[T logger.debug(f"Retry {retry} times.") time.sleep(30) except Exception as e: - logger.warning(f"Unexpected error {e} in extracting {required_key} from: {response}") + import traceback + logger.warning(f"Unexpected error {e} in extracting {required_key} from: {response}\n{traceback.format_exc()}") content = {} retry += 1 logger.debug(f"Retry {retry} times.") From 5752faad6374c4d13595ff8d17b053c3347704be Mon Sep 17 00:00:00 2001 From: Yuante Li <1957922024@qq.com> Date: Mon, 4 May 2026 19:27:57 +0000 Subject: [PATCH 7/7] feat: shared candidate pool with ratio-based prestart workers Refactor TTS refinement around a shared _ChunkRefineContext where a prestart worker (kicked off two iters early on ratio or absolute size triggers, plus the existing last-chunk variant) and the normal worker contribute candidates to the same pool. First worker to confirm an in-range candidate wins via try_adopt(). Per-worker fs/llm/tts streams are recorded so overlap_viz_par renders both lanes in parallel with the chosen kind labelled on playback. Chunk merging now uses word count (MIN_CHUNK_WORDS=30) and also folds an undersized last chunk into its predecessor. --- src/scripts/overlap_viz_par.py | 184 ++++++--- src/tts_streaming.py | 733 ++++++++++++++++++++++----------- 2 files changed, 618 insertions(+), 299 deletions(-) diff --git a/src/scripts/overlap_viz_par.py b/src/scripts/overlap_viz_par.py index 32a8d36..0dfccbe 100644 --- a/src/scripts/overlap_viz_par.py +++ b/src/scripts/overlap_viz_par.py @@ -62,16 +62,32 @@ def compute_timeline(rows): play0_start = t + rows[0]["total_elapsed_s"] # = tts_api_s for chunk 0 play0_end = play0_start + rows[0]["audio_seconds"] + # chunk 0 has a single TTS candidate (no workers); legacy path. + legacy_tts_times = rows[0].get("iter_tts_times_s", []) or [] + legacy_cands = [] + for k, tts_dur in enumerate(legacy_tts_times): + legacy_cands.append({ + "k": k, + "submit_t": t, + "duration": float(tts_dur), + "is_chosen": True, + "worker": "normal", + "intra_iter": k, + }) + events.append({ - "chunk": 0, - "prep_start": t, - "prep_end": play0_start, - "main_segments": [], # no fs/llm - "tts_candidates": _build_tts_candidates(rows[0], t), - "play_start": play0_start, - "play_end": play0_end, - "gap": 0.0, - "is_prestart": False, + "chunk": 0, + "prep_start_normal": t, + "prep_start_prestart": None, + "prep_end": play0_start, + "normal_segments": [], + "prestart_segments": [], + "tts_candidates": legacy_cands, + "play_start": play0_start, + "play_end": play0_end, + "gap": 0.0, + "is_prestart": False, + "prestart_kind": "", }) for i in range(1, len(rows)): @@ -80,7 +96,9 @@ def compute_timeline(rows): lead_s = float(rows[i].get("prep_start_lead_s", 0.0) or 0.0) is_prestart = lead_s > 0.01 - prep_start = prep_start_default - lead_s + prestart_kind = rows[i].get("prestart_kind", "") or "" + prep_start_normal = prep_start_default + prep_start_prestart = (prep_start_default - lead_s) if is_prestart else None prev_play_end = events[i - 1]["play_end"] ready_at = prep_start_default + prep_duration # wall-clock when prep is done @@ -88,25 +106,41 @@ def compute_timeline(rows): play_start = max(prev_play_end, ready_at) play_end = play_start + rows[i]["audio_seconds"] + prestart_segments = _build_worker_segments(rows[i], "prestart", prep_start_prestart) if is_prestart else [] + normal_segments = _build_worker_segments(rows[i], "normal", prep_start_normal) + tts_candidates = _build_dual_candidates(rows[i], prep_start_prestart, prep_start_normal) + events.append({ - "chunk": i, - "prep_start": prep_start, - "prep_end": ready_at, - "main_segments": _build_main_segments(rows[i], prep_start), - "tts_candidates": _build_tts_candidates(rows[i], prep_start), - "play_start": play_start, - "play_end": play_end, - "gap": gap, - "is_prestart": is_prestart, + "chunk": i, + "prep_start_normal": prep_start_normal, + "prep_start_prestart": prep_start_prestart, + "prep_end": ready_at, + "normal_segments": normal_segments, + "prestart_segments": prestart_segments, + "tts_candidates": tts_candidates, + "play_start": play_start, + "play_end": play_end, + "gap": gap, + "is_prestart": is_prestart, + "prestart_kind": prestart_kind if is_prestart else "", }) return events -def _build_main_segments(row, prep_start): - """fs / llm alternating segments on the main thread.""" - fs_times = row.get("iter_fs_times_s", []) or [] - llm_times = row.get("iter_llm_times_s", []) or [] +def _build_worker_segments(row, worker, prep_start): + """fs / llm alternating segments for one worker, starting at prep_start.""" + if prep_start is None: + return [] + fs_times = row.get(f"{worker}_fs_times_s", []) or [] + llm_times = row.get(f"{worker}_llm_times_s", []) or [] + # Backward compat: if per-worker fields are absent (old CSV), fall back to combined. + if not fs_times and not llm_times: + if worker != "normal": + return [] + fs_times = row.get("iter_fs_times_s", []) or [] + llm_times = row.get("iter_llm_times_s", []) or [] + segments = [] t = prep_start n_iters = max(len(fs_times), len(llm_times)) @@ -120,28 +154,67 @@ def _build_main_segments(row, prep_start): return segments -def _build_tts_candidates(row, prep_start): +def _build_dual_candidates(row, prep_start_prestart, prep_start_normal): """ - Return list of dicts {k, submit_t, duration, is_chosen}. + Build TTS candidate list combining prestart + normal workers. + Each candidate is {k, submit_t, duration, is_chosen, worker, intra_iter}. + k is assigned by sorting candidates by submit_t (so vertical stacking + matches chronological submission order). - Candidate k is submitted after sum(fs[0..k]) + sum(llm[0..k-1]) - duration == -1.0 means the future had not completed at selection time. + Within a single worker: candidate at intra_iter j is submitted after + sum(fs[0..j]) + sum(llm[0..j-1]) of that worker's timeline. """ - fs_times = row.get("iter_fs_times_s", []) or [] - llm_times = row.get("iter_llm_times_s", []) or [] - tts_times = row.get("iter_tts_times_s", []) or [] - used_iter = int(row.get("used_candidate_iter", 0)) - - result = [] - for k, tts_dur in enumerate(tts_times): - submit_t = prep_start + sum(fs_times[: k + 1]) + sum(llm_times[:k]) - result.append({ - "k": k, - "submit_t": submit_t, - "duration": float(tts_dur), - "is_chosen": k == used_iter, - }) - return result + chosen_label = (row.get("chosen_worker_label", "") or "").strip() + chosen_intra = int(row.get("chosen_intra_iter", 0) or 0) + used_iter_legacy = int(row.get("used_candidate_iter", 0) or 0) + + cands = [] + + def _add_worker(worker, start_t): + if start_t is None: + return + fs_times = row.get(f"{worker}_fs_times_s", []) or [] + llm_times = row.get(f"{worker}_llm_times_s", []) or [] + tts_times = row.get(f"{worker}_tts_times_s", []) or [] + for j, tts_dur in enumerate(tts_times): + submit_t = start_t + sum(fs_times[: j + 1]) + sum(llm_times[:j]) + is_chosen = (chosen_label == worker and j == chosen_intra) + cands.append({ + "submit_t": submit_t, + "duration": float(tts_dur), + "is_chosen": is_chosen, + "worker": worker, + "intra_iter": j, + }) + + has_per_worker = bool( + row.get("normal_tts_times_s") + or row.get("prestart_tts_times_s") + ) + + if has_per_worker: + _add_worker("prestart", prep_start_prestart) + _add_worker("normal", prep_start_normal) + else: + # Backward compat: old CSV with combined iter_*_times_s + fs_times = row.get("iter_fs_times_s", []) or [] + llm_times = row.get("iter_llm_times_s", []) or [] + tts_times = row.get("iter_tts_times_s", []) or [] + start_t = prep_start_prestart if prep_start_prestart is not None else prep_start_normal + for j, tts_dur in enumerate(tts_times): + submit_t = start_t + sum(fs_times[: j + 1]) + sum(llm_times[:j]) + cands.append({ + "submit_t": submit_t, + "duration": float(tts_dur), + "is_chosen": j == used_iter_legacy, + "worker": "normal", + "intra_iter": j, + }) + + cands.sort(key=lambda c: c["submit_t"]) + for k, c in enumerate(cands): + c["k"] = k + return cands # ── plotting ────────────────────────────────────────────────────────────────── @@ -192,8 +265,9 @@ def plot_timeline(rows, events, out_path, title="Parallel Chunk Overlap Timeline dur = ev["play_end"] - ev["play_start"] ax.barh(Y_PLAY, dur, left=ev["play_start"], height=BAR_H, color=COLORS["play"], alpha=0.88, edgecolor="white", linewidth=0.6) + kind_suffix = f" [{ev['prestart_kind']}-prestart]" if ev.get("prestart_kind") else "" ax.text(ev["play_start"] + dur / 2, Y_PLAY, - f"▶{i} {dur:.1f}s", + f"▶{i} {dur:.1f}s{kind_suffix}", ha="center", va="center", fontsize=8, fontweight="bold", color="white") # ── gap / silence ────────────────────────────────────────────────────── @@ -231,17 +305,25 @@ def plot_timeline(rows, events, out_path, title="Parallel Chunk Overlap Timeline f"T{cand['k']} (still running)", fontsize=5.5, color="#888888", va="bottom") - # ── main thread (fs / llm) ───────────────────────────────────────────── - # pre-started chunks ran on a background executor: paint on a separate - # lane so they don't collide with the main thread's bars for chunks i-3..i-1. - y_main = Y_PRESTART if ev.get("is_prestart") else Y_MAIN - for seg_start, seg_end, kind in ev["main_segments"]: + # ── main thread (normal worker fs/llm) ──────────────────────────────── + for seg_start, seg_end, kind in ev.get("normal_segments", []): + seg_dur = seg_end - seg_start + if seg_dur < 0.05: + continue + ax.barh(Y_MAIN, seg_dur, left=seg_start, height=BAR_H, + color=COLORS[kind], alpha=0.88, edgecolor="white", linewidth=0.5) + ax.text(seg_start + seg_dur / 2, Y_MAIN, + f"{kind}\n{seg_dur:.1f}s", + ha="center", va="center", fontsize=6.5, color="white") + + # ── pre-start lane (prestart worker fs/llm; runs in parallel with normal) ── + for seg_start, seg_end, kind in ev.get("prestart_segments", []): seg_dur = seg_end - seg_start if seg_dur < 0.05: continue - ax.barh(y_main, seg_dur, left=seg_start, height=BAR_H, + ax.barh(Y_PRESTART, seg_dur, left=seg_start, height=BAR_H, color=COLORS[kind], alpha=0.88, edgecolor="white", linewidth=0.5) - ax.text(seg_start + seg_dur / 2, y_main, + ax.text(seg_start + seg_dur / 2, Y_PRESTART, f"{kind}\n{seg_dur:.1f}s", ha="center", va="center", fontsize=6.5, color="white") @@ -263,7 +345,7 @@ def plot_timeline(rows, events, out_path, title="Parallel Chunk Overlap Timeline ylabels = ["Main thread\n(fs + llm)", "TTS candidates\n(parallel)", "Playback"] if has_prestart: yticks = [Y_PRESTART] + yticks - ylabels = ["Pre-start lane\n(last chunk)"] + ylabels + ylabels = ["Pre-start lane\n(parallel worker)"] + ylabels ax.set_yticks(yticks) ax.set_yticklabels(ylabels, fontsize=9) ax.set_xlabel("Wall-clock time (s)", fontsize=10) diff --git a/src/tts_streaming.py b/src/tts_streaming.py index b054610..71e3a26 100644 --- a/src/tts_streaming.py +++ b/src/tts_streaming.py @@ -35,8 +35,10 @@ MIN_TOLERANCE_S = 1.0 # floor so very short chunks aren't impossible to hit MAX_REFINEMENTS = 10 MAX_PARALLEL_TTS = 8 # max concurrent background TTS threads per chunk -MIN_CHUNK_CHARS = 50 # chunks shorter than this are merged into the next one +MIN_CHUNK_WORDS = 30 # chunks shorter than this (word count) are merged into the next one; the last chunk uses the same threshold EARLY_CUT_RATIO = 1.25 # fs_est/target_s threshold to trigger early-cut +RATIO_PRESTART_THRESHOLD = 2.0 # if chunk[i+2].chars / chunk[i+1].chars >= this, pre-start chunk i+2 one chunk earlier +ABS_PRESTART_CHARS = 1000 # also pre-start when chunk[i+2] is absolutely large (>= this many chars), regardless of ratio SPEED_ADJUST_MIN = 0.85 # TTS speed clamp lower bound SPEED_ADJUST_MAX = 1.15 # TTS speed clamp upper bound @@ -65,10 +67,21 @@ class ChunkProfile: chunk_total_s: float tolerance_s: float tol_upper_s: float - iter_llm_times_s: str # JSON list - iter_fs_times_s: str # JSON list - iter_tts_times_s: str # JSON list - prep_start_lead_s: float = 0.0 # >0 only for adopted pre-started last chunk; how much earlier than events[i-1].play_start the prep actually began + iter_llm_times_s: str # JSON list — combined across both workers (legacy) + iter_fs_times_s: str # JSON list — combined across both workers (legacy) + iter_tts_times_s: str # JSON list — combined across both workers (legacy) + prep_start_lead_s: float = 0.0 # >0 only when a pre-started result was adopted (last-chunk or ratio-prestart); how much earlier than events[i-1].play_start the prep actually began + prestart_kind: str = "" # "" if standard refine; "ratio" if adopted from ratio-prestart; "last" if adopted from last-chunk pre-start + # Per-worker timings. Each worker has its own ordered fs/llm/tts streams that + # the visualizer uses to draw two parallel lanes. Empty when that worker did not run. + prestart_llm_times_s: str = "[]" + prestart_fs_times_s: str = "[]" + prestart_tts_times_s: str = "[]" + normal_llm_times_s: str = "[]" + normal_fs_times_s: str = "[]" + normal_tts_times_s: str = "[]" + chosen_worker_label: str = "" # "" / "prestart" / "normal" (chunk 0 has no worker) + chosen_intra_iter: int = 0 # iteration within the chosen worker (0 = raw text) @dataclass @@ -201,10 +214,12 @@ def _tts_with_retry(client, content: str, voice: str = "echo", speed: float = 1. # -------- TTS candidate tracking -------- @dataclass class _TtsCandidate: - iteration: int # 0 = original chunk text; k = after k-th rewrite + iteration: int # global index in shared candidate pool text: str fs_estimated_s: float future: Any # concurrent.futures.Future -> Dict from _query_time_profiled + worker_label: str = "" # "prestart" / "normal" — which worker produced this + intra_iter: int = 0 # iteration index within the producing worker (0 = raw, k = k-th refine) def _pick_best_completed( @@ -240,152 +255,202 @@ def _collect_done() -> List[Tuple[_TtsCandidate, Dict]]: return min(done, key=lambda x: abs(x[1]["audio_seconds"] - target_s)) -# -------- parallel refinement + TTS (chunks 1+) -------- -def _adaptive_refine_parallel( - client, - text: str, - target_s: float, - time_budget_s: float, - prev_texts: List[str], - tolerance_s: float, - tolerance_upper_s: float, - voice: str = "echo", - max_ref: int = MAX_REFINEMENTS, - next_chunk_text: str = "", -) -> Dict[str, Any]: - executor = concurrent.futures.ThreadPoolExecutor(max_workers=MAX_PARALLEL_TTS) - candidates: List[_TtsCandidate] = [] - candidates_lock = threading.Lock() - stop_event = threading.Event() - done_event = threading.Event() - - t_wall_start = _now() - refine_stats: Dict[str, Any] = {} - - def _refine(): - cur = text - n_ref = 0 - llm_times: List[float] = [] - fs_times: List[float] = [] - t0 = _now() - - _t = _now(); est = _fastspeech_estimate(cur); fs_times.append(_now() - _t) - ok = _in_range(est, target_s, tolerance_s, tolerance_upper_s) - with candidates_lock: - if not stop_event.is_set(): - candidates.append(_TtsCandidate( - iteration=0, text=cur, fs_estimated_s=est, - future=executor.submit(_tts_with_retry, client, cur, voice), - )) - - while not ok and n_ref < max_ref and not stop_event.is_set() and target_s > 0: - cw = LengthEstimator.count_words(cur) - tw = max(10, round(cw * target_s / est)) - _t = _now(); cur = _revise_to_n_words(client, cur, tw, prev_texts, next_chunk_text); llm_times.append(_now() - _t) - n_ref += 1 - _t = _now(); est = _fastspeech_estimate(cur); fs_times.append(_now() - _t) - ok = _in_range(est, target_s, tolerance_s, tolerance_upper_s) - - if stop_event.is_set(): - break - with candidates_lock: - candidates.append(_TtsCandidate( - iteration=n_ref, text=cur, fs_estimated_s=est, - future=executor.submit(_tts_with_retry, client, cur, voice), - )) - - refine_stats.update({ - "n_ref": n_ref, "ok": ok, - "llm_times": llm_times, "fs_times": fs_times, - "refine_total_s": _now() - t0, - }) - - if ok and not stop_event.is_set(): - with candidates_lock: - last = candidates[-1] - tts_out = last.future.result() - refine_stats["chosen_cand"] = last - refine_stats["tts_out"] = tts_out - done_event.set() - - stop_event.set() - - refine_thread = threading.Thread(target=_refine, daemon=True) - refine_thread.start() - - done_event.wait(timeout=time_budget_s) - stop_event.set() - - total_elapsed_s = _now() - t_wall_start - timed_out = not done_event.is_set() - - if done_event.is_set() and "chosen_cand" in refine_stats: - chosen_cand = refine_stats["chosen_cand"] - tts_out = refine_stats["tts_out"] - used_iter = chosen_cand.iteration - else: - while True: - with candidates_lock: - snap = list(candidates) - if snap: - break - time.sleep(0.05) - chosen_cand, tts_out = _pick_best_completed(snap, target_s) - used_iter = chosen_cand.iteration - - with candidates_lock: - snap = list(candidates) - - iter_tts_times: List[float] = [] - for c in snap: - if c.future.done(): - try: - iter_tts_times.append(round(c.future.result()["tts_api_s"], 3)) - except Exception: - iter_tts_times.append(-1.0) - else: - iter_tts_times.append(-1.0) - - refine_thread.join(timeout=5) - executor.shutdown(wait=False) - - # ---- speed adjustment: re-TTS if audio is outside tolerance but speed can fix it ---- - audio_s = float(tts_out["audio_seconds"]) - raw_speed = audio_s / target_s if target_s > 0 else 1.0 - speed_used = 1.0 - if not _in_range(audio_s, target_s, tolerance_s, tolerance_upper_s): - clamped = max(SPEED_ADJUST_MIN, min(SPEED_ADJUST_MAX, raw_speed)) - if abs(clamped - 1.0) > 0.01: +# -------- shared refine context + worker (used by both prestart and normal refine) -------- +class _ChunkRefineContext: + """ + Holds the shared state for refining ONE chunk. Multiple workers (a prestart + worker, a normal-refine worker) can run concurrently against the same context, + contributing candidates to a shared pool until either: + - any candidate's fs_estimate hits target -> first worker to confirm sets done_event + - main loop's deadline (= prev_audio_s) elapses -> external code stops everything + """ + def __init__( + self, + client, + original_text: str, + target_s: float, + tol_s: float, + tol_upper_s: float, + prev_texts: List[str], + next_chunk_text: str, + voice: str, + max_ref: int, + kickoff_iter: int, + kickoff_kind: str, # "" | "ratio" | "last" + ): + self.client = client + self.original_text = original_text + + self._target_s = target_s + self._tol_s = tol_s + self._tol_upper_s = tol_upper_s + self._target_lock = threading.Lock() + + self.candidates: List[_TtsCandidate] = [] + self.candidates_lock = threading.Lock() + + self.stop_event = threading.Event() + self.done_event = threading.Event() + self._adopt_lock = threading.Lock() + + self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=MAX_PARALLEL_TTS) + + self.prev_texts = list(prev_texts) + self.next_chunk_text = next_chunk_text + self.voice = voice + self.max_ref = max_ref + + self.kickoff_iter = kickoff_iter + self.kickoff_kind = kickoff_kind + + # per-worker stats: label -> {"n_ref", "llm_times", "fs_times"} + self._stats: Dict[str, Dict[str, Any]] = {} + self._stats_lock = threading.Lock() + + self.chosen_cand: Optional[_TtsCandidate] = None + self.chosen_tts_out: Optional[Dict[str, Any]] = None + + self.t_start_wall = _now() + self.workers: List[threading.Thread] = [] + + def get_target(self) -> Tuple[float, float, float]: + with self._target_lock: + return self._target_s, self._tol_s, self._tol_upper_s + + def update_target(self, target_s: float, tol_s: float, tol_upper_s: float) -> None: + with self._target_lock: + self._target_s = target_s + self._tol_s = tol_s + self._tol_upper_s = tol_upper_s + + def _stats_for(self, label: str) -> Dict[str, Any]: + with self._stats_lock: + if label not in self._stats: + self._stats[label] = {"n_ref": 0, "llm_times": [], "fs_times": []} + return self._stats[label] + + def add_fs_time(self, label: str, t: float) -> None: + with self._stats_lock: + self._stats.setdefault(label, {"n_ref": 0, "llm_times": [], "fs_times": []}) + self._stats[label]["fs_times"].append(t) + + def add_llm_time(self, label: str, t: float) -> None: + with self._stats_lock: + self._stats.setdefault(label, {"n_ref": 0, "llm_times": [], "fs_times": []}) + self._stats[label]["llm_times"].append(t) + self._stats[label]["n_ref"] += 1 + + def aggregate_stats(self) -> Tuple[int, List[float], List[float]]: + with self._stats_lock: + n_ref_total = sum(s["n_ref"] for s in self._stats.values()) + llm_times: List[float] = [] + fs_times: List[float] = [] + for s in self._stats.values(): + llm_times.extend(s["llm_times"]) + fs_times.extend(s["fs_times"]) + return n_ref_total, llm_times, fs_times + + def add_candidate(self, text: str, est: float, label: str, intra_iter: int) -> _TtsCandidate: + with self.candidates_lock: + iteration = len(self.candidates) + cand = _TtsCandidate( + iteration=iteration, + text=text, + fs_estimated_s=est, + future=self.executor.submit(_tts_with_retry, self.client, text, self.voice), + worker_label=label, + intra_iter=intra_iter, + ) + self.candidates.append(cand) + return cand + + def try_adopt(self, cand: _TtsCandidate, tts_out: Dict[str, Any]) -> bool: + with self._adopt_lock: + if self.done_event.is_set(): + return False + self.chosen_cand = cand + self.chosen_tts_out = tts_out + self.done_event.set() + self.stop_event.set() + return True + + +def _refine_worker(ctx: _ChunkRefineContext, label: str) -> None: + """ + Independent worker that drives one branch of refinement against ctx. + Reads target/tolerance from ctx (which may be updated externally) at the + start of each iteration. Pushes candidates into the shared pool. + Sets ctx.done_event via try_adopt() when its candidate hits target. + """ + cur = ctx.original_text + + if ctx.stop_event.is_set(): + return + + # ---- step 0: fs estimate raw text + submit raw TTS candidate ---- + t = _now() + try: + est = _fastspeech_estimate(cur) + except Exception: + return + ctx.add_fs_time(label, _now() - t) + + if ctx.stop_event.is_set(): + return + + target_s, tol_s, tol_upper_s = ctx.get_target() + cand = ctx.add_candidate(cur, est, label, intra_iter=0) + + if _in_range(est, target_s, tol_s, tol_upper_s): + try: + tts_out = cand.future.result() + if ctx.try_adopt(cand, tts_out): + return + except Exception: + pass + + # ---- LLM refinement loop ---- + n_ref_local = 0 + while not ctx.stop_event.is_set() and n_ref_local < ctx.max_ref: + target_s, tol_s, tol_upper_s = ctx.get_target() + if target_s <= 0: + break + + cw = LengthEstimator.count_words(cur) + tw = max(10, round(cw * target_s / max(est, 1.0))) + + t = _now() + try: + cur = _revise_to_n_words(ctx.client, cur, tw, ctx.prev_texts, ctx.next_chunk_text) + except Exception: + break + ctx.add_llm_time(label, _now() - t) + n_ref_local += 1 + + if ctx.stop_event.is_set(): + break + + t = _now() + try: + est = _fastspeech_estimate(cur) + except Exception: + break + ctx.add_fs_time(label, _now() - t) + + if ctx.stop_event.is_set(): + break + + cand = ctx.add_candidate(cur, est, label, intra_iter=n_ref_local) + + target_s, tol_s, tol_upper_s = ctx.get_target() + if _in_range(est, target_s, tol_s, tol_upper_s): try: - speed_tts_out = _tts_with_retry(client, chosen_cand.text, voice=voice, speed=clamped) - if abs(speed_tts_out["audio_seconds"] - target_s) < abs(audio_s - target_s): - tts_out = speed_tts_out - audio_s = float(tts_out["audio_seconds"]) - speed_used = clamped + tts_out = cand.future.result() + if ctx.try_adopt(cand, tts_out): + return except Exception: pass - return { - "refined_text": chosen_cand.text, - "n_ref_used": refine_stats.get("n_ref", 0), - "fs_estimated_s": chosen_cand.fs_estimated_s, - "refine_total_s": refine_stats.get("refine_total_s", total_elapsed_s), - "total_elapsed_s": total_elapsed_s, - "target_reached": refine_stats.get("ok", False), - "timed_out": timed_out, - "n_candidates_submitted": len(snap), - "n_candidates_done": sum(1 for t in iter_tts_times if t >= 0), - "used_candidate_iter": used_iter, - "llm_times_s": refine_stats.get("llm_times", []), - "fs_times_s": refine_stats.get("fs_times", []), - "iter_tts_times_s": iter_tts_times, - "speed_used": speed_used, - "audio_seconds": audio_s, - "tts_api_s": float(tts_out["tts_api_s"]), - "mp3_parse_s": float(tts_out["mp3_parse_s"]), - "mp3_bytes": tts_out["mp3_bytes"], - } - # -------- chunk utilities -------- def _split_sentences(text: str) -> List[str]: @@ -426,15 +491,21 @@ def _early_cut_chunk( return head, tail -def _merge_short_chunks(segments: List[str], min_chars: int = MIN_CHUNK_CHARS) -> List[str]: +def _merge_short_chunks( + segments: List[str], + min_words: int = MIN_CHUNK_WORDS, +) -> List[str]: result = list(segments) i = 0 while i < len(result): - if len(result[i]) < min_chars and i + 1 < len(result): + if LengthEstimator.count_words(result[i]) < min_words and i + 1 < len(result): result[i + 1] = result[i] + " " + result[i + 1] result.pop(i) else: i += 1 + if len(result) >= 2 and LengthEstimator.count_words(result[-1]) < min_words: + result[-2] = result[-2] + " " + result[-1] + result.pop() return result @@ -473,6 +544,7 @@ def run_pipeline( segments_list = list(_merge_short_chunks(segments_list)) n_chunks = len(segments_list) audio_budget_remaining = total_budget_s + total_chars_initial = sum(len(c) for c in segments_list) if out_dir is not None: out_dir = Path(out_dir) @@ -484,10 +556,100 @@ def run_pipeline( prev_audio_s: Optional[float] = None - # Pre-start future for the last chunk (filled at chunk n_chunks-2) - _last_chunk_future: Optional[concurrent.futures.Future] = None - _last_chunk_executor: Optional[concurrent.futures.ThreadPoolExecutor] = None - _last_chunk_lead_s = 0.0 + # Pre-start contexts indexed by target chunk idx. A context is created when + # we kick off a prestart worker for that chunk (ratio or last-chunk variant). + # When the main loop reaches that chunk, it pops the context and adds a + # normal refine worker that shares the same candidate pool. + _chunk_contexts: Dict[int, _ChunkRefineContext] = {} + + def _kickoff_ratio_prestart(iter_i: int) -> None: + """At the start of iter iter_i, maybe kick off prestart for c[i+2]. + + Triggers if EITHER: + - chunk[i+2].chars / chunk[i+1].chars >= RATIO_PRESTART_THRESHOLD, OR + - chunk[i+2].chars >= ABS_PRESTART_CHARS (absolutely long chunk) + """ + target_idx = iter_i + 2 + if target_idx >= len(segments_list) - 1: # would be the last chunk → handled by last-chunk kickoff + return + if target_idx in _chunk_contexts: + return + next_chars = len(segments_list[iter_i + 1]) + target_chars = len(segments_list[target_idx]) + ratio = (target_chars / next_chars) if next_chars > 0 else 0.0 + ratio_trigger = ratio >= RATIO_PRESTART_THRESHOLD + abs_trigger = target_chars >= ABS_PRESTART_CHARS + if not (ratio_trigger or abs_trigger): + return + target_text = segments_list[target_idx] + tgt_s_est = total_budget_s * target_chars / max(total_chars_initial, 1) + tol_est = max(MIN_TOLERANCE_S, tgt_s_est * tolerance_ratio) + tol_upper_est = max(MIN_TOLERANCE_S, tgt_s_est * TOLERANCE_RATIO_UPPER) + ctx = _ChunkRefineContext( + client=client, + original_text=target_text, + target_s=tgt_s_est, + tol_s=tol_est, + tol_upper_s=tol_upper_est, + prev_texts=list(final_texts), + next_chunk_text="", + voice=voice, + max_ref=MAX_REFINEMENTS, + kickoff_iter=iter_i, + kickoff_kind="ratio", + ) + _chunk_contexts[target_idx] = ctx + th = threading.Thread(target=_refine_worker, args=(ctx, "prestart"), daemon=True) + ctx.workers.append(th) + th.start() + triggers = [] + if ratio_trigger: + triggers.append(f"ratio={ratio:.2f}x") + if abs_trigger: + triggers.append(f"abs={target_chars}c>={ABS_PRESTART_CHARS}") + print( + f" [ratio-prestart] chunk {target_idx} kicked off at start of iter {iter_i}, " + f"trigger=[{', '.join(triggers)}], target_est={tgt_s_est:.1f}s" + ) + + def _kickoff_last_chunk_prestart(iter_i: int) -> None: + """At the start of iter iter_i, if iter_i == n-3, kick off prestart for the last chunk.""" + n_now = len(segments_list) + if n_now < 3: + return + if iter_i != n_now - 3: + return + last_idx = n_now - 1 + if last_idx in _chunk_contexts: + return + last_text = segments_list[last_idx] + last_chars = len(last_text) + # Use *initial* allocation for consistency with ratio-prestart; main loop + # will push the up-to-date target later via update_target(). + tgt_s_est = total_budget_s * last_chars / max(total_chars_initial, 1) + tol_est = max(MIN_TOLERANCE_S, tgt_s_est * tolerance_ratio) + tol_upper_est = max(MIN_TOLERANCE_S, tgt_s_est * TOLERANCE_RATIO_UPPER) + ctx = _ChunkRefineContext( + client=client, + original_text=last_text, + target_s=tgt_s_est, + tol_s=tol_est, + tol_upper_s=tol_upper_est, + prev_texts=list(final_texts), + next_chunk_text="", + voice=voice, + max_ref=MAX_REFINEMENTS, + kickoff_iter=iter_i, + kickoff_kind="last", + ) + _chunk_contexts[last_idx] = ctx + th = threading.Thread(target=_refine_worker, args=(ctx, "prestart"), daemon=True) + ctx.workers.append(th) + th.start() + print( + f" [last-prestart] chunk {last_idx} kicked off at start of iter {iter_i}, " + f"target_est={tgt_s_est:.1f}s" + ) i = 0 while i < len(segments_list): @@ -534,6 +696,23 @@ def run_pipeline( tts_api_s = 0.0 mp3_parse_s = 0.0 chunk_lead_s = 0.0 + chunk_prestart_kind = "" + # Per-worker times (default empty; chunks 1+ branch overrides) + prestart_fs_list: List[float] = [] + prestart_llm_list: List[float] = [] + prestart_tts_list: List[float] = [] + normal_fs_list: List[float] = [] + normal_llm_list: List[float] = [] + normal_tts_list: List[float] = [] + chosen_worker_label = "" + chosen_intra_iter = 0 + + # ---- (NEW) at start of every iteration: maybe kick off prestarts ---- + # For iter i, ratio check looks at c[i+2]/c[i+1]; last-chunk fires when i == n-3. + # Both kickoffs run BEFORE we process the current chunk, so chunk 0's TTS + # runs in parallel with the prestart for chunk 2 (if ratio triggered). + _kickoff_ratio_prestart(i) + _kickoff_last_chunk_prestart(i) # ---- chunk 0: no refinement, sequential TTS ---- if i == 0: @@ -574,80 +753,158 @@ def run_pipeline( total_elapsed_s = tts_api_s overrun_s = tts_api_s - # ---- chunks 1+: parallel refinement with background TTS candidates ---- + # ---- chunks 1+: shared candidate pool with prestart + normal workers ---- else: time_budget_s = prev_audio_s - is_last = (i == len(segments_list) - 1) - if is_last and _last_chunk_future is not None: - try: - pre_out = _last_chunk_future.result(timeout=max(0.0, time_budget_s)) - except concurrent.futures.TimeoutError: - pre_out = _last_chunk_future.result() - except Exception: - pre_out = None - if pre_out is not None: - pre_out = dict(pre_out) - pre_out["total_elapsed_s"] = max( - 0.0, - float(pre_out["total_elapsed_s"]) - _last_chunk_lead_s, - ) - ref_out = pre_out - chunk_lead_s = _last_chunk_lead_s - print( - f" [pre-start] adopted pre-start result: " - f"effective_wait={pre_out['total_elapsed_s']:.1f}s, " - f"lead={_last_chunk_lead_s:.1f}s, target={target_s:.1f}s" - ) - else: - ref_out = _adaptive_refine_parallel( - client, chunk, - target_s=target_s, - time_budget_s=time_budget_s, - prev_texts=final_texts, - tolerance_s=tol_s, - tolerance_upper_s=tol_upper_s, - voice=voice, - max_ref=max_ref, - next_chunk_text=next_chunk_text, - ) - if _last_chunk_executor is not None: - _last_chunk_executor.shutdown(wait=False) - _last_chunk_executor = None - _last_chunk_future = None - _last_chunk_lead_s = 0.0 + # Pop existing prestart context (if any), or build a fresh context + ctx = _chunk_contexts.pop(i, None) + if ctx is not None: + # Prestart was running; push the up-to-date target so its next + # iteration uses real budget instead of the initial estimate. + ctx.update_target(target_s, tol_s, tol_upper_s) + # update prev_texts for normal worker via its own field + ctx.prev_texts = list(final_texts) + ctx.next_chunk_text = next_chunk_text + chunk_prestart_kind = ctx.kickoff_kind else: - ref_out = _adaptive_refine_parallel( - client, chunk, + ctx = _ChunkRefineContext( + client=client, + original_text=chunk, target_s=target_s, - time_budget_s=time_budget_s, - prev_texts=final_texts, - tolerance_s=tol_s, - tolerance_upper_s=tol_upper_s, + tol_s=tol_s, + tol_upper_s=tol_upper_s, + prev_texts=list(final_texts), + next_chunk_text=next_chunk_text, voice=voice, max_ref=max_ref, - next_chunk_text=next_chunk_text, + kickoff_iter=i, + kickoff_kind="", + ) + chunk_prestart_kind = "" + + # Always start a normal worker for this chunk (in addition to any + # prestart worker that may already be running on the same context). + normal_th = threading.Thread(target=_refine_worker, args=(ctx, "normal"), daemon=True) + ctx.workers.append(normal_th) + normal_th.start() + + # Wait for ANY worker to find an ok candidate, OR until deadline. + ctx.done_event.wait(timeout=max(0.0, time_budget_s)) + ctx.stop_event.set() + + total_elapsed_s = _now() - ctx.t_start_wall + + # Pick the chosen candidate + if ctx.chosen_cand is not None and ctx.chosen_tts_out is not None: + chosen_cand = ctx.chosen_cand + tts_out = ctx.chosen_tts_out + in_range = True + timed_out = False + else: + # Deadline hit before any worker confirmed ok → take best from pool + with ctx.candidates_lock: + snap = list(ctx.candidates) + if not snap: + raise RuntimeError(f"Chunk {i}: no candidates produced") + target_now, _, _ = ctx.get_target() + chosen_cand, tts_out = _pick_best_completed(snap, target_now) + in_range = False + timed_out = True + + # Speed adjustment if still out of range + target_now, tol_now, tol_upper_now = ctx.get_target() + audio_s = float(tts_out["audio_seconds"]) + if not _in_range(audio_s, target_now, tol_now, tol_upper_now): + raw_speed = audio_s / target_now if target_now > 0 else 1.0 + clamped = max(SPEED_ADJUST_MIN, min(SPEED_ADJUST_MAX, raw_speed)) + if abs(clamped - 1.0) > 0.01: + try: + speed_tts_out = _tts_with_retry(client, chosen_cand.text, voice=voice, speed=clamped) + if abs(speed_tts_out["audio_seconds"] - target_now) < abs(audio_s - target_now): + tts_out = speed_tts_out + audio_s = float(tts_out["audio_seconds"]) + except Exception: + pass + + # Aggregate stats from all workers on this context + n_ref_total, all_llm_times, all_fs_times = ctx.aggregate_stats() + with ctx.candidates_lock: + snap = list(ctx.candidates) + iter_tts_times: List[float] = [] + for c in snap: + if c.future.done(): + try: + iter_tts_times.append(round(c.future.result()["tts_api_s"], 3)) + except Exception: + iter_tts_times.append(-1.0) + else: + iter_tts_times.append(-1.0) + + # Per-worker fs/llm/tts streams (in candidate-submit order within each worker) + def _worker_tts_in_order(label: str) -> List[float]: + out: List[float] = [] + for c in snap: + if c.worker_label != label: + continue + if c.future.done(): + try: + out.append(round(c.future.result()["tts_api_s"], 3)) + except Exception: + out.append(-1.0) + else: + out.append(-1.0) + return out + + prestart_stats = ctx._stats.get("prestart", {"fs_times": [], "llm_times": []}) + normal_stats = ctx._stats.get("normal", {"fs_times": [], "llm_times": []}) + prestart_fs_list = list(prestart_stats.get("fs_times", [])) + prestart_llm_list = list(prestart_stats.get("llm_times", [])) + prestart_tts_list = _worker_tts_in_order("prestart") + normal_fs_list = list(normal_stats.get("fs_times", [])) + normal_llm_list = list(normal_stats.get("llm_times", [])) + normal_tts_list = _worker_tts_in_order("normal") + + # Compute lead_s for prestarted chunks (how much earlier than the + # would-be normal prep_start the worker actually started). + if chunk_prestart_kind: + if ctx.kickoff_iter == 0: + chunk_lead_s = chunk_profiles[0].tts_api_s + chunk_profiles[0].audio_seconds + else: + # kickoff at start of iter k (k = i-2) → lead = audio_{k-1} + audio_k = audio_{i-3} + audio_{i-2} + chunk_lead_s = ( + chunk_profiles[i - 3].audio_seconds + chunk_profiles[i - 2].audio_seconds + ) + # subtract lead from elapsed for reporting parity with the old design + total_elapsed_s = max(0.0, total_elapsed_s - chunk_lead_s) + print( + f" [{chunk_prestart_kind}-prestart] adopted for chunk {i}: " + f"lead={chunk_lead_s:.1f}s, effective_elapsed={total_elapsed_s:.1f}s" ) - refined = ref_out["refined_text"] - n_ref_used = ref_out["n_ref_used"] - fs_estimated_s = ref_out["fs_estimated_s"] - refine_total_s = ref_out["refine_total_s"] - in_range = ref_out["target_reached"] - timed_out = ref_out["timed_out"] - n_candidates_submitted = ref_out["n_candidates_submitted"] - n_candidates_done = ref_out["n_candidates_done"] - used_candidate_iter = ref_out["used_candidate_iter"] - iter_llm_times_s = json.dumps([round(t, 3) for t in ref_out["llm_times_s"]]) - iter_fs_times_s = json.dumps([round(t, 3) for t in ref_out["fs_times_s"]]) - iter_tts_times_s = json.dumps(ref_out["iter_tts_times_s"]) - total_elapsed_s = ref_out["total_elapsed_s"] - audio_seconds = ref_out["audio_seconds"] - tts_api_s = ref_out["tts_api_s"] - mp3_parse_s = ref_out["mp3_parse_s"] - mp3_bytes = ref_out["mp3_bytes"] - - overrun_s = max(0.0, ref_out["total_elapsed_s"] - time_budget_s) + refined = chosen_cand.text + n_ref_used = n_ref_total + fs_estimated_s = chosen_cand.fs_estimated_s + refine_total_s = total_elapsed_s + n_candidates_submitted = len(snap) + n_candidates_done = sum(1 for t in iter_tts_times if t >= 0) + used_candidate_iter = chosen_cand.iteration + iter_llm_times_s = json.dumps([round(t, 3) for t in all_llm_times]) + iter_fs_times_s = json.dumps([round(t, 3) for t in all_fs_times]) + iter_tts_times_s = json.dumps(iter_tts_times) + chosen_worker_label = chosen_cand.worker_label + chosen_intra_iter = chosen_cand.intra_iter + audio_seconds = audio_s + tts_api_s = float(tts_out["tts_api_s"]) + mp3_parse_s = float(tts_out["mp3_parse_s"]) + mp3_bytes = tts_out["mp3_bytes"] + + overrun_s = max(0.0, total_elapsed_s - time_budget_s) + + # Best-effort cleanup (workers will exit at next stop_event check) + for w in ctx.workers: + w.join(timeout=2) + ctx.executor.shutdown(wait=False) try: seg = AudioSegment.from_file(BytesIO(mp3_bytes), format="mp3") @@ -662,35 +919,6 @@ def run_pipeline( final_texts.append(refined) all_mp3_bytes.append(mp3_bytes) - # ---- pre-start last chunk refinement two chunks early ---- - if ( - i == len(segments_list) - 3 - and len(segments_list) >= 3 - and _last_chunk_future is None - ): - last_idx = len(segments_list) - 1 - last_text = segments_list[last_idx] - last_chars = len(last_text) - remaining_chars_est = sum(len(c) for c in segments_list[i + 1:]) - last_target_s_est = audio_budget_remaining * (last_chars / max(remaining_chars_est, 1)) - last_tol_s_est = max(MIN_TOLERANCE_S, last_target_s_est * tolerance_ratio) - last_tol_upper_s_est = max(MIN_TOLERANCE_S, last_target_s_est * TOLERANCE_RATIO_UPPER) - _last_chunk_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) - _last_chunk_future = _last_chunk_executor.submit( - _adaptive_refine_parallel, - client, last_text, - last_target_s_est, - total_budget_s, - list(final_texts), - last_tol_s_est, - last_tol_upper_s_est, - voice, - MAX_REFINEMENTS, - "", - ) - _last_chunk_lead_s = prev_audio_s or 0.0 - print(f" [pre-start] chunk {last_idx} kicked off at chunk {i}, est_target={last_target_s_est:.1f}s") - if out_dir is not None: (out_dir / f"chunk_{i:03d}.txt").write_text(refined, encoding="utf-8") (out_dir / f"chunk_{i:03d}.mp3").write_bytes(mp3_bytes) @@ -731,6 +959,15 @@ def run_pipeline( iter_fs_times_s=iter_fs_times_s, iter_tts_times_s=iter_tts_times_s, prep_start_lead_s=chunk_lead_s, + prestart_kind=chunk_prestart_kind, + prestart_llm_times_s=json.dumps([round(t, 3) for t in prestart_llm_list]), + prestart_fs_times_s=json.dumps([round(t, 3) for t in prestart_fs_list]), + prestart_tts_times_s=json.dumps(prestart_tts_list), + normal_llm_times_s=json.dumps([round(t, 3) for t in normal_llm_list]), + normal_fs_times_s=json.dumps([round(t, 3) for t in normal_fs_list]), + normal_tts_times_s=json.dumps(normal_tts_list), + chosen_worker_label=chosen_worker_label, + chosen_intra_iter=chosen_intra_iter, ) chunk_profiles.append(cp)