From 1c7672a669f0d115a5a8522ba04509aaa19895ff Mon Sep 17 00:00:00 2001 From: Glen Beane <356266+gbeane@users.noreply.github.com> Date: Tue, 23 Jun 2026 14:04:21 -0400 Subject: [PATCH 1/7] Add optional MLflow logging to cross-validation CLI --- pyproject.toml | 1 + src/jabs/classifier/__init__.py | 8 + src/jabs/classifier/mlflow_logging.py | 294 ++++++++++++++++++++ src/jabs/scripts/cli/cli.py | 52 +++- src/jabs/scripts/cli/cross_validation.py | 42 +++ tests/classifier/test_mlflow_logging.py | 298 +++++++++++++++++++++ tests/scripts/test_cross_validation_cli.py | 73 +++++ uv.lock | 6 +- 8 files changed, 772 insertions(+), 2 deletions(-) create mode 100644 src/jabs/classifier/mlflow_logging.py create mode 100644 tests/classifier/test_mlflow_logging.py diff --git a/pyproject.toml b/pyproject.toml index 9a20b28b..394895dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dependencies = [ [project.optional-dependencies] nwb = ["jabs-io[nwb]"] yaml = ["pyyaml>=6.0.0"] +mlflow = ["mlflow>=3.8.1"] [tool.uv.sources] jabs-behavior = { workspace = true } diff --git a/src/jabs/classifier/__init__.py b/src/jabs/classifier/__init__.py index e4751575..039fffcf 100644 --- a/src/jabs/classifier/__init__.py +++ b/src/jabs/classifier/__init__.py @@ -7,6 +7,11 @@ from .classifier import Classifier from .cross_validation import run_leave_one_group_out_cv +from .mlflow_logging import ( + MlflowLoggingError, + log_cross_validation_to_mlflow, + parse_kv_tags, +) from .multi_class_classifier import MultiClassClassifier from .protocols import ClassifierProtocol from .training_report import ( @@ -23,10 +28,13 @@ "Classifier", "ClassifierProtocol", "CrossValidationResult", + "MlflowLoggingError", "MultiClassCVResult", "MultiClassClassifier", "TrainingReportData", "generate_markdown_report", + "log_cross_validation_to_mlflow", + "parse_kv_tags", "run_leave_one_group_out_cv", "save_training_report", ] diff --git a/src/jabs/classifier/mlflow_logging.py b/src/jabs/classifier/mlflow_logging.py new file mode 100644 index 00000000..8b895256 --- /dev/null +++ b/src/jabs/classifier/mlflow_logging.py @@ -0,0 +1,294 @@ +"""MLflow run + artifact logging for classifier cross-validation results. + +Opt-in tracking for a JABS cross-validation run: one MLflow *run* per +invocation that records aggregate cross-validation metrics, a curated set of +configuration scalars as params, descriptive tags, and the generated training +report as an artifact. + +Connection configuration (tracking URI, experiment, auth, TLS) is **not** +hard-coded here; it is read from standard ``MLFLOW_*`` environment variables, +populated either from a ``.env`` file (see :func:`load_env_file`) or from the +ambient environment. The experiment is whatever ``MLFLOW_EXPERIMENT_NAME`` +names, falling back to MLflow's built-in "Default" experiment. + +``mlflow`` is an optional dependency. Install it with +``pip install 'jabs-behavior-classifier[mlflow]'`` (or, for a development +checkout, ``uv sync --extra mlflow``). :func:`log_cross_validation_to_mlflow` +raises :class:`MlflowLoggingError` with installation guidance if it is missing. +""" + +from __future__ import annotations + +import logging +import math +import os +import subprocess +from pathlib import Path +from typing import TYPE_CHECKING + +import numpy as np + +from .training_report import BinaryCVResult, MultiClassCVResult + +if TYPE_CHECKING: + from .training_report import TrainingReportData + +logger = logging.getLogger(__name__) + + +class MlflowLoggingError(RuntimeError): + """Raised when pushing cross-validation results to MLflow fails. + + The cross-validation run itself has already completed and its report has + been saved by the time this is raised; it signals only that the optional + MLflow push did not succeed (e.g. missing dependency, network, auth, TLS). + """ + + +def parse_kv_tags(items: list[str] | None) -> dict[str, str]: + """Parse repeated ``KEY=VALUE`` ``--mlflow-tag`` entries into a dict. + + Args: + items: Raw ``KEY=VALUE`` strings (or None / empty). + + Returns: + Mapping of tag key to value. ``None`` / empty input yields ``{}``. The + first ``=`` splits the entry; later ``=`` characters go to the value. + + Raises: + ValueError: If an entry has no ``=`` or an empty key. + """ + tags: dict[str, str] = {} + for item in items or []: + key, sep, value = item.partition("=") + key = key.strip() + if not sep or not key: + raise ValueError(f"invalid --mlflow-tag (expected KEY=VALUE): {item!r}") + tags[key] = value.strip() + return tags + + +def load_env_file(env_file: Path | None, *, override: bool = True) -> dict[str, str]: + """Apply the ``MLFLOW_*`` settings from a ``.env`` file to ``os.environ``. + + Only keys beginning with ``MLFLOW_`` are applied; any other keys in the file + are ignored, so it cannot accidentally clobber unrelated environment + variables (``PATH``, ``HTTP_PROXY``, ...). Lines may be blank, ``#`` + comments, or ``KEY=VALUE`` (an optional leading ``export`` and surrounding + quotes on the value are stripped). + + Args: + env_file: Path to a ``.env`` file, or None to read connection config + from the ambient environment (a no-op returning ``{}``). + override: If True (default), file values win for keys they define, since + naming a file is an explicit request to use its settings. + + Returns: + The ``MLFLOW_*`` mapping found in the file (handy for diagnostics). + + Raises: + FileNotFoundError: If ``env_file`` is given but does not exist. + """ + if env_file is None: + return {} + env_file = Path(env_file) + if not env_file.is_file(): + raise FileNotFoundError(f"--mlflow env file not found: {env_file}") + + values: dict[str, str] = {} + for raw_line in env_file.read_text().splitlines(): + line = raw_line.strip() + if not line or line.startswith("#"): + continue + if line.startswith("export "): + line = line[len("export ") :].lstrip() + key, sep, value = line.partition("=") + key = key.strip() + if not sep or not key.startswith("MLFLOW_"): + continue + values[key] = value.strip().strip('"').strip("'") + + for key, value in values.items(): + if override or key not in os.environ: + os.environ[key] = value + return values + + +def _git_sha() -> str | None: + """Short git SHA of the jabs checkout, or None if unavailable.""" + try: + result = subprocess.run( + ["git", "rev-parse", "--short", "HEAD"], + cwd=str(Path(__file__).resolve().parent), + capture_output=True, + text=True, + check=True, + timeout=5, + ) + except (subprocess.SubprocessError, OSError): + return None + sha = result.stdout.strip() + return sha or None + + +def aggregate_cv_metrics(report_data: TrainingReportData) -> dict[str, float]: + """Aggregate per-iteration cross-validation results into scalar metrics. + + Computes the mean and population standard deviation of accuracy across CV + iterations, plus class-specific scores (behavior precision/recall/F1 for + binary results, macro precision/recall/F1 for multi-class results), and + records dataset composition and timing as additional metrics. + + Args: + report_data: Completed training report data. + + Returns: + Mapping of MLflow-legal metric key to finite float value. Empty if + ``report_data`` has no cross-validation iterations. + """ + cv_results = report_data.cv_results + if not cv_results: + return {} + + metrics: dict[str, float] = {} + accuracy = np.array([r.accuracy for r in cv_results], dtype=float) + metrics["cv_accuracy_mean"] = float(np.mean(accuracy)) + metrics["cv_accuracy_std"] = float(np.std(accuracy)) + metrics["cv_iterations"] = float(len(cv_results)) + + if all(isinstance(r, BinaryCVResult) for r in cv_results): + per_class_attrs = ("precision_behavior", "recall_behavior", "f1_behavior") + elif all(isinstance(r, MultiClassCVResult) for r in cv_results): + per_class_attrs = ("precision_macro", "recall_macro", "f1_macro") + else: + per_class_attrs = () + + for attr in per_class_attrs: + vals = np.array([getattr(r, attr) for r in cv_results], dtype=float) + metrics[f"cv_{attr}_mean"] = float(np.mean(vals)) + metrics[f"cv_{attr}_std"] = float(np.std(vals)) + + metrics["frames_behavior"] = float(report_data.frames_behavior) + metrics["frames_not_behavior"] = float(report_data.frames_not_behavior) + metrics["bouts_behavior"] = float(report_data.bouts_behavior) + metrics["bouts_not_behavior"] = float(report_data.bouts_not_behavior) + metrics["training_time_ms"] = float(report_data.training_time_ms) + + return {key: value for key, value in metrics.items() if math.isfinite(value)} + + +def build_params(report_data: TrainingReportData) -> dict[str, str]: + """Build the curated, filterable MLflow params for a CV run. + + These are the columns you sort/filter the leaderboard by. The full results + ride the training-report artifact. + + Args: + report_data: Completed training report data. + + Returns: + Mapping of param key to stringified value. + """ + params: dict[str, object] = { + "behavior": report_data.behavior_name, + "classifier": report_data.classifier_type, + "window_size": report_data.window_size, + "balance_labels": report_data.balance_training_labels, + "symmetric_behavior": report_data.symmetric_behavior, + "distance_unit": report_data.distance_unit, + "cv_grouping_strategy": report_data.cv_grouping_strategy.value, + } + if report_data.cv_grouping_regex: + params["cv_grouping_regex"] = report_data.cv_grouping_regex + return {key: str(value) for key, value in params.items()} + + +def build_tags(report_data: TrainingReportData) -> dict[str, str]: + """Build the auto-derived MLflow run tags for a CV run. + + Args: + report_data: Completed training report data. + + Returns: + Mapping of tag key to value, omitting any whose value is empty/None. + """ + tags = { + "behavior": report_data.behavior_name, + "classifier": report_data.classifier_type, + "cv_grouping_strategy": report_data.cv_grouping_strategy.value, + "jabs_git": _git_sha(), + } + return {key: value for key, value in tags.items() if value} + + +def log_cross_validation_to_mlflow( + *, + report_data: TrainingReportData, + report_file: Path | None = None, + env_file: Path | None = None, + run_name: str | None = None, + tags: dict[str, str] | None = None, + log_report_artifact: bool = True, +) -> tuple[str, str]: + """Create one MLflow run for a cross-validation run and return its ids. + + Logs aggregate CV metrics, curated params, auto-derived plus caller tags, + and (optionally) the training report as an artifact. Connection config comes + from the environment; ``env_file``, if given, is loaded into it first. + + Args: + report_data: Completed training report data. + report_file: Path to the saved training report to upload as an artifact. + Ignored if None or missing on disk, or if ``log_report_artifact`` is + False. + env_file: Optional ``.env`` file with ``MLFLOW_*`` connection settings. + If None, connection config comes from the ambient environment. + run_name: MLflow run name. Defaults to ``-cv-``, where the + timestamp is the report's completion time (so it matches the saved report). + tags: Caller-supplied run tags; merged over the auto-derived tags (so a + user tag with the same key wins). + log_report_artifact: Whether to upload the training report artifact. + + Returns: + A ``(run_id, tracking_uri)`` tuple for the created MLflow run. + + Raises: + MlflowLoggingError: If the ``mlflow`` package is not installed. + """ + try: + import mlflow + except ImportError as e: + raise MlflowLoggingError( + "MLflow logging requires the optional 'mlflow' dependency. Install it with " + "`pip install 'jabs-behavior-classifier[mlflow]'` " + "(or, for a development checkout, `uv sync --extra mlflow`)." + ) from e + + load_env_file(env_file) + + if run_name is None: + run_name = f"{report_data.behavior_name}-cv-{report_data.timestamp:%Y%m%d-%H%M%S}" + + logger.info("Logging cross-validation results to MLflow run %r", run_name) + with mlflow.start_run(run_name=run_name) as run: + metrics = aggregate_cv_metrics(report_data) + for key, value in metrics.items(): + mlflow.log_metric(key, value) + + params = build_params(report_data) + if params: + mlflow.log_params(params) + + merged_tags = build_tags(report_data) + merged_tags.update(tags or {}) + if merged_tags: + mlflow.set_tags(merged_tags) + + if log_report_artifact and report_file is not None and Path(report_file).is_file(): + mlflow.log_artifact(str(report_file)) + + run_id = run.info.run_id + + tracking_uri = mlflow.get_tracking_uri() + logger.info("Logged cross-validation results to MLflow run %s (%s)", run_id, tracking_uri) + return run_id, tracking_uri diff --git a/src/jabs/scripts/cli/cli.py b/src/jabs/scripts/cli/cli.py index 0e89c406..ffe4843e 100644 --- a/src/jabs/scripts/cli/cli.py +++ b/src/jabs/scripts/cli/cli.py @@ -13,7 +13,7 @@ import click from rich.console import Console -from jabs.classifier import Classifier +from jabs.classifier import Classifier, MlflowLoggingError, parse_kv_tags from jabs.core.enums import ClassifierMode, ClassifierType, CrossValidationGroupingStrategy from jabs.project import ( Project, @@ -356,6 +356,35 @@ def prune(ctx: click.Context, directory: Path, behavior: str | None): "Report format will be determined by extension (.md for Markdown, .json for JSON). " "If not provided, a default filename will be used.", ) +@click.option( + "--mlflow", + "mlflow_env", + is_flag=False, + flag_value="", + default=None, + metavar="ENV_FILE", + help="Enable opt-in MLflow logging of the cross-validation results (aggregate " + "metrics, params, and the training report artifact). Optionally takes a path to a " + ".env file holding MLFLOW_* settings (tracking URI, experiment, auth, TLS); with no " + "path, those are read from the ambient environment. Absent leaves the command's " + "behavior unchanged. Requires the optional 'mlflow' extra " + "(pip install 'jabs-behavior-classifier[mlflow]').", +) +@click.option( + "--mlflow-tag", + "mlflow_tags", + multiple=True, + metavar="KEY=VALUE", + help="With --mlflow, add a free-form run tag (repeatable), e.g. purpose=baseline. " + "Merges over the auto-derived tags. No-op without --mlflow.", +) +@click.option( + "--mlflow-no-report", + "mlflow_no_report", + is_flag=True, + help="With --mlflow, skip uploading the training report artifact (metrics + params " + "only). No-op without --mlflow.", +) @click.pass_context def cross_validation( ctx: click.Context, @@ -366,6 +395,9 @@ def cross_validation( grouping_pattern: str | None, classifier: str, report_file: Path | None, + mlflow_env: str | None, + mlflow_tags: tuple[str, ...], + mlflow_no_report: bool, ): """Run leave-one-group-out cross-validation for a JABS project.""" if report_file is not None and report_file.suffix.lower() not in {".md", ".json"}: @@ -380,6 +412,15 @@ def cross_validation( } cv_grouping = cv_grouping_by_name[grouping_strategy.lower()] if grouping_strategy else None + # --mlflow: absent -> None (disabled); bare flag -> "" (ambient env); + # with a path -> that .env file. + mlflow_enabled = mlflow_env is not None + mlflow_env_file = Path(mlflow_env) if mlflow_env else None + try: + parsed_mlflow_tags = parse_kv_tags(list(mlflow_tags)) + except ValueError as e: + raise click.ClickException(str(e)) from e + try: classifier_type = ClassifierType[classifier.upper()] run_cross_validation( @@ -390,7 +431,16 @@ def cross_validation( k, report_file, grouping_regex=grouping_pattern, + mlflow_enabled=mlflow_enabled, + mlflow_env_file=mlflow_env_file, + mlflow_tags=parsed_mlflow_tags, + mlflow_log_report=not mlflow_no_report, ) + except MlflowLoggingError: + # Cross-validation and the report succeeded; only the optional MLflow + # push failed (already reported on the console). Use a distinct, non-zero + # exit code so automation can tell this apart from a CV failure. + ctx.exit(3) except Exception as e: raise click.ClickException(str(e)) from e diff --git a/src/jabs/scripts/cli/cross_validation.py b/src/jabs/scripts/cli/cross_validation.py index a90fd897..52ea82d4 100644 --- a/src/jabs/scripts/cli/cross_validation.py +++ b/src/jabs/scripts/cli/cross_validation.py @@ -9,7 +9,9 @@ from jabs.classifier import ( Classifier, + MlflowLoggingError, TrainingReportData, + log_cross_validation_to_mlflow, run_leave_one_group_out_cv, save_training_report, ) @@ -28,6 +30,10 @@ def run_cross_validation( k: int, report_file: Path | None = None, grouping_regex: str | None = None, + mlflow_enabled: bool = False, + mlflow_env_file: Path | None = None, + mlflow_tags: dict[str, str] | None = None, + mlflow_log_report: bool = True, ) -> None: """Run cross-validation for a JABS project from the command line. @@ -45,6 +51,18 @@ def run_cross_validation( grouping_regex (str | None): Regular expression used to extract a grouping key from each video filename. Only used when ``grouping_strategy`` is ``FILENAME_PATTERN``. If None, uses the pattern saved in project settings. + mlflow_enabled (bool): If True, push the cross-validation results to MLflow + after the report is saved. + mlflow_env_file (Path | None): Optional ``.env`` file with ``MLFLOW_*`` connection + settings. If None, connection config comes from the ambient environment. + mlflow_tags (dict[str, str] | None): Optional free-form MLflow run tags, merged + over the auto-derived tags. + mlflow_log_report (bool): Whether to upload the training report as an MLflow + artifact. Only used when ``mlflow_enabled`` is True. + + Raises: + MlflowLoggingError: If MLflow logging is requested but fails. The + cross-validation results and the saved report are unaffected. """ if k < 0: raise ValueError("The number of cross-validation splits 'k' must be non-negative.") @@ -227,3 +245,27 @@ def progress_callback(): save_training_report(training_data, report_file) console.print(f"\nTraining report saved to: {report_file}", style="bold green") + + # Push results to MLflow last, so a logging failure (missing dependency, + # network, auth, TLS) never costs the cross-validation results -- they are + # already on screen and the report is already saved. + if mlflow_enabled: + try: + run_id, tracking_uri = log_cross_validation_to_mlflow( + report_data=training_data, + report_file=report_file, + env_file=mlflow_env_file, + tags=mlflow_tags, + log_report_artifact=mlflow_log_report, + ) + except Exception as e: + console.print(f"\nWarning: MLflow logging failed: {e}", style="bold yellow") + console.print( + " (cross-validation results above and the saved report are unaffected)", + style="yellow", + ) + raise MlflowLoggingError(str(e)) from e + console.print( + f"\nLogged cross-validation results to MLflow run {run_id} ({tracking_uri})", + style="bold green", + ) diff --git a/tests/classifier/test_mlflow_logging.py b/tests/classifier/test_mlflow_logging.py new file mode 100644 index 00000000..9c2dbddb --- /dev/null +++ b/tests/classifier/test_mlflow_logging.py @@ -0,0 +1,298 @@ +"""Tests for :mod:`jabs.classifier.mlflow_logging`. + +The actual MLflow client is never imported here; tests that exercise +:func:`log_cross_validation_to_mlflow` inject a fake ``mlflow`` module into +``sys.modules`` so no tracking server, filesystem, or network is touched. +""" + +import os +import sys +from datetime import datetime +from pathlib import Path +from types import SimpleNamespace + +import numpy as np +import pytest + +from jabs.classifier.mlflow_logging import ( + MlflowLoggingError, + aggregate_cv_metrics, + build_params, + build_tags, + load_env_file, + log_cross_validation_to_mlflow, + parse_kv_tags, +) +from jabs.classifier.training_report import BinaryCVResult, TrainingReportData +from jabs.core.enums import CrossValidationGroupingStrategy + + +@pytest.fixture +def binary_report() -> TrainingReportData: + """A binary cross-validation report with two iterations.""" + cv_results = [ + BinaryCVResult( + iteration=1, + test_label="cage_1", + accuracy=0.9, + confusion_matrix=np.zeros((2, 2)), + precision_behavior=0.8, + recall_behavior=0.7, + f1_behavior=0.75, + ), + BinaryCVResult( + iteration=2, + test_label="cage_2", + accuracy=0.8, + confusion_matrix=np.zeros((2, 2)), + precision_behavior=0.6, + recall_behavior=0.5, + f1_behavior=0.55, + ), + ] + return TrainingReportData( + behavior_name="Walk", + classifier_type="XGBoost", + window_size=5, + balance_training_labels=True, + symmetric_behavior=False, + distance_unit="cm", + cv_results=cv_results, + final_top_features=[("feat_a", 0.5)], + training_time_ms=1234, + timestamp=datetime(2026, 6, 23, 12, 0, 0), + cv_grouping_strategy=CrossValidationGroupingStrategy.FILENAME_PATTERN, + frames_behavior=100, + frames_not_behavior=200, + bouts_behavior=10, + bouts_not_behavior=20, + cv_grouping_regex=r"^(\w+?)_", + ) + + +class _FakeRun: + def __init__(self) -> None: + self.info = SimpleNamespace(run_id="run-123") + + def __enter__(self) -> "_FakeRun": + return self + + def __exit__(self, *exc: object) -> bool: + return False + + +class _FakeMlflow: + """Minimal stand-in recording the calls the logger makes.""" + + def __init__(self) -> None: + self.run_name: str | None = None + self.metrics: dict[str, float] = {} + self.params: dict[str, str] = {} + self.tags: dict[str, str] = {} + self.artifacts: list[str] = [] + + def start_run(self, run_name: str | None = None) -> _FakeRun: + self.run_name = run_name + return _FakeRun() + + def log_metric(self, key: str, value: float) -> None: + self.metrics[key] = value + + def log_params(self, params: dict[str, str]) -> None: + self.params.update(params) + + def set_tags(self, tags: dict[str, str]) -> None: + self.tags.update(tags) + + def log_artifact(self, path: str) -> None: + self.artifacts.append(path) + + def get_tracking_uri(self) -> str: + return "file:///tmp/mlruns" + + +# --------------------------------------------------------------------------- # +# parse_kv_tags +# --------------------------------------------------------------------------- # +def test_parse_kv_tags_basic() -> None: + """KEY=VALUE entries parse into a dict, values may contain spaces.""" + assert parse_kv_tags(["a=1", "purpose=release candidate"]) == { + "a": "1", + "purpose": "release candidate", + } + + +def test_parse_kv_tags_none_and_empty() -> None: + """None or empty input yields an empty dict.""" + assert parse_kv_tags(None) == {} + assert parse_kv_tags([]) == {} + + +def test_parse_kv_tags_value_may_contain_equals() -> None: + """Only the first '=' splits; later ones go to the value.""" + assert parse_kv_tags(["expr=a=b"]) == {"expr": "a=b"} + + +@pytest.mark.parametrize("bad", ["noequals", "=novalue"], ids=["no-eq", "empty-key"]) +def test_parse_kv_tags_invalid(bad: str) -> None: + """Entries with no '=' or an empty key are rejected.""" + with pytest.raises(ValueError, match="expected KEY=VALUE"): + parse_kv_tags([bad]) + + +# --------------------------------------------------------------------------- # +# load_env_file +# --------------------------------------------------------------------------- # +def test_load_env_file_none_is_noop() -> None: + """A None env file applies nothing and returns an empty dict.""" + assert load_env_file(None) == {} + + +def test_load_env_file_applies_only_mlflow_keys( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """Only MLFLOW_* keys are applied to the environment; others are ignored.""" + monkeypatch.setattr(os, "environ", dict(os.environ)) + env_file = tmp_path / "mlflow.env" + env_file.write_text( + "# a comment\n" + 'MLFLOW_TRACKING_URI="https://mlflow.example.org"\n' + "export MLFLOW_EXPERIMENT_NAME=behaviors\n" + "OTHER_VAR=should-be-ignored\n" + ) + + applied = load_env_file(env_file) + + assert applied == { + "MLFLOW_TRACKING_URI": "https://mlflow.example.org", + "MLFLOW_EXPERIMENT_NAME": "behaviors", + } + assert os.environ["MLFLOW_TRACKING_URI"] == "https://mlflow.example.org" + assert os.environ["MLFLOW_EXPERIMENT_NAME"] == "behaviors" + assert "OTHER_VAR" not in os.environ + + +def test_load_env_file_missing_raises(tmp_path: Path) -> None: + """A given-but-missing env file raises FileNotFoundError.""" + with pytest.raises(FileNotFoundError, match="env file not found"): + load_env_file(tmp_path / "does_not_exist.env") + + +# --------------------------------------------------------------------------- # +# aggregate_cv_metrics / build_params / build_tags +# --------------------------------------------------------------------------- # +def test_aggregate_cv_metrics_binary(binary_report: TrainingReportData) -> None: + """Binary CV results aggregate into mean/std and composition metrics.""" + metrics = aggregate_cv_metrics(binary_report) + assert metrics["cv_accuracy_mean"] == pytest.approx(0.85) + assert metrics["cv_accuracy_std"] == pytest.approx(0.05) + assert metrics["cv_iterations"] == pytest.approx(2.0) + assert metrics["cv_precision_behavior_mean"] == pytest.approx(0.7) + assert metrics["cv_recall_behavior_mean"] == pytest.approx(0.6) + assert metrics["cv_f1_behavior_mean"] == pytest.approx(0.65) + assert metrics["frames_behavior"] == pytest.approx(100.0) + assert metrics["bouts_not_behavior"] == pytest.approx(20.0) + assert metrics["training_time_ms"] == pytest.approx(1234.0) + + +def test_aggregate_cv_metrics_empty(binary_report: TrainingReportData) -> None: + """No CV iterations yields no metrics.""" + binary_report.cv_results = [] + assert aggregate_cv_metrics(binary_report) == {} + + +def test_build_params_includes_regex_for_filename_strategy( + binary_report: TrainingReportData, +) -> None: + """The grouping regex is recorded as a param under the filename strategy.""" + params = build_params(binary_report) + assert params["behavior"] == "Walk" + assert params["classifier"] == "XGBoost" + assert params["window_size"] == "5" + assert params["balance_labels"] == "True" + assert params["cv_grouping_strategy"] == "Filename Pattern" + assert params["cv_grouping_regex"] == r"^(\w+?)_" + + +def test_build_params_omits_regex_when_unset(binary_report: TrainingReportData) -> None: + """No regex param is recorded when no grouping regex is set.""" + binary_report.cv_grouping_strategy = CrossValidationGroupingStrategy.VIDEO + binary_report.cv_grouping_regex = None + params = build_params(binary_report) + assert "cv_grouping_regex" not in params + + +def test_build_tags_omits_empty( + binary_report: TrainingReportData, monkeypatch: pytest.MonkeyPatch +) -> None: + """Tags with empty/None values (e.g. an unavailable git sha) are dropped.""" + monkeypatch.setattr("jabs.classifier.mlflow_logging._git_sha", lambda: None) + tags = build_tags(binary_report) + assert tags == { + "behavior": "Walk", + "classifier": "XGBoost", + "cv_grouping_strategy": "Filename Pattern", + } + assert "jabs_git" not in tags + + +# --------------------------------------------------------------------------- # +# log_cross_validation_to_mlflow +# --------------------------------------------------------------------------- # +def test_log_cross_validation_to_mlflow( + binary_report: TrainingReportData, + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A run logs metrics/params/tags/artifact and returns run id + tracking URI.""" + fake = _FakeMlflow() + monkeypatch.setitem(sys.modules, "mlflow", fake) + monkeypatch.setattr("jabs.classifier.mlflow_logging._git_sha", lambda: "abc1234") + report_file = tmp_path / "report.md" + report_file.write_text("# report") + + run_id, tracking_uri = log_cross_validation_to_mlflow( + report_data=binary_report, + report_file=report_file, + tags={"purpose": "baseline"}, + ) + + assert run_id == "run-123" + assert tracking_uri == "file:///tmp/mlruns" + assert fake.run_name == "Walk-cv-20260623-120000" + assert fake.metrics["cv_accuracy_mean"] == pytest.approx(0.85) + assert fake.params["behavior"] == "Walk" + # user tag merges over auto tags, auto git tag preserved + assert fake.tags["purpose"] == "baseline" + assert fake.tags["jabs_git"] == "abc1234" + assert fake.artifacts == [str(report_file)] + + +def test_log_cross_validation_skips_artifact_when_disabled( + binary_report: TrainingReportData, + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """With log_report_artifact=False, the report is not uploaded.""" + fake = _FakeMlflow() + monkeypatch.setitem(sys.modules, "mlflow", fake) + report_file = tmp_path / "report.md" + report_file.write_text("# report") + + log_cross_validation_to_mlflow( + report_data=binary_report, + report_file=report_file, + log_report_artifact=False, + ) + + assert fake.artifacts == [] + + +def test_log_cross_validation_missing_mlflow_raises( + binary_report: TrainingReportData, monkeypatch: pytest.MonkeyPatch +) -> None: + """A missing mlflow install raises MlflowLoggingError with install guidance.""" + # Setting the module to None makes ``import mlflow`` raise ImportError. + monkeypatch.setitem(sys.modules, "mlflow", None) + with pytest.raises(MlflowLoggingError, match="optional 'mlflow' dependency"): + log_cross_validation_to_mlflow(report_data=binary_report) diff --git a/tests/scripts/test_cross_validation_cli.py b/tests/scripts/test_cross_validation_cli.py index 14ae00a2..0f00c153 100644 --- a/tests/scripts/test_cross_validation_cli.py +++ b/tests/scripts/test_cross_validation_cli.py @@ -14,6 +14,7 @@ from click.testing import CliRunner import jabs.scripts.cli.cli as cli_module +from jabs.classifier import MlflowLoggingError from jabs.core.enums import CrossValidationGroupingStrategy from jabs.scripts.cli.cli import cli @@ -92,3 +93,75 @@ def test_invalid_grouping_strategy_rejected(tmp_path: Path, run_cv_spy: mock.Moc assert result.exit_code != 0 run_cv_spy.assert_not_called() + + +# --------------------------------------------------------------------------- # +# MLflow options +# --------------------------------------------------------------------------- # +def test_mlflow_absent_disables_logging(tmp_path: Path, run_cv_spy: mock.Mock) -> None: + """Without --mlflow, logging is disabled and no env file is passed.""" + result = _invoke(tmp_path) + + assert result.exit_code == 0, result.output + kwargs = run_cv_spy.call_args.kwargs + assert kwargs["mlflow_enabled"] is False + assert kwargs["mlflow_env_file"] is None + assert kwargs["mlflow_log_report"] is True + + +def test_mlflow_bare_flag_enables_ambient(tmp_path: Path, run_cv_spy: mock.Mock) -> None: + """A bare --mlflow enables logging with no env file (ambient environment).""" + result = _invoke(tmp_path, "--mlflow") + + assert result.exit_code == 0, result.output + kwargs = run_cv_spy.call_args.kwargs + assert kwargs["mlflow_enabled"] is True + assert kwargs["mlflow_env_file"] is None + + +def test_mlflow_with_env_file(tmp_path: Path, run_cv_spy: mock.Mock) -> None: + """--mlflow with a path forwards that .env file.""" + result = _invoke(tmp_path, "--mlflow", "settings.env") + + assert result.exit_code == 0, result.output + kwargs = run_cv_spy.call_args.kwargs + assert kwargs["mlflow_enabled"] is True + assert kwargs["mlflow_env_file"] == Path("settings.env") + + +def test_mlflow_tags_parsed_and_forwarded(tmp_path: Path, run_cv_spy: mock.Mock) -> None: + """Repeated --mlflow-tag entries are parsed into a dict.""" + result = _invoke( + tmp_path, "--mlflow", "--mlflow-tag", "purpose=baseline", "--mlflow-tag", "owner=glen" + ) + + assert result.exit_code == 0, result.output + assert run_cv_spy.call_args.kwargs["mlflow_tags"] == { + "purpose": "baseline", + "owner": "glen", + } + + +def test_mlflow_no_report_flag(tmp_path: Path, run_cv_spy: mock.Mock) -> None: + """--mlflow-no-report disables the report artifact upload.""" + result = _invoke(tmp_path, "--mlflow", "--mlflow-no-report") + + assert result.exit_code == 0, result.output + assert run_cv_spy.call_args.kwargs["mlflow_log_report"] is False + + +def test_invalid_mlflow_tag_rejected(tmp_path: Path, run_cv_spy: mock.Mock) -> None: + """A malformed --mlflow-tag fails before run_cross_validation is called.""" + result = _invoke(tmp_path, "--mlflow", "--mlflow-tag", "noequals") + + assert result.exit_code != 0 + run_cv_spy.assert_not_called() + + +def test_mlflow_logging_failure_exits_with_code_3(tmp_path: Path, run_cv_spy: mock.Mock) -> None: + """An MlflowLoggingError maps to a distinct exit code (3), not the generic 1.""" + run_cv_spy.side_effect = MlflowLoggingError("push failed") + + result = _invoke(tmp_path, "--mlflow") + + assert result.exit_code == 3 diff --git a/uv.lock b/uv.lock index 4cfd6a40..02fecf3d 100644 --- a/uv.lock +++ b/uv.lock @@ -1755,6 +1755,9 @@ dependencies = [ ] [package.optional-dependencies] +mlflow = [ + { name = "mlflow" }, +] nwb = [ { name = "jabs-io", extra = ["nwb"] }, ] @@ -1796,6 +1799,7 @@ requires-dist = [ { name = "jabs-io", extras = ["nwb"], marker = "extra == 'nwb'", editable = "packages/jabs-io" }, { name = "jsonschema", specifier = ">=4.25.1,<5.0.0" }, { name = "markdown2", specifier = ">=2.5.1,<3.0.0" }, + { name = "mlflow", marker = "extra == 'mlflow'", specifier = ">=3.8.1" }, { name = "numpy", specifier = ">=2.0.0,<3.0.0" }, { name = "opencv-python-headless", specifier = ">=4.8.1.78,<5.0.0" }, { name = "packaging", specifier = ">=24.0" }, @@ -1812,7 +1816,7 @@ requires-dist = [ { name = "toml", specifier = ">=0.10.2,<0.11.0" }, { name = "xgboost", specifier = ">=2.0.0,<3.0.0" }, ] -provides-extras = ["nwb", "yaml"] +provides-extras = ["nwb", "yaml", "mlflow"] [package.metadata.requires-dev] dev = [ From 2e0ec334a7ce07d3b28eefd03e2f73557e36e62a Mon Sep 17 00:00:00 2001 From: Glen Beane <356266+gbeane@users.noreply.github.com> Date: Tue, 23 Jun 2026 14:58:38 -0400 Subject: [PATCH 2/7] Warn and skip MLflow logging when the mlflow extra is not installed --- src/jabs/classifier/__init__.py | 2 + src/jabs/classifier/mlflow_logging.py | 12 ++++++ src/jabs/scripts/cli/cli.py | 14 ++++++- src/jabs/scripts/cli/cross_validation.py | 3 +- tests/classifier/test_mlflow_logging.py | 14 +++++++ tests/scripts/test_cross_validation_cli.py | 45 +++++++++++++++++++--- 6 files changed, 83 insertions(+), 7 deletions(-) diff --git a/src/jabs/classifier/__init__.py b/src/jabs/classifier/__init__.py index 039fffcf..993caa9f 100644 --- a/src/jabs/classifier/__init__.py +++ b/src/jabs/classifier/__init__.py @@ -10,6 +10,7 @@ from .mlflow_logging import ( MlflowLoggingError, log_cross_validation_to_mlflow, + mlflow_available, parse_kv_tags, ) from .multi_class_classifier import MultiClassClassifier @@ -34,6 +35,7 @@ "TrainingReportData", "generate_markdown_report", "log_cross_validation_to_mlflow", + "mlflow_available", "parse_kv_tags", "run_leave_one_group_out_cv", "save_training_report", diff --git a/src/jabs/classifier/mlflow_logging.py b/src/jabs/classifier/mlflow_logging.py index 8b895256..5745e866 100644 --- a/src/jabs/classifier/mlflow_logging.py +++ b/src/jabs/classifier/mlflow_logging.py @@ -19,6 +19,7 @@ from __future__ import annotations +import importlib.util import logging import math import os @@ -36,6 +37,17 @@ logger = logging.getLogger(__name__) +def mlflow_available() -> bool: + """Return True if the optional ``mlflow`` package is importable. + + Uses :func:`importlib.util.find_spec` so the (heavy) ``mlflow`` package is + not actually imported just to test for its presence. Lets callers degrade + gracefully -- warning and skipping MLflow logging -- when the optional + 'mlflow' extra is not installed. + """ + return importlib.util.find_spec("mlflow") is not None + + class MlflowLoggingError(RuntimeError): """Raised when pushing cross-validation results to MLflow fails. diff --git a/src/jabs/scripts/cli/cli.py b/src/jabs/scripts/cli/cli.py index ffe4843e..3796daa0 100644 --- a/src/jabs/scripts/cli/cli.py +++ b/src/jabs/scripts/cli/cli.py @@ -13,7 +13,7 @@ import click from rich.console import Console -from jabs.classifier import Classifier, MlflowLoggingError, parse_kv_tags +from jabs.classifier import Classifier, MlflowLoggingError, mlflow_available, parse_kv_tags from jabs.core.enums import ClassifierMode, ClassifierType, CrossValidationGroupingStrategy from jabs.project import ( Project, @@ -421,6 +421,18 @@ def cross_validation( except ValueError as e: raise click.ClickException(str(e)) from e + # If MLflow logging was requested but the optional 'mlflow' extra is not + # installed, warn and ignore the MLflow options rather than failing -- the + # cross-validation still runs and the report is still produced. + if mlflow_enabled and not mlflow_available(): + click.echo( + "Warning: MLflow logging was requested (--mlflow) but the optional 'mlflow' " + "dependency is not installed; ignoring MLflow options. Install it with " + "\"pip install 'jabs-behavior-classifier[mlflow]'\" to enable logging.", + err=True, + ) + mlflow_enabled = False + try: classifier_type = ClassifierType[classifier.upper()] run_cross_validation( diff --git a/src/jabs/scripts/cli/cross_validation.py b/src/jabs/scripts/cli/cross_validation.py index 52ea82d4..1a60df4f 100644 --- a/src/jabs/scripts/cli/cross_validation.py +++ b/src/jabs/scripts/cli/cross_validation.py @@ -52,7 +52,8 @@ def run_cross_validation( from each video filename. Only used when ``grouping_strategy`` is ``FILENAME_PATTERN``. If None, uses the pattern saved in project settings. mlflow_enabled (bool): If True, push the cross-validation results to MLflow - after the report is saved. + after the report is saved. Callers should only enable this when the optional + 'mlflow' dependency is installed (the CLI checks this and warns otherwise). mlflow_env_file (Path | None): Optional ``.env`` file with ``MLFLOW_*`` connection settings. If None, connection config comes from the ambient environment. mlflow_tags (dict[str, str] | None): Optional free-form MLflow run tags, merged diff --git a/tests/classifier/test_mlflow_logging.py b/tests/classifier/test_mlflow_logging.py index 9c2dbddb..ce3e3b94 100644 --- a/tests/classifier/test_mlflow_logging.py +++ b/tests/classifier/test_mlflow_logging.py @@ -14,6 +14,7 @@ import numpy as np import pytest +from jabs.classifier import mlflow_logging from jabs.classifier.mlflow_logging import ( MlflowLoggingError, aggregate_cv_metrics, @@ -21,6 +22,7 @@ build_tags, load_env_file, log_cross_validation_to_mlflow, + mlflow_available, parse_kv_tags, ) from jabs.classifier.training_report import BinaryCVResult, TrainingReportData @@ -111,6 +113,18 @@ def get_tracking_uri(self) -> str: return "file:///tmp/mlruns" +# --------------------------------------------------------------------------- # +# mlflow_available +# --------------------------------------------------------------------------- # +@pytest.mark.parametrize( + ("spec", "expected"), [(object(), True), (None, False)], ids=["present", "absent"] +) +def test_mlflow_available(spec: object, expected: bool, monkeypatch: pytest.MonkeyPatch) -> None: + """mlflow_available() reflects whether find_spec locates the package.""" + monkeypatch.setattr(mlflow_logging.importlib.util, "find_spec", lambda name: spec) + assert mlflow_available() is expected + + # --------------------------------------------------------------------------- # # parse_kv_tags # --------------------------------------------------------------------------- # diff --git a/tests/scripts/test_cross_validation_cli.py b/tests/scripts/test_cross_validation_cli.py index 0f00c153..3df47573 100644 --- a/tests/scripts/test_cross_validation_cli.py +++ b/tests/scripts/test_cross_validation_cli.py @@ -27,6 +27,16 @@ def run_cv_spy(monkeypatch: pytest.MonkeyPatch) -> mock.Mock: return spy +@pytest.fixture +def mlflow_installed(monkeypatch: pytest.MonkeyPatch) -> None: + """Make the CLI treat the optional 'mlflow' extra as installed. + + The extra is not a root dependency, so it is typically absent from the test + environment; tests of the MLflow-enabled path patch this to be deterministic. + """ + monkeypatch.setattr(cli_module, "mlflow_available", lambda: True) + + def _invoke(tmp_path: Path, *extra_args: str): """Invoke the cross-validation command against ``tmp_path`` with extra args.""" runner = CliRunner() @@ -109,7 +119,9 @@ def test_mlflow_absent_disables_logging(tmp_path: Path, run_cv_spy: mock.Mock) - assert kwargs["mlflow_log_report"] is True -def test_mlflow_bare_flag_enables_ambient(tmp_path: Path, run_cv_spy: mock.Mock) -> None: +def test_mlflow_bare_flag_enables_ambient( + tmp_path: Path, run_cv_spy: mock.Mock, mlflow_installed: None +) -> None: """A bare --mlflow enables logging with no env file (ambient environment).""" result = _invoke(tmp_path, "--mlflow") @@ -119,7 +131,9 @@ def test_mlflow_bare_flag_enables_ambient(tmp_path: Path, run_cv_spy: mock.Mock) assert kwargs["mlflow_env_file"] is None -def test_mlflow_with_env_file(tmp_path: Path, run_cv_spy: mock.Mock) -> None: +def test_mlflow_with_env_file( + tmp_path: Path, run_cv_spy: mock.Mock, mlflow_installed: None +) -> None: """--mlflow with a path forwards that .env file.""" result = _invoke(tmp_path, "--mlflow", "settings.env") @@ -129,7 +143,9 @@ def test_mlflow_with_env_file(tmp_path: Path, run_cv_spy: mock.Mock) -> None: assert kwargs["mlflow_env_file"] == Path("settings.env") -def test_mlflow_tags_parsed_and_forwarded(tmp_path: Path, run_cv_spy: mock.Mock) -> None: +def test_mlflow_tags_parsed_and_forwarded( + tmp_path: Path, run_cv_spy: mock.Mock, mlflow_installed: None +) -> None: """Repeated --mlflow-tag entries are parsed into a dict.""" result = _invoke( tmp_path, "--mlflow", "--mlflow-tag", "purpose=baseline", "--mlflow-tag", "owner=glen" @@ -142,7 +158,9 @@ def test_mlflow_tags_parsed_and_forwarded(tmp_path: Path, run_cv_spy: mock.Mock) } -def test_mlflow_no_report_flag(tmp_path: Path, run_cv_spy: mock.Mock) -> None: +def test_mlflow_no_report_flag( + tmp_path: Path, run_cv_spy: mock.Mock, mlflow_installed: None +) -> None: """--mlflow-no-report disables the report artifact upload.""" result = _invoke(tmp_path, "--mlflow", "--mlflow-no-report") @@ -158,10 +176,27 @@ def test_invalid_mlflow_tag_rejected(tmp_path: Path, run_cv_spy: mock.Mock) -> N run_cv_spy.assert_not_called() -def test_mlflow_logging_failure_exits_with_code_3(tmp_path: Path, run_cv_spy: mock.Mock) -> None: +def test_mlflow_logging_failure_exits_with_code_3( + tmp_path: Path, run_cv_spy: mock.Mock, mlflow_installed: None +) -> None: """An MlflowLoggingError maps to a distinct exit code (3), not the generic 1.""" run_cv_spy.side_effect = MlflowLoggingError("push failed") result = _invoke(tmp_path, "--mlflow") assert result.exit_code == 3 + + +def test_mlflow_unavailable_warns_and_ignores( + tmp_path: Path, run_cv_spy: mock.Mock, monkeypatch: pytest.MonkeyPatch +) -> None: + """When the mlflow extra is absent, --mlflow is ignored with a warning (exit 0).""" + monkeypatch.setattr(cli_module, "mlflow_available", lambda: False) + + result = _invoke(tmp_path, "--mlflow", "--mlflow-tag", "purpose=baseline") + + assert result.exit_code == 0, result.output + assert "not installed" in result.stderr + # cross-validation still runs, but MLflow logging is disabled + run_cv_spy.assert_called_once() + assert run_cv_spy.call_args.kwargs["mlflow_enabled"] is False From a5e083268c53eae889f8eb3d334b0ff635216a86 Mon Sep 17 00:00:00 2001 From: Glen Beane <356266+gbeane@users.noreply.github.com> Date: Tue, 23 Jun 2026 15:10:11 -0400 Subject: [PATCH 3/7] Document cross-validation CLI and MLflow integration in user guide --- docs/user-guide/cli-tools.md | 180 ++++++++++++++++++ .../resources/docs/user_guide/cli-tools.md | 180 ++++++++++++++++++ 2 files changed, 360 insertions(+) diff --git a/docs/user-guide/cli-tools.md b/docs/user-guide/cli-tools.md index 6799bb6f..80bb2f73 100644 --- a/docs/user-guide/cli-tools.md +++ b/docs/user-guide/cli-tools.md @@ -530,3 +530,183 @@ jabs-cli convert-parquet session_poses.parquet \ --num-frames 3600 \ --out-dir /path/to/output ``` + +## jabs-cli cross-validation + +The `jabs-cli cross-validation` command runs leave-one-group-out cross-validation for a single behavior in a JABS project, then trains a final model on all labeled data to report feature importance. It prints per-iteration metrics to the console and writes a training report file (the same report produced by the GUI). Use it to estimate how well a classifier generalizes before committing to a trained model. + +Features must already be computed for the project (for example via [`jabs-init`](#jabs-init)); if they are missing this command will compute them, which can be slow. + +**Usage:** + +```bash +jabs-cli cross-validation DIRECTORY --behavior BEHAVIOR \ + [-k SPLITS] \ + [--grouping-strategy {video|individual|filename}] \ + [--grouping-pattern REGEX] \ + [--classifier {catboost|random_forest|xgboost}] \ + [--report-file FILE] \ + [--mlflow [ENV_FILE]] [--mlflow-tag KEY=VALUE] [--mlflow-no-report] +``` + +- `DIRECTORY`: Path to the JABS project directory. +- `--behavior BEHAVIOR` (required): Behavior to evaluate. Quote it if it contains spaces; must match an existing behavior in the project. +- `-k SPLITS`: Number of cross-validation iterations. `0` (the default) uses the maximum number of splits supported by the data and grouping strategy. +- `--grouping-strategy {video|individual|filename}`: How labeled frames are grouped into cross-validation folds (see [Grouping strategies](#grouping-strategies)). If omitted, the project's saved setting is used. +- `--grouping-pattern REGEX`: Regular expression applied to each video filename to derive a grouping key. Only used with `--grouping-strategy filename`. If omitted, the pattern saved in the project is used. +- `--classifier {catboost|random_forest|xgboost}`: Classifier to evaluate. Defaults to `xgboost`. The available choices depend on which classifier libraries are installed; see [Classifier Types](classifier-types.md). +- `--report-file FILE`: Where to write the training report. The format is chosen by extension: `.md` (Markdown) or `.json` (JSON). If omitted, a timestamped Markdown file is written to the current directory (`__training_report.md`). +- `--mlflow`, `--mlflow-tag`, `--mlflow-no-report`: Optional MLflow logging (see [MLflow logging](#mlflow-logging)). + +### Grouping strategies + +Cross-validation holds out one *group* of labeled data per iteration and trains on the rest, so groups define what "generalization" means for the score. JABS supports three strategies: + +| Strategy | Each group is... | Use when | +|---|---|---| +| `individual` | one (video, identity) pair | you want to measure generalization across individual animals | +| `video` | one whole video (all identities in it) | you want to measure generalization across videos/sessions | +| `filename` | all videos whose filename yields the same key under `--grouping-pattern` | videos from the same cage/cohort/day share a filename component and should not be split across train and test | + +For the `filename` strategy, the pattern is applied with `re.search`, so it matches anywhere in the filename. If the pattern has a capturing group, the first captured group is the key; otherwise the whole match is the key. Videos that do not match the pattern are placed in their own single-video group. For example, `--grouping-pattern '^(\w+?)_'` groups `cage12_2024-01-01.mp4` and `cage12_2024-01-02.mp4` together under the key `cage12`. + +### Training report + +The report (and the console output) include: + +- Per-iteration accuracy, precision and recall for both classes, and F1 for the behavior class, plus the held-out test group label for each iteration. +- The top features (by importance) from a final model trained on all labeled data. +- Labeled frame and bout counts, the window size, distance unit, classifier type, and the grouping strategy/pattern used. + +**Examples:** + +```bash +# Cross-validate "grooming" with default settings (project grouping, all splits, xgboost) +jabs-cli cross-validation /path/to/project --behavior grooming + +# 5-fold, grouped by individual animal, with a CatBoost classifier +jabs-cli cross-validation /path/to/project --behavior grooming \ + -k 5 --grouping-strategy individual --classifier catboost + +# Group videos by a shared filename prefix and write a JSON report +jabs-cli cross-validation /path/to/project --behavior grooming \ + --grouping-strategy filename --grouping-pattern '^(\w+?)_' \ + --report-file grooming_cv.json +``` + +### MLflow logging + +The cross-validation command can optionally log each run to an [MLflow](https://mlflow.org/) tracking server, recording aggregate metrics, run parameters, descriptive tags, and the training report as an artifact. This is opt-in and off by default. + +#### Installing the MLflow extra + +MLflow is an optional dependency. Install it with the `mlflow` extra: + +```bash +pip install 'jabs-behavior-classifier[mlflow]' +``` + +If you request MLflow logging without the extra installed, the command prints a warning, ignores the MLflow options, and still runs the cross-validation and writes the report (it exits `0`). + +#### Enabling logging + +Add the `--mlflow` flag: + +```bash +# Use connection settings from the ambient environment +jabs-cli cross-validation /path/to/project --behavior grooming --mlflow + +# Use connection settings from a .env file +jabs-cli cross-validation /path/to/project --behavior grooming --mlflow settings.env +``` + +`--mlflow` optionally takes the path to a `.env` file. With no path, connection settings are read from the current environment. + +#### Connection configuration + +Connection details — tracking server URI, experiment, authentication, TLS — are **not** passed as command-line options. They come from standard `MLFLOW_*` environment variables, either exported in your shell or written to the `.env` file you pass to `--mlflow`. Only keys beginning with `MLFLOW_` are read from the `.env` file; everything else is ignored. + +Common variables: + +| Variable | Purpose | +|---|---| +| `MLFLOW_TRACKING_URI` | URL (or local path) of the tracking server, e.g. `https://mlflow.example.org` | +| `MLFLOW_EXPERIMENT_NAME` | Name of the experiment the run is logged under | +| `MLFLOW_TRACKING_USERNAME` / `MLFLOW_TRACKING_PASSWORD` | HTTP basic-auth credentials | +| `MLFLOW_TRACKING_TOKEN` | Bearer-token auth (alternative to username/password) | + +Example `.env` file: + +``` +MLFLOW_TRACKING_URI=https://mlflow.example.org +MLFLOW_EXPERIMENT_NAME=mouse-grooming +MLFLOW_TRACKING_USERNAME=jabs +MLFLOW_TRACKING_PASSWORD=hunter2 +``` + +#### Selecting the experiment + +The run is logged under the experiment named by `MLFLOW_EXPERIMENT_NAME` (or `MLFLOW_EXPERIMENT_ID`). If neither is set, MLflow's built-in **Default** experiment is used. There is no dedicated command-line option for the experiment; set the environment variable in the `.env` file or your shell: + +```bash +export MLFLOW_EXPERIMENT_NAME=mouse-grooming +jabs-cli cross-validation /path/to/project --behavior grooming --mlflow +``` + +#### What gets logged + +Each invocation creates one MLflow run named `-cv-`. + +**Metrics** (aggregated across cross-validation iterations): + +- `cv_accuracy_mean`, `cv_accuracy_std` +- `cv_precision_behavior_mean` / `_std`, `cv_recall_behavior_mean` / `_std`, `cv_f1_behavior_mean` / `_std` +- `cv_iterations` — number of folds run +- `frames_behavior`, `frames_not_behavior`, `bouts_behavior`, `bouts_not_behavior` — dataset composition +- `training_time_ms` + +**Parameters:** `behavior`, `classifier`, `window_size`, `balance_labels`, `symmetric_behavior`, `distance_unit`, `cv_grouping_strategy`, and `cv_grouping_regex` (only for the `filename` strategy). + +**Tags:** auto-derived `behavior`, `classifier`, `cv_grouping_strategy`, and `jabs_git` (the short git SHA of the JABS checkout, when available). Any `--mlflow-tag` entries are merged on top, so a user tag wins over an auto tag with the same key. + +**Artifact:** the generated training report file, unless `--mlflow-no-report` is passed. + +#### Free-form tags + +Add searchable tags to the run with `--mlflow-tag`, which is repeatable: + +```bash +jabs-cli cross-validation /path/to/project --behavior grooming --mlflow settings.env \ + --mlflow-tag purpose=baseline --mlflow-tag cohort=2024Q1 +``` + +Each entry is `KEY=VALUE`; only the first `=` splits the entry, so values may contain `=`. + +#### Skipping the report artifact + +To log metrics and parameters only (no report upload): + +```bash +jabs-cli cross-validation /path/to/project --behavior grooming --mlflow --mlflow-no-report +``` + +#### Exit codes and failure handling + +MLflow logging happens **after** the cross-validation results are printed and the report is saved, so a logging failure never costs you the results: + +- **Extra not installed:** a warning is printed, the MLflow options are ignored, and the command exits `0`. +- **Logging fails** (for example the tracking server is unreachable or authentication fails): the results and report are preserved, a warning is printed, and the command exits with code **`3`** — distinct from the generic error code `1`, so automation can tell a push failure apart from a cross-validation failure. + +#### Full example + +```bash +# settings.env contains: +# MLFLOW_TRACKING_URI=https://mlflow.example.org +# MLFLOW_EXPERIMENT_NAME=mouse-grooming +jabs-cli cross-validation /path/to/project \ + --behavior grooming \ + --grouping-strategy individual \ + --classifier xgboost \ + --mlflow settings.env \ + --mlflow-tag purpose=baseline +``` diff --git a/src/jabs/resources/docs/user_guide/cli-tools.md b/src/jabs/resources/docs/user_guide/cli-tools.md index 6799bb6f..80bb2f73 100644 --- a/src/jabs/resources/docs/user_guide/cli-tools.md +++ b/src/jabs/resources/docs/user_guide/cli-tools.md @@ -530,3 +530,183 @@ jabs-cli convert-parquet session_poses.parquet \ --num-frames 3600 \ --out-dir /path/to/output ``` + +## jabs-cli cross-validation + +The `jabs-cli cross-validation` command runs leave-one-group-out cross-validation for a single behavior in a JABS project, then trains a final model on all labeled data to report feature importance. It prints per-iteration metrics to the console and writes a training report file (the same report produced by the GUI). Use it to estimate how well a classifier generalizes before committing to a trained model. + +Features must already be computed for the project (for example via [`jabs-init`](#jabs-init)); if they are missing this command will compute them, which can be slow. + +**Usage:** + +```bash +jabs-cli cross-validation DIRECTORY --behavior BEHAVIOR \ + [-k SPLITS] \ + [--grouping-strategy {video|individual|filename}] \ + [--grouping-pattern REGEX] \ + [--classifier {catboost|random_forest|xgboost}] \ + [--report-file FILE] \ + [--mlflow [ENV_FILE]] [--mlflow-tag KEY=VALUE] [--mlflow-no-report] +``` + +- `DIRECTORY`: Path to the JABS project directory. +- `--behavior BEHAVIOR` (required): Behavior to evaluate. Quote it if it contains spaces; must match an existing behavior in the project. +- `-k SPLITS`: Number of cross-validation iterations. `0` (the default) uses the maximum number of splits supported by the data and grouping strategy. +- `--grouping-strategy {video|individual|filename}`: How labeled frames are grouped into cross-validation folds (see [Grouping strategies](#grouping-strategies)). If omitted, the project's saved setting is used. +- `--grouping-pattern REGEX`: Regular expression applied to each video filename to derive a grouping key. Only used with `--grouping-strategy filename`. If omitted, the pattern saved in the project is used. +- `--classifier {catboost|random_forest|xgboost}`: Classifier to evaluate. Defaults to `xgboost`. The available choices depend on which classifier libraries are installed; see [Classifier Types](classifier-types.md). +- `--report-file FILE`: Where to write the training report. The format is chosen by extension: `.md` (Markdown) or `.json` (JSON). If omitted, a timestamped Markdown file is written to the current directory (`__training_report.md`). +- `--mlflow`, `--mlflow-tag`, `--mlflow-no-report`: Optional MLflow logging (see [MLflow logging](#mlflow-logging)). + +### Grouping strategies + +Cross-validation holds out one *group* of labeled data per iteration and trains on the rest, so groups define what "generalization" means for the score. JABS supports three strategies: + +| Strategy | Each group is... | Use when | +|---|---|---| +| `individual` | one (video, identity) pair | you want to measure generalization across individual animals | +| `video` | one whole video (all identities in it) | you want to measure generalization across videos/sessions | +| `filename` | all videos whose filename yields the same key under `--grouping-pattern` | videos from the same cage/cohort/day share a filename component and should not be split across train and test | + +For the `filename` strategy, the pattern is applied with `re.search`, so it matches anywhere in the filename. If the pattern has a capturing group, the first captured group is the key; otherwise the whole match is the key. Videos that do not match the pattern are placed in their own single-video group. For example, `--grouping-pattern '^(\w+?)_'` groups `cage12_2024-01-01.mp4` and `cage12_2024-01-02.mp4` together under the key `cage12`. + +### Training report + +The report (and the console output) include: + +- Per-iteration accuracy, precision and recall for both classes, and F1 for the behavior class, plus the held-out test group label for each iteration. +- The top features (by importance) from a final model trained on all labeled data. +- Labeled frame and bout counts, the window size, distance unit, classifier type, and the grouping strategy/pattern used. + +**Examples:** + +```bash +# Cross-validate "grooming" with default settings (project grouping, all splits, xgboost) +jabs-cli cross-validation /path/to/project --behavior grooming + +# 5-fold, grouped by individual animal, with a CatBoost classifier +jabs-cli cross-validation /path/to/project --behavior grooming \ + -k 5 --grouping-strategy individual --classifier catboost + +# Group videos by a shared filename prefix and write a JSON report +jabs-cli cross-validation /path/to/project --behavior grooming \ + --grouping-strategy filename --grouping-pattern '^(\w+?)_' \ + --report-file grooming_cv.json +``` + +### MLflow logging + +The cross-validation command can optionally log each run to an [MLflow](https://mlflow.org/) tracking server, recording aggregate metrics, run parameters, descriptive tags, and the training report as an artifact. This is opt-in and off by default. + +#### Installing the MLflow extra + +MLflow is an optional dependency. Install it with the `mlflow` extra: + +```bash +pip install 'jabs-behavior-classifier[mlflow]' +``` + +If you request MLflow logging without the extra installed, the command prints a warning, ignores the MLflow options, and still runs the cross-validation and writes the report (it exits `0`). + +#### Enabling logging + +Add the `--mlflow` flag: + +```bash +# Use connection settings from the ambient environment +jabs-cli cross-validation /path/to/project --behavior grooming --mlflow + +# Use connection settings from a .env file +jabs-cli cross-validation /path/to/project --behavior grooming --mlflow settings.env +``` + +`--mlflow` optionally takes the path to a `.env` file. With no path, connection settings are read from the current environment. + +#### Connection configuration + +Connection details — tracking server URI, experiment, authentication, TLS — are **not** passed as command-line options. They come from standard `MLFLOW_*` environment variables, either exported in your shell or written to the `.env` file you pass to `--mlflow`. Only keys beginning with `MLFLOW_` are read from the `.env` file; everything else is ignored. + +Common variables: + +| Variable | Purpose | +|---|---| +| `MLFLOW_TRACKING_URI` | URL (or local path) of the tracking server, e.g. `https://mlflow.example.org` | +| `MLFLOW_EXPERIMENT_NAME` | Name of the experiment the run is logged under | +| `MLFLOW_TRACKING_USERNAME` / `MLFLOW_TRACKING_PASSWORD` | HTTP basic-auth credentials | +| `MLFLOW_TRACKING_TOKEN` | Bearer-token auth (alternative to username/password) | + +Example `.env` file: + +``` +MLFLOW_TRACKING_URI=https://mlflow.example.org +MLFLOW_EXPERIMENT_NAME=mouse-grooming +MLFLOW_TRACKING_USERNAME=jabs +MLFLOW_TRACKING_PASSWORD=hunter2 +``` + +#### Selecting the experiment + +The run is logged under the experiment named by `MLFLOW_EXPERIMENT_NAME` (or `MLFLOW_EXPERIMENT_ID`). If neither is set, MLflow's built-in **Default** experiment is used. There is no dedicated command-line option for the experiment; set the environment variable in the `.env` file or your shell: + +```bash +export MLFLOW_EXPERIMENT_NAME=mouse-grooming +jabs-cli cross-validation /path/to/project --behavior grooming --mlflow +``` + +#### What gets logged + +Each invocation creates one MLflow run named `-cv-`. + +**Metrics** (aggregated across cross-validation iterations): + +- `cv_accuracy_mean`, `cv_accuracy_std` +- `cv_precision_behavior_mean` / `_std`, `cv_recall_behavior_mean` / `_std`, `cv_f1_behavior_mean` / `_std` +- `cv_iterations` — number of folds run +- `frames_behavior`, `frames_not_behavior`, `bouts_behavior`, `bouts_not_behavior` — dataset composition +- `training_time_ms` + +**Parameters:** `behavior`, `classifier`, `window_size`, `balance_labels`, `symmetric_behavior`, `distance_unit`, `cv_grouping_strategy`, and `cv_grouping_regex` (only for the `filename` strategy). + +**Tags:** auto-derived `behavior`, `classifier`, `cv_grouping_strategy`, and `jabs_git` (the short git SHA of the JABS checkout, when available). Any `--mlflow-tag` entries are merged on top, so a user tag wins over an auto tag with the same key. + +**Artifact:** the generated training report file, unless `--mlflow-no-report` is passed. + +#### Free-form tags + +Add searchable tags to the run with `--mlflow-tag`, which is repeatable: + +```bash +jabs-cli cross-validation /path/to/project --behavior grooming --mlflow settings.env \ + --mlflow-tag purpose=baseline --mlflow-tag cohort=2024Q1 +``` + +Each entry is `KEY=VALUE`; only the first `=` splits the entry, so values may contain `=`. + +#### Skipping the report artifact + +To log metrics and parameters only (no report upload): + +```bash +jabs-cli cross-validation /path/to/project --behavior grooming --mlflow --mlflow-no-report +``` + +#### Exit codes and failure handling + +MLflow logging happens **after** the cross-validation results are printed and the report is saved, so a logging failure never costs you the results: + +- **Extra not installed:** a warning is printed, the MLflow options are ignored, and the command exits `0`. +- **Logging fails** (for example the tracking server is unreachable or authentication fails): the results and report are preserved, a warning is printed, and the command exits with code **`3`** — distinct from the generic error code `1`, so automation can tell a push failure apart from a cross-validation failure. + +#### Full example + +```bash +# settings.env contains: +# MLFLOW_TRACKING_URI=https://mlflow.example.org +# MLFLOW_EXPERIMENT_NAME=mouse-grooming +jabs-cli cross-validation /path/to/project \ + --behavior grooming \ + --grouping-strategy individual \ + --classifier xgboost \ + --mlflow settings.env \ + --mlflow-tag purpose=baseline +``` From 94b4c8272bf3cb13b08391c663533a1c5bd36871 Mon Sep 17 00:00:00 2001 From: Glen Beane <356266+gbeane@users.noreply.github.com> Date: Tue, 23 Jun 2026 20:35:05 -0400 Subject: [PATCH 4/7] Log cross-validation runs to a per-behavior MLflow experiment --- docs/user-guide/cli-tools.md | 31 ++++++++++----- src/jabs/classifier/mlflow_logging.py | 39 ++++++++++++++++++- .../resources/docs/user_guide/cli-tools.md | 31 ++++++++++----- src/jabs/scripts/cli/cli.py | 12 ++++++ src/jabs/scripts/cli/cross_validation.py | 4 ++ tests/classifier/test_mlflow_logging.py | 35 +++++++++++++++++ tests/scripts/test_cross_validation_cli.py | 14 +++++++ 7 files changed, 147 insertions(+), 19 deletions(-) diff --git a/docs/user-guide/cli-tools.md b/docs/user-guide/cli-tools.md index 80bb2f73..33a43bc6 100644 --- a/docs/user-guide/cli-tools.md +++ b/docs/user-guide/cli-tools.md @@ -628,31 +628,44 @@ Connection details — tracking server URI, experiment, authentication, TLS — Common variables: -| Variable | Purpose | -|---|---| -| `MLFLOW_TRACKING_URI` | URL (or local path) of the tracking server, e.g. `https://mlflow.example.org` | -| `MLFLOW_EXPERIMENT_NAME` | Name of the experiment the run is logged under | -| `MLFLOW_TRACKING_USERNAME` / `MLFLOW_TRACKING_PASSWORD` | HTTP basic-auth credentials | -| `MLFLOW_TRACKING_TOKEN` | Bearer-token auth (alternative to username/password) | +| Variable | Purpose | +|---------------------------------------------------------|---------------------------------------------------------------------------------------------------| +| `MLFLOW_TRACKING_URI` | URL (or local path) of the tracking server, e.g. `https://mlflow.example.org` | +| `MLFLOW_EXPERIMENT_NAME` | Overrides the default experiment name (see [Selecting the experiment](#selecting-the-experiment)) | +| `MLFLOW_TRACKING_USERNAME` / `MLFLOW_TRACKING_PASSWORD` | HTTP basic-auth credentials | +| `MLFLOW_TRACKING_TOKEN` | Bearer-token auth (alternative to username/password) | Example `.env` file: ``` MLFLOW_TRACKING_URI=https://mlflow.example.org -MLFLOW_EXPERIMENT_NAME=mouse-grooming MLFLOW_TRACKING_USERNAME=jabs MLFLOW_TRACKING_PASSWORD=hunter2 ``` #### Selecting the experiment -The run is logged under the experiment named by `MLFLOW_EXPERIMENT_NAME` (or `MLFLOW_EXPERIMENT_ID`). If neither is set, MLflow's built-in **Default** experiment is used. There is no dedicated command-line option for the experiment; set the environment variable in the `.env` file or your shell: +Each behavior is logged to its **own experiment** by default, named `jabs-` (for example `jabs-grooming`). This keeps comparisons meaningful: an experiment's runs table is effectively a leaderboard, and you want to rank runs of the *same* behavior over time rather than mix behaviors, whose metrics are not comparable. The experiment is created automatically if it does not exist. + +To override the experiment name, in order of precedence: + +1. `--mlflow-experiment NAME` (highest) — use a specific experiment for this run. +2. `MLFLOW_EXPERIMENT_NAME` (in your shell or the `.env` file). +3. The default `jabs-`. ```bash -export MLFLOW_EXPERIMENT_NAME=mouse-grooming +# Default: logs to experiment "jabs-grooming" jabs-cli cross-validation /path/to/project --behavior grooming --mlflow + +# Override the experiment for this run +jabs-cli cross-validation /path/to/project --behavior grooming --mlflow \ + --mlflow-experiment grooming-hyperparam-sweep ``` +#### Comparing runs (leaderboard) + +The aggregate scores below (`cv_f1_behavior_mean`, `cv_accuracy_mean`, etc.) are logged as MLflow **metrics**, so an experiment's runs table doubles as a leaderboard — sort by the `cv_f1_behavior_mean` column to rank a behavior's runs by mean F1. The full per-fold breakdown rides along as the training-report artifact. + #### What gets logged Each invocation creates one MLflow run named `-cv-`. diff --git a/src/jabs/classifier/mlflow_logging.py b/src/jabs/classifier/mlflow_logging.py index 5745e866..4ece7f08 100644 --- a/src/jabs/classifier/mlflow_logging.py +++ b/src/jabs/classifier/mlflow_logging.py @@ -233,11 +233,37 @@ def build_tags(report_data: TrainingReportData) -> dict[str, str]: return {key: value for key, value in tags.items() if value} +def resolve_experiment_name(report_data: TrainingReportData, experiment_name: str | None) -> str: + """Resolve the MLflow experiment name for a cross-validation run. + + Each behavior is logged to its own experiment by default so that runs of the + same behavior are compared together (and not mixed with other behaviors, whose + metrics are not comparable). Precedence, highest first: + + 1. ``experiment_name`` -- an explicit override (e.g. from ``--mlflow-experiment``). + 2. The ``MLFLOW_EXPERIMENT_NAME`` environment variable, if set. + 3. The default ``jabs-``. + + Args: + report_data: Completed training report data (supplies the behavior name). + experiment_name: Explicit override, or None to fall back to the env var/default. + + Returns: + The experiment name to use. + """ + return ( + experiment_name + or os.environ.get("MLFLOW_EXPERIMENT_NAME") + or f"jabs-{report_data.behavior_name}" + ) + + def log_cross_validation_to_mlflow( *, report_data: TrainingReportData, report_file: Path | None = None, env_file: Path | None = None, + experiment_name: str | None = None, run_name: str | None = None, tags: dict[str, str] | None = None, log_report_artifact: bool = True, @@ -255,6 +281,10 @@ def log_cross_validation_to_mlflow( False. env_file: Optional ``.env`` file with ``MLFLOW_*`` connection settings. If None, connection config comes from the ambient environment. + experiment_name: Explicit MLflow experiment name. If None, the experiment is + resolved by :func:`resolve_experiment_name` (``MLFLOW_EXPERIMENT_NAME`` env + var, else the default ``jabs-``). The experiment is created if it + does not exist. run_name: MLflow run name. Defaults to ``-cv-``, where the timestamp is the report's completion time (so it matches the saved report). tags: Caller-supplied run tags; merged over the auto-derived tags (so a @@ -278,10 +308,17 @@ def log_cross_validation_to_mlflow( load_env_file(env_file) + resolved_experiment = resolve_experiment_name(report_data, experiment_name) + mlflow.set_experiment(resolved_experiment) + if run_name is None: run_name = f"{report_data.behavior_name}-cv-{report_data.timestamp:%Y%m%d-%H%M%S}" - logger.info("Logging cross-validation results to MLflow run %r", run_name) + logger.info( + "Logging cross-validation results to MLflow experiment %r run %r", + resolved_experiment, + run_name, + ) with mlflow.start_run(run_name=run_name) as run: metrics = aggregate_cv_metrics(report_data) for key, value in metrics.items(): diff --git a/src/jabs/resources/docs/user_guide/cli-tools.md b/src/jabs/resources/docs/user_guide/cli-tools.md index 80bb2f73..33a43bc6 100644 --- a/src/jabs/resources/docs/user_guide/cli-tools.md +++ b/src/jabs/resources/docs/user_guide/cli-tools.md @@ -628,31 +628,44 @@ Connection details — tracking server URI, experiment, authentication, TLS — Common variables: -| Variable | Purpose | -|---|---| -| `MLFLOW_TRACKING_URI` | URL (or local path) of the tracking server, e.g. `https://mlflow.example.org` | -| `MLFLOW_EXPERIMENT_NAME` | Name of the experiment the run is logged under | -| `MLFLOW_TRACKING_USERNAME` / `MLFLOW_TRACKING_PASSWORD` | HTTP basic-auth credentials | -| `MLFLOW_TRACKING_TOKEN` | Bearer-token auth (alternative to username/password) | +| Variable | Purpose | +|---------------------------------------------------------|---------------------------------------------------------------------------------------------------| +| `MLFLOW_TRACKING_URI` | URL (or local path) of the tracking server, e.g. `https://mlflow.example.org` | +| `MLFLOW_EXPERIMENT_NAME` | Overrides the default experiment name (see [Selecting the experiment](#selecting-the-experiment)) | +| `MLFLOW_TRACKING_USERNAME` / `MLFLOW_TRACKING_PASSWORD` | HTTP basic-auth credentials | +| `MLFLOW_TRACKING_TOKEN` | Bearer-token auth (alternative to username/password) | Example `.env` file: ``` MLFLOW_TRACKING_URI=https://mlflow.example.org -MLFLOW_EXPERIMENT_NAME=mouse-grooming MLFLOW_TRACKING_USERNAME=jabs MLFLOW_TRACKING_PASSWORD=hunter2 ``` #### Selecting the experiment -The run is logged under the experiment named by `MLFLOW_EXPERIMENT_NAME` (or `MLFLOW_EXPERIMENT_ID`). If neither is set, MLflow's built-in **Default** experiment is used. There is no dedicated command-line option for the experiment; set the environment variable in the `.env` file or your shell: +Each behavior is logged to its **own experiment** by default, named `jabs-` (for example `jabs-grooming`). This keeps comparisons meaningful: an experiment's runs table is effectively a leaderboard, and you want to rank runs of the *same* behavior over time rather than mix behaviors, whose metrics are not comparable. The experiment is created automatically if it does not exist. + +To override the experiment name, in order of precedence: + +1. `--mlflow-experiment NAME` (highest) — use a specific experiment for this run. +2. `MLFLOW_EXPERIMENT_NAME` (in your shell or the `.env` file). +3. The default `jabs-`. ```bash -export MLFLOW_EXPERIMENT_NAME=mouse-grooming +# Default: logs to experiment "jabs-grooming" jabs-cli cross-validation /path/to/project --behavior grooming --mlflow + +# Override the experiment for this run +jabs-cli cross-validation /path/to/project --behavior grooming --mlflow \ + --mlflow-experiment grooming-hyperparam-sweep ``` +#### Comparing runs (leaderboard) + +The aggregate scores below (`cv_f1_behavior_mean`, `cv_accuracy_mean`, etc.) are logged as MLflow **metrics**, so an experiment's runs table doubles as a leaderboard — sort by the `cv_f1_behavior_mean` column to rank a behavior's runs by mean F1. The full per-fold breakdown rides along as the training-report artifact. + #### What gets logged Each invocation creates one MLflow run named `-cv-`. diff --git a/src/jabs/scripts/cli/cli.py b/src/jabs/scripts/cli/cli.py index 3796daa0..e5bde322 100644 --- a/src/jabs/scripts/cli/cli.py +++ b/src/jabs/scripts/cli/cli.py @@ -370,6 +370,16 @@ def prune(ctx: click.Context, directory: Path, behavior: str | None): "behavior unchanged. Requires the optional 'mlflow' extra " "(pip install 'jabs-behavior-classifier[mlflow]').", ) +@click.option( + "--mlflow-experiment", + "mlflow_experiment", + type=str, + default=None, + metavar="NAME", + help="With --mlflow, the MLflow experiment to log the run under. If not provided, " + "defaults to the MLFLOW_EXPERIMENT_NAME environment variable, else 'jabs-' " + "(one experiment per behavior). No-op without --mlflow.", +) @click.option( "--mlflow-tag", "mlflow_tags", @@ -396,6 +406,7 @@ def cross_validation( classifier: str, report_file: Path | None, mlflow_env: str | None, + mlflow_experiment: str | None, mlflow_tags: tuple[str, ...], mlflow_no_report: bool, ): @@ -445,6 +456,7 @@ def cross_validation( grouping_regex=grouping_pattern, mlflow_enabled=mlflow_enabled, mlflow_env_file=mlflow_env_file, + mlflow_experiment=mlflow_experiment, mlflow_tags=parsed_mlflow_tags, mlflow_log_report=not mlflow_no_report, ) diff --git a/src/jabs/scripts/cli/cross_validation.py b/src/jabs/scripts/cli/cross_validation.py index 1a60df4f..ba27efc5 100644 --- a/src/jabs/scripts/cli/cross_validation.py +++ b/src/jabs/scripts/cli/cross_validation.py @@ -32,6 +32,7 @@ def run_cross_validation( grouping_regex: str | None = None, mlflow_enabled: bool = False, mlflow_env_file: Path | None = None, + mlflow_experiment: str | None = None, mlflow_tags: dict[str, str] | None = None, mlflow_log_report: bool = True, ) -> None: @@ -56,6 +57,8 @@ def run_cross_validation( 'mlflow' dependency is installed (the CLI checks this and warns otherwise). mlflow_env_file (Path | None): Optional ``.env`` file with ``MLFLOW_*`` connection settings. If None, connection config comes from the ambient environment. + mlflow_experiment (str | None): Explicit MLflow experiment name. If None, defaults + to the ``MLFLOW_EXPERIMENT_NAME`` env var, else ``jabs-``. mlflow_tags (dict[str, str] | None): Optional free-form MLflow run tags, merged over the auto-derived tags. mlflow_log_report (bool): Whether to upload the training report as an MLflow @@ -256,6 +259,7 @@ def progress_callback(): report_data=training_data, report_file=report_file, env_file=mlflow_env_file, + experiment_name=mlflow_experiment, tags=mlflow_tags, log_report_artifact=mlflow_log_report, ) diff --git a/tests/classifier/test_mlflow_logging.py b/tests/classifier/test_mlflow_logging.py index ce3e3b94..60d28ee0 100644 --- a/tests/classifier/test_mlflow_logging.py +++ b/tests/classifier/test_mlflow_logging.py @@ -24,6 +24,7 @@ log_cross_validation_to_mlflow, mlflow_available, parse_kv_tags, + resolve_experiment_name, ) from jabs.classifier.training_report import BinaryCVResult, TrainingReportData from jabs.core.enums import CrossValidationGroupingStrategy @@ -88,11 +89,15 @@ class _FakeMlflow: def __init__(self) -> None: self.run_name: str | None = None + self.experiment: str | None = None self.metrics: dict[str, float] = {} self.params: dict[str, str] = {} self.tags: dict[str, str] = {} self.artifacts: list[str] = [] + def set_experiment(self, name: str) -> None: + self.experiment = name + def start_run(self, run_name: str | None = None) -> _FakeRun: self.run_name = run_name return _FakeRun() @@ -250,6 +255,33 @@ def test_build_tags_omits_empty( assert "jabs_git" not in tags +# --------------------------------------------------------------------------- # +# resolve_experiment_name +# --------------------------------------------------------------------------- # +def test_resolve_experiment_name_default( + binary_report: TrainingReportData, monkeypatch: pytest.MonkeyPatch +) -> None: + """With no override or env var, the experiment defaults to jabs-.""" + monkeypatch.delenv("MLFLOW_EXPERIMENT_NAME", raising=False) + assert resolve_experiment_name(binary_report, None) == "jabs-Walk" + + +def test_resolve_experiment_name_env_var( + binary_report: TrainingReportData, monkeypatch: pytest.MonkeyPatch +) -> None: + """MLFLOW_EXPERIMENT_NAME overrides the default when no explicit name is given.""" + monkeypatch.setenv("MLFLOW_EXPERIMENT_NAME", "shared-experiment") + assert resolve_experiment_name(binary_report, None) == "shared-experiment" + + +def test_resolve_experiment_name_explicit_wins( + binary_report: TrainingReportData, monkeypatch: pytest.MonkeyPatch +) -> None: + """An explicit name takes precedence over both the env var and the default.""" + monkeypatch.setenv("MLFLOW_EXPERIMENT_NAME", "shared-experiment") + assert resolve_experiment_name(binary_report, "custom") == "custom" + + # --------------------------------------------------------------------------- # # log_cross_validation_to_mlflow # --------------------------------------------------------------------------- # @@ -261,6 +293,7 @@ def test_log_cross_validation_to_mlflow( """A run logs metrics/params/tags/artifact and returns run id + tracking URI.""" fake = _FakeMlflow() monkeypatch.setitem(sys.modules, "mlflow", fake) + monkeypatch.delenv("MLFLOW_EXPERIMENT_NAME", raising=False) monkeypatch.setattr("jabs.classifier.mlflow_logging._git_sha", lambda: "abc1234") report_file = tmp_path / "report.md" report_file.write_text("# report") @@ -273,6 +306,8 @@ def test_log_cross_validation_to_mlflow( assert run_id == "run-123" assert tracking_uri == "file:///tmp/mlruns" + # default per-behavior experiment + assert fake.experiment == "jabs-Walk" assert fake.run_name == "Walk-cv-20260623-120000" assert fake.metrics["cv_accuracy_mean"] == pytest.approx(0.85) assert fake.params["behavior"] == "Walk" diff --git a/tests/scripts/test_cross_validation_cli.py b/tests/scripts/test_cross_validation_cli.py index 3df47573..84f4997b 100644 --- a/tests/scripts/test_cross_validation_cli.py +++ b/tests/scripts/test_cross_validation_cli.py @@ -168,6 +168,20 @@ def test_mlflow_no_report_flag( assert run_cv_spy.call_args.kwargs["mlflow_log_report"] is False +def test_mlflow_experiment_forwarded( + tmp_path: Path, run_cv_spy: mock.Mock, mlflow_installed: None +) -> None: + """--mlflow-experiment is forwarded; default is None (resolved downstream).""" + result = _invoke(tmp_path, "--mlflow", "--mlflow-experiment", "my-experiment") + assert result.exit_code == 0, result.output + assert run_cv_spy.call_args.kwargs["mlflow_experiment"] == "my-experiment" + + run_cv_spy.reset_mock() + result = _invoke(tmp_path, "--mlflow") + assert result.exit_code == 0, result.output + assert run_cv_spy.call_args.kwargs["mlflow_experiment"] is None + + def test_invalid_mlflow_tag_rejected(tmp_path: Path, run_cv_spy: mock.Mock) -> None: """A malformed --mlflow-tag fails before run_cross_validation is called.""" result = _invoke(tmp_path, "--mlflow", "--mlflow-tag", "noequals") From 2de35e440edbacaca018208df68a870c60b7bb5f Mon Sep 17 00:00:00 2001 From: Glen Beane <356266+gbeane@users.noreply.github.com> Date: Tue, 23 Jun 2026 21:08:31 -0400 Subject: [PATCH 5/7] Address PR review: scope MLflow option parsing, fix docstring/docs/error wrapping --- docs/user-guide/cli-tools.md | 4 ++-- src/jabs/classifier/mlflow_logging.py | 11 +++++----- .../resources/docs/user_guide/cli-tools.md | 4 ++-- src/jabs/scripts/cli/cli.py | 17 ++++++++++----- src/jabs/scripts/cli/cross_validation.py | 4 ++++ tests/scripts/test_cross_validation_cli.py | 21 ++++++++++++++++--- 6 files changed, 44 insertions(+), 17 deletions(-) diff --git a/docs/user-guide/cli-tools.md b/docs/user-guide/cli-tools.md index 33a43bc6..d6673230 100644 --- a/docs/user-guide/cli-tools.md +++ b/docs/user-guide/cli-tools.md @@ -546,7 +546,7 @@ jabs-cli cross-validation DIRECTORY --behavior BEHAVIOR \ [--grouping-pattern REGEX] \ [--classifier {catboost|random_forest|xgboost}] \ [--report-file FILE] \ - [--mlflow [ENV_FILE]] [--mlflow-tag KEY=VALUE] [--mlflow-no-report] + [--mlflow [ENV_FILE]] [--mlflow-experiment NAME] [--mlflow-tag KEY=VALUE] [--mlflow-no-report] ``` - `DIRECTORY`: Path to the JABS project directory. @@ -556,7 +556,7 @@ jabs-cli cross-validation DIRECTORY --behavior BEHAVIOR \ - `--grouping-pattern REGEX`: Regular expression applied to each video filename to derive a grouping key. Only used with `--grouping-strategy filename`. If omitted, the pattern saved in the project is used. - `--classifier {catboost|random_forest|xgboost}`: Classifier to evaluate. Defaults to `xgboost`. The available choices depend on which classifier libraries are installed; see [Classifier Types](classifier-types.md). - `--report-file FILE`: Where to write the training report. The format is chosen by extension: `.md` (Markdown) or `.json` (JSON). If omitted, a timestamped Markdown file is written to the current directory (`__training_report.md`). -- `--mlflow`, `--mlflow-tag`, `--mlflow-no-report`: Optional MLflow logging (see [MLflow logging](#mlflow-logging)). +- `--mlflow`, `--mlflow-experiment`, `--mlflow-tag`, `--mlflow-no-report`: Optional MLflow logging (see [MLflow logging](#mlflow-logging)). ### Grouping strategies diff --git a/src/jabs/classifier/mlflow_logging.py b/src/jabs/classifier/mlflow_logging.py index 4ece7f08..714fe1d7 100644 --- a/src/jabs/classifier/mlflow_logging.py +++ b/src/jabs/classifier/mlflow_logging.py @@ -5,11 +5,12 @@ configuration scalars as params, descriptive tags, and the generated training report as an artifact. -Connection configuration (tracking URI, experiment, auth, TLS) is **not** -hard-coded here; it is read from standard ``MLFLOW_*`` environment variables, -populated either from a ``.env`` file (see :func:`load_env_file`) or from the -ambient environment. The experiment is whatever ``MLFLOW_EXPERIMENT_NAME`` -names, falling back to MLflow's built-in "Default" experiment. +Connection configuration (tracking URI, auth, TLS) is **not** hard-coded here; +it is read from standard ``MLFLOW_*`` environment variables, populated either +from a ``.env`` file (see :func:`load_env_file`) or from the ambient +environment. Each run is logged to a per-behavior experiment by default +(``jabs-``); see :func:`resolve_experiment_name` for the override +precedence (explicit name, then ``MLFLOW_EXPERIMENT_NAME``, then the default). ``mlflow`` is an optional dependency. Install it with ``pip install 'jabs-behavior-classifier[mlflow]'`` (or, for a development diff --git a/src/jabs/resources/docs/user_guide/cli-tools.md b/src/jabs/resources/docs/user_guide/cli-tools.md index 33a43bc6..d6673230 100644 --- a/src/jabs/resources/docs/user_guide/cli-tools.md +++ b/src/jabs/resources/docs/user_guide/cli-tools.md @@ -546,7 +546,7 @@ jabs-cli cross-validation DIRECTORY --behavior BEHAVIOR \ [--grouping-pattern REGEX] \ [--classifier {catboost|random_forest|xgboost}] \ [--report-file FILE] \ - [--mlflow [ENV_FILE]] [--mlflow-tag KEY=VALUE] [--mlflow-no-report] + [--mlflow [ENV_FILE]] [--mlflow-experiment NAME] [--mlflow-tag KEY=VALUE] [--mlflow-no-report] ``` - `DIRECTORY`: Path to the JABS project directory. @@ -556,7 +556,7 @@ jabs-cli cross-validation DIRECTORY --behavior BEHAVIOR \ - `--grouping-pattern REGEX`: Regular expression applied to each video filename to derive a grouping key. Only used with `--grouping-strategy filename`. If omitted, the pattern saved in the project is used. - `--classifier {catboost|random_forest|xgboost}`: Classifier to evaluate. Defaults to `xgboost`. The available choices depend on which classifier libraries are installed; see [Classifier Types](classifier-types.md). - `--report-file FILE`: Where to write the training report. The format is chosen by extension: `.md` (Markdown) or `.json` (JSON). If omitted, a timestamped Markdown file is written to the current directory (`__training_report.md`). -- `--mlflow`, `--mlflow-tag`, `--mlflow-no-report`: Optional MLflow logging (see [MLflow logging](#mlflow-logging)). +- `--mlflow`, `--mlflow-experiment`, `--mlflow-tag`, `--mlflow-no-report`: Optional MLflow logging (see [MLflow logging](#mlflow-logging)). ### Grouping strategies diff --git a/src/jabs/scripts/cli/cli.py b/src/jabs/scripts/cli/cli.py index e5bde322..fc0885ff 100644 --- a/src/jabs/scripts/cli/cli.py +++ b/src/jabs/scripts/cli/cli.py @@ -426,11 +426,8 @@ def cross_validation( # --mlflow: absent -> None (disabled); bare flag -> "" (ambient env); # with a path -> that .env file. mlflow_enabled = mlflow_env is not None - mlflow_env_file = Path(mlflow_env) if mlflow_env else None - try: - parsed_mlflow_tags = parse_kv_tags(list(mlflow_tags)) - except ValueError as e: - raise click.ClickException(str(e)) from e + mlflow_env_file: Path | None = None + parsed_mlflow_tags: dict[str, str] = {} # If MLflow logging was requested but the optional 'mlflow' extra is not # installed, warn and ignore the MLflow options rather than failing -- the @@ -444,6 +441,16 @@ def cross_validation( ) mlflow_enabled = False + # Only interpret the other MLflow options when logging is actually enabled. + # They are documented as no-ops without --mlflow, so e.g. a malformed + # --mlflow-tag is ignored rather than failing the command. + if mlflow_enabled: + mlflow_env_file = Path(mlflow_env) if mlflow_env else None + try: + parsed_mlflow_tags = parse_kv_tags(list(mlflow_tags)) + except ValueError as e: + raise click.ClickException(str(e)) from e + try: classifier_type = ClassifierType[classifier.upper()] run_cross_validation( diff --git a/src/jabs/scripts/cli/cross_validation.py b/src/jabs/scripts/cli/cross_validation.py index ba27efc5..586ae54a 100644 --- a/src/jabs/scripts/cli/cross_validation.py +++ b/src/jabs/scripts/cli/cross_validation.py @@ -269,6 +269,10 @@ def progress_callback(): " (cross-validation results above and the saved report are unaffected)", style="yellow", ) + # Preserve an MlflowLoggingError raised by the logger (e.g. missing + # dependency); only wrap genuinely unexpected exceptions. + if isinstance(e, MlflowLoggingError): + raise raise MlflowLoggingError(str(e)) from e console.print( f"\nLogged cross-validation results to MLflow run {run_id} ({tracking_uri})", diff --git a/tests/scripts/test_cross_validation_cli.py b/tests/scripts/test_cross_validation_cli.py index 84f4997b..af827fa5 100644 --- a/tests/scripts/test_cross_validation_cli.py +++ b/tests/scripts/test_cross_validation_cli.py @@ -182,14 +182,26 @@ def test_mlflow_experiment_forwarded( assert run_cv_spy.call_args.kwargs["mlflow_experiment"] is None -def test_invalid_mlflow_tag_rejected(tmp_path: Path, run_cv_spy: mock.Mock) -> None: - """A malformed --mlflow-tag fails before run_cross_validation is called.""" +def test_invalid_mlflow_tag_rejected( + tmp_path: Path, run_cv_spy: mock.Mock, mlflow_installed: None +) -> None: + """A malformed --mlflow-tag fails before run_cross_validation when MLflow is enabled.""" result = _invoke(tmp_path, "--mlflow", "--mlflow-tag", "noequals") assert result.exit_code != 0 run_cv_spy.assert_not_called() +def test_mlflow_tag_ignored_without_mlflow(tmp_path: Path, run_cv_spy: mock.Mock) -> None: + """--mlflow-tag is a no-op (even if malformed) when --mlflow is not given.""" + result = _invoke(tmp_path, "--mlflow-tag", "noequals") + + assert result.exit_code == 0, result.output + run_cv_spy.assert_called_once() + assert run_cv_spy.call_args.kwargs["mlflow_enabled"] is False + assert run_cv_spy.call_args.kwargs["mlflow_tags"] == {} + + def test_mlflow_logging_failure_exits_with_code_3( tmp_path: Path, run_cv_spy: mock.Mock, mlflow_installed: None ) -> None: @@ -207,10 +219,13 @@ def test_mlflow_unavailable_warns_and_ignores( """When the mlflow extra is absent, --mlflow is ignored with a warning (exit 0).""" monkeypatch.setattr(cli_module, "mlflow_available", lambda: False) - result = _invoke(tmp_path, "--mlflow", "--mlflow-tag", "purpose=baseline") + # A malformed tag must NOT error here: the options are ignored when the extra + # is missing, so the tag is never parsed. + result = _invoke(tmp_path, "--mlflow", "--mlflow-tag", "noequals") assert result.exit_code == 0, result.output assert "not installed" in result.stderr # cross-validation still runs, but MLflow logging is disabled run_cv_spy.assert_called_once() assert run_cv_spy.call_args.kwargs["mlflow_enabled"] is False + assert run_cv_spy.call_args.kwargs["mlflow_tags"] == {} From 75f86a623a942b8fd9d9f3fa1255586d8618d355 Mon Sep 17 00:00:00 2001 From: Glen Beane <356266+gbeane@users.noreply.github.com> Date: Tue, 23 Jun 2026 21:33:41 -0400 Subject: [PATCH 6/7] Make load_env_file test robust to ambient environment --- tests/classifier/test_mlflow_logging.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/classifier/test_mlflow_logging.py b/tests/classifier/test_mlflow_logging.py index 60d28ee0..7f746f11 100644 --- a/tests/classifier/test_mlflow_logging.py +++ b/tests/classifier/test_mlflow_logging.py @@ -171,13 +171,18 @@ def test_load_env_file_applies_only_mlflow_keys( tmp_path: Path, monkeypatch: pytest.MonkeyPatch ) -> None: """Only MLFLOW_* keys are applied to the environment; others are ignored.""" + # Use a unique, clearly-unrelated key and ensure it is absent from the + # (copied) environment up front, so the negative assertion below proves + # load_env_file did not add it rather than depending on the ambient env. + non_mlflow_key = "JABS_TEST_NON_MLFLOW_VAR" monkeypatch.setattr(os, "environ", dict(os.environ)) + os.environ.pop(non_mlflow_key, None) env_file = tmp_path / "mlflow.env" env_file.write_text( "# a comment\n" 'MLFLOW_TRACKING_URI="https://mlflow.example.org"\n' "export MLFLOW_EXPERIMENT_NAME=behaviors\n" - "OTHER_VAR=should-be-ignored\n" + f"{non_mlflow_key}=should-be-ignored\n" ) applied = load_env_file(env_file) @@ -188,7 +193,7 @@ def test_load_env_file_applies_only_mlflow_keys( } assert os.environ["MLFLOW_TRACKING_URI"] == "https://mlflow.example.org" assert os.environ["MLFLOW_EXPERIMENT_NAME"] == "behaviors" - assert "OTHER_VAR" not in os.environ + assert non_mlflow_key not in os.environ def test_load_env_file_missing_raises(tmp_path: Path) -> None: From 3cd1525b5e9ddff93216da3f90a3e8618fa89407 Mon Sep 17 00:00:00 2001 From: Glen Beane <356266+gbeane@users.noreply.github.com> Date: Wed, 24 Jun 2026 15:40:45 -0400 Subject: [PATCH 7/7] Fail fast when --mlflow requested but mlflow extra is not installed --- docs/user-guide/cli-tools.md | 4 ++-- src/jabs/resources/docs/user_guide/cli-tools.md | 4 ++-- src/jabs/scripts/cli/cli.py | 16 +++++++--------- tests/scripts/test_cross_validation_cli.py | 16 ++++++---------- 4 files changed, 17 insertions(+), 23 deletions(-) diff --git a/docs/user-guide/cli-tools.md b/docs/user-guide/cli-tools.md index d6673230..c7e7d160 100644 --- a/docs/user-guide/cli-tools.md +++ b/docs/user-guide/cli-tools.md @@ -606,7 +606,7 @@ MLflow is an optional dependency. Install it with the `mlflow` extra: pip install 'jabs-behavior-classifier[mlflow]' ``` -If you request MLflow logging without the extra installed, the command prints a warning, ignores the MLflow options, and still runs the cross-validation and writes the report (it exits `0`). +If you request MLflow logging (`--mlflow`) without the extra installed, the command fails immediately with an error and exits `1` (before running the cross-validation), so you can install the extra and re-run rather than discovering after a long run that nothing was logged. #### Enabling logging @@ -707,7 +707,7 @@ jabs-cli cross-validation /path/to/project --behavior grooming --mlflow --mlflow MLflow logging happens **after** the cross-validation results are printed and the report is saved, so a logging failure never costs you the results: -- **Extra not installed:** a warning is printed, the MLflow options are ignored, and the command exits `0`. +- **Extra not installed:** the command fails fast with an error and exits `1` before running the cross-validation, since logging was explicitly requested but cannot be honored. Install the extra (or drop `--mlflow`) and re-run. - **Logging fails** (for example the tracking server is unreachable or authentication fails): the results and report are preserved, a warning is printed, and the command exits with code **`3`** — distinct from the generic error code `1`, so automation can tell a push failure apart from a cross-validation failure. #### Full example diff --git a/src/jabs/resources/docs/user_guide/cli-tools.md b/src/jabs/resources/docs/user_guide/cli-tools.md index d6673230..c7e7d160 100644 --- a/src/jabs/resources/docs/user_guide/cli-tools.md +++ b/src/jabs/resources/docs/user_guide/cli-tools.md @@ -606,7 +606,7 @@ MLflow is an optional dependency. Install it with the `mlflow` extra: pip install 'jabs-behavior-classifier[mlflow]' ``` -If you request MLflow logging without the extra installed, the command prints a warning, ignores the MLflow options, and still runs the cross-validation and writes the report (it exits `0`). +If you request MLflow logging (`--mlflow`) without the extra installed, the command fails immediately with an error and exits `1` (before running the cross-validation), so you can install the extra and re-run rather than discovering after a long run that nothing was logged. #### Enabling logging @@ -707,7 +707,7 @@ jabs-cli cross-validation /path/to/project --behavior grooming --mlflow --mlflow MLflow logging happens **after** the cross-validation results are printed and the report is saved, so a logging failure never costs you the results: -- **Extra not installed:** a warning is printed, the MLflow options are ignored, and the command exits `0`. +- **Extra not installed:** the command fails fast with an error and exits `1` before running the cross-validation, since logging was explicitly requested but cannot be honored. Install the extra (or drop `--mlflow`) and re-run. - **Logging fails** (for example the tracking server is unreachable or authentication fails): the results and report are preserved, a warning is printed, and the command exits with code **`3`** — distinct from the generic error code `1`, so automation can tell a push failure apart from a cross-validation failure. #### Full example diff --git a/src/jabs/scripts/cli/cli.py b/src/jabs/scripts/cli/cli.py index fc0885ff..0e71487c 100644 --- a/src/jabs/scripts/cli/cli.py +++ b/src/jabs/scripts/cli/cli.py @@ -429,17 +429,15 @@ def cross_validation( mlflow_env_file: Path | None = None parsed_mlflow_tags: dict[str, str] = {} - # If MLflow logging was requested but the optional 'mlflow' extra is not - # installed, warn and ignore the MLflow options rather than failing -- the - # cross-validation still runs and the report is still produced. + # If MLflow logging was explicitly requested but the optional 'mlflow' extra + # is not installed, fail fast before running the (potentially long) + # cross-validation rather than silently producing a run with no logging. if mlflow_enabled and not mlflow_available(): - click.echo( - "Warning: MLflow logging was requested (--mlflow) but the optional 'mlflow' " - "dependency is not installed; ignoring MLflow options. Install it with " - "\"pip install 'jabs-behavior-classifier[mlflow]'\" to enable logging.", - err=True, + raise click.ClickException( + "MLflow logging was requested (--mlflow) but the optional 'mlflow' " + "dependency is not installed. Install it with " + "\"pip install 'jabs-behavior-classifier[mlflow]'\", or omit --mlflow." ) - mlflow_enabled = False # Only interpret the other MLflow options when logging is actually enabled. # They are documented as no-ops without --mlflow, so e.g. a malformed diff --git a/tests/scripts/test_cross_validation_cli.py b/tests/scripts/test_cross_validation_cli.py index af827fa5..6742a57c 100644 --- a/tests/scripts/test_cross_validation_cli.py +++ b/tests/scripts/test_cross_validation_cli.py @@ -213,19 +213,15 @@ def test_mlflow_logging_failure_exits_with_code_3( assert result.exit_code == 3 -def test_mlflow_unavailable_warns_and_ignores( +def test_mlflow_unavailable_fails_fast( tmp_path: Path, run_cv_spy: mock.Mock, monkeypatch: pytest.MonkeyPatch ) -> None: - """When the mlflow extra is absent, --mlflow is ignored with a warning (exit 0).""" + """When the mlflow extra is absent, --mlflow fails fast (exit 1) before running CV.""" monkeypatch.setattr(cli_module, "mlflow_available", lambda: False) - # A malformed tag must NOT error here: the options are ignored when the extra - # is missing, so the tag is never parsed. - result = _invoke(tmp_path, "--mlflow", "--mlflow-tag", "noequals") + result = _invoke(tmp_path, "--mlflow") - assert result.exit_code == 0, result.output + assert result.exit_code == 1, result.output assert "not installed" in result.stderr - # cross-validation still runs, but MLflow logging is disabled - run_cv_spy.assert_called_once() - assert run_cv_spy.call_args.kwargs["mlflow_enabled"] is False - assert run_cv_spy.call_args.kwargs["mlflow_tags"] == {} + # cross-validation must not run when an explicitly requested feature is unavailable + run_cv_spy.assert_not_called()