Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 132 additions & 4 deletions backend/routers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,13 +666,17 @@ def decode_opus_file_to_wav(opus_file_path, wav_file_path, sample_rate=16000, ch
return False


def get_timestamp_from_path(path: str):
timestamp = int(path.split('/')[-1].split('_')[-1].split('.')[0])
def get_timestamp_seconds_from_path(path: str) -> float:
timestamp = float(os.path.basename(path).rsplit('_', 1)[-1].rsplit('.', 1)[0])
if timestamp > 1e10:
return int(timestamp / 1000)
return timestamp / 1000
return timestamp


def get_timestamp_from_path(path: str):
return int(get_timestamp_seconds_from_path(path))


def retrieve_file_paths(files: List[UploadFile], uid: str):
directory = f'syncing/{uid}/'
os.makedirs(directory, exist_ok=True)
Expand Down Expand Up @@ -717,6 +721,116 @@ def get_wav_duration(wav_path: str) -> float:
return 0.0


SYNC_ASR_WINDOW_MAX_SECONDS = 10 * 60
SYNC_ASR_WINDOW_MAX_GAP_SECONDS = 120


def _collect_asr_segment_infos(segment_paths: List[str]) -> List[dict]:
infos = []
for path in segment_paths:
start = get_timestamp_seconds_from_path(path)
duration = get_wav_duration(path)
if duration <= 0:
raise ValueError(f'Invalid WAV duration for {path}')
infos.append({'path': path, 'start': start, 'end': start + duration})
return sorted(infos, key=lambda item: item['start'])


def _group_asr_segment_infos(segment_infos: List[dict]) -> List[List[dict]]:
groups = []
current_group = []
current_start = 0.0
current_end = 0.0

for info in segment_infos:
if not current_group:
current_group = [info]
current_start = info['start']
current_end = info['end']
continue

gap_seconds = info['start'] - current_end
projected_end = max(current_end, info['end'])
projected_window_seconds = projected_end - current_start
should_split = (
gap_seconds > SYNC_ASR_WINDOW_MAX_GAP_SECONDS or projected_window_seconds > SYNC_ASR_WINDOW_MAX_SECONDS
)
if should_split:
groups.append(current_group)
current_group = [info]
current_start = info['start']
current_end = info['end']
continue

current_group.append(info)
current_end = projected_end

if current_group:
groups.append(current_group)

return groups


def _merge_asr_window_group(group: List[dict]) -> str:
if len(group) == 1:
return group[0]['path']

output_dir = os.path.dirname(group[0]['path'])
output_path = os.path.join(output_dir, f"asr_window_{int(group[0]['start'] * 1000)}.wav")
combined = None

try:
combined = AudioSegment.from_wav(group[0]['path'])
expected_end = group[0]['end']
sample_width = combined.sample_width
channels = combined.channels
frame_rate = combined.frame_rate

for info in group[1:]:
gap_ms = max(0, int((info['start'] - expected_end) * 1000))
if gap_ms > 0:
silence = AudioSegment.silent(duration=gap_ms, frame_rate=frame_rate)
silence = silence.set_sample_width(sample_width).set_channels(channels)
combined += silence
del silence

next_audio = AudioSegment.from_wav(info['path'])
combined += next_audio
expected_end = max(expected_end, info['end'])
del next_audio

combined.export(output_path, format='wav')
return output_path
finally:
if combined is not None:
del combined


def build_asr_windows_from_segments(segment_paths: List[str]) -> Tuple[List[str], List[str]]:
"""Merge adjacent VAD WAVs into wider ASR windows while preserving wall-clock gaps."""
if len(segment_paths) <= 1:
return segment_paths, []

temp_paths = []
try:
segment_infos = _collect_asr_segment_infos(segment_paths)
groups = _group_asr_segment_infos(segment_infos)
asr_paths = []
for group in groups:
path = _merge_asr_window_group(group)
asr_paths.append(path)
if len(group) > 1:
temp_paths.append(path)

if len(asr_paths) != len(segment_paths):
logger.info(f'sync: regrouped {len(segment_paths)} VAD segment(s) into {len(asr_paths)} ASR window(s)')
return asr_paths, temp_paths
except Exception as e:
logger.error(f'sync: failed to regroup VAD segments for ASR, using original segments: {e}')
_cleanup_files(temp_paths)
return segment_paths, []


def decode_pcm_file_to_wav(pcm_file_path, wav_file_path, sample_rate=16000, channels=1, sample_width=2):
"""Decode a length-prefixed PCM .bin file to WAV.

Expand Down Expand Up @@ -1358,6 +1472,7 @@ async def sync_local_files(
paths = []
wav_paths = []
segmented_paths = set()
asr_temp_paths = []

try:
try:
Expand Down Expand Up @@ -1451,6 +1566,10 @@ def _run_vad(path):
# assignment is serialized oldest-first so adjacent chunks merge instead of
# racing into separate conversations (#6551, #5747).
ordered_paths = sorted(segmented_paths, key=get_timestamp_from_path)
ordered_paths, asr_temp_paths = await run_blocking(
sync_executor, build_asr_windows_from_segments, ordered_paths
)
total_segments = len(ordered_paths)
assignment_turnstile = _OrderedTurnstile(ordered_paths)
await asyncio.gather(
*[
Expand Down Expand Up @@ -1531,6 +1650,7 @@ def _run_vad(path):
# Clean up any remaining temporary files
_cleanup_files(paths) # .bin files (in case decode_files_to_wav didn't finish)
_cleanup_files(wav_paths) # Original wav files (if VAD didn't complete)
_cleanup_files(asr_temp_paths) # ASR windows synthesized from VAD segments
_cleanup_files(segmented_paths) # Segmented wav files after processing


Expand Down Expand Up @@ -1611,6 +1731,7 @@ async def _run_full_pipeline_background_async(
concurrency_gate = contextlib.nullcontext() if task_mode else _get_sync_pipeline_semaphore()
async with concurrency_gate:
segmented_paths = set()
asr_temp_paths = []
wav_paths = []
stage_timings = {}
pipeline_start = time.monotonic()
Expand Down Expand Up @@ -1756,6 +1877,13 @@ def _run_vad_bg(path):
segment_errors = []
segment_lock = threading.Lock()

segment_list = sorted(segmented_paths, key=get_timestamp_from_path)
segment_list, asr_temp_paths = await run_blocking(
sync_executor, build_asr_windows_from_segments, segment_list
)
total_segments = len(segment_list)
await run_blocking(db_executor, update_sync_job, job_id, {'total_segments': total_segments})

# Segments that fully landed in a prior Cloud Tasks attempt are skipped
already_processed = set()
if task_mode:
Expand All @@ -1769,7 +1897,6 @@ def _run_vad_bg(path):
# Chronological order + turnstile: STT runs in parallel (per chunk), but
# conversation assignment is serialized oldest-first so adjacent chunks merge
# instead of racing into separate conversations (#6551, #5747).
segment_list = sorted(segmented_paths, key=get_timestamp_from_path)
assignment_turnstile = _OrderedTurnstile(segment_list)

def _process_one_segment(path):
Expand Down Expand Up @@ -1890,6 +2017,7 @@ def _process_one_segment(path):
pass
finally:
set_byok_keys({})
await run_blocking(storage_executor, _cleanup_files, asr_temp_paths)
await run_blocking(storage_executor, _cleanup_files, list(segmented_paths))
await run_blocking(storage_executor, _cleanup_files, wav_paths)
try:
Expand Down
Loading