From e57d5751ecd39276acb9986c2992afa28de48bf6 Mon Sep 17 00:00:00 2001 From: Glen Beane <356266+gbeane@users.noreply.github.com> Date: Tue, 16 Jun 2026 15:51:31 -0400 Subject: [PATCH 1/2] Add filename pattern cross-validation grouping with group preview --- packages/jabs-core/src/jabs/core/constants.py | 2 + .../jabs-core/src/jabs/core/enums/__init__.py | 9 +- .../src/jabs/core/enums/cv_grouping.py | 53 ++++ packages/jabs-core/tests/test_cv_grouping.py | 75 +++++ src/jabs/classifier/classifier.py | 31 +- src/jabs/classifier/cross_validation.py | 10 +- src/jabs/classifier/multi_class_classifier.py | 31 +- src/jabs/classifier/training_report.py | 6 + src/jabs/project/export_training.py | 47 +-- src/jabs/project/project.py | 91 +++++- src/jabs/project/settings_manager.py | 14 +- src/jabs/scripts/cli/cross_validation.py | 1 + src/jabs/ui/main_window/central_widget.py | 4 + .../cross_validation_settings_group.py | 277 ++++++++++++++++-- .../ui/settings_dialog/settings_dialog.py | 25 +- src/jabs/ui/settings_dialog/settings_group.py | 12 + src/jabs/ui/training_strategy.py | 5 + src/jabs/ui/training_thread.py | 1 + tests/classifier/test_classifier.py | 51 ++++ .../classifier/test_multi_class_classifier.py | 48 +++ tests/project/test_cv_grouping.py | 106 +++++++ tests/project/test_settings_manager.py | 19 ++ tests/ui/_fakes.py | 1 + tests/ui/test_settings_dialog.py | 150 +++++++++- 24 files changed, 1006 insertions(+), 63 deletions(-) create mode 100644 packages/jabs-core/tests/test_cv_grouping.py create mode 100644 tests/project/test_cv_grouping.py diff --git a/packages/jabs-core/src/jabs/core/constants.py b/packages/jabs-core/src/jabs/core/constants.py index 5ea79323..0b6a9bd1 100644 --- a/packages/jabs-core/src/jabs/core/constants.py +++ b/packages/jabs-core/src/jabs/core/constants.py @@ -13,6 +13,8 @@ # settings keys for project settings stored in the project.json file CV_GROUPING_KEY = "cv_grouping" +# regex used when CV_GROUPING_KEY is the "Filename Pattern" strategy +CV_GROUPING_REGEX_KEY = "cv_grouping_regex" CLASSIFIER_MODE_KEY = "classifier_mode" CACHE_FORMAT_KEY = "cache_format" diff --git a/packages/jabs-core/src/jabs/core/enums/__init__.py b/packages/jabs-core/src/jabs/core/enums/__init__.py index 18c8db90..5a2c7f5d 100644 --- a/packages/jabs-core/src/jabs/core/enums/__init__.py +++ b/packages/jabs-core/src/jabs/core/enums/__init__.py @@ -3,7 +3,12 @@ from .cache_format import CacheFormat from .classifier_mode import DEFAULT_CLASSIFIER_MODE, ClassifierMode from .classifier_types import ClassifierType -from .cv_grouping import DEFAULT_CV_GROUPING_STRATEGY, CrossValidationGroupingStrategy +from .cv_grouping import ( + DEFAULT_CV_GROUPING_STRATEGY, + CrossValidationGroupingStrategy, + compile_grouping_regex, + filename_group_key, +) from .inference import ConfidenceMetric, Method, SamplingStrategy from .prediction_type import PredictionType from .storage_format import StorageFormat @@ -22,4 +27,6 @@ "ProjectDistanceUnit", "SamplingStrategy", "StorageFormat", + "compile_grouping_regex", + "filename_group_key", ] diff --git a/packages/jabs-core/src/jabs/core/enums/cv_grouping.py b/packages/jabs-core/src/jabs/core/enums/cv_grouping.py index c3ce9cb0..3c592840 100644 --- a/packages/jabs-core/src/jabs/core/enums/cv_grouping.py +++ b/packages/jabs-core/src/jabs/core/enums/cv_grouping.py @@ -1,3 +1,6 @@ +"""Cross-validation grouping strategy enum and filename-pattern helpers.""" + +import re from enum import Enum @@ -10,6 +13,56 @@ class CrossValidationGroupingStrategy(str, Enum): INDIVIDUAL = "Individual Animal" VIDEO = "Video" + FILENAME_PATTERN = "Filename Pattern" DEFAULT_CV_GROUPING_STRATEGY = CrossValidationGroupingStrategy.INDIVIDUAL + + +def compile_grouping_regex(regex: str) -> re.Pattern[str]: + """Compile a filename-pattern cross-validation grouping regular expression. + + Args: + regex: Regular expression used to extract a grouping key from a video + filename. + + Returns: + The compiled regular expression pattern. + + Raises: + ValueError: If ``regex`` is empty or not a valid regular expression. + """ + if not regex: + raise ValueError("Filename pattern grouping requires a non-empty regular expression") + try: + return re.compile(regex) + except re.error as e: + raise ValueError(f"Invalid filename grouping pattern: {e}") from e + + +def filename_group_key(video_name: str, pattern: re.Pattern[str]) -> str: + """Extract a cross-validation grouping key from a video filename. + + The pattern is applied with :meth:`re.Pattern.search`, so it matches anywhere + in ``video_name``. If the pattern defines a capturing group and it matched, the + first captured group is used as the key (so a pattern that captures the digits + in ``cage_1234.mp4`` yields ``"1234"``); otherwise the full matched text is used + (a pattern matching the whole ``cage_1234`` token yields ``"cage_1234"``). + Videos that do not match the pattern are placed in their own group, keyed by the + filename itself. + + Args: + video_name: Video filename to extract a grouping key from. + pattern: Compiled regular expression (see :func:`compile_grouping_regex`). + + Returns: + The grouping key string. All videos that yield the same key are placed in + the same cross-validation group. + """ + match = pattern.search(video_name) + if match is None: + # No match: the video becomes its own group (keyed by its unique filename). + return video_name + if pattern.groups >= 1 and match.group(1) is not None: + return match.group(1) + return match.group(0) diff --git a/packages/jabs-core/tests/test_cv_grouping.py b/packages/jabs-core/tests/test_cv_grouping.py new file mode 100644 index 00000000..1c2982b9 --- /dev/null +++ b/packages/jabs-core/tests/test_cv_grouping.py @@ -0,0 +1,75 @@ +"""Tests for cross-validation grouping enum and filename-pattern helpers.""" + +import pytest + +from jabs.core.enums import ( + CrossValidationGroupingStrategy, + compile_grouping_regex, + filename_group_key, +) + + +def test_filename_pattern_member_exists() -> None: + """The FILENAME_PATTERN strategy is available with its display value.""" + assert CrossValidationGroupingStrategy.FILENAME_PATTERN.value == "Filename Pattern" + assert CrossValidationGroupingStrategy("Filename Pattern") is ( + CrossValidationGroupingStrategy.FILENAME_PATTERN + ) + + +def test_compile_grouping_regex_valid() -> None: + """A valid pattern compiles to a usable regex.""" + pattern = compile_grouping_regex(r"cage_(\d+)") + assert pattern.search("cage_0042.mp4") is not None + + +@pytest.mark.parametrize("regex", ["", None], ids=["empty", "none"]) +def test_compile_grouping_regex_empty_raises(regex) -> None: + """An empty (or falsy) pattern is rejected.""" + with pytest.raises(ValueError, match="non-empty"): + compile_grouping_regex(regex) + + +def test_compile_grouping_regex_invalid_raises() -> None: + """A syntactically invalid pattern raises ValueError, not re.error.""" + with pytest.raises(ValueError, match="Invalid filename grouping pattern"): + compile_grouping_regex("cage_(") + + +def test_filename_group_key_uses_capture_group() -> None: + """When the pattern has a capture group, the captured text is the key.""" + pattern = compile_grouping_regex(r"cage_(\d+)") + assert filename_group_key("cage_0042_2026-06-16.mp4", pattern) == "0042" + + +def test_filename_group_key_uses_full_match_without_capture_group() -> None: + """Without a capture group, the whole matched substring is the key.""" + pattern = compile_grouping_regex(r"cage_\d+") + assert filename_group_key("cage_0042_2026-06-16.mp4", pattern) == "cage_0042" + + +def test_filename_group_key_searches_anywhere() -> None: + """The pattern matches anywhere in the filename (re.search semantics).""" + pattern = compile_grouping_regex(r"cage_(\d+)") + assert filename_group_key("2026-06-16_cage_0007_cam1.avi", pattern) == "0007" + + +def test_filename_group_key_unmatched_returns_filename() -> None: + """A filename that does not match becomes its own group (keyed by the name).""" + pattern = compile_grouping_regex(r"cage_(\d+)") + assert filename_group_key("mouse_video.mp4", pattern) == "mouse_video.mp4" + + +def test_filename_group_key_same_cage_different_files_share_key() -> None: + """Different files from the same cage produce the same grouping key.""" + pattern = compile_grouping_regex(r"cage_(\d+)") + key_a = filename_group_key("cage_0042_day1.mp4", pattern) + key_b = filename_group_key("cage_0042_day2.avi", pattern) + assert key_a == key_b == "0042" + + +def test_filename_group_key_optional_capture_group_falls_back_to_full_match() -> None: + """An optional capture group that does not participate falls back to the full match.""" + pattern = compile_grouping_regex(r"cage(_extra)?_\d+") + # The optional group does not match here, so the full match is used. + assert filename_group_key("cage_0042.mp4", pattern) == "cage_0042" diff --git a/src/jabs/classifier/classifier.py b/src/jabs/classifier/classifier.py index 0162fb18..3003d99a 100644 --- a/src/jabs/classifier/classifier.py +++ b/src/jabs/classifier/classifier.py @@ -12,6 +12,8 @@ DEFAULT_CV_GROUPING_STRATEGY, ClassifierType, CrossValidationGroupingStrategy, + compile_grouping_regex, + filename_group_key, ) from jabs.core.utils import hash_file from jabs.project import Project, load_training_data @@ -318,6 +320,7 @@ def confusion_matrix(truth: np.ndarray, predictions: np.ndarray) -> np.ndarray: def count_label_threshold( all_counts: dict, cv_grouping_strategy: CrossValidationGroupingStrategy = DEFAULT_CV_GROUPING_STRATEGY, + cv_grouping_regex: str | None = None, ) -> int: """Count groups that meet the label-threshold criteria. @@ -326,6 +329,9 @@ def count_label_threshold( Structure is a dict[video_name][identity] of fragmented and unfragmented frame/bout count tuples. cv_grouping_strategy: Cross-validation grouping strategy. + cv_grouping_regex: Regex used to extract a grouping key from each video + filename when ``cv_grouping_strategy`` is ``FILENAME_PATTERN``. An + empty or invalid regex yields a count of 0 (no trainable groups). Returns: Number of groups that meet the labeling threshold criteria. @@ -356,6 +362,24 @@ def count_label_threshold( and not_behavior_sum >= Classifier.LABEL_THRESHOLD ): group_count += 1 + elif cv_grouping_strategy == CrossValidationGroupingStrategy.FILENAME_PATTERN: + try: + pattern = compile_grouping_regex(cv_grouping_regex or "") + except ValueError: + return 0 + group_sums: dict[str, list[int]] = {} + for video in all_counts: + label = filename_group_key(video, pattern) + sums = group_sums.setdefault(label, [0, 0]) + for identity_count in all_counts[video].values(): + sums[0] += identity_count["fragmented_frame_counts"][0] + sums[1] += identity_count["fragmented_frame_counts"][1] + for behavior_sum, not_behavior_sum in group_sums.values(): + if ( + behavior_sum >= Classifier.LABEL_THRESHOLD + and not_behavior_sum >= Classifier.LABEL_THRESHOLD + ): + group_count += 1 else: raise ValueError(f"Unknown cv_grouping_strategy: {cv_grouping_strategy}") return group_count @@ -365,6 +389,7 @@ def label_threshold_met( all_counts: dict, min_groups: int, cv_grouping_strategy: CrossValidationGroupingStrategy = DEFAULT_CV_GROUPING_STRATEGY, + cv_grouping_regex: str | None = None, ) -> bool: """Determine whether the labeling threshold is met. @@ -372,11 +397,15 @@ def label_threshold_met( all_counts: Labeled frame and bout counts for the entire project. min_groups: Minimum number of groups required. cv_grouping_strategy: Cross-validation grouping strategy. + cv_grouping_regex: Regex used for ``FILENAME_PATTERN`` grouping (see + :meth:`count_label_threshold`). Returns: True if there are enough groups meeting the threshold. """ group_count = Classifier.count_label_threshold( - all_counts, cv_grouping_strategy=cv_grouping_strategy + all_counts, + cv_grouping_strategy=cv_grouping_strategy, + cv_grouping_regex=cv_grouping_regex, ) return 1 < group_count >= min_groups diff --git a/src/jabs/classifier/cross_validation.py b/src/jabs/classifier/cross_validation.py index d015334a..6ea2caeb 100644 --- a/src/jabs/classifier/cross_validation.py +++ b/src/jabs/classifier/cross_validation.py @@ -125,7 +125,15 @@ def _train_multiclass_fold( def _test_label_from_group(test_info: dict) -> str: - """Render a CV test-group label for the report (video name + optional identity).""" + """Render a CV test-group label for the report. + + Filename-pattern groups carry a ``label`` (the regex-extracted key, e.g. + ``"cage_1234"``); otherwise the label is the video name plus an optional + identity. + """ + label = test_info.get("label") + if label is not None: + return label if test_info["identity"] is not None: return f"{test_info['video']} [{test_info['identity']}]" return test_info["video"] diff --git a/src/jabs/classifier/multi_class_classifier.py b/src/jabs/classifier/multi_class_classifier.py index b1e2783b..cb28153d 100644 --- a/src/jabs/classifier/multi_class_classifier.py +++ b/src/jabs/classifier/multi_class_classifier.py @@ -17,6 +17,8 @@ DEFAULT_CV_GROUPING_STRATEGY, ClassifierType, CrossValidationGroupingStrategy, + compile_grouping_regex, + filename_group_key, ) from jabs.core.utils import hash_file from jabs.project import load_multiclass_training_data @@ -423,6 +425,7 @@ def count_label_threshold( counts_by_behavior: dict[str, dict], behavior_names: list[str], cv_grouping_strategy: CrossValidationGroupingStrategy = DEFAULT_CV_GROUPING_STRATEGY, + cv_grouping_regex: str | None = None, ) -> int: """Count multi-class LOGO groups that satisfy the relaxed acceptance rule. @@ -439,6 +442,9 @@ def count_label_threshold( ``counts_by_behavior``. Typically includes ``MULTICLASS_NONE_BEHAVIOR``. cv_grouping_strategy: Cross-validation grouping strategy. + cv_grouping_regex: Regex used to extract a grouping key from each video + filename when ``cv_grouping_strategy`` is ``FILENAME_PATTERN``. An + empty or invalid regex yields a count of 0 (no trainable groups). Returns: Number of groups that can serve as a valid multi-class LOGO test split. @@ -450,13 +456,30 @@ def count_label_threshold( if not behavior_names: return 0 + # FILENAME_PATTERN aggregates like VIDEO grouping, but keys each group by + # the regex-extracted filename key instead of the video name (so several + # videos can share one group). An empty/invalid regex means no groups. + pattern = None + if cv_grouping_strategy == CrossValidationGroupingStrategy.FILENAME_PATTERN: + try: + pattern = compile_grouping_regex(cv_grouping_regex or "") + except ValueError: + return 0 + threshold = MultiClassClassifier.LABEL_THRESHOLD group_class_counts: dict[tuple[str, int] | str, dict[str, int]] = {} for behavior_name in behavior_names: behavior_counts = counts_by_behavior.get(behavior_name, {}) for video_name, video_counts in behavior_counts.items(): - if cv_grouping_strategy == CrossValidationGroupingStrategy.VIDEO: - key: tuple[str, int] | str = video_name + if cv_grouping_strategy in ( + CrossValidationGroupingStrategy.VIDEO, + CrossValidationGroupingStrategy.FILENAME_PATTERN, + ): + key: tuple[str, int] | str = ( + filename_group_key(video_name, pattern) + if pattern is not None + else video_name + ) group_entry = group_class_counts.setdefault(key, {}) group_entry[behavior_name] = group_entry.get(behavior_name, 0) + sum( identity_counts["fragmented_frame_counts"][0] @@ -498,6 +521,7 @@ def label_threshold_met( behavior_names: list[str], min_groups: int, cv_grouping_strategy: CrossValidationGroupingStrategy = DEFAULT_CV_GROUPING_STRATEGY, + cv_grouping_regex: str | None = None, ) -> bool: """Determine whether multi-class labels support ``min_groups`` LOGO splits. @@ -511,6 +535,8 @@ class names are supplied. at 1, since multi-class training requires at least one valid split. cv_grouping_strategy: Cross-validation grouping strategy. + cv_grouping_regex: Regex used for ``FILENAME_PATTERN`` grouping (see + :meth:`count_label_threshold`). Returns: True if the count of valid splits meets ``max(1, min_groups)``. @@ -521,5 +547,6 @@ class names are supplied. counts_by_behavior=counts_by_behavior, behavior_names=behavior_names, cv_grouping_strategy=cv_grouping_strategy, + cv_grouping_regex=cv_grouping_regex, ) return valid_splits >= max(1, min_groups) diff --git a/src/jabs/classifier/training_report.py b/src/jabs/classifier/training_report.py index b14dd60f..948ce8ba 100644 --- a/src/jabs/classifier/training_report.py +++ b/src/jabs/classifier/training_report.py @@ -105,6 +105,8 @@ class TrainingReportData: training_time_ms: Total training time in milliseconds. timestamp: Datetime when training was completed. cv_grouping_strategy: Strategy used for cross-validation grouping. + cv_grouping_regex: Filename-pattern regex used for grouping. Only set when + the grouping strategy is "Filename Pattern". """ behavior_name: str @@ -124,6 +126,7 @@ class TrainingReportData: bouts_not_behavior: int = 0 class_frame_counts: dict[str, int] | None = None class_bout_counts: dict[str, int] | None = None + cv_grouping_regex: str | None = None def _escape_markdown(text: str) -> str: @@ -287,6 +290,8 @@ def generate_markdown_report(data: TrainingReportData) -> str: lines.append("### Iteration Details") lines.append(f"CV Grouping Strategy: {data.cv_grouping_strategy.value}") + if data.cv_grouping_regex: + lines.append(f"CV Grouping Pattern: `{data.cv_grouping_regex}`") lines.append("") lines.append(_format_iteration_table(data.cv_results)) lines.append("") @@ -410,6 +415,7 @@ def generate_json_report(data: TrainingReportData) -> dict: "training_time_ms": int(data.training_time_ms), "timestamp": timestamp_str, "cv_grouping_strategy": data.cv_grouping_strategy.value, + "cv_grouping_regex": data.cv_grouping_regex, "frames_behavior": int(data.frames_behavior), "frames_not_behavior": int(data.frames_not_behavior), "bouts_behavior": int(data.bouts_behavior), diff --git a/src/jabs/project/export_training.py b/src/jabs/project/export_training.py index 4ecaa032..577950f7 100644 --- a/src/jabs/project/export_training.py +++ b/src/jabs/project/export_training.py @@ -20,6 +20,34 @@ from jabs.project import Project +def _write_group_mapping( + out_h5: h5py.File, group_mapping: dict[int, dict], string_type: np.dtype +) -> None: + """Write the cross-validation group mapping into the exported HDF5 file. + + For each group, stores an ``identity`` (the animal identity, or ``-1`` when the + group is not identity-specific) and a ``video_name``. ``identity`` is ``-1`` for + ``VIDEO`` and ``FILENAME_PATTERN`` grouping. For ``FILENAME_PATTERN`` groups, + ``video_name`` holds the regex-extracted group label (e.g. ``"cage_1234"``) + rather than a single filename, since one group can span multiple videos. + + Args: + out_h5: Open HDF5 file to write into. + group_mapping: Mapping of group id to its source descriptor. + string_type: h5py variable-length string dtype for the video_name dataset. + """ + for group, info in group_mapping.items(): + identity_dset = out_h5.create_dataset( + f"group_mapping/{group}/identity", (1,), dtype=np.int64 + ) + identity = info["identity"] + identity_dset[:] = identity if identity is not None else -1 + video_name_dset = out_h5.create_dataset( + f"group_mapping/{group}/video_name", (1,), dtype=string_type + ) + video_name_dset[:] = info["label"] if info.get("label") is not None else info["video"] + + def export_training_data( project: "Project", behavior: str, @@ -76,15 +104,7 @@ def export_training_data( out_h5.create_dataset("label", data=features["labels"]) # store the video/identity to group mapping in the h5 file - # identity is None when VIDEO grouping strategy is used; store -1 as a sentinel - for group in group_mapping: - dset = out_h5.create_dataset(f"group_mapping/{group}/identity", (1,), dtype=np.int64) - identity = group_mapping[group]["identity"] - dset[:] = identity if identity is not None else -1 - dset = out_h5.create_dataset( - f"group_mapping/{group}/video_name", (1,), dtype=string_type - ) - dset[:] = group_mapping[group]["video"] + _write_group_mapping(out_h5, group_mapping, string_type) # return output path, so if it was generated automatically the caller # will know @@ -159,14 +179,7 @@ def export_training_data_multiclass( out_h5.create_dataset("group", data=features["groups"]) - for group in group_mapping: - dset = out_h5.create_dataset(f"group_mapping/{group}/identity", (1,), dtype=np.int64) - identity = group_mapping[group]["identity"] - dset[:] = identity if identity is not None else -1 - dset = out_h5.create_dataset( - f"group_mapping/{group}/video_name", (1,), dtype=string_type - ) - dset[:] = group_mapping[group]["video"] + _write_group_mapping(out_h5, group_mapping, string_type) return out_file diff --git a/src/jabs/project/project.py b/src/jabs/project/project.py index a6746842..f25faad5 100644 --- a/src/jabs/project/project.py +++ b/src/jabs/project/project.py @@ -23,6 +23,8 @@ ClassifierMode, CrossValidationGroupingStrategy, ProjectDistanceUnit, + compile_grouping_regex, + filename_group_key, ) from jabs.pose_estimation import ( PoseEstimation, @@ -879,6 +881,7 @@ def _assign_cv_group_ids( all_group_keys: list[tuple[str, int]], videos: list[str], grouping_strategy: CrossValidationGroupingStrategy, + regex: str | None = None, ) -> tuple[dict[tuple[str, int], int], dict[int, dict]]: """Assign deterministic cross-validation group ids. @@ -886,12 +889,25 @@ def _assign_cv_group_ids( all_group_keys: ``(video, identity)`` tuples in row order. videos: Canonical list of project videos; ids are assigned in this order. grouping_strategy: ``INDIVIDUAL`` groups one (video, identity) pair per - gid; ``VIDEO`` groups all identities of a video together. + gid; ``VIDEO`` groups all identities of a video together; + ``FILENAME_PATTERN`` groups all videos (and their identities) whose + filename yields the same key under ``regex``. + regex: Regular expression used to extract a grouping key from each + video filename. Required for ``FILENAME_PATTERN`` grouping; ignored + otherwise. Returns: Tuple of ``(key_to_gid, group_mapping)`` where ``key_to_gid`` maps each ``(video, identity)`` pair to its group id and ``group_mapping`` maps - each group id back to ``{"video": ..., "identity": ...}``. + each group id back to its source. ``INDIVIDUAL``/``VIDEO`` entries are + ``{"video": ..., "identity": ...}``; ``FILENAME_PATTERN`` entries are + ``{"video": None, "identity": None, "label": , "videos": [...]}`` + where ``videos`` lists the labeled videos in the group. + + Raises: + ValueError: If ``grouping_strategy`` is ``FILENAME_PATTERN`` and + ``regex`` is empty or not a valid regular expression, or if the + strategy is unknown. """ key_to_gid: dict[tuple[str, int], int] = {} group_mapping: dict[int, dict] = {} @@ -918,6 +934,29 @@ def _assign_cv_group_ids( for video_name, ident in all_group_keys: if video_name == v: key_to_gid[(v, ident)] = video_to_gid[v] + elif grouping_strategy == CrossValidationGroupingStrategy.FILENAME_PATTERN: + pattern = compile_grouping_regex(regex or "") + label_to_gid: dict[str, int] = {} + # Group ids are created lazily in row order (which follows canonical + # video order). Videos whose filename does not match the pattern fall + # back to their own group, since filename_group_key returns the + # (unique) filename as the key. + for video_name, ident in all_group_keys: + label = filename_group_key(video_name, pattern) + if label not in label_to_gid: + label_to_gid[label] = gid + group_mapping[gid] = { + "video": None, + "identity": None, + "label": label, + "videos": [], + } + gid += 1 + group_gid = label_to_gid[label] + key_to_gid[(video_name, ident)] = group_gid + videos_in_group = group_mapping[group_gid]["videos"] + if video_name not in videos_in_group: + videos_in_group.append(video_name) else: raise ValueError(f"Unknown grouping strategy: {grouping_strategy}") return key_to_gid, group_mapping @@ -936,21 +975,32 @@ def _build_groups_array( return np.concatenate(groups_list) if groups_list else np.array([], dtype=np.int32) def _excluded_group_ids(self, group_mapping: dict[int, dict]) -> set[int]: - """Return CV group ids whose source video is excluded from training. + """Return CV group ids whose source video(s) are excluded from training. Args: - group_mapping: Mapping of group id to ``{"video": ..., "identity": ...}``. + group_mapping: Mapping of group id to its source. ``INDIVIDUAL``/``VIDEO`` + groups carry a single ``"video"``; ``FILENAME_PATTERN`` groups carry + a ``"videos"`` list (one group can span several videos). Returns: - Set of group ids belonging to videos marked excluded from training. - These groups are still eligible as the held-out test group in - leave-one-group-out cross-validation but are never used for training. + Set of group ids whose constituent videos are all marked excluded from + training. These groups are still eligible as the held-out test group in + leave-one-group-out cross-validation but are never used for training. A + filename-pattern group is excluded only when *every* labeled video in it + is excluded, so a partially-excluded group still contributes its + non-excluded videos' data to training folds. """ - return { - gid - for gid, info in group_mapping.items() - if self._settings_manager.is_video_excluded(info["video"]) - } + excluded: set[int] = set() + for gid, info in group_mapping.items(): + group_videos = info.get("videos") + if group_videos is None: + video = info.get("video") + group_videos = [video] if video is not None else [] + if group_videos and all( + self._settings_manager.is_video_excluded(v) for v in group_videos + ): + excluded.add(gid) + return excluded def get_labeled_features( self, @@ -958,6 +1008,7 @@ def get_labeled_features( progress_callable: Callable[[], None] | None = None, should_terminate_callable: Callable[[], None] | None = None, grouping_strategy: CrossValidationGroupingStrategy | None = None, + grouping_regex: str | None = None, ) -> tuple[dict, dict]: """Get labeled features for training (parallel per-video). @@ -977,6 +1028,8 @@ def get_labeled_features( and as results complete; it should raise a `ThreadTerminatedError` if the user has requested early termination. grouping_strategy: Optional override for cross-validation grouping strategy. If None, uses project settings. + grouping_regex: Optional override for the filename-pattern grouping regex + (only used when the strategy is ``FILENAME_PATTERN``). If None, uses project settings. Returns: tuple[dict, dict]: A tuple of (features, group_mapping). @@ -989,13 +1042,16 @@ def get_labeled_features( The values in the first dict are suitable for `Classifier.leave_one_group_out()`. - The second dict maps group ids to their source: + The second dict maps group ids to their source (the exact shape + depends on the grouping strategy; see ``_assign_cv_group_ids``): { : {'video':