From 2e900d8d9325caec3491fa74627157a05abefc55 Mon Sep 17 00:00:00 2001 From: Francis Chalissery <45127389+fctb12@users.noreply.github.com> Date: Thu, 23 Oct 2025 10:45:17 -0700 Subject: [PATCH] Add top-k accuracy metric for perturbation evaluation --- src/cell_eval/metrics/__init__.py | 2 + src/cell_eval/metrics/_anndata.py | 82 ++++++++++++++++++++++++++++ src/cell_eval/metrics/_impl.py | 9 ++++ tests/test_metrics_topk.py | 90 +++++++++++++++++++++++++++++++ 4 files changed, 183 insertions(+) create mode 100644 tests/test_metrics_topk.py diff --git a/src/cell_eval/metrics/__init__.py b/src/cell_eval/metrics/__init__.py index 350a365..fe3300b 100644 --- a/src/cell_eval/metrics/__init__.py +++ b/src/cell_eval/metrics/__init__.py @@ -8,6 +8,7 @@ mse, mse_delta, pearson_delta, + top_k_accuracy, ) from ._de import ( DEDirectionMatch, @@ -31,6 +32,7 @@ "mse_delta", "mae_delta", "discrimination_score", + "top_k_accuracy", # DE metrics "DEDirectionMatch", "DESpearmanSignificant", diff --git a/src/cell_eval/metrics/_anndata.py b/src/cell_eval/metrics/_anndata.py index 8bcdf7d..8bbde96 100644 --- a/src/cell_eval/metrics/_anndata.py +++ b/src/cell_eval/metrics/_anndata.py @@ -198,6 +198,88 @@ def discrimination_score( return norm_ranks +def _get_array( + adata: ad.AnnData, + embed_key: str | None = None, +) -> np.ndarray: + """Extract a dense numpy array from an AnnData object.""" + + matrix = adata.obsm[embed_key] if embed_key else adata.X + if issparse(matrix): + matrix = matrix.toarray() # type: ignore[assignment] + if matrix.dtype != np.float64: + matrix = matrix.astype(np.float64, copy=False) # type: ignore[assignment] + return matrix # type: ignore[return-value] + + +def top_k_accuracy( + data: PerturbationAnndataPair, + k: int = 5, + metric: str = "l2", + embed_key: str | None = None, +) -> dict[str, float]: + """Compute top-k accuracy for perturbation retrieval. + + For each real perturbed cell, we identify the *k* closest predicted cells (according + to the provided distance metric) and assign a score of ``1`` if any of those cells + share the same perturbation label. Otherwise we assign ``0``. The score for each + perturbation is the mean score across its constituent cells. + + Args: + data: Paired AnnData objects containing real and predicted perturbations. + k: Number of nearest neighbours to consider. + metric: Distance metric to use when computing neighbours. Defaults to ``"l2"``. + embed_key: Optional embedding key in ``obsm`` to use instead of expression. + + Returns: + Mapping from perturbation name to top-k accuracy. + """ + + if k <= 0: + raise ValueError("Parameter `k` must be a positive integer") + + metric_normalized = metric.lower() + if metric_normalized in {"l2", "euclidean"}: + metric_normalized = "euclidean" + elif metric_normalized in {"l1", "manhattan", "cityblock"}: + metric_normalized = "manhattan" + + real_matrix = _get_array(data.real, embed_key=embed_key) + pred_matrix = _get_array(data.pred, embed_key=embed_key) + pred_labels = data.pred.obs[data.pert_col].to_numpy(str) + + n_pred_cells = pred_matrix.shape[0] + if n_pred_cells == 0: + raise ValueError("Predicted AnnData does not contain any cells") + + topk = min(k, n_pred_cells) + + scores: dict[str, float] = {} + for pert in data.perts: + real_indices = data.pert_mask_real[pert] + if real_indices.size == 0: + scores[str(pert)] = float("nan") + continue + + real_cells = real_matrix[real_indices] + cell_scores = np.zeros(real_cells.shape[0], dtype=np.float64) + + distances = skm.pairwise_distances( + real_cells, + pred_matrix, + metric=metric_normalized, + ) + + for idx, neighbor_distances in enumerate(distances): + neighbor_indices = np.argpartition(neighbor_distances, topk - 1)[:topk] + neighbor_perts = pred_labels[neighbor_indices] + cell_scores[idx] = float(np.any(neighbor_perts == pert)) + + scores[str(pert)] = float(cell_scores.mean()) + + return scores + + def _generic_evaluation( data: PerturbationAnndataPair, func: Callable[[np.ndarray, np.ndarray], float], diff --git a/src/cell_eval/metrics/_impl.py b/src/cell_eval/metrics/_impl.py index 667f08a..f3a8409 100644 --- a/src/cell_eval/metrics/_impl.py +++ b/src/cell_eval/metrics/_impl.py @@ -8,6 +8,7 @@ mse, mse_delta, pearson_delta, + top_k_accuracy, ) from ._de import ( DEDirectionMatch, @@ -72,6 +73,14 @@ kwargs={"metric": distance_metric}, ) +metrics_registry.register( + name="top_k_accuracy", + metric_type=MetricType.ANNDATA_PAIR, + description="Top-k retrieval accuracy of predicted perturbation profiles", + best_value=MetricBestValue.ONE, + func=top_k_accuracy, +) + metrics_registry.register( name="pearson_edistance", metric_type=MetricType.ANNDATA_PAIR, diff --git a/tests/test_metrics_topk.py b/tests/test_metrics_topk.py new file mode 100644 index 0000000..f9aca6b --- /dev/null +++ b/tests/test_metrics_topk.py @@ -0,0 +1,90 @@ +import numpy as np +import pytest +import anndata as ad + +from cell_eval._types import PerturbationAnndataPair +from cell_eval.metrics import top_k_accuracy + + +def _make_anndata(matrix: np.ndarray, perts: list[str], genes: list[str]) -> ad.AnnData: + adata = ad.AnnData(X=matrix.astype(np.float64)) + adata.obs["pert"] = perts + adata.var_names = genes + return adata + + +def _build_pair(real_matrix: np.ndarray, pred_matrix: np.ndarray) -> PerturbationAnndataPair: + genes = ["g1", "g2"] + perts = ["ctrl", "ctrl", "A", "A", "B", "B"] + + adata_real = _make_anndata(real_matrix, perts, genes) + adata_pred = _make_anndata(pred_matrix, perts, genes) + + return PerturbationAnndataPair( + real=adata_real, + pred=adata_pred, + pert_col="pert", + control_pert="ctrl", + ) + + +def test_topk_accuracy_perfect_match() -> None: + real_matrix = np.array( + [ + [0.0, 0.0], + [0.1, -0.1], + [1.0, 0.0], + [1.0, 0.1], + [-1.0, 0.0], + [-1.0, -0.1], + ] + ) + + pred_matrix = np.array( + [ + [0.0, 0.05], + [0.05, -0.05], + [1.05, 0.05], + [0.95, -0.05], + [-1.05, 0.0], + [-0.95, 0.05], + ] + ) + + pair = _build_pair(real_matrix, pred_matrix) + + scores = top_k_accuracy(pair, k=1) + + assert scores["A"] == pytest.approx(1.0) + assert scores["B"] == pytest.approx(1.0) + + +def test_topk_accuracy_mismatch() -> None: + real_matrix = np.array( + [ + [0.0, 0.0], + [0.1, -0.1], + [1.0, 0.0], + [1.0, 0.1], + [-1.0, 0.0], + [-1.0, -0.1], + ] + ) + + pred_matrix = np.array( + [ + [0.0, 0.05], + [0.05, -0.05], + [1.05, 0.05], + [0.95, -0.05], + [2.0, 2.0], + [2.2, 1.8], + ] + ) + + pair = _build_pair(real_matrix, pred_matrix) + + scores = top_k_accuracy(pair, k=1) + + assert scores["A"] == pytest.approx(1.0) + assert scores["B"] == pytest.approx(0.0)