diff --git a/README.md b/README.md index fb61184..4c96600 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,7 @@ The Overlap Index can be used in several settings: - Overlap is estimated by monitoring shared best-matching units (BMUs) or top prototype activations between class pairs. - The global OI is computed as the macro mean of per-class minimum pairwise overlap scores, so each observed class contributes equally to `index`. - A support-weighted companion score is available through `weighted_index` for workflows that need the score to reflect observed class frequencies. +- Global aggregation can exclude one or more label ids through `exclude_classes` without removing those labels from fitting, singleton scores, or pairwise scores. --- @@ -127,6 +128,23 @@ score = oi.index The fitted value is available through `oi.index`. For users who prefer update methods that return the current score directly, `add_batch(X, y)` is also supported. +### Excluding Classes From Global Aggregation + +`exclude_classes` lets you keep a label fully involved in overlap evaluation +while omitting it from the two global summary scores: + +```python +oi = OverlapIndex(exclude_classes=0) +oi = OverlapIndex(exclude_classes=[0, "unlabeled"]) +``` + +This is useful for segmentation workflows where only foreground objects are +labeled but background-only samples should still contribute to pairwise overlap +counts. A common pattern is to create one background class containing those +samples, then pass that class id to `exclude_classes`. The background class will +still appear in `singleton_index`, `pairwise_index`, and prototype ownership; +only `index` and `weighted_index` omit it from aggregation. + ### Online ARTMAP Usage ```python @@ -336,6 +354,10 @@ backends. - `ballcover_kwargs` *(dict, optional)* Additional BallCover options such as `metric`, `cover_fraction`, `chunk_size`, `max_balls`, and `random_state`. +- `exclude_classes` *(None, scalar label, or iterable of labels)* + Label ids to omit from the global `index` and `weighted_index` + aggregation while leaving all fitting and per-class overlap outputs intact. + --- The default parameters are intended for offline batch use with `MiniBatchKMeans`. For online or continual-learning workflows, explicitly choose `model_type="Fuzzy"` or `model_type="Hypersphere"`. For very large ART-based runs, smaller `rho` values (0.5-0.7) may improve run-time performance. @@ -345,14 +367,15 @@ The default parameters are intended for offline batch use with `MiniBatchKMeans` ## Output - **`index`** - Global macro Overlap Index across all observed classes. This is the default - class-balanced score and is usually preferred for imbalance-sensitive - separation analysis. + Global macro Overlap Index across all observed classes that are not listed in + `exclude_classes`. This is the default class-balanced score and is usually + preferred for imbalance-sensitive separation analysis. - **`weighted_index`** - Support-weighted Overlap Index across observed classes. This weights each - class's `singleton_index` value by its positive sample count, which can be - useful when reporting should reflect observed class frequencies. + Support-weighted Overlap Index across observed classes that are not listed in + `exclude_classes`. This weights each included class's `singleton_index` value + by its positive sample count, which can be useful when reporting should + reflect observed class frequencies. - **`singleton_index[y]`** Minimum pairwise overlap score for class `y`. diff --git a/overlapindex/OverlapIndex.py b/overlapindex/OverlapIndex.py index 19cfdd4..ca9adc8 100644 --- a/overlapindex/OverlapIndex.py +++ b/overlapindex/OverlapIndex.py @@ -137,6 +137,21 @@ def _ordered_unique_labels(Y_sets: list[set]) -> np.ndarray: labels.append(label) return np.asarray(labels, dtype=object) + +def _deduplicated_labels(labels: Iterable[Any]) -> list[Any]: + """Return labels with first-seen order preserved after hashability checks.""" + deduplicated = [] + seen = set() + for label in labels: + try: + hash(label) + except TypeError as exc: + raise TypeError("exclude_classes entries must be hashable.") from exc + if label not in seen: + seen.add(label) + deduplicated.append(label) + return deduplicated + # ---------------------------- # OverlapIndex with model_type # ---------------------------- @@ -167,6 +182,7 @@ def __init__( offline_chunk_size: Optional[int] = 10_000, multilabel_pair_mode: Literal["all", "top_m"] = "all", top_m: Optional[int] = None, + exclude_classes: Optional[Any] = None, ) -> None: """ Initialize the overlap index and its clustering backend. @@ -206,6 +222,11 @@ def __init__( top_m : int, optional Number of nearest competing labels per source label when multilabel_pair_mode is "top_m". + exclude_classes : None, scalar label, or iterable of labels, optional + Label ids to exclude from the global ``index`` and + ``weighted_index`` aggregation. Excluded labels remain fully + involved in fitting, singleton scoring, pairwise scoring, and all + bookkeeping outputs. """ self.rho = rho self.r_hat = r_hat @@ -219,6 +240,7 @@ def __init__( self.offline_chunk_size = offline_chunk_size self.multilabel_pair_mode = multilabel_pair_mode self.top_m = top_m + self.exclude_classes = exclude_classes self._validate_multilabel_params() # indices / bookkeeping @@ -340,6 +362,15 @@ def _warn_single_class() -> None: stacklevel=2, ) + @staticmethod + def _warn_all_observed_classes_excluded() -> None: + """Warn when exclusions remove every observed class from summaries.""" + warnings.warn( + "All observed classes were excluded from global aggregation; leaving OverlapIndex scores at 1.0.", + RuntimeWarning, + stacklevel=2, + ) + def _reset_indices(self) -> None: """Reset overlap-index bookkeeping without replacing the clustering backend.""" self.sparse_adj = defaultdict(lambda: 0) @@ -366,9 +397,13 @@ def weighted_index(self) -> float: The default ``index`` is a macro average over per-class scores. This property weights each per-class score by that class's positive support. """ + included_labels = self._included_singleton_labels() + if not included_labels: + return float(self.index) + total_support = sum( int(self.cluster_cardinality.get(y, 0)) - for y in self.singleton_index + for y in included_labels ) if total_support <= 0: return float(self.index) @@ -376,9 +411,46 @@ def weighted_index(self) -> float: weighted_sum = sum( float(score) * int(self.cluster_cardinality.get(y, 0)) for y, score in self.singleton_index.items() + if y in included_labels ) return float(weighted_sum / float(total_support)) + def _normalized_exclude_classes(self) -> set[Any]: + """Normalize the configured global-aggregation exclusions.""" + excluded = self.exclude_classes + if excluded is None: + return set() + + if isinstance(excluded, (str, bytes)) or not isinstance(excluded, Iterable): + labels = [excluded] + else: + labels = list(excluded) + + return set(_deduplicated_labels(labels)) + + def _included_singleton_labels(self) -> list[Any]: + """Return observed labels that still contribute to global summaries.""" + observed_labels = list(self.singleton_index.keys()) + if not observed_labels: + return [] + + excluded = self._normalized_exclude_classes() + return [label for label in observed_labels if label not in excluded] + + def _recompute_global_index(self) -> float: + """Recompute the exclusion-aware global macro overlap index.""" + included_labels = self._included_singleton_labels() + if not included_labels: + if self.singleton_index: + self._warn_all_observed_classes_excluded() + self.index = 1.0 + return self.index + + self.index = float( + np.mean([self.singleton_index[label] for label in included_labels]) + ) + return self.index + @property def module_a(self) -> Any: """Return the underlying ARTMAP module A object for ARTMAP backends.""" @@ -483,9 +555,12 @@ def add_sample(self, x: np.ndarray, y: Any) -> float: self.singleton_index[y] = min( [self.pairwise_index[(y, b)] for b in self.rev_map.keys() if b != y] ) - self.index = float(np.mean(list(self.singleton_index.values()))) + self._recompute_global_index() else: - self._warn_single_class() + if self._included_singleton_labels(): + self._warn_single_class() + else: + self._warn_all_observed_classes_excluded() return self.index def add_batch(self, X: np.ndarray, Y: Any) -> float: @@ -547,7 +622,7 @@ def add_batch(self, X: np.ndarray, Y: Any) -> float: self.singleton_index[y] = min( [self.pairwise_index[(y, b)] for b in self.rev_map.keys() if b != y] ) - self.index = float(np.mean(list(self.singleton_index.values()))) + self._recompute_global_index() return self.index def fit(self, X: np.ndarray, Y: np.ndarray) -> "OverlapIndex": @@ -836,7 +911,7 @@ def _fit_offline_centroid_optimized_multilabel( if self.pairwise_cardinality[(y, b)] > 0 ] self.singleton_index[y] = min(valid_scores) if valid_scores else 1.0 - self.index = float(np.mean(list(self.singleton_index.values()))) + self._recompute_global_index() return self.index @@ -914,7 +989,7 @@ def _fit_offline_centroid_optimized( self.singleton_index[y] = min( [self.pairwise_index[(y, b)] for b in self.rev_map.keys() if b != y] ) - self.index = float(np.mean(list(self.singleton_index.values()))) + self._recompute_global_index() return self.index def _fit_offline_replay( @@ -949,7 +1024,7 @@ def _fit_offline_replay( self.singleton_index[y] = min( [self.pairwise_index[(y, b)] for b in self.rev_map.keys() if b != y] ) - self.index = float(np.mean(list(self.singleton_index.values()))) + self._recompute_global_index() return self.index def fit_offline(self, X: np.ndarray, Y: Any, reset_state: bool = True) -> float: @@ -1010,6 +1085,10 @@ def fit_offline(self, X: np.ndarray, Y: Any, reset_state: bool = True) -> float: self.singleton_index[c] = 1.0 if len(classes) <= 1: + if self._included_singleton_labels(): + self._warn_single_class() + else: + self._warn_all_observed_classes_excluded() return self.index if is_multilabel: diff --git a/tests/test_overlap_index_regression.py b/tests/test_overlap_index_regression.py index 566447c..b984776 100644 --- a/tests/test_overlap_index_regression.py +++ b/tests/test_overlap_index_regression.py @@ -2,6 +2,7 @@ Behavior-regression tests for OverlapIndex. """ +from collections.abc import Iterable from importlib import import_module from importlib.util import find_spec @@ -52,15 +53,17 @@ def _iris_data(): -def _make_model(model_type): +def _make_model(model_type, **overrides): if model_type == "KMeans": - return OverlapIndex( + params = dict( model_type="KMeans", kmeans_k=10, kmeans_kwargs={"random_state": 0, "n_init": 10}, ) + params.update(overrides) + return OverlapIndex(**params) if model_type == "MiniBatchKMeans": - return OverlapIndex( + params = dict( model_type="MiniBatchKMeans", kmeans_k=10, kmeans_kwargs={ @@ -70,8 +73,10 @@ def _make_model(model_type): "max_iter": 100, }, ) + params.update(overrides) + return OverlapIndex(**params) if model_type == "BallCover": - return OverlapIndex( + params = dict( model_type="BallCover", ballcover_k=20, ballcover_radius="auto", @@ -81,11 +86,57 @@ def _make_model(model_type): "random_state": 0, }, ) - return OverlapIndex( + params.update(overrides) + return OverlapIndex(**params) + params = dict( model_type=model_type, rho=0.95, r_hat=0.1, ) + params.update(overrides) + return OverlapIndex(**params) + + +def _normalize_excluded_for_test(excluded): + if excluded is None: + return set() + if isinstance(excluded, (str, bytes)) or not isinstance(excluded, Iterable): + return {excluded} + return set(excluded) + + +def _manual_global_scores(model, excluded): + excluded_set = _normalize_excluded_for_test(excluded) + included_labels = [ + label for label in model.singleton_index if label not in excluded_set + ] + if not included_labels: + return 1.0, 1.0 + + macro = float( + np.mean([model.singleton_index[label] for label in included_labels]) + ) + total_support = sum(model.cluster_cardinality[label] for label in included_labels) + if total_support <= 0: + return macro, macro + + weighted = sum( + model.singleton_index[label] * model.cluster_cardinality[label] + for label in included_labels + ) / total_support + return macro, float(weighted) + + +def _assert_float_mapping_close(received, expected): + assert set(received) == set(expected) + for key, expected_value in expected.items(): + assert np.isclose(received[key], expected_value, atol=0.0, rtol=0.0) + + +def _assert_array_mapping_equal(received, expected): + assert set(received) == set(expected) + for key, expected_value in expected.items(): + assert np.array_equal(received[key], expected_value) def _assert_index_close(received, expected, context): @@ -370,6 +421,81 @@ def test_get_params_and_set_params_follow_sklearn_conventions(): assert model.offline_chunk_size == 2048 +def test_exclude_classes_get_params_and_set_params_follow_sklearn_conventions(): + model = OverlapIndex( + model_type="MiniBatchKMeans", + exclude_classes=[0, "unlabeled"], + ) + + params = model.get_params() + assert params["exclude_classes"] == [0, "unlabeled"] + + model.set_params(exclude_classes="background") + assert model.exclude_classes == "background" + + +@pytest.mark.parametrize("model_type", ["KMeans", "MiniBatchKMeans", "BallCover"]) +def test_exclude_classes_only_changes_global_aggregation_single_label(model_type): + X, y = _iris_data() + + baseline = _make_model(model_type) + excluded = _make_model(model_type, exclude_classes=0) + + baseline.fit(X, y) + excluded.fit(X, y) + + _assert_float_mapping_close( + dict(excluded.singleton_index), + dict(baseline.singleton_index), + ) + _assert_float_mapping_close( + dict(excluded.pairwise_index), + dict(baseline.pairwise_index), + ) + assert dict(excluded.cluster_cardinality) == dict(baseline.cluster_cardinality) + + expected_index, expected_weighted = _manual_global_scores(baseline, 0) + assert np.isclose(excluded.index, expected_index, atol=0.0, rtol=0.0) + assert np.isclose(excluded.weighted_index, expected_weighted, atol=0.0, rtol=0.0) + + +def test_empty_and_absent_exclusions_preserve_current_behavior(): + X, y = _iris_data() + + baseline = _make_model("MiniBatchKMeans") + empty = _make_model("MiniBatchKMeans", exclude_classes=[]) + absent = _make_model( + "MiniBatchKMeans", + exclude_classes=[999, "missing"], + ) + + baseline.fit(X, y) + empty.fit(X, y) + absent.fit(X, y) + + assert np.isclose(empty.index, baseline.index, atol=0.0, rtol=0.0) + assert np.isclose(absent.index, baseline.index, atol=0.0, rtol=0.0) + assert np.isclose(empty.weighted_index, baseline.weighted_index, atol=0.0, rtol=0.0) + assert np.isclose(absent.weighted_index, baseline.weighted_index, atol=0.0, rtol=0.0) + + +def test_all_observed_classes_excluded_warns_and_leaves_default_scores(): + X, y = _iris_data() + model = _make_model("MiniBatchKMeans", exclude_classes=[0, 1, 2]) + + with pytest.warns( + RuntimeWarning, + match="All observed classes were excluded from global aggregation", + ): + returned = model.add_batch(X, y) + + assert returned == 1.0 + assert model.index == 1.0 + assert model.weighted_index == 1.0 + assert set(model.singleton_index) == {0, 1, 2} + assert dict(model.cluster_cardinality) == {0: 50, 1: 50, 2: 50} + + def test_multilabel_sequence_of_same_length_label_lists_is_supported(): X = np.array( [ @@ -464,6 +590,168 @@ def test_multilabel_top_m_limits_competitors(): assert label not in set(competitors) +def test_multilabel_sequence_exclusions_change_only_global_aggregation(): + X = np.array( + [ + [0.0, 0.0], + [0.1, 0.0], + [1.0, 0.0], + [1.1, 0.0], + [2.0, 0.0], + [2.1, 0.0], + ], + dtype=float, + ) + y = [ + ["A", "B"], + ["A"], + ["B"], + ["B", "C"], + ["C"], + ["A", "C"], + ] + baseline = OverlapIndex( + model_type="KMeans", + kmeans_k=1, + kmeans_kwargs={"random_state": 0, "n_init": 10}, + multilabel_pair_mode="top_m", + top_m=1, + ) + excluded = OverlapIndex( + model_type="KMeans", + kmeans_k=1, + kmeans_kwargs={"random_state": 0, "n_init": 10}, + multilabel_pair_mode="top_m", + top_m=1, + exclude_classes="B", + ) + + baseline.fit(X, y) + excluded.fit(X, y) + + _assert_array_mapping_equal(excluded.competitors_, baseline.competitors_) + _assert_float_mapping_close( + dict(excluded.singleton_index), + dict(baseline.singleton_index), + ) + _assert_float_mapping_close( + dict(excluded.pairwise_index), + dict(baseline.pairwise_index), + ) + assert dict(excluded.pairwise_cardinality) == dict(baseline.pairwise_cardinality) + assert dict(excluded.cluster_cardinality) == dict(baseline.cluster_cardinality) + + expected_index, expected_weighted = _manual_global_scores(baseline, "B") + assert np.isclose(excluded.index, expected_index, atol=0.0, rtol=0.0) + assert np.isclose(excluded.weighted_index, expected_weighted, atol=0.0, rtol=0.0) + + +def test_multilabel_indicator_exclusions_change_only_global_aggregation(): + X = np.array( + [ + [0.0, 0.0], + [0.2, 0.0], + [1.0, 0.0], + [1.2, 0.0], + ], + dtype=float, + ) + y = np.array( + [ + [1, 1, 0], + [1, 0, 1], + [0, 1, 1], + [1, 1, 0], + ], + dtype=int, + ) + baseline = OverlapIndex( + model_type="MiniBatchKMeans", + kmeans_k=1, + kmeans_kwargs={"random_state": 0, "n_init": 1, "batch_size": 4}, + ) + excluded = OverlapIndex( + model_type="MiniBatchKMeans", + kmeans_k=1, + kmeans_kwargs={"random_state": 0, "n_init": 1, "batch_size": 4}, + exclude_classes=[1, "absent"], + ) + + baseline.fit(X, y) + excluded.fit(X, y) + + assert excluded.label_to_index_ == baseline.label_to_index_ + _assert_float_mapping_close( + dict(excluded.singleton_index), + dict(baseline.singleton_index), + ) + _assert_float_mapping_close( + dict(excluded.pairwise_index), + dict(baseline.pairwise_index), + ) + assert dict(excluded.pairwise_cardinality) == dict(baseline.pairwise_cardinality) + assert dict(excluded.cluster_cardinality) == dict(baseline.cluster_cardinality) + + expected_index, expected_weighted = _manual_global_scores(baseline, [1, "absent"]) + assert np.isclose(excluded.index, expected_index, atol=0.0, rtol=0.0) + assert np.isclose(excluded.weighted_index, expected_weighted, atol=0.0, rtol=0.0) + + def test_multilabel_top_m_requires_positive_integer(): with pytest.raises(ValueError, match="top_m must be a positive integer"): OverlapIndex(multilabel_pair_mode="top_m", top_m=0) + + +@ARTLIB_REQUIRED +@pytest.mark.parametrize("model_type", ["Fuzzy", "Hypersphere"]) +def test_art_backends_exclude_classes_only_changes_global_aggregation_after_add_batch(model_type): + X, y = _iris_data() + + baseline = _make_model(model_type) + excluded = _make_model(model_type, exclude_classes=0) + + baseline.add_batch(X, y) + excluded.add_batch(X, y) + + _assert_float_mapping_close( + dict(excluded.singleton_index), + dict(baseline.singleton_index), + ) + _assert_float_mapping_close( + dict(excluded.pairwise_index), + dict(baseline.pairwise_index), + ) + assert dict(excluded.cluster_cardinality) == dict(baseline.cluster_cardinality) + + expected_index, expected_weighted = _manual_global_scores(baseline, 0) + assert np.isclose(excluded.index, expected_index, atol=0.0, rtol=0.0) + assert np.isclose(excluded.weighted_index, expected_weighted, atol=0.0, rtol=0.0) + + +@ARTLIB_REQUIRED +@pytest.mark.parametrize("model_type", ["Fuzzy", "Hypersphere"]) +def test_art_backends_exclude_classes_only_changes_global_aggregation_after_add_sample(model_type): + X, y = _iris_data() + + baseline = _make_model(model_type) + excluded = _make_model(model_type, exclude_classes=0) + + baseline.add_batch(X[:-ADD_SAMPLE_IDX], y[:-ADD_SAMPLE_IDX]) + excluded.add_batch(X[:-ADD_SAMPLE_IDX], y[:-ADD_SAMPLE_IDX]) + + baseline.add_sample(X[ADD_SAMPLE_IDX], int(y[ADD_SAMPLE_IDX])) + excluded.add_sample(X[ADD_SAMPLE_IDX], int(y[ADD_SAMPLE_IDX])) + + _assert_float_mapping_close( + dict(excluded.singleton_index), + dict(baseline.singleton_index), + ) + _assert_float_mapping_close( + dict(excluded.pairwise_index), + dict(baseline.pairwise_index), + ) + assert dict(excluded.cluster_cardinality) == dict(baseline.cluster_cardinality) + + expected_index, expected_weighted = _manual_global_scores(baseline, 0) + assert np.isclose(excluded.index, expected_index, atol=0.0, rtol=0.0) + assert np.isclose(excluded.weighted_index, expected_weighted, atol=0.0, rtol=0.0)