diff --git a/docs/user-guide/gui.md b/docs/user-guide/gui.md index f303cfa7..6949314e 100644 --- a/docs/user-guide/gui.md +++ b/docs/user-guide/gui.md @@ -161,10 +161,20 @@ Project settings are saved within the project directory and apply only to the cu | Setting | Description | |-------------------------------|------------------------------------------------------------------| -| Cross Validation Grouping | Determines how cross-validation groups are defined. Options are "Individual Animal" (default) or "Video". | +| Cross Validation Grouping | Determines how cross-validation groups are defined. Options are "Individual Animal" (default), "Video", or "Filename Pattern". See [Cross-Validation Grouping](#cross-validation-grouping) below. | As new settings are added, they will appear in this dialog with inline documentation. +### Cross-Validation Grouping + +The **Cross Validation Grouping** setting controls how labeled data is partitioned into groups for leave-one-group-out cross-validation: + +- **Individual Animal** (default): each group is a single animal identity within a single video. +- **Video**: each group is a single video; all identities within a video are held out together. +- **Filename Pattern**: groups are defined by a regular expression applied to each video's filename. All videos whose filenames produce the same key are placed in the same group, which is useful for grouping videos by an identifier embedded in their names (for example, a cage ID). If the pattern contains a capture group, the captured text is used as the key; otherwise the entire match is used. Videos that do not match the pattern are each placed in their own group. + +When you select **Filename Pattern**, a text field appears for the regular expression. For example, if your videos are named like `cage_0042_2026-06-16.mp4`, the pattern `cage_(\d+)` extracts the cage number (`0042`) so that every video recorded from the same cage forms a single cross-validation group. A live preview below the field shows how your project's videos partition into groups under the current pattern (videos excluded from training are marked), so you can confirm the pattern before saving. + ## Overlays diff --git a/packages/jabs-core/src/jabs/core/constants.py b/packages/jabs-core/src/jabs/core/constants.py index 5ea79323..0b6a9bd1 100644 --- a/packages/jabs-core/src/jabs/core/constants.py +++ b/packages/jabs-core/src/jabs/core/constants.py @@ -13,6 +13,8 @@ # settings keys for project settings stored in the project.json file CV_GROUPING_KEY = "cv_grouping" +# regex used when CV_GROUPING_KEY is the "Filename Pattern" strategy +CV_GROUPING_REGEX_KEY = "cv_grouping_regex" CLASSIFIER_MODE_KEY = "classifier_mode" CACHE_FORMAT_KEY = "cache_format" diff --git a/packages/jabs-core/src/jabs/core/enums/__init__.py b/packages/jabs-core/src/jabs/core/enums/__init__.py index 18c8db90..5a2c7f5d 100644 --- a/packages/jabs-core/src/jabs/core/enums/__init__.py +++ b/packages/jabs-core/src/jabs/core/enums/__init__.py @@ -3,7 +3,12 @@ from .cache_format import CacheFormat from .classifier_mode import DEFAULT_CLASSIFIER_MODE, ClassifierMode from .classifier_types import ClassifierType -from .cv_grouping import DEFAULT_CV_GROUPING_STRATEGY, CrossValidationGroupingStrategy +from .cv_grouping import ( + DEFAULT_CV_GROUPING_STRATEGY, + CrossValidationGroupingStrategy, + compile_grouping_regex, + filename_group_key, +) from .inference import ConfidenceMetric, Method, SamplingStrategy from .prediction_type import PredictionType from .storage_format import StorageFormat @@ -22,4 +27,6 @@ "ProjectDistanceUnit", "SamplingStrategy", "StorageFormat", + "compile_grouping_regex", + "filename_group_key", ] diff --git a/packages/jabs-core/src/jabs/core/enums/cv_grouping.py b/packages/jabs-core/src/jabs/core/enums/cv_grouping.py index c3ce9cb0..3c592840 100644 --- a/packages/jabs-core/src/jabs/core/enums/cv_grouping.py +++ b/packages/jabs-core/src/jabs/core/enums/cv_grouping.py @@ -1,3 +1,6 @@ +"""Cross-validation grouping strategy enum and filename-pattern helpers.""" + +import re from enum import Enum @@ -10,6 +13,56 @@ class CrossValidationGroupingStrategy(str, Enum): INDIVIDUAL = "Individual Animal" VIDEO = "Video" + FILENAME_PATTERN = "Filename Pattern" DEFAULT_CV_GROUPING_STRATEGY = CrossValidationGroupingStrategy.INDIVIDUAL + + +def compile_grouping_regex(regex: str) -> re.Pattern[str]: + """Compile a filename-pattern cross-validation grouping regular expression. + + Args: + regex: Regular expression used to extract a grouping key from a video + filename. + + Returns: + The compiled regular expression pattern. + + Raises: + ValueError: If ``regex`` is empty or not a valid regular expression. + """ + if not regex: + raise ValueError("Filename pattern grouping requires a non-empty regular expression") + try: + return re.compile(regex) + except re.error as e: + raise ValueError(f"Invalid filename grouping pattern: {e}") from e + + +def filename_group_key(video_name: str, pattern: re.Pattern[str]) -> str: + """Extract a cross-validation grouping key from a video filename. + + The pattern is applied with :meth:`re.Pattern.search`, so it matches anywhere + in ``video_name``. If the pattern defines a capturing group and it matched, the + first captured group is used as the key (so a pattern that captures the digits + in ``cage_1234.mp4`` yields ``"1234"``); otherwise the full matched text is used + (a pattern matching the whole ``cage_1234`` token yields ``"cage_1234"``). + Videos that do not match the pattern are placed in their own group, keyed by the + filename itself. + + Args: + video_name: Video filename to extract a grouping key from. + pattern: Compiled regular expression (see :func:`compile_grouping_regex`). + + Returns: + The grouping key string. All videos that yield the same key are placed in + the same cross-validation group. + """ + match = pattern.search(video_name) + if match is None: + # No match: the video becomes its own group (keyed by its unique filename). + return video_name + if pattern.groups >= 1 and match.group(1) is not None: + return match.group(1) + return match.group(0) diff --git a/packages/jabs-core/tests/test_cv_grouping.py b/packages/jabs-core/tests/test_cv_grouping.py new file mode 100644 index 00000000..1c2982b9 --- /dev/null +++ b/packages/jabs-core/tests/test_cv_grouping.py @@ -0,0 +1,75 @@ +"""Tests for cross-validation grouping enum and filename-pattern helpers.""" + +import pytest + +from jabs.core.enums import ( + CrossValidationGroupingStrategy, + compile_grouping_regex, + filename_group_key, +) + + +def test_filename_pattern_member_exists() -> None: + """The FILENAME_PATTERN strategy is available with its display value.""" + assert CrossValidationGroupingStrategy.FILENAME_PATTERN.value == "Filename Pattern" + assert CrossValidationGroupingStrategy("Filename Pattern") is ( + CrossValidationGroupingStrategy.FILENAME_PATTERN + ) + + +def test_compile_grouping_regex_valid() -> None: + """A valid pattern compiles to a usable regex.""" + pattern = compile_grouping_regex(r"cage_(\d+)") + assert pattern.search("cage_0042.mp4") is not None + + +@pytest.mark.parametrize("regex", ["", None], ids=["empty", "none"]) +def test_compile_grouping_regex_empty_raises(regex) -> None: + """An empty (or falsy) pattern is rejected.""" + with pytest.raises(ValueError, match="non-empty"): + compile_grouping_regex(regex) + + +def test_compile_grouping_regex_invalid_raises() -> None: + """A syntactically invalid pattern raises ValueError, not re.error.""" + with pytest.raises(ValueError, match="Invalid filename grouping pattern"): + compile_grouping_regex("cage_(") + + +def test_filename_group_key_uses_capture_group() -> None: + """When the pattern has a capture group, the captured text is the key.""" + pattern = compile_grouping_regex(r"cage_(\d+)") + assert filename_group_key("cage_0042_2026-06-16.mp4", pattern) == "0042" + + +def test_filename_group_key_uses_full_match_without_capture_group() -> None: + """Without a capture group, the whole matched substring is the key.""" + pattern = compile_grouping_regex(r"cage_\d+") + assert filename_group_key("cage_0042_2026-06-16.mp4", pattern) == "cage_0042" + + +def test_filename_group_key_searches_anywhere() -> None: + """The pattern matches anywhere in the filename (re.search semantics).""" + pattern = compile_grouping_regex(r"cage_(\d+)") + assert filename_group_key("2026-06-16_cage_0007_cam1.avi", pattern) == "0007" + + +def test_filename_group_key_unmatched_returns_filename() -> None: + """A filename that does not match becomes its own group (keyed by the name).""" + pattern = compile_grouping_regex(r"cage_(\d+)") + assert filename_group_key("mouse_video.mp4", pattern) == "mouse_video.mp4" + + +def test_filename_group_key_same_cage_different_files_share_key() -> None: + """Different files from the same cage produce the same grouping key.""" + pattern = compile_grouping_regex(r"cage_(\d+)") + key_a = filename_group_key("cage_0042_day1.mp4", pattern) + key_b = filename_group_key("cage_0042_day2.avi", pattern) + assert key_a == key_b == "0042" + + +def test_filename_group_key_optional_capture_group_falls_back_to_full_match() -> None: + """An optional capture group that does not participate falls back to the full match.""" + pattern = compile_grouping_regex(r"cage(_extra)?_\d+") + # The optional group does not match here, so the full match is used. + assert filename_group_key("cage_0042.mp4", pattern) == "cage_0042" diff --git a/src/jabs/classifier/classifier.py b/src/jabs/classifier/classifier.py index 0162fb18..3003d99a 100644 --- a/src/jabs/classifier/classifier.py +++ b/src/jabs/classifier/classifier.py @@ -12,6 +12,8 @@ DEFAULT_CV_GROUPING_STRATEGY, ClassifierType, CrossValidationGroupingStrategy, + compile_grouping_regex, + filename_group_key, ) from jabs.core.utils import hash_file from jabs.project import Project, load_training_data @@ -318,6 +320,7 @@ def confusion_matrix(truth: np.ndarray, predictions: np.ndarray) -> np.ndarray: def count_label_threshold( all_counts: dict, cv_grouping_strategy: CrossValidationGroupingStrategy = DEFAULT_CV_GROUPING_STRATEGY, + cv_grouping_regex: str | None = None, ) -> int: """Count groups that meet the label-threshold criteria. @@ -326,6 +329,9 @@ def count_label_threshold( Structure is a dict[video_name][identity] of fragmented and unfragmented frame/bout count tuples. cv_grouping_strategy: Cross-validation grouping strategy. + cv_grouping_regex: Regex used to extract a grouping key from each video + filename when ``cv_grouping_strategy`` is ``FILENAME_PATTERN``. An + empty or invalid regex yields a count of 0 (no trainable groups). Returns: Number of groups that meet the labeling threshold criteria. @@ -356,6 +362,24 @@ def count_label_threshold( and not_behavior_sum >= Classifier.LABEL_THRESHOLD ): group_count += 1 + elif cv_grouping_strategy == CrossValidationGroupingStrategy.FILENAME_PATTERN: + try: + pattern = compile_grouping_regex(cv_grouping_regex or "") + except ValueError: + return 0 + group_sums: dict[str, list[int]] = {} + for video in all_counts: + label = filename_group_key(video, pattern) + sums = group_sums.setdefault(label, [0, 0]) + for identity_count in all_counts[video].values(): + sums[0] += identity_count["fragmented_frame_counts"][0] + sums[1] += identity_count["fragmented_frame_counts"][1] + for behavior_sum, not_behavior_sum in group_sums.values(): + if ( + behavior_sum >= Classifier.LABEL_THRESHOLD + and not_behavior_sum >= Classifier.LABEL_THRESHOLD + ): + group_count += 1 else: raise ValueError(f"Unknown cv_grouping_strategy: {cv_grouping_strategy}") return group_count @@ -365,6 +389,7 @@ def label_threshold_met( all_counts: dict, min_groups: int, cv_grouping_strategy: CrossValidationGroupingStrategy = DEFAULT_CV_GROUPING_STRATEGY, + cv_grouping_regex: str | None = None, ) -> bool: """Determine whether the labeling threshold is met. @@ -372,11 +397,15 @@ def label_threshold_met( all_counts: Labeled frame and bout counts for the entire project. min_groups: Minimum number of groups required. cv_grouping_strategy: Cross-validation grouping strategy. + cv_grouping_regex: Regex used for ``FILENAME_PATTERN`` grouping (see + :meth:`count_label_threshold`). Returns: True if there are enough groups meeting the threshold. """ group_count = Classifier.count_label_threshold( - all_counts, cv_grouping_strategy=cv_grouping_strategy + all_counts, + cv_grouping_strategy=cv_grouping_strategy, + cv_grouping_regex=cv_grouping_regex, ) return 1 < group_count >= min_groups diff --git a/src/jabs/classifier/cross_validation.py b/src/jabs/classifier/cross_validation.py index d015334a..6ea2caeb 100644 --- a/src/jabs/classifier/cross_validation.py +++ b/src/jabs/classifier/cross_validation.py @@ -125,7 +125,15 @@ def _train_multiclass_fold( def _test_label_from_group(test_info: dict) -> str: - """Render a CV test-group label for the report (video name + optional identity).""" + """Render a CV test-group label for the report. + + Filename-pattern groups carry a ``label`` (the regex-extracted key, e.g. + ``"cage_1234"``); otherwise the label is the video name plus an optional + identity. + """ + label = test_info.get("label") + if label is not None: + return label if test_info["identity"] is not None: return f"{test_info['video']} [{test_info['identity']}]" return test_info["video"] diff --git a/src/jabs/classifier/multi_class_classifier.py b/src/jabs/classifier/multi_class_classifier.py index b1e2783b..cb28153d 100644 --- a/src/jabs/classifier/multi_class_classifier.py +++ b/src/jabs/classifier/multi_class_classifier.py @@ -17,6 +17,8 @@ DEFAULT_CV_GROUPING_STRATEGY, ClassifierType, CrossValidationGroupingStrategy, + compile_grouping_regex, + filename_group_key, ) from jabs.core.utils import hash_file from jabs.project import load_multiclass_training_data @@ -423,6 +425,7 @@ def count_label_threshold( counts_by_behavior: dict[str, dict], behavior_names: list[str], cv_grouping_strategy: CrossValidationGroupingStrategy = DEFAULT_CV_GROUPING_STRATEGY, + cv_grouping_regex: str | None = None, ) -> int: """Count multi-class LOGO groups that satisfy the relaxed acceptance rule. @@ -439,6 +442,9 @@ def count_label_threshold( ``counts_by_behavior``. Typically includes ``MULTICLASS_NONE_BEHAVIOR``. cv_grouping_strategy: Cross-validation grouping strategy. + cv_grouping_regex: Regex used to extract a grouping key from each video + filename when ``cv_grouping_strategy`` is ``FILENAME_PATTERN``. An + empty or invalid regex yields a count of 0 (no trainable groups). Returns: Number of groups that can serve as a valid multi-class LOGO test split. @@ -450,13 +456,30 @@ def count_label_threshold( if not behavior_names: return 0 + # FILENAME_PATTERN aggregates like VIDEO grouping, but keys each group by + # the regex-extracted filename key instead of the video name (so several + # videos can share one group). An empty/invalid regex means no groups. + pattern = None + if cv_grouping_strategy == CrossValidationGroupingStrategy.FILENAME_PATTERN: + try: + pattern = compile_grouping_regex(cv_grouping_regex or "") + except ValueError: + return 0 + threshold = MultiClassClassifier.LABEL_THRESHOLD group_class_counts: dict[tuple[str, int] | str, dict[str, int]] = {} for behavior_name in behavior_names: behavior_counts = counts_by_behavior.get(behavior_name, {}) for video_name, video_counts in behavior_counts.items(): - if cv_grouping_strategy == CrossValidationGroupingStrategy.VIDEO: - key: tuple[str, int] | str = video_name + if cv_grouping_strategy in ( + CrossValidationGroupingStrategy.VIDEO, + CrossValidationGroupingStrategy.FILENAME_PATTERN, + ): + key: tuple[str, int] | str = ( + filename_group_key(video_name, pattern) + if pattern is not None + else video_name + ) group_entry = group_class_counts.setdefault(key, {}) group_entry[behavior_name] = group_entry.get(behavior_name, 0) + sum( identity_counts["fragmented_frame_counts"][0] @@ -498,6 +521,7 @@ def label_threshold_met( behavior_names: list[str], min_groups: int, cv_grouping_strategy: CrossValidationGroupingStrategy = DEFAULT_CV_GROUPING_STRATEGY, + cv_grouping_regex: str | None = None, ) -> bool: """Determine whether multi-class labels support ``min_groups`` LOGO splits. @@ -511,6 +535,8 @@ class names are supplied. at 1, since multi-class training requires at least one valid split. cv_grouping_strategy: Cross-validation grouping strategy. + cv_grouping_regex: Regex used for ``FILENAME_PATTERN`` grouping (see + :meth:`count_label_threshold`). Returns: True if the count of valid splits meets ``max(1, min_groups)``. @@ -521,5 +547,6 @@ class names are supplied. counts_by_behavior=counts_by_behavior, behavior_names=behavior_names, cv_grouping_strategy=cv_grouping_strategy, + cv_grouping_regex=cv_grouping_regex, ) return valid_splits >= max(1, min_groups) diff --git a/src/jabs/classifier/training_report.py b/src/jabs/classifier/training_report.py index b14dd60f..948ce8ba 100644 --- a/src/jabs/classifier/training_report.py +++ b/src/jabs/classifier/training_report.py @@ -105,6 +105,8 @@ class TrainingReportData: training_time_ms: Total training time in milliseconds. timestamp: Datetime when training was completed. cv_grouping_strategy: Strategy used for cross-validation grouping. + cv_grouping_regex: Filename-pattern regex used for grouping. Only set when + the grouping strategy is "Filename Pattern". """ behavior_name: str @@ -124,6 +126,7 @@ class TrainingReportData: bouts_not_behavior: int = 0 class_frame_counts: dict[str, int] | None = None class_bout_counts: dict[str, int] | None = None + cv_grouping_regex: str | None = None def _escape_markdown(text: str) -> str: @@ -287,6 +290,8 @@ def generate_markdown_report(data: TrainingReportData) -> str: lines.append("### Iteration Details") lines.append(f"CV Grouping Strategy: {data.cv_grouping_strategy.value}") + if data.cv_grouping_regex: + lines.append(f"CV Grouping Pattern: `{data.cv_grouping_regex}`") lines.append("") lines.append(_format_iteration_table(data.cv_results)) lines.append("") @@ -410,6 +415,7 @@ def generate_json_report(data: TrainingReportData) -> dict: "training_time_ms": int(data.training_time_ms), "timestamp": timestamp_str, "cv_grouping_strategy": data.cv_grouping_strategy.value, + "cv_grouping_regex": data.cv_grouping_regex, "frames_behavior": int(data.frames_behavior), "frames_not_behavior": int(data.frames_not_behavior), "bouts_behavior": int(data.bouts_behavior), diff --git a/src/jabs/project/export_training.py b/src/jabs/project/export_training.py index 4ecaa032..577950f7 100644 --- a/src/jabs/project/export_training.py +++ b/src/jabs/project/export_training.py @@ -20,6 +20,34 @@ from jabs.project import Project +def _write_group_mapping( + out_h5: h5py.File, group_mapping: dict[int, dict], string_type: np.dtype +) -> None: + """Write the cross-validation group mapping into the exported HDF5 file. + + For each group, stores an ``identity`` (the animal identity, or ``-1`` when the + group is not identity-specific) and a ``video_name``. ``identity`` is ``-1`` for + ``VIDEO`` and ``FILENAME_PATTERN`` grouping. For ``FILENAME_PATTERN`` groups, + ``video_name`` holds the regex-extracted group label (e.g. ``"cage_1234"``) + rather than a single filename, since one group can span multiple videos. + + Args: + out_h5: Open HDF5 file to write into. + group_mapping: Mapping of group id to its source descriptor. + string_type: h5py variable-length string dtype for the video_name dataset. + """ + for group, info in group_mapping.items(): + identity_dset = out_h5.create_dataset( + f"group_mapping/{group}/identity", (1,), dtype=np.int64 + ) + identity = info["identity"] + identity_dset[:] = identity if identity is not None else -1 + video_name_dset = out_h5.create_dataset( + f"group_mapping/{group}/video_name", (1,), dtype=string_type + ) + video_name_dset[:] = info["label"] if info.get("label") is not None else info["video"] + + def export_training_data( project: "Project", behavior: str, @@ -76,15 +104,7 @@ def export_training_data( out_h5.create_dataset("label", data=features["labels"]) # store the video/identity to group mapping in the h5 file - # identity is None when VIDEO grouping strategy is used; store -1 as a sentinel - for group in group_mapping: - dset = out_h5.create_dataset(f"group_mapping/{group}/identity", (1,), dtype=np.int64) - identity = group_mapping[group]["identity"] - dset[:] = identity if identity is not None else -1 - dset = out_h5.create_dataset( - f"group_mapping/{group}/video_name", (1,), dtype=string_type - ) - dset[:] = group_mapping[group]["video"] + _write_group_mapping(out_h5, group_mapping, string_type) # return output path, so if it was generated automatically the caller # will know @@ -159,14 +179,7 @@ def export_training_data_multiclass( out_h5.create_dataset("group", data=features["groups"]) - for group in group_mapping: - dset = out_h5.create_dataset(f"group_mapping/{group}/identity", (1,), dtype=np.int64) - identity = group_mapping[group]["identity"] - dset[:] = identity if identity is not None else -1 - dset = out_h5.create_dataset( - f"group_mapping/{group}/video_name", (1,), dtype=string_type - ) - dset[:] = group_mapping[group]["video"] + _write_group_mapping(out_h5, group_mapping, string_type) return out_file diff --git a/src/jabs/project/project.py b/src/jabs/project/project.py index a6746842..f25faad5 100644 --- a/src/jabs/project/project.py +++ b/src/jabs/project/project.py @@ -23,6 +23,8 @@ ClassifierMode, CrossValidationGroupingStrategy, ProjectDistanceUnit, + compile_grouping_regex, + filename_group_key, ) from jabs.pose_estimation import ( PoseEstimation, @@ -879,6 +881,7 @@ def _assign_cv_group_ids( all_group_keys: list[tuple[str, int]], videos: list[str], grouping_strategy: CrossValidationGroupingStrategy, + regex: str | None = None, ) -> tuple[dict[tuple[str, int], int], dict[int, dict]]: """Assign deterministic cross-validation group ids. @@ -886,12 +889,25 @@ def _assign_cv_group_ids( all_group_keys: ``(video, identity)`` tuples in row order. videos: Canonical list of project videos; ids are assigned in this order. grouping_strategy: ``INDIVIDUAL`` groups one (video, identity) pair per - gid; ``VIDEO`` groups all identities of a video together. + gid; ``VIDEO`` groups all identities of a video together; + ``FILENAME_PATTERN`` groups all videos (and their identities) whose + filename yields the same key under ``regex``. + regex: Regular expression used to extract a grouping key from each + video filename. Required for ``FILENAME_PATTERN`` grouping; ignored + otherwise. Returns: Tuple of ``(key_to_gid, group_mapping)`` where ``key_to_gid`` maps each ``(video, identity)`` pair to its group id and ``group_mapping`` maps - each group id back to ``{"video": ..., "identity": ...}``. + each group id back to its source. ``INDIVIDUAL``/``VIDEO`` entries are + ``{"video": ..., "identity": ...}``; ``FILENAME_PATTERN`` entries are + ``{"video": None, "identity": None, "label": , "videos": [...]}`` + where ``videos`` lists the labeled videos in the group. + + Raises: + ValueError: If ``grouping_strategy`` is ``FILENAME_PATTERN`` and + ``regex`` is empty or not a valid regular expression, or if the + strategy is unknown. """ key_to_gid: dict[tuple[str, int], int] = {} group_mapping: dict[int, dict] = {} @@ -918,6 +934,29 @@ def _assign_cv_group_ids( for video_name, ident in all_group_keys: if video_name == v: key_to_gid[(v, ident)] = video_to_gid[v] + elif grouping_strategy == CrossValidationGroupingStrategy.FILENAME_PATTERN: + pattern = compile_grouping_regex(regex or "") + label_to_gid: dict[str, int] = {} + # Group ids are created lazily in row order (which follows canonical + # video order). Videos whose filename does not match the pattern fall + # back to their own group, since filename_group_key returns the + # (unique) filename as the key. + for video_name, ident in all_group_keys: + label = filename_group_key(video_name, pattern) + if label not in label_to_gid: + label_to_gid[label] = gid + group_mapping[gid] = { + "video": None, + "identity": None, + "label": label, + "videos": [], + } + gid += 1 + group_gid = label_to_gid[label] + key_to_gid[(video_name, ident)] = group_gid + videos_in_group = group_mapping[group_gid]["videos"] + if video_name not in videos_in_group: + videos_in_group.append(video_name) else: raise ValueError(f"Unknown grouping strategy: {grouping_strategy}") return key_to_gid, group_mapping @@ -936,21 +975,32 @@ def _build_groups_array( return np.concatenate(groups_list) if groups_list else np.array([], dtype=np.int32) def _excluded_group_ids(self, group_mapping: dict[int, dict]) -> set[int]: - """Return CV group ids whose source video is excluded from training. + """Return CV group ids whose source video(s) are excluded from training. Args: - group_mapping: Mapping of group id to ``{"video": ..., "identity": ...}``. + group_mapping: Mapping of group id to its source. ``INDIVIDUAL``/``VIDEO`` + groups carry a single ``"video"``; ``FILENAME_PATTERN`` groups carry + a ``"videos"`` list (one group can span several videos). Returns: - Set of group ids belonging to videos marked excluded from training. - These groups are still eligible as the held-out test group in - leave-one-group-out cross-validation but are never used for training. + Set of group ids whose constituent videos are all marked excluded from + training. These groups are still eligible as the held-out test group in + leave-one-group-out cross-validation but are never used for training. A + filename-pattern group is excluded only when *every* labeled video in it + is excluded, so a partially-excluded group still contributes its + non-excluded videos' data to training folds. """ - return { - gid - for gid, info in group_mapping.items() - if self._settings_manager.is_video_excluded(info["video"]) - } + excluded: set[int] = set() + for gid, info in group_mapping.items(): + group_videos = info.get("videos") + if group_videos is None: + video = info.get("video") + group_videos = [video] if video is not None else [] + if group_videos and all( + self._settings_manager.is_video_excluded(v) for v in group_videos + ): + excluded.add(gid) + return excluded def get_labeled_features( self, @@ -958,6 +1008,7 @@ def get_labeled_features( progress_callable: Callable[[], None] | None = None, should_terminate_callable: Callable[[], None] | None = None, grouping_strategy: CrossValidationGroupingStrategy | None = None, + grouping_regex: str | None = None, ) -> tuple[dict, dict]: """Get labeled features for training (parallel per-video). @@ -977,6 +1028,8 @@ def get_labeled_features( and as results complete; it should raise a `ThreadTerminatedError` if the user has requested early termination. grouping_strategy: Optional override for cross-validation grouping strategy. If None, uses project settings. + grouping_regex: Optional override for the filename-pattern grouping regex + (only used when the strategy is ``FILENAME_PATTERN``). If None, uses project settings. Returns: tuple[dict, dict]: A tuple of (features, group_mapping). @@ -989,13 +1042,16 @@ def get_labeled_features( The values in the first dict are suitable for `Classifier.leave_one_group_out()`. - The second dict maps group ids to their source: + The second dict maps group ids to their source (the exact shape + depends on the grouping strategy; see ``_assign_cv_group_ids``): { : {'video':