Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 193 additions & 0 deletions docs/user-guide/cli-tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -530,3 +530,196 @@ jabs-cli convert-parquet session_poses.parquet \
--num-frames 3600 \
--out-dir /path/to/output
```

## jabs-cli cross-validation

The `jabs-cli cross-validation` command runs leave-one-group-out cross-validation for a single behavior in a JABS project, then trains a final model on all labeled data to report feature importance. It prints per-iteration metrics to the console and writes a training report file (the same report produced by the GUI). Use it to estimate how well a classifier generalizes before committing to a trained model.

Features must already be computed for the project (for example via [`jabs-init`](#jabs-init)); if they are missing this command will compute them, which can be slow.

**Usage:**

```bash
jabs-cli cross-validation DIRECTORY --behavior BEHAVIOR \
[-k SPLITS] \
[--grouping-strategy {video|individual|filename}] \
[--grouping-pattern REGEX] \
[--classifier {catboost|random_forest|xgboost}] \
[--report-file FILE] \
[--mlflow [ENV_FILE]] [--mlflow-experiment NAME] [--mlflow-tag KEY=VALUE] [--mlflow-no-report]
```

- `DIRECTORY`: Path to the JABS project directory.
- `--behavior BEHAVIOR` (required): Behavior to evaluate. Quote it if it contains spaces; must match an existing behavior in the project.
- `-k SPLITS`: Number of cross-validation iterations. `0` (the default) uses the maximum number of splits supported by the data and grouping strategy.
- `--grouping-strategy {video|individual|filename}`: How labeled frames are grouped into cross-validation folds (see [Grouping strategies](#grouping-strategies)). If omitted, the project's saved setting is used.
- `--grouping-pattern REGEX`: Regular expression applied to each video filename to derive a grouping key. Only used with `--grouping-strategy filename`. If omitted, the pattern saved in the project is used.
- `--classifier {catboost|random_forest|xgboost}`: Classifier to evaluate. Defaults to `xgboost`. The available choices depend on which classifier libraries are installed; see [Classifier Types](classifier-types.md).
- `--report-file FILE`: Where to write the training report. The format is chosen by extension: `.md` (Markdown) or `.json` (JSON). If omitted, a timestamped Markdown file is written to the current directory (`<behavior>_<timestamp>_training_report.md`).
- `--mlflow`, `--mlflow-experiment`, `--mlflow-tag`, `--mlflow-no-report`: Optional MLflow logging (see [MLflow logging](#mlflow-logging)).

### Grouping strategies

Cross-validation holds out one *group* of labeled data per iteration and trains on the rest, so groups define what "generalization" means for the score. JABS supports three strategies:

| Strategy | Each group is... | Use when |
|---|---|---|
| `individual` | one (video, identity) pair | you want to measure generalization across individual animals |
| `video` | one whole video (all identities in it) | you want to measure generalization across videos/sessions |
| `filename` | all videos whose filename yields the same key under `--grouping-pattern` | videos from the same cage/cohort/day share a filename component and should not be split across train and test |

For the `filename` strategy, the pattern is applied with `re.search`, so it matches anywhere in the filename. If the pattern has a capturing group, the first captured group is the key; otherwise the whole match is the key. Videos that do not match the pattern are placed in their own single-video group. For example, `--grouping-pattern '^(\w+?)_'` groups `cage12_2024-01-01.mp4` and `cage12_2024-01-02.mp4` together under the key `cage12`.

### Training report

The report (and the console output) include:

- Per-iteration accuracy, precision and recall for both classes, and F1 for the behavior class, plus the held-out test group label for each iteration.
- The top features (by importance) from a final model trained on all labeled data.
- Labeled frame and bout counts, the window size, distance unit, classifier type, and the grouping strategy/pattern used.

**Examples:**

```bash
# Cross-validate "grooming" with default settings (project grouping, all splits, xgboost)
jabs-cli cross-validation /path/to/project --behavior grooming

# 5-fold, grouped by individual animal, with a CatBoost classifier
jabs-cli cross-validation /path/to/project --behavior grooming \
-k 5 --grouping-strategy individual --classifier catboost

# Group videos by a shared filename prefix and write a JSON report
jabs-cli cross-validation /path/to/project --behavior grooming \
--grouping-strategy filename --grouping-pattern '^(\w+?)_' \
--report-file grooming_cv.json
```

### MLflow logging

The cross-validation command can optionally log each run to an [MLflow](https://mlflow.org/) tracking server, recording aggregate metrics, run parameters, descriptive tags, and the training report as an artifact. This is opt-in and off by default.

#### Installing the MLflow extra

MLflow is an optional dependency. Install it with the `mlflow` extra:

```bash
pip install 'jabs-behavior-classifier[mlflow]'
```

If you request MLflow logging (`--mlflow`) without the extra installed, the command fails immediately with an error and exits `1` (before running the cross-validation), so you can install the extra and re-run rather than discovering after a long run that nothing was logged.

#### Enabling logging

Add the `--mlflow` flag:

```bash
# Use connection settings from the ambient environment
jabs-cli cross-validation /path/to/project --behavior grooming --mlflow

# Use connection settings from a .env file
jabs-cli cross-validation /path/to/project --behavior grooming --mlflow settings.env
```

`--mlflow` optionally takes the path to a `.env` file. With no path, connection settings are read from the current environment.

#### Connection configuration

Connection details — tracking server URI, experiment, authentication, TLS — are **not** passed as command-line options. They come from standard `MLFLOW_*` environment variables, either exported in your shell or written to the `.env` file you pass to `--mlflow`. Only keys beginning with `MLFLOW_` are read from the `.env` file; everything else is ignored.

Common variables:

| Variable | Purpose |
|---------------------------------------------------------|---------------------------------------------------------------------------------------------------|
| `MLFLOW_TRACKING_URI` | URL (or local path) of the tracking server, e.g. `https://mlflow.example.org` |
| `MLFLOW_EXPERIMENT_NAME` | Overrides the default experiment name (see [Selecting the experiment](#selecting-the-experiment)) |
| `MLFLOW_TRACKING_USERNAME` / `MLFLOW_TRACKING_PASSWORD` | HTTP basic-auth credentials |
| `MLFLOW_TRACKING_TOKEN` | Bearer-token auth (alternative to username/password) |

Example `.env` file:

```
MLFLOW_TRACKING_URI=https://mlflow.example.org
MLFLOW_TRACKING_USERNAME=jabs
MLFLOW_TRACKING_PASSWORD=hunter2
```

#### Selecting the experiment

Each behavior is logged to its **own experiment** by default, named `jabs-<behavior>` (for example `jabs-grooming`). This keeps comparisons meaningful: an experiment's runs table is effectively a leaderboard, and you want to rank runs of the *same* behavior over time rather than mix behaviors, whose metrics are not comparable. The experiment is created automatically if it does not exist.

To override the experiment name, in order of precedence:

1. `--mlflow-experiment NAME` (highest) — use a specific experiment for this run.
2. `MLFLOW_EXPERIMENT_NAME` (in your shell or the `.env` file).
3. The default `jabs-<behavior>`.

```bash
# Default: logs to experiment "jabs-grooming"
jabs-cli cross-validation /path/to/project --behavior grooming --mlflow

# Override the experiment for this run
jabs-cli cross-validation /path/to/project --behavior grooming --mlflow \
--mlflow-experiment grooming-hyperparam-sweep
```

#### Comparing runs (leaderboard)

The aggregate scores below (`cv_f1_behavior_mean`, `cv_accuracy_mean`, etc.) are logged as MLflow **metrics**, so an experiment's runs table doubles as a leaderboard — sort by the `cv_f1_behavior_mean` column to rank a behavior's runs by mean F1. The full per-fold breakdown rides along as the training-report artifact.

#### What gets logged

Each invocation creates one MLflow run named `<behavior>-cv-<timestamp>`.

**Metrics** (aggregated across cross-validation iterations):

- `cv_accuracy_mean`, `cv_accuracy_std`
- `cv_precision_behavior_mean` / `_std`, `cv_recall_behavior_mean` / `_std`, `cv_f1_behavior_mean` / `_std`
- `cv_iterations` — number of folds run
- `frames_behavior`, `frames_not_behavior`, `bouts_behavior`, `bouts_not_behavior` — dataset composition
- `training_time_ms`

**Parameters:** `behavior`, `classifier`, `window_size`, `balance_labels`, `symmetric_behavior`, `distance_unit`, `cv_grouping_strategy`, and `cv_grouping_regex` (only for the `filename` strategy).

**Tags:** auto-derived `behavior`, `classifier`, `cv_grouping_strategy`, and `jabs_git` (the short git SHA of the JABS checkout, when available). Any `--mlflow-tag` entries are merged on top, so a user tag wins over an auto tag with the same key.

**Artifact:** the generated training report file, unless `--mlflow-no-report` is passed.

#### Free-form tags

Add searchable tags to the run with `--mlflow-tag`, which is repeatable:

```bash
jabs-cli cross-validation /path/to/project --behavior grooming --mlflow settings.env \
--mlflow-tag purpose=baseline --mlflow-tag cohort=2024Q1
```

Each entry is `KEY=VALUE`; only the first `=` splits the entry, so values may contain `=`.

#### Skipping the report artifact

To log metrics and parameters only (no report upload):

```bash
jabs-cli cross-validation /path/to/project --behavior grooming --mlflow --mlflow-no-report
```

#### Exit codes and failure handling

MLflow logging happens **after** the cross-validation results are printed and the report is saved, so a logging failure never costs you the results:

- **Extra not installed:** the command fails fast with an error and exits `1` before running the cross-validation, since logging was explicitly requested but cannot be honored. Install the extra (or drop `--mlflow`) and re-run.
- **Logging fails** (for example the tracking server is unreachable or authentication fails): the results and report are preserved, a warning is printed, and the command exits with code **`3`** — distinct from the generic error code `1`, so automation can tell a push failure apart from a cross-validation failure.

#### Full example

```bash
# settings.env contains:
# MLFLOW_TRACKING_URI=https://mlflow.example.org
# MLFLOW_EXPERIMENT_NAME=mouse-grooming
jabs-cli cross-validation /path/to/project \
--behavior grooming \
--grouping-strategy individual \
--classifier xgboost \
--mlflow settings.env \
--mlflow-tag purpose=baseline
```
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ dependencies = [
[project.optional-dependencies]
nwb = ["jabs-io[nwb]"]
yaml = ["pyyaml>=6.0.0"]
mlflow = ["mlflow>=3.8.1"]

[tool.uv.sources]
jabs-behavior = { workspace = true }
Expand Down
10 changes: 10 additions & 0 deletions src/jabs/classifier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@

from .classifier import Classifier
from .cross_validation import run_leave_one_group_out_cv
from .mlflow_logging import (
MlflowLoggingError,
log_cross_validation_to_mlflow,
mlflow_available,
parse_kv_tags,
)
from .multi_class_classifier import MultiClassClassifier
from .protocols import ClassifierProtocol
from .training_report import (
Expand All @@ -23,10 +29,14 @@
"Classifier",
"ClassifierProtocol",
"CrossValidationResult",
"MlflowLoggingError",
"MultiClassCVResult",
"MultiClassClassifier",
"TrainingReportData",
"generate_markdown_report",
"log_cross_validation_to_mlflow",
"mlflow_available",
"parse_kv_tags",
"run_leave_one_group_out_cv",
"save_training_report",
]
Loading
Loading