Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ The Overlap Index can be used in several settings:
- Overlap is estimated by monitoring shared best-matching units (BMUs) or top prototype activations between class pairs.
- The global OI is computed as the macro mean of per-class minimum pairwise overlap scores, so each observed class contributes equally to `index`.
- A support-weighted companion score is available through `weighted_index` for workflows that need the score to reflect observed class frequencies.
- Global aggregation can exclude one or more label ids through `exclude_classes` without removing those labels from fitting, singleton scores, or pairwise scores.

---

Expand All @@ -127,6 +128,23 @@ score = oi.index

The fitted value is available through `oi.index`. For users who prefer update methods that return the current score directly, `add_batch(X, y)` is also supported.

### Excluding Classes From Global Aggregation

`exclude_classes` lets you keep a label fully involved in overlap evaluation
while omitting it from the two global summary scores:

```python
oi = OverlapIndex(exclude_classes=0)
oi = OverlapIndex(exclude_classes=[0, "unlabeled"])
```

This is useful for segmentation workflows where only foreground objects are
labeled but background-only samples should still contribute to pairwise overlap
counts. A common pattern is to create one background class containing those
samples, then pass that class id to `exclude_classes`. The background class will
still appear in `singleton_index`, `pairwise_index`, and prototype ownership;
only `index` and `weighted_index` omit it from aggregation.

### Online ARTMAP Usage

```python
Expand Down Expand Up @@ -336,6 +354,10 @@ backends.
- `ballcover_kwargs` *(dict, optional)*
Additional BallCover options such as `metric`, `cover_fraction`, `chunk_size`, `max_balls`, and `random_state`.

- `exclude_classes` *(None, scalar label, or iterable of labels)*
Label ids to omit from the global `index` and `weighted_index`
aggregation while leaving all fitting and per-class overlap outputs intact.

---

The default parameters are intended for offline batch use with `MiniBatchKMeans`. For online or continual-learning workflows, explicitly choose `model_type="Fuzzy"` or `model_type="Hypersphere"`. For very large ART-based runs, smaller `rho` values (0.5-0.7) may improve run-time performance.
Expand All @@ -345,14 +367,15 @@ The default parameters are intended for offline batch use with `MiniBatchKMeans`
## Output

- **`index`**
Global macro Overlap Index across all observed classes. This is the default
class-balanced score and is usually preferred for imbalance-sensitive
separation analysis.
Global macro Overlap Index across all observed classes that are not listed in
`exclude_classes`. This is the default class-balanced score and is usually
preferred for imbalance-sensitive separation analysis.

- **`weighted_index`**
Support-weighted Overlap Index across observed classes. This weights each
class's `singleton_index` value by its positive sample count, which can be
useful when reporting should reflect observed class frequencies.
Support-weighted Overlap Index across observed classes that are not listed in
`exclude_classes`. This weights each included class's `singleton_index` value
by its positive sample count, which can be useful when reporting should
reflect observed class frequencies.

- **`singleton_index[y]`**
Minimum pairwise overlap score for class `y`.
Expand Down
93 changes: 86 additions & 7 deletions overlapindex/OverlapIndex.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,21 @@ def _ordered_unique_labels(Y_sets: list[set]) -> np.ndarray:
labels.append(label)
return np.asarray(labels, dtype=object)


def _deduplicated_labels(labels: Iterable[Any]) -> list[Any]:
"""Return labels with first-seen order preserved after hashability checks."""
deduplicated = []
seen = set()
for label in labels:
try:
hash(label)
except TypeError as exc:
raise TypeError("exclude_classes entries must be hashable.") from exc
if label not in seen:
seen.add(label)
deduplicated.append(label)
return deduplicated

# ----------------------------
# OverlapIndex with model_type
# ----------------------------
Expand Down Expand Up @@ -167,6 +182,7 @@ def __init__(
offline_chunk_size: Optional[int] = 10_000,
multilabel_pair_mode: Literal["all", "top_m"] = "all",
top_m: Optional[int] = None,
exclude_classes: Optional[Any] = None,
) -> None:
"""
Initialize the overlap index and its clustering backend.
Expand Down Expand Up @@ -206,6 +222,11 @@ def __init__(
top_m : int, optional
Number of nearest competing labels per source label when
multilabel_pair_mode is "top_m".
exclude_classes : None, scalar label, or iterable of labels, optional
Label ids to exclude from the global ``index`` and
``weighted_index`` aggregation. Excluded labels remain fully
involved in fitting, singleton scoring, pairwise scoring, and all
bookkeeping outputs.
"""
self.rho = rho
self.r_hat = r_hat
Expand All @@ -219,6 +240,7 @@ def __init__(
self.offline_chunk_size = offline_chunk_size
self.multilabel_pair_mode = multilabel_pair_mode
self.top_m = top_m
self.exclude_classes = exclude_classes
self._validate_multilabel_params()

# indices / bookkeeping
Expand Down Expand Up @@ -340,6 +362,15 @@ def _warn_single_class() -> None:
stacklevel=2,
)

@staticmethod
def _warn_all_observed_classes_excluded() -> None:
"""Warn when exclusions remove every observed class from summaries."""
warnings.warn(
"All observed classes were excluded from global aggregation; leaving OverlapIndex scores at 1.0.",
RuntimeWarning,
stacklevel=2,
)

def _reset_indices(self) -> None:
"""Reset overlap-index bookkeeping without replacing the clustering backend."""
self.sparse_adj = defaultdict(lambda: 0)
Expand All @@ -366,19 +397,60 @@ def weighted_index(self) -> float:
The default ``index`` is a macro average over per-class scores. This
property weights each per-class score by that class's positive support.
"""
included_labels = self._included_singleton_labels()
if not included_labels:
return float(self.index)

total_support = sum(
int(self.cluster_cardinality.get(y, 0))
for y in self.singleton_index
for y in included_labels
)
if total_support <= 0:
return float(self.index)

weighted_sum = sum(
float(score) * int(self.cluster_cardinality.get(y, 0))
for y, score in self.singleton_index.items()
if y in included_labels
)
return float(weighted_sum / float(total_support))

def _normalized_exclude_classes(self) -> set[Any]:
"""Normalize the configured global-aggregation exclusions."""
excluded = self.exclude_classes
if excluded is None:
return set()

if isinstance(excluded, (str, bytes)) or not isinstance(excluded, Iterable):
labels = [excluded]
else:
labels = list(excluded)

return set(_deduplicated_labels(labels))

def _included_singleton_labels(self) -> list[Any]:
"""Return observed labels that still contribute to global summaries."""
observed_labels = list(self.singleton_index.keys())
if not observed_labels:
return []

excluded = self._normalized_exclude_classes()
return [label for label in observed_labels if label not in excluded]

def _recompute_global_index(self) -> float:
"""Recompute the exclusion-aware global macro overlap index."""
included_labels = self._included_singleton_labels()
if not included_labels:
if self.singleton_index:
self._warn_all_observed_classes_excluded()
self.index = 1.0
return self.index

self.index = float(
np.mean([self.singleton_index[label] for label in included_labels])
)
return self.index

@property
def module_a(self) -> Any:
"""Return the underlying ARTMAP module A object for ARTMAP backends."""
Expand Down Expand Up @@ -483,9 +555,12 @@ def add_sample(self, x: np.ndarray, y: Any) -> float:
self.singleton_index[y] = min(
[self.pairwise_index[(y, b)] for b in self.rev_map.keys() if b != y]
)
self.index = float(np.mean(list(self.singleton_index.values())))
self._recompute_global_index()
else:
self._warn_single_class()
if self._included_singleton_labels():
self._warn_single_class()
else:
self._warn_all_observed_classes_excluded()
return self.index

def add_batch(self, X: np.ndarray, Y: Any) -> float:
Expand Down Expand Up @@ -547,7 +622,7 @@ def add_batch(self, X: np.ndarray, Y: Any) -> float:
self.singleton_index[y] = min(
[self.pairwise_index[(y, b)] for b in self.rev_map.keys() if b != y]
)
self.index = float(np.mean(list(self.singleton_index.values())))
self._recompute_global_index()
return self.index

def fit(self, X: np.ndarray, Y: np.ndarray) -> "OverlapIndex":
Expand Down Expand Up @@ -836,7 +911,7 @@ def _fit_offline_centroid_optimized_multilabel(
if self.pairwise_cardinality[(y, b)] > 0
]
self.singleton_index[y] = min(valid_scores) if valid_scores else 1.0
self.index = float(np.mean(list(self.singleton_index.values())))
self._recompute_global_index()

return self.index

Expand Down Expand Up @@ -914,7 +989,7 @@ def _fit_offline_centroid_optimized(
self.singleton_index[y] = min(
[self.pairwise_index[(y, b)] for b in self.rev_map.keys() if b != y]
)
self.index = float(np.mean(list(self.singleton_index.values())))
self._recompute_global_index()
return self.index

def _fit_offline_replay(
Expand Down Expand Up @@ -949,7 +1024,7 @@ def _fit_offline_replay(
self.singleton_index[y] = min(
[self.pairwise_index[(y, b)] for b in self.rev_map.keys() if b != y]
)
self.index = float(np.mean(list(self.singleton_index.values())))
self._recompute_global_index()
return self.index

def fit_offline(self, X: np.ndarray, Y: Any, reset_state: bool = True) -> float:
Expand Down Expand Up @@ -1010,6 +1085,10 @@ def fit_offline(self, X: np.ndarray, Y: Any, reset_state: bool = True) -> float:
self.singleton_index[c] = 1.0

if len(classes) <= 1:
if self._included_singleton_labels():
self._warn_single_class()
else:
self._warn_all_observed_classes_excluded()
return self.index

if is_multilabel:
Expand Down
Loading