Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion docs/user-guide/gui.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,20 @@ Project settings are saved within the project directory and apply only to the cu

| Setting | Description |
|-------------------------------|------------------------------------------------------------------|
| Cross Validation Grouping | Determines how cross-validation groups are defined. Options are "Individual Animal" (default) or "Video". |
| Cross Validation Grouping | Determines how cross-validation groups are defined. Options are "Individual Animal" (default), "Video", or "Filename Pattern". See [Cross-Validation Grouping](#cross-validation-grouping) below. |

As new settings are added, they will appear in this dialog with inline documentation.

### Cross-Validation Grouping

The **Cross Validation Grouping** setting controls how labeled data is partitioned into groups for leave-one-group-out cross-validation:

- **Individual Animal** (default): each group is a single animal identity within a single video.
- **Video**: each group is a single video; all identities within a video are held out together.
- **Filename Pattern**: groups are defined by a regular expression applied to each video's filename. All videos whose filenames produce the same key are placed in the same group, which is useful for grouping videos by an identifier embedded in their names (for example, a cage ID). If the pattern contains a capture group, the captured text is used as the key; otherwise the entire match is used. Videos that do not match the pattern are each placed in their own group.

When you select **Filename Pattern**, a text field appears for the regular expression. For example, if your videos are named like `cage_0042_2026-06-16.mp4`, the pattern `cage_(\d+)` extracts the cage number (`0042`) so that every video recorded from the same cage forms a single cross-validation group. A live preview below the field shows how your project's videos partition into groups under the current pattern (videos excluded from training are marked), so you can confirm the pattern before saving.


## Overlays

Expand Down
2 changes: 2 additions & 0 deletions packages/jabs-core/src/jabs/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

# settings keys for project settings stored in the project.json file
CV_GROUPING_KEY = "cv_grouping"
# regex used when CV_GROUPING_KEY is the "Filename Pattern" strategy
CV_GROUPING_REGEX_KEY = "cv_grouping_regex"
CLASSIFIER_MODE_KEY = "classifier_mode"
CACHE_FORMAT_KEY = "cache_format"

Expand Down
9 changes: 8 additions & 1 deletion packages/jabs-core/src/jabs/core/enums/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
from .cache_format import CacheFormat
from .classifier_mode import DEFAULT_CLASSIFIER_MODE, ClassifierMode
from .classifier_types import ClassifierType
from .cv_grouping import DEFAULT_CV_GROUPING_STRATEGY, CrossValidationGroupingStrategy
from .cv_grouping import (
DEFAULT_CV_GROUPING_STRATEGY,
CrossValidationGroupingStrategy,
compile_grouping_regex,
filename_group_key,
)
from .inference import ConfidenceMetric, Method, SamplingStrategy
from .prediction_type import PredictionType
from .storage_format import StorageFormat
Expand All @@ -22,4 +27,6 @@
"ProjectDistanceUnit",
"SamplingStrategy",
"StorageFormat",
"compile_grouping_regex",
"filename_group_key",
]
53 changes: 53 additions & 0 deletions packages/jabs-core/src/jabs/core/enums/cv_grouping.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
"""Cross-validation grouping strategy enum and filename-pattern helpers."""

import re
from enum import Enum


Expand All @@ -10,6 +13,56 @@ class CrossValidationGroupingStrategy(str, Enum):

INDIVIDUAL = "Individual Animal"
VIDEO = "Video"
FILENAME_PATTERN = "Filename Pattern"


DEFAULT_CV_GROUPING_STRATEGY = CrossValidationGroupingStrategy.INDIVIDUAL


def compile_grouping_regex(regex: str) -> re.Pattern[str]:
"""Compile a filename-pattern cross-validation grouping regular expression.

Args:
regex: Regular expression used to extract a grouping key from a video
filename.

Returns:
The compiled regular expression pattern.

Raises:
ValueError: If ``regex`` is empty or not a valid regular expression.
"""
if not regex:
raise ValueError("Filename pattern grouping requires a non-empty regular expression")
try:
return re.compile(regex)
except re.error as e:
raise ValueError(f"Invalid filename grouping pattern: {e}") from e


def filename_group_key(video_name: str, pattern: re.Pattern[str]) -> str:
"""Extract a cross-validation grouping key from a video filename.

The pattern is applied with :meth:`re.Pattern.search`, so it matches anywhere
in ``video_name``. If the pattern defines a capturing group and it matched, the
first captured group is used as the key (so a pattern that captures the digits
in ``cage_1234.mp4`` yields ``"1234"``); otherwise the full matched text is used
(a pattern matching the whole ``cage_1234`` token yields ``"cage_1234"``).
Videos that do not match the pattern are placed in their own group, keyed by the
filename itself.

Args:
video_name: Video filename to extract a grouping key from.
pattern: Compiled regular expression (see :func:`compile_grouping_regex`).

Returns:
The grouping key string. All videos that yield the same key are placed in
the same cross-validation group.
"""
match = pattern.search(video_name)
if match is None:
# No match: the video becomes its own group (keyed by its unique filename).
return video_name
if pattern.groups >= 1 and match.group(1) is not None:
return match.group(1)
return match.group(0)
75 changes: 75 additions & 0 deletions packages/jabs-core/tests/test_cv_grouping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""Tests for cross-validation grouping enum and filename-pattern helpers."""

import pytest

from jabs.core.enums import (
CrossValidationGroupingStrategy,
compile_grouping_regex,
filename_group_key,
)


def test_filename_pattern_member_exists() -> None:
"""The FILENAME_PATTERN strategy is available with its display value."""
assert CrossValidationGroupingStrategy.FILENAME_PATTERN.value == "Filename Pattern"
assert CrossValidationGroupingStrategy("Filename Pattern") is (
CrossValidationGroupingStrategy.FILENAME_PATTERN
)


def test_compile_grouping_regex_valid() -> None:
"""A valid pattern compiles to a usable regex."""
pattern = compile_grouping_regex(r"cage_(\d+)")
assert pattern.search("cage_0042.mp4") is not None


@pytest.mark.parametrize("regex", ["", None], ids=["empty", "none"])
def test_compile_grouping_regex_empty_raises(regex) -> None:
"""An empty (or falsy) pattern is rejected."""
with pytest.raises(ValueError, match="non-empty"):
compile_grouping_regex(regex)


def test_compile_grouping_regex_invalid_raises() -> None:
"""A syntactically invalid pattern raises ValueError, not re.error."""
with pytest.raises(ValueError, match="Invalid filename grouping pattern"):
compile_grouping_regex("cage_(")


def test_filename_group_key_uses_capture_group() -> None:
"""When the pattern has a capture group, the captured text is the key."""
pattern = compile_grouping_regex(r"cage_(\d+)")
assert filename_group_key("cage_0042_2026-06-16.mp4", pattern) == "0042"


def test_filename_group_key_uses_full_match_without_capture_group() -> None:
"""Without a capture group, the whole matched substring is the key."""
pattern = compile_grouping_regex(r"cage_\d+")
assert filename_group_key("cage_0042_2026-06-16.mp4", pattern) == "cage_0042"


def test_filename_group_key_searches_anywhere() -> None:
"""The pattern matches anywhere in the filename (re.search semantics)."""
pattern = compile_grouping_regex(r"cage_(\d+)")
assert filename_group_key("2026-06-16_cage_0007_cam1.avi", pattern) == "0007"


def test_filename_group_key_unmatched_returns_filename() -> None:
"""A filename that does not match becomes its own group (keyed by the name)."""
pattern = compile_grouping_regex(r"cage_(\d+)")
assert filename_group_key("mouse_video.mp4", pattern) == "mouse_video.mp4"


def test_filename_group_key_same_cage_different_files_share_key() -> None:
"""Different files from the same cage produce the same grouping key."""
pattern = compile_grouping_regex(r"cage_(\d+)")
key_a = filename_group_key("cage_0042_day1.mp4", pattern)
key_b = filename_group_key("cage_0042_day2.avi", pattern)
assert key_a == key_b == "0042"


def test_filename_group_key_optional_capture_group_falls_back_to_full_match() -> None:
"""An optional capture group that does not participate falls back to the full match."""
pattern = compile_grouping_regex(r"cage(_extra)?_\d+")
# The optional group does not match here, so the full match is used.
assert filename_group_key("cage_0042.mp4", pattern) == "cage_0042"
31 changes: 30 additions & 1 deletion src/jabs/classifier/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
DEFAULT_CV_GROUPING_STRATEGY,
ClassifierType,
CrossValidationGroupingStrategy,
compile_grouping_regex,
filename_group_key,
)
from jabs.core.utils import hash_file
from jabs.project import Project, load_training_data
Expand Down Expand Up @@ -318,6 +320,7 @@ def confusion_matrix(truth: np.ndarray, predictions: np.ndarray) -> np.ndarray:
def count_label_threshold(
all_counts: dict,
cv_grouping_strategy: CrossValidationGroupingStrategy = DEFAULT_CV_GROUPING_STRATEGY,
cv_grouping_regex: str | None = None,
) -> int:
"""Count groups that meet the label-threshold criteria.

Expand All @@ -326,6 +329,9 @@ def count_label_threshold(
Structure is a dict[video_name][identity] of fragmented and
unfragmented frame/bout count tuples.
cv_grouping_strategy: Cross-validation grouping strategy.
cv_grouping_regex: Regex used to extract a grouping key from each video
filename when ``cv_grouping_strategy`` is ``FILENAME_PATTERN``. An
empty or invalid regex yields a count of 0 (no trainable groups).

Returns:
Number of groups that meet the labeling threshold criteria.
Expand Down Expand Up @@ -356,6 +362,24 @@ def count_label_threshold(
and not_behavior_sum >= Classifier.LABEL_THRESHOLD
):
group_count += 1
elif cv_grouping_strategy == CrossValidationGroupingStrategy.FILENAME_PATTERN:
try:
pattern = compile_grouping_regex(cv_grouping_regex or "")
except ValueError:
return 0
group_sums: dict[str, list[int]] = {}
for video in all_counts:
label = filename_group_key(video, pattern)
sums = group_sums.setdefault(label, [0, 0])
for identity_count in all_counts[video].values():
sums[0] += identity_count["fragmented_frame_counts"][0]
sums[1] += identity_count["fragmented_frame_counts"][1]
for behavior_sum, not_behavior_sum in group_sums.values():
if (
behavior_sum >= Classifier.LABEL_THRESHOLD
and not_behavior_sum >= Classifier.LABEL_THRESHOLD
):
group_count += 1
else:
raise ValueError(f"Unknown cv_grouping_strategy: {cv_grouping_strategy}")
return group_count
Expand All @@ -365,18 +389,23 @@ def label_threshold_met(
all_counts: dict,
min_groups: int,
cv_grouping_strategy: CrossValidationGroupingStrategy = DEFAULT_CV_GROUPING_STRATEGY,
cv_grouping_regex: str | None = None,
) -> bool:
"""Determine whether the labeling threshold is met.

Args:
all_counts: Labeled frame and bout counts for the entire project.
min_groups: Minimum number of groups required.
cv_grouping_strategy: Cross-validation grouping strategy.
cv_grouping_regex: Regex used for ``FILENAME_PATTERN`` grouping (see
:meth:`count_label_threshold`).

Returns:
True if there are enough groups meeting the threshold.
"""
group_count = Classifier.count_label_threshold(
all_counts, cv_grouping_strategy=cv_grouping_strategy
all_counts,
cv_grouping_strategy=cv_grouping_strategy,
cv_grouping_regex=cv_grouping_regex,
)
return 1 < group_count >= min_groups
10 changes: 9 additions & 1 deletion src/jabs/classifier/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,15 @@ def _train_multiclass_fold(


def _test_label_from_group(test_info: dict) -> str:
"""Render a CV test-group label for the report (video name + optional identity)."""
"""Render a CV test-group label for the report.

Filename-pattern groups carry a ``label`` (the regex-extracted key, e.g.
``"cage_1234"``); otherwise the label is the video name plus an optional
identity.
"""
label = test_info.get("label")
if label is not None:
return label
if test_info["identity"] is not None:
return f"{test_info['video']} [{test_info['identity']}]"
return test_info["video"]
Expand Down
31 changes: 29 additions & 2 deletions src/jabs/classifier/multi_class_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
DEFAULT_CV_GROUPING_STRATEGY,
ClassifierType,
CrossValidationGroupingStrategy,
compile_grouping_regex,
filename_group_key,
)
from jabs.core.utils import hash_file
from jabs.project import load_multiclass_training_data
Expand Down Expand Up @@ -423,6 +425,7 @@ def count_label_threshold(
counts_by_behavior: dict[str, dict],
behavior_names: list[str],
cv_grouping_strategy: CrossValidationGroupingStrategy = DEFAULT_CV_GROUPING_STRATEGY,
cv_grouping_regex: str | None = None,
) -> int:
"""Count multi-class LOGO groups that satisfy the relaxed acceptance rule.

Expand All @@ -439,6 +442,9 @@ def count_label_threshold(
``counts_by_behavior``. Typically includes
``MULTICLASS_NONE_BEHAVIOR``.
cv_grouping_strategy: Cross-validation grouping strategy.
cv_grouping_regex: Regex used to extract a grouping key from each video
filename when ``cv_grouping_strategy`` is ``FILENAME_PATTERN``. An
empty or invalid regex yields a count of 0 (no trainable groups).

Returns:
Number of groups that can serve as a valid multi-class LOGO test split.
Expand All @@ -450,13 +456,30 @@ def count_label_threshold(
if not behavior_names:
return 0

# FILENAME_PATTERN aggregates like VIDEO grouping, but keys each group by
# the regex-extracted filename key instead of the video name (so several
# videos can share one group). An empty/invalid regex means no groups.
pattern = None
if cv_grouping_strategy == CrossValidationGroupingStrategy.FILENAME_PATTERN:
try:
pattern = compile_grouping_regex(cv_grouping_regex or "")
except ValueError:
return 0

threshold = MultiClassClassifier.LABEL_THRESHOLD
group_class_counts: dict[tuple[str, int] | str, dict[str, int]] = {}
for behavior_name in behavior_names:
behavior_counts = counts_by_behavior.get(behavior_name, {})
for video_name, video_counts in behavior_counts.items():
if cv_grouping_strategy == CrossValidationGroupingStrategy.VIDEO:
key: tuple[str, int] | str = video_name
if cv_grouping_strategy in (
CrossValidationGroupingStrategy.VIDEO,
CrossValidationGroupingStrategy.FILENAME_PATTERN,
):
key: tuple[str, int] | str = (
filename_group_key(video_name, pattern)
if pattern is not None
else video_name
)
group_entry = group_class_counts.setdefault(key, {})
group_entry[behavior_name] = group_entry.get(behavior_name, 0) + sum(
identity_counts["fragmented_frame_counts"][0]
Expand Down Expand Up @@ -498,6 +521,7 @@ def label_threshold_met(
behavior_names: list[str],
min_groups: int,
cv_grouping_strategy: CrossValidationGroupingStrategy = DEFAULT_CV_GROUPING_STRATEGY,
cv_grouping_regex: str | None = None,
) -> bool:
"""Determine whether multi-class labels support ``min_groups`` LOGO splits.

Expand All @@ -511,6 +535,8 @@ class names are supplied.
at 1, since multi-class training requires at least one valid
split.
cv_grouping_strategy: Cross-validation grouping strategy.
cv_grouping_regex: Regex used for ``FILENAME_PATTERN`` grouping (see
:meth:`count_label_threshold`).

Returns:
True if the count of valid splits meets ``max(1, min_groups)``.
Expand All @@ -521,5 +547,6 @@ class names are supplied.
counts_by_behavior=counts_by_behavior,
behavior_names=behavior_names,
cv_grouping_strategy=cv_grouping_strategy,
cv_grouping_regex=cv_grouping_regex,
)
return valid_splits >= max(1, min_groups)
Loading
Loading