Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/openbench/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
from .argmax_oss_engine import (
ArgmaxOpenSourceEngine,
ArgmaxOpenSourceEngineConfig,
DiarizeCliInput,
DiarizeCliOutput,
TranscriptionCliInput,
TranscriptionCliOutput,
resolve_argmax_oss_cache_dir,
)
from .deepgram_engine import DeepgramApi, DeepgramApiResponse
from .elevenlabs_engine import ElevenLabsApi, ElevenLabsApiResponse
from .openai_engine import OpenAIApi
Expand All @@ -18,6 +27,13 @@


__all__ = [
"ArgmaxOpenSourceEngine",
"ArgmaxOpenSourceEngineConfig",
"DiarizeCliInput",
"DiarizeCliOutput",
"TranscriptionCliInput",
"TranscriptionCliOutput",
"resolve_argmax_oss_cache_dir",
"DeepgramApi",
"DeepgramApiResponse",
"ElevenLabsApi",
Expand Down
203 changes: 203 additions & 0 deletions src/openbench/engine/argmax_oss_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# For licensing see accompanying LICENSE.md file.
# Copyright (C) 2025 Argmax, Inc. All Rights Reserved.

"""Argmax SDK open-source CLI (`argmax-cli`) — clone/build, transcribe, and diarize."""

from __future__ import annotations

import os
import subprocess
from pathlib import Path

from argmaxtools.utils import _maybe_git_clone, get_logger
from pydantic import BaseModel, Field


logger = get_logger(__name__)

ARGMAX_OSS_REPO_URL = "https://github.com/argmaxinc/argmax-oss-swift"
ARGMAX_OSS_PRODUCT = "argmax-cli"
DEFAULT_CACHE_SUBDIR = Path(".cache") / "openbench" / "argmax-oss"


def resolve_argmax_oss_cache_dir(explicit: str | Path | None = None) -> Path:
"""Absolute cache root for WhisperKit clone + `argmax-cli` build."""
if explicit is not None and str(explicit).strip():
return Path(explicit).expanduser().resolve()
env = os.environ.get("ARGMAX_OSS_CACHE_DIR")
if env:
return Path(env).expanduser().resolve()
return (Path.home() / DEFAULT_CACHE_SUBDIR).resolve()


class ArgmaxOpenSourceEngineConfig(BaseModel):
"""Engine config: cache, optional commit pin, optional prebuilt CLI path."""

cache_dir: str | None = Field(
default=None,
description="Directory for cloned repo and build artifacts (absolute after resolve). "
"Overrides ARGMAX_OSS_CACHE_DIR when set.",
)
commit_hash: str | None = Field(
default=None,
description="Git commit to checkout in the cached WhisperKit clone.",
)
cli_path: str | None = Field(
default=None,
description="If set, skip clone/build and use this argmax-cli binary.",
)


class TranscriptionCliInput(BaseModel):
"""Input for `argmax-cli transcribe`."""

audio_path: Path
keep_audio: bool = False
language: str | None = None


class TranscriptionCliOutput(BaseModel):
"""Output paths from `argmax-cli transcribe --report`."""

json_report_path: Path = Field(..., description="JSON report path")
srt_report_path: Path = Field(..., description="SRT report path")
cli_combined_output: str | None = Field(
default=None,
description="Concatenated stdout+stderr when transcribe was run with capture_combined_output=True.",
)


class DiarizeCliInput(BaseModel):
"""Input for `argmax-cli diarize`."""

audio_path: Path
rttm_path: Path
keep_audio: bool = False


class DiarizeCliOutput(BaseModel):
rttm_path: Path = Field(..., description="Written RTTM path")


class ArgmaxOpenSourceEngine:
"""Resolve `argmax-cli`, then run `transcribe` / `diarize` subcommands."""

def __init__(self, config: ArgmaxOpenSourceEngineConfig) -> None:
self.config = config
if config.cli_path:
self.cli_path = str(Path(config.cli_path).expanduser().resolve())
logger.info(f"Using Argmax OSS CLI at {self.cli_path}")
else:
self.cli_path = self._clone_and_build_cli()

def _build_cli(self, repo_dir: str) -> str:
"""Run release build (swift build -c release, not debug) and return the dir containing the binary."""
logger.info("Building %s with: swift build -c release (not debug)", ARGMAX_OSS_PRODUCT)
build_cmd = f"swift build -c release --product {ARGMAX_OSS_PRODUCT}"
try:
subprocess.run(build_cmd, cwd=repo_dir, shell=True, check=True)
result = subprocess.run(
f"{build_cmd} --show-bin-path",
cwd=repo_dir,
stdout=subprocess.PIPE,
shell=True,
text=True,
check=True,
)
except subprocess.CalledProcessError as e:
logger.error("Build failed with return code %s", e.returncode)
logger.error("Stdout:\n%s", getattr(e, "stdout", ""))
logger.error("Stderr:\n%s", getattr(e, "stderr", ""))
raise RuntimeError(
f"Failed to build {ARGMAX_OSS_PRODUCT}: exit {e.returncode}\n"
f"stdout: {getattr(e, 'stdout', '')}\nstderr: {getattr(e, 'stderr', '')}"
) from e
bin_dir = result.stdout.strip()
cli = Path(bin_dir) / ARGMAX_OSS_PRODUCT
if not cli.is_file():
raise RuntimeError(f"Expected CLI binary not found: {cli}")
logger.info("Built Argmax OSS CLI at %s", cli)
return bin_dir

def _clone_and_build_cli(self) -> str:
cache_root = resolve_argmax_oss_cache_dir(self.config.cache_dir)
cache_root.mkdir(parents=True, exist_ok=True)
repo_url_parts = ARGMAX_OSS_REPO_URL.rstrip("/").split("/")
repo_name = repo_url_parts[-1]
repo_owner = repo_url_parts[-2]

logger.info("Ensuring WhisperKit clone under %s", cache_root)
repo_dir, commit_hash = _maybe_git_clone(
out_dir=str(cache_root),
hub_url="github.com",
repo_name=repo_name,
repo_owner=repo_owner,
commit_hash=self.config.commit_hash,
)
self.config.commit_hash = commit_hash
logger.info("%s at commit %s — running sanity build", repo_name, commit_hash)
bin_dir = self._build_cli(repo_dir)
return str(Path(bin_dir) / ARGMAX_OSS_PRODUCT)

def transcribe(
self,
input: TranscriptionCliInput,
transcription_args: list[str],
report_dir: Path,
capture_combined_output: bool = False,
) -> TranscriptionCliOutput:
"""Run `argmax-cli transcribe` with pre-built flag list (see transcription config)."""
cmd = [
self.cli_path,
"transcribe",
"--audio-path",
str(input.audio_path),
*transcription_args,
]
if input.language:
cmd.extend(["--language", input.language])

logger.debug("Argmax OSS transcribe: %s", cmd)
try:
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
except subprocess.CalledProcessError as e:
raise RuntimeError(f"argmax-cli transcribe failed: {e.stderr}") from e

json_report_path = report_dir / f"{input.audio_path.stem}.json"
srt_report_path = report_dir / f"{input.audio_path.stem}.srt"

cli_combined_output: str | None = None
if capture_combined_output:
cli_combined_output = (result.stdout or "") + (result.stderr or "")

if not input.keep_audio:
input.audio_path.unlink(missing_ok=True)

return TranscriptionCliOutput(
json_report_path=json_report_path,
srt_report_path=srt_report_path,
cli_combined_output=cli_combined_output,
)

def diarize(self, input: DiarizeCliInput, diarize_args: list[str]) -> DiarizeCliOutput:
"""Run `argmax-cli diarize` with pre-built flag list (see diarization config)."""
input.rttm_path.parent.mkdir(parents=True, exist_ok=True)
cmd = [
self.cli_path,
"diarize",
"--audio-path",
str(input.audio_path),
"--rttm-path",
str(input.rttm_path),
*diarize_args,
]
logger.debug("Argmax OSS diarize: %s", cmd)
try:
subprocess.run(cmd, check=True, capture_output=True, text=True)
except subprocess.CalledProcessError as e:
raise RuntimeError(f"argmax-cli diarize failed: {e.stderr}") from e

if not input.keep_audio:
input.audio_path.unlink(missing_ok=True)

return DiarizeCliOutput(rttm_path=input.rttm_path)
1 change: 1 addition & 0 deletions src/openbench/pipeline/diarization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from .aws import *
from .common import *
from .diarization_argmax_oss import *
from .diarization_deepgram import *
from .elevenlabs import *
from .nemo import *
Expand Down
132 changes: 132 additions & 0 deletions src/openbench/pipeline/diarization/diarization_argmax_oss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# For licensing see accompanying LICENSE.md file.
# Copyright (C) 2025 Argmax, Inc. All Rights Reserved.

"""Diarization via Argmax SDK open-source `argmax-cli diarize`."""

from __future__ import annotations

from pathlib import Path
from typing import Callable

from argmaxtools.utils import get_logger
from pydantic import BaseModel, Field

from ...dataset import DiarizationSample
from ...engine.argmax_oss_engine import (
ArgmaxOpenSourceEngine,
ArgmaxOpenSourceEngineConfig,
DiarizeCliInput,
DiarizeCliOutput,
)
from ...pipeline_prediction import DiarizationAnnotation
from ..base import Pipeline, PipelineType, register_pipeline
from .common import DiarizationOutput, DiarizationPipelineConfig


logger = get_logger(__name__)

__all__ = [
"ArgmaxOpenSourceDiarizationConfig",
"ArgmaxOpenSourceDiarizationPipeline",
"ArgmaxOpenSourceDiarizePipelineInput",
]

TEMP_AUDIO_DIR = Path("audio_temp_argmax_oss")


class ArgmaxOpenSourceDiarizationConfig(DiarizationPipelineConfig):
cache_dir: str | None = Field(
default=None,
description="Cache directory for WhisperKit clone and CLI build. "
"Defaults to ARGMAX_OSS_CACHE_DIR or ~/.cache/openbench/argmax-oss.",
)
commit_hash: str | None = Field(default=None, description="Optional git commit pin for the clone.")
cli_path: str | None = Field(default=None, description="Prebuilt argmax-cli path; skips clone/build.")
model_path: str | None = Field(default=None, description="--model-path (local diarization models).")
model_repo: str | None = Field(
default=None,
description="--model-repo (HuggingFace). For auth, rely on the CLI / Hugging Face env (e.g. HF_TOKEN) rather than OpenBench config.",
)
download_model_path: str | None = Field(default=None, description="--download-model-path.")
cluster_distance_threshold: float | None = Field(
default=None,
description="--cluster-distance-threshold (default on CLI is 0.6).",
)
use_exclusive_reconciliation: bool = Field(default=False, description="--use-exclusive-reconciliation.")
disable_full_redundancy: bool = Field(default=False, description="--disable-full-redundancy.")
verbose: bool = Field(default=False, description="--verbose.")
num_speakers: int | None = Field(
default=None,
description="Optional static --num-speakers. Ignored when use_exact_num_speakers derives count from reference.",
)

def generate_diarize_cli_args(self) -> list[str]:
args: list[str] = []
if self.model_path:
args.extend(["--model-path", self.model_path])
if self.model_repo:
args.extend(["--model-repo", self.model_repo])
if self.download_model_path:
args.extend(["--download-model-path", self.download_model_path])
if self.cluster_distance_threshold is not None:
args.extend(["--cluster-distance-threshold", str(self.cluster_distance_threshold)])
if self.use_exclusive_reconciliation:
args.append("--use-exclusive-reconciliation")
if self.disable_full_redundancy:
args.append("--disable-full-redundancy")
if self.verbose:
args.append("--verbose")
return args


class ArgmaxOpenSourceDiarizePipelineInput(BaseModel):
audio_path: Path
rttm_path: Path
num_speakers: int | None = None


@register_pipeline
class ArgmaxOpenSourceDiarizationPipeline(Pipeline):
_config_class = ArgmaxOpenSourceDiarizationConfig
pipeline_type = PipelineType.DIARIZATION

def build_pipeline(self) -> Callable[[ArgmaxOpenSourceDiarizePipelineInput], DiarizeCliOutput]:
engine = ArgmaxOpenSourceEngine(
ArgmaxOpenSourceEngineConfig(
cache_dir=self.config.cache_dir,
commit_hash=self.config.commit_hash,
cli_path=self.config.cli_path,
)
)
diarize_args_base = self.config.generate_diarize_cli_args()

def run(inp: ArgmaxOpenSourceDiarizePipelineInput) -> DiarizeCliOutput:
args = list(diarize_args_base)
if inp.num_speakers is not None:
args.extend(["--num-speakers", str(inp.num_speakers)])
return engine.diarize(
DiarizeCliInput(
audio_path=inp.audio_path,
rttm_path=inp.rttm_path,
keep_audio=False,
),
args,
)

return run

def parse_input(self, input_sample: DiarizationSample) -> ArgmaxOpenSourceDiarizePipelineInput:
audio_path = input_sample.save_audio(TEMP_AUDIO_DIR)
rttm_path = audio_path.with_suffix(".rttm")
num_speakers: int | None = self.config.num_speakers
if self.config.use_exact_num_speakers:
num_speakers = len(set(input_sample.annotation.speakers))
return ArgmaxOpenSourceDiarizePipelineInput(
audio_path=audio_path,
rttm_path=rttm_path,
num_speakers=num_speakers,
)

def parse_output(self, output: DiarizeCliOutput) -> DiarizationOutput:
prediction = DiarizationAnnotation.load_annotation_file(str(output.rttm_path))
return DiarizationOutput(prediction=prediction)
6 changes: 6 additions & 0 deletions src/openbench/pipeline/orchestration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
# Copyright (C) 2025 Argmax, Inc. All Rights Reserved.

from .nemo import NeMoMTParakeetPipeline, NeMoMTParakeetPipelineConfig
from .orchestration_argmax_oss import (
ArgmaxOpenSourceOrchestrationConfig,
ArgmaxOpenSourceOrchestrationPipeline,
)
from .orchestration_deepgram import DeepgramOrchestrationPipeline, DeepgramOrchestrationPipelineConfig
from .orchestration_elevenlabs import ElevenLabsOrchestrationPipeline, ElevenLabsOrchestrationPipelineConfig
from .orchestration_openai import OpenAIOrchestrationPipeline, OpenAIOrchestrationPipelineConfig
Expand All @@ -11,6 +15,8 @@


__all__ = [
"ArgmaxOpenSourceOrchestrationPipeline",
"ArgmaxOpenSourceOrchestrationConfig",
"DeepgramOrchestrationPipeline",
"DeepgramOrchestrationPipelineConfig",
"ElevenLabsOrchestrationPipeline",
Expand Down
Loading
Loading