From 9a36c51d3760575b4c4c8a8405593cf7c5dfb9ab Mon Sep 17 00:00:00 2001 From: MXuer Date: Thu, 11 Jun 2026 10:33:23 +0800 Subject: [PATCH 1/8] Add CLI output formats --- README.md | 9 + dolphin/transcribe.py | 85 +++++++- reports/issue-80-cli-output-test-report.md | 239 +++++++++++++++++++++ tests/test_cli_output.py | 152 +++++++++++++ 4 files changed, 484 insertions(+), 1 deletion(-) create mode 100644 reports/issue-80-cli-output-test-report.md create mode 100644 tests/test_cli_output.py diff --git a/README.md b/README.md index b3beeb6..12ba1d7 100644 --- a/README.md +++ b/README.md @@ -75,6 +75,15 @@ Dolphin supports 40 Eastern languages and 22 Chinese dialects. For a complete li # default model:small dolphin audio.wav +# Write plain text output to a file +dolphin audio.wav --output result.txt + +# Write structured output with metadata +dolphin audio.wav --output result.json --output_format json + +# Write subtitle output +dolphin audio.wav --output result.srt --output_format srt + # Download model and specify the model path dolphin audio.wav --model small.cn diff --git a/dolphin/transcribe.py b/dolphin/transcribe.py index b3d92f3..1256042 100644 --- a/dolphin/transcribe.py +++ b/dolphin/transcribe.py @@ -2,6 +2,7 @@ import logging import warnings +import json LOGGING_FORMAT="[%(asctime)s] [%(levelname)s] [%(filename)s:%(lineno)d:%(funcName)s] %(message)s" logging.basicConfig(level=logging.INFO, format=LOGGING_FORMAT) @@ -75,6 +76,14 @@ def parser_args() -> Namespace: parser.add_argument("--use_two_stage_filter", type=str2bool, default=False, help="use two-stage filtering for hotwords (default: false)") parser.add_argument("--use_prompt_hotword", type=str2bool, default=False, help="use prompt-based hotword (default: false)") parser.add_argument("--prompt_filter_threshold", type=float, default=-2.0, help="filter threshold for prompt hotwords (default: -2.0)") + parser.add_argument("--output", type=Path, default=None, help="write transcription output to file") + parser.add_argument( + "--output_format", + type=str, + default="txt", + choices=("txt", "json", "srt"), + help="output format for stdout or --output (default: txt)", + ) args = parser.parse_args() return args @@ -570,6 +579,79 @@ def detect_language(model: ASRModel, audio: str) -> Tuple[str, str]: return (lang, dialect) +def _format_cli_output( + result: Union[TranscribeResult, List[TranscribeSegmentResult]], + output_format: str = "txt", +) -> str: + if output_format == "json": + if isinstance(result, list): + payload = [dataclasses.asdict(item) for item in result] + else: + payload = dataclasses.asdict(result) + return json.dumps(payload, ensure_ascii=False, indent=2) + + if output_format == "srt": + return _format_srt_output(result) + + if isinstance(result, list): + return "\n".join(item.text_nospecial for item in result) + + return result.text_nospecial + + +def _seconds_to_srt_time(seconds: float) -> str: + total_ms = max(0, int(round(seconds * 1000))) + hours = total_ms // 3600000 + total_ms %= 3600000 + minutes = total_ms // 60000 + total_ms %= 60000 + secs = total_ms // 1000 + millis = total_ms % 1000 + return f"{hours:02d}:{minutes:02d}:{secs:02d},{millis:03d}" + + +def _format_srt_output(result: Union[TranscribeResult, List[TranscribeSegmentResult]]) -> str: + if isinstance(result, list): + cues = [ + (segment.start, segment.end, segment.text_nospecial) + for segment in result + if segment.text_nospecial + ] + else: + timestamps = result.word_timestamps or [] + if timestamps: + start = float(timestamps[0].get("start", 0.0)) + end = float(timestamps[-1].get("end", start)) + else: + start = 0.0 + end = 0.0 + cues = [(start, end, result.text_nospecial)] if result.text_nospecial else [] + + blocks = [] + for index, (start, end, text) in enumerate(cues, start=1): + blocks.append( + f"{index}\n" + f"{_seconds_to_srt_time(start)} --> {_seconds_to_srt_time(end)}\n" + f"{text}" + ) + + return "\n\n".join(blocks) + + +def _emit_cli_output( + result: Union[TranscribeResult, List[TranscribeSegmentResult]], + output_format: str, + output: Optional[Path], +): + text = _format_cli_output(result, output_format) + if output is None: + print(text) + return + + output.parent.mkdir(parents=True, exist_ok=True) + output.write_text(text + "\n", encoding="utf-8") + + def transcribe( model: ASRModel, audio: str, @@ -746,7 +828,8 @@ def cli(): "use_prompt_hotword": args.use_prompt_hotword, "prompt_filter_threshold": args.prompt_filter_threshold, } - transcribe_fn(**transcribe_params) + result = transcribe_fn(**transcribe_params) + _emit_cli_output(result, args.output_format, args.output) if __name__ == "__main__": diff --git a/reports/issue-80-cli-output-test-report.md b/reports/issue-80-cli-output-test-report.md new file mode 100644 index 0000000..3631d6c --- /dev/null +++ b/reports/issue-80-cli-output-test-report.md @@ -0,0 +1,239 @@ +# Issue #80 CLI Output Test Report + +Date: 2026-06-11 +Branch: `codex/issue-80-cli-output` +Issue: [#80 没有直接的输出文本](https://github.com/DataoceanAI/Dolphin/issues/80) + +## Scope + +Issue #80 requests direct CLI output to text instead of only printing mixed logs/results in the terminal. This branch adds: + +- `--output PATH` for writing transcription output to a file. +- `--output_format {txt,json,srt}` for plain text, structured JSON, and subtitle output. +- SRT support requested during validation. +- README CLI examples for all three output formats. +- Unit coverage for text, JSON, SRT, stdout, and nested output file writes. + +## Test Environment + +- Host Python: `Python 3.10.9` +- Runtime used for real ASR tests: `/private/tmp/dolphin-venv/bin/python` +- Runtime Python: `Python 3.10.9` +- Key packages: + - `numpy 1.23.5` + - `torch 2.4.1` + - `torchaudio 2.4.1` + - `modelscope 1.36.3` + - `funasr 1.1.5` +- Environment variables used: + - `TRANSFORMERS_NO_TF=1` + - `USE_TF=0` + - `MODELSCOPE_CACHE=/private/tmp/modelscope-cache` + - `NUMBA_CACHE_DIR=/private/tmp/numba-cache` + - `HOME=/private/tmp/dolphin-home` for long-audio VAD cache isolation +- Model: + - Dolphin `base` + - Local path: `/private/tmp/dolphin-models/base` +- VAD model for long audio: + - `iic/speech_fsmn_vad_zh-cn-16k-common-pytorch` + - Downloaded to `/private/tmp/dolphin-home/.cache/dolphin/speech_fsmn_vad` + +Note: the global environment currently has `numpy 2.2.6`, which causes noisy TensorFlow/Whisper/numba compatibility stderr during imports. Real ASR validation used the temporary venv above with `numpy 1.23.5`. + +## Test Audio + +| Case | Source | Local file | Duration | +| --- | --- | --- | --- | +| Short zh-CN | `https://so-algorithm-test.oss-cn-beijing.aliyuncs.com/samples/asr/zh-cn-demo.wav` | `/private/tmp/dolphin-test-audio/zh-cn-demo.wav` | 6.267s | +| Long zh-CN | `https://so-algorithm-test.oss-cn-beijing.aliyuncs.com/samples/asr/zh-cn-long.wav` | `/private/tmp/dolphin-test-audio/zh-cn-long.wav` | 752.880s | +| Long hi-IN | `https://so-algorithm-test.oss-cn-beijing.aliyuncs.com/samples/asr/hi_in/%E5%8D%95%E4%BA%BA%E5%BD%95%E9%9F%B3_%E5%B7%A5%E4%BD%9C%E6%8A%A5%E5%91%8A_12112.wav` | `/private/tmp/dolphin-test-audio/hi-in-work-report.wav` | 1274.009s | + +## Automated Tests + +Command: + +```bash +env TRANSFORMERS_NO_TF=1 USE_TF=0 python -m pytest +``` + +Result: + +```text +collected 7 items +tests/test_cli_output.py ....... [100%] +7 passed in 1.01s +``` + +Covered behavior: + +- Short result defaults to plain `text_nospecial`. +- Long segmented results are joined as one text block for `txt`. +- JSON preserves metadata and word timestamps. +- Short-audio SRT uses first and last word timestamp. +- Long-audio SRT uses segment start/end times. +- Stdout output still works when `--output` is omitted. +- Nested output directories are created automatically. + +## CLI Help Check + +Command: + +```bash +env TRANSFORMERS_NO_TF=1 USE_TF=0 /private/tmp/dolphin-venv/bin/python -m dolphin --help +``` + +Result: + +```text +--output OUTPUT write transcription output to file +--output_format {txt,json,srt} + output format for stdout or --output (default: txt) +``` + +Exit code: `0`. + +## Real ASR Output Tests + +### Short zh-CN Plain Text + +Command: + +```bash +env MODELSCOPE_CACHE=/private/tmp/modelscope-cache NUMBA_CACHE_DIR=/private/tmp/numba-cache TRANSFORMERS_NO_TF=1 USE_TF=0 /private/tmp/dolphin-venv/bin/python -m dolphin /private/tmp/dolphin-test-audio/zh-cn-demo.wav --model base --model_dir /private/tmp/dolphin-models/base --device cpu --output /private/tmp/dolphin-test-output/demo.txt --output_format txt +``` + +Result: + +```text +诚然 , 时代正在推崇初心 , 但文化之精髓切虚传承。 +``` + +Output file: `/private/tmp/dolphin-test-output/demo.txt` +Exit code: `0`. + +### Short zh-CN JSON + +Command: + +```bash +env MODELSCOPE_CACHE=/private/tmp/modelscope-cache NUMBA_CACHE_DIR=/private/tmp/numba-cache TRANSFORMERS_NO_TF=1 USE_TF=0 /private/tmp/dolphin-venv/bin/python -m dolphin /private/tmp/dolphin-test-audio/zh-cn-demo.wav --model base --model_dir /private/tmp/dolphin-models/base --device cpu --output /private/tmp/dolphin-test-output/demo.json --output_format json +``` + +Validation: + +```bash +python -m json.tool /private/tmp/dolphin-test-output/demo.json +``` + +Result: + +- JSON parsed successfully. +- File contains `text`, `text_nospecial`, `language`, `region`, and `word_timestamps`. +- Detected language/region: `zh` / `CN`. + +Output file: `/private/tmp/dolphin-test-output/demo.json` +Exit code: `0`. + +### Short zh-CN SRT + +Command: + +```bash +env MODELSCOPE_CACHE=/private/tmp/modelscope-cache NUMBA_CACHE_DIR=/private/tmp/numba-cache TRANSFORMERS_NO_TF=1 USE_TF=0 /private/tmp/dolphin-venv/bin/python -m dolphin /private/tmp/dolphin-test-audio/zh-cn-demo.wav --model base --model_dir /private/tmp/dolphin-models/base --device cpu --output /private/tmp/dolphin-test-output/demo.srt --output_format srt +``` + +Result: + +```srt +1 +00:00:00,600 --> 00:00:05,867 +诚然 , 时代正在推崇初心 , 但文化之精髓切虚传承。 +``` + +Output file: `/private/tmp/dolphin-test-output/demo.srt` +Exit code: `0`. + +### Long zh-CN SRT + +Command: + +```bash +env HOME=/private/tmp/dolphin-home MODELSCOPE_CACHE=/private/tmp/modelscope-cache NUMBA_CACHE_DIR=/private/tmp/numba-cache TRANSFORMERS_NO_TF=1 USE_TF=0 /private/tmp/dolphin-venv/bin/python -m dolphin /private/tmp/dolphin-test-audio/zh-cn-long.wav --model base --model_dir /private/tmp/dolphin-models/base --device cpu --output /private/tmp/dolphin-test-output/zh-cn-long.srt --output_format srt +``` + +Result: + +- Output file: `/private/tmp/dolphin-test-output/zh-cn-long.srt` +- File size: `12K` +- Cue count: `60` +- First cue: `00:00:03,150 --> 00:00:05,000` +- Last cue: `00:12:26,870 --> 00:12:28,680` +- Exit code: `0` + +### Long hi-IN SRT + +Command: + +```bash +env HOME=/private/tmp/dolphin-home MODELSCOPE_CACHE=/private/tmp/modelscope-cache NUMBA_CACHE_DIR=/private/tmp/numba-cache TRANSFORMERS_NO_TF=1 USE_TF=0 /private/tmp/dolphin-venv/bin/python -m dolphin /private/tmp/dolphin-test-audio/hi-in-work-report.wav --model base --model_dir /private/tmp/dolphin-models/base --device cpu --output /private/tmp/dolphin-test-output/hi-in-work-report.srt --output_format srt +``` + +Result: + +- Output file: `/private/tmp/dolphin-test-output/hi-in-work-report.srt` +- File size: `32K` +- Cue count: `175` +- First cue: `00:00:02,690 --> 00:00:08,900` +- Last cue: `00:21:11,550 --> 00:21:13,240` +- Exit code: `0` + +## SRT Structural Validation + +Command: + +```bash +python -c 'import re, pathlib +for path in ["/private/tmp/dolphin-test-output/demo.srt", "/private/tmp/dolphin-test-output/zh-cn-long.srt", "/private/tmp/dolphin-test-output/hi-in-work-report.srt"]: + blocks = [b for b in pathlib.Path(path).read_text(encoding="utf-8").strip().split("\n\n") if b.strip()] + prev_end = -1.0 + ok = True + for expected, block in enumerate(blocks, 1): + lines = block.splitlines() + if int(lines[0]) != expected: + ok = False + break + m = re.match(r"(\d\d):(\d\d):(\d\d),(\d\d\d) --> (\d\d):(\d\d):(\d\d),(\d\d\d)$", lines[1]) + if not m: + ok = False + break + vals = list(map(int, m.groups())) + start = vals[0]*3600 + vals[1]*60 + vals[2] + vals[3]/1000 + end = vals[4]*3600 + vals[5]*60 + vals[6] + vals[7]/1000 + if end < start or start < prev_end: + ok = False + break + prev_end = end + print(f"{path}: cues={len(blocks)}, valid={ok}, last_end={prev_end:.3f}s")' +``` + +Result: + +```text +/private/tmp/dolphin-test-output/demo.srt: cues=1, valid=True, last_end=5.867s +/private/tmp/dolphin-test-output/zh-cn-long.srt: cues=60, valid=True, last_end=748.680s +/private/tmp/dolphin-test-output/hi-in-work-report.srt: cues=175, valid=True, last_end=1273.240s +``` + +## Notes And Risks + +- This change does not alter ASR decoding, VAD, language detection, or timestamp generation. +- Long-audio SRT uses VAD segment start/end times from `TranscribeSegmentResult`. +- Short-audio SRT uses first and last word timestamp from `TranscribeResult.word_timestamps`. +- If no text is returned, SRT output is empty rather than creating blank cues. +- JSON output is intentionally full dataclass data, including timestamps and language metadata. +- The existing CLI still prints the formatted result when `--output` is omitted. +- Some upstream ASR logs still go to stderr/stdout during real inference. The new file output keeps the requested transcription artifact clean. + +## Recommendation + +This branch is ready for user review for issue #80. The new SRT format also passed short and long real-audio validation. diff --git a/tests/test_cli_output.py b/tests/test_cli_output.py new file mode 100644 index 0000000..34813fb --- /dev/null +++ b/tests/test_cli_output.py @@ -0,0 +1,152 @@ +import importlib +import json +import sys +import types + + +def _install_modelscope_stub(): + modelscope = types.ModuleType("modelscope") + modelscope.snapshot_download = lambda *args, **kwargs: None + + model_module = types.ModuleType("modelscope.models.audio.funasr.model") + + class GenericFunASR: + pass + + model_module.GenericFunASR = GenericFunASR + + sys.modules.setdefault("modelscope", modelscope) + sys.modules.setdefault("modelscope.models", types.ModuleType("modelscope.models")) + sys.modules.setdefault("modelscope.models.audio", types.ModuleType("modelscope.models.audio")) + sys.modules.setdefault("modelscope.models.audio.funasr", types.ModuleType("modelscope.models.audio.funasr")) + sys.modules.setdefault("modelscope.models.audio.funasr.model", model_module) + + +_install_modelscope_stub() +transcribe = importlib.import_module("dolphin.transcribe") + + +def test_format_cli_output_short_text(): + result = transcribe.TranscribeResult( + text="你好", + text_nospecial="你好", + language="zh", + region="CN", + ) + + assert transcribe._format_cli_output(result) == "你好" + + +def test_format_cli_output_segments_text(): + segments = [ + transcribe.TranscribeSegmentResult( + text="a", + text_nospecial="第一段", + language="zh", + region="CN", + start=0.0, + end=1.0, + ), + transcribe.TranscribeSegmentResult( + text="b", + text_nospecial="第二段", + language="zh", + region="CN", + start=1.0, + end=2.0, + ), + ] + + assert transcribe._format_cli_output(segments) == "第一段\n第二段" + + +def test_format_cli_output_json_preserves_metadata(): + result = transcribe.TranscribeResult( + text="你好", + text_nospecial="你好", + language="zh", + region="CN", + word_timestamps=[{"word": "你好", "start": 0.0, "end": 0.4}], + ) + + payload = json.loads(transcribe._format_cli_output(result, "json")) + + assert payload["text_nospecial"] == "你好" + assert payload["language"] == "zh" + assert payload["word_timestamps"] == [{"word": "你好", "start": 0.0, "end": 0.4}] + + +def test_format_cli_output_short_result_srt_uses_word_timestamps(): + result = transcribe.TranscribeResult( + text="raw", + text_nospecial="你好", + language="zh", + region="CN", + word_timestamps=[ + {"word": "你", "start": 0.12, "end": 0.31}, + {"word": "好", "start": 0.31, "end": 0.62}, + ], + ) + + assert transcribe._format_cli_output(result, "srt") == ( + "1\n" + "00:00:00,120 --> 00:00:00,620\n" + "你好" + ) + + +def test_format_cli_output_segments_srt(): + segments = [ + transcribe.TranscribeSegmentResult( + text="a", + text_nospecial="第一段", + language="zh", + region="CN", + start=0.0, + end=1.25, + ), + transcribe.TranscribeSegmentResult( + text="b", + text_nospecial="第二段", + language="zh", + region="CN", + start=61.0, + end=62.5, + ), + ] + + assert transcribe._format_cli_output(segments, "srt") == ( + "1\n" + "00:00:00,000 --> 00:00:01,250\n" + "第一段\n\n" + "2\n" + "00:01:01,000 --> 00:01:02,500\n" + "第二段" + ) + + +def test_emit_cli_output_stdout(capsys): + result = transcribe.TranscribeResult( + text="raw", + text_nospecial="plain", + language="zh", + region="CN", + ) + + transcribe._emit_cli_output(result, "txt", None) + + assert capsys.readouterr().out == "plain\n" + + +def test_emit_cli_output_file(tmp_path): + result = transcribe.TranscribeResult( + text="raw", + text_nospecial="plain", + language="zh", + region="CN", + ) + output = tmp_path / "nested" / "result.txt" + + transcribe._emit_cli_output(result, "txt", output) + + assert output.read_text(encoding="utf-8") == "plain\n" From 5e429b25f3c03af991b20839565900b2165ea4de Mon Sep 17 00:00:00 2001 From: MXuer Date: Thu, 11 Jun 2026 10:44:25 +0800 Subject: [PATCH 2/8] Expose language detection task --- README.md | 10 + dolphin/__init__.py | 2 +- dolphin/transcribe.py | 51 ++++- ...issue-93-language-detection-test-report.md | 180 ++++++++++++++++++ tests/test_language_detection.py | 101 ++++++++++ 5 files changed, 342 insertions(+), 2 deletions(-) create mode 100644 reports/issue-93-language-detection-test-report.md create mode 100644 tests/test_language_detection.py diff --git a/README.md b/README.md index 12ba1d7..e3ff0ed 100644 --- a/README.md +++ b/README.md @@ -90,6 +90,12 @@ dolphin audio.wav --model small.cn # Specify language and region dolphin audio.wav --model small.cn --lang_sym "zh" --region_sym "CN" +# Detect language and region only. This uses the Dolphin ASR model's built-in +# language identification head; this package does not provide a separate +# lightweight LID-only model. +dolphin audio.wav --model small.cn --task detect_language +dolphin long_audio.wav --model small.cn --task detect_language --lid_duration 30 + # Specify the hotwords file with Encoder-biased method dolphin audio.wav --model small.cn --hotword_list_path hotwords.txt --use_deep_biasing true @@ -113,6 +119,10 @@ model = dolphin.load_model(model_name, device="cuda") result = transcribe(model, 'audio.wav') print(result.text) +# Detect language and region only +language, region = dolphin.detect_language(model, 'audio.wav') +print(language, region) + # Specify language result = transcribe(model, 'audio.wav', lang_sym="zh") print(result.text) diff --git a/dolphin/__init__.py b/dolphin/__init__.py index 9db5ef0..b064235 100644 --- a/dolphin/__init__.py +++ b/dolphin/__init__.py @@ -1,6 +1,6 @@ # encoding: utf8 from .audio import load_audio -from .transcribe import load_model, transcribe +from .transcribe import detect_language, load_model, transcribe from .hotword import HotwordEncoder, apply_deep_biasing, two_stage_filtering from .version import __version__ diff --git a/dolphin/transcribe.py b/dolphin/transcribe.py index 1256042..0196a50 100644 --- a/dolphin/transcribe.py +++ b/dolphin/transcribe.py @@ -28,6 +28,7 @@ import torch import torch.nn as nn +import torchaudio import modelscope from modelscope.models.audio.funasr.model import GenericFunASR @@ -76,6 +77,14 @@ def parser_args() -> Namespace: parser.add_argument("--use_two_stage_filter", type=str2bool, default=False, help="use two-stage filtering for hotwords (default: false)") parser.add_argument("--use_prompt_hotword", type=str2bool, default=False, help="use prompt-based hotword (default: false)") parser.add_argument("--prompt_filter_threshold", type=float, default=-2.0, help="filter threshold for prompt hotwords (default: -2.0)") + parser.add_argument("--lid_duration", type=float, default=SPEECH_LENGTH, help="seconds of audio to use for language detection; set 0 to use full audio (default: 30)") + parser.add_argument( + "--task", + type=str, + default="transcribe", + choices=("transcribe", "detect_language"), + help="task to run: transcribe or detect_language (default: transcribe)", + ) parser.add_argument("--output", type=Path, default=None, help="write transcription output to file") parser.add_argument( "--output_format", @@ -562,10 +571,45 @@ def _filter_prompt_tokens(tokens: List[int], tokenizer: BaseTokenizer) -> Tuple[ return hotwords_text, tokens -def detect_language(model: ASRModel, audio: str) -> Tuple[str, str]: +def _limit_audio_duration( + audio: Union[str, Path, torch.Tensor], + max_duration: Optional[float] = SPEECH_LENGTH, +) -> Union[str, Path, torch.Tensor]: + if max_duration is None or max_duration <= 0: + return audio + + if isinstance(audio, torch.Tensor): + max_samples = int(max_duration * 16000) + return audio[..., :max_samples] + + info = torchaudio.info(str(audio)) + max_frames = int(max_duration * info.sample_rate) + if info.num_frames > 0 and info.num_frames <= max_frames: + return audio + + waveform, sample_rate = torchaudio.load(str(audio), num_frames=max_frames) + if waveform.size(0) != 1: + waveform = waveform[0, :].unsqueeze(0) + + if sample_rate != 16000: + waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform) + + return waveform + + +def detect_language( + model: ASRModel, + audio: Union[str, Path, torch.Tensor], + max_duration: Optional[float] = SPEECH_LENGTH, +) -> Tuple[str, str]: """ Detect language and dialect. + + Language detection only needs a short audio sample. By default this uses + the first ``SPEECH_LENGTH`` seconds to keep long-audio detection fast. + Pass ``max_duration=None`` or a non-positive value to use the full audio. """ + audio = _limit_audio_duration(audio, max_duration) batch = extract_feats([audio], model.model_configs) batch["feats"] = batch["feats"].to(model.device) batch["feats_lengths"] = batch["feats_lengths"].to(model.device) @@ -807,6 +851,11 @@ def cli(): model_instance = load_model(model, model_dir, device) logger.info(f"model loaded successfuly, device: {device}") + if args.task == "detect_language": + lang, region = detect_language(model_instance, args.audio, max_duration=args.lid_duration) + print(f"{lang}\t{region}") + return + # Parse hotwords hotwords = _parse_hotwords(args.hotword_str, args.hotword_list_path) diff --git a/reports/issue-93-language-detection-test-report.md b/reports/issue-93-language-detection-test-report.md new file mode 100644 index 0000000..1e83e1e --- /dev/null +++ b/reports/issue-93-language-detection-test-report.md @@ -0,0 +1,180 @@ +# Issue #93 Language Detection Test Report + +Date: 2026-06-11 +Branch: `codex/issue-93-language-detection` +Issue: [#93 Language Detection Model](https://github.com/DataoceanAI/Dolphin/issues/93) + +## Scope + +Issue #93 asks whether Dolphin can be used for language detection only, and whether a separate lightweight language-detection model exists. + +This branch makes the existing language-identification capability directly usable: + +- Exports `dolphin.detect_language(model, audio, max_duration=30)` from the package top level. +- Adds CLI task mode: `--task detect_language`. +- Adds `--lid_duration` to control how many seconds are used for LID. +- Documents that LID is built into the ASR model; this package does not include a separate lightweight LID-only model. +- Adds unit tests for the Python API, package export, CLI task, and tensor-duration limiting. + +## Implementation Notes + +Dolphin already had an internal `detect_language(model, audio)` function backed by `ASRModel.detect_language`. The model predicts the language token and region token from the ASR model's encoder/decoder path. + +During validation, sending full long audio directly into this function did not complete within about 90 seconds on CPU. The branch now limits language detection to the first `SPEECH_LENGTH` seconds by default, which is 30 seconds. This keeps language-detection-only usage responsive while preserving the option to use full audio: + +- CLI full audio: `--lid_duration 0` +- Python full audio: `dolphin.detect_language(model, audio, max_duration=None)` + +## Test Environment + +- Host Python: `Python 3.10.9` +- Runtime used for real ASR/LID tests: `/private/tmp/dolphin-venv/bin/python` +- Runtime Python: `Python 3.10.9` +- Key packages: + - `numpy 1.23.5` + - `torch 2.4.1` + - `torchaudio 2.4.1` + - `modelscope 1.36.3` + - `funasr 1.1.5` +- Environment variables used: + - `TRANSFORMERS_NO_TF=1` + - `USE_TF=0` + - `MODELSCOPE_CACHE=/private/tmp/modelscope-cache` + - `NUMBA_CACHE_DIR=/private/tmp/numba-cache` +- Model: + - Dolphin `base` + - Local path: `/private/tmp/dolphin-models/base` + +Note: the global environment currently has `numpy 2.2.6`, which causes noisy TensorFlow/Whisper/numba compatibility stderr during imports. Real validation used the temporary venv above with `numpy 1.23.5`. + +## Test Audio + +| Case | Source | Local file | Duration | +| --- | --- | --- | --- | +| Short zh-CN | `https://so-algorithm-test.oss-cn-beijing.aliyuncs.com/samples/asr/zh-cn-demo.wav` | `/private/tmp/dolphin-test-audio/zh-cn-demo.wav` | 6.267s | +| Long zh-CN | `https://so-algorithm-test.oss-cn-beijing.aliyuncs.com/samples/asr/zh-cn-long.wav` | `/private/tmp/dolphin-test-audio/zh-cn-long.wav` | 752.880s | +| Long hi-IN | `https://so-algorithm-test.oss-cn-beijing.aliyuncs.com/samples/asr/hi_in/%E5%8D%95%E4%BA%BA%E5%BD%95%E9%9F%B3_%E5%B7%A5%E4%BD%9C%E6%8A%A5%E5%91%8A_12112.wav` | `/private/tmp/dolphin-test-audio/hi-in-work-report.wav` | 1274.009s | + +## Automated Tests + +Command: + +```bash +env TRANSFORMERS_NO_TF=1 USE_TF=0 python -m pytest +``` + +Result: + +```text +collected 4 items +tests/test_language_detection.py .... [100%] +4 passed in 0.83s +``` + +Covered behavior: + +- `detect_language` returns language and region tokens through the existing model LID path. +- `dolphin.detect_language` is exported from package top level. +- `--task detect_language` prints `languageregion`. +- Tensor audio is cropped by `max_duration`. + +## CLI Help Check + +Command: + +```bash +env TRANSFORMERS_NO_TF=1 USE_TF=0 /private/tmp/dolphin-venv/bin/python -m dolphin --help +``` + +Result: + +```text +--lid_duration LID_DURATION + seconds of audio to use for language detection; set 0 + to use full audio (default: 30) +--task {transcribe,detect_language} + task to run: transcribe or detect_language (default: + transcribe) +``` + +Exit code: `0`. + +## Real CLI LID Tests + +### Short zh-CN + +Command: + +```bash +env MODELSCOPE_CACHE=/private/tmp/modelscope-cache NUMBA_CACHE_DIR=/private/tmp/numba-cache TRANSFORMERS_NO_TF=1 USE_TF=0 /private/tmp/dolphin-venv/bin/python -m dolphin /private/tmp/dolphin-test-audio/zh-cn-demo.wav --model base --model_dir /private/tmp/dolphin-models/base --device cpu --task detect_language +``` + +Result: + +```text +zh CN +``` + +Exit code: `0`. + +### Long zh-CN + +Command: + +```bash +env MODELSCOPE_CACHE=/private/tmp/modelscope-cache NUMBA_CACHE_DIR=/private/tmp/numba-cache TRANSFORMERS_NO_TF=1 USE_TF=0 /private/tmp/dolphin-venv/bin/python -m dolphin /private/tmp/dolphin-test-audio/zh-cn-long.wav --model base --model_dir /private/tmp/dolphin-models/base --device cpu --task detect_language +``` + +Result: + +```text +zh CN +``` + +Exit code: `0`. + +### Long hi-IN + +Command: + +```bash +env MODELSCOPE_CACHE=/private/tmp/modelscope-cache NUMBA_CACHE_DIR=/private/tmp/numba-cache TRANSFORMERS_NO_TF=1 USE_TF=0 /private/tmp/dolphin-venv/bin/python -m dolphin /private/tmp/dolphin-test-audio/hi-in-work-report.wav --model base --model_dir /private/tmp/dolphin-models/base --device cpu --task detect_language +``` + +Result: + +```text +hi IN +``` + +Exit code: `0`. + +## Real Python API LID Test + +Command: + +```bash +env MODELSCOPE_CACHE=/private/tmp/modelscope-cache NUMBA_CACHE_DIR=/private/tmp/numba-cache TRANSFORMERS_NO_TF=1 USE_TF=0 /private/tmp/dolphin-venv/bin/python -c "import dolphin; model=dolphin.load_model('base', '/private/tmp/dolphin-models/base', 'cpu'); print(dolphin.detect_language(model, '/private/tmp/dolphin-test-audio/zh-cn-demo.wav')); print(dolphin.detect_language(model, '/private/tmp/dolphin-test-audio/zh-cn-long.wav')); print(dolphin.detect_language(model, '/private/tmp/dolphin-test-audio/hi-in-work-report.wav'))" +``` + +Result: + +```text +('zh', 'CN') +('zh', 'CN') +('hi', 'IN') +``` + +Exit code: `0`. + +## Notes And Risks + +- This branch does not introduce or claim a separate lightweight LID model. +- LID still requires loading a Dolphin ASR model. +- Default LID duration is 30 seconds to prevent long-audio language detection from sending an entire long recording into the decoder. +- `--lid_duration 0` and `max_duration=None` remain available for full-audio detection. +- ASR transcription behavior is unchanged when `--task` is omitted. + +## Recommendation + +This branch is ready for user review for issue #93. It answers the issue directly and gives both CLI and Python users a language-detection-only workflow. diff --git a/tests/test_language_detection.py b/tests/test_language_detection.py new file mode 100644 index 0000000..b8e7081 --- /dev/null +++ b/tests/test_language_detection.py @@ -0,0 +1,101 @@ +import importlib +import sys +import types + +import torch + + +def _install_modelscope_stub(): + modelscope = types.ModuleType("modelscope") + modelscope.snapshot_download = lambda *args, **kwargs: None + + model_module = types.ModuleType("modelscope.models.audio.funasr.model") + + class GenericFunASR: + pass + + model_module.GenericFunASR = GenericFunASR + + sys.modules.setdefault("modelscope", modelscope) + sys.modules.setdefault("modelscope.models", types.ModuleType("modelscope.models")) + sys.modules.setdefault("modelscope.models.audio", types.ModuleType("modelscope.models.audio")) + sys.modules.setdefault("modelscope.models.audio.funasr", types.ModuleType("modelscope.models.audio.funasr")) + sys.modules.setdefault("modelscope.models.audio.funasr.model", model_module) + + +_install_modelscope_stub() +transcribe = importlib.import_module("dolphin.transcribe") +dolphin = importlib.import_module("dolphin") + + +class FakeTokenizer: + def ids2tokens(self, ids): + assert ids == [1, 2] + return ["", ""] + + +class FakeModel: + device = torch.device("cpu") + model_configs = {"dummy": True} + + def __init__(self): + self.called = False + + def detect_language(self, feats, feats_lengths): + self.called = True + assert feats.device == self.device + assert feats_lengths.device == self.device + return torch.tensor([[1, 2]]) + + +def test_detect_language_returns_language_and_region(monkeypatch): + model = FakeModel() + monkeypatch.setattr( + transcribe, + "extract_feats", + lambda audio, configs: { + "feats": torch.zeros(1, 2, 3), + "feats_lengths": torch.tensor([2]), + }, + ) + monkeypatch.setattr(transcribe, "init_tokenizer", lambda configs: FakeTokenizer()) + + assert transcribe.detect_language(model, "audio.wav", max_duration=None) == ("zh", "CN") + assert model.called + + +def test_detect_language_is_exported_at_package_top_level(): + assert dolphin.detect_language is transcribe.detect_language + + +def test_cli_detect_language_task_prints_only_language_result(monkeypatch, capsys): + monkeypatch.setattr( + sys, + "argv", + [ + "dolphin", + "audio.wav", + "--model", + "base", + "--model_dir", + "/tmp/model", + "--device", + "cpu", + "--task", + "detect_language", + ], + ) + monkeypatch.setattr(transcribe, "load_model", lambda *args, **kwargs: object()) + monkeypatch.setattr(transcribe, "detect_language", lambda model, audio, max_duration=None: ("hi", "IN")) + + transcribe.cli() + + assert capsys.readouterr().out == "hi\tIN\n" + + +def test_limit_audio_duration_crops_tensor_audio(): + audio = torch.arange(20, dtype=torch.float32).unsqueeze(0) + + limited = transcribe._limit_audio_duration(audio, max_duration=0.0005) + + assert torch.equal(limited, audio[:, :8]) From 8ed0086a9a226d33f10c363fa74bfffe0abae700 Mon Sep 17 00:00:00 2001 From: MXuer Date: Thu, 11 Jun 2026 10:52:48 +0800 Subject: [PATCH 3/8] Add punctuation removal option --- README.md | 7 + dolphin/transcribe.py | 54 ++++- ...ssue-92-disable-punctuation-test-report.md | 191 ++++++++++++++++++ tests/test_punctuation.py | 62 ++++++ 4 files changed, 313 insertions(+), 1 deletion(-) create mode 100644 reports/issue-92-disable-punctuation-test-report.md create mode 100644 tests/test_punctuation.py diff --git a/README.md b/README.md index e3ff0ed..a12703b 100644 --- a/README.md +++ b/README.md @@ -105,6 +105,9 @@ dolphin audio.wav --model small.cn.prompt --hotword_list_path hotwords.txt --use # predict word timestamp dolphin audio.wav --model small.cn.prompt --word_timestamp true +# Remove punctuation from transcription text +dolphin audio.wav --model small.cn --remove_punctuation true + ``` ### Python usage @@ -127,6 +130,10 @@ print(language, region) result = transcribe(model, 'audio.wav', lang_sym="zh") print(result.text) +# Remove punctuation from transcription text +result = transcribe(model, 'audio.wav', remove_punctuation=True) +print(result.text_nospecial) + # Specify language and region and encoder-biased hotwords result = transcribe(model, 'audio.wav', lang_sym="zh", region_sym="CN", hotwords=['诺香丹青牌科研胶囊'], use_deep_biasing=True, use_two_stage_filter=True) print(result.text) diff --git a/dolphin/transcribe.py b/dolphin/transcribe.py index 0196a50..6e236d9 100644 --- a/dolphin/transcribe.py +++ b/dolphin/transcribe.py @@ -3,6 +3,8 @@ import logging import warnings import json +import re +import unicodedata LOGGING_FORMAT="[%(asctime)s] [%(levelname)s] [%(filename)s:%(lineno)d:%(funcName)s] %(message)s" logging.basicConfig(level=logging.INFO, format=LOGGING_FORMAT) @@ -77,6 +79,7 @@ def parser_args() -> Namespace: parser.add_argument("--use_two_stage_filter", type=str2bool, default=False, help="use two-stage filtering for hotwords (default: false)") parser.add_argument("--use_prompt_hotword", type=str2bool, default=False, help="use prompt-based hotword (default: false)") parser.add_argument("--prompt_filter_threshold", type=float, default=-2.0, help="filter threshold for prompt hotwords (default: -2.0)") + parser.add_argument("--remove_punctuation", type=str2bool, default=False, help="remove punctuation from transcription text output (default: false)") parser.add_argument("--lid_duration", type=float, default=SPEECH_LENGTH, help="seconds of audio to use for language detection; set 0 to use full audio (default: 30)") parser.add_argument( "--task", @@ -272,6 +275,48 @@ def validate_lang_region(lang_sym: str, region_sym: str): return True +def _remove_punctuation(text: str) -> str: + return "".join(ch for ch in text if not unicodedata.category(ch).startswith("P")) + + +def _remove_punctuation_preserving_special_tokens(text: str) -> str: + parts = re.split(r"(<[^>]+>)", text) + return "".join( + part if part.startswith("<") and part.endswith(">") else _remove_punctuation(part) + for part in parts + ) + + +def _remove_punctuation_word_timestamps( + word_timestamps: Optional[List[Dict[str, Any]]], +) -> Optional[List[Dict[str, Any]]]: + if word_timestamps is None: + return None + + cleaned_timestamps = [] + for item in word_timestamps: + cleaned_word = _remove_punctuation(str(item.get("word", ""))) + if not cleaned_word: + continue + + cleaned_item = dict(item) + cleaned_item["word"] = cleaned_word + cleaned_timestamps.append(cleaned_item) + + return cleaned_timestamps + + +def _remove_result_punctuation( + result: Union[TranscribeResult, TranscribeSegmentResult], +) -> Union[TranscribeResult, TranscribeSegmentResult]: + return dataclasses.replace( + result, + text=_remove_punctuation_preserving_special_tokens(result.text), + text_nospecial=_remove_punctuation(result.text_nospecial), + word_timestamps=_remove_punctuation_word_timestamps(result.word_timestamps), + ) + + def transcribe_long( model: ASRModel, audio: str, @@ -287,6 +332,7 @@ def transcribe_long( use_two_stage_filter: bool = False, use_prompt_hotword: bool = False, prompt_filter_threshold: float = -2.0, + remove_punctuation: bool = False, **kwargs, ) -> List[TranscribeSegmentResult]: """ @@ -419,6 +465,8 @@ def transcribe_long( region=region, word_timestamps=word_ts, ) + if remove_punctuation: + result = _remove_result_punctuation(result) st = seconds_to_hms(s/1000) et = seconds_to_hms(e/1000) @@ -711,6 +759,7 @@ def transcribe( use_two_stage_filter: bool = False, use_prompt_hotword: bool = False, prompt_filter_threshold: float = -4.0, + remove_punctuation: bool = False, **kwargs, ) -> TranscribeResult: """ @@ -830,8 +879,10 @@ def transcribe( region=region, word_timestamps=word_ts, ) + if remove_punctuation: + result = _remove_result_punctuation(result) - logger.info(f"decode result, language: {result.language}, region: {result.region}, text: {result.text_nospecial} Timestamp: {word_ts}") + logger.info(f"decode result, language: {result.language}, region: {result.region}, text: {result.text_nospecial} Timestamp: {result.word_timestamps}") return result @@ -876,6 +927,7 @@ def cli(): "use_two_stage_filter": args.use_two_stage_filter, "use_prompt_hotword": args.use_prompt_hotword, "prompt_filter_threshold": args.prompt_filter_threshold, + "remove_punctuation": args.remove_punctuation, } result = transcribe_fn(**transcribe_params) _emit_cli_output(result, args.output_format, args.output) diff --git a/reports/issue-92-disable-punctuation-test-report.md b/reports/issue-92-disable-punctuation-test-report.md new file mode 100644 index 0000000..27468f2 --- /dev/null +++ b/reports/issue-92-disable-punctuation-test-report.md @@ -0,0 +1,191 @@ +# Issue #92 Disable Punctuation Test Report + +Date: 2026-06-11 +Branch: `codex/issue-92-disable-punctuation` +Issue: [#92 希望后面能有手动关闭标点符号的选项](https://github.com/DataoceanAI/Dolphin/issues/92) + +## Scope + +Issue #92 requests an option to manually disable punctuation in recognition output. + +This branch adds: + +- CLI option: `--remove_punctuation true` +- Python API option: `transcribe(..., remove_punctuation=True)` +- Long-audio support through `transcribe_long(..., remove_punctuation=True)` +- Shared Unicode punctuation removal for short and segmented results +- Preservation of special task/language tokens in `result.text` +- Punctuation cleanup for `result.text_nospecial` +- Filtering/cleaning punctuation tokens in `word_timestamps` +- README examples and unit tests + +This is an output post-processing option. It does not change model decoding, training, or punctuation generation behavior inside the model. + +## Test Environment + +- Host Python: `Python 3.10.9` +- Runtime used for real ASR tests: `/private/tmp/dolphin-venv/bin/python` +- Runtime Python: `Python 3.10.9` +- Key packages: + - `numpy 1.23.5` + - `torch 2.4.1` + - `torchaudio 2.4.1` + - `modelscope 1.36.3` + - `funasr 1.1.5` +- Environment variables used: + - `TRANSFORMERS_NO_TF=1` + - `USE_TF=0` + - `MODELSCOPE_CACHE=/private/tmp/modelscope-cache` + - `NUMBA_CACHE_DIR=/private/tmp/numba-cache` + - `HOME=/private/tmp/dolphin-home` for long-audio VAD cache isolation +- Model: + - Dolphin `base` + - Local path: `/private/tmp/dolphin-models/base` + +Note: the global environment currently has `numpy 2.2.6`, which causes noisy TensorFlow/Whisper/numba compatibility stderr during imports. Real validation used the temporary venv above with `numpy 1.23.5`. + +## Test Audio + +| Case | Source | Local file | Duration | +| --- | --- | --- | --- | +| Short zh-CN | `https://so-algorithm-test.oss-cn-beijing.aliyuncs.com/samples/asr/zh-cn-demo.wav` | `/private/tmp/dolphin-test-audio/zh-cn-demo.wav` | 6.267s | +| Long zh-CN | `https://so-algorithm-test.oss-cn-beijing.aliyuncs.com/samples/asr/zh-cn-long.wav` | `/private/tmp/dolphin-test-audio/zh-cn-long.wav` | 752.880s | +| Long hi-IN | `https://so-algorithm-test.oss-cn-beijing.aliyuncs.com/samples/asr/hi_in/%E5%8D%95%E4%BA%BA%E5%BD%95%E9%9F%B3_%E5%B7%A5%E4%BD%9C%E6%8A%A5%E5%91%8A_12112.wav` | `/private/tmp/dolphin-test-audio/hi-in-work-report.wav` | 1274.009s | + +## Automated Tests + +Command: + +```bash +env TRANSFORMERS_NO_TF=1 USE_TF=0 python -m pytest +``` + +Result: + +```text +collected 3 items +tests/test_punctuation.py ... [100%] +3 passed in 0.89s +``` + +Covered behavior: + +- Unicode punctuation is removed from recognition text. +- Special tokens such as `` are preserved in `result.text`. +- `text_nospecial` is cleaned. +- Pure punctuation word timestamps are removed. +- Punctuation attached to timestamp words is stripped. +- CLI parser accepts `--remove_punctuation true`. + +## CLI Help Check + +Command: + +```bash +env TRANSFORMERS_NO_TF=1 USE_TF=0 /private/tmp/dolphin-venv/bin/python -m dolphin --help +``` + +Result: + +```text +--remove_punctuation REMOVE_PUNCTUATION + remove punctuation from transcription text output + (default: false) +``` + +Exit code: `0`. + +## Real ASR Tests + +### Short zh-CN Python API + +Command: + +```bash +env MODELSCOPE_CACHE=/private/tmp/modelscope-cache NUMBA_CACHE_DIR=/private/tmp/numba-cache TRANSFORMERS_NO_TF=1 USE_TF=0 /private/tmp/dolphin-venv/bin/python -c ' +import logging +import unicodedata +import dolphin +logging.getLogger("dolphin").setLevel(logging.ERROR) +model = dolphin.load_model("base", "/private/tmp/dolphin-models/base", "cpu") +result = dolphin.transcribe(model, "/private/tmp/dolphin-test-audio/zh-cn-demo.wav", predict_time=True, word_timestamp=True, remove_punctuation=True) +has_punctuation = any(unicodedata.category(ch).startswith("P") for ch in result.text_nospecial) +timestamp_punctuation = [item for item in result.word_timestamps or [] if any(unicodedata.category(ch).startswith("P") for ch in str(item.get("word", "")))] +print(result.text) +print(result.text_nospecial) +print("text_has_punctuation=", has_punctuation) +print("timestamp_punctuation_count=", len(timestamp_punctuation)) +' +``` + +Result: + +```text + 诚然 时代正在推崇初心 但文化之精髓切虚传承 +诚然 时代正在推崇初心 但文化之精髓切虚传承 +text_has_punctuation= False +timestamp_punctuation_count= 0 +``` + +Exit code: `0`. + +### Short zh-CN CLI + +Command: + +```bash +env MODELSCOPE_CACHE=/private/tmp/modelscope-cache NUMBA_CACHE_DIR=/private/tmp/numba-cache TRANSFORMERS_NO_TF=1 USE_TF=0 /private/tmp/dolphin-venv/bin/python -m dolphin /private/tmp/dolphin-test-audio/zh-cn-demo.wav --model base --model_dir /private/tmp/dolphin-models/base --device cpu --remove_punctuation true +``` + +Result: + +```text +decode result, language: zh, region: CN, text: 诚然 时代正在推崇初心 但文化之精髓切虚传承 +``` + +The logged word timestamps no longer include the comma token. + +Exit code: `0`. + +### Long zh-CN And hi-IN Python API + +Command: + +```bash +env HOME=/private/tmp/dolphin-home MODELSCOPE_CACHE=/private/tmp/modelscope-cache NUMBA_CACHE_DIR=/private/tmp/numba-cache TRANSFORMERS_NO_TF=1 USE_TF=0 /private/tmp/dolphin-venv/bin/python -c ' +import logging +import unicodedata +import dolphin +from dolphin.transcribe import transcribe_long +logging.getLogger("dolphin").setLevel(logging.ERROR) +model = dolphin.load_model("base", "/private/tmp/dolphin-models/base", "cpu") +for label, path in [("zh-long", "/private/tmp/dolphin-test-audio/zh-cn-long.wav"), ("hi-long", "/private/tmp/dolphin-test-audio/hi-in-work-report.wav")]: + results = transcribe_long(model, path, remove_punctuation=True) + merged = "".join(item.text_nospecial for item in results) + has_punctuation = any(unicodedata.category(ch).startswith("P") for ch in merged) + ts_count = sum(1 for item in results for ts in (item.word_timestamps or []) if any(unicodedata.category(ch).startswith("P") for ch in str(ts.get("word", "")))) + print(label, "segments=", len(results), "text_has_punctuation=", has_punctuation, "timestamp_punctuation_count=", ts_count) + print(results[0].text_nospecial[:120] if results else "") +' +``` + +Result: + +```text +zh-long segments= 61 text_has_punctuation= False timestamp_punctuation_count= 0 +hi-long segments= 177 text_has_punctuation= False timestamp_punctuation_count= 0 +मेरा नामेश कुमार मैं +``` + +Exit code: `0`. + +## Notes And Risks + +- This option removes punctuation as Unicode punctuation categories, so it covers CJK punctuation, ASCII punctuation, and punctuation attached to Devanagari/Latin tokens. +- Special task and language tags in `result.text` are preserved. +- The model can still internally generate punctuation; this option cleans the returned/logged result. +- Because punctuation is removed after recognition, spacing may contain doubled spaces where punctuation was originally surrounded by spaces. The branch keeps this conservative behavior instead of normalizing whitespace across languages. + +## Recommendation + +This branch is ready for user review for issue #92. It gives CLI and Python users an explicit way to disable punctuation in returned transcription text. diff --git a/tests/test_punctuation.py b/tests/test_punctuation.py new file mode 100644 index 0000000..e81194d --- /dev/null +++ b/tests/test_punctuation.py @@ -0,0 +1,62 @@ +import importlib +import sys +import types + + +def _install_modelscope_stub(): + modelscope = types.ModuleType("modelscope") + modelscope.snapshot_download = lambda *args, **kwargs: None + + model_module = types.ModuleType("modelscope.models.audio.funasr.model") + + class GenericFunASR: + pass + + model_module.GenericFunASR = GenericFunASR + + sys.modules.setdefault("modelscope", modelscope) + sys.modules.setdefault("modelscope.models", types.ModuleType("modelscope.models")) + sys.modules.setdefault("modelscope.models.audio", types.ModuleType("modelscope.models.audio")) + sys.modules.setdefault("modelscope.models.audio.funasr", types.ModuleType("modelscope.models.audio.funasr")) + sys.modules.setdefault("modelscope.models.audio.funasr.model", model_module) + + +_install_modelscope_stub() +transcribe = importlib.import_module("dolphin.transcribe") + + +def test_remove_punctuation_preserves_special_tokens(): + text = " 你好,world!" + + assert transcribe._remove_punctuation_preserving_special_tokens(text) == ( + " 你好world" + ) + + +def test_remove_result_punctuation_cleans_text_and_timestamps(): + result = transcribe.TranscribeResult( + text=" 你好,world!", + text_nospecial="你好,world!", + language="zh", + region="CN", + word_timestamps=[ + {"word": "你好", "start": 0.0, "end": 0.4}, + {"word": ",", "start": 0.4, "end": 0.5}, + {"word": "world!", "start": 0.5, "end": 0.9}, + ], + ) + + cleaned = transcribe._remove_result_punctuation(result) + + assert cleaned.text == " 你好world" + assert cleaned.text_nospecial == "你好world" + assert cleaned.word_timestamps == [ + {"word": "你好", "start": 0.0, "end": 0.4}, + {"word": "world", "start": 0.5, "end": 0.9}, + ] + + +def test_parser_remove_punctuation_flag(monkeypatch): + monkeypatch.setattr(sys, "argv", ["dolphin", "audio.wav", "--remove_punctuation", "true"]) + + assert transcribe.parser_args().remove_punctuation is True From 5ad9a0f0ceb20ef3ed776042e908777ca5bc50a9 Mon Sep 17 00:00:00 2001 From: MXuer Date: Thu, 11 Jun 2026 11:03:44 +0800 Subject: [PATCH 4/8] Add integration test report --- ...integration-issues-80-92-93-test-report.md | 178 ++++++++++++++++++ 1 file changed, 178 insertions(+) create mode 100644 reports/integration-issues-80-92-93-test-report.md diff --git a/reports/integration-issues-80-92-93-test-report.md b/reports/integration-issues-80-92-93-test-report.md new file mode 100644 index 0000000..c13d2ea --- /dev/null +++ b/reports/integration-issues-80-92-93-test-report.md @@ -0,0 +1,178 @@ +# Integration Test Report: Issues #80, #92, #93 + +Date: 2026-06-11 +Branch: `codex/integration-issues-80-92-93` +Base: `main` at `78ea615` + +## Scope + +This branch integrates the three issue branches that are ready for review: + +| Issue | Branch | Commit | Summary | +| --- | --- | --- | --- | +| #80 | `codex/issue-80-cli-output` | `5fff9c8` | CLI output files and `txt/json/srt` formats | +| #93 | `codex/issue-93-language-detection` | `7f6d188` | CLI/API language-detection-only mode | +| #92 | `codex/issue-92-disable-punctuation` | `a256cf7` | CLI/API punctuation removal option | + +The integration branch cherry-picks these changes onto current `main` and resolves overlapping CLI parser edits in `dolphin/transcribe.py`. + +## Integration Commits + +```text +8ed0086 Add punctuation removal option +5e429b2 Expose language detection task +9a36c51 Add CLI output formats +78ea615 Merge pull request #105 from DataoceanAI/update-readme +``` + +## Conflict Resolution + +The only conflicts were in `dolphin/transcribe.py`, where all three branches extended imports and CLI options. The resolved CLI keeps all options together: + +- `--remove_punctuation` +- `--lid_duration` +- `--task {transcribe,detect_language}` +- `--output` +- `--output_format {txt,json,srt}` + +The final CLI flow is: + +1. Load model and audio. +2. If `--task detect_language`, print `languageregion` and return. +3. Otherwise transcribe normally. +4. Optionally remove punctuation from returned text and timestamps. +5. Emit output to stdout or `--output` in `txt`, `json`, or `srt`. + +## Test Environment + +- Runtime used for validation: `/private/tmp/dolphin-venv/bin/python` +- Python: `3.10.9` +- Key packages: + - `numpy 1.23.5` + - `torch 2.4.1` + - `torchaudio 2.4.1` + - `modelscope 1.36.3` + - `funasr 1.1.5` +- Environment variables used: + - `TRANSFORMERS_NO_TF=1` + - `USE_TF=0` + - `MODELSCOPE_CACHE=/private/tmp/modelscope-cache` + - `NUMBA_CACHE_DIR=/private/tmp/numba-cache` + - `HOME=/private/tmp/dolphin-home` for long-audio VAD cache isolation +- Model: + - Dolphin `base` + - Local path: `/private/tmp/dolphin-models/base` + +## Test Audio + +| Case | Source | Local file | Duration | +| --- | --- | --- | --- | +| Short zh-CN | `https://so-algorithm-test.oss-cn-beijing.aliyuncs.com/samples/asr/zh-cn-demo.wav` | `/private/tmp/dolphin-test-audio/zh-cn-demo.wav` | 6.267s | +| Long zh-CN | `https://so-algorithm-test.oss-cn-beijing.aliyuncs.com/samples/asr/zh-cn-long.wav` | `/private/tmp/dolphin-test-audio/zh-cn-long.wav` | 752.880s | +| Long hi-IN | `https://so-algorithm-test.oss-cn-beijing.aliyuncs.com/samples/asr/hi_in/%E5%8D%95%E4%BA%BA%E5%BD%95%E9%9F%B3_%E5%B7%A5%E4%BD%9C%E6%8A%A5%E5%91%8A_12112.wav` | `/private/tmp/dolphin-test-audio/hi-in-work-report.wav` | 1274.009s | + +## Automated Tests + +Command: + +```bash +env TRANSFORMERS_NO_TF=1 USE_TF=0 python -m pytest +``` + +Result: + +```text +collected 14 items +tests/test_cli_output.py ....... [ 50%] +tests/test_language_detection.py .... [ 78%] +tests/test_punctuation.py ... [100%] +14 passed +``` + +Static whitespace check: + +```bash +git diff --check +``` + +Result: passed. + +## CLI Help Check + +Command: + +```bash +env TRANSFORMERS_NO_TF=1 USE_TF=0 /private/tmp/dolphin-venv/bin/python -m dolphin --help +``` + +Result: help output includes all integrated options: + +```text +--remove_punctuation REMOVE_PUNCTUATION +--lid_duration LID_DURATION +--task {transcribe,detect_language} +--output OUTPUT +--output_format {txt,json,srt} +``` + +Exit code: `0`. + +## Real Language Detection Tests + +All three requested audio files were tested with: + +```bash +/private/tmp/dolphin-venv/bin/python -m dolphin AUDIO --model base --model_dir /private/tmp/dolphin-models/base --device cpu --task detect_language +``` + +Results: + +| Audio | Result | +| --- | --- | +| Short zh-CN | `zh CN` | +| Long zh-CN | `zh CN` | +| Long hi-IN | `hi IN` | + +All commands exited with code `0`. + +## Real SRT And Punctuation Tests + +All three requested audio files were tested with integrated SRT output and punctuation removal: + +```bash +/private/tmp/dolphin-venv/bin/python -m dolphin AUDIO --model base --model_dir /private/tmp/dolphin-models/base --device cpu --remove_punctuation true --output OUTPUT.srt --output_format srt +``` + +Output files: + +| Audio | Output | Cue count | Last end | Body punctuation | +| --- | --- | ---: | ---: | ---: | +| Short zh-CN | `/private/tmp/dolphin-integration-test-output/demo-clean.srt` | 1 | 5.867s | 0 | +| Long zh-CN | `/private/tmp/dolphin-integration-test-output/zh-cn-long-clean.srt` | 60 | 748.680s | 0 | +| Long hi-IN | `/private/tmp/dolphin-integration-test-output/hi-in-clean.srt` | 174 | 1273.240s | 0 | + +Structural validation checked: + +- cue numbers are continuous from 1; +- time lines match `HH:MM:SS,mmm --> HH:MM:SS,mmm`; +- cue start/end times are ordered and non-overlapping; +- subtitle body text has no Unicode punctuation. + +Result: + +```text +demo-clean: cues=1, valid=True, last_end=5.867s, body_punctuation=0 +zh-cn-long-clean: cues=60, valid=True, last_end=748.680s, body_punctuation=0 +hi-in-clean: cues=174, valid=True, last_end=1273.240s, body_punctuation=0 +``` + +## Notes And Risks + +- The integration does not change model weights or decoding internals. +- `--remove_punctuation` is output post-processing, so spacing may contain doubled spaces where punctuation was removed. +- Language detection still loads a Dolphin ASR model; this does not add a separate lightweight LID-only model. +- The global Python environment has `numpy 2.2.6`, which emits noisy TensorFlow/Whisper/numba import warnings. Real validation used a temporary venv with `numpy 1.23.5`. + +## Recommendation + +For easiest review, keep the three per-issue branches available. For easiest merge, use `codex/integration-issues-80-92-93` after reviewing this report, because it already resolves the overlapping CLI changes and has been tested with the requested short audio, long zh-CN audio, long hi-IN audio, and SRT output. From 0ddfbbbaff6cbc1a179fad87452d2479fb7ff251 Mon Sep 17 00:00:00 2001 From: MXuer Date: Thu, 11 Jun 2026 11:14:02 +0800 Subject: [PATCH 5/8] Fix README repository and model links --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index a12703b..400d6d0 100644 --- a/README.md +++ b/README.md @@ -42,14 +42,14 @@ pip install -U dataoceanai-dolphin Alternatively, it can also be installed from the source: ```shell -pip install git+https://github.com/SpeechOceanTech/Dolphin.git +pip install git+https://github.com/DataoceanAI/Dolphin.git ``` ## Available Models and Languages ### Models -There are 8 models in Dolphin, and 6 of them are available now. See details in [Dolphin](https://arxiv.org/abs/2503.20212) and [Dolphin-CN-Dialect](https://arxiv.org/abs/2605.08961). +Available Dolphin models are listed below. See details in [Dolphin](https://arxiv.org/abs/2503.20212) and [Dolphin-CN-Dialect](https://arxiv.org/abs/2605.08961). | Model | Parameters |Publicly Available | |:------:|:----------:|:------------------:| @@ -58,7 +58,7 @@ There are 8 models in Dolphin, and 6 of them are available now. See details in [ | medium | 0.9 B | | | large | 1.7B | | | [base.cn](https://modelscope.cn/models/DataoceanAI/dolphin-cn-dialect-base) | 0.1 B | ✅ | -| [base.cn.streaming](https://modelscope.cn/models/DataoceanAI/dolphin-cn-dialect-small-prompt) | 0.1 B | ✅ | +| [base.cn.streaming](https://modelscope.cn/models/DataoceanAI/dolphin-cn-dialect-base-streaming) | 0.1 B | ✅ | | [small.cn](https://modelscope.cn/models/DataoceanAI/dolphi-cn-dialect-small) | 0.4 B | ✅ | | [small.cn.streaming](https://modelscope.cn/models/DataoceanAI/dolphin-cn-dialect-small-streaming) | 0.4 B | ✅ | | [small.cn.prompt](https://modelscope.cn/models/DataoceanAI/dolphin-cn-dialect-small-prompt) | 0.4 B | ✅ | From 89ed38617e78043ef440275ee91f3c06164e63c5 Mon Sep 17 00:00:00 2001 From: MXuer Date: Thu, 11 Jun 2026 11:18:13 +0800 Subject: [PATCH 6/8] Fix small CN model spelling --- README.md | 2 +- dolphin/model_registry.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 400d6d0..22c4ff5 100644 --- a/README.md +++ b/README.md @@ -59,7 +59,7 @@ Available Dolphin models are listed below. See details in [Dolphin](https://arxi | large | 1.7B | | | [base.cn](https://modelscope.cn/models/DataoceanAI/dolphin-cn-dialect-base) | 0.1 B | ✅ | | [base.cn.streaming](https://modelscope.cn/models/DataoceanAI/dolphin-cn-dialect-base-streaming) | 0.1 B | ✅ | -| [small.cn](https://modelscope.cn/models/DataoceanAI/dolphi-cn-dialect-small) | 0.4 B | ✅ | +| [small.cn](https://modelscope.cn/models/DataoceanAI/dolphin-cn-dialect-small) | 0.4 B | ✅ | | [small.cn.streaming](https://modelscope.cn/models/DataoceanAI/dolphin-cn-dialect-small-streaming) | 0.4 B | ✅ | | [small.cn.prompt](https://modelscope.cn/models/DataoceanAI/dolphin-cn-dialect-small-prompt) | 0.4 B | ✅ | diff --git a/dolphin/model_registry.py b/dolphin/model_registry.py index cb9b7bb..760e8ea 100644 --- a/dolphin/model_registry.py +++ b/dolphin/model_registry.py @@ -17,7 +17,7 @@ "sha256": "62e4c11fe1e0e42bd34e444172c5a05e792c4b5a03750f794fa3206fc0649cd7" }, "small.cn": { - "model_id": "DataoceanAI/dolphi-cn-dialect-small", + "model_id": "DataoceanAI/dolphin-cn-dialect-small", "sha256": "1cee2b8d2133cabb36567625a832d4033569e27eaf5f98df9be1139ec6068bbb", }, "small.cn.streaming": { From c69ccc13555d82d8da5be8aa0bcdae858d447420 Mon Sep 17 00:00:00 2001 From: MXuer Date: Thu, 11 Jun 2026 11:41:03 +0800 Subject: [PATCH 7/8] Add experimental streaming demo --- README.md | 16 ++ dolphin/model.py | 15 +- dolphin/transcribe.py | 24 +++ examples/streaming_demo.py | 161 ++++++++++++++++++ .../issue-106-streaming-demo-test-report.md | 63 +++++++ reports/issue-triage.md | 49 ++++++ tests/test_streaming_demo.py | 36 ++++ tests/test_streaming_params.py | 127 ++++++++++++++ 8 files changed, 486 insertions(+), 5 deletions(-) create mode 100644 examples/streaming_demo.py create mode 100644 reports/issue-106-streaming-demo-test-report.md create mode 100644 reports/issue-triage.md create mode 100644 tests/test_streaming_demo.py create mode 100644 tests/test_streaming_params.py diff --git a/README.md b/README.md index 22c4ff5..d809c29 100644 --- a/README.md +++ b/README.md @@ -110,6 +110,22 @@ dolphin audio.wav --model small.cn --remove_punctuation true ``` +### Experimental streaming demo + +For Chinese dialect streaming models, the repository provides an experimental +file-streaming demo. It reads an existing audio file in chunks and prints each +chunk result as soon as it is decoded: + +```shell +python examples/streaming_demo.py audio.wav --model small.cn.streaming --device cuda +``` + +For CPU smoke tests, limit the number of chunks: + +```shell +python examples/streaming_demo.py audio.wav --model base.cn.streaming --device cpu --chunk_duration 4 --max_chunks 2 +``` + ### Python usage ```python diff --git a/dolphin/model.py b/dolphin/model.py index e2b2f7f..176b715 100644 --- a/dolphin/model.py +++ b/dolphin/model.py @@ -960,11 +960,16 @@ def forward( return self.forward_attention(v, scores, mask), new_cache else: - # NOTE(Mddct): we need mask bias, not boolean mask - assert mask.dtype != torch.bool - mask = mask.unsqueeze(1) - # matrix_bd as a mask bias - mask = (matrix_bd + mask) / math.sqrt(self.d_k) + if mask.size(-1) > 0: + # NOTE(Mddct): SDPA needs an attention bias here so the + # relative position logits can share the same attn_mask. + if mask.dtype == torch.bool: + mask = mask_to_bias(mask, query.dtype) + mask = mask.unsqueeze(1) + mask = matrix_bd + mask + else: + mask = matrix_bd + mask = mask / math.sqrt(self.d_k) output = torch.nn.functional.scaled_dot_product_attention( q_with_bias_u, k, diff --git a/dolphin/transcribe.py b/dolphin/transcribe.py index 6e236d9..41bf2bf 100644 --- a/dolphin/transcribe.py +++ b/dolphin/transcribe.py @@ -70,6 +70,9 @@ def parser_args() -> Namespace: parser.add_argument("--beam_size", type=int, default=10, help="number of beams in beam search (default: 10)") parser.add_argument("--decoding_method", type=str, default="attention_rescoring", help="decoding methods, supports: attention, attention_rescoring (default: attention_rescoring)") + parser.add_argument("--decoding_chunk_size", type=int, default=-1, help="decoding chunk size for streaming encoder simulation (default: -1)") + parser.add_argument("--num_decoding_left_chunks", type=int, default=-1, help="number of left chunks for streaming encoder simulation (default: -1)") + parser.add_argument("--simulate_streaming", type=str2bool, default=False, help="simulate streaming encoder decoding (default: false)") parser.add_argument("--maxlenratio", type=float, default=0.0, help="deprecated, Input length ratio to obtain max output length (default: 0.0)") parser.add_argument("--padding_speech", type=str2bool, default=False, help="deprecated, whether padding speech to 30 seconds (default: false)") parser.add_argument("--normalize_length", type=str2bool, default=False, help="deprecated, whether to normalize length (default: false)") @@ -333,6 +336,9 @@ def transcribe_long( use_prompt_hotword: bool = False, prompt_filter_threshold: float = -2.0, remove_punctuation: bool = False, + decoding_chunk_size: int = -1, + num_decoding_left_chunks: int = -1, + simulate_streaming: bool = False, **kwargs, ) -> List[TranscribeSegmentResult]: """ @@ -352,6 +358,9 @@ def transcribe_long( use_two_stage_filter: whether use two-stage filtering (default: false) use_prompt_hotword: whether use prompt-based hotword (default: false) prompt_filter_threshold: filter threshold for prompt hotwords (default: -2.0) + decoding_chunk_size: decoding chunk size for streaming encoder simulation (default: -1) + num_decoding_left_chunks: number of left chunks for streaming encoder simulation (default: -1) + simulate_streaming: whether simulate streaming encoder decoding (default: false) Returns: List[TranscribeSegmentResult] @@ -446,6 +455,9 @@ def transcribe_long( speech=batch["feats"], speech_lengths=batch["feats_lengths"], beam_size=beam_size, + decoding_chunk_size=decoding_chunk_size, + num_decoding_left_chunks=num_decoding_left_chunks, + simulate_streaming=simulate_streaming, infos=decoding_infos ) tokens = ret[decoding_method][0].tokens @@ -760,6 +772,9 @@ def transcribe( use_prompt_hotword: bool = False, prompt_filter_threshold: float = -4.0, remove_punctuation: bool = False, + decoding_chunk_size: int = -1, + num_decoding_left_chunks: int = -1, + simulate_streaming: bool = False, **kwargs, ) -> TranscribeResult: """ @@ -779,6 +794,9 @@ def transcribe( use_two_stage_filter: whether use two-stage filtering (default: false) use_prompt_hotword: whether use prompt-based hotword (default: false) prompt_filter_threshold: filter threshold for prompt hotwords (default: -4.0) + decoding_chunk_size: decoding chunk size for streaming encoder simulation (default: -1) + num_decoding_left_chunks: number of left chunks for streaming encoder simulation (default: -1) + simulate_streaming: whether simulate streaming encoder decoding (default: false) Returns: TranscribeResult @@ -859,6 +877,9 @@ def transcribe( speech=batch["feats"], speech_lengths=batch["feats_lengths"], beam_size=beam_size, + decoding_chunk_size=decoding_chunk_size, + num_decoding_left_chunks=num_decoding_left_chunks, + simulate_streaming=simulate_streaming, infos=decoding_infos ) @@ -922,6 +943,9 @@ def cli(): "padding_speech": args.padding_speech, "decoding_method": args.decoding_method, "beam_size": args.beam_size, + "decoding_chunk_size": args.decoding_chunk_size, + "num_decoding_left_chunks": args.num_decoding_left_chunks, + "simulate_streaming": args.simulate_streaming, "hotwords": hotwords, "use_deep_biasing": args.use_deep_biasing, "use_two_stage_filter": args.use_two_stage_filter, diff --git a/examples/streaming_demo.py b/examples/streaming_demo.py new file mode 100644 index 0000000..f6810ae --- /dev/null +++ b/examples/streaming_demo.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python3 +"""Experimental file-streaming demo for Dolphin streaming models. + +This script reads an audio file, feeds it to Dolphin in chunks, and prints +recognition results as each chunk is processed. It is intended as a simple +terminal demo for streaming models such as ``small.cn.streaming``. +""" + +import argparse +import sys +import time +from pathlib import Path +from typing import Iterator, Optional, Tuple + + +SAMPLE_RATE = 16000 + + +def _format_time(seconds: float) -> str: + total_ms = max(0, int(round(seconds * 1000))) + minutes = total_ms // 60000 + total_ms %= 60000 + secs = total_ms // 1000 + millis = total_ms % 1000 + return f"{minutes:02d}:{secs:02d}.{millis:03d}" + + +def _iter_chunk_ranges( + total_samples: int, + chunk_samples: int, + max_chunks: Optional[int] = None, +) -> Iterator[Tuple[int, int, int]]: + if chunk_samples <= 0: + raise ValueError("chunk_samples must be positive") + + index = 0 + start = 0 + while start < total_samples: + if max_chunks is not None and index >= max_chunks: + break + end = min(start + chunk_samples, total_samples) + yield index, start, end + index += 1 + start = end + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Experimental Dolphin file-streaming ASR demo." + ) + parser.add_argument("audio", type=Path, help="audio file to stream") + parser.add_argument( + "--model", + default="small.cn.streaming", + help="streaming model name (default: small.cn.streaming)", + ) + parser.add_argument( + "--model_dir", + type=Path, + default=None, + help="model checkpoint directory; defaults to ~/.cache/dolphin/", + ) + parser.add_argument("--device", default=None, help="torch device, e.g. cuda or cpu") + parser.add_argument("--lang_sym", default="zh", help="language symbol (default: zh)") + parser.add_argument("--region_sym", default=None, help="region symbol, e.g. CN") + parser.add_argument("--chunk_duration", type=float, default=4.0, help="chunk size in seconds (default: 4.0)") + parser.add_argument("--beam_size", type=int, default=5, help="beam size (default: 5)") + parser.add_argument( + "--decoding_method", + default="attention_rescoring", + choices=("attention", "attention_rescoring"), + help="decoding method (default: attention_rescoring)", + ) + parser.add_argument("--decoding_chunk_size", type=int, default=16, help="encoder decoding chunk size (default: 16)") + parser.add_argument("--num_decoding_left_chunks", type=int, default=4, help="left chunks kept by encoder (default: 4)") + parser.add_argument( + "--mode", + choices=("chunk", "rolling"), + default="chunk", + help="chunk prints each chunk independently; rolling prints accumulated partial text (default: chunk)", + ) + parser.add_argument( + "--realtime", + action="store_true", + help="sleep between chunks to approximate real-time playback", + ) + parser.add_argument( + "--max_chunks", + type=int, + default=None, + help="limit chunks for quick smoke tests", + ) + return parser.parse_args() + + +def main() -> int: + args = _parse_args() + if args.chunk_duration <= 0: + raise ValueError("--chunk_duration must be positive") + + repo_root = Path(__file__).resolve().parents[1] + if str(repo_root) not in sys.path: + sys.path.insert(0, str(repo_root)) + + import torch + + import dolphin + from dolphin.audio import load_audio + from dolphin.transcribe import transcribe + + model_dir = args.model_dir or Path.home() / ".cache" / "dolphin" / args.model + + print(f"loading model: {args.model} ({model_dir})", flush=True) + model = dolphin.load_model(args.model, model_dir, args.device) + waveform = torch.from_numpy(load_audio(str(args.audio))).float().unsqueeze(0) + + chunk_samples = int(args.chunk_duration * SAMPLE_RATE) + total_samples = waveform.size(-1) + print( + f"streaming file: {args.audio} " + f"duration={total_samples / SAMPLE_RATE:.2f}s " + f"chunk={args.chunk_duration:.2f}s mode={args.mode}", + flush=True, + ) + + started_at = time.time() + for index, start, end in _iter_chunk_ranges(total_samples, chunk_samples, args.max_chunks): + chunk = waveform[:, :end] if args.mode == "rolling" else waveform[:, start:end] + result = transcribe( + model, + chunk, + lang_sym=args.lang_sym, + region_sym=args.region_sym, + predict_time=False, + word_timestamp=False, + decoding_method=args.decoding_method, + beam_size=args.beam_size, + decoding_chunk_size=args.decoding_chunk_size, + num_decoding_left_chunks=args.num_decoding_left_chunks, + simulate_streaming=True, + ) + + if args.mode == "rolling": + span = f"00:00.000-{_format_time(end / SAMPLE_RATE)}" + label = "partial" + else: + span = f"{_format_time(start / SAMPLE_RATE)}-{_format_time(end / SAMPLE_RATE)}" + label = "chunk" + print(f"[{index + 1:04d} {label} {span}] {result.text_nospecial}", flush=True) + + if args.realtime: + target_elapsed = end / SAMPLE_RATE + sleep_for = target_elapsed - (time.time() - started_at) + if sleep_for > 0: + time.sleep(sleep_for) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/reports/issue-106-streaming-demo-test-report.md b/reports/issue-106-streaming-demo-test-report.md new file mode 100644 index 0000000..9b6f8df --- /dev/null +++ b/reports/issue-106-streaming-demo-test-report.md @@ -0,0 +1,63 @@ +# Issue #106 Streaming Demo Test Report + +Date: 2026-06-11 +Branch: `codex/integration-issues-80-92-93` +Issue: https://github.com/DataoceanAI/Dolphin/issues/106 + +## Scope + +- Added `examples/streaming_demo.py` as an experimental terminal file-streaming demo for `small.cn.streaming` and other streaming models. +- Added CLI/API passthrough for `decoding_chunk_size`, `num_decoding_left_chunks`, and `simulate_streaming`. +- Fixed the SDPA relative-position attention path so streaming chunk inference can run when no explicit mask is supplied. +- Updated `README.md` with terminal demo commands. +- Added `reports/issue-triage.md` so handled and candidate issues are tracked in the repository. + +## How To Run + +```shell +python examples/streaming_demo.py audio.wav --model small.cn.streaming --device cuda +``` + +CPU smoke test: + +```shell +python examples/streaming_demo.py audio.wav --model small.cn.streaming --device cpu --chunk_duration 3 --max_chunks 1 +``` + +## Verification + +```shell +env TRANSFORMERS_NO_TF=1 USE_TF=0 python -m pytest +``` + +Result: `20 passed in 1.17s`. + +```shell +python -m py_compile examples/streaming_demo.py +``` + +Result: passed. + +```shell +python examples/streaming_demo.py --help +``` + +Result: passed. + +Real model smoke test: + +```shell +env MODELSCOPE_CACHE=/private/tmp/modelscope-cache NUMBA_CACHE_DIR=/private/tmp/numba-cache TRANSFORMERS_NO_TF=1 USE_TF=0 /private/tmp/dolphin-venv/bin/python examples/streaming_demo.py /private/tmp/dolphin-test-audio/zh-cn-demo.wav --model small.cn.streaming --model_dir /private/tmp/dolphin-models/small.cn.streaming --device cpu --chunk_duration 3 --max_chunks 1 +``` + +Result: + +```text +[0001 chunk 00:00.000-00:03.000] 诚然时代正在推陈出新 +``` + +## Notes + +- The demo streams an existing audio file in fixed-size chunks and prints each decoded chunk immediately. +- `--mode chunk` decodes each chunk independently. `--mode rolling` re-decodes accumulated audio and prints partial text. +- This is intentionally an experimental terminal demo, not a production microphone or socket streaming server. diff --git a/reports/issue-triage.md b/reports/issue-triage.md new file mode 100644 index 0000000..0cc68c3 --- /dev/null +++ b/reports/issue-triage.md @@ -0,0 +1,49 @@ +# Issue Triage Record + +Date: 2026-06-11 +Branch: `codex/integration-issues-80-92-93` +PR: https://github.com/DataoceanAI/Dolphin/pull/109 + +This record tracks issues reviewed during the Codex pass so future work does not need to re-triage from scratch. + +## Done In PR #109 + +| Issue | Status | Notes | +| --- | --- | --- | +| #80 | Done | Added CLI `--output` and `--output_format {txt,json,srt}`. | +| #92 | Done | Added CLI/API punctuation removal via `--remove_punctuation` and `remove_punctuation=True`. | +| #93 | Done | Added `dolphin.detect_language(...)`, `--task detect_language`, and `--lid_duration`. | +| #106 | Done | Added experimental `examples/streaming_demo.py`, streaming decode options, and SDPA chunk mask support for `small.cn.streaming`. | + +## In Progress + +None. + +## Already Fixed Or Mostly Addressed Upstream + +| Issue | Status | Notes | +| --- | --- | --- | +| #72 | Already fixed | Current `main` exports `dolphin.load_audio`. | +| #81 | Partially fixed | VAD cache directory is created before download. Remaining dependency/cache guidance can be documented. | + +## Good Next Candidates + +| Issue | Type | Proposed work | +| --- | --- | --- | +| #44 | Code | Wire `maxlenratio` or a replacement max decode length setting through the decode path. | +| #83 | Docs | Add word-level timestamp examples and explain `--word_timestamp`. | +| #86 / #62 | Docs | Expand hotword docs with CLI/API examples and state current ONNX status. | +| #95 | Docs/example | Add a simple FastAPI or HTTP service deployment example. | +| #42 | Docs/API | Document long-audio Python usage and safer `transcribe_long` patterns. | +| #33 | Code/API | Explore `transcribe_batch` or CLI multi-file batch inference. | +| #50 / #67 / #41 / #89 | Docs/errors | Add installation and environment troubleshooting guidance. | +| #20 / #53 / #78 | Docs | Clarify base/small vs `*.cn` dialect models and dialect tag behavior. | + +## Not Directly Actionable Without Maintainer Or Model-Team Input + +| Issues | Reason | +| --- | --- | +| #7 / #43 | Large model release policy. | +| #8 / #10 / #35 / #56 / #88 | Fine-tuning code release and training recipe policy. | +| #15 / #49 / #84 / #108 | Benchmark methodology or paper-result discussion. | +| #18 / #24 / #51 / #90 / #107 | New model capability, resource release, or language support decisions. | diff --git a/tests/test_streaming_demo.py b/tests/test_streaming_demo.py new file mode 100644 index 0000000..ec400cc --- /dev/null +++ b/tests/test_streaming_demo.py @@ -0,0 +1,36 @@ +import importlib.util +from pathlib import Path + + +def _load_streaming_demo(): + path = Path(__file__).resolve().parents[1] / "examples" / "streaming_demo.py" + spec = importlib.util.spec_from_file_location("streaming_demo", path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def test_format_time(): + demo = _load_streaming_demo() + + assert demo._format_time(0) == "00:00.000" + assert demo._format_time(65.4321) == "01:05.432" + + +def test_iter_chunk_ranges(): + demo = _load_streaming_demo() + + assert list(demo._iter_chunk_ranges(10, 4)) == [ + (0, 0, 4), + (1, 4, 8), + (2, 8, 10), + ] + + +def test_iter_chunk_ranges_respects_max_chunks(): + demo = _load_streaming_demo() + + assert list(demo._iter_chunk_ranges(10, 4, max_chunks=2)) == [ + (0, 0, 4), + (1, 4, 8), + ] diff --git a/tests/test_streaming_params.py b/tests/test_streaming_params.py new file mode 100644 index 0000000..bc81ec2 --- /dev/null +++ b/tests/test_streaming_params.py @@ -0,0 +1,127 @@ +import importlib +import sys +import types + +import torch + + +def _install_modelscope_stub(): + modelscope = types.ModuleType("modelscope") + modelscope.snapshot_download = lambda *args, **kwargs: None + + model_module = types.ModuleType("modelscope.models.audio.funasr.model") + + class GenericFunASR: + pass + + model_module.GenericFunASR = GenericFunASR + + sys.modules.setdefault("modelscope", modelscope) + sys.modules.setdefault("modelscope.models", types.ModuleType("modelscope.models")) + sys.modules.setdefault("modelscope.models.audio", types.ModuleType("modelscope.models.audio")) + sys.modules.setdefault("modelscope.models.audio.funasr", types.ModuleType("modelscope.models.audio.funasr")) + sys.modules.setdefault("modelscope.models.audio.funasr.model", model_module) + + +_install_modelscope_stub() +transcribe = importlib.import_module("dolphin.transcribe") +model_module = importlib.import_module("dolphin.model") + + +class FakeTokenizer: + symbol_table = {} + + def tokens2ids(self, tokens): + mapping = {"<30.00>": 2} + return [mapping[item] for item in tokens] + + def ids2tokens(self, ids): + mapping = {1: "", 2: "", 3: "你", 4: "好"} + return [mapping[item] for item in ids] + + def detokenize(self, tokens): + token_text = "".join(self.ids2tokens(tokens)) + if tokens == [3, 4]: + token_text = "你好" + return [token_text] + + +class FakeHyp: + tokens = [1, 2, 3, 4] + times = None + + +class FakeModel: + device = torch.device("cpu") + model_configs = {"support_timestamp": False} + + def __init__(self): + self.decode_kwargs = None + + def decode(self, **kwargs): + self.decode_kwargs = kwargs + return {"attention_rescoring": [FakeHyp()]} + + +def test_transcribe_passes_streaming_decode_parameters(monkeypatch): + model = FakeModel() + monkeypatch.setattr( + transcribe, + "extract_feats", + lambda audio, configs: { + "feats": torch.zeros(1, 4, 8), + "feats_lengths": torch.tensor([4]), + }, + ) + monkeypatch.setattr(transcribe, "init_tokenizer", lambda configs: FakeTokenizer()) + + result = transcribe.transcribe( + model, + torch.zeros(1, 16000), + decoding_chunk_size=16, + num_decoding_left_chunks=4, + simulate_streaming=True, + ) + + assert result.text_nospecial == "你好" + assert model.decode_kwargs["decoding_chunk_size"] == 16 + assert model.decode_kwargs["num_decoding_left_chunks"] == 4 + assert model.decode_kwargs["simulate_streaming"] is True + + +def test_parser_streaming_flags(monkeypatch): + monkeypatch.setattr( + sys, + "argv", + [ + "dolphin", + "audio.wav", + "--decoding_chunk_size", + "16", + "--num_decoding_left_chunks", + "4", + "--simulate_streaming", + "true", + ], + ) + + args = transcribe.parser_args() + + assert args.decoding_chunk_size == 16 + assert args.num_decoding_left_chunks == 4 + assert args.simulate_streaming is True + + +def test_rel_position_sdpa_accepts_empty_streaming_mask(): + attention = model_module.RelPositionMultiHeadedAttention( + 2, + 8, + 0.0, + use_sdpa=True, + ) + query = torch.randn(1, 4, 8) + pos_emb = torch.randn(1, 4, 8) + + output, _ = attention(query, query, query, pos_emb=pos_emb) + + assert output.shape == query.shape From 65e4ca039fb9e5349e9ebe95e272aaac462ce392 Mon Sep 17 00:00:00 2001 From: MXuer Date: Thu, 11 Jun 2026 15:25:43 +0800 Subject: [PATCH 8/8] Add cache-level streaming demos --- README.md | 27 +- examples/microphone_streaming_demo.py | 310 +++++++++ examples/streaming_demo.py | 586 ++++++++++++++++-- .../issue-106-streaming-demo-test-report.md | 77 ++- reports/issue-triage.md | 2 +- tests/test_streaming_demo.py | 65 +- 6 files changed, 974 insertions(+), 93 deletions(-) create mode 100644 examples/microphone_streaming_demo.py diff --git a/README.md b/README.md index d809c29..dde02db 100644 --- a/README.md +++ b/README.md @@ -113,19 +113,35 @@ dolphin audio.wav --model small.cn --remove_punctuation true ### Experimental streaming demo For Chinese dialect streaming models, the repository provides an experimental -file-streaming demo. It reads an existing audio file in chunks and prints each -chunk result as soon as it is decoded: +cache-level streaming demo. It drives `forward_encoder_chunk` with encoder +caches and prints CTC partial results as each chunk is decoded. `--chunk_size` +controls the encoder streaming chunk size. CTC endpointing is enabled by +default, so a long silence or long utterance automatically finalizes the +current segment: ```shell -python examples/streaming_demo.py audio.wav --model small.cn.streaming --device cuda +python examples/streaming_demo.py audio.wav --model small.cn.streaming --device cuda --chunk_size 16 --final_rescore attention ``` -For CPU smoke tests, limit the number of chunks: +For timestamped partial lines or CPU smoke tests: ```shell -python examples/streaming_demo.py audio.wav --model base.cn.streaming --device cpu --chunk_duration 4 --max_chunks 2 +python examples/streaming_demo.py audio.wav --model base.cn.streaming --device cpu --chunk_size 16 --emit line --max_chunks 2 ``` +To stream from your microphone, install the optional recorder dependency and +run: + +```shell +python -m pip install sounddevice +python examples/microphone_streaming_demo.py --model small.cn.streaming --device cuda --chunk_size 16 --final_rescore attention +``` + +Endpoint defaults follow common CTC streaming behavior: 5s silence before any +decoded text, 1s silence after decoded text, or 20s maximum utterance length. +Use `--disable_endpoint` to turn this off, or tune +`--endpoint_rule2_min_trailing_silence_ms` for faster/slower segment finals. + ### Python usage ```python @@ -171,7 +187,6 @@ Thanks to the following excellent open-source works: - [Espnet](https://github.com/espnet/espnet) - [Wenet](https://github.com/wenet-e2e/wenet) - [FunASR](https://github.com/modelscope/FunASR) -- [FireRedASR2S](https://github.com/FireRedTeam/FireRedASR2S) ## License diff --git a/examples/microphone_streaming_demo.py b/examples/microphone_streaming_demo.py new file mode 100644 index 0000000..ec06fc2 --- /dev/null +++ b/examples/microphone_streaming_demo.py @@ -0,0 +1,310 @@ +#!/usr/bin/env python3 +"""Experimental microphone streaming demo for Dolphin streaming models.""" + +import argparse +import queue +import sys +import time +from pathlib import Path +from typing import Optional + + +SAMPLE_RATE = 16000 + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Experimental Dolphin microphone streaming ASR demo." + ) + parser.add_argument( + "--model", + default="small.cn.streaming", + help="streaming model name (default: small.cn.streaming)", + ) + parser.add_argument( + "--model_dir", + type=Path, + default=None, + help="model checkpoint directory; defaults to ~/.cache/dolphin/", + ) + parser.add_argument("--device", default=None, help="torch device, e.g. cuda or cpu") + parser.add_argument("--lang_sym", default=None, help="language symbol for final attention rescoring, e.g. zh") + parser.add_argument("--region_sym", default=None, help="region symbol for final attention rescoring, e.g. CN") + parser.add_argument( + "--input_device", + default=None, + help="sounddevice input device id/name (default: system default)", + ) + parser.add_argument( + "--list_devices", + action="store_true", + help="list microphone devices and exit", + ) + parser.add_argument( + "--chunk_size", + "--decoding_chunk_size", + dest="chunk_size", + type=int, + default=16, + help="encoder streaming chunk size in subsampled frames (default: 16)", + ) + parser.add_argument( + "--left_chunks", + "--num_decoding_left_chunks", + dest="left_chunks", + type=int, + default=4, + help="left chunks kept by encoder cache; use -1 for all history (default: 4)", + ) + parser.add_argument( + "--block_ms", + type=int, + default=40, + help="microphone callback block size in milliseconds (default: 40)", + ) + parser.add_argument( + "--feature_update_ms", + type=int, + default=80, + help="minimum interval between feature refreshes in milliseconds (default: 80)", + ) + parser.add_argument( + "--duration", + type=float, + default=None, + help="optional maximum recording duration in seconds", + ) + parser.add_argument( + "--emit", + choices=("delta", "line"), + default="delta", + help="delta prints only newly recognized text; line prints timestamped partials (default: delta)", + ) + parser.add_argument( + "--max_chunks", + type=int, + default=None, + help="limit encoder chunks for quick smoke tests", + ) + parser.add_argument( + "--final_rescore", + choices=("none", "attention"), + default="none", + help="run final attention rescoring after streaming partials (default: none)", + ) + parser.add_argument( + "--rescore_beam_size", + type=int, + default=10, + help="beam size for final attention rescoring (default: 10)", + ) + parser.add_argument( + "--rescore_ctc_weight", + type=float, + default=0.0, + help="CTC score weight for final attention rescoring (default: 0.0)", + ) + parser.add_argument( + "--reverse_weight", + type=float, + default=0.0, + help="right-to-left decoder weight for final attention rescoring (default: 0.0)", + ) + parser.add_argument( + "--disable_endpoint", + action="store_true", + help="disable CTC endpoint segmentation", + ) + parser.add_argument( + "--endpoint_blank_threshold", + type=float, + default=0.8, + help="blank probability threshold treated as silence (default: 0.8)", + ) + parser.add_argument( + "--endpoint_rule1_min_trailing_silence_ms", + type=int, + default=5000, + help="endpoint rule1: silence timeout without decoded text (default: 5000)", + ) + parser.add_argument( + "--endpoint_rule2_min_trailing_silence_ms", + type=int, + default=1000, + help="endpoint rule2: silence timeout after decoded text (default: 1000)", + ) + parser.add_argument( + "--endpoint_rule3_min_utterance_length_ms", + type=int, + default=20000, + help="endpoint rule3: maximum utterance length (default: 20000)", + ) + return parser.parse_args() + + +def _load_sounddevice(): + try: + import sounddevice as sd + except ModuleNotFoundError as exc: + raise RuntimeError( + "The microphone demo requires sounddevice. Install it with: " + "python -m pip install sounddevice" + ) from exc + return sd + + +def _extract_feats_from_samples(samples, configs): + import torch + from dolphin.processor import extract_feats + + waveform = torch.from_numpy(samples.copy()).float().unsqueeze(0) + return extract_feats([waveform], configs)["feats"] + + +def _drain_audio(audio_queue: "queue.Queue", timeout: Optional[float]): + blocks = [] + try: + blocks.append(audio_queue.get(timeout=timeout)) + except queue.Empty: + return blocks + + while True: + try: + blocks.append(audio_queue.get_nowait()) + except queue.Empty: + break + return blocks + + +def main() -> int: + args = _parse_args() + if args.chunk_size <= 0: + raise ValueError("--chunk_size must be positive") + if args.block_ms <= 0: + raise ValueError("--block_ms must be positive") + if args.feature_update_ms <= 0: + raise ValueError("--feature_update_ms must be positive") + + sd = _load_sounddevice() + if args.list_devices: + print(sd.query_devices()) + return 0 + + repo_root = Path(__file__).resolve().parents[1] + if str(repo_root) not in sys.path: + sys.path.insert(0, str(repo_root)) + + import numpy as np + import dolphin + from dolphin.tokenizer import init_tokenizer + from streaming_demo import ( + StreamingCtcGreedyDecoder, + StreamingOutputState, + _emit_final_rescore, + _emit_streaming_event, + _endpoint_config_from_args, + _finish_current_segment, + _frame_shift_seconds, + ) + + model_dir = args.model_dir or Path.home() / ".cache" / "dolphin" / args.model + + print(f"loading model: {args.model} ({model_dir})", file=sys.stderr, flush=True) + model = dolphin.load_model(args.model, model_dir, args.device) + tokenizer = init_tokenizer(model.model_configs) + decoder = StreamingCtcGreedyDecoder( + model, + tokenizer, + chunk_size=args.chunk_size, + left_chunks=args.left_chunks, + endpoint_config=_endpoint_config_from_args(args), + ) + frame_shift_seconds = _frame_shift_seconds(model.model_configs) + chunk_ms = decoder.stride * frame_shift_seconds * 1000 + print( + f"listening: sample_rate={SAMPLE_RATE} chunk_size={args.chunk_size} " + f"(~{chunk_ms:.0f}ms) left_chunks={args.left_chunks}", + file=sys.stderr, + flush=True, + ) + + audio_queue: "queue.Queue" = queue.Queue() + + def callback(indata, frames, time_info, status): + if status: + print(status, file=sys.stderr, flush=True) + audio_queue.put(indata[:, 0].copy()) + + blocksize = int(SAMPLE_RATE * args.block_ms / 1000) + samples = np.empty((0,), dtype=np.float32) + last_feature_update = 0.0 + started_at = time.time() + output_state = StreamingOutputState() + + try: + with sd.InputStream( + samplerate=SAMPLE_RATE, + channels=1, + dtype="float32", + blocksize=blocksize, + callback=callback, + device=args.input_device, + ): + while True: + if args.duration is not None and time.time() - started_at >= args.duration: + break + if args.max_chunks is not None and decoder.chunks_processed >= args.max_chunks: + break + + blocks = _drain_audio(audio_queue, timeout=0.1) + if blocks: + samples = np.concatenate([samples, *blocks]) + + now = time.time() + if now - last_feature_update < args.feature_update_ms / 1000.0: + continue + last_feature_update = now + + if samples.size == 0: + continue + + feats = _extract_feats_from_samples(samples, model.model_configs).to(model.device) + for event in decoder.decode_available( + feats, + frame_shift_seconds, + max_chunks=args.max_chunks, + flush=False, + ): + if event.is_endpoint: + _finish_current_segment(decoder, args, output_state) + continue + _emit_streaming_event(event, args.emit, output_state) + except KeyboardInterrupt: + pass + + if samples.size: + feats = _extract_feats_from_samples(samples, model.model_configs).to(model.device) + for event in decoder.decode_available( + feats, + frame_shift_seconds, + max_chunks=args.max_chunks, + flush=True, + ): + if event.is_endpoint: + _finish_current_segment(decoder, args, output_state) + continue + _emit_streaming_event(event, args.emit, output_state) + + if args.emit == "delta" and output_state.wrote_delta: + print(flush=True) + if args.final_rescore == "attention" and decoder.encoder_out_chunks: + _emit_final_rescore(decoder, args) + return 0 + + +if __name__ == "__main__": + try: + sys.exit(main()) + except RuntimeError as exc: + print(f"error: {exc}", file=sys.stderr) + sys.exit(2) diff --git a/examples/streaming_demo.py b/examples/streaming_demo.py index f6810ae..683448b 100644 --- a/examples/streaming_demo.py +++ b/examples/streaming_demo.py @@ -1,19 +1,126 @@ #!/usr/bin/env python3 -"""Experimental file-streaming demo for Dolphin streaming models. +"""Experimental cache-level streaming demo for Dolphin streaming models. -This script reads an audio file, feeds it to Dolphin in chunks, and prints -recognition results as each chunk is processed. It is intended as a simple -terminal demo for streaming models such as ``small.cn.streaming``. +This script drives the model through ``forward_encoder_chunk`` with encoder +caches and prints CTC greedy partial results as each encoder chunk is decoded. +For a file demo, fbank features are prepared up front; recognition itself is +performed chunk by chunk instead of by hard-cutting audio and calling +``transcribe`` repeatedly. """ import argparse +import dataclasses import sys import time from pathlib import Path -from typing import Iterator, Optional, Tuple +from typing import Iterator, List, Optional, Tuple -SAMPLE_RATE = 16000 +DEFAULT_FRAME_SHIFT_SECONDS = 0.01 + + +@dataclasses.dataclass +class StreamingEvent: + index: int + processed_seconds: float + text: str + delta: str + is_endpoint: bool = False + + +@dataclasses.dataclass +class StreamingOutputState: + wrote_delta: bool = False + previous_display_text: str = "" + + def reset(self): + self.wrote_delta = False + self.previous_display_text = "" + + +@dataclasses.dataclass +class CtcEndpointRule: + must_decoded_something: bool = True + min_trailing_silence_ms: int = 1000 + min_utterance_length_ms: int = 0 + + +@dataclasses.dataclass +class CtcEndpointConfig: + blank_id: int = 0 + blank_scale: float = 1.0 + blank_threshold: float = 0.8 + rule1: CtcEndpointRule = dataclasses.field( + default_factory=lambda: CtcEndpointRule(False, 5000, 0) + ) + rule2: CtcEndpointRule = dataclasses.field( + default_factory=lambda: CtcEndpointRule(True, 1000, 0) + ) + rule3: CtcEndpointRule = dataclasses.field( + default_factory=lambda: CtcEndpointRule(False, 0, 20000) + ) + + +class CtcEndpoint: + def __init__(self, config: CtcEndpointConfig): + self.config = config + self.reset() + + def reset(self): + self.num_frames_decoded = 0 + self.num_frames_trailing_blank = 0 + + def _rule_activated( + self, + rule: CtcEndpointRule, + decoded_something: bool, + trailing_silence_ms: int, + utterance_length_ms: int, + ) -> bool: + return ( + (decoded_something or not rule.must_decoded_something) + and trailing_silence_ms >= rule.min_trailing_silence_ms + and utterance_length_ms >= rule.min_utterance_length_ms + ) + + def update( + self, + ctc_log_probs, + decoded_something: bool, + frame_shift_ms: int, + ) -> bool: + import math + + for logp_t in ctc_log_probs.squeeze(0): + blank_prob = math.exp(float(logp_t[self.config.blank_id])) + self.num_frames_decoded += 1 + if blank_prob > self.config.blank_threshold * self.config.blank_scale: + self.num_frames_trailing_blank += 1 + else: + self.num_frames_trailing_blank = 0 + + utterance_length_ms = self.num_frames_decoded * frame_shift_ms + trailing_silence_ms = self.num_frames_trailing_blank * frame_shift_ms + return ( + self._rule_activated( + self.config.rule1, + decoded_something, + trailing_silence_ms, + utterance_length_ms, + ) + or self._rule_activated( + self.config.rule2, + decoded_something, + trailing_silence_ms, + utterance_length_ms, + ) + or self._rule_activated( + self.config.rule3, + decoded_something, + trailing_silence_ms, + utterance_length_ms, + ) + ) def _format_time(seconds: float) -> str: @@ -25,28 +132,312 @@ def _format_time(seconds: float) -> str: return f"{minutes:02d}:{secs:02d}.{millis:03d}" -def _iter_chunk_ranges( - total_samples: int, - chunk_samples: int, +def _iter_encoder_chunk_ranges( + total_frames: int, + chunk_size: int, + subsampling: int, + right_context: int, max_chunks: Optional[int] = None, ) -> Iterator[Tuple[int, int, int]]: - if chunk_samples <= 0: - raise ValueError("chunk_samples must be positive") + if chunk_size <= 0: + raise ValueError("chunk_size must be positive") + if subsampling <= 0: + raise ValueError("subsampling must be positive") + + context = right_context + 1 + stride = subsampling * chunk_size + decoding_window = (chunk_size - 1) * subsampling + context index = 0 - start = 0 - while start < total_samples: + for start in range(0, total_frames - context + 1, stride): if max_chunks is not None and index >= max_chunks: break - end = min(start + chunk_samples, total_samples) + end = min(start + decoding_window, total_frames) yield index, start, end index += 1 - start = end + + +def _frame_shift_seconds(configs) -> float: + dataset_conf = configs.get("dataset_conf", {}) + if "frontend_conf" in dataset_conf: + frontend_conf = dataset_conf["frontend_conf"] + sample_rate = frontend_conf.get("fs", 16000) + hop_length = frontend_conf.get("hop_length", 128) + return float(hop_length) / float(sample_rate) + fbank_conf = dataset_conf.get("fbank_conf", {}) + return float(fbank_conf.get("frame_shift", 10)) / 1000.0 + + +def _filter_nonspecial_ids(token_ids: List[int], tokenizer) -> List[int]: + try: + last_time_id = tokenizer.tokens2ids(["<30.00>"])[0] + except Exception: + last_time_id = -1 + return [token_id for token_id in token_ids if token_id > last_time_id] + + +def _ids_to_text(token_ids: List[int], tokenizer) -> str: + if not token_ids: + return "" + return tokenizer.detokenize(_filter_nonspecial_ids(token_ids, tokenizer))[0] + + +def _endpoint_config_from_args(args) -> Optional[CtcEndpointConfig]: + if getattr(args, "disable_endpoint", False): + return None + return CtcEndpointConfig( + blank_threshold=args.endpoint_blank_threshold, + rule1=CtcEndpointRule( + False, + args.endpoint_rule1_min_trailing_silence_ms, + 0, + ), + rule2=CtcEndpointRule( + True, + args.endpoint_rule2_min_trailing_silence_ms, + 0, + ), + rule3=CtcEndpointRule( + False, + 0, + args.endpoint_rule3_min_utterance_length_ms, + ), + ) + + +def _emit_streaming_event( + event: StreamingEvent, + emit: str, + output_state: Optional[StreamingOutputState] = None, +) -> bool: + if event.is_endpoint: + return False + if emit == "line": + print( + f"[{event.index + 1:04d} {_format_time(event.processed_seconds)}] " + f"{event.text}", + flush=True, + ) + if output_state is not None: + output_state.previous_display_text = event.text + return False + + print(event.delta, end="", flush=True) + if output_state is not None: + output_state.previous_display_text = event.text + output_state.wrote_delta = True + return True + + +def _emit_final_rescore( + decoder: "StreamingCtcGreedyDecoder", + args, +) -> str: + final_text = decoder.final_attention_rescore( + beam_size=args.rescore_beam_size, + ctc_weight=args.rescore_ctc_weight, + reverse_weight=args.reverse_weight, + lang_sym=args.lang_sym, + region_sym=args.region_sym, + ) + if final_text: + print(f"[final attention_rescoring] {final_text}", flush=True) + return final_text + + +def _finish_current_segment( + decoder: "StreamingCtcGreedyDecoder", + args, + output_state: StreamingOutputState, + force_rescore: bool = False, +): + if args.emit == "delta" and output_state.wrote_delta: + print(flush=True) + if (force_rescore or args.final_rescore == "attention") and decoder.encoder_out_chunks: + _emit_final_rescore(decoder, args) + decoder.reset_segment() + output_state.reset() + + +class StreamingCtcGreedyDecoder: + def __init__( + self, + model, + tokenizer, + chunk_size: int, + left_chunks: int, + endpoint_config: Optional[CtcEndpointConfig] = None, + ): + import torch + + self.model = model + self.tokenizer = tokenizer + self.chunk_size = chunk_size + self.left_chunks = left_chunks + self.subsampling = int(model.subsampling_rate()) + self.right_context = int(model.right_context()) + self.context = self.right_context + 1 + self.stride = self.subsampling * chunk_size + self.decoding_window = (chunk_size - 1) * self.subsampling + self.context + self.required_cache_size = chunk_size * left_chunks + self.next_start = 0 + self.chunks_processed = 0 + self.endpoint = CtcEndpoint(endpoint_config) if endpoint_config else None + self.reset_segment() + + def reset_segment(self): + import torch + + self.att_cache = torch.zeros((0, 0, 0, 0), device=self.model.device) + self.cnn_cache = torch.zeros((0, 0, 0, 0), device=self.model.device) + self.emitted_ids: List[int] = [] + self.previous_frame_id = 0 + self.previous_text = "" + self.encoder_offset = 0 + self.encoder_out_chunks = [] + if self.endpoint is not None: + self.endpoint.reset() + + def decode_available( + self, + feats, + frame_shift_seconds: float, + max_chunks: Optional[int] = None, + flush: bool = False, + ) -> Iterator[StreamingEvent]: + import torch + + total_frames = feats.size(1) + endpoint_frame_shift_ms = max( + 1, + int(round(frame_shift_seconds * self.subsampling * 1000)), + ) + while self.next_start + self.context <= total_frames: + if max_chunks is not None and self.chunks_processed >= max_chunks: + break + if not flush and self.next_start + self.decoding_window > total_frames: + break + + end = min(self.next_start + self.decoding_window, total_frames) + chunk_feats = feats[:, self.next_start:end, :] + encoder_out, self.att_cache, self.cnn_cache = self.model.forward_encoder_chunk( + chunk_feats, + self.encoder_offset, + self.required_cache_size, + self.att_cache, + self.cnn_cache, + ) + self.encoder_offset += encoder_out.size(1) + self.encoder_out_chunks.append(encoder_out) + + ctc_log_probs = self.model.ctc_activation(encoder_out) + frame_ids = torch.argmax(ctc_log_probs, dim=-1).squeeze(0).tolist() + for token_id in frame_ids: + if token_id != 0 and token_id != self.previous_frame_id: + self.emitted_ids.append(token_id) + self.previous_frame_id = token_id + + text = _ids_to_text(self.emitted_ids, self.tokenizer) + event = None + if text != self.previous_text: + if text.startswith(self.previous_text): + delta = text[len(self.previous_text):] + else: + delta = text + self.previous_text = text + event = StreamingEvent( + index=self.chunks_processed, + processed_seconds=end * frame_shift_seconds, + text=text, + delta=delta, + ) + + self.chunks_processed += 1 + self.next_start += self.stride + if event is not None: + yield event + + if self.endpoint is not None and self.endpoint.update( + ctc_log_probs, + decoded_something=bool(text), + frame_shift_ms=endpoint_frame_shift_ms, + ): + yield StreamingEvent( + index=self.chunks_processed - 1, + processed_seconds=end * frame_shift_seconds, + text=text, + delta="", + is_endpoint=True, + ) + + def final_attention_rescore( + self, + beam_size: int = 10, + ctc_weight: float = 0.0, + reverse_weight: float = 0.0, + lang_sym: Optional[str] = None, + region_sym: Optional[str] = None, + ) -> str: + import torch + from dolphin.search import attention_rescoring, ctc_prefix_beam_search + + if not self.encoder_out_chunks: + return "" + + encoder_out = torch.cat(self.encoder_out_chunks, dim=1) + encoder_lens = torch.tensor( + [encoder_out.size(1)], + dtype=torch.long, + device=encoder_out.device, + ) + encoder_mask = torch.ones( + 1, + 1, + encoder_out.size(1), + dtype=torch.bool, + device=encoder_out.device, + ) + ctc_probs = self.model.ctc_activation(encoder_out) + ctc_prefix_results = ctc_prefix_beam_search( + ctc_probs, + encoder_lens, + beam_size, + ) + + rescore_encoder_out = encoder_out + rescore_encoder_mask = encoder_mask + if getattr(self.model, "apply_non_blank_embedding", False): + rescore_encoder_out, rescore_encoder_mask = self.model.filter_blank_embedding( + ctc_probs, + encoder_out, + ) + + infos = { + "tokenizer": self.tokenizer, + "need_timestamp": False, + "word_timestamp": False, + } + if lang_sym is not None: + infos["langs"] = [f"<{lang_sym}>"] + if region_sym is not None: + infos["regions"] = [f"<{region_sym}>"] + + results = attention_rescoring( + self.model, + ctc_prefix_results, + rescore_encoder_out, + encoder_lens, + ctc_weight=ctc_weight, + reverse_weight=reverse_weight, + infos=infos, + encoder_mask=rescore_encoder_mask, + ) + return _ids_to_text(results[0].tokens, self.tokenizer) def _parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( - description="Experimental Dolphin file-streaming ASR demo." + description="Experimental Dolphin cache-level streaming ASR demo." ) parser.add_argument("audio", type=Path, help="audio file to stream") parser.add_argument( @@ -61,23 +452,29 @@ def _parse_args() -> argparse.Namespace: help="model checkpoint directory; defaults to ~/.cache/dolphin/", ) parser.add_argument("--device", default=None, help="torch device, e.g. cuda or cpu") - parser.add_argument("--lang_sym", default="zh", help="language symbol (default: zh)") - parser.add_argument("--region_sym", default=None, help="region symbol, e.g. CN") - parser.add_argument("--chunk_duration", type=float, default=4.0, help="chunk size in seconds (default: 4.0)") - parser.add_argument("--beam_size", type=int, default=5, help="beam size (default: 5)") + parser.add_argument("--lang_sym", default=None, help="language symbol for final attention rescoring, e.g. zh") + parser.add_argument("--region_sym", default=None, help="region symbol for final attention rescoring, e.g. CN") + parser.add_argument( + "--chunk_size", + "--decoding_chunk_size", + dest="chunk_size", + type=int, + default=16, + help="encoder streaming chunk size in subsampled frames (default: 16)", + ) parser.add_argument( - "--decoding_method", - default="attention_rescoring", - choices=("attention", "attention_rescoring"), - help="decoding method (default: attention_rescoring)", + "--left_chunks", + "--num_decoding_left_chunks", + dest="left_chunks", + type=int, + default=4, + help="left chunks kept by encoder cache; use -1 for all history (default: 4)", ) - parser.add_argument("--decoding_chunk_size", type=int, default=16, help="encoder decoding chunk size (default: 16)") - parser.add_argument("--num_decoding_left_chunks", type=int, default=4, help="left chunks kept by encoder (default: 4)") parser.add_argument( - "--mode", - choices=("chunk", "rolling"), - default="chunk", - help="chunk prints each chunk independently; rolling prints accumulated partial text (default: chunk)", + "--emit", + choices=("delta", "line"), + default="delta", + help="delta prints only newly recognized text; line prints timestamped partials (default: delta)", ) parser.add_argument( "--realtime", @@ -88,72 +485,127 @@ def _parse_args() -> argparse.Namespace: "--max_chunks", type=int, default=None, - help="limit chunks for quick smoke tests", + help="limit encoder chunks for quick smoke tests", + ) + parser.add_argument( + "--final_rescore", + choices=("none", "attention"), + default="none", + help="run final attention rescoring after streaming partials (default: none)", + ) + parser.add_argument( + "--rescore_beam_size", + type=int, + default=10, + help="beam size for final attention rescoring (default: 10)", + ) + parser.add_argument( + "--rescore_ctc_weight", + type=float, + default=0.0, + help="CTC score weight for final attention rescoring (default: 0.0)", + ) + parser.add_argument( + "--reverse_weight", + type=float, + default=0.0, + help="right-to-left decoder weight for final attention rescoring (default: 0.0)", + ) + parser.add_argument( + "--disable_endpoint", + action="store_true", + help="disable CTC endpoint segmentation", + ) + parser.add_argument( + "--endpoint_blank_threshold", + type=float, + default=0.8, + help="blank probability threshold treated as silence (default: 0.8)", + ) + parser.add_argument( + "--endpoint_rule1_min_trailing_silence_ms", + type=int, + default=5000, + help="endpoint rule1: silence timeout without decoded text (default: 5000)", + ) + parser.add_argument( + "--endpoint_rule2_min_trailing_silence_ms", + type=int, + default=1000, + help="endpoint rule2: silence timeout after decoded text (default: 1000)", + ) + parser.add_argument( + "--endpoint_rule3_min_utterance_length_ms", + type=int, + default=20000, + help="endpoint rule3: maximum utterance length (default: 20000)", ) return parser.parse_args() def main() -> int: args = _parse_args() - if args.chunk_duration <= 0: - raise ValueError("--chunk_duration must be positive") + if args.chunk_size <= 0: + raise ValueError("--chunk_size must be positive") repo_root = Path(__file__).resolve().parents[1] if str(repo_root) not in sys.path: sys.path.insert(0, str(repo_root)) - import torch - import dolphin - from dolphin.audio import load_audio - from dolphin.transcribe import transcribe + from dolphin.processor import extract_feats + from dolphin.tokenizer import init_tokenizer model_dir = args.model_dir or Path.home() / ".cache" / "dolphin" / args.model - print(f"loading model: {args.model} ({model_dir})", flush=True) + print(f"loading model: {args.model} ({model_dir})", file=sys.stderr, flush=True) model = dolphin.load_model(args.model, model_dir, args.device) - waveform = torch.from_numpy(load_audio(str(args.audio))).float().unsqueeze(0) + batch = extract_feats([str(args.audio)], model.model_configs) + feats = batch["feats"].to(model.device) + tokenizer = init_tokenizer(model.model_configs) + frame_shift_seconds = _frame_shift_seconds(model.model_configs) - chunk_samples = int(args.chunk_duration * SAMPLE_RATE) - total_samples = waveform.size(-1) + decoder = StreamingCtcGreedyDecoder( + model, + tokenizer, + chunk_size=args.chunk_size, + left_chunks=args.left_chunks, + endpoint_config=_endpoint_config_from_args(args), + ) + chunk_ms = decoder.stride * frame_shift_seconds * 1000 print( f"streaming file: {args.audio} " - f"duration={total_samples / SAMPLE_RATE:.2f}s " - f"chunk={args.chunk_duration:.2f}s mode={args.mode}", + f"frames={feats.size(1)} " + f"chunk_size={args.chunk_size} (~{chunk_ms:.0f}ms) " + f"left_chunks={args.left_chunks}", + file=sys.stderr, flush=True, ) started_at = time.time() - for index, start, end in _iter_chunk_ranges(total_samples, chunk_samples, args.max_chunks): - chunk = waveform[:, :end] if args.mode == "rolling" else waveform[:, start:end] - result = transcribe( - model, - chunk, - lang_sym=args.lang_sym, - region_sym=args.region_sym, - predict_time=False, - word_timestamp=False, - decoding_method=args.decoding_method, - beam_size=args.beam_size, - decoding_chunk_size=args.decoding_chunk_size, - num_decoding_left_chunks=args.num_decoding_left_chunks, - simulate_streaming=True, - ) + output_state = StreamingOutputState() + for event in decoder.decode_available( + feats, + frame_shift_seconds, + max_chunks=args.max_chunks, + flush=True, + ): + if event.is_endpoint: + _finish_current_segment(decoder, args, output_state) + continue - if args.mode == "rolling": - span = f"00:00.000-{_format_time(end / SAMPLE_RATE)}" - label = "partial" - else: - span = f"{_format_time(start / SAMPLE_RATE)}-{_format_time(end / SAMPLE_RATE)}" - label = "chunk" - print(f"[{index + 1:04d} {label} {span}] {result.text_nospecial}", flush=True) + _emit_streaming_event(event, args.emit, output_state) if args.realtime: - target_elapsed = end / SAMPLE_RATE - sleep_for = target_elapsed - (time.time() - started_at) + sleep_for = event.processed_seconds - (time.time() - started_at) if sleep_for > 0: time.sleep(sleep_for) + if args.emit == "delta" and output_state.wrote_delta: + print(flush=True) + + if args.final_rescore == "attention" and decoder.encoder_out_chunks: + _emit_final_rescore(decoder, args) return 0 diff --git a/reports/issue-106-streaming-demo-test-report.md b/reports/issue-106-streaming-demo-test-report.md index 9b6f8df..38b8af5 100644 --- a/reports/issue-106-streaming-demo-test-report.md +++ b/reports/issue-106-streaming-demo-test-report.md @@ -6,7 +6,10 @@ Issue: https://github.com/DataoceanAI/Dolphin/issues/106 ## Scope -- Added `examples/streaming_demo.py` as an experimental terminal file-streaming demo for `small.cn.streaming` and other streaming models. +- Added `examples/streaming_demo.py` as an experimental terminal cache-level streaming demo for `small.cn.streaming` and other streaming models. +- Added `examples/microphone_streaming_demo.py` for live microphone streaming output. +- Added optional final attention rescoring after streaming partials. +- Added WeNet-style CTC endpointing so silence or long utterances finalize and rescore segments automatically. - Added CLI/API passthrough for `decoding_chunk_size`, `num_decoding_left_chunks`, and `simulate_streaming`. - Fixed the SDPA relative-position attention path so streaming chunk inference can run when no explicit mask is supplied. - Updated `README.md` with terminal demo commands. @@ -15,13 +18,20 @@ Issue: https://github.com/DataoceanAI/Dolphin/issues/106 ## How To Run ```shell -python examples/streaming_demo.py audio.wav --model small.cn.streaming --device cuda +python examples/streaming_demo.py audio.wav --model small.cn.streaming --device cuda --chunk_size 16 --final_rescore attention ``` CPU smoke test: ```shell -python examples/streaming_demo.py audio.wav --model small.cn.streaming --device cpu --chunk_duration 3 --max_chunks 1 +python examples/streaming_demo.py audio.wav --model small.cn.streaming --device cpu --chunk_size 16 --emit line --max_chunks 2 +``` + +Microphone demo: + +```shell +python -m pip install sounddevice +python examples/microphone_streaming_demo.py --model small.cn.streaming --device cuda --chunk_size 16 --final_rescore attention ``` ## Verification @@ -30,16 +40,17 @@ python examples/streaming_demo.py audio.wav --model small.cn.streaming --device env TRANSFORMERS_NO_TF=1 USE_TF=0 python -m pytest ``` -Result: `20 passed in 1.17s`. +Result: `23 passed in 1.19s`. ```shell -python -m py_compile examples/streaming_demo.py +python -m py_compile examples/streaming_demo.py examples/microphone_streaming_demo.py ``` Result: passed. ```shell python examples/streaming_demo.py --help +python examples/microphone_streaming_demo.py --help ``` Result: passed. @@ -47,17 +58,63 @@ Result: passed. Real model smoke test: ```shell -env MODELSCOPE_CACHE=/private/tmp/modelscope-cache NUMBA_CACHE_DIR=/private/tmp/numba-cache TRANSFORMERS_NO_TF=1 USE_TF=0 /private/tmp/dolphin-venv/bin/python examples/streaming_demo.py /private/tmp/dolphin-test-audio/zh-cn-demo.wav --model small.cn.streaming --model_dir /private/tmp/dolphin-models/small.cn.streaming --device cpu --chunk_duration 3 --max_chunks 1 +env MODELSCOPE_CACHE=/private/tmp/modelscope-cache NUMBA_CACHE_DIR=/private/tmp/numba-cache TRANSFORMERS_NO_TF=1 USE_TF=0 /private/tmp/dolphin-venv/bin/python examples/streaming_demo.py /private/tmp/dolphin-test-audio/zh-cn-demo.wav --model small.cn.streaming --model_dir /private/tmp/dolphin-models/small.cn.streaming --device cpu --chunk_size 16 --emit line --max_chunks 2 --final_rescore attention +``` + +Result: + +```text +[0002 00:01.310] 诚然 +[final attention_rescoring] 诚然 +``` + +Full short-audio cache-level streaming smoke test: + +```shell +env MODELSCOPE_CACHE=/private/tmp/modelscope-cache NUMBA_CACHE_DIR=/private/tmp/numba-cache TRANSFORMERS_NO_TF=1 USE_TF=0 /private/tmp/dolphin-venv/bin/python examples/streaming_demo.py /private/tmp/dolphin-test-audio/zh-cn-demo.wav --model small.cn.streaming --model_dir /private/tmp/dolphin-models/small.cn.streaming --device cpu --chunk_size 16 --emit line --final_rescore attention +``` + +Result: + +```text +[0002 00:01.310] 诚然 +[0003 00:01.950] 诚然时代正 +[0004 00:02.590] 诚然时代正在推陈 +[0005 00:03.230] 诚然时代正在推陈初新 +[0006 00:03.870] 诚然时代正在推陈初新但文化 +[0007 00:04.510] 诚然时代正在推陈初新但文化之精髓 +[0008 00:05.150] 诚然时代正在推陈初新但文化之精髓切需 +[0009 00:05.790] 诚然时代正在推陈初新但文化之精髓切需传承 +[final attention_rescoring] 诚然时代正在推陈出新但文化之精髓切需传承 +``` + +Forced endpoint smoke test: + +```shell +env MODELSCOPE_CACHE=/private/tmp/modelscope-cache NUMBA_CACHE_DIR=/private/tmp/numba-cache TRANSFORMERS_NO_TF=1 USE_TF=0 /private/tmp/dolphin-venv/bin/python examples/streaming_demo.py /private/tmp/dolphin-test-audio/zh-cn-demo.wav --model small.cn.streaming --model_dir /private/tmp/dolphin-models/small.cn.streaming --device cpu --chunk_size 16 --emit line --final_rescore attention --endpoint_rule3_min_utterance_length_ms 3000 ``` Result: ```text -[0001 chunk 00:00.000-00:03.000] 诚然时代正在推陈出新 +[0002 00:01.310] 诚然 +[0003 00:01.950] 诚然时代正 +[0004 00:02.590] 诚然时代正在推陈 +[0005 00:03.230] 诚然时代正在推陈初新 +[final attention_rescoring] 诚然时代正在推陈出新 +[0006 00:03.870] 文化 +[0007 00:04.510] 文化之精髓 +[0008 00:05.150] 文化之精髓切需 +[0009 00:05.790] 文化之精髓切需传承 +[final attention_rescoring] 文化之精髓切需传承 ``` ## Notes -- The demo streams an existing audio file in fixed-size chunks and prints each decoded chunk immediately. -- `--mode chunk` decodes each chunk independently. `--mode rolling` re-decodes accumulated audio and prints partial text. -- This is intentionally an experimental terminal demo, not a production microphone or socket streaming server. +- `--chunk_size 16` maps to the encoder streaming `decoding_chunk_size`. +- The file demo prepares fbank features from the file first, then runs model inference chunk by chunk with encoder caches. +- The microphone demo uses live audio input and emits CTC greedy partial text as new chunks decode. +- CTC greedy partial text may temporarily contain unstable words before later chunks arrive. +- `--final_rescore attention` runs attention rescoring once at the end and prints `[final attention_rescoring] ...`. +- CTC endpointing is enabled by default: 5s initial silence, 1s trailing silence after decoded text, or 20s max utterance length. +- This is intentionally an experimental terminal demo, not a production streaming server. diff --git a/reports/issue-triage.md b/reports/issue-triage.md index 0cc68c3..bc44340 100644 --- a/reports/issue-triage.md +++ b/reports/issue-triage.md @@ -13,7 +13,7 @@ This record tracks issues reviewed during the Codex pass so future work does not | #80 | Done | Added CLI `--output` and `--output_format {txt,json,srt}`. | | #92 | Done | Added CLI/API punctuation removal via `--remove_punctuation` and `remove_punctuation=True`. | | #93 | Done | Added `dolphin.detect_language(...)`, `--task detect_language`, and `--lid_duration`. | -| #106 | Done | Added experimental `examples/streaming_demo.py`, streaming decode options, and SDPA chunk mask support for `small.cn.streaming`. | +| #106 | Done | Added cache-level file streaming and microphone demos, streaming decode options, and SDPA chunk mask support for `small.cn.streaming`. | ## In Progress diff --git a/tests/test_streaming_demo.py b/tests/test_streaming_demo.py index ec400cc..5245c36 100644 --- a/tests/test_streaming_demo.py +++ b/tests/test_streaming_demo.py @@ -1,11 +1,15 @@ import importlib.util +import sys from pathlib import Path +import torch + def _load_streaming_demo(): path = Path(__file__).resolve().parents[1] / "examples" / "streaming_demo.py" spec = importlib.util.spec_from_file_location("streaming_demo", path) module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module spec.loader.exec_module(module) return module @@ -17,20 +21,63 @@ def test_format_time(): assert demo._format_time(65.4321) == "01:05.432" -def test_iter_chunk_ranges(): +def test_iter_encoder_chunk_ranges(): demo = _load_streaming_demo() - assert list(demo._iter_chunk_ranges(10, 4)) == [ - (0, 0, 4), - (1, 4, 8), - (2, 8, 10), + assert list(demo._iter_encoder_chunk_ranges(140, 16, 4, 6)) == [ + (0, 0, 67), + (1, 64, 131), + (2, 128, 140), ] -def test_iter_chunk_ranges_respects_max_chunks(): +def test_iter_encoder_chunk_ranges_respects_max_chunks(): demo = _load_streaming_demo() - assert list(demo._iter_chunk_ranges(10, 4, max_chunks=2)) == [ - (0, 0, 4), - (1, 4, 8), + assert list(demo._iter_encoder_chunk_ranges(140, 16, 4, 6, max_chunks=2)) == [ + (0, 0, 67), + (1, 64, 131), ] + + +def test_parser_chunk_size_alias(monkeypatch): + demo = _load_streaming_demo() + monkeypatch.setattr( + sys, + "argv", + [ + "streaming_demo.py", + "audio.wav", + "--chunk_size", + "16", + "--left_chunks", + "4", + "--final_rescore", + "attention", + "--rescore_beam_size", + "3", + ], + ) + + args = demo._parse_args() + + assert args.chunk_size == 16 + assert args.left_chunks == 4 + assert args.final_rescore == "attention" + assert args.rescore_beam_size == 3 + + +def test_endpoint_rule2_after_decoded_text_and_silence(): + demo = _load_streaming_demo() + endpoint = demo.CtcEndpoint(demo.CtcEndpointConfig()) + blank_frames = torch.log(torch.tensor([[[0.9, 0.1]] * 25])) + + assert endpoint.update(blank_frames, decoded_something=True, frame_shift_ms=40) + + +def test_endpoint_rule3_long_utterance(): + demo = _load_streaming_demo() + endpoint = demo.CtcEndpoint(demo.CtcEndpointConfig()) + speech_frames = torch.log(torch.tensor([[[0.1, 0.9]] * 500])) + + assert endpoint.update(speech_frames, decoded_something=True, frame_shift_ms=40)