diff --git a/benchmarks/bench_de_metrics.py b/benchmarks/bench_de_metrics.py new file mode 100644 index 0000000..0a668c3 --- /dev/null +++ b/benchmarks/bench_de_metrics.py @@ -0,0 +1,191 @@ +"""Microbenchmark for the per-perturbation-scan DE metrics. + +`DENsigCounts` and `compute_generic_auc` (pr/roc) used to loop over every +perturbation doing a full-table `.filter(target == pert)` (or per-pert +`get_significant_genes`). They now slice the table once -- a grouped count for +`DENsigCounts`, a single `partition_by` for the AUC metrics. This script times +the pre-optimization implementations (verbatim baselines below) against the +current ones across a range of perturbation counts and confirms identical +output. + +Run with the package importable, e.g.:: + + python benchmarks/bench_de_metrics.py + python benchmarks/bench_de_metrics.py --n-pert 1000 4000 8000 --n-genes 50 + +`_reference_densig_counts` / `_reference_generic_auc` are verbatim copies of the +pre-optimization implementations, kept here only as benchmark baselines; they +are not used by the package. +""" + +from __future__ import annotations + +import argparse +import gc +import logging +import math +import time + +import numpy as np +import polars as pl +from sklearn.metrics import auc, average_precision_score, roc_curve + +from cell_eval._types import DEComparison, DEResults +from cell_eval.metrics._de import DENsigCounts, compute_pr_auc, compute_roc_auc + +logging.getLogger("cell_eval._types._de").setLevel(logging.WARNING) + +_FDR = 0.05 + + +def _reference_densig_counts( + data: DEComparison, fdr_threshold: float +) -> dict[str, dict[str, int]]: + """Verbatim pre-optimization DENsigCounts.__call__.""" + counts = {} + for pert in data.iter_perturbations(): + real_sig = data.real.get_significant_genes(pert, fdr_threshold) + pred_sig = data.pred.get_significant_genes(pert, fdr_threshold) + counts[pert] = {"real": int(real_sig.size), "pred": int(pred_sig.size)} + return counts + + +def _reference_generic_auc(data: DEComparison, method: str) -> dict[str, float]: + """Verbatim pre-optimization compute_generic_auc (per-pert .filter loop).""" + target_col = data.real.target_col + feature_col = data.real.feature_col + real_fdr_col = data.real.fdr_col + pred_fdr_col = data.pred.fdr_col + + labeled_real = data.real.data.with_columns( + (pl.col(real_fdr_col) < 0.05).cast(pl.Float32).alias("label") + ).select([target_col, feature_col, "label"]) + + pred_q = pl.col(pred_fdr_col).fill_null(1.0).clip(1e-10, 1.0) + merged = ( + labeled_real.join( + data.pred.data.select([target_col, feature_col, pred_fdr_col]), + on=[target_col, feature_col], + how="left", + coalesce=True, + ) + .drop_nulls(["label"]) + .with_columns(pred_q.alias(pred_fdr_col), (-pred_q.log10()).alias("nlp")) + ) + + results: dict[str, float] = {} + for pert in data.iter_perturbations(): + pert_data = merged.filter(pl.col(target_col) == pert) + if pert_data.shape[0] == 0: + results[pert] = float("nan") + continue + labels = pert_data["label"].to_numpy() + scores = pert_data["nlp"].to_numpy() + if not (0 < labels.sum() < len(labels)): + results[pert] = float("nan") + continue + match method: + case "pr": + results[pert] = float(average_precision_score(labels, scores)) + case "roc": + fpr, tpr, _ = roc_curve(labels, scores) + results[pert] = float(auc(fpr, tpr)) + case _: + raise ValueError(f"Invalid AUC method: {method}") + return results + + +def _make_side(n_pert: int, n_genes: int, sig_p: float, seed: int) -> pl.DataFrame: + rng = np.random.default_rng(seed) + n = n_pert * n_genes + target = np.repeat([f"P{i}" for i in range(n_pert)], n_genes) + feature = np.tile([f"g{j}" for j in range(n_genes)], n_pert) + is_sig = rng.random(n) < sig_p + # Distinct fdr values -> distinct AUC scores; sig genes below threshold. + fdr = np.where(is_sig, rng.uniform(1e-6, 0.049, n), rng.uniform(0.051, 1.0, n)) + lfc = rng.normal(0.0, 2.0, n) + return pl.DataFrame( + { + "target": target, + "feature": feature, + "log2_fold_change": lfc, + "p_value": fdr, + "fdr": fdr, + } + ) + + +def make_comparison(n_pert: int, n_genes: int, seed: int = 0) -> DEComparison: + return DEComparison( + real=DEResults(_make_side(n_pert, n_genes, 0.3, seed + 1), name="real"), + pred=DEResults(_make_side(n_pert, n_genes, 0.3, seed + 2), name="pred"), + ) + + +def _auc_equal(a: dict, b: dict) -> bool: + if list(a.keys()) != list(b.keys()): + return False + for k in a: + av, bv = a[k], b[k] + if isinstance(av, float) and math.isnan(av): + if not (isinstance(bv, float) and math.isnan(bv)): + return False + elif av != bv: + return False + return True + + +def _timed(fn): + gc.collect() + t0 = time.perf_counter() + out = fn() + return out, time.perf_counter() - t0 + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "--n-pert", type=int, nargs="+", default=[1000, 2000, 4000, 8000] + ) + parser.add_argument("--n-genes", type=int, default=50) + args = parser.parse_args() + + hdr = ( + f"{'n_pert':>7} | {'rows':>8} | " + f"{'nsig old':>9} {'nsig new':>9} {'x':>5} | " + f"{'pr old':>8} {'pr new':>8} {'x':>5} | " + f"{'roc old':>8} {'roc new':>8} {'x':>5} | ok" + ) + print(hdr) + print("-" * len(hdr)) + + for n_pert in args.n_pert: + comp = make_comparison(n_pert, args.n_genes) + rows = n_pert * args.n_genes + + old_nsig, t_on = _timed(lambda: _reference_densig_counts(comp, _FDR)) + new_nsig, t_nn = _timed(lambda: DENsigCounts(fdr_threshold=_FDR)(comp)) + old_pr, t_op = _timed(lambda: _reference_generic_auc(comp, "pr")) + new_pr, t_np = _timed(lambda: compute_pr_auc(comp)) + old_roc, t_or = _timed(lambda: _reference_generic_auc(comp, "roc")) + new_roc, t_nr = _timed(lambda: compute_roc_auc(comp)) + + ok = ( + old_nsig == new_nsig + and _auc_equal(old_pr, new_pr) + and _auc_equal(old_roc, new_roc) + ) + + def sp(o: float, n: float) -> float: + return o / n if n else float("inf") + + print( + f"{n_pert:>7} | {rows:>8} | " + f"{t_on:>9.3f} {t_nn:>9.3f} {sp(t_on, t_nn):>4.1f}x | " + f"{t_op:>8.3f} {t_np:>8.3f} {sp(t_op, t_np):>4.1f}x | " + f"{t_or:>8.3f} {t_nr:>8.3f} {sp(t_or, t_nr):>4.1f}x | {ok}" + ) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/bench_de_overlap.py b/benchmarks/bench_de_overlap.py new file mode 100644 index 0000000..de21d22 --- /dev/null +++ b/benchmarks/bench_de_overlap.py @@ -0,0 +1,222 @@ +"""Microbenchmark for the memoized DE overlap metric. + +The DE ``full``/``de`` profile registers 10 overlap variants +(``{overlap, precision} x k in {None, 50, 100, 200, 500}``), and every one of +them calls ``DEComparison.compute_overlap`` -> ``DEResults.get_top_genes`` on the +same real/pred pair with the same default ``sort_by`` / ``fdr_threshold``. Before +the optimization each call rebuilt a polars ``.pivot()`` with one column per +perturbation, and the per-perturbation loop tested membership against +``matrix.columns`` (which rebuilds a fresh list on every access). At thousands of +perturbations that is 10 redundant wide pivots plus an O(n_pert^2) membership +loop. + +This script runs the full 10-variant pattern with the pre-optimization +implementation (verbatim baseline below) and with the current memoized one, +across a range of perturbation counts, and confirms the two produce identical +results. + +Run with the package importable, e.g.:: + + python benchmarks/bench_de_overlap.py + python benchmarks/bench_de_overlap.py --n-pert 1000 4000 8000 + +``_reference_get_top_genes`` / ``_reference_compute_overlap`` are verbatim copies +of the pre-optimization implementation, kept here only as a benchmark baseline; +they are not used by the package. +""" + +from __future__ import annotations + +import argparse +import gc +import logging +import time +from typing import Literal + +import numpy as np +import polars as pl + +from cell_eval._types import DEComparison, DEResults +from cell_eval._types._enums import DESortBy + +# The pre-optimization metric logged INFO lines per DEResults construction; keep +# the benchmark output clean. +logging.getLogger("cell_eval._types._de").setLevel(logging.WARNING) + +# Mirrors metrics/_impl.py registration. +VARIANTS: list[tuple[Literal["overlap", "precision"], int | None]] = [ + (metric, k) + for metric in ("overlap", "precision") + for k in (None, 50, 100, 200, 500) +] + + +def _reference_get_top_genes( + de: DEResults, + sort_by: DESortBy, + fdr_threshold: float | None = None, +) -> pl.DataFrame: + """Verbatim pre-optimization DEResults.get_top_genes (no memoization).""" + fdr_threshold = fdr_threshold if fdr_threshold is not None else 0.05 + descending = sort_by in { + DESortBy.LOG2_FOLD_CHANGE, + DESortBy.ABS_LOG2_FOLD_CHANGE, + } + rank_matrix = ( + de.data.filter(pl.col(de.fdr_col) < fdr_threshold) + .with_columns( + rank=pl.struct(sort_by.value) + .rank("ordinal", descending=descending) + .over("target") + - 1 + ) + .pivot(index="rank", on="target", values="feature") + .sort("rank") + ) + missing_perts = set(de.get_perts()) - set(rank_matrix.columns) + if missing_perts: + rank_matrix = rank_matrix.with_columns( + [pl.lit(None).alias(p) for p in missing_perts] + ) + return rank_matrix + + +def _reference_compute_overlap( + comparison: DEComparison, + k: int | None, + metric: Literal["overlap", "precision"] = "overlap", + fdr_threshold: float | None = None, + sort_by: DESortBy = DESortBy.ABS_LOG2_FOLD_CHANGE, +) -> dict[str, float]: + """Verbatim pre-optimization DEComparison.compute_overlap. + + Rebuilds both rank matrices on every call and tests membership against the + polars ``.columns`` list (one fresh list per access). + """ + real_sig_rank_matrix = _reference_get_top_genes( + comparison.real, sort_by=sort_by, fdr_threshold=fdr_threshold + ) + pred_sig_rank_matrix = _reference_get_top_genes( + comparison.pred, sort_by=sort_by, fdr_threshold=fdr_threshold + ) + + if real_sig_rank_matrix.shape[0] == 0 or pred_sig_rank_matrix.shape[0] == 0: + return {pert: 0.0 for pert in comparison.iter_perturbations()} + + overlaps = {} + for pert in comparison.iter_perturbations(): + if ( + pert not in real_sig_rank_matrix.columns + or pert not in pred_sig_rank_matrix.columns + ): + overlaps[pert] = 0.0 + continue + + real_genes = real_sig_rank_matrix[pert].drop_nulls().to_numpy() + pred_genes = pred_sig_rank_matrix[pert].drop_nulls().to_numpy() + + if metric == "overlap": + k_eff = real_genes.size if not k else k + k_eff = min(k_eff, real_genes.size) + elif metric == "precision": + k_eff = pred_genes.size if not k else k + k_eff = min(k_eff, pred_genes.size) + else: + raise ValueError(f"Invalid metric: {metric}") + + if k_eff == 0: + overlaps[pert] = 0.0 + else: + real_subset = real_genes[:k_eff] + pred_subset = pred_genes[:k_eff] + overlaps[pert] = np.intersect1d(real_subset, pred_subset).size / k_eff + + return overlaps + + +def _make_side( + perts: np.ndarray, genes: np.ndarray, n_sig: int, rng: np.random.Generator +) -> pl.DataFrame: + n_pert = perts.size + targets = np.repeat(perts, n_sig) + feats = np.empty(n_pert * n_sig, dtype=object) + for i in range(n_pert): + feats[i * n_sig : (i + 1) * n_sig] = rng.choice( + genes, size=n_sig, replace=False + ) + lfc = rng.normal(0.0, 2.0, size=n_pert * n_sig) + fdr = np.full(n_pert * n_sig, 0.01) + return pl.DataFrame( + { + "target": targets, + "feature": feats, + "log2_fold_change": lfc, + "p_value": fdr, + "fdr": fdr, + } + ) + + +def make_frames( + n_pert: int, n_genes: int = 2000, n_sig: int = 100, seed: int = 0 +) -> tuple[pl.DataFrame, pl.DataFrame]: + """Synthetic real/pred DE tables: every pert has `n_sig` significant genes.""" + rng = np.random.default_rng(seed) + perts = np.array([f"P{i}" for i in range(n_pert)]) + genes = np.array([f"g{j}" for j in range(n_genes)]) + real = _make_side(perts, genes, n_sig, rng) + pred = _make_side(perts, genes, n_sig, rng) + return real, pred + + +def _run(comparison: DEComparison, fn) -> dict: + return {(m, k): fn(comparison, k=k, metric=m) for (m, k) in VARIANTS} + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--n-pert", type=int, nargs="+", default=[1000, 2000, 4000]) + parser.add_argument("--n-genes", type=int, default=2000) + parser.add_argument("--n-sig", type=int, default=100) + args = parser.parse_args() + + print( + f"{'n_pert':>8} | {'old (s)':>10} | {'new (s)':>10} | " + f"{'speedup':>8} | identical" + ) + print("-" * 60) + + for n_pert in args.n_pert: + real_df, pred_df = make_frames(n_pert, args.n_genes, args.n_sig) + + comp_old = DEComparison( + real=DEResults(real_df, name="real"), + pred=DEResults(pred_df, name="pred"), + ) + comp_new = DEComparison( + real=DEResults(real_df, name="real"), + pred=DEResults(pred_df, name="pred"), + ) + + gc.collect() + t0 = time.perf_counter() + old_results = _run(comp_old, _reference_compute_overlap) + t_old = time.perf_counter() - t0 + + gc.collect() + t0 = time.perf_counter() + new_results = { + (m, k): comp_new.compute_overlap(k=k, metric=m) for (m, k) in VARIANTS + } + t_new = time.perf_counter() - t0 + + identical = all(old_results[v] == new_results[v] for v in VARIANTS) + speedup = t_old / t_new if t_new else float("inf") + print( + f"{n_pert:>8} | {t_old:>10.3f} | {t_new:>10.3f} | " + f"{speedup:>7.1f}x | {identical}" + ) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/bench_discrimination_score.py b/benchmarks/bench_discrimination_score.py new file mode 100644 index 0000000..1212f98 --- /dev/null +++ b/benchmarks/bench_discrimination_score.py @@ -0,0 +1,202 @@ +"""Microbenchmark for the vectorized ``discrimination_score``. + +Compares the original per-perturbation loop against the vectorized +implementation across a range of perturbation counts, and confirms the two +produce identical normalized ranks. + +Run with the package importable, e.g.:: + + python benchmarks/bench_discrimination_score.py + python benchmarks/bench_discrimination_score.py --n-pert 100 1000 10000 + +The ``_reference_discrimination_score`` function below is a verbatim copy of +the pre-optimization implementation, kept here only as a benchmark baseline; +it is not used by the package. +""" + +from __future__ import annotations + +import argparse +import gc +import time + +import anndata as ad +import numpy as np +import pandas as pd +import sklearn.metrics as skm + +from cell_eval._types import PerturbationAnndataPair +from cell_eval.metrics._anndata import discrimination_score + +CONTROL = "non-targeting" +PERT_COL = "perturbation" + + +def make_pair( + n_pert: int, + n_genes: int = 2000, + n_cells: int = 5000, + seed: int = 0, + frac_targeting: float = 0.8, +) -> PerturbationAnndataPair: + """Synthetic pair where `frac_targeting` of perts are named after genes.""" + rng = np.random.default_rng(seed) + var_names = np.array([f"gene_{i}" for i in range(n_genes)]) + n_targeting = min(int(round(n_pert * frac_targeting)), n_genes) + pert_names = list(var_names[:n_targeting]) + [ + f"drug_{k}" for k in range(n_pert - n_targeting) + ] + all_labels = np.concatenate([pert_names, [CONTROL]]) + labels = rng.choice(all_labels, size=max(n_cells, all_labels.size)) + labels[: all_labels.size] = all_labels + + def build(off: int) -> ad.AnnData: + r = np.random.default_rng(seed + 1000 * off) + a = ad.AnnData(X=r.standard_normal((labels.size, n_genes))) + a.obs[PERT_COL] = pd.Categorical(labels) + a.var_names = var_names + a.obs_names = [f"cell_{i}" for i in range(labels.size)] + return a + + return PerturbationAnndataPair( + real=build(1), pred=build(2), pert_col=PERT_COL, control_pert=CONTROL + ) + + +def _reference_discrimination_score( + data, metric="l1", embed_key=None, exclude_target_gene=True +): + """Verbatim pre-optimization per-perturbation loop (benchmark baseline).""" + if metric in ("l1", "manhattan", "cityblock"): + embed_key = None + real_effects = np.vstack( + [ + d.perturbation_effect("real", abs=False) + for d in data.iter_bulk_arrays(embed_key=embed_key) + ] + ) + pred_effects = np.vstack( + [ + d.perturbation_effect("pred", abs=False) + for d in data.iter_bulk_arrays(embed_key=embed_key) + ] + ) + norm_ranks = {} + for p_idx, p in enumerate(data.perts): + if exclude_target_gene and not embed_key: + include_mask = np.flatnonzero(data.genes != p) + else: + include_mask = np.ones(real_effects.shape[1], dtype=bool) + distances = skm.pairwise_distances( + real_effects[:, include_mask], + pred_effects[p_idx, include_mask].reshape(1, -1), + metric=metric, + ).flatten() + sorted_indices = np.argsort(distances) + p_index = np.flatnonzero(data.perts == p)[0] + rank = np.flatnonzero(sorted_indices == p_index)[0] + norm_ranks[str(p)] = 1 - rank / data.perts.size + return norm_ranks + + +def _time(fn, *args, repeats=1, **kwargs): + best = float("inf") + out = None + for _ in range(repeats): + gc.collect() + t0 = time.perf_counter() + out = fn(*args, **kwargs) + best = min(best, time.perf_counter() - t0) + return best, out + + +def bench_discrimination(sizes, metrics, n_genes, repeats, run_old): + print("\n## discrimination_score: old (loop) vs new (vectorized)\n") + # Warm up numba/sklearn/BLAS on a tiny input so the first measured call is + # not penalized by one-time import/JIT cost. + warm = make_pair(n_pert=50, n_genes=n_genes, n_cells=200, seed=99) + for metric in metrics: + discrimination_score(warm, metric=metric) + if run_old: + _reference_discrimination_score(warm, metric=metric) + header = "| n_pert | metric | old (s) | new (s) | speedup | ranks identical |" + print(header) + print("|---|---|---|---|---|---|") + for n_pert in sizes: + data = make_pair(n_pert=n_pert, n_genes=n_genes, n_cells=3 * n_pert, seed=0) + for metric in metrics: + t_new, new_out = _time( + discrimination_score, data, metric=metric, repeats=repeats + ) + if run_old: + t_old, old_out = _time( + _reference_discrimination_score, + data, + metric=metric, + repeats=repeats, + ) + keys = sorted(old_out) + identical = np.array_equal( + np.array([old_out[k] for k in keys]), + np.array([new_out[k] for k in keys]), + ) + speed = f"{t_old / t_new:.1f}x" + old_s = f"{t_old:.3f}" + else: + old_s, speed, identical = "(skipped)", "-", "-" + print( + f"| {n_pert} | {metric} | {old_s} | {t_new:.3f} | {speed} " + f"| {identical} |" + ) + del new_out + gc.collect() + + +def print_env(): + """Print the machine + library versions so the run is reproducible.""" + import platform + from importlib.metadata import version + + cpu = platform.processor() or "unknown" + if platform.system() == "Darwin": + import subprocess + + try: + out = subprocess.run( + ["sysctl", "-n", "machdep.cpu.brand_string"], + capture_output=True, + text=True, + check=True, + ).stdout.strip() + cpu = out or cpu + except Exception: + pass + libs = ", ".join(f"{p} {version(p)}" for p in ("numpy", "scipy", "scikit-learn")) + print("## environment\n") + print(f"- platform: {platform.platform()}") + print(f"- cpu: {cpu}") + print(f"- python: {platform.python_version()}") + print(f"- libs: {libs}") + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--n-pert", type=int, nargs="+", default=[100, 1000, 10000]) + ap.add_argument("--metrics", nargs="+", default=["l1", "l2", "cosine"]) + ap.add_argument("--n-genes", type=int, default=2000) + ap.add_argument("--repeats", type=int, default=1) + ap.add_argument( + "--no-old", + action="store_true", + help="skip the slow reference loop (new impl only)", + ) + args = ap.parse_args() + + print_env() + bench_discrimination( + args.n_pert, args.metrics, args.n_genes, args.repeats, run_old=not args.no_old + ) + + +if __name__ == "__main__": + main() diff --git a/src/cell_eval/_types/_de.py b/src/cell_eval/_types/_de.py index 7dd1eba..ad91657 100644 --- a/src/cell_eval/_types/_de.py +++ b/src/cell_eval/_types/_de.py @@ -51,6 +51,12 @@ class DEResults: fdr_col: str = "fdr" name: str = "de" + # Caches the (sort_by, fdr_threshold) -> rank-matrix pivot built by + # get_top_genes; excluded from init/repr/eq so it never affects identity. + _top_genes_cache: dict[tuple[DESortBy, float], pl.DataFrame] = field( + default_factory=dict, init=False, repr=False, compare=False + ) + def __post_init__(self) -> None: required_cols = { self.target_col, @@ -150,6 +156,16 @@ def get_top_genes( # Set FDR threshold if not provided fdr_threshold = fdr_threshold if fdr_threshold is not None else 0.05 + # The rank matrix depends only on (sort_by, fdr_threshold). The overlap + # family registers many variants (overlap/precision x several k) that all + # request this same matrix -- k only truncates the per-pert gene lists in + # compute_overlap, never the pivot -- so memoize it to rebuild the wide + # one-column-per-perturbation pivot once instead of once per variant. + cache_key = (sort_by, fdr_threshold) + cached = self._top_genes_cache.get(cache_key) + if cached is not None: + return cached + descending = sort_by in { DESortBy.LOG2_FOLD_CHANGE, DESortBy.ABS_LOG2_FOLD_CHANGE, @@ -180,6 +196,7 @@ def get_top_genes( [pl.lit(None).alias(p) for p in missing_perts] ) + self._top_genes_cache[cache_key] = rank_matrix return rank_matrix @@ -244,13 +261,16 @@ def compute_overlap( # Cannot evaluate in this case so setting all perturbations to 0.0 return {pert: 0.0 for pert in self.iter_perturbations()} + # `.columns` rebuilds a fresh list on every access; hoist both into sets + # once so the per-pert membership test is O(1) rather than O(n_perts), + # which over the loop is the difference between O(n_perts) and O(n_perts^2). + real_cols = set(real_sig_rank_matrix.columns) + pred_cols = set(pred_sig_rank_matrix.columns) + overlaps = {} for pert in self.iter_perturbations(): # If perturbation is not in either real or pred, set overlap to 0.0 - if ( - pert not in real_sig_rank_matrix.columns - or pert not in pred_sig_rank_matrix.columns - ): + if pert not in real_cols or pert not in pred_cols: overlaps[pert] = 0.0 continue diff --git a/src/cell_eval/metrics/_anndata.py b/src/cell_eval/metrics/_anndata.py index 458de62..d822be7 100644 --- a/src/cell_eval/metrics/_anndata.py +++ b/src/cell_eval/metrics/_anndata.py @@ -149,7 +149,8 @@ def discrimination_score( # Ignore the embedding key for L1 embed_key = None - # Compute perturbation effects for all perturbations + # Compute perturbation effects for all perturbations. The underlying + # pseudobulk is memoized on the pair, so this is shared across metrics. real_effects = np.vstack( [ d.perturbation_effect(which="real", abs=False) @@ -163,39 +164,170 @@ def discrimination_score( ] ) - norm_ranks = {} - for p_idx, p in enumerate(data.perts): - # Determine which features to include in the comparison - if exclude_target_gene and not embed_key: - # For expression data, exclude the target gene - include_mask = np.flatnonzero(data.genes != p) - else: - # For embedding data or when not excluding target gene, use all features - include_mask = np.ones(real_effects.shape[1], dtype=bool) - - # Compute distances to all real effects - distances = skm.pairwise_distances( - real_effects[ - :, include_mask - ], # compare to all real effects across perturbations - pred_effects[p_idx, include_mask].reshape( - 1, -1 - ), # select pred effect for current perturbation - metric=metric, + # dist_matrix[i, j] = distance(pred_effect[i], real_effect[j]); each row is + # the vector of distances from one predicted effect to all real effects. + # + # When excluding the target gene on expression data, perturbation i drops a + # *different* feature column (the gene named like perturbation i), so a + # single unmasked pairwise call would not reproduce the per-perturbation + # masked distances. We instead compute the full matrix once and apply an + # exact, vectorized rank-1 correction that removes the target gene's + # contribution from each row. For metrics without a closed-form column + # correction we fall back to exact per-row masked distances. + do_exclude = exclude_target_gene and not embed_key + family = _distance_family(metric) + + if not do_exclude: + dist_matrix = skm.pairwise_distances(pred_effects, real_effects, metric=metric) + elif family is None: + dist_matrix = _masked_distance_matrix( + real_effects, pred_effects, data.genes, data.perts, metric + ) + else: + # `family` is narrowed to Literal["l1", "l2", "cosine"] here. + dist_matrix = _excluded_distance_matrix( + real_effects, + pred_effects, + data.genes, + data.perts, + family, + ) + + # Rank of the matching perturbation within each row, by ascending distance. + # order[i] lists columns by increasing distance, so the rank of perturbation + # i is the position of column i within row i. A boolean match locates that + # position directly -- equivalent to argsort(argsort(.)) but without a + # second full-matrix sort (cheaper, and the mask is bool rather than int). + n_pert = data.perts.size + order = np.argsort(dist_matrix, axis=1) + ranks = np.where(order == np.arange(n_pert)[:, None])[1] + + return {str(p): float(1 - ranks[i] / n_pert) for i, p in enumerate(data.perts)} + + +def _distance_family(metric: str) -> Literal["l1", "l2", "cosine"] | None: + """Map a pairwise metric name onto a family with a closed-form column drop.""" + match metric.lower(): + case "l1" | "manhattan" | "cityblock": + return "l1" + case "l2" | "euclidean": + return "l2" + case "cosine": + return "cosine" + case _: + return None + + +def _target_gene_columns(genes: np.ndarray, perts: np.ndarray) -> list[list[int]]: + """Feature columns whose gene name matches each perturbation (usually 0 or 1).""" + gene_to_cols: dict[str, list[int]] = {} + for col, g in enumerate(genes): + gene_to_cols.setdefault(str(g), []).append(col) + return [gene_to_cols.get(str(p), []) for p in perts] + + +def _masked_row( + real_effects: np.ndarray, + pred_row: np.ndarray, + excluded_cols: Sequence[int], + metric: str, +) -> np.ndarray: + """Distances from one predicted effect to all real effects, dropping columns.""" + if excluded_cols: + mask = np.ones(real_effects.shape[1], dtype=bool) + mask[list(excluded_cols)] = False + return skm.pairwise_distances( + real_effects[:, mask], pred_row[mask].reshape(1, -1), metric=metric ).flatten() + return skm.pairwise_distances( + real_effects, pred_row.reshape(1, -1), metric=metric + ).flatten() + + +def _masked_distance_matrix( + real_effects: np.ndarray, + pred_effects: np.ndarray, + genes: np.ndarray, + perts: np.ndarray, + metric: str, +) -> np.ndarray: + """Per-row masked distance matrix for metrics without a column correction.""" + excluded = _target_gene_columns(genes, perts) + return np.vstack( + [ + _masked_row(real_effects, pred_effects[i], excluded[i], metric) + for i in range(perts.size) + ] + ) + - # Sort by distance (ascending - lower distance = better match) - sorted_indices = np.argsort(distances) +def _excluded_distance_matrix( + real_effects: np.ndarray, + pred_effects: np.ndarray, + genes: np.ndarray, + perts: np.ndarray, + family: Literal["l1", "l2", "cosine"], +) -> np.ndarray: + """Full distance matrix with each row's target gene contribution removed. - # Find rank of the correct perturbation - p_index = np.flatnonzero(data.perts == p)[0] - rank = np.flatnonzero(sorted_indices == p_index)[0] + Row i corresponds to perturbation i; the feature column named like + perturbation i is dropped from that row's distances only. The result is + numerically equivalent (up to floating-point summation order) to computing + `pairwise_distances` on the masked columns per perturbation. + """ + n_pert = perts.size + excluded = _target_gene_columns(genes, perts) + + has_target = np.zeros(n_pert, dtype=bool) + tcol = np.zeros(n_pert, dtype=np.intp) + multi: list[int] = [] + for i, cols in enumerate(excluded): + if len(cols) == 1: + has_target[i] = True + tcol[i] = cols[0] + elif len(cols) > 1: + multi.append(i) + + rows = np.arange(n_pert) + mask2d = has_target[:, None] + # pred_at[i] = pred_effects[i, tcol[i]] + # real_at[i, j] = real_effects[j, tcol[i]] + pred_at = pred_effects[rows, tcol] + real_at = real_effects[:, tcol].T + + match family: + case "l1": + out = skm.pairwise_distances(pred_effects, real_effects, metric="l1") + out -= np.where(mask2d, np.abs(pred_at[:, None] - real_at), 0.0) + case "l2": + out = skm.pairwise_distances(pred_effects, real_effects, metric="l2") + corr = np.where(mask2d, (pred_at[:, None] - real_at) ** 2, 0.0) + out = np.sqrt(np.maximum(out**2 - corr, 0.0)) + case _: # cosine: drop the column from the dot product and both norms + dot = pred_effects @ real_effects.T + pred_sq = np.einsum("ij,ij->i", pred_effects, pred_effects) + real_sq = np.einsum("ij,ij->i", real_effects, real_effects) + dot -= np.where(mask2d, pred_at[:, None] * real_at, 0.0) + pred_sq_m = pred_sq - np.where(has_target, pred_at**2, 0.0) + real_sq_m = real_sq[None, :] - np.where(mask2d, real_at**2, 0.0) + # An effect dominated by its target gene can leave a masked squared + # norm at a tiny negative value from float rounding; clip to 0 so the + # norm is real (not NaN). The resulting zero norm is then handled + # like sklearn below (cosine similarity 0 -> distance 1). + denom = np.sqrt(np.maximum(pred_sq_m, 0.0))[:, None] * np.sqrt( + np.maximum(real_sq_m, 0.0) + ) + with np.errstate(divide="ignore", invalid="ignore"): + cos = dot / denom + cos = np.where(denom == 0.0, 0.0, cos) + out = np.clip(1 - cos, 0.0, 2.0) - # Normalize rank by total number of perturbations - norm_rank = rank / data.perts.size - norm_ranks[str(p)] = 1 - norm_rank + # Safety net for the rare case of duplicate gene names matching a single + # perturbation (more than one column to drop): recompute those rows exactly. + for i in multi: + out[i] = _masked_row(real_effects, pred_effects[i], excluded[i], family) - return norm_ranks + return out def _generic_evaluation( diff --git a/src/cell_eval/metrics/_de.py b/src/cell_eval/metrics/_de.py index 5204c2c..da7c918 100644 --- a/src/cell_eval/metrics/_de.py +++ b/src/cell_eval/metrics/_de.py @@ -191,18 +191,36 @@ def __init__(self, fdr_threshold: float = 0.05) -> None: def __call__(self, data: DEComparison) -> dict[str, dict[str, int]]: """Compute counts of significant genes in real and predicted DE.""" - counts = {} - - for pert in data.iter_perturbations(): - real_sig = data.real.get_significant_genes(pert, self.fdr_threshold) - pred_sig = data.pred.get_significant_genes(pert, self.fdr_threshold) - - counts[pert] = { - "real": int(real_sig.size), - "pred": int(pred_sig.size), + # One grouped count per side instead of a per-pert full-table scan; only + # the count (the old per-pert get_significant_genes(...).size) is used. + real_counts = { + str(pert): count + for pert, count in data.real.filter_to_significant( + fdr_threshold=self.fdr_threshold + ) + .group_by(data.real.target_col) + .len() + .iter_rows() + } + pred_counts = { + str(pert): count + for pert, count in data.pred.filter_to_significant( + fdr_threshold=self.fdr_threshold + ) + .group_by(data.pred.target_col) + .len() + .iter_rows() + } + + # Reindex over the full perturbation universe so perts with no + # significant genes on a side stay at 0 (matching the old empty .size). + return { + pert: { + "real": int(real_counts.get(str(pert), 0)), + "pred": int(pred_counts.get(str(pert), 0)), } - - return counts + for pert in data.iter_perturbations() + } def compute_pr_auc(data: DEComparison) -> dict[str, float]: @@ -245,10 +263,22 @@ def compute_generic_auc( ) ) + # Slice the table once instead of a full .filter(target == pert) scan per + # perturbation. maintain_order keeps each partition in the same row order the + # per-pert filter produced, so labels/scores -> sklearn stay bit-identical. + # partition_by(as_dict=True) keys are tuples on newer polars, scalars on + # older; normalize both to str so lookup by perturbation name is stable. + partitions = { + str(key[0] if isinstance(key, tuple) else key): frame + for key, frame in merged.partition_by( + target_col, as_dict=True, maintain_order=True + ).items() + } + results: dict[str, float] = {} for pert in data.iter_perturbations(): - pert_data = merged.filter(pl.col(target_col) == pert) - if pert_data.shape[0] == 0: + pert_data = partitions.get(str(pert)) + if pert_data is None or pert_data.shape[0] == 0: results[pert] = float("nan") continue diff --git a/tests/test_de_overlap_equivalence.py b/tests/test_de_overlap_equivalence.py new file mode 100644 index 0000000..d959bdf --- /dev/null +++ b/tests/test_de_overlap_equivalence.py @@ -0,0 +1,169 @@ +"""Equivalence + memoization guards for the DE overlap metric. + +`DEComparison.compute_overlap` memoizes the per-side rank-matrix pivot keyed by +`(sort_by, fdr_threshold)` and tests perturbation membership against precomputed +column sets. These are pure performance changes: the overlap/precision values +must stay bit-identical to a from-scratch reference, and the memoization must +collapse the repeated `get_top_genes` calls (one per registered overlap variant) +to a single pivot per side. +""" + +import polars as pl + +from cell_eval._types import DEComparison, DEResults +from cell_eval._types._enums import DESortBy + +# (target, feature, log2_fold_change, p_value, fdr). Within each perturbation the +# significant genes have distinct |log2_fold_change|, so the descending sort is +# unambiguous and the reference does not depend on polars' tie handling. +_REAL_ROWS = [ + ("A", "g1", 3.0, 0.001, 0.01), + ("A", "g2", 2.0, 0.002, 0.02), + ("A", "g3", 1.0, 0.004, 0.04), + ("A", "g4", 0.5, 0.090, 0.10), # not significant + ("B", "g1", -2.5, 0.001, 0.01), + ("B", "g2", 1.5, 0.003, 0.03), + ("B", "g5", 0.8, 0.150, 0.20), # not significant + ("C", "g3", 2.2, 0.250, 0.30), # not significant + ("C", "g4", 1.1, 0.350, 0.40), # not significant +] + +_PRED_ROWS = [ + ("A", "g1", 2.8, 0.001, 0.01), + ("A", "g3", 2.5, 0.002, 0.02), + ("A", "g2", 1.0, 0.004, 0.04), + ("A", "g4", 0.9, 0.003, 0.03), # significant in pred only + ("B", "g1", -2.0, 0.002, 0.02), + ("B", "g2", 0.5, 0.400, 0.50), # not significant + ("B", "g5", 1.2, 0.001, 0.01), # significant in pred only + ("C", "g3", 1.9, 0.002, 0.02), # significant in pred only + ("C", "g4", 0.3, 0.400, 0.50), # not significant +] + +_FDR = 0.05 + + +def _rows_to_df(rows: list[tuple]) -> pl.DataFrame: + return pl.DataFrame( + { + "target": [r[0] for r in rows], + "feature": [r[1] for r in rows], + "log2_fold_change": [r[2] for r in rows], + "p_value": [r[3] for r in rows], + "fdr": [r[4] for r in rows], + } + ) + + +def _ordered_sig_genes(rows: list[tuple], fdr_threshold: float) -> dict[str, list[str]]: + """Reference: per-pert features, FDR-filtered, sorted by |lfc| descending.""" + by_pert: dict[str, list[tuple[float, str]]] = {} + for target, feature, lfc, _p, fdr in rows: + by_pert.setdefault(target, []) + if fdr < fdr_threshold: + by_pert[target].append((abs(lfc), feature)) + return { + pert: [g for _, g in sorted(items, key=lambda t: t[0], reverse=True)] + for pert, items in by_pert.items() + } + + +def _ref_overlap( + real_order: dict[str, list[str]], + pred_order: dict[str, list[str]], + perts: list[str], + k: int | None, + metric: str, +) -> dict[str, float]: + """Reference overlap/precision mirroring compute_overlap's exact k_eff math.""" + out: dict[str, float] = {} + for pert in perts: + real_genes = real_order.get(pert, []) + pred_genes = pred_order.get(pert, []) + if metric == "overlap": + k_eff = len(real_genes) if not k else k + k_eff = min(k_eff, len(real_genes)) + else: # precision + k_eff = len(pred_genes) if not k else k + k_eff = min(k_eff, len(pred_genes)) + if k_eff == 0: + out[pert] = 0.0 + else: + inter = set(real_genes[:k_eff]) & set(pred_genes[:k_eff]) + out[pert] = len(inter) / k_eff + return out + + +def _make_comparison() -> DEComparison: + return DEComparison( + real=DEResults(_rows_to_df(_REAL_ROWS), name="real"), + pred=DEResults(_rows_to_df(_PRED_ROWS), name="pred"), + ) + + +def test_compute_overlap_matches_reference() -> None: + """Optimized compute_overlap is bit-identical to a from-scratch reference.""" + comparison = _make_comparison() + perts = list(comparison.get_perts()) + + real_order = _ordered_sig_genes(_REAL_ROWS, _FDR) + pred_order = _ordered_sig_genes(_PRED_ROWS, _FDR) + + for metric in ("overlap", "precision"): + for k in (None, 1, 2, 50, 500): + got = comparison.compute_overlap(k=k, metric=metric, fdr_threshold=_FDR) + expected = _ref_overlap(real_order, pred_order, perts, k, metric) + assert got == expected, f"metric={metric} k={k}: {got} != {expected}" + + +def test_get_top_genes_memoized_across_variants() -> None: + """The 10 overlap variants must reuse one pivot per side, not rebuild each.""" + comparison = _make_comparison() + + # Mirror metrics/_impl.py: {overlap, precision} x {None, 50, 100, 200, 500} + # all hit the same default (sort_by, fdr_threshold). + for metric in ("overlap", "precision"): + for k in (None, 50, 100, 200, 500): + comparison.compute_overlap(k=k, metric=metric, fdr_threshold=_FDR) + + assert len(comparison.real._top_genes_cache) == 1 + assert len(comparison.pred._top_genes_cache) == 1 + + +def test_get_top_genes_cache_distinguishes_keys() -> None: + """Different (sort_by, fdr_threshold) keys must produce separate entries.""" + comparison = _make_comparison() + + comparison.compute_overlap( + k=None, + metric="overlap", + fdr_threshold=_FDR, + sort_by=DESortBy.ABS_LOG2_FOLD_CHANGE, + ) + comparison.compute_overlap( + k=None, + metric="overlap", + fdr_threshold=0.10, + sort_by=DESortBy.ABS_LOG2_FOLD_CHANGE, + ) + comparison.compute_overlap( + k=None, metric="overlap", fdr_threshold=_FDR, sort_by=DESortBy.PVALUE + ) + + # 0.05/abs, 0.10/abs, 0.05/pvalue -> three distinct cache keys. + assert len(comparison.real._top_genes_cache) == 3 + + +def test_compute_overlap_no_significant_genes_all_zero() -> None: + """Early-return branch: when one side has no significant genes, all 0.0.""" + rows = [ + ("A", "g1", 3.0, 0.5, 0.5), + ("A", "g2", 2.0, 0.6, 0.6), + ("B", "g1", 1.0, 0.7, 0.7), + ] + comparison = DEComparison( + real=DEResults(_rows_to_df(rows), name="real"), + pred=DEResults(_rows_to_df(rows), name="pred"), + ) + got = comparison.compute_overlap(k=None, metric="overlap", fdr_threshold=_FDR) + assert got == {"A": 0.0, "B": 0.0} diff --git a/tests/test_de_perpert_scan_equivalence.py b/tests/test_de_perpert_scan_equivalence.py new file mode 100644 index 0000000..4804a7a --- /dev/null +++ b/tests/test_de_perpert_scan_equivalence.py @@ -0,0 +1,163 @@ +"""Bit-exact equivalence for the per-perturbation-scan DE metrics. + +`DENsigCounts` and `compute_generic_auc` (pr/roc) used to loop over every +perturbation doing a full-table `.filter(target == pert)` (or per-pert +`get_significant_genes`). They now slice the table once -- a grouped count for +`DENsigCounts`, a single `partition_by` for the AUC metrics. These are pure +performance changes: the numeric output must be identical to the pre-optimization +implementations, which are reproduced verbatim below as references. +""" + +import math + +import numpy as np +import polars as pl +from sklearn.metrics import auc, average_precision_score, roc_curve + +from cell_eval._types import DEComparison, DEResults +from cell_eval.metrics._de import DENsigCounts, compute_pr_auc, compute_roc_auc + +_FDR = 0.05 + +# Genes g0..g5 for every pert. real fdr is chosen so P0/P1 have a label mix +# (real AUC), P2 is all-significant and P3 all-non-significant (both -> nan); +# pred fdr varies the scores and the per-side significant counts. +_GENES = [f"g{i}" for i in range(6)] +_REAL_FDR = { + "P0": [0.01, 0.02, 0.60, 0.70, 0.80, 0.90], + "P1": [0.01, 0.01, 0.01, 0.60, 0.70, 0.80], + "P2": [0.01, 0.01, 0.01, 0.01, 0.01, 0.01], + "P3": [0.50, 0.60, 0.70, 0.80, 0.90, 0.95], +} +_PRED_FDR = { + "P0": [0.02, 0.50, 0.04, 0.90, 0.10, 0.80], + "P1": [0.01, 0.30, 0.02, 0.40, 0.20, 0.60], + "P2": [0.30, 0.01, 0.50, 0.02, 0.70, 0.03], + "P3": [0.04, 0.50, 0.60, 0.01, 0.70, 0.80], +} + + +def _frame(fdr_by_pert: dict[str, list[float]]) -> pl.DataFrame: + target, feature, lfc, fdr = [], [], [], [] + for pert, fdrs in fdr_by_pert.items(): + for gi, (gene, f) in enumerate(zip(_GENES, fdrs)): + target.append(pert) + feature.append(gene) + # Deterministic nonzero lfc; unused by these metrics but required. + lfc.append((gi + 1) * (1.0 if gi % 2 == 0 else -1.0)) + fdr.append(f) + return pl.DataFrame( + { + "target": target, + "feature": feature, + "log2_fold_change": lfc, + "p_value": fdr, + "fdr": fdr, + } + ) + + +def _make_comparison() -> DEComparison: + return DEComparison( + real=DEResults(_frame(_REAL_FDR), name="real"), + pred=DEResults(_frame(_PRED_FDR), name="pred"), + ) + + +def _reference_densig_counts( + data: DEComparison, fdr_threshold: float +) -> dict[str, dict[str, int]]: + """Verbatim pre-optimization DENsigCounts.__call__.""" + counts = {} + for pert in data.iter_perturbations(): + real_sig = data.real.get_significant_genes(pert, fdr_threshold) + pred_sig = data.pred.get_significant_genes(pert, fdr_threshold) + counts[pert] = {"real": int(real_sig.size), "pred": int(pred_sig.size)} + return counts + + +def _reference_generic_auc(data: DEComparison, method: str) -> dict[str, float]: + """Verbatim pre-optimization compute_generic_auc (per-pert .filter loop).""" + target_col = data.real.target_col + feature_col = data.real.feature_col + real_fdr_col = data.real.fdr_col + pred_fdr_col = data.pred.fdr_col + + labeled_real = data.real.data.with_columns( + (pl.col(real_fdr_col) < 0.05).cast(pl.Float32).alias("label") + ).select([target_col, feature_col, "label"]) + + pred_q = pl.col(pred_fdr_col).fill_null(1.0).clip(1e-10, 1.0) + merged = ( + labeled_real.join( + data.pred.data.select([target_col, feature_col, pred_fdr_col]), + on=[target_col, feature_col], + how="left", + coalesce=True, + ) + .drop_nulls(["label"]) + .with_columns( + pred_q.alias(pred_fdr_col), + (-pred_q.log10()).alias("nlp"), + ) + ) + + results: dict[str, float] = {} + for pert in data.iter_perturbations(): + pert_data = merged.filter(pl.col(target_col) == pert) + if pert_data.shape[0] == 0: + results[pert] = float("nan") + continue + labels = pert_data["label"].to_numpy() + scores = pert_data["nlp"].to_numpy() + if not (0 < labels.sum() < len(labels)): + results[pert] = float("nan") + continue + match method: + case "pr": + results[pert] = float(average_precision_score(labels, scores)) + case "roc": + fpr, tpr, _ = roc_curve(labels, scores) + results[pert] = float(auc(fpr, tpr)) + case _: + raise ValueError(f"Invalid AUC method: {method}") + return results + + +def _assert_auc_equal(got: dict, expected: dict) -> None: + assert list(got.keys()) == list(expected.keys()) + for k in expected: + gv, ev = got[k], expected[k] + if isinstance(ev, float) and math.isnan(ev): + assert isinstance(gv, float) and math.isnan(gv), (k, gv) + else: + assert gv == ev, (k, gv, ev) + + +def test_densig_counts_matches_reference() -> None: + comparison = _make_comparison() + got = DENsigCounts(fdr_threshold=_FDR)(comparison) + expected = _reference_densig_counts(comparison, _FDR) + assert got == expected + # Sanity: P3 has no significant real genes -> count 0 (reindex fill path). + assert got[np.str_("P3")]["real"] == 0 + assert got[np.str_("P2")]["real"] == 6 + + +def test_pr_auc_matches_reference() -> None: + comparison = _make_comparison() + got = compute_pr_auc(comparison) + expected = _reference_generic_auc(comparison, "pr") + _assert_auc_equal(got, expected) + # Sanity: degenerate label sets -> nan; mixed -> finite. + assert math.isnan(got[np.str_("P2")]) + assert math.isnan(got[np.str_("P3")]) + assert math.isfinite(got[np.str_("P0")]) + + +def test_roc_auc_matches_reference() -> None: + comparison = _make_comparison() + got = compute_roc_auc(comparison) + expected = _reference_generic_auc(comparison, "roc") + _assert_auc_equal(got, expected) + assert math.isfinite(got[np.str_("P1")]) diff --git a/tests/test_discrimination_score.py b/tests/test_discrimination_score.py new file mode 100644 index 0000000..23a91cc --- /dev/null +++ b/tests/test_discrimination_score.py @@ -0,0 +1,199 @@ +"""Equivalence tests for the vectorized ``discrimination_score``. + +The vectorized implementation must reproduce the original per-perturbation +loop exactly, including the target-gene-exclusion path where each perturbation +drops a different feature column. Because the output values are discrete ranks +(spaced ``1 / n_pert`` apart), identical rankings are asserted with +``array_equal``, not just ``allclose``. +""" + +import anndata as ad +import numpy as np +import pandas as pd +import pytest +import sklearn.metrics as skm + +from cell_eval._types import PerturbationAnndataPair +from cell_eval.metrics._anndata import discrimination_score + +CONTROL = "non-targeting" +PERT_COL = "perturbation" + + +def _reference_discrimination_score( + data, metric="l1", embed_key=None, exclude_target_gene=True +): + """Verbatim copy of the original per-perturbation loop implementation.""" + if metric in ("l1", "manhattan", "cityblock"): + embed_key = None + real_effects = np.vstack( + [ + d.perturbation_effect("real", abs=False) + for d in data.iter_bulk_arrays(embed_key=embed_key) + ] + ) + pred_effects = np.vstack( + [ + d.perturbation_effect("pred", abs=False) + for d in data.iter_bulk_arrays(embed_key=embed_key) + ] + ) + norm_ranks = {} + for p_idx, p in enumerate(data.perts): + if exclude_target_gene and not embed_key: + include_mask = np.flatnonzero(data.genes != p) + else: + include_mask = np.ones(real_effects.shape[1], dtype=bool) + distances = skm.pairwise_distances( + real_effects[:, include_mask], + pred_effects[p_idx, include_mask].reshape(1, -1), + metric=metric, + ).flatten() + sorted_indices = np.argsort(distances) + p_index = np.flatnonzero(data.perts == p)[0] + rank = np.flatnonzero(sorted_indices == p_index)[0] + norm_ranks[str(p)] = 1 - rank / data.perts.size + return norm_ranks + + +def _make_pair( + n_pert=60, + n_genes=400, + n_cells=1500, + seed=0, + embed_dim=16, + frac_targeting=0.8, + var_names=None, +): + """A real/pred pair where a fraction of perts are named after gene columns.""" + rng = np.random.default_rng(seed) + if var_names is None: + var_names = np.array([f"gene_{i}" for i in range(n_genes)]) + n_genes = var_names.size + n_targeting = min(int(round(n_pert * frac_targeting)), n_genes) + pert_names = list(var_names[:n_targeting]) + [ + f"drug_{k}" for k in range(n_pert - n_targeting) + ] + all_labels = np.concatenate([np.asarray(pert_names), [CONTROL]]) + labels = rng.choice(all_labels, size=max(n_cells, all_labels.size)) + labels[: all_labels.size] = all_labels + + def build(off): + r = np.random.default_rng(seed + 1000 * off) + a = ad.AnnData(X=r.standard_normal((labels.size, n_genes))) + a.obs[PERT_COL] = pd.Categorical(labels) + a.var_names = var_names + a.obs_names = [f"cell_{i}" for i in range(labels.size)] + if embed_dim: + a.obsm["X_emb"] = r.standard_normal((labels.size, embed_dim)) + return a + + return PerturbationAnndataPair( + real=build(1), pred=build(2), pert_col=PERT_COL, control_pert=CONTROL + ) + + +def _ranks(out): + return np.array([out[k] for k in sorted(out)]) + + +@pytest.mark.parametrize("metric", ["l1", "l2", "cosine"]) +@pytest.mark.parametrize("exclude_target_gene", [True, False]) +@pytest.mark.parametrize("embed_key", [None, "X_emb"]) +@pytest.mark.parametrize("seed", [0, 1, 2]) +def test_matches_reference(metric, exclude_target_gene, embed_key, seed): + data = _make_pair(seed=seed) + ref = _reference_discrimination_score( + data, + metric=metric, + embed_key=embed_key, + exclude_target_gene=exclude_target_gene, + ) + new = discrimination_score( + data, + metric=metric, + embed_key=embed_key, + exclude_target_gene=exclude_target_gene, + ) + assert set(ref) == set(new) + # Discrete ranks: identical ordering must yield identical values. + np.testing.assert_array_equal(_ranks(ref), _ranks(new)) + + +@pytest.mark.parametrize("metric", ["l1", "l2", "cosine"]) +@pytest.mark.parametrize("frac_targeting", [0.0, 1.0]) +def test_matches_reference_extreme_targeting(metric, frac_targeting): + """No perts named after genes, and every pert named after a gene.""" + data = _make_pair(seed=3, frac_targeting=frac_targeting) + ref = _reference_discrimination_score(data, metric=metric) + new = discrimination_score(data, metric=metric) + np.testing.assert_array_equal(_ranks(ref), _ranks(new)) + + +@pytest.mark.parametrize("metric", ["chebyshev", "correlation"]) +def test_exotic_metric_fallback(metric): + """Metrics without a closed-form column correction take the exact fallback.""" + data = _make_pair(seed=4, frac_targeting=0.9) + ref = _reference_discrimination_score(data, metric=metric) + new = discrimination_score(data, metric=metric) + np.testing.assert_array_equal(_ranks(ref), _ranks(new)) + + +@pytest.mark.parametrize("metric", ["l1", "l2", "cosine"]) +def test_duplicate_gene_name_safety_net(metric): + """A perturbation matching two gene columns must drop both (multi-col net).""" + var_names = np.array([f"gene_{i}" for i in range(40)]) + var_names[10] = "gene_5" # "gene_5" now matches two columns + data = _make_pair( + n_pert=20, + n_genes=40, + n_cells=600, + seed=5, + embed_dim=0, + frac_targeting=1.0, + var_names=var_names, + ) + ref = _reference_discrimination_score(data, metric=metric) + new = discrimination_score(data, metric=metric) + np.testing.assert_array_equal(_ranks(ref), _ranks(new)) + + +def _make_target_dominated_pair( + n_pert=30, n_genes=200, n_cells=1200, seed=7, spike=30.0 +): + """Pair where each perturbation's effect is concentrated in its own target + gene, so the target-excluded (masked) vector is near-zero -- the degenerate + case for cosine, where a masked squared norm can round negative.""" + rng = np.random.default_rng(seed) + var_names = np.array([f"gene_{i}" for i in range(n_genes)]) + pert_names = list(var_names[:n_pert]) # every pert is named after a gene + all_labels = np.concatenate([np.asarray(pert_names), [CONTROL]]) + labels = rng.choice(all_labels, size=max(n_cells, all_labels.size)) + labels[: all_labels.size] = all_labels + col_of = {name: i for i, name in enumerate(var_names)} + + def build(off): + r = np.random.default_rng(seed + 1000 * off) + x = r.standard_normal((labels.size, n_genes)) + for j, lab in enumerate(labels): + if lab in col_of: # perturbed cell: spike its target gene column + x[j, col_of[lab]] += spike + a = ad.AnnData(X=x) + a.obs[PERT_COL] = pd.Categorical(labels) + a.var_names = var_names + a.obs_names = [f"cell_{i}" for i in range(labels.size)] + return a + + return PerturbationAnndataPair( + real=build(1), pred=build(2), pert_col=PERT_COL, control_pert=CONTROL + ) + + +@pytest.mark.parametrize("metric", ["l1", "l2", "cosine"]) +def test_target_gene_dominated_effects(metric): + """Near-zero masked vectors must match the loop and never produce NaN.""" + data = _make_target_dominated_pair() + ref = _reference_discrimination_score(data, metric=metric) + new = discrimination_score(data, metric=metric) + assert not np.isnan(_ranks(new)).any() + np.testing.assert_array_equal(_ranks(ref), _ranks(new))