From 545f89551d1e2478a080c4ef6f781df855b4a2aa Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 18 Jun 2026 19:31:14 +0800 Subject: [PATCH] fix(backend): regroup offline sync segments before ASR --- backend/routers/sync.py | 136 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 132 insertions(+), 4 deletions(-) diff --git a/backend/routers/sync.py b/backend/routers/sync.py index e50cd2a4db..db0d14e3b0 100644 --- a/backend/routers/sync.py +++ b/backend/routers/sync.py @@ -666,13 +666,17 @@ def decode_opus_file_to_wav(opus_file_path, wav_file_path, sample_rate=16000, ch return False -def get_timestamp_from_path(path: str): - timestamp = int(path.split('/')[-1].split('_')[-1].split('.')[0]) +def get_timestamp_seconds_from_path(path: str) -> float: + timestamp = float(os.path.basename(path).rsplit('_', 1)[-1].rsplit('.', 1)[0]) if timestamp > 1e10: - return int(timestamp / 1000) + return timestamp / 1000 return timestamp +def get_timestamp_from_path(path: str): + return int(get_timestamp_seconds_from_path(path)) + + def retrieve_file_paths(files: List[UploadFile], uid: str): directory = f'syncing/{uid}/' os.makedirs(directory, exist_ok=True) @@ -717,6 +721,116 @@ def get_wav_duration(wav_path: str) -> float: return 0.0 +SYNC_ASR_WINDOW_MAX_SECONDS = 10 * 60 +SYNC_ASR_WINDOW_MAX_GAP_SECONDS = 120 + + +def _collect_asr_segment_infos(segment_paths: List[str]) -> List[dict]: + infos = [] + for path in segment_paths: + start = get_timestamp_seconds_from_path(path) + duration = get_wav_duration(path) + if duration <= 0: + raise ValueError(f'Invalid WAV duration for {path}') + infos.append({'path': path, 'start': start, 'end': start + duration}) + return sorted(infos, key=lambda item: item['start']) + + +def _group_asr_segment_infos(segment_infos: List[dict]) -> List[List[dict]]: + groups = [] + current_group = [] + current_start = 0.0 + current_end = 0.0 + + for info in segment_infos: + if not current_group: + current_group = [info] + current_start = info['start'] + current_end = info['end'] + continue + + gap_seconds = info['start'] - current_end + projected_end = max(current_end, info['end']) + projected_window_seconds = projected_end - current_start + should_split = ( + gap_seconds > SYNC_ASR_WINDOW_MAX_GAP_SECONDS or projected_window_seconds > SYNC_ASR_WINDOW_MAX_SECONDS + ) + if should_split: + groups.append(current_group) + current_group = [info] + current_start = info['start'] + current_end = info['end'] + continue + + current_group.append(info) + current_end = projected_end + + if current_group: + groups.append(current_group) + + return groups + + +def _merge_asr_window_group(group: List[dict]) -> str: + if len(group) == 1: + return group[0]['path'] + + output_dir = os.path.dirname(group[0]['path']) + output_path = os.path.join(output_dir, f"asr_window_{int(group[0]['start'] * 1000)}.wav") + combined = None + + try: + combined = AudioSegment.from_wav(group[0]['path']) + expected_end = group[0]['end'] + sample_width = combined.sample_width + channels = combined.channels + frame_rate = combined.frame_rate + + for info in group[1:]: + gap_ms = max(0, int((info['start'] - expected_end) * 1000)) + if gap_ms > 0: + silence = AudioSegment.silent(duration=gap_ms, frame_rate=frame_rate) + silence = silence.set_sample_width(sample_width).set_channels(channels) + combined += silence + del silence + + next_audio = AudioSegment.from_wav(info['path']) + combined += next_audio + expected_end = max(expected_end, info['end']) + del next_audio + + combined.export(output_path, format='wav') + return output_path + finally: + if combined is not None: + del combined + + +def build_asr_windows_from_segments(segment_paths: List[str]) -> Tuple[List[str], List[str]]: + """Merge adjacent VAD WAVs into wider ASR windows while preserving wall-clock gaps.""" + if len(segment_paths) <= 1: + return segment_paths, [] + + temp_paths = [] + try: + segment_infos = _collect_asr_segment_infos(segment_paths) + groups = _group_asr_segment_infos(segment_infos) + asr_paths = [] + for group in groups: + path = _merge_asr_window_group(group) + asr_paths.append(path) + if len(group) > 1: + temp_paths.append(path) + + if len(asr_paths) != len(segment_paths): + logger.info(f'sync: regrouped {len(segment_paths)} VAD segment(s) into {len(asr_paths)} ASR window(s)') + return asr_paths, temp_paths + except Exception as e: + logger.error(f'sync: failed to regroup VAD segments for ASR, using original segments: {e}') + _cleanup_files(temp_paths) + return segment_paths, [] + + def decode_pcm_file_to_wav(pcm_file_path, wav_file_path, sample_rate=16000, channels=1, sample_width=2): """Decode a length-prefixed PCM .bin file to WAV. @@ -1358,6 +1472,7 @@ async def sync_local_files( paths = [] wav_paths = [] segmented_paths = set() + asr_temp_paths = [] try: try: @@ -1451,6 +1566,10 @@ def _run_vad(path): # assignment is serialized oldest-first so adjacent chunks merge instead of # racing into separate conversations (#6551, #5747). ordered_paths = sorted(segmented_paths, key=get_timestamp_from_path) + ordered_paths, asr_temp_paths = await run_blocking( + sync_executor, build_asr_windows_from_segments, ordered_paths + ) + total_segments = len(ordered_paths) assignment_turnstile = _OrderedTurnstile(ordered_paths) await asyncio.gather( *[ @@ -1531,6 +1650,7 @@ def _run_vad(path): # Clean up any remaining temporary files _cleanup_files(paths) # .bin files (in case decode_files_to_wav didn't finish) _cleanup_files(wav_paths) # Original wav files (if VAD didn't complete) + _cleanup_files(asr_temp_paths) # ASR windows synthesized from VAD segments _cleanup_files(segmented_paths) # Segmented wav files after processing @@ -1611,6 +1731,7 @@ async def _run_full_pipeline_background_async( concurrency_gate = contextlib.nullcontext() if task_mode else _get_sync_pipeline_semaphore() async with concurrency_gate: segmented_paths = set() + asr_temp_paths = [] wav_paths = [] stage_timings = {} pipeline_start = time.monotonic() @@ -1756,6 +1877,13 @@ def _run_vad_bg(path): segment_errors = [] segment_lock = threading.Lock() + segment_list = sorted(segmented_paths, key=get_timestamp_from_path) + segment_list, asr_temp_paths = await run_blocking( + sync_executor, build_asr_windows_from_segments, segment_list + ) + total_segments = len(segment_list) + await run_blocking(db_executor, update_sync_job, job_id, {'total_segments': total_segments}) + # Segments that fully landed in a prior Cloud Tasks attempt are skipped already_processed = set() if task_mode: @@ -1769,7 +1897,6 @@ def _run_vad_bg(path): # Chronological order + turnstile: STT runs in parallel (per chunk), but # conversation assignment is serialized oldest-first so adjacent chunks merge # instead of racing into separate conversations (#6551, #5747). - segment_list = sorted(segmented_paths, key=get_timestamp_from_path) assignment_turnstile = _OrderedTurnstile(segment_list) def _process_one_segment(path): @@ -1890,6 +2017,7 @@ def _process_one_segment(path): pass finally: set_byok_keys({}) + await run_blocking(storage_executor, _cleanup_files, asr_temp_paths) await run_blocking(storage_executor, _cleanup_files, list(segmented_paths)) await run_blocking(storage_executor, _cleanup_files, wav_paths) try: