diff --git a/src/openbench/engine/__init__.py b/src/openbench/engine/__init__.py index 8fe912c..04dfd78 100644 --- a/src/openbench/engine/__init__.py +++ b/src/openbench/engine/__init__.py @@ -1,3 +1,12 @@ +from .argmax_oss_engine import ( + ArgmaxOpenSourceEngine, + ArgmaxOpenSourceEngineConfig, + DiarizeCliInput, + DiarizeCliOutput, + TranscriptionCliInput, + TranscriptionCliOutput, + resolve_argmax_oss_cache_dir, +) from .deepgram_engine import DeepgramApi, DeepgramApiResponse from .elevenlabs_engine import ElevenLabsApi, ElevenLabsApiResponse from .openai_engine import OpenAIApi @@ -18,6 +27,13 @@ __all__ = [ + "ArgmaxOpenSourceEngine", + "ArgmaxOpenSourceEngineConfig", + "DiarizeCliInput", + "DiarizeCliOutput", + "TranscriptionCliInput", + "TranscriptionCliOutput", + "resolve_argmax_oss_cache_dir", "DeepgramApi", "DeepgramApiResponse", "ElevenLabsApi", diff --git a/src/openbench/engine/argmax_oss_engine.py b/src/openbench/engine/argmax_oss_engine.py new file mode 100644 index 0000000..02aa9c2 --- /dev/null +++ b/src/openbench/engine/argmax_oss_engine.py @@ -0,0 +1,203 @@ +# For licensing see accompanying LICENSE.md file. +# Copyright (C) 2025 Argmax, Inc. All Rights Reserved. + +"""Argmax SDK open-source CLI (`argmax-cli`) — clone/build, transcribe, and diarize.""" + +from __future__ import annotations + +import os +import subprocess +from pathlib import Path + +from argmaxtools.utils import _maybe_git_clone, get_logger +from pydantic import BaseModel, Field + + +logger = get_logger(__name__) + +ARGMAX_OSS_REPO_URL = "https://github.com/argmaxinc/argmax-oss-swift" +ARGMAX_OSS_PRODUCT = "argmax-cli" +DEFAULT_CACHE_SUBDIR = Path(".cache") / "openbench" / "argmax-oss" + + +def resolve_argmax_oss_cache_dir(explicit: str | Path | None = None) -> Path: + """Absolute cache root for WhisperKit clone + `argmax-cli` build.""" + if explicit is not None and str(explicit).strip(): + return Path(explicit).expanduser().resolve() + env = os.environ.get("ARGMAX_OSS_CACHE_DIR") + if env: + return Path(env).expanduser().resolve() + return (Path.home() / DEFAULT_CACHE_SUBDIR).resolve() + + +class ArgmaxOpenSourceEngineConfig(BaseModel): + """Engine config: cache, optional commit pin, optional prebuilt CLI path.""" + + cache_dir: str | None = Field( + default=None, + description="Directory for cloned repo and build artifacts (absolute after resolve). " + "Overrides ARGMAX_OSS_CACHE_DIR when set.", + ) + commit_hash: str | None = Field( + default=None, + description="Git commit to checkout in the cached WhisperKit clone.", + ) + cli_path: str | None = Field( + default=None, + description="If set, skip clone/build and use this argmax-cli binary.", + ) + + +class TranscriptionCliInput(BaseModel): + """Input for `argmax-cli transcribe`.""" + + audio_path: Path + keep_audio: bool = False + language: str | None = None + + +class TranscriptionCliOutput(BaseModel): + """Output paths from `argmax-cli transcribe --report`.""" + + json_report_path: Path = Field(..., description="JSON report path") + srt_report_path: Path = Field(..., description="SRT report path") + cli_combined_output: str | None = Field( + default=None, + description="Concatenated stdout+stderr when transcribe was run with capture_combined_output=True.", + ) + + +class DiarizeCliInput(BaseModel): + """Input for `argmax-cli diarize`.""" + + audio_path: Path + rttm_path: Path + keep_audio: bool = False + + +class DiarizeCliOutput(BaseModel): + rttm_path: Path = Field(..., description="Written RTTM path") + + +class ArgmaxOpenSourceEngine: + """Resolve `argmax-cli`, then run `transcribe` / `diarize` subcommands.""" + + def __init__(self, config: ArgmaxOpenSourceEngineConfig) -> None: + self.config = config + if config.cli_path: + self.cli_path = str(Path(config.cli_path).expanduser().resolve()) + logger.info(f"Using Argmax OSS CLI at {self.cli_path}") + else: + self.cli_path = self._clone_and_build_cli() + + def _build_cli(self, repo_dir: str) -> str: + """Run release build (swift build -c release, not debug) and return the dir containing the binary.""" + logger.info("Building %s with: swift build -c release (not debug)", ARGMAX_OSS_PRODUCT) + build_cmd = f"swift build -c release --product {ARGMAX_OSS_PRODUCT}" + try: + subprocess.run(build_cmd, cwd=repo_dir, shell=True, check=True) + result = subprocess.run( + f"{build_cmd} --show-bin-path", + cwd=repo_dir, + stdout=subprocess.PIPE, + shell=True, + text=True, + check=True, + ) + except subprocess.CalledProcessError as e: + logger.error("Build failed with return code %s", e.returncode) + logger.error("Stdout:\n%s", getattr(e, "stdout", "")) + logger.error("Stderr:\n%s", getattr(e, "stderr", "")) + raise RuntimeError( + f"Failed to build {ARGMAX_OSS_PRODUCT}: exit {e.returncode}\n" + f"stdout: {getattr(e, 'stdout', '')}\nstderr: {getattr(e, 'stderr', '')}" + ) from e + bin_dir = result.stdout.strip() + cli = Path(bin_dir) / ARGMAX_OSS_PRODUCT + if not cli.is_file(): + raise RuntimeError(f"Expected CLI binary not found: {cli}") + logger.info("Built Argmax OSS CLI at %s", cli) + return bin_dir + + def _clone_and_build_cli(self) -> str: + cache_root = resolve_argmax_oss_cache_dir(self.config.cache_dir) + cache_root.mkdir(parents=True, exist_ok=True) + repo_url_parts = ARGMAX_OSS_REPO_URL.rstrip("/").split("/") + repo_name = repo_url_parts[-1] + repo_owner = repo_url_parts[-2] + + logger.info("Ensuring WhisperKit clone under %s", cache_root) + repo_dir, commit_hash = _maybe_git_clone( + out_dir=str(cache_root), + hub_url="github.com", + repo_name=repo_name, + repo_owner=repo_owner, + commit_hash=self.config.commit_hash, + ) + self.config.commit_hash = commit_hash + logger.info("%s at commit %s — running sanity build", repo_name, commit_hash) + bin_dir = self._build_cli(repo_dir) + return str(Path(bin_dir) / ARGMAX_OSS_PRODUCT) + + def transcribe( + self, + input: TranscriptionCliInput, + transcription_args: list[str], + report_dir: Path, + capture_combined_output: bool = False, + ) -> TranscriptionCliOutput: + """Run `argmax-cli transcribe` with pre-built flag list (see transcription config).""" + cmd = [ + self.cli_path, + "transcribe", + "--audio-path", + str(input.audio_path), + *transcription_args, + ] + if input.language: + cmd.extend(["--language", input.language]) + + logger.debug("Argmax OSS transcribe: %s", cmd) + try: + result = subprocess.run(cmd, check=True, capture_output=True, text=True) + except subprocess.CalledProcessError as e: + raise RuntimeError(f"argmax-cli transcribe failed: {e.stderr}") from e + + json_report_path = report_dir / f"{input.audio_path.stem}.json" + srt_report_path = report_dir / f"{input.audio_path.stem}.srt" + + cli_combined_output: str | None = None + if capture_combined_output: + cli_combined_output = (result.stdout or "") + (result.stderr or "") + + if not input.keep_audio: + input.audio_path.unlink(missing_ok=True) + + return TranscriptionCliOutput( + json_report_path=json_report_path, + srt_report_path=srt_report_path, + cli_combined_output=cli_combined_output, + ) + + def diarize(self, input: DiarizeCliInput, diarize_args: list[str]) -> DiarizeCliOutput: + """Run `argmax-cli diarize` with pre-built flag list (see diarization config).""" + input.rttm_path.parent.mkdir(parents=True, exist_ok=True) + cmd = [ + self.cli_path, + "diarize", + "--audio-path", + str(input.audio_path), + "--rttm-path", + str(input.rttm_path), + *diarize_args, + ] + logger.debug("Argmax OSS diarize: %s", cmd) + try: + subprocess.run(cmd, check=True, capture_output=True, text=True) + except subprocess.CalledProcessError as e: + raise RuntimeError(f"argmax-cli diarize failed: {e.stderr}") from e + + if not input.keep_audio: + input.audio_path.unlink(missing_ok=True) + + return DiarizeCliOutput(rttm_path=input.rttm_path) diff --git a/src/openbench/pipeline/diarization/__init__.py b/src/openbench/pipeline/diarization/__init__.py index 3c3ac9f..95dc956 100644 --- a/src/openbench/pipeline/diarization/__init__.py +++ b/src/openbench/pipeline/diarization/__init__.py @@ -3,6 +3,7 @@ from .aws import * from .common import * +from .diarization_argmax_oss import * from .diarization_deepgram import * from .elevenlabs import * from .nemo import * diff --git a/src/openbench/pipeline/diarization/diarization_argmax_oss.py b/src/openbench/pipeline/diarization/diarization_argmax_oss.py new file mode 100644 index 0000000..74773d3 --- /dev/null +++ b/src/openbench/pipeline/diarization/diarization_argmax_oss.py @@ -0,0 +1,132 @@ +# For licensing see accompanying LICENSE.md file. +# Copyright (C) 2025 Argmax, Inc. All Rights Reserved. + +"""Diarization via Argmax SDK open-source `argmax-cli diarize`.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Callable + +from argmaxtools.utils import get_logger +from pydantic import BaseModel, Field + +from ...dataset import DiarizationSample +from ...engine.argmax_oss_engine import ( + ArgmaxOpenSourceEngine, + ArgmaxOpenSourceEngineConfig, + DiarizeCliInput, + DiarizeCliOutput, +) +from ...pipeline_prediction import DiarizationAnnotation +from ..base import Pipeline, PipelineType, register_pipeline +from .common import DiarizationOutput, DiarizationPipelineConfig + + +logger = get_logger(__name__) + +__all__ = [ + "ArgmaxOpenSourceDiarizationConfig", + "ArgmaxOpenSourceDiarizationPipeline", + "ArgmaxOpenSourceDiarizePipelineInput", +] + +TEMP_AUDIO_DIR = Path("audio_temp_argmax_oss") + + +class ArgmaxOpenSourceDiarizationConfig(DiarizationPipelineConfig): + cache_dir: str | None = Field( + default=None, + description="Cache directory for WhisperKit clone and CLI build. " + "Defaults to ARGMAX_OSS_CACHE_DIR or ~/.cache/openbench/argmax-oss.", + ) + commit_hash: str | None = Field(default=None, description="Optional git commit pin for the clone.") + cli_path: str | None = Field(default=None, description="Prebuilt argmax-cli path; skips clone/build.") + model_path: str | None = Field(default=None, description="--model-path (local diarization models).") + model_repo: str | None = Field( + default=None, + description="--model-repo (HuggingFace). For auth, rely on the CLI / Hugging Face env (e.g. HF_TOKEN) rather than OpenBench config.", + ) + download_model_path: str | None = Field(default=None, description="--download-model-path.") + cluster_distance_threshold: float | None = Field( + default=None, + description="--cluster-distance-threshold (default on CLI is 0.6).", + ) + use_exclusive_reconciliation: bool = Field(default=False, description="--use-exclusive-reconciliation.") + disable_full_redundancy: bool = Field(default=False, description="--disable-full-redundancy.") + verbose: bool = Field(default=False, description="--verbose.") + num_speakers: int | None = Field( + default=None, + description="Optional static --num-speakers. Ignored when use_exact_num_speakers derives count from reference.", + ) + + def generate_diarize_cli_args(self) -> list[str]: + args: list[str] = [] + if self.model_path: + args.extend(["--model-path", self.model_path]) + if self.model_repo: + args.extend(["--model-repo", self.model_repo]) + if self.download_model_path: + args.extend(["--download-model-path", self.download_model_path]) + if self.cluster_distance_threshold is not None: + args.extend(["--cluster-distance-threshold", str(self.cluster_distance_threshold)]) + if self.use_exclusive_reconciliation: + args.append("--use-exclusive-reconciliation") + if self.disable_full_redundancy: + args.append("--disable-full-redundancy") + if self.verbose: + args.append("--verbose") + return args + + +class ArgmaxOpenSourceDiarizePipelineInput(BaseModel): + audio_path: Path + rttm_path: Path + num_speakers: int | None = None + + +@register_pipeline +class ArgmaxOpenSourceDiarizationPipeline(Pipeline): + _config_class = ArgmaxOpenSourceDiarizationConfig + pipeline_type = PipelineType.DIARIZATION + + def build_pipeline(self) -> Callable[[ArgmaxOpenSourceDiarizePipelineInput], DiarizeCliOutput]: + engine = ArgmaxOpenSourceEngine( + ArgmaxOpenSourceEngineConfig( + cache_dir=self.config.cache_dir, + commit_hash=self.config.commit_hash, + cli_path=self.config.cli_path, + ) + ) + diarize_args_base = self.config.generate_diarize_cli_args() + + def run(inp: ArgmaxOpenSourceDiarizePipelineInput) -> DiarizeCliOutput: + args = list(diarize_args_base) + if inp.num_speakers is not None: + args.extend(["--num-speakers", str(inp.num_speakers)]) + return engine.diarize( + DiarizeCliInput( + audio_path=inp.audio_path, + rttm_path=inp.rttm_path, + keep_audio=False, + ), + args, + ) + + return run + + def parse_input(self, input_sample: DiarizationSample) -> ArgmaxOpenSourceDiarizePipelineInput: + audio_path = input_sample.save_audio(TEMP_AUDIO_DIR) + rttm_path = audio_path.with_suffix(".rttm") + num_speakers: int | None = self.config.num_speakers + if self.config.use_exact_num_speakers: + num_speakers = len(set(input_sample.annotation.speakers)) + return ArgmaxOpenSourceDiarizePipelineInput( + audio_path=audio_path, + rttm_path=rttm_path, + num_speakers=num_speakers, + ) + + def parse_output(self, output: DiarizeCliOutput) -> DiarizationOutput: + prediction = DiarizationAnnotation.load_annotation_file(str(output.rttm_path)) + return DiarizationOutput(prediction=prediction) diff --git a/src/openbench/pipeline/orchestration/__init__.py b/src/openbench/pipeline/orchestration/__init__.py index 0e7870f..489b501 100644 --- a/src/openbench/pipeline/orchestration/__init__.py +++ b/src/openbench/pipeline/orchestration/__init__.py @@ -2,6 +2,10 @@ # Copyright (C) 2025 Argmax, Inc. All Rights Reserved. from .nemo import NeMoMTParakeetPipeline, NeMoMTParakeetPipelineConfig +from .orchestration_argmax_oss import ( + ArgmaxOpenSourceOrchestrationConfig, + ArgmaxOpenSourceOrchestrationPipeline, +) from .orchestration_deepgram import DeepgramOrchestrationPipeline, DeepgramOrchestrationPipelineConfig from .orchestration_elevenlabs import ElevenLabsOrchestrationPipeline, ElevenLabsOrchestrationPipelineConfig from .orchestration_openai import OpenAIOrchestrationPipeline, OpenAIOrchestrationPipelineConfig @@ -11,6 +15,8 @@ __all__ = [ + "ArgmaxOpenSourceOrchestrationPipeline", + "ArgmaxOpenSourceOrchestrationConfig", "DeepgramOrchestrationPipeline", "DeepgramOrchestrationPipelineConfig", "ElevenLabsOrchestrationPipeline", diff --git a/src/openbench/pipeline/orchestration/orchestration_argmax_oss.py b/src/openbench/pipeline/orchestration/orchestration_argmax_oss.py new file mode 100644 index 0000000..ea468fe --- /dev/null +++ b/src/openbench/pipeline/orchestration/orchestration_argmax_oss.py @@ -0,0 +1,180 @@ +# For licensing see accompanying LICENSE.md file. +# Copyright (C) 2025 Argmax, Inc. All Rights Reserved. + +"""Orchestration: `argmax-cli transcribe --diarization`; prediction from RTTM-like verbose log (word + speaker, no timestamps).""" + +from __future__ import annotations + +from pathlib import Path +from typing import Callable + +from argmaxtools.utils import get_logger +from pydantic import BaseModel, Field + +from ...dataset import OrchestrationSample +from ...engine.argmax_oss_engine import ( + ArgmaxOpenSourceEngine, + ArgmaxOpenSourceEngineConfig, + TranscriptionCliInput, +) +from ...pipeline_prediction import Transcript, Word +from ..base import Pipeline, PipelineType, register_pipeline +from .common import OrchestrationConfig, OrchestrationOutput + + +logger = get_logger(__name__) + +TEMP_AUDIO_DIR = Path("./temp_audio_argmax_oss_orch") +ARGMAX_OSS_ORCH_DEFAULT_REPORT = "./argmax_oss_orchestration_reports" + +SPEAKER_DIARIZATION_MARKER = "---- Speaker Diarization Results ----" +_MIN_SPEAKER_LINE_PARTS = 9 + + +def _slice_rttm_like_block(cli_log: str) -> str: + if SPEAKER_DIARIZATION_MARKER not in cli_log: + logger.warning( + "Diarization marker %r not found in CLI output; using empty RTTM-like block", + SPEAKER_DIARIZATION_MARKER, + ) + return "" + return cli_log.split(SPEAKER_DIARIZATION_MARKER, 1)[1].strip() + + +def _words_from_rttm_like_text(text: str) -> list[Word]: + words: list[Word] = [] + for line in text.splitlines(): + parts = line.split() + if len(parts) < _MIN_SPEAKER_LINE_PARTS or parts[0] != "SPEAKER": + continue + speaker = parts[-3] + transcript_words = parts[5:-4] + words.extend(Word(word=w, speaker=speaker, start=None, end=None) for w in transcript_words) + return words + + +class ArgmaxOpenSourceOrchestrationConfig(OrchestrationConfig): + cache_dir: str | None = Field( + default=None, + description="Cache directory for WhisperKit clone / argmax-cli build.", + ) + commit_hash: str | None = Field(default=None, description="Optional git commit pin.") + cli_path: str | None = Field(default=None, description="Prebuilt argmax-cli binary.") + model_version: str = Field(default="base", description="--model for transcribe.") + model_prefix: str = Field(default="openai", description="--model-prefix.") + word_timestamps: bool = Field( + default=False, + description="--word-timestamps on transcribe (optional; affects JSON/SRT written under report_path only).", + ) + chunking_strategy: str | None = Field(default="vad") + report_path: str | None = Field( + default=ARGMAX_OSS_ORCH_DEFAULT_REPORT, + description="Report directory for JSON/SRT (--report-path).", + ) + prompt: str | None = None + text_decoder_compute_units: str = Field(default="cpuAndNeuralEngine") + audio_encoder_compute_units: str = Field(default="cpuAndNeuralEngine") + diarization_model_path: str | None = Field( + default=None, + description="Optional --diarization-model-path for transcribe.", + ) + diarization_model_repo: str | None = Field( + default=None, + description="Optional --diarization-model-repo for transcribe.", + ) + + def create_report_path(self) -> Path: + path = Path(self.report_path or ".").resolve() + path.mkdir(parents=True, exist_ok=True) + return path + + def generate_transcribe_cli_args(self, report_dir: Path) -> list[str]: + args = [ + "--model", + self.model_version, + "--model-prefix", + self.model_prefix, + "--report", + "--chunking-strategy", + str(self.chunking_strategy or "vad"), + ] + if self.word_timestamps: + args.append("--word-timestamps") + args.extend( + [ + "--report-path", + str(report_dir), + "--text-decoder-compute-units", + self.text_decoder_compute_units, + "--audio-encoder-compute-units", + self.audio_encoder_compute_units, + ] + ) + if self.prompt: + args.extend(["--prompt", self.prompt]) + return args + + def generate_transcribe_diarization_args(self) -> list[str]: + args: list[str] = ["--diarization", "--verbose"] + if self.diarization_model_path: + args.extend(["--diarization-model-path", self.diarization_model_path]) + if self.diarization_model_repo: + args.extend(["--diarization-model-repo", self.diarization_model_repo]) + return args + + +class ArgmaxOpenSourceOrchestrationPipelineInput(BaseModel): + audio_path: Path + language: str | None = None + + +@register_pipeline +class ArgmaxOpenSourceOrchestrationPipeline(Pipeline): + _config_class = ArgmaxOpenSourceOrchestrationConfig + pipeline_type = PipelineType.ORCHESTRATION + + def build_pipeline( + self, + ) -> Callable[[ArgmaxOpenSourceOrchestrationPipelineInput], tuple[str, Path]]: + engine = ArgmaxOpenSourceEngine( + ArgmaxOpenSourceEngineConfig( + cache_dir=self.config.cache_dir, + commit_hash=self.config.commit_hash, + cli_path=self.config.cli_path, + ) + ) + report_dir = self.config.create_report_path() + transcribe_args = ( + self.config.generate_transcribe_cli_args(report_dir) + self.config.generate_transcribe_diarization_args() + ) + + def run(inp: ArgmaxOpenSourceOrchestrationPipelineInput) -> tuple[str, Path]: + t_out = engine.transcribe( + TranscriptionCliInput(audio_path=inp.audio_path, keep_audio=False, language=inp.language), + transcribe_args, + report_dir, + capture_combined_output=True, + ) + log = t_out.cli_combined_output or "" + rttm_like = _slice_rttm_like_block(log) + return (rttm_like, t_out.json_report_path) + + return run + + def parse_input(self, input_sample: OrchestrationSample) -> ArgmaxOpenSourceOrchestrationPipelineInput: + language = None + if self.config.force_language: + language = input_sample.language + return ArgmaxOpenSourceOrchestrationPipelineInput( + audio_path=input_sample.save_audio(TEMP_AUDIO_DIR), + language=language, + ) + + def parse_output(self, output: tuple[str, Path]) -> OrchestrationOutput: + rttm_like_text, _ = output + rttm_words = _words_from_rttm_like_text(rttm_like_text) + return OrchestrationOutput( + prediction=Transcript(words=rttm_words), + diarization_output=None, + transcription_output=None, + ) diff --git a/src/openbench/pipeline/pipeline_aliases.py b/src/openbench/pipeline/pipeline_aliases.py index b96ba80..6dd8ef1 100644 --- a/src/openbench/pipeline/pipeline_aliases.py +++ b/src/openbench/pipeline/pipeline_aliases.py @@ -6,6 +6,7 @@ import os from .diarization import ( + ArgmaxOpenSourceDiarizationPipeline, AWSTranscribePipeline, DeepgramDiarizationPipeline, ElevenLabsDiarizationPipeline, @@ -16,6 +17,7 @@ SpeakerKitPipeline, ) from .orchestration import ( + ArgmaxOpenSourceOrchestrationPipeline, DeepgramOrchestrationPipeline, ElevenLabsOrchestrationPipeline, NeMoMTParakeetPipeline, @@ -33,6 +35,7 @@ OpenAIStreamingPipeline, ) from .transcription import ( + ArgmaxOpenSourceTranscriptionPipeline, AssemblyAITranscriptionPipeline, DeepgramTranscriptionPipeline, ElevenLabsTranscriptionPipeline, @@ -42,7 +45,6 @@ PyannoteTranscriptionPipeline, SpeechAnalyzerPipeline, WhisperKitProTranscriptionPipeline, - WhisperKitTranscriptionPipeline, WhisperOSSTranscriptionPipeline, ) @@ -117,6 +119,19 @@ def register_pipeline_aliases() -> None: description="SpeakerKit speaker diarization pipeline. Requires CLI installation and API key. Set `SPEAKERKIT_CLI_PATH` and `SPEAKERKIT_API_KEY` env vars. For access to the CLI binary contact speakerkitpro@argmaxinc.com", ) + PipelineRegistry.register_alias( + "argmax-oss-diarization", + ArgmaxOpenSourceDiarizationPipeline, + default_config={ + "out_dir": "./argmax_oss_diarization_reports", + }, + description=( + "Argmax SDK open-source diarization via `argmax-cli diarize`. " + "Clone/build under ARGMAX_OSS_CACHE_DIR (default ~/.cache/openbench/argmax-oss) unless `cli_path` is set." + "Uses pyannote's community-1 model for diarization." + ), + ) + PipelineRegistry.register_alias( "picovoice-diarization", PicovoicePipeline, @@ -339,39 +354,59 @@ def register_pipeline_aliases() -> None: description="PyannoteAI orchestration pipeline (diarization + transcription). Uses the precision-2 model with Nvidia Parakeet STT. Requires `PYANNOTE_TOKEN` env var from https://www.pyannote.ai/.", ) + PipelineRegistry.register_alias( + "argmax-oss-orchestration-tiny", + ArgmaxOpenSourceOrchestrationPipeline, + default_config={ + "out_dir": "./argmax_oss_orchestration_reports", + "model_version": "tiny", + "word_timestamps": False, + "chunking_strategy": "vad", + }, + description="Argmax SDK (OSS): `argmax-cli transcribe --diarization` with verbose RTTM-like diarization log " + "as the transcript (word + speaker per token). Cache: ARGMAX_OSS_CACHE_DIR or default; optional `cli_path`.", + ) + ################# TRANSCRIPTION PIPELINES ################# PipelineRegistry.register_alias( "whisperkit-tiny", - WhisperKitTranscriptionPipeline, + ArgmaxOpenSourceTranscriptionPipeline, default_config={ "model_version": "tiny", "word_timestamps": True, "chunking_strategy": "vad", }, - description="WhisperKit transcription pipeline (open-source version) using the tiny version of the model. Requires Swift and Xcode installed.", + description="Argmax SDK (open source) transcription via `argmax-cli` (Swift release build, not debug). Model tiny. " + "Cache: ARGMAX_OSS_CACHE_DIR or ~/.cache/openbench/argmax-oss unless `cli_path` is set. " + "For `openbench-cli evaluate`, prefer `-d earnings22-3hours` over `librispeech-200`: longer clips amortize " + "first-run model load so speed factor is less misleading than on very short utterances.", ) PipelineRegistry.register_alias( "whisperkit-large-v3", - WhisperKitTranscriptionPipeline, + ArgmaxOpenSourceTranscriptionPipeline, default_config={ "model_version": "large-v3", "word_timestamps": True, "chunking_strategy": "vad", }, - description="WhisperKit transcription pipeline (open-source version) using the large-v3 version of the model. Requires Swift and Xcode installed.", + description="Argmax SDK (open source) transcription via `argmax-cli` (release build). Model large-v3. " + "Cache: ARGMAX_OSS_CACHE_DIR or default. For transcription benchmarks use `-d earnings22-3hours`; " + "`librispeech-200` is better reserved for quick WER smoke tests.", ) PipelineRegistry.register_alias( "whisperkit-large-v3-turbo", - WhisperKitTranscriptionPipeline, + ArgmaxOpenSourceTranscriptionPipeline, default_config={ "model_version": "large-v3-v20240930", "word_timestamps": True, "chunking_strategy": "vad", }, - description="WhisperKit transcription pipeline (open-source version) using the large-v3-v20240930 version of the model (which is the same as large-v3-turbo from OpenAI). Requires Swift and Xcode installed.", + description="Argmax SDK (open source) transcription via `argmax-cli` (release build). Model large-v3-v20240930. " + "Cache: ARGMAX_OSS_CACHE_DIR or default. Prefer `-d earnings22-3hours` for evaluate; short LibriSpeech " + "clips skew speed factor because of fixed startup cost per file.", ) PipelineRegistry.register_alias( diff --git a/src/openbench/pipeline/transcription/__init__.py b/src/openbench/pipeline/transcription/__init__.py index f465293..ce7b34a 100644 --- a/src/openbench/pipeline/transcription/__init__.py +++ b/src/openbench/pipeline/transcription/__init__.py @@ -3,6 +3,10 @@ from .apple_speech_analyzer import SpeechAnalyzerConfig, SpeechAnalyzerPipeline from .common import TranscriptionOutput +from .transcription_argmax_oss import ( + ArgmaxOpenSourceTranscriptionConfig, + ArgmaxOpenSourceTranscriptionPipeline, +) from .transcription_assemblyai import AssemblyAITranscriptionPipeline, AssemblyAITranscriptionPipelineConfig from .transcription_deepgram import DeepgramTranscriptionPipeline, DeepgramTranscriptionPipelineConfig from .transcription_elevenlabs import ElevenLabsTranscriptionPipeline, ElevenLabsTranscriptionPipelineConfig @@ -12,7 +16,6 @@ from .transcription_oss_whisper import WhisperOSSTranscriptionPipeline, WhisperOSSTranscriptionPipelineConfig from .transcription_pyannote import PyannoteTranscriptionPipeline, PyannoteTranscriptionPipelineConfig from .transcription_whisperkitpro import WhisperKitProTranscriptionConfig, WhisperKitProTranscriptionPipeline -from .whisperkit import WhisperKitTranscriptionConfig, WhisperKitTranscriptionPipeline __all__ = [ @@ -21,8 +24,8 @@ "SpeechAnalyzerConfig", "AssemblyAITranscriptionPipeline", "AssemblyAITranscriptionPipelineConfig", - "WhisperKitTranscriptionPipeline", - "WhisperKitTranscriptionConfig", + "ArgmaxOpenSourceTranscriptionPipeline", + "ArgmaxOpenSourceTranscriptionConfig", "WhisperKitProTranscriptionPipeline", "WhisperKitProTranscriptionConfig", "OpenAITranscriptionPipeline", diff --git a/src/openbench/pipeline/transcription/transcription_argmax_oss.py b/src/openbench/pipeline/transcription/transcription_argmax_oss.py new file mode 100644 index 0000000..6ccc130 --- /dev/null +++ b/src/openbench/pipeline/transcription/transcription_argmax_oss.py @@ -0,0 +1,161 @@ +# For licensing see accompanying LICENSE.md file. +# Copyright (C) 2025 Argmax, Inc. All Rights Reserved. + +import json +from pathlib import Path +from typing import Callable + +from argmaxtools.utils import get_logger +from pydantic import Field + +from ...dataset import TranscriptionSample +from ...engine.argmax_oss_engine import ( + ArgmaxOpenSourceEngine, + ArgmaxOpenSourceEngineConfig, + TranscriptionCliInput, + TranscriptionCliOutput, +) +from ...pipeline_prediction import Transcript +from ..base import Pipeline, PipelineType, register_pipeline +from .common import TranscriptionConfig, TranscriptionOutput + + +logger = get_logger(__name__) + +TEMP_AUDIO_DIR = Path("./temp_audio") +ARGMAX_OSS_DEFAULT_REPORT_PATH = "./argmax_oss_transcription_reports" + + +class ArgmaxOpenSourceTranscriptionConfig(TranscriptionConfig): + """Configuration for Argmax SDK open-source CLI (`argmax-cli`) transcription.""" + + cache_dir: str | None = Field( + default=None, + description="Cache directory for WhisperKit clone and CLI build. " + "Defaults to ARGMAX_OSS_CACHE_DIR or ~/.cache/openbench/argmax-oss.", + ) + commit_hash: str | None = Field( + default=None, + description="Optional git commit to checkout before building argmax-cli.", + ) + cli_path: str | None = Field( + default=None, + description="If set, use this argmax-cli binary instead of clone/build.", + ) + model_version: str = Field( + default="base", + description="Passed as --model (e.g. tiny, base, small, large-v3, large-v3-v20240930).", + ) + model_prefix: str = Field( + default="openai", + description="Passed as --model-prefix.", + ) + word_timestamps: bool = Field( + default=True, + description="Whether to request --word-timestamps.", + ) + chunking_strategy: str | None = Field( + default="vad", + description="Chunking strategy: none or vad.", + ) + report_path: str | None = Field( + default=ARGMAX_OSS_DEFAULT_REPORT_PATH, + description="Directory for JSON/SRT reports (--report-path).", + ) + prompt: str | None = Field( + default=None, + description="Optional --prompt for decoding.", + ) + text_decoder_compute_units: str = Field( + default="cpuAndNeuralEngine", + description="--text-decoder-compute-units", + ) + audio_encoder_compute_units: str = Field( + default="cpuAndNeuralEngine", + description="--audio-encoder-compute-units", + ) + + def create_report_path(self) -> Path: + if self.report_path is None: + return Path.cwd() + report_dir = Path(self.report_path) + report_dir.mkdir(parents=True, exist_ok=True) + logger.info("Argmax OSS transcription report dir: %s", report_dir) + return report_dir.resolve() + + def generate_cli_args(self) -> list[str]: + args = [ + "--model", + self.model_version, + "--model-prefix", + self.model_prefix, + "--report", + ] + if self.chunking_strategy: + args.extend(["--chunking-strategy", self.chunking_strategy]) + if self.word_timestamps: + args.append("--word-timestamps") + if self.report_path: + args.extend(["--report-path", str(Path(self.report_path).resolve())]) + if self.prompt: + args.extend(["--prompt", self.prompt]) + args.extend( + [ + "--text-decoder-compute-units", + self.text_decoder_compute_units, + "--audio-encoder-compute-units", + self.audio_encoder_compute_units, + ] + ) + logger.info("Argmax OSS transcribe CLI args: %s", args) + return args + + +@register_pipeline +class ArgmaxOpenSourceTranscriptionPipeline(Pipeline): + _config_class = ArgmaxOpenSourceTranscriptionConfig + pipeline_type = PipelineType.TRANSCRIPTION + + def build_pipeline(self) -> Callable[[TranscriptionCliInput], TranscriptionCliOutput]: + engine = ArgmaxOpenSourceEngine( + ArgmaxOpenSourceEngineConfig( + cache_dir=self.config.cache_dir, + commit_hash=self.config.commit_hash, + cli_path=self.config.cli_path, + ) + ) + transcription_args = self.config.generate_cli_args() + report_dir = self.config.create_report_path() + + def transcribe(inp: TranscriptionCliInput) -> TranscriptionCliOutput: + return engine.transcribe(inp, transcription_args, report_dir) + + return transcribe + + def parse_input(self, input_sample: TranscriptionSample) -> TranscriptionCliInput: + language = None + if self.config.force_language: + language = input_sample.language + + return TranscriptionCliInput( + audio_path=input_sample.save_audio(TEMP_AUDIO_DIR), + keep_audio=False, + language=language, + ) + + def parse_output(self, output: TranscriptionCliOutput) -> TranscriptionOutput: + with output.json_report_path.open() as f: + data = json.load(f) + + words: list[str] = [] + start: list[float | None] = [] + end: list[float | None] = [] + for segment in data["segments"]: + for word in segment.get("words", []): + words.append(word["word"]) + start.append(word["start"] if "start" in word else None) + end.append(word["end"] if "end" in word else None) + + return TranscriptionOutput( + prediction=Transcript.from_words_info(words=words, start=start, end=end), + ) diff --git a/src/openbench/pipeline/transcription/whisperkit.py b/src/openbench/pipeline/transcription/whisperkit.py deleted file mode 100644 index 2e0f444..0000000 --- a/src/openbench/pipeline/transcription/whisperkit.py +++ /dev/null @@ -1,284 +0,0 @@ -# For licensing see accompanying LICENSE.md file. -# Copyright (C) 2025 Argmax, Inc. All Rights Reserved. - -import json -import os -import subprocess -from pathlib import Path -from typing import Callable - -from argmaxtools.utils import _maybe_git_clone, get_logger -from pydantic import BaseModel, Field - -from ...dataset import TranscriptionSample -from ...pipeline_prediction import Transcript -from ..base import Pipeline, PipelineType, register_pipeline -from .common import TranscriptionConfig, TranscriptionOutput - - -logger = get_logger(__name__) - -# Constants -WHISPERKIT_REPO_URL = "https://github.com/argmaxinc/WhisperKit" -PRODUCT_NAME = "whisperkit-cli" -TEMP_AUDIO_DIR = Path("./temp_audio") -WHISPERKIT_DEFAULT_REPORT_PATH = "./whisperkit_report" - - -class WhisperKitTranscriptionConfig(TranscriptionConfig): - """Configuration for WhisperKit transcription operations.""" - - model_version: str = Field( - default="base", - description="The version of the WhisperKit model to use (e.g., 'tiny', 'base', 'small', 'large-v3')", - ) - word_timestamps: bool = Field( - default=True, - description="Whether to include word timestamps in the output", - ) - chunking_strategy: str | None = Field( - default="vad", - description="The chunking strategy to use either `none` or `vad`", - ) - report_path: str | None = Field( - default=WHISPERKIT_DEFAULT_REPORT_PATH, - description="The path to the directory where the report files will be saved. If not provided, the report files will be saved in the current working directory.", - ) - prompt: str | None = Field( - default=None, - description="Initial prompt for transcription", - ) - text_decoder_compute_units: str = Field( - default="cpuAndNeuralEngine", - description="Compute units for text decoder", - ) - audio_encoder_compute_units: str = Field( - default="cpuAndNeuralEngine", - description="Compute units for audio encoder", - ) - - def create_report_path(self) -> Path: - if self.report_path is None: - return Path.cwd() - - report_dir = Path(self.report_path) - - if report_dir.exists(): - logger.info(f"Report dir already exists for WhisperKit at: {report_dir}") - return report_dir - - report_dir.mkdir(parents=True, exist_ok=True) - logger.info(f"Created report dir for WhisperKit at: {report_dir}") - return report_dir - - def generate_cli_args(self) -> list[str]: - args = [ - "--model", - self.model_version, - "--report", # Always generate the report files - ] - - if self.chunking_strategy: - args.extend(["--chunking-strategy", self.chunking_strategy]) - if self.word_timestamps: - args.append("--word-timestamps") - if self.report_path: - args.extend(["--report-path", self.report_path]) - if self.prompt: - args.extend(["--prompt", f'"{self.prompt}"']) - if self.text_decoder_compute_units: - args.extend(["--text-decoder-compute-units", self.text_decoder_compute_units]) - if self.audio_encoder_compute_units: - args.extend(["--audio-encoder-compute-units", self.audio_encoder_compute_units]) - - logger.info(f"Generating CLI args for WhisperKit: {args}") - return args - - -class WhisperKitEngineConfig(BaseModel): - """Base configuration for WhisperKit operations.""" - - commit_hash: str | None = Field( - default=None, - description="The commit hash of the WhisperKit repo when cloning", - ) - cli_path: str | None = Field( - default=None, - description="The path to the WhisperKit CLI", - ) - clone_dir: str = Field( - default="./whisperkit_repo", - description="Directory to clone and build the CLI", - ) - - -class TranscriptionCliInput(BaseModel): - """Input for transcription CLI.""" - - audio_path: Path - keep_audio: bool = False - language: str | None = None - - -class TranscriptionCliOutput(BaseModel): - """Output for transcription CLI.""" - - json_report_path: Path = Field( - ..., - description="Path to the JSON report with transcription results", - ) - srt_report_path: Path = Field( - ..., - description="Path to the .srt file containing transcription results", - ) - - -class WhisperKitEngine: - """Unified CLI interface for WhisperKit operations.""" - - def __init__( - self, - config: WhisperKitEngineConfig, - transcription_config: WhisperKitTranscriptionConfig, - ): - self.config = config - self.cli_path = config.cli_path or self._clone_and_build_cli() - self.transcription_config = transcription_config - self.transcription_args = self.transcription_config.generate_cli_args() - self.transcription_config.create_report_path() - - def _clone_and_build_cli(self) -> str: - """Clone the repository and build the CLI.""" - os.makedirs(self.config.clone_dir, exist_ok=True) - if not WHISPERKIT_REPO_URL: - raise ValueError("Repository URL is not set") - - logger.info(f"Cloning repo {WHISPERKIT_REPO_URL} into {self.config.clone_dir}") - repo_name = WHISPERKIT_REPO_URL.split("/")[-1] - repo_owner = WHISPERKIT_REPO_URL.split("/")[-2] - - repo_dir, commit_hash = _maybe_git_clone( - out_dir=self.config.clone_dir, - hub_url="github.com", - repo_name=repo_name, - repo_owner=repo_owner, - commit_hash=self.config.commit_hash, - ) - logger.info(f"{repo_name} -> Commit hash: {commit_hash}") - - try: - build_dir = self._build_cli(repo_dir) - cli_path = os.path.join(build_dir, PRODUCT_NAME) - self.config.commit_hash = commit_hash - return cli_path - except subprocess.CalledProcessError as e: - logger.error(f"Build failed with return code {e.returncode}") - logger.error(f"Build stdout:\n{e.stdout}") - logger.error(f"Build stderr:\n{e.stderr}") - raise RuntimeError( - f"Failed to build CLI: Exit code {e.returncode}\nStdout: {e.stdout}\nStderr: {e.stderr}" - ) - - def _build_cli(self, repo_dir: str) -> str: - """Build the CLI and return the build directory path.""" - logger.info(f"Building {PRODUCT_NAME} CLI...") - - build_cmd = f"swift build -c release --product {PRODUCT_NAME}" - - subprocess.run( - build_cmd, - cwd=repo_dir, - shell=True, - check=True, - ) - logger.info(f"Successfully built {PRODUCT_NAME} CLI!") - - result = subprocess.run( - f"{build_cmd} --show-bin-path", - cwd=repo_dir, - stdout=subprocess.PIPE, - shell=True, - text=True, - check=True, - ) - return result.stdout.strip() - - def transcribe(self, input: TranscriptionCliInput) -> TranscriptionCliOutput: - """Run transcription on the given audio file.""" - cmd = [ - self.cli_path, - "transcribe", - "--audio-path", - str(input.audio_path), - *self.transcription_args, - ] - if input.language: - cmd.extend(["--language", input.language]) - - logger.debug(f"Running WhisperKit CLI: {cmd}") - - report_dir = self.transcription_config.create_report_path() - if not report_dir: - raise ValueError("Report directory not configured") - - try: - subprocess.run(cmd, check=True, capture_output=True, text=True) - except subprocess.CalledProcessError as e: - raise RuntimeError(f"CLI command failed: {e.stderr}") - - if not input.keep_audio: - input.audio_path.unlink(missing_ok=True) - - json_report_path = report_dir / input.audio_path.with_suffix(".json").name - srt_report_path = report_dir / input.audio_path.with_suffix(".srt").name - - return TranscriptionCliOutput( - json_report_path=json_report_path, - srt_report_path=srt_report_path, - ) - - -@register_pipeline -class WhisperKitTranscriptionPipeline(Pipeline): - _config_class = WhisperKitTranscriptionConfig - pipeline_type = PipelineType.TRANSCRIPTION - - def build_pipeline(self) -> Callable[[TranscriptionCliInput], TranscriptionCliOutput]: - # Create WhisperKit engine - engine_config = WhisperKitEngineConfig( - clone_dir="./whisperkit_repo", - ) - - engine = WhisperKitEngine( - config=engine_config, - transcription_config=self.config, - ) - - return engine.transcribe - - def parse_input(self, input_sample: TranscriptionSample) -> TranscriptionCliInput: - # Extract language if force_language is enabled - language = None - if self.config.force_language: - language = input_sample.language - - return TranscriptionCliInput( - audio_path=input_sample.save_audio(TEMP_AUDIO_DIR), - keep_audio=False, - language=language, - ) - - def parse_output(self, output: TranscriptionCliOutput) -> TranscriptionOutput: - """Parse JSON output file into TranscriptionOutput.""" - with output.json_report_path.open("r") as f: - data = json.load(f) - - transcript = Transcript.from_words_info( - words=[word["word"] for segment in data["segments"] for word in segment["words"]], - start=[word["start"] for segment in data["segments"] for word in segment["words"] if "start" in word], - end=[word["end"] for segment in data["segments"] for word in segment["words"] if "end" in word], - ) - - return TranscriptionOutput( - prediction=transcript, - )