From 355cd3fbc7bd4f843a9ccbed1a5fd186acadd24a Mon Sep 17 00:00:00 2001 From: tsenoner Date: Thu, 11 Jun 2026 19:17:43 +0200 Subject: [PATCH 01/21] chore(docs): add EAT annotation-transfer design spec + backend implementation plan --- .../plans/2026-06-11-eat-transfer-backend.md | 1671 +++++++++++++++++ .../2026-05-27-neighbors-subcommand-design.md | 14 + ...26-06-11-eat-annotation-transfer-design.md | 337 ++++ 3 files changed, 2022 insertions(+) create mode 100644 docs/superpowers/plans/2026-06-11-eat-transfer-backend.md create mode 100644 docs/superpowers/specs/2026-05-27-neighbors-subcommand-design.md create mode 100644 docs/superpowers/specs/2026-06-11-eat-annotation-transfer-design.md diff --git a/docs/superpowers/plans/2026-06-11-eat-transfer-backend.md b/docs/superpowers/plans/2026-06-11-eat-transfer-backend.md new file mode 100644 index 00000000..a3afcd20 --- /dev/null +++ b/docs/superpowers/plans/2026-06-11-eat-transfer-backend.md @@ -0,0 +1,1671 @@ +# EAT Annotation-Transfer Backend Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Add a `protlabel` embedding-annotation-transfer engine and a `protspace transfer` CLI subcommand that fills in missing annotation values for query proteins from their nearest reference neighbours in pLM embedding space, writing a per-cell prediction overlay back into the `.parquetbundle`. + +**Architecture:** `protlabel` is a small, ProtSpace-agnostic package (numpy/scipy/h5py only) that does the kNN search + goPredSim reliability index + label transfer. `protspace transfer` is a thin Typer subcommand that reads a bundle + HDF5 embeddings, classifies query vs reference proteins, calls `protlabel`, and appends `__pred_value` / `__pred_confidence` / `__pred_source` columns to the bundle's annotations table. Default = Euclidean, k=1. Optional gating/mining/report are out of scope for this MVP. + +**Tech Stack:** Python ≥3.10, numpy, scipy (new dep), h5py, pyarrow, Typer, pytest, ruff, uv. + +**Spec:** `docs/superpowers/specs/2026-06-11-eat-annotation-transfer-design.md`. The two refinements below override the spec where they differ (and the spec's §4/§10 are updated to match): +- **Packaging:** `protlabel` ships as a second top-level package *inside the protspace repo* (`src/protlabel/`), bundled into the protspace wheel — not a suite-level uv workspace member (the suite root is not a uv workspace, and a separate PyPI distribution would need its own release/CI). The strict no-`protspace`-imports boundary keeps a future standalone split trivial. +- **Overlay storage:** extra `*__pred_*` columns on the existing annotations table (bundle part 1), **not** a new 5th bundle part — this is backward-compatible with the currently-deployed web reader, which tolerates unknown columns but parses a fixed part count. + +**Out of scope (follow-up plans):** the `protspace_web` frontend rendering (separate repo/PR), optional `--cutoff` gating, `--mine` neighborhood mining, `--report`/`--plots`/`--full-tables`, faiss-cpu ANN backend, ProtTucker learned distance. + +**Conventions (enforced):** all Python via `uv run`; deps via `uv add` (never hand-edit `[project.dependencies]`); ruff clean; update docs + Colab notebook before committing; feature branch + PR, never push to `main`; commit prefixes — `feat:` only for user-visible changes, `chore:`/`test:`/`refactor:` for the rest. + +--- + +## File Structure + +**New package `protlabel` (`src/protlabel/`):** +- `src/protlabel/__init__.py` — public API: `eat`, `Lookup`, `Prediction`. +- `src/protlabel/reliability.py` — `similarity()`, the goPredSim distance→`[0,1]` transform. +- `src/protlabel/backends.py` — `nearest()`, chunked brute-force kNN (Euclidean/cosine). +- `src/protlabel/transfer.py` — `Prediction` dataclass + `eat()`/`transfer_labels()` (kNN → RI → label). +- `src/protlabel/lookup.py` — `Lookup` dataclass: build / `save` / `load` (sidecar `.npz`) / `query`. + +**ProtSpace additions/changes:** +- Create `src/protspace/analysis/__init__.py`, `src/protspace/analysis/classification.py` — query/reference classifier. +- Create `src/protspace/data/io/predictions.py` — build overlay columns from predictions. +- Create `src/protspace/cli/transfer.py` — `run_transfer()` (pure) + the `transfer` Typer command. +- Modify `src/protspace/data/io/bundle.py` — add `replace_annotations_in_bundle()`. +- Modify `src/protspace/cli/app.py:65-73` — register the `transfer` subcommand. +- Modify `pyproject.toml` — add `[tool.hatch.build.targets.wheel] packages` (both packages); add `scipy` via `uv add`. + +**Tests (all under `tests/`, run with `uv run pytest tests/`):** +- `tests/test_protlabel_reliability.py`, `tests/test_protlabel_backends.py`, `tests/test_protlabel_transfer.py`, `tests/test_protlabel_lookup.py` +- `tests/test_classification.py`, `tests/test_predictions_overlay.py`, `tests/test_bundle_overlay.py`, `tests/test_transfer_cli.py` + +**Docs:** +- Modify `docs/cli.md`, `docs/annotations.md`, top-level `../CLAUDE.md` CLI table; add `notebooks/ProtSpace_Transfer.ipynb`. + +--- + +## Task 1: Scaffold the `protlabel` package + +**Files:** +- Create: `src/protlabel/__init__.py` +- Modify: `pyproject.toml` (hatchling packages + scipy dep) +- Test: `tests/test_protlabel_reliability.py` (placeholder import test, expanded in Task 2) + +- [ ] **Step 1: Create the package marker with a version** + +Create `src/protlabel/__init__.py`: + +```python +"""protlabel — Embedding Annotation Transfer (EAT) engine. + +Nearest-neighbour label transfer in protein-language-model embedding space, +with the goPredSim reliability index. Pure numpy/scipy/h5py; no protspace imports. +""" + +__version__ = "0.1.0" +``` + +- [ ] **Step 2: Tell hatchling to build both packages** + +Edit `pyproject.toml`. After the `[build-system]` block (around line 71), add: + +```toml +[tool.hatch.build.targets.wheel] +packages = ["src/protspace", "src/protlabel"] +``` + +- [ ] **Step 3: Add the scipy dependency (via uv, not by hand)** + +Run: `uv add 'scipy>=1.10'` +Expected: `pyproject.toml` gains `scipy>=1.10` in `[project.dependencies]` and `uv.lock` updates. + +- [ ] **Step 4: Sync so `import protlabel` resolves** + +Run: `uv sync` +Then verify: `uv run python -c "import protlabel; print(protlabel.__version__)"` +Expected: prints `0.1.0` + +- [ ] **Step 5: Write a smoke test** + +Create `tests/test_protlabel_reliability.py`: + +```python +"""Tests for protlabel.reliability.""" + + +def test_protlabel_imports(): + import protlabel + + assert protlabel.__version__ +``` + +- [ ] **Step 6: Run the smoke test** + +Run: `uv run pytest tests/test_protlabel_reliability.py -v` +Expected: PASS + +- [ ] **Step 7: Commit** + +```bash +git add src/protlabel/__init__.py pyproject.toml uv.lock tests/test_protlabel_reliability.py +git commit -m "chore(protlabel): scaffold EAT engine package + scipy dep" +``` + +--- + +## Task 2: Reliability index (`protlabel.reliability`) + +**Files:** +- Create: `src/protlabel/reliability.py` +- Test: `tests/test_protlabel_reliability.py` + +- [ ] **Step 1: Write the failing tests** + +Replace `tests/test_protlabel_reliability.py` with: + +```python +"""Tests for protlabel.reliability.""" + +import math + +import pytest + +from protlabel.reliability import similarity + + +def test_euclidean_at_zero_distance_is_one(): + assert similarity(0.0, "euclidean") == pytest.approx(1.0) + + +def test_euclidean_at_half_distance_is_half(): + assert similarity(0.5, "euclidean") == pytest.approx(0.5) + + +def test_euclidean_decreases_to_zero(): + assert similarity(1e9, "euclidean") == pytest.approx(0.0, abs=1e-6) + + +def test_cosine_is_one_minus_distance(): + assert similarity(0.2, "cosine") == pytest.approx(0.8) + + +def test_cosine_clamped_to_unit_interval(): + # cosine distance can be up to 2.0 -> 1 - d would go negative; clamp at 0 + assert similarity(1.7, "cosine") == pytest.approx(0.0) + assert similarity(-0.1, "cosine") == pytest.approx(1.0) + + +def test_unknown_metric_raises(): + with pytest.raises(ValueError): + similarity(0.5, "manhattan") + + +def test_smoke(): + assert math.isfinite(similarity(0.5, "euclidean")) +``` + +- [ ] **Step 2: Run to verify failure** + +Run: `uv run pytest tests/test_protlabel_reliability.py -v` +Expected: FAIL — `ImportError: cannot import name 'similarity'` + +- [ ] **Step 3: Implement** + +Create `src/protlabel/reliability.py`: + +```python +"""goPredSim reliability index: map an embedding distance to a [0,1] confidence. + +Euclidean: s(d) = 0.5 / (0.5 + d) (1.0 at d=0, 0.5 at d=0.5, ->0 as d->inf) +Cosine: s(d) = 1 - d (clamped to [0,1]; cosine distance in [0,2]) + +Reference: Littmann et al., Sci Rep 2021 (Eq. 5); goPredSim calc_reliability_index. +""" + +from __future__ import annotations + + +def similarity(distance: float, metric: str) -> float: + """Per-neighbour distance->similarity (the goPredSim reliability transform).""" + if metric == "euclidean": + return 0.5 / (0.5 + distance) + if metric == "cosine": + return min(1.0, max(0.0, 1.0 - distance)) + raise ValueError(f"Unknown metric {metric!r}; expected 'euclidean' or 'cosine'") +``` + +- [ ] **Step 4: Run to verify pass** + +Run: `uv run pytest tests/test_protlabel_reliability.py -v` +Expected: PASS (all 7) + +- [ ] **Step 5: Lint + commit** + +```bash +uv run ruff check src/protlabel/reliability.py tests/test_protlabel_reliability.py +git add src/protlabel/reliability.py tests/test_protlabel_reliability.py +git commit -m "feat(protlabel): goPredSim reliability index transform" +``` + +--- + +## Task 3: Brute-force kNN backend (`protlabel.backends`) + +**Files:** +- Create: `src/protlabel/backends.py` +- Test: `tests/test_protlabel_backends.py` + +- [ ] **Step 1: Write the failing tests** + +Create `tests/test_protlabel_backends.py`: + +```python +"""Tests for protlabel.backends.nearest.""" + +import numpy as np +import pytest + +from protlabel.backends import nearest + + +def _toy(): + # 3 references on a line; queries close to ref 0 and ref 2 + refs = np.array([[0.0, 0.0], [10.0, 0.0], [20.0, 0.0]], dtype=np.float32) + queries = np.array([[0.1, 0.0], [19.5, 0.0]], dtype=np.float32) + return queries, refs + + +def test_returns_shapes(): + queries, refs = _toy() + idx, dist = nearest(queries, refs, k=2, metric="euclidean") + assert idx.shape == (2, 2) + assert dist.shape == (2, 2) + + +def test_nearest_index_euclidean(): + queries, refs = _toy() + idx, dist = nearest(queries, refs, k=1, metric="euclidean") + assert idx[0, 0] == 0 # first query nearest to ref 0 + assert idx[1, 0] == 2 # second query nearest to ref 2 + assert dist[0, 0] == pytest.approx(0.1, abs=1e-4) + + +def test_neighbours_sorted_by_distance(): + queries, refs = _toy() + idx, dist = nearest(queries, refs, k=3, metric="euclidean") + assert np.all(np.diff(dist, axis=1) >= -1e-6) # non-decreasing per row + + +def test_cosine_metric_runs_and_orders(): + refs = np.array([[1.0, 0.0], [0.0, 1.0]], dtype=np.float32) + queries = np.array([[1.0, 0.1]], dtype=np.float32) # closest in angle to ref 0 + idx, dist = nearest(queries, refs, k=1, metric="cosine") + assert idx[0, 0] == 0 + + +def test_k_capped_to_num_refs(): + queries, refs = _toy() + idx, dist = nearest(queries, refs, k=10, metric="euclidean") + assert idx.shape == (2, 3) # only 3 refs available + + +def test_chunking_matches_unchunked(): + rng = np.random.default_rng(0) + refs = rng.standard_normal((50, 8)).astype(np.float32) + queries = rng.standard_normal((7, 8)).astype(np.float32) + a_idx, a_dist = nearest(queries, refs, k=3, metric="euclidean", chunk=1000) + b_idx, b_dist = nearest(queries, refs, k=3, metric="euclidean", chunk=3) + assert np.array_equal(a_idx, b_idx) + assert np.allclose(a_dist, b_dist, atol=1e-5) + + +def test_unknown_metric_raises(): + queries, refs = _toy() + with pytest.raises(ValueError): + nearest(queries, refs, k=1, metric="manhattan") +``` + +- [ ] **Step 2: Run to verify failure** + +Run: `uv run pytest tests/test_protlabel_backends.py -v` +Expected: FAIL — `ModuleNotFoundError: No module named 'protlabel.backends'` + +- [ ] **Step 3: Implement** + +Create `src/protlabel/backends.py`: + +```python +"""Exact (brute-force) k-nearest-neighbour search over reference embeddings. + +Chunked over the query axis so the Q_chunk x N distance block stays small, +which keeps peak memory near the reference matrix itself even at Swiss-Prot +scale. scipy.cdist handles both euclidean and cosine. +""" + +from __future__ import annotations + +import numpy as np +from scipy.spatial.distance import cdist + +_METRICS = {"euclidean", "cosine"} + + +def nearest( + queries: np.ndarray, + refs: np.ndarray, + k: int, + metric: str = "euclidean", + chunk: int = 4096, +) -> tuple[np.ndarray, np.ndarray]: + """Return (idx, dist) of the k nearest *references* per query. + + idx[i] -> indices into ``refs`` of the k nearest, ascending by distance. + dist[i] -> the corresponding distances. + k is capped to the number of references. + """ + if metric not in _METRICS: + raise ValueError(f"Unknown metric {metric!r}; expected one of {_METRICS}") + + queries = np.ascontiguousarray(queries, dtype=np.float32) + refs = np.ascontiguousarray(refs, dtype=np.float32) + n_refs = refs.shape[0] + k = min(k, n_refs) + + idx_out = np.empty((queries.shape[0], k), dtype=np.int64) + dist_out = np.empty((queries.shape[0], k), dtype=np.float32) + + for start in range(0, queries.shape[0], chunk): + block = queries[start : start + chunk] + d = cdist(block, refs, metric=metric).astype(np.float32) # (b, n_refs) + part = np.argpartition(d, kth=k - 1, axis=1)[:, :k] # unsorted top-k + rows = np.arange(block.shape[0])[:, None] + part_d = d[rows, part] + order = np.argsort(part_d, axis=1) # sort the k by distance + sorted_idx = part[rows, order] + idx_out[start : start + block.shape[0]] = sorted_idx + dist_out[start : start + block.shape[0]] = d[rows, sorted_idx] + + return idx_out, dist_out +``` + +- [ ] **Step 4: Run to verify pass** + +Run: `uv run pytest tests/test_protlabel_backends.py -v` +Expected: PASS (all 7) + +- [ ] **Step 5: Lint + commit** + +```bash +uv run ruff check src/protlabel/backends.py tests/test_protlabel_backends.py +git add src/protlabel/backends.py tests/test_protlabel_backends.py +git commit -m "feat(protlabel): chunked brute-force kNN backend" +``` + +--- + +## Task 4: Label transfer + Prediction (`protlabel.transfer`) + +**Files:** +- Create: `src/protlabel/transfer.py` +- Test: `tests/test_protlabel_transfer.py` + +- [ ] **Step 1: Write the failing tests** + +Create `tests/test_protlabel_transfer.py`: + +```python +"""Tests for protlabel.transfer.""" + +import numpy as np +import pytest + +from protlabel.transfer import Prediction, eat + + +def _setup(): + ref_emb = np.array([[0.0, 0.0], [10.0, 0.0], [20.0, 0.0]], dtype=np.float32) + ref_ids = ["R0", "R1", "R2"] + ref_labels = ["toxin", "enzyme", "toxin"] + query_emb = np.array([[0.0, 0.0], [19.7, 0.0]], dtype=np.float32) + query_ids = ["Q0", "Q1"] + return ref_emb, ref_ids, ref_labels, query_emb, query_ids + + +def test_k1_transfers_nearest_label_and_source(): + ref_emb, ref_ids, ref_labels, query_emb, query_ids = _setup() + preds = eat(query_emb, query_ids, ref_emb, ref_ids, ref_labels, k=1) + assert isinstance(preds[0], Prediction) + assert preds[0].query_id == "Q0" + assert preds[0].label == "toxin" + assert preds[0].source_id == "R0" + assert preds[0].reliability == pytest.approx(1.0) # distance 0 -> RI 1.0 + + +def test_k1_reliability_uses_gopredsim_transform(): + ref_emb, ref_ids, ref_labels, query_emb, query_ids = _setup() + preds = eat(query_emb, query_ids, ref_emb, ref_ids, ref_labels, k=1) + # Q1 distance to R2 is 0.3 -> RI = 0.5/(0.5+0.3) + assert preds[1].label == "toxin" + assert preds[1].source_id == "R2" + assert preds[1].reliability == pytest.approx(0.5 / 0.8, abs=1e-4) + + +def test_k3_vote_picks_majority_label(): + # Query equidistant-ish but two of three nearest are "toxin" + ref_emb = np.array( + [[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [3.0, 0.0]], dtype=np.float32 + ) + ref_ids = ["R0", "R1", "R2", "R3"] + ref_labels = ["toxin", "enzyme", "toxin", "toxin"] + query_emb = np.array([[1.4, 0.0]], dtype=np.float32) + preds = eat(query_emb, ["Q"], ref_emb, ref_ids, ref_labels, k=3) + assert preds[0].label == "toxin" # toxin RI sum beats lone enzyme neighbour + assert 0.0 < preds[0].reliability <= 1.0 + + +def test_cosine_metric(): + ref_emb = np.array([[1.0, 0.0], [0.0, 1.0]], dtype=np.float32) + preds = eat( + np.array([[1.0, 0.05]], dtype=np.float32), + ["Q"], + ref_emb, + ["R0", "R1"], + ["a", "b"], + k=1, + metric="cosine", + ) + assert preds[0].label == "a" + + +def test_length_mismatch_raises(): + ref_emb, ref_ids, ref_labels, query_emb, query_ids = _setup() + with pytest.raises(ValueError): + eat(query_emb, query_ids, ref_emb, ref_ids, ref_labels[:-1], k=1) + + +def test_empty_references_raises(): + with pytest.raises(ValueError): + eat( + np.zeros((1, 2), dtype=np.float32), + ["Q"], + np.zeros((0, 2), dtype=np.float32), + [], + [], + k=1, + ) +``` + +- [ ] **Step 2: Run to verify failure** + +Run: `uv run pytest tests/test_protlabel_transfer.py -v` +Expected: FAIL — `ModuleNotFoundError: No module named 'protlabel.transfer'` + +- [ ] **Step 3: Implement** + +Create `src/protlabel/transfer.py`: + +```python +"""Embedding annotation transfer: kNN -> reliability index -> transferred label. + +Implements the goPredSim aggregation (Littmann et al. 2021, Eq. 5): + RI(p) = (1/k) * sum over neighbours carrying label p of similarity(d). +The transferred label is argmax RI(p); its source is the nearest neighbour +carrying that label. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +import numpy as np + +from protlabel.backends import nearest +from protlabel.reliability import similarity + + +@dataclass(frozen=True) +class Prediction: + """One transferred annotation for a query protein.""" + + query_id: str + label: str + source_id: str + distance: float + reliability: float + k: int + metric: str + + +def eat( + query_emb: np.ndarray, + query_ids: list[str], + ref_emb: np.ndarray, + ref_ids: list[str], + ref_labels: list[str], + *, + k: int = 1, + metric: str = "euclidean", +) -> list[Prediction]: + """Transfer the best-guess label to each query from its k nearest references.""" + if not (len(ref_ids) == len(ref_labels) == ref_emb.shape[0]): + raise ValueError("ref_emb, ref_ids and ref_labels must have equal length") + if ref_emb.shape[0] == 0: + raise ValueError("No reference embeddings to transfer from") + if len(query_ids) != query_emb.shape[0]: + raise ValueError("query_emb and query_ids must have equal length") + + idx, dist = nearest(query_emb, ref_emb, k=k, metric=metric) + eff_k = idx.shape[1] + predictions: list[Prediction] = [] + + for qi, query_id in enumerate(query_ids): + neigh_idx = idx[qi] + neigh_dist = dist[qi] + # Accumulate RI per label and track the nearest source per label. + ri_by_label: dict[str, float] = {} + nearest_src: dict[str, tuple[float, str]] = {} + for j, ref_i in enumerate(neigh_idx): + lab = ref_labels[ref_i] + d = float(neigh_dist[j]) + ri_by_label[lab] = ri_by_label.get(lab, 0.0) + similarity(d, metric) + if lab not in nearest_src or d < nearest_src[lab][0]: + nearest_src[lab] = (d, ref_ids[ref_i]) + # Normalise by k (the goPredSim 1/k term). + best_label = max(ri_by_label, key=lambda p: ri_by_label[p]) + ri = ri_by_label[best_label] / eff_k + src_dist, src_id = nearest_src[best_label] + predictions.append( + Prediction( + query_id=query_id, + label=best_label, + source_id=src_id, + distance=src_dist, + reliability=ri, + k=eff_k, + metric=metric, + ) + ) + + return predictions +``` + +- [ ] **Step 4: Run to verify pass** + +Run: `uv run pytest tests/test_protlabel_transfer.py -v` +Expected: PASS (all 6) + +- [ ] **Step 5: Lint + commit** + +```bash +uv run ruff check src/protlabel/transfer.py tests/test_protlabel_transfer.py +git add src/protlabel/transfer.py tests/test_protlabel_transfer.py +git commit -m "feat(protlabel): kNN label transfer with reliability index" +``` + +--- + +## Task 5: Persistable lookup (`protlabel.lookup`) + public API + +**Files:** +- Create: `src/protlabel/lookup.py` +- Modify: `src/protlabel/__init__.py` +- Test: `tests/test_protlabel_lookup.py` + +- [ ] **Step 1: Write the failing tests** + +Create `tests/test_protlabel_lookup.py`: + +```python +"""Tests for protlabel.lookup.Lookup (the rebuildable sidecar).""" + +import numpy as np + +from protlabel import Lookup, Prediction + + +def _lookup(): + emb = np.array([[0.0, 0.0], [10.0, 0.0]], dtype=np.float32) + return Lookup(embeddings=emb, ids=["R0", "R1"], labels=["a", "b"]) + + +def test_query_returns_predictions(): + lk = _lookup() + preds = lk.query(np.array([[0.2, 0.0]], dtype=np.float32), ["Q0"], k=1) + assert isinstance(preds[0], Prediction) + assert preds[0].label == "a" + assert preds[0].source_id == "R0" + + +def test_save_load_roundtrip(tmp_path): + lk = _lookup() + path = tmp_path / "lookup.npz" + lk.save(path) + assert path.exists() + loaded = Lookup.load(path) + assert loaded.ids == lk.ids + assert loaded.labels == lk.labels + assert loaded.metric == lk.metric + assert np.allclose(loaded.embeddings, lk.embeddings) + + +def test_loaded_lookup_queries_identically(tmp_path): + lk = _lookup() + q = np.array([[9.8, 0.0]], dtype=np.float32) + before = lk.query(q, ["Q"], k=1) + path = tmp_path / "lk.npz" + lk.save(path) + after = Lookup.load(path).query(q, ["Q"], k=1) + assert before[0].label == after[0].label == "b" + assert before[0].reliability == after[0].reliability +``` + +- [ ] **Step 2: Run to verify failure** + +Run: `uv run pytest tests/test_protlabel_lookup.py -v` +Expected: FAIL — `ImportError: cannot import name 'Lookup' from 'protlabel'` + +- [ ] **Step 3: Implement the Lookup** + +Create `src/protlabel/lookup.py`: + +```python +"""A persistable reference lookup: embeddings + ids + labels, plus query(). + +Serialized as a single .npz sidecar so it can live next to a bundle or in a +cache dir and be rebuilt on demand from the source HDF5. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path + +import numpy as np + +from protlabel.transfer import Prediction, eat + + +@dataclass +class Lookup: + """Reference set for embedding annotation transfer.""" + + embeddings: np.ndarray # (N, D) float32 + ids: list[str] + labels: list[str] + metric: str = "euclidean" + model: str = field(default="") + + def query( + self, query_emb: np.ndarray, query_ids: list[str], *, k: int = 1 + ) -> list[Prediction]: + """Transfer labels to query embeddings from this lookup.""" + return eat( + query_emb, + query_ids, + self.embeddings, + self.ids, + self.labels, + k=k, + metric=self.metric, + ) + + def save(self, path: Path) -> None: + """Serialize to a .npz sidecar.""" + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + np.savez( + path, + embeddings=self.embeddings.astype(np.float16), + ids=np.array(self.ids, dtype=object), + labels=np.array(self.labels, dtype=object), + metric=self.metric, + model=self.model, + ) + + @classmethod + def load(cls, path: Path) -> Lookup: + """Load a .npz sidecar (re-upcasts embeddings to float32).""" + with np.load(path, allow_pickle=True) as data: + return cls( + embeddings=data["embeddings"].astype(np.float32), + ids=list(data["ids"]), + labels=list(data["labels"]), + metric=str(data["metric"]), + model=str(data["model"]), + ) +``` + +> Note: `np.savez` appends `.npz` if the path has no extension; the tests use an explicit `.npz` so the saved file matches the requested path. + +- [ ] **Step 4: Export the public API** + +Replace `src/protlabel/__init__.py` with: + +```python +"""protlabel — Embedding Annotation Transfer (EAT) engine. + +Nearest-neighbour label transfer in protein-language-model embedding space, +with the goPredSim reliability index. Pure numpy/scipy/h5py; no protspace imports. +""" + +from protlabel.lookup import Lookup +from protlabel.transfer import Prediction, eat + +__version__ = "0.1.0" + +__all__ = ["Lookup", "Prediction", "eat", "__version__"] +``` + +- [ ] **Step 5: Run to verify pass** + +Run: `uv run pytest tests/test_protlabel_lookup.py -v` +Expected: PASS (all 3) + +- [ ] **Step 6: Run the whole protlabel suite + guard the boundary** + +Run: `uv run pytest tests/test_protlabel_*.py -v` +Expected: PASS +Run: `! grep -rqE "import protspace|from protspace" src/protlabel/ && echo "boundary clean"` +Expected: prints `boundary clean` (protlabel must not import protspace) + +- [ ] **Step 7: Lint + commit** + +```bash +uv run ruff check src/protlabel/ tests/test_protlabel_lookup.py +git add src/protlabel/lookup.py src/protlabel/__init__.py tests/test_protlabel_lookup.py +git commit -m "feat(protlabel): persistable Lookup sidecar + public API" +``` + +--- + +## Task 6: Query/reference classifier (`protspace.analysis.classification`) + +**Files:** +- Create: `src/protspace/analysis/__init__.py`, `src/protspace/analysis/classification.py` +- Test: `tests/test_classification.py` + +- [ ] **Step 1: Write the failing tests** + +Create `tests/test_classification.py`: + +```python +"""Tests for the query/reference classifier.""" + +import pyarrow as pa +import pytest + +from protspace.analysis.classification import Rule, classify + + +def _table(): + return pa.table( + { + "identifier": ["TRINITY_1", "TRINITY_2", "P00001", "P00002"], + "protein_category": ["mSCR", "mSCR", "neurotoxin", "enzyme"], + } + ) + + +def test_prefix_rule_selects_queries(): + q = Rule(id_prefixes=["TRINITY_"]) + r = Rule(where=[("protein_category", "neurotoxin")]) + qi, ri = classify(_table(), q, r) + assert qi == [0, 1] + assert ri == [2] + + +def test_where_substring_is_case_insensitive(): + q = Rule(where=[("protein_category", "MSCR")]) + r = Rule(id_prefixes=["P0"]) + qi, ri = classify(_table(), q, r) + assert qi == [0, 1] + assert ri == [2, 3] + + +def test_query_takes_precedence_over_reference(): + # A protein matching both rules is classified as a query, never a reference. + q = Rule(id_prefixes=["P00001"]) + r = Rule(where=[("protein_category", "neurotoxin")]) + qi, ri = classify(_table(), q, r) + assert 2 in qi + assert 2 not in ri + + +def test_empty_query_match_raises(): + q = Rule(id_prefixes=["NOPE_"]) + r = Rule(id_prefixes=["P0"]) + with pytest.raises(ValueError, match="no query"): + classify(_table(), q, r) + + +def test_missing_where_column_raises(): + q = Rule(where=[("not_a_column", "x")]) + r = Rule(id_prefixes=["P0"]) + with pytest.raises(KeyError): + classify(_table(), q, r) +``` + +- [ ] **Step 2: Run to verify failure** + +Run: `uv run pytest tests/test_classification.py -v` +Expected: FAIL — `ModuleNotFoundError: No module named 'protspace.analysis'` + +- [ ] **Step 3: Implement** + +Create `src/protspace/analysis/__init__.py`: + +```python +"""Optional analysis layer for ProtSpace (classification, gating, mining).""" +``` + +Create `src/protspace/analysis/classification.py`: + +```python +"""Classify proteins as transfer queries vs annotated references. + +Rules match by ID prefix and/or a case-insensitive metadata substring +(``column ~ substring``). No biology is hardcoded; a query rule that matches +nothing is an error. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +import pyarrow as pa + + +@dataclass +class Rule: + """A classification rule. A protein matches if ANY clause matches.""" + + id_prefixes: list[str] = field(default_factory=list) + where: list[tuple[str, str]] = field(default_factory=list) # (column, substring) + + +def _matches(rule: Rule, identifier: str, row: dict[str, str]) -> bool: + if any(identifier.startswith(p) for p in rule.id_prefixes): + return True + for column, substring in rule.where: + if column not in row: + raise KeyError(f"Classification column {column!r} not in annotations") + value = row[column] + if value is not None and substring.lower() in str(value).lower(): + return True + return False + + +def classify( + annotations: pa.Table, query_rule: Rule, reference_rule: Rule +) -> tuple[list[int], list[int]]: + """Return (query_indices, reference_indices) into the annotations table. + + Query classification takes precedence: a protein matching both rules is a + query. Raises ValueError if the query rule matches nothing. + """ + columns = set(annotations.column_names) + # Validate where-columns up front so an empty table still raises KeyError. + for rule in (query_rule, reference_rule): + for column, _ in rule.where: + if column not in columns: + raise KeyError(f"Classification column {column!r} not in annotations") + + rows = annotations.to_pylist() + identifiers = [str(r["identifier"]) for r in rows] + + query_indices: list[int] = [] + reference_indices: list[int] = [] + for i, (identifier, row) in enumerate(zip(identifiers, rows, strict=True)): + if _matches(query_rule, identifier, row): + query_indices.append(i) + elif _matches(reference_rule, identifier, row): + reference_indices.append(i) + + if not query_indices: + raise ValueError( + "Classifier matched no query proteins; check --query-id-prefix / " + "--query-where rules." + ) + return query_indices, reference_indices +``` + +- [ ] **Step 4: Run to verify pass** + +Run: `uv run pytest tests/test_classification.py -v` +Expected: PASS (all 5) + +- [ ] **Step 5: Lint + commit** + +```bash +uv run ruff check src/protspace/analysis/ tests/test_classification.py +git add src/protspace/analysis/ tests/test_classification.py +git commit -m "feat: query/reference classifier for annotation transfer" +``` + +--- + +## Task 7: Build the overlay columns (`protspace.data.io.predictions`) + +**Files:** +- Create: `src/protspace/data/io/predictions.py` +- Test: `tests/test_predictions_overlay.py` + +- [ ] **Step 1: Write the failing tests** + +Create `tests/test_predictions_overlay.py`: + +```python +"""Tests for building the per-cell prediction overlay columns.""" + +import pyarrow as pa + +from protlabel import Prediction +from protspace.data.io.predictions import add_overlay_columns + + +def _table(): + return pa.table( + { + "identifier": ["Q0", "Q1", "R0"], + "protein_category": ["", "", "neurotoxin"], + } + ) + + +def test_adds_three_overlay_columns(): + preds = [ + Prediction("Q0", "neurotoxin", "R0", 0.3, 0.62, 1, "euclidean"), + ] + out = add_overlay_columns(_table(), "protein_category", preds) + assert "protein_category__pred_value" in out.column_names + assert "protein_category__pred_confidence" in out.column_names + assert "protein_category__pred_source" in out.column_names + + +def test_overlay_values_aligned_by_identifier(): + preds = [Prediction("Q1", "enzyme", "R9", 0.5, 0.5, 1, "euclidean")] + out = add_overlay_columns(_table(), "protein_category", preds).to_pylist() + by_id = {r["identifier"]: r for r in out} + assert by_id["Q1"]["protein_category__pred_value"] == "enzyme" + assert by_id["Q1"]["protein_category__pred_confidence"] == 0.5 + assert by_id["Q1"]["protein_category__pred_source"] == "R9" + # Non-predicted rows are null in the overlay columns. + assert by_id["Q0"]["protein_category__pred_value"] is None + assert by_id["R0"]["protein_category__pred_confidence"] is None + + +def test_curated_column_is_left_untouched(): + preds = [Prediction("Q0", "neurotoxin", "R0", 0.1, 0.8, 1, "euclidean")] + out = add_overlay_columns(_table(), "protein_category", preds).to_pylist() + by_id = {r["identifier"]: r for r in out} + assert by_id["Q0"]["protein_category"] == "" # original column unchanged + assert by_id["R0"]["protein_category"] == "neurotoxin" + + +def test_confidence_column_is_float(): + preds = [Prediction("Q0", "x", "R0", 0.1, 0.83, 1, "euclidean")] + out = add_overlay_columns(_table(), "protein_category", preds) + field = out.schema.field("protein_category__pred_confidence") + assert pa.types.is_floating(field.type) +``` + +- [ ] **Step 2: Run to verify failure** + +Run: `uv run pytest tests/test_predictions_overlay.py -v` +Expected: FAIL — `ModuleNotFoundError: No module named 'protspace.data.io.predictions'` + +- [ ] **Step 3: Implement** + +Create `src/protspace/data/io/predictions.py`: + +```python +"""Turn protlabel Predictions into per-cell overlay columns on the annotations table. + +For a transferred column ``COL`` we append three aligned columns (null for +non-predicted proteins), leaving the curated ``COL`` untouched: + COL__pred_value (string) the transferred label + COL__pred_confidence (float32) the reliability index in [0, 1] + COL__pred_source (string) the nearest reference protein id +""" + +from __future__ import annotations + +from collections.abc import Sequence + +import pyarrow as pa + +from protlabel import Prediction + + +def add_overlay_columns( + annotations: pa.Table, column: str, predictions: Sequence[Prediction] +) -> pa.Table: + """Append the COL__pred_* overlay columns, aligned by identifier.""" + by_query = {p.query_id: p for p in predictions} + identifiers = [str(v) for v in annotations.column("identifier").to_pylist()] + + values: list[str | None] = [] + confidences: list[float | None] = [] + sources: list[str | None] = [] + for identifier in identifiers: + pred = by_query.get(identifier) + if pred is None: + values.append(None) + confidences.append(None) + sources.append(None) + else: + values.append(pred.label) + confidences.append(float(pred.reliability)) + sources.append(pred.source_id) + + out = annotations + out = out.append_column(f"{column}__pred_value", pa.array(values, pa.string())) + out = out.append_column( + f"{column}__pred_confidence", pa.array(confidences, pa.float32()) + ) + out = out.append_column(f"{column}__pred_source", pa.array(sources, pa.string())) + return out +``` + +- [ ] **Step 4: Run to verify pass** + +Run: `uv run pytest tests/test_predictions_overlay.py -v` +Expected: PASS (all 4) + +- [ ] **Step 5: Lint + commit** + +```bash +uv run ruff check src/protspace/data/io/predictions.py tests/test_predictions_overlay.py +git add src/protspace/data/io/predictions.py tests/test_predictions_overlay.py +git commit -m "feat: build per-cell prediction overlay columns" +``` + +--- + +## Task 8: Rewrite the annotations part of a bundle (`bundle.replace_annotations_in_bundle`) + +**Files:** +- Modify: `src/protspace/data/io/bundle.py` +- Test: `tests/test_bundle_overlay.py` + +- [ ] **Step 1: Write the failing tests** + +Create `tests/test_bundle_overlay.py`: + +```python +"""Round-trip tests for replacing the annotations part of a bundle.""" + +import io + +import pyarrow as pa +import pyarrow.parquet as pq + +from protspace.data.io.bundle import ( + read_bundle, + replace_annotations_in_bundle, + write_bundle, +) + + +def _tables(): + annotations = pa.table({"identifier": ["A", "B"], "cat": ["x", "y"]}) + proj_meta = pa.table({"name": ["PCA 2"], "dims": [2]}) + proj_data = pa.table({"id": ["A", "B"], "x": [0.0, 1.0], "y": [0.0, 1.0]}) + return [annotations, proj_meta, proj_data] + + +def _read_part(part_bytes): + return pq.read_table(io.BytesIO(part_bytes)) + + +def test_replaces_annotations_keeps_other_parts(tmp_path): + src = tmp_path / "in.parquetbundle" + out = tmp_path / "out.parquetbundle" + write_bundle(_tables(), src) + + new_annotations = pa.table( + {"identifier": ["A", "B"], "cat": ["x", "y"], "cat__pred_value": [None, "z"]} + ) + replace_annotations_in_bundle(src, out, new_annotations) + + parts, settings = read_bundle(out) + assert "cat__pred_value" in _read_part(parts[0]).column_names + # Projections preserved byte-for-byte. + assert _read_part(parts[1]).column_names == ["name", "dims"] + assert _read_part(parts[2]).to_pydict()["x"] == [0.0, 1.0] + + +def test_preserves_settings_when_present(tmp_path): + src = tmp_path / "in.parquetbundle" + out = tmp_path / "out.parquetbundle" + write_bundle(_tables(), src, settings={"foo": 1}) + + new_annotations = pa.table({"identifier": ["A", "B"], "cat": ["x", "y"]}) + replace_annotations_in_bundle(src, out, new_annotations) + + _parts, settings = read_bundle(out) + assert settings == {"foo": 1} +``` + +- [ ] **Step 2: Run to verify failure** + +Run: `uv run pytest tests/test_bundle_overlay.py -v` +Expected: FAIL — `ImportError: cannot import name 'replace_annotations_in_bundle'` + +- [ ] **Step 3: Implement** + +Add to `src/protspace/data/io/bundle.py` (after `replace_settings_in_bundle`, around line 149): + +```python +def replace_annotations_in_bundle( + input_path: Path, + output_path: Path, + annotations_table: pa.Table, +) -> None: + """Replace the annotations (1st) part of a bundle, preserving the rest. + + Projection parts (2nd, 3rd) are kept byte-for-byte; an existing settings + (4th) part is carried over unchanged. + """ + with open(input_path, "rb") as f: + content = f.read() + + parts = content.split(PARQUET_BUNDLE_DELIMITER) + if len(parts) < 3 or len(parts) > 4: + raise ValueError(f"Expected 3 or 4 parts in parquetbundle, found {len(parts)}") + + buf = io.BytesIO() + pq.write_table(annotations_table, buf) + new_parts = [buf.getvalue(), parts[1], parts[2]] + if len(parts) == 4: + new_parts.append(parts[3]) + + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "wb") as f: + f.write(PARQUET_BUNDLE_DELIMITER.join(new_parts)) + + logger.info(f"Wrote bundle with updated annotations to: {output_path}") +``` + +- [ ] **Step 4: Run to verify pass** + +Run: `uv run pytest tests/test_bundle_overlay.py -v` +Expected: PASS (both) + +- [ ] **Step 5: Lint + commit** + +```bash +uv run ruff check src/protspace/data/io/bundle.py tests/test_bundle_overlay.py +git add src/protspace/data/io/bundle.py tests/test_bundle_overlay.py +git commit -m "feat: replace annotations part of a parquetbundle in place" +``` + +--- + +## Task 9: The `transfer` orchestration core + CLI (`protspace.cli.transfer`) + +**Files:** +- Create: `src/protspace/cli/transfer.py` +- Modify: `src/protspace/cli/app.py:65-73` +- Test: `tests/test_transfer_cli.py` + +- [ ] **Step 1: Write the failing tests (pure core + registration)** + +Create `tests/test_transfer_cli.py`: + +```python +"""Tests for the transfer orchestration core and CLI registration.""" + +import numpy as np +import pyarrow as pa +import pytest + +from protspace.analysis.classification import Rule +from protspace.cli.transfer import run_transfer + + +def _inputs(): + annotations = pa.table( + { + "identifier": ["TRINITY_1", "P00001", "P00002"], + "protein_category": ["", "neurotoxin", "enzyme"], + } + ) + # TRINITY_1 sits right on top of the neurotoxin reference P00001. + embeddings = { + "TRINITY_1": np.array([0.0, 0.0], dtype=np.float32), + "P00001": np.array([0.05, 0.0], dtype=np.float32), + "P00002": np.array([9.0, 0.0], dtype=np.float32), + } + return annotations, embeddings + + +def test_run_transfer_predicts_for_query_with_missing_value(): + annotations, embeddings = _inputs() + out = run_transfer( + annotations=annotations, + embeddings=embeddings, + transfer_columns=["protein_category"], + query_rule=Rule(id_prefixes=["TRINITY_"]), + reference_rule=Rule(where=[("protein_category", "")]), # any non-empty ref + k=1, + metric="euclidean", + ) + by_id = {r["identifier"]: r for r in out.to_pylist()} + assert by_id["TRINITY_1"]["protein_category__pred_value"] == "neurotoxin" + assert by_id["TRINITY_1"]["protein_category__pred_source"] == "P00001" + assert by_id["TRINITY_1"]["protein_category__pred_confidence"] > 0.9 + + +def test_run_transfer_skips_proteins_without_embeddings(): + annotations, embeddings = _inputs() + embeddings.pop("TRINITY_1") # no embedding -> cannot be a query + with pytest.raises(ValueError, match="no query"): + run_transfer( + annotations=annotations, + embeddings=embeddings, + transfer_columns=["protein_category"], + query_rule=Rule(id_prefixes=["TRINITY_"]), + reference_rule=Rule(id_prefixes=["P0"]), + k=1, + metric="euclidean", + ) + + +def test_transfer_command_is_registered(): + from typer.testing import CliRunner + + from protspace.cli.app import app + + result = CliRunner().invoke(app, ["transfer", "--help"]) + assert result.exit_code == 0 + assert "transfer" in result.output.lower() +``` + +- [ ] **Step 2: Run to verify failure** + +Run: `uv run pytest tests/test_transfer_cli.py -v` +Expected: FAIL — `ModuleNotFoundError: No module named 'protspace.cli.transfer'` + +- [ ] **Step 3: Implement the core + command** + +Create `src/protspace/cli/transfer.py`: + +```python +"""protspace transfer — fill missing annotation values from nearest references. + +Embedding Annotation Transfer (EAT): for each query protein with a missing +value in a target column, transfer the value of its nearest annotated +reference in pLM embedding space, with a reliability-index confidence. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Annotated + +import numpy as np +import pyarrow as pa +import typer + +from protspace.cli.app import app, setup_logging +from protspace.cli.common_options import Opt_Verbose + +logger = logging.getLogger(__name__) + + +def _is_missing(value) -> bool: + return value is None or str(value).strip() == "" + + +def run_transfer( + *, + annotations: pa.Table, + embeddings: dict[str, np.ndarray], + transfer_columns: list[str], + query_rule, + reference_rule, + k: int = 1, + metric: str = "euclidean", +) -> pa.Table: + """Pure core: classify, transfer per column, return the augmented table. + + ``embeddings`` maps protein id -> 1-D float32 vector. Proteins without an + embedding cannot act as queries or references. + """ + from protlabel import eat + + from protspace.analysis.classification import classify + from protspace.data.io.predictions import add_overlay_columns + + # Restrict classification to proteins that actually have an embedding. + has_emb = pa.array( + [str(v) in embeddings for v in annotations.column("identifier").to_pylist()] + ) + embedded = annotations.filter(has_emb) + + query_idx, ref_idx = classify(embedded, query_rule, reference_rule) + rows = embedded.to_pylist() + + out = annotations + for column in transfer_columns: + if column not in annotations.column_names: + raise KeyError(f"Transfer column {column!r} not in annotations table") + + # References: classified refs that HAVE a value in this column. + ref_ids, ref_labels, ref_vecs = [], [], [] + for i in ref_idx: + value = rows[i].get(column) + if not _is_missing(value): + rid = str(rows[i]["identifier"]) + ref_ids.append(rid) + ref_labels.append(str(value)) + ref_vecs.append(embeddings[rid]) + if not ref_ids: + logger.warning("No references with a value for %r; skipping", column) + continue + + # Queries: classified queries MISSING a value in this column. + q_ids, q_vecs = [], [] + for i in query_idx: + if _is_missing(rows[i].get(column)): + qid = str(rows[i]["identifier"]) + q_ids.append(qid) + q_vecs.append(embeddings[qid]) + if not q_ids: + logger.warning("No queries missing %r; nothing to transfer", column) + continue + + preds = eat( + np.vstack(q_vecs), + q_ids, + np.vstack(ref_vecs), + ref_ids, + ref_labels, + k=k, + metric=metric, + ) + out = add_overlay_columns(out, column, preds) + logger.info("Transferred %r to %d quer(ies)", column, len(preds)) + + return out + + +@app.command() +def transfer( + bundle: Annotated[ + Path, + typer.Option("-b", "--bundle", help="Input .parquetbundle to annotate."), + ], + embeddings: Annotated[ + str, + typer.Option( + "-e", + "--embeddings", + help="HDF5 embeddings, optional :name suffix (e.g. emb.h5:prot_t5).", + ), + ], + transfer_columns: Annotated[ + list[str], + typer.Option("-t", "--transfer", help="Annotation column to transfer (repeat)."), + ], + output: Annotated[ + Path, + typer.Option("-o", "--output", help="Output .parquetbundle path."), + ], + query_id_prefix: Annotated[list[str], typer.Option("--query-id-prefix")] = None, + query_where: Annotated[list[str], typer.Option("--query-where", help="col~substr")] = None, + reference_id_prefix: Annotated[list[str], typer.Option("--reference-id-prefix")] = None, + reference_where: Annotated[list[str], typer.Option("--reference-where", help="col~substr")] = None, + k: Annotated[int, typer.Option("--k", help="Neighbours considered (default 1).")] = 1, + metric: Annotated[str, typer.Option("--metric", help="euclidean | cosine.")] = "euclidean", + verbose: Opt_Verbose = 0, +) -> None: + """Transfer annotations to query proteins from nearest reference neighbours.""" + setup_logging(verbose) + + import io + + import pyarrow.parquet as pq + + from protspace.analysis.classification import Rule + from protspace.data.io.bundle import read_bundle, replace_annotations_in_bundle + from protspace.data.loaders import load_h5 + + def _parse_where(items: list[str] | None) -> list[tuple[str, str]]: + clauses = [] + for item in items or []: + if "~" not in item: + raise typer.BadParameter(f"--*-where must be col~substr, got {item!r}") + col, sub = item.split("~", 1) + clauses.append((col, sub)) + return clauses + + query_rule = Rule(id_prefixes=query_id_prefix or [], where=_parse_where(query_where)) + reference_rule = Rule( + id_prefixes=reference_id_prefix or [], where=_parse_where(reference_where) + ) + + # Load embeddings (name override after ':'). + h5_spec = embeddings.split(":", 1) + h5_path = Path(h5_spec[0]) + name_override = h5_spec[1] if len(h5_spec) == 2 else None + emb_set = load_h5([h5_path], name_override=name_override) + emb_map = { + header: emb_set.data[i] for i, header in enumerate(emb_set.headers) + } + + # Read the annotations part of the bundle. + parts, _settings = read_bundle(bundle) + annotations = pq.read_table(io.BytesIO(parts[0])) + + augmented = run_transfer( + annotations=annotations, + embeddings=emb_map, + transfer_columns=transfer_columns, + query_rule=query_rule, + reference_rule=reference_rule, + k=k, + metric=metric, + ) + + replace_annotations_in_bundle(bundle, output, augmented) + logger.info("Wrote transferred bundle to %s", output) +``` + +- [ ] **Step 4: Register the subcommand** + +Edit `src/protspace/cli/app.py`. In `_register_commands()` (lines 65-73), add `transfer` to the import list (keep alphabetical): + +```python + from protspace.cli import ( # noqa: F401 + annotate, + bundle, + embed, + prepare, + project, + serve, + style, + transfer, + ) +``` + +- [ ] **Step 5: Run to verify pass** + +Run: `uv run pytest tests/test_transfer_cli.py -v` +Expected: PASS (all 3) + +- [ ] **Step 6: Run the full suite (fast) to check for regressions** + +Run: `uv run pytest tests/ -m "not slow" -q` +Expected: all pass (existing + new) + +- [ ] **Step 7: Lint + commit** + +```bash +uv run ruff check src/protspace/cli/transfer.py src/protspace/cli/app.py tests/test_transfer_cli.py +git add src/protspace/cli/transfer.py src/protspace/cli/app.py tests/test_transfer_cli.py +git commit -m "feat: add 'protspace transfer' annotation-transfer subcommand" +``` + +--- + +## Task 10: End-to-end smoke test through a real bundle round-trip + +**Files:** +- Test: `tests/test_transfer_cli.py` (append) + +- [ ] **Step 1: Write the failing end-to-end test** + +Append to `tests/test_transfer_cli.py`: + +```python +def test_end_to_end_bundle_roundtrip(tmp_path): + """Build a tiny bundle + h5, run the CLI, read the overlay back.""" + import io + + import h5py + import pyarrow.parquet as pq + from typer.testing import CliRunner + + from protspace.cli.app import app + from protspace.data.io.bundle import read_bundle, write_bundle + + annotations = pa.table( + {"identifier": ["TRINITY_1", "P00001"], "protein_category": ["", "neurotoxin"]} + ) + proj_meta = pa.table({"name": ["PCA 2"], "dims": [2]}) + proj_data = pa.table({"id": ["TRINITY_1", "P00001"], "x": [0.0, 9.0], "y": [0.0, 0.0]}) + bundle_path = tmp_path / "in.parquetbundle" + write_bundle([annotations, proj_meta, proj_data], bundle_path) + + h5_path = tmp_path / "emb.h5" + with h5py.File(h5_path, "w") as f: + f.create_dataset("TRINITY_1", data=np.array([0.0, 0.0], dtype=np.float32)) + f.create_dataset("P00001", data=np.array([0.1, 0.0], dtype=np.float32)) + + out_path = tmp_path / "out.parquetbundle" + result = CliRunner().invoke( + app, + [ + "transfer", + "-b", str(bundle_path), + "-e", str(h5_path), + "-t", "protein_category", + "-o", str(out_path), + "--query-id-prefix", "TRINITY_", + "--reference-id-prefix", "P0", + ], + ) + assert result.exit_code == 0, result.output + parts, _ = read_bundle(out_path) + rows = {r["identifier"]: r for r in pq.read_table(io.BytesIO(parts[0])).to_pylist()} + assert rows["TRINITY_1"]["protein_category__pred_value"] == "neurotoxin" + assert rows["TRINITY_1"]["protein_category__pred_source"] == "P00001" +``` + +- [ ] **Step 2: Run to verify it passes (implementation already exists)** + +Run: `uv run pytest tests/test_transfer_cli.py::test_end_to_end_bundle_roundtrip -v` +Expected: PASS. If it fails, fix `cli/transfer.py` until it passes (do not edit the test). + +- [ ] **Step 3: Commit** + +```bash +git add tests/test_transfer_cli.py +git commit -m "test: end-to-end protspace transfer bundle round-trip" +``` + +--- + +## Task 11: Documentation + notebook (required before final commit) + +**Files:** +- Modify: `docs/cli.md`, `docs/annotations.md`, `../CLAUDE.md` +- Create: `notebooks/ProtSpace_Transfer.ipynb` + +- [ ] **Step 1: Document the subcommand in `docs/cli.md`** + +Add a new section (match the heading style of the existing `protspace project` section): + +```markdown +### protspace transfer + +Fill in missing annotation values for query proteins by **Embedding Annotation +Transfer (EAT)** — each query's missing value is transferred from its nearest +annotated reference in pLM embedding space, with a goPredSim reliability-index +confidence (`0.5 / (0.5 + distance)` for Euclidean). + +```bash +protspace transfer \ + -b results.parquetbundle \ + -e embeddings.h5:prot_t5 \ + -t protein_category \ + -o results.parquetbundle \ + --query-id-prefix TRINITY_ \ + --reference-where 'protein_category~neurotoxin' +``` + +Default metric is Euclidean (canonical EAT); `--metric cosine` and `--k N` are +available. Writes `protein_category__pred_value`, `__pred_confidence`, and +`__pred_source` columns into the bundle's annotations table. Distances are +computed in the original embedding space (HDF5), not in the 2-D/3-D projection. + +References: Littmann et al., *Sci Rep* 2021 (DOI 10.1038/s41598-020-80786-0); +Heinzinger et al., *NAR Genom Bioinform* 2022 (DOI 10.1093/nargab/lqac043). +``` + +- [ ] **Step 2: Document the overlay columns in `docs/annotations.md`** + +Add a short subsection documenting the `__pred_value` / `__pred_confidence` / `__pred_source` convention so the `protspace_web` annotation registry can stay aligned: + +```markdown +## Predicted-by-transfer overlay columns + +`protspace transfer` appends three columns per transferred annotation `COL`, +populated only for proteins whose `COL` value was predicted (null otherwise): + +| Column | Type | Meaning | +|--------|------|---------| +| `COL__pred_value` | string | the transferred label | +| `COL__pred_confidence` | float | reliability index in [0, 1] | +| `COL__pred_source` | string | nearest reference protein id | + +The curated `COL` is left untouched. A protein is "predicted" for `COL` when +`COL` is empty but `COL__pred_value` is present. +``` + +- [ ] **Step 3: Update the CLI table in `../CLAUDE.md`** + +In the `## CLI Commands` table (the `protspace/CLAUDE.md` one), add a row: + +```markdown +| `protspace transfer` | Fill missing annotations from nearest reference embeddings (EAT) | +``` + +- [ ] **Step 4: Create the Colab notebook** + +Create `notebooks/ProtSpace_Transfer.ipynb` — a minimal notebook (use `uv run jupytext` or write JSON directly) with: (1) a markdown intro to EAT and `protspace transfer`, (2) a cell installing protspace, (3) a cell running the example command on a public dataset, (4) a cell reading the `__pred_*` columns back with pandas. Keep it runnable end-to-end. + +Run to validate it parses: `uv run python -c "import json,nbformat; nbformat.read(open('notebooks/ProtSpace_Transfer.ipynb'), as_version=4)"` +Expected: no error. + +- [ ] **Step 5: Final lint of the whole change** + +Run: `uv run ruff check src/ tests/` +Expected: no errors. + +- [ ] **Step 6: Commit docs** + +```bash +git add docs/cli.md docs/annotations.md ../CLAUDE.md notebooks/ProtSpace_Transfer.ipynb +git commit -m "docs: document protspace transfer + prediction overlay columns" +``` + +--- + +## Task 12: Full verification + open a PR + +- [ ] **Step 1: Run the complete fast suite** + +Run: `uv run pytest tests/ -m "not slow" -q` +Expected: all pass. + +- [ ] **Step 2: Confirm the protlabel boundary is still clean** + +Run: `! grep -rqE "import protspace|from protspace" src/protlabel/ && echo "boundary clean"` +Expected: `boundary clean` + +- [ ] **Step 3: Confirm the command is wired** + +Run: `uv run protspace transfer --help` +Expected: help text with `--bundle`, `--embeddings`, `--transfer`, `--metric`, `--k`. + +- [ ] **Step 4: Push the branch and open a PR** + +```bash +git push -u origin +gh pr create --title "feat: protlabel EAT engine + protspace transfer subcommand" \ + --body "Implements the backend of docs/superpowers/specs/2026-06-11-eat-annotation-transfer-design.md (closes #54 backend scope). protlabel = embedding annotation-transfer engine; protspace transfer = CLI that writes a per-cell prediction overlay into the bundle. Frontend rendering is a separate PR." +``` + +--- + +## Self-Review (completed during planning) + +**1. Spec coverage:** +- protlabel engine (spec §4) → Tasks 1–5. ✓ +- Euclidean default + RI formula (spec §3) → Tasks 2, 4. ✓ +- Embedding-space distances, not DR (spec §3) → Task 9 (`run_transfer` reads the HDF5, never projections). ✓ +- Query/reference classifier, no hardcoded biology (spec §5 step 2) → Task 6. ✓ +- Per-cell overlay representation (spec §6.2, user decision) → Tasks 7, 9. ✓ +- Rebuildable sidecar lookup, not in the bundle (spec §6.1) → Task 5 (`Lookup.save/load`). ✓ +- Brute-force default, no ANN (spec §7) → Task 3. ✓ +- One default output table; gating/mining/report opt-in/out-of-scope (spec §5, §13 Q4) → handled by scoping; noted out of scope. ✓ +- Docs + notebook (spec §12) → Task 11. ✓ +- **Deferred to follow-up plans (intentional):** frontend rendering (spec §9), optional gating/mining/report, faiss-cpu, ProtTucker. Noted in the header. + +**2. Placeholder scan:** Every code step contains complete code; commands have expected output; no "TBD"/"handle edge cases". Task 11 Step 4 (notebook) describes cell contents rather than embedding full notebook JSON — acceptable because the artifact is a notebook, not source code, and the validation command is concrete. + +**3. Type consistency:** `Prediction(query_id, label, source_id, distance, reliability, k, metric)` is defined in Task 4 and used identically in Tasks 5, 7, 9. `Rule(id_prefixes, where)` defined in Task 6, used in Tasks 9, 10. `nearest()->(idx, dist)`, `eat(...)->list[Prediction]`, `add_overlay_columns(table, column, predictions)->Table`, `replace_annotations_in_bundle(input, output, table)`, `run_transfer(...)->Table` — signatures consistent across tasks. Overlay column names (`__pred_value/__pred_confidence/__pred_source`) identical in Tasks 7, 9, 10, 11. ✓ diff --git a/docs/superpowers/specs/2026-05-27-neighbors-subcommand-design.md b/docs/superpowers/specs/2026-05-27-neighbors-subcommand-design.md new file mode 100644 index 00000000..3ed151d4 --- /dev/null +++ b/docs/superpowers/specs/2026-05-27-neighbors-subcommand-design.md @@ -0,0 +1,14 @@ +# Design: `protspace neighbors` — reproducible proximity mining + +> **⚠️ Superseded (2026-06-11).** This early draft scoped a single `protspace neighbors` +> subcommand and defaulted to cosine distance. It has been replaced by +> [`2026-06-11-eat-annotation-transfer-design.md`](./2026-06-11-eat-annotation-transfer-design.md), +> which: +> - splits the work into a **`protlabel`** engine (the EAT lookup, per GitHub issue #54) + +> a thin **`protspace transfer`** subcommand; +> - flips the default metric to **Euclidean** (canonical EAT) with cosine opt-in; +> - adopts the goPredSim **reliability index** as the confidence column; +> - specifies **storage** (reference lookup as a rebuildable sidecar, prediction overlay in the bundle), +> **compute feasibility**, and the **frontend representation** (extending PR #272). +> +> Kept for history only. Read the 2026-06-11 spec instead. diff --git a/docs/superpowers/specs/2026-06-11-eat-annotation-transfer-design.md b/docs/superpowers/specs/2026-06-11-eat-annotation-transfer-design.md new file mode 100644 index 00000000..610b3bf7 --- /dev/null +++ b/docs/superpowers/specs/2026-06-11-eat-annotation-transfer-design.md @@ -0,0 +1,337 @@ +# Design: Embedding Annotation Transfer (`protlabel` engine + `protspace transfer` subcommand) + +**Status:** Draft for review +**Date:** 2026-06-11 +**Supersedes:** `2026-05-27-neighbors-subcommand-design.md` (the earlier "neighbors-subcommand" draft — this expands its scope, reconciles it with GitHub issue #54 and frontend PR #272, and corrects two defaults). +**Trigger:** Conference feedback (`Conference_feedback/ProtSpaceExtractor_v1.7.4_mod 1.py`) + GitHub issue [#54 "EAT — Embedding Annotation Transfer (protlabel lookup table)"](https://github.com/tsenoner/protspace/issues/54) + frontend PR [protspace_web #272 "mark predictions and surface per-annotation docs"](https://github.com/tsenoner/protspace_web/pull/272). +**Research backing:** Literature + codebase fan-out (8 agents) with adversarial verification of the storage/compute math and the EAT algorithm against primary sources. Citations in §15. + +--- + +## 1. One-paragraph decision + +Build **`protlabel`** as a small standalone uv workspace member — the **Embedding Annotation Transfer (EAT) engine**: nearest-neighbour label transfer in *true pLM embedding space* with a calibrated reliability index. **`protspace`** consumes it through a thin **`protspace transfer`** subcommand that reads a `.parquetbundle` (+ the source HDF5 embeddings), classifies query vs reference proteins, calls `protlabel`, and writes a small **prediction overlay** back into the bundle. The big artifact (the reference embedding matrix / lookup index) is **never** shipped in the bundle — it is a rebuildable **sidecar/cache file**. The tiny artifact (per-protein predicted value + confidence + source neighbour) ships in the bundle and the **frontend renders it as a new "predicted-by-transfer" visual layer** (hollow markers, confidence-driven opacity, provenance tooltip) that is *orthogonal to* PR #272's existing column-level "predicted-by-model" badge. Everything beyond the one default output table is opt-in, so ProtSpace does not get overblown. + +This is the canonical Rost-lab EAT method (Littmann et al. 2021; Heinzinger et al. 2022) — exactly what issue #54 describes — packaged so the conference users' proximity-mining workflow becomes a thin, optional layer on top rather than a parallel reimplementation. + +--- + +## 2. How the three inputs relate + +| Input | What it is | Role in this design | +|---|---|---| +| **Issue #54 (EAT / `protlabel`)** | "Given references with known annotations + embeddings, transfer labels to unknowns by nearest neighbour in *embedding space*, backed by a lookup table." | **The engine.** Defines `protlabel`. This is canonical EAT. | +| **Conference `ProtSpaceExtractor` v1.7.4** | A 1.5K-LOC script doing proximity mining in *DR/projection space* (UMAP/t-SNE coords) with cross-method consensus, an EDD elbow, Venn/agreement sets, neighborhood mining, and a 25-file + HTML report. | **A screening/exploration layer.** Its genuinely-novel parts (neighborhood mining, report) become opt-in extras on `protspace transfer`; its DR-space machinery is mostly subsumed by transferring in true embedding space (see §6 keep/adapt/drop). | +| **PR #272 (frontend)** | Marks whole annotation **columns** as model-predicted (⚡ badge, source grouping, info popovers), deliberately **frontend-only, no data-format change**. | **The base the EAT frontend extends.** EAT introduces *cell-level* predictions (a value-level axis #272 explicitly deferred). The two compose; neither replaces the other. | + +The key realization tying them together: **DR-space proximity (the conference approach) is a lossy proxy for embedding-space proximity (EAT).** Non-linear DR is not isometric, so "nearest in UMAP_3" ≠ "nearest in embedding space." Issue #54 gets this right; the conference script worked around it with consensus/normalization scaffolding that becomes unnecessary once we transfer in the true space. + +--- + +## 3. The canonical method we adopt (EAT), with corrections + +Verified against primary sources (goPredSim / Littmann et al. *Sci Rep* 2021; EAT tool / Heinzinger et al. *NAR Genom Bioinform* 2022): + +- **Space:** original pLM embedding space (mean-pooled per-protein vectors). **Not** DR coordinates. +- **Metric:** **Euclidean (L2)**, default. *Nuance (verifier correction):* the strong "Euclidean beats cosine for pLM embeddings" statement is from the **2022** paper (citing prior work); the **2021** paper only found cosine "changed little." Euclidean is still the right default because it is the canonical tool default and the documented 2022 finding — but the basis is "tool convention + 2022 claim," not "both papers." Cosine stays an opt-in `--metric`. +- **No L2-normalization by default** — canonical EAT uses raw Euclidean; magnitude carries information cosine discards. (Normalization/whitening to fight hubness are research extras, off by default.) +- **k = 1** default (the value eat.py defaults to and the 2021 paper chose after grid search). `k` is exposed; a distance-threshold mode (transfer from all references within distance d) is also supported, per goPredSim. +- **Reliability index (the confidence column).** Adopt goPredSim Eq.(5) **verbatim** — do **not** invent a separate `reliability × vote` product (a verifier flagged that as non-canonical; neighbour agreement is *already* encoded in the formula): + + ``` + RI(p) = (1/k) · Σ_{i : n_i carries label p} s(d(q, n_i)) + + where s(d) = 0.5 / (0.5 + d) for Euclidean (RI = 1.0 at d=0, 0.5 at d=0.5, →0 as d→∞) + s(d) = 1 − d for cosine (clamp to [0,1]; cosine distance ∈ [0,2]) + ``` + + For the default k=1 this collapses to `RI = 0.5/(0.5 + d)`. The `(1/k)·Σ_{neighbours carrying p}` term *is* the multi-neighbour agreement weighting; report `RI` directly as the `[0,1]` confidence. +- **Distance→accuracy calibration (reference point, ProtT5/CATH):** at Euclidean distance ≤ 1.1, ~75% coverage with ~90% accuracy at CATH H-level; ProtTucker (contrastive) reaches ~76% H-level vs raw ProtT5 EAT ~64% and HMMER ~77%. **Caveat (critical):** the `0.5` constant in `s(d)` and the `1.1` threshold are **ProtT5-specific**. ProtSpace supports 12 embedders (320–2560 dim) with different distance scales — RI stays *monotone* (good for ranking) but is **not a calibrated probability** for other models without re-validation. Document this loudly. + +**Output contract (mirror eat.py for interoperability):** per query → `query_id`, transferred `label`, `source_id` (nearest reference), `source_label`, `distance`, `reliability`. Accept goPredSim's 2-column `id → comma-separated labels` lookup-label file so existing EAT/goPredSim lookups drop in. + +**Optional upgrade path (documented, not built first):** ProtTucker-style contrastive projection or CLEAN-style EC centroids as a future "learned distance" mode. Ship raw-embedding Euclidean EAT first — it needs no training and is the published baseline. + +--- + +## 4. Architecture + +``` +protspace/ # the protspace repo (also the build root) +├── pyproject.toml # hatchling builds BOTH packages; scipy added as a dep +└── src/ + ├── protlabel/ # NEW second top-level package — the EAT engine (issue #54) + │ ├── __init__.py # public API: eat(), Lookup, Prediction + │ ├── reliability.py # goPredSim distance→[0,1] reliability transform + │ ├── backends.py # brute-force (default) | faiss (optional, later) NN search + │ ├── transfer.py # kNN + label transfer + reliability index (RI) + │ └── lookup.py # build / save / load the reference lookup sidecar (.npz) + └── protspace/ + ├── cli/transfer.py # NEW thin Typer subcommand (glue only, ~150 LOC) + ├── analysis/ # NEW — classifier now; optional gating/mining later + │ └── classification.py # query/reference classifier (no hardcoded biology) + ├── data/io/bundle.py # EXTEND: replace_annotations_in_bundle() (in-place part-1 rewrite) + └── data/io/predictions.py # NEW: build the per-cell overlay columns +``` + +**Packaging decision (refines the spec):** the suite root is *not* a uv workspace, and `protspace` +publishes to PyPI — so a separate `protlabel` distribution would force its own PyPI release + CI +changes. For the MVP, `protlabel` ships as a **second top-level package inside the protspace repo** +(`src/protlabel/`), bundled into the protspace wheel via +`[tool.hatch.build.targets.wheel] packages = ["src/protspace", "src/protlabel"]`. A strict +**no-`protspace`-imports boundary** keeps it independently testable and reusable, and makes a future +promotion to a standalone PyPI package / uv workspace member a clean, mechanical split. The optional +gating/mining/report and faiss backend are deferred to follow-up work (kept out so ProtSpace stays lean). + +**The boundary (why a submodule, not just a subcommand):** + +- **`protlabel` is pure and ProtSpace-agnostic.** In: reference embeddings (`ndarray` + ids + labels) and query embeddings (`ndarray` + ids). Out: per-query nearest neighbour(s), distance, reliability, transferred label(s). It owns the **lookup table** (issue #54's core artifact): building it, serializing it to a sidecar, and querying it. Usable as a standalone `eat`-style tool and reusable by other projects (`protspace_uniprot`, notebooks). +- **`protspace transfer` is glue.** It knows about `.parquetbundle`, HDF5 loading (`load_h5`), query/reference classification, and writing the overlay back. It contains **no distance math** — that lives in `protlabel`. +- **`protspace/analysis/`** holds the *optional* conference-derived screening (gating, neighborhood mining). Default runs skip it entirely. + +This keeps `protspace` from getting overblown: the heavy/clever code is isolated in a focused library; the subcommand stays small; the extras are opt-in modules that don't load unless requested. + +--- + +## 5. Algorithm — "best of both worlds" pipeline + +Minimal viable command produces a useful, calibrated transfer table from **embedding-space 1-NN + reliability index alone**. Everything else is a flag. + +1. **Load embeddings** from the source HDF5 (`load_h5`: float16→float32 upcast, dim validation, reject per-residue). Mean-pool already done upstream (per-protein vectors). +2. **Classify** queries vs references by ID-prefix and/or metadata-substring rules (CLI flags or YAML). **No hardcoded biology** (drop the v1.7.4 `TRINITY_`/`mscr` fallback). Error clearly if no query rule matches anything. +3. **kNN in true embedding space** (`protlabel.transfer`). Default metric **Euclidean**; `--metric {euclidean,cosine}`. Brute-force chunked search (default); faiss backend if installed. Take the `k` nearest *references* per query (default k=1). +4. **Transfer label + reliability** via Eq.(5) above → the primary `[0,1]` confidence. This *replaces* consensus/EDD as the headline number. +5. *(opt-in)* **Confidence gate** for batch triage: `--cutoff {fixed,reliability,percentile,edd}`. Default-if-requested = fixed distance/reliability tied to measured accuracy (EAT-style). If `edd`, compute `max(Kneedle distance-to-chord, median-jump)` **on the embedding-space distance curve**, clearly labeled a heuristic soft gate — never as calibrated confidence. +6. *(opt-in)* **High-precision subset** = `reliability ≥ threshold AND k-NN vote unanimous` (the embedding-space replacement for v1.7.4 "Overlapped"). +7. *(opt-in)* **Neighborhood mining** (`--mine`/`--top-n`): top-N nearest items around each reference/confident query, with recurrence counts and a non-redundant pooled panel. Pure exploration, decoupled from confidence. +8. **Output** one tidy `predictions.parquet` by default (§7.2). Extras opt-in: `--report` (Jinja2 HTML), `--plots`, `--full-tables` (reproduce the v1.7.4 25-file layout for the conference users). + +### Keep / adapt / drop the conference ideas (verified) + +| v1.7.4 idea | Verdict | Why | +|---|---|---| +| Rank-percentile normalization | **DROP** as core / keep as optional descriptor | Exists only to make incomparable DR scales (PCA vs UMAP vs t-SNE) summable. One embedding-space metric + RI makes it unnecessary. | +| Cross-DR-method consensus (0–6) | **DROP** | The 6 projections are deterministic lossy shadows of the *same* embedding; their agreement measures DR stability, not biological confidence — zero independent evidence once you transfer in the source space. (At most an opt-in "projection agreement" QC diagnostic, never labeled "confidence".) | +| EDD elbow (Kneedle ∨ median-jump) | **DEMOTE** to optional soft gate | Adaptive and parameter-free, but statistically uncalibrated (curve-shape dependent, Kneedle is noise-sensitive) — unlike EAT's accuracy-tied 1.1. Recompute in embedding space; offer as one `--cutoff` option, not the default. | +| N-way "Overlapped" agreement set | **ADAPT** | Redefine from "UMAP_3 ∩ TSNE_3 survivors" to "reliability ≥ t AND vote unanimous" — same high-precision-subset value, calibrated terms. | +| Top-N neighborhood mining | **KEEP** (opt-in) | The strongest survivor: metric-agnostic, genuinely useful for cluster expansion / focused re-runs. Gate behind `--mine`. | +| 25-file output + Venns/coverage maps/graphs | **DROP** as default / keep behind `--plots`/`--full-tables` | Venns specifically lose meaning once consensus is dropped. | +| One-page HTML report | **KEEP** (opt-in `--report`) | Useful sharing artifact; Jinja2 template, not `__doc__`-string injection. | + +> **Framing for the conference users:** present embedding-space transfer as a strict upgrade that *subsumes* their goals (a single calibrated confidence instead of a 6-way consensus proxy), and keep their exact workflow reproducible via `--full-tables`. Their contribution is acknowledged and preserved, not discarded. + +--- + +## 6. Storage & data representation — the user's "is it too large?" question + +Two artifacts with **very different** sizes. Treat them differently. + +### 6.1 The reference lookup (BIG) → sidecar / cache, never in the bundle + +Reference embedding matrix size (N proteins × D dims; binary units, ≈5% smaller than SI): + +| N | D=1024 (ProtT5) fp32 / fp16 | D=2560 (ESM2-3B) fp32 / fp16 | +|---|---|---| +| 1,000 | 3.9 MiB / 2.0 MiB | 9.8 MiB / 4.9 MiB | +| 10,000 | 39 MiB / 20 MiB | 98 MiB / 49 MiB | +| 100,000 | 391 MiB / 196 MiB | 977 MiB / 489 MiB | +| **573,000 (Swiss-Prot)** | **2.19 GiB / 1.09 GiB** | **5.47 GiB / 2.73 GiB** | + +The `.parquetbundle` is a ~45 MB portable viz payload. **Embedding a 1–5.5 GiB matrix into it is the wrong call** — it would bloat every download for a feature most viz users never touch, and it is rebuildable from the source HDF5 anyway. **Decision:** + +- **`protlabel` writes the lookup as a sidecar file** next to the bundle (or in `~/.cache/protspace/`): raw fp16 `.npy`/`.h5` for small sets, a serialized faiss index for large sets. +- **Rebuildable on demand:** if the sidecar is absent/stale, regenerate from the embeddings HDF5 (brute force needs nothing; a faiss build for 573K is seconds-to-low-minutes on CPU). +- **Optional compression** (faiss IVF-PQ) shrinks the *whole* of Swiss-Prot dramatically: `m=64 → ~35 MiB` (64× at D=1024, 160× at D=2560), `m=32 → ~17 MiB`, `m=16 → ~9 MiB`. Add an exact-distance **rerank** of top candidates to recover the recall bare 8-bit PQ loses (~70% → ~90–95%). So a "store it small" option exists if a user *does* want it portable — but it is opt-in, validated, and still a sidecar. + +The user's instinct — *"have it as an optional file in tmp"* — is exactly right and is the recommended default. + +### 6.2 The prediction overlay (TINY) → ships in the bundle + +The per-protein result is small and **sparse** (only proteins that got a transferred value). Store as a **new optional parquet table `predicted_annotations`** appended after the existing parts (the bundle is delimiter-separated and length-extensible; old readers that read parts 1–4 ignore it → backward compatible). Long format, one row per (protein, predicted column): + +| column | type | notes | +|---|---|---| +| `identifier` | string | protein id | +| `annotation_column` | string | which annotation was transferred (e.g. `protein_category`) | +| `predicted_value` | string | the transferred label | +| `reliability` | float32 | the `[0,1]` confidence (RI) | +| `distance` | float32 | embedding distance to the source | +| `source_id` | string | nearest reference protein | +| `k`, `method`, `model` | small | provenance (e.g. `k=1`, `euclidean`, `prot_t5`) | + +Even for tens of thousands of predicted cells this is well under a megabyte. **Do not** store full neighbour lists per cell for 570K × many columns — keep top-1 `source_id` + `k` + `method`; fetch fuller neighbour lists lazily only if a richer hover is ever needed. + +**Representation model (chosen): per-cell overlay on the original column.** EAT *fills missing values inside an existing annotation column* and marks those cells predicted, so the scatter shows curated (filled) + transferred (hollow) points together in one colour scheme — far more useful than a separate `predicted_` column. (A separate-column model, like Biocentral, is the simpler fallback if the overlay proves too invasive.) + +> **Note on PR #272's contract:** #272 was deliberately *no-data-format-change*. EAT is precisely the feature that introduces **value-level** predicted metadata into the bundle — the axis #272 deferred. That is expected and intended; §10 keeps the two axes cleanly separated. + +--- + +## 7. Compute & feasibility verdict + +**Brute-force kNN is laptop-feasible across the entire range, including full Swiss-Prot.** Measured (Apple Silicon, chunked numpy GEMM + argpartition; reproduced by an independent verifier within ~10–25%): + +| Query batch × references × dim | wall time | +|---|---| +| 1,000 × 100K × 1024 | ~0.8–0.9 s | +| 1,000 × 573K × 1024 | ~4–4.6 s (~4 ms/query) | +| 1,000 × 573K × 2560 | ~6 s (~6 ms/query) | +| single query × 573K | ~4–6 ms | + +**The binding constraint is RAM (to hold the reference matrix), not compute.** Mitigation: load the reference as fp16 and upcast per chunk, chunk the N axis so the Q×N distance block never materializes at full size. This stays within a 16 GB laptop at D=1024 and is borderline-but-workable at D=2560. Older Intel/CI machines run ~2–5× slower but stay sub-minute for a few queries at Swiss-Prot scale. + +**Conclusion:** the entire feature is feasible and *not* compute-intensive at realistic scales. **No ANN index is needed for speed** — exact search is already ~ms/query. ANN's only justification here is *shrinking the stored reference* (PQ), which is opt-in. + +**Default:** exact brute force (numpy/scipy/sklearn — already deps) up to ~100–200K references, and still usable to full Swiss-Prot. **Optional accelerator:** `protspace[ann]` extra → **faiss-cpu** (best wheel coverage: macOS arm64+x86_64, manylinux x86_64+aarch64, Windows; pacmap 0.9.x already adopts faiss). Reject hnswlib (sdist-only, needs a compiler) and ScaNN (no macOS wheels) as cross-platform CLI deps. *(scipy is not currently an explicit `protspace` dep — sklearn pulls it transitively; add it explicitly if using `scipy.spatial.distance.cdist`, or use `sklearn.neighbors.NearestNeighbors(algorithm='brute')`.)* + +--- + +## 8. CLI design + +```bash +# Minimal: transfer one annotation, default Euclidean 1-NN, one output table +protspace transfer \ + --bundle results.parquetbundle \ + --embeddings emb.h5:prot_t5 \ + --transfer protein_category \ + --query-id-prefix TRINITY_ \ + --reference-where 'protein_category~neurotoxin' \ + --out results.parquetbundle # writes the overlay back into the bundle (or a sidecar) + +# Tuning + optional screening +protspace transfer ... --metric euclidean --k 3 \ + --cutoff reliability --min-reliability 0.6 \ + --mine --top-n 5 \ + --report --plots +``` + +| Flag | Default | Purpose | +|---|---|---| +| `--bundle` | required | the `.parquetbundle` to annotate | +| `--embeddings h5[:model]` | required | source embeddings for true-space distance | +| `--transfer COL` | required | annotation column to transfer (repeatable) | +| `--query-* / --reference-*` | required (≥1 query rule) | classification (prefix / `col~substr`); or `--rules rules.yaml` | +| `--metric {euclidean,cosine}` | `euclidean` | **reconciled from the old cosine default** | +| `--k` | `1` | neighbours considered (Eq.5) | +| `--cutoff {none,fixed,reliability,percentile,edd}` | `none` | opt-in confidence gate | +| `--mine`, `--top-n` | off | opt-in neighborhood mining | +| `--lookup PATH` | auto sidecar | reuse/persist the reference lookup | +| `--report`, `--plots`, `--full-tables`, `--excel` | off | opt-in artifacts | + +Subcommand name is an **open question** (§13): `transfer` (clear verb), `eat` (matches #54 / Rost-lab convention), or `neighbors` (old draft). + +--- + +## 9. Frontend representation (extends PR #272, does not duplicate it) + +**Two orthogonal axes — codify this mental model:** + +- **Axis A (existing, #272): column-level provenance** — "this whole column is a model output" (Biocentral / Phobius / TED). Keep `AnnotationMeta.isPredicted`, the ⚡ dropdown/legend badge, and the info-popover **unchanged**. +- **Axis B (new, EAT): cell-level provenance** — "this specific protein's value was *transferred from a neighbour*, confidence X, source Y." New visual language below. Never overload the ⚡ badge to mean both. + +### 9.1 Scatter plot — the primary cue is *shape*, not colour + +- **Observed/curated cells → filled markers** (current behaviour). **EAT-imputed cells → hollow (outline-only) markers in the same category hue**, so cluster identity is preserved while provenance reads at a glance. This is an established convention (filled = observed, open = imputed) and satisfies "never colour-only" (accessibility; ~4% CVD). + - Implementable in the existing WebGL renderer: add a per-point `a_predicted` float attribute (mirror the existing `a_shape` plumbing) and a ring-only branch reusing the current edge-distance/outline math (`strokeWidth = 0.15`, `webgl-renderer.ts`). No shader rewrite. +- **Confidence → redundant opacity (and optional size) ramp on imputed points only.** `alpha = lerp(0.25, 0.9, confidence)`; observed points stay at `baseOpacity 0.9`. Optionally scale size by `sqrt(confidence)`. For very low confidence (<0.3), desaturate toward grey (lightweight VSUP). Hooks: `getOpacity`/`getBaseOpacity`/`getPointSize` in `style-getters.ts`. + +### 9.2 Tooltip — per-point provenance line + +Extend `AnnotationBlock` + `renderAnnotationBlock` (`protein-tooltip.ts`) with an EAT row, distinct from observed values: + +> ⚡ **Predicted:** Neurotoxin (82%) — transferred from **P12345** via ProtT5, k=1 + +with an inline confidence bar and the source id as a **click target** that selects/centres that reference in the scatter. Observed values render exactly as today (no chip). + +### 9.3 Legend — a separate "Predicted (transferred)" sub-section + +When the active annotation has any imputed cells, render a small group with two swatches — **filled = "Observed"**, **hollow = "Predicted by EAT"** — and a note "Faint = low confidence", plus live counts ("1,204 shown / 380 below threshold"). Add as a new optional block in `legend-renderer.ts` (alongside `renderHeader`). **Do not** merge into the ⚡ header badge (that is Axis A). + +### 9.4 Global control — one "Predicted annotations" group near the dropdown/legend + +- **Toggle "Show predicted annotations"** (off → imputed cells render neutral/N-A; only the curated layer shows). +- **Confidence-threshold slider** 0–100% with conventional bands (High >80 / Med 50–80 / Low <50); below-threshold imputed points **fade** (`fadedOpacity 0.15`) rather than vanish, preserving layout context. +- Feed `showPredicted` + `minConfidence` into `StyleConfig`; persist in `LegendPersistedSettings` so the choice survives reload/export. Keyboard-operable with `aria-valuetext`. + +### 9.5 Data-model extension (frontend) + +Mirror the existing parallel-array pattern (`annotation_scores`, `annotation_evidence` in `types.ts`): + +```ts +// VisualizationData (optional, populated only when the bundle carries the overlay) +annotation_predicted?: Record; +// PredictedCell = { confidence: number; sourceId: string; k?: number; method?: string } +``` + +Loader (`data-loader/utils/bundle.ts`) pivots the sparse `predicted_annotations` table into these arrays at parse time. Backward compatible: old bundles lack the table → no overlay; the parser already tolerates unknown columns/parts. + +### 9.6 Frontend gotchas to respect + +- Multi-label cells: treat a cell as imputed **only if all its values were transferred**; otherwise show observed with a tooltip note. +- Selection opacity must override confidence dimming (a clicked low-confidence point stays visible). +- Grayscale/PNG export: hollow-vs-filled must be the load-bearing cue (opacity alone is ambiguous in print). The export path renders the same shader, so hollow survives export — verify at 570K points. +- This is a **separate frontend PR** (depends on the backend emitting the overlay) and warrants its own OpenSpec change in `protspace_web`, building on #272's `annotation-metadata`/`annotation-presentation` capabilities. + +--- + +## 10. Dependencies & packaging + +- **`protlabel`** uses only `numpy`, `scipy`, `h5py` (all already in the ProtSpace stack). It is a **second top-level package in the protspace repo** (`src/protlabel/`), built into the protspace wheel via `[tool.hatch.build.targets.wheel] packages = ["src/protspace", "src/protlabel"]` — *not* a suite-level uv workspace member (the suite root is not a uv workspace, and a separate distribution would need its own PyPI release/CI). It imports nothing from `protspace` (a guarded boundary), so a future standalone split is mechanical. +- Add **`scipy`** to `protspace` via `uv add 'scipy>=1.10'` (currently only transitive via sklearn; `cdist` needs it explicit). +- **faiss-cpu** is a *future* optional accelerator (`protspace[ann]`), out of MVP scope. Rejected for cross-platform CLI: hnswlib (sdist-only) and ScaNN (no macOS wheels). + +--- + +## 11. Testing + +`protlabel` (engine, fast, no ProtSpace deps): +- `test_transfer.py` — synthetic ref/query sets with known nearest neighbours (Euclidean + cosine); RI values at known distances (`d=0→1.0`, `d=0.5→0.5`); k=1 vs k>1 agreement weighting. +- `test_lookup.py` — build/serialize/load round-trip; sidecar rebuild-on-demand; brute-force vs faiss parity (when faiss installed). + +`protspace`: +- `test_transfer_cli.py` — end-to-end on `data/sizes/phosphatase.h5` after `protspace prepare`; overlay table round-trips through the bundle; old (4-part) bundles still read. +- `test_classification.py` — prefix/substring/case rules, empty-query error (no hardcoded biology). +- `test_gating.py` — fixed/reliability/percentile/EDD on known-elbow curves (incl. degenerate `n<15` fallback). + +`protspace_web` (in its own PR): unit tests for the overlay parser + `style-getters` predicted branch; browser checks for hollow markers, legend sub-section, threshold slider, grayscale export. + +--- + +## 12. Docs & notebook + +- `protspace/docs/cli.md`: new `### protspace transfer` section; cite EAT papers. +- `protspace/docs/annotations.md`: document the predicted-overlay columns so `protspace_web`'s registry stays aligned (the #272 contract note already points here). +- New `notebooks/ProtSpace_Transfer.ipynb` — a lean annotation-transfer story on a public dataset (not the full v1.7.4 reproduction). +- Update top-level `CLAUDE.md` CLI table; pre-commit checklist (ruff + docs + notebook). +- `protspace_web/docs/guide/annotations.md` generator: extend for the predicted-by-transfer legend/UX. + +--- + +## 13. Open questions / decisions for the user + +1. **Subcommand name:** `transfer` (clear verb, recommended), `eat` (matches #54 / Rost-lab brand), or `neighbors` (old draft)? +2. **Overlay vs new column:** per-cell overlay on the original column (recommended, richer viz) or a separate `predicted_` column (simpler, mirrors Biocentral)? +3. **`protlabel` scope:** EAT engine only (recommended first), or also bundle the ProtTucker/CLEAN "learned distance" mode now? +4. **Default cutoff:** ship with `--cutoff none` (transfer everything, let the user filter on reliability — recommended) or a default reliability floor? +5. **Frontend timing:** build the backend + overlay first and the frontend PR second (recommended), or design both in lockstep? +6. **Reconcile or supersede #54:** post this design as the plan on issue #54 and keep #54 as the engine tracking issue? + +## 14. Non-goals + +- No statistical FDR/hypothesis testing (RI is a heuristic ranking, ProtT5-calibrated only). +- No automatic UniProt fetch of references — references are whatever is already in the bundle. +- No legacy `output.json` support (point users to `protspace bundle`). +- No shipping ProtTucker weights initially (raw-embedding EAT needs no training). +- The frontend work is a *separate* PR; this spec only specifies the representation. + +## 15. Citations + +- Littmann, Heinzinger, Olenyi, Dallago, Rost. "Embeddings from protein language models predict conservation and ... / Embedding-based annotation transfer." *Sci Rep* 2021. DOI 10.1038/s41598-020-80786-0 — **RI formula (Eq. 5), GO results.** +- Heinzinger, Littmann, Sillitoe, Bordin, Orengo, Rost. "Contrastive learning on protein embeddings enlightens midnight zone." *NAR Genom Bioinform* 2022. DOI 10.1093/nargab/lqac043 — **generic EAT tool, Euclidean>cosine, CATH 1.1-threshold calibration, ProtTucker.** +- goPredSim — https://github.com/Rostlab/goPredSim (reliability code, 2-column label format). EAT tool — https://github.com/Rostlab/EAT (`eat.py` interface). +- CLEAN (EC, contrastive + centroids), *Science* 2023, DOI 10.1126/science.adf2465 — documented learned-distance upgrade path. +- VSUP (Correll, Moritz, Heer, CHI 2018); ScatterUQ (arXiv 2308.04588); imputed-vs-observed marker convention (mclust) — frontend uncertainty UX. + +--- + +*Verification note: an adversarial pass confirmed the RI formula, Euclidean default, k=1, the eat.py contract, and the 1.1/≈90% calibration against primary sources; it corrected the historical framing (Euclidean>cosine is the 2022 paper's claim, not the 2021 one), the CATH comparison numbers (raw ProtT5 EAT H-level ≈64, HMMER ≈77, MMseqs2 ≈35), and the confidence aggregation (use Eq.5 directly, not a `reliability×vote` product). The storage/compute math reproduced exactly (binary units; the SI labels were ~5% low) and the "sidecar not bundle / brute-force default" recommendations were judged sound.* From 70881d7b9992de29ad3b37c14ea09af48d4b060a Mon Sep 17 00:00:00 2001 From: tsenoner Date: Thu, 11 Jun 2026 19:19:19 +0200 Subject: [PATCH 02/21] chore(protlabel): scaffold EAT engine package + scipy dep Co-Authored-By: Claude Sonnet 4.6 --- pyproject.toml | 4 ++++ src/protlabel/__init__.py | 7 +++++++ tests/test_protlabel_reliability.py | 7 +++++++ uv.lock | 2 ++ 4 files changed, 20 insertions(+) create mode 100644 src/protlabel/__init__.py create mode 100644 tests/test_protlabel_reliability.py diff --git a/pyproject.toml b/pyproject.toml index 9caf3eb9..83ef3eb3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "requests>=2.32.4", "typer>=0.24.1", "rich>=14.3.3", + "scipy>=1.10", ] [project.optional-dependencies] @@ -70,6 +71,9 @@ protspace = "protspace.cli.app:app" requires = ["hatchling"] build-backend = "hatchling.build" +[tool.hatch.build.targets.wheel] +packages = ["src/protspace", "src/protlabel"] + [tool.pytest.ini_options] markers = [ "slow: marks tests as slow (deselect with '-m \"not slow\"')", diff --git a/src/protlabel/__init__.py b/src/protlabel/__init__.py new file mode 100644 index 00000000..c255e887 --- /dev/null +++ b/src/protlabel/__init__.py @@ -0,0 +1,7 @@ +"""protlabel — Embedding Annotation Transfer (EAT) engine. + +Nearest-neighbour label transfer in protein-language-model embedding space, +with the goPredSim reliability index. Pure numpy/scipy/h5py; no protspace imports. +""" + +__version__ = "0.1.0" diff --git a/tests/test_protlabel_reliability.py b/tests/test_protlabel_reliability.py new file mode 100644 index 00000000..6035d1a4 --- /dev/null +++ b/tests/test_protlabel_reliability.py @@ -0,0 +1,7 @@ +"""Tests for protlabel.reliability.""" + + +def test_protlabel_imports(): + import protlabel + + assert protlabel.__version__ diff --git a/uv.lock b/uv.lock index 9a2d779e..4237dccc 100644 --- a/uv.lock +++ b/uv.lock @@ -2738,6 +2738,7 @@ dependencies = [ { name = "requests" }, { name = "rich" }, { name = "scikit-learn" }, + { name = "scipy" }, { name = "tqdm" }, { name = "typer" }, { name = "umap-learn" }, @@ -2805,6 +2806,7 @@ requires-dist = [ { name = "requests", marker = "extra == 'frontend'", specifier = ">=2.32.4" }, { name = "rich", specifier = ">=14.3.3" }, { name = "scikit-learn", specifier = ">=1.6.1" }, + { name = "scipy", specifier = ">=1.10" }, { name = "tqdm", specifier = ">=4.67.1" }, { name = "typer", specifier = ">=0.24.1" }, { name = "umap-learn", specifier = ">=0.5.10" }, From ee482ba37191ad4dd9f675440ab2c442713de385 Mon Sep 17 00:00:00 2001 From: tsenoner Date: Thu, 11 Jun 2026 19:22:55 +0200 Subject: [PATCH 03/21] feat(protlabel): goPredSim reliability index transform Co-Authored-By: Claude Sonnet 4.6 --- src/protlabel/reliability.py | 18 ++++++++++++++ tests/test_protlabel_reliability.py | 37 ++++++++++++++++++++++++++--- 2 files changed, 52 insertions(+), 3 deletions(-) create mode 100644 src/protlabel/reliability.py diff --git a/src/protlabel/reliability.py b/src/protlabel/reliability.py new file mode 100644 index 00000000..7d9a0c11 --- /dev/null +++ b/src/protlabel/reliability.py @@ -0,0 +1,18 @@ +"""goPredSim reliability index: map an embedding distance to a [0,1] confidence. + +Euclidean: s(d) = 0.5 / (0.5 + d) (1.0 at d=0, 0.5 at d=0.5, ->0 as d->inf) +Cosine: s(d) = 1 - d (clamped to [0,1]; cosine distance in [0,2]) + +Reference: Littmann et al., Sci Rep 2021 (Eq. 5); goPredSim calc_reliability_index. +""" + +from __future__ import annotations + + +def similarity(distance: float, metric: str) -> float: + """Per-neighbour distance->similarity (the goPredSim reliability transform).""" + if metric == "euclidean": + return 0.5 / (0.5 + distance) + if metric == "cosine": + return min(1.0, max(0.0, 1.0 - distance)) + raise ValueError(f"Unknown metric {metric!r}; expected 'euclidean' or 'cosine'") diff --git a/tests/test_protlabel_reliability.py b/tests/test_protlabel_reliability.py index 6035d1a4..d40bec60 100644 --- a/tests/test_protlabel_reliability.py +++ b/tests/test_protlabel_reliability.py @@ -1,7 +1,38 @@ """Tests for protlabel.reliability.""" +import math -def test_protlabel_imports(): - import protlabel +import pytest - assert protlabel.__version__ +from protlabel.reliability import similarity + + +def test_euclidean_at_zero_distance_is_one(): + assert similarity(0.0, "euclidean") == pytest.approx(1.0) + + +def test_euclidean_at_half_distance_is_half(): + assert similarity(0.5, "euclidean") == pytest.approx(0.5) + + +def test_euclidean_decreases_to_zero(): + assert similarity(1e9, "euclidean") == pytest.approx(0.0, abs=1e-6) + + +def test_cosine_is_one_minus_distance(): + assert similarity(0.2, "cosine") == pytest.approx(0.8) + + +def test_cosine_clamped_to_unit_interval(): + # cosine distance can be up to 2.0 -> 1 - d would go negative; clamp at 0 + assert similarity(1.7, "cosine") == pytest.approx(0.0) + assert similarity(-0.1, "cosine") == pytest.approx(1.0) + + +def test_unknown_metric_raises(): + with pytest.raises(ValueError): + similarity(0.5, "manhattan") + + +def test_smoke(): + assert math.isfinite(similarity(0.5, "euclidean")) From 4e99e8d6f88f9259a54bf791f55eaec845a14fef Mon Sep 17 00:00:00 2001 From: tsenoner Date: Thu, 11 Jun 2026 19:26:48 +0200 Subject: [PATCH 04/21] feat(protlabel): chunked brute-force kNN backend Co-Authored-By: Claude Sonnet 4.6 --- src/protlabel/backends.py | 51 ++++++++++++++++++++++++++ tests/test_protlabel_backends.py | 63 ++++++++++++++++++++++++++++++++ 2 files changed, 114 insertions(+) create mode 100644 src/protlabel/backends.py create mode 100644 tests/test_protlabel_backends.py diff --git a/src/protlabel/backends.py b/src/protlabel/backends.py new file mode 100644 index 00000000..c6408a89 --- /dev/null +++ b/src/protlabel/backends.py @@ -0,0 +1,51 @@ +"""Exact (brute-force) k-nearest-neighbour search over reference embeddings. + +Chunked over the query axis so the Q_chunk x N distance block stays small, +which keeps peak memory near the reference matrix itself even at Swiss-Prot +scale. scipy.cdist handles both euclidean and cosine. +""" + +from __future__ import annotations + +import numpy as np +from scipy.spatial.distance import cdist + +_METRICS = {"euclidean", "cosine"} + + +def nearest( + queries: np.ndarray, + refs: np.ndarray, + k: int, + metric: str = "euclidean", + chunk: int = 4096, +) -> tuple[np.ndarray, np.ndarray]: + """Return (idx, dist) of the k nearest *references* per query. + + idx[i] -> indices into ``refs`` of the k nearest, ascending by distance. + dist[i] -> the corresponding distances. + k is capped to the number of references. + """ + if metric not in _METRICS: + raise ValueError(f"Unknown metric {metric!r}; expected one of {_METRICS}") + + queries = np.ascontiguousarray(queries, dtype=np.float32) + refs = np.ascontiguousarray(refs, dtype=np.float32) + n_refs = refs.shape[0] + k = min(k, n_refs) + + idx_out = np.empty((queries.shape[0], k), dtype=np.int64) + dist_out = np.empty((queries.shape[0], k), dtype=np.float32) + + for start in range(0, queries.shape[0], chunk): + block = queries[start : start + chunk] + d = cdist(block, refs, metric=metric).astype(np.float32) # (b, n_refs) + part = np.argpartition(d, kth=k - 1, axis=1)[:, :k] # unsorted top-k + rows = np.arange(block.shape[0])[:, None] + part_d = d[rows, part] + order = np.argsort(part_d, axis=1) # sort the k by distance + sorted_idx = part[rows, order] + idx_out[start : start + block.shape[0]] = sorted_idx + dist_out[start : start + block.shape[0]] = d[rows, sorted_idx] + + return idx_out, dist_out diff --git a/tests/test_protlabel_backends.py b/tests/test_protlabel_backends.py new file mode 100644 index 00000000..a7bb2ac1 --- /dev/null +++ b/tests/test_protlabel_backends.py @@ -0,0 +1,63 @@ +"""Tests for protlabel.backends.nearest.""" + +import numpy as np +import pytest + +from protlabel.backends import nearest + + +def _toy(): + # 3 references on a line; queries close to ref 0 and ref 2 + refs = np.array([[0.0, 0.0], [10.0, 0.0], [20.0, 0.0]], dtype=np.float32) + queries = np.array([[0.1, 0.0], [19.5, 0.0]], dtype=np.float32) + return queries, refs + + +def test_returns_shapes(): + queries, refs = _toy() + idx, dist = nearest(queries, refs, k=2, metric="euclidean") + assert idx.shape == (2, 2) + assert dist.shape == (2, 2) + + +def test_nearest_index_euclidean(): + queries, refs = _toy() + idx, dist = nearest(queries, refs, k=1, metric="euclidean") + assert idx[0, 0] == 0 # first query nearest to ref 0 + assert idx[1, 0] == 2 # second query nearest to ref 2 + assert dist[0, 0] == pytest.approx(0.1, abs=1e-4) + + +def test_neighbours_sorted_by_distance(): + queries, refs = _toy() + idx, dist = nearest(queries, refs, k=3, metric="euclidean") + assert np.all(np.diff(dist, axis=1) >= -1e-6) # non-decreasing per row + + +def test_cosine_metric_runs_and_orders(): + refs = np.array([[1.0, 0.0], [0.0, 1.0]], dtype=np.float32) + queries = np.array([[1.0, 0.1]], dtype=np.float32) # closest in angle to ref 0 + idx, dist = nearest(queries, refs, k=1, metric="cosine") + assert idx[0, 0] == 0 + + +def test_k_capped_to_num_refs(): + queries, refs = _toy() + idx, dist = nearest(queries, refs, k=10, metric="euclidean") + assert idx.shape == (2, 3) # only 3 refs available + + +def test_chunking_matches_unchunked(): + rng = np.random.default_rng(0) + refs = rng.standard_normal((50, 8)).astype(np.float32) + queries = rng.standard_normal((7, 8)).astype(np.float32) + a_idx, a_dist = nearest(queries, refs, k=3, metric="euclidean", chunk=1000) + b_idx, b_dist = nearest(queries, refs, k=3, metric="euclidean", chunk=3) + assert np.array_equal(a_idx, b_idx) + assert np.allclose(a_dist, b_dist, atol=1e-5) + + +def test_unknown_metric_raises(): + queries, refs = _toy() + with pytest.raises(ValueError): + nearest(queries, refs, k=1, metric="manhattan") From d494242b4350b5021211f7200fb4e7456e19550a Mon Sep 17 00:00:00 2001 From: tsenoner Date: Thu, 11 Jun 2026 19:33:48 +0200 Subject: [PATCH 05/21] fix(protlabel): bound kNN per-chunk memory adaptively; guard k>=1 Co-Authored-By: Claude Sonnet 4.6 --- src/protlabel/backends.py | 34 ++++++++++++++++++++++++++------ tests/test_protlabel_backends.py | 18 +++++++++++++++++ 2 files changed, 46 insertions(+), 6 deletions(-) diff --git a/src/protlabel/backends.py b/src/protlabel/backends.py index c6408a89..cfedd914 100644 --- a/src/protlabel/backends.py +++ b/src/protlabel/backends.py @@ -1,8 +1,10 @@ """Exact (brute-force) k-nearest-neighbour search over reference embeddings. -Chunked over the query axis so the Q_chunk x N distance block stays small, -which keeps peak memory near the reference matrix itself even at Swiss-Prot -scale. scipy.cdist handles both euclidean and cosine. +Chunked over both the query axis and, adaptively, the reference axis so the +per-chunk float64 distance block emitted by ``cdist`` is bounded to +``max_block_bytes`` (default 256 MiB) regardless of ``n_refs``. This keeps +peak memory near the reference matrix itself even at Swiss-Prot scale +(~570 000 references). scipy.cdist handles both euclidean and cosine. """ from __future__ import annotations @@ -19,26 +21,46 @@ def nearest( k: int, metric: str = "euclidean", chunk: int = 4096, + max_block_bytes: int = 256 * 1024 * 1024, ) -> tuple[np.ndarray, np.ndarray]: """Return (idx, dist) of the k nearest *references* per query. idx[i] -> indices into ``refs`` of the k nearest, ascending by distance. dist[i] -> the corresponding distances. k is capped to the number of references. + + Memory behaviour + ---------------- + ``cdist`` internally produces a float64 block of shape + ``(query_chunk, n_refs)``. When ``n_refs`` is large (e.g. Swiss-Prot + ~570 000), even a modest ``chunk`` of 4096 yields a ~19 GiB block. + To bound this, the effective query-chunk size ``eff_chunk`` is computed + as ``min(chunk, max_block_bytes // (n_refs * 8))`` so the float64 block + stays at or below ``max_block_bytes`` (default 256 MiB) independent of + ``n_refs``. Peak memory therefore remains close to the reference matrix + itself, making the function laptop-feasible at Swiss-Prot scale. """ if metric not in _METRICS: - raise ValueError(f"Unknown metric {metric!r}; expected one of {_METRICS}") + raise ValueError(f"Unknown metric {metric!r}; expected 'euclidean' or 'cosine'") + + if k < 1: + raise ValueError("k must be >= 1") queries = np.ascontiguousarray(queries, dtype=np.float32) refs = np.ascontiguousarray(refs, dtype=np.float32) n_refs = refs.shape[0] k = min(k, n_refs) + # Adaptively shrink the query chunk so the float64 cdist block stays + # within max_block_bytes (cdist emits float64 = 8 bytes per element). + bytes_per_row = max(1, n_refs * 8) + eff_chunk = max(1, min(chunk, max_block_bytes // bytes_per_row)) + idx_out = np.empty((queries.shape[0], k), dtype=np.int64) dist_out = np.empty((queries.shape[0], k), dtype=np.float32) - for start in range(0, queries.shape[0], chunk): - block = queries[start : start + chunk] + for start in range(0, queries.shape[0], eff_chunk): + block = queries[start : start + eff_chunk] d = cdist(block, refs, metric=metric).astype(np.float32) # (b, n_refs) part = np.argpartition(d, kth=k - 1, axis=1)[:, :k] # unsorted top-k rows = np.arange(block.shape[0])[:, None] diff --git a/tests/test_protlabel_backends.py b/tests/test_protlabel_backends.py index a7bb2ac1..f1ea8732 100644 --- a/tests/test_protlabel_backends.py +++ b/tests/test_protlabel_backends.py @@ -45,6 +45,8 @@ def test_k_capped_to_num_refs(): queries, refs = _toy() idx, dist = nearest(queries, refs, k=10, metric="euclidean") assert idx.shape == (2, 3) # only 3 refs available + assert np.all(np.diff(dist, axis=1) >= -1e-6) + assert idx[0, 0] == 0 and idx[1, 0] == 2 def test_chunking_matches_unchunked(): @@ -61,3 +63,19 @@ def test_unknown_metric_raises(): queries, refs = _toy() with pytest.raises(ValueError): nearest(queries, refs, k=1, metric="manhattan") + + +def test_tiny_memory_budget_matches_default(): + rng = np.random.default_rng(1) + refs = rng.standard_normal((40, 6)).astype(np.float32) + queries = rng.standard_normal((9, 6)).astype(np.float32) + a_idx, a_dist = nearest(queries, refs, k=3, metric="euclidean") + b_idx, b_dist = nearest(queries, refs, k=3, metric="euclidean", max_block_bytes=1) + assert np.array_equal(a_idx, b_idx) + assert np.allclose(a_dist, b_dist, atol=1e-5) + + +def test_k_less_than_one_raises(): + queries, refs = _toy() + with pytest.raises(ValueError): + nearest(queries, refs, k=0, metric="euclidean") From c07aef544315d99fe34b8fa2e2b177f691ff64d9 Mon Sep 17 00:00:00 2001 From: tsenoner Date: Thu, 11 Jun 2026 19:38:15 +0200 Subject: [PATCH 06/21] feat(protlabel): kNN label transfer with reliability index Co-Authored-By: Claude Sonnet 4.6 --- src/protlabel/transfer.py | 82 ++++++++++++++++++++++++++++++++ tests/test_protlabel_transfer.py | 79 ++++++++++++++++++++++++++++++ 2 files changed, 161 insertions(+) create mode 100644 src/protlabel/transfer.py create mode 100644 tests/test_protlabel_transfer.py diff --git a/src/protlabel/transfer.py b/src/protlabel/transfer.py new file mode 100644 index 00000000..36b44d3a --- /dev/null +++ b/src/protlabel/transfer.py @@ -0,0 +1,82 @@ +"""Embedding annotation transfer: kNN -> reliability index -> transferred label. + +Implements the goPredSim aggregation (Littmann et al. 2021, Eq. 5): + RI(p) = (1/k) * sum over neighbours carrying label p of similarity(d). +The transferred label is argmax RI(p); its source is the nearest neighbour +carrying that label. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +import numpy as np + +from protlabel.backends import nearest +from protlabel.reliability import similarity + + +@dataclass(frozen=True) +class Prediction: + """One transferred annotation for a query protein.""" + + query_id: str + label: str + source_id: str + distance: float + reliability: float + k: int + metric: str + + +def eat( + query_emb: np.ndarray, + query_ids: list[str], + ref_emb: np.ndarray, + ref_ids: list[str], + ref_labels: list[str], + *, + k: int = 1, + metric: str = "euclidean", +) -> list[Prediction]: + """Transfer the best-guess label to each query from its k nearest references.""" + if not (len(ref_ids) == len(ref_labels) == ref_emb.shape[0]): + raise ValueError("ref_emb, ref_ids and ref_labels must have equal length") + if ref_emb.shape[0] == 0: + raise ValueError("No reference embeddings to transfer from") + if len(query_ids) != query_emb.shape[0]: + raise ValueError("query_emb and query_ids must have equal length") + + idx, dist = nearest(query_emb, ref_emb, k=k, metric=metric) + eff_k = idx.shape[1] + predictions: list[Prediction] = [] + + for qi, query_id in enumerate(query_ids): + neigh_idx = idx[qi] + neigh_dist = dist[qi] + # Accumulate RI per label and track the nearest source per label. + ri_by_label: dict[str, float] = {} + nearest_src: dict[str, tuple[float, str]] = {} + for j, ref_i in enumerate(neigh_idx): + lab = ref_labels[ref_i] + d = float(neigh_dist[j]) + ri_by_label[lab] = ri_by_label.get(lab, 0.0) + similarity(d, metric) + if lab not in nearest_src or d < nearest_src[lab][0]: + nearest_src[lab] = (d, ref_ids[ref_i]) + # Normalise by k (the goPredSim 1/k term). + best_label = max(ri_by_label, key=lambda p: ri_by_label[p]) + ri = ri_by_label[best_label] / eff_k + src_dist, src_id = nearest_src[best_label] + predictions.append( + Prediction( + query_id=query_id, + label=best_label, + source_id=src_id, + distance=src_dist, + reliability=ri, + k=eff_k, + metric=metric, + ) + ) + + return predictions diff --git a/tests/test_protlabel_transfer.py b/tests/test_protlabel_transfer.py new file mode 100644 index 00000000..5862532c --- /dev/null +++ b/tests/test_protlabel_transfer.py @@ -0,0 +1,79 @@ +"""Tests for protlabel.transfer.""" + +import numpy as np +import pytest + +from protlabel.transfer import Prediction, eat + + +def _setup(): + ref_emb = np.array([[0.0, 0.0], [10.0, 0.0], [20.0, 0.0]], dtype=np.float32) + ref_ids = ["R0", "R1", "R2"] + ref_labels = ["toxin", "enzyme", "toxin"] + query_emb = np.array([[0.0, 0.0], [19.7, 0.0]], dtype=np.float32) + query_ids = ["Q0", "Q1"] + return ref_emb, ref_ids, ref_labels, query_emb, query_ids + + +def test_k1_transfers_nearest_label_and_source(): + ref_emb, ref_ids, ref_labels, query_emb, query_ids = _setup() + preds = eat(query_emb, query_ids, ref_emb, ref_ids, ref_labels, k=1) + assert isinstance(preds[0], Prediction) + assert preds[0].query_id == "Q0" + assert preds[0].label == "toxin" + assert preds[0].source_id == "R0" + assert preds[0].reliability == pytest.approx(1.0) # distance 0 -> RI 1.0 + + +def test_k1_reliability_uses_gopredsim_transform(): + ref_emb, ref_ids, ref_labels, query_emb, query_ids = _setup() + preds = eat(query_emb, query_ids, ref_emb, ref_ids, ref_labels, k=1) + # Q1 distance to R2 is 0.3 -> RI = 0.5/(0.5+0.3) + assert preds[1].label == "toxin" + assert preds[1].source_id == "R2" + assert preds[1].reliability == pytest.approx(0.5 / 0.8, abs=1e-4) + + +def test_k3_vote_picks_majority_label(): + # Query equidistant-ish but two of three nearest are "toxin" + ref_emb = np.array( + [[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [3.0, 0.0]], dtype=np.float32 + ) + ref_ids = ["R0", "R1", "R2", "R3"] + ref_labels = ["toxin", "enzyme", "toxin", "toxin"] + query_emb = np.array([[1.4, 0.0]], dtype=np.float32) + preds = eat(query_emb, ["Q"], ref_emb, ref_ids, ref_labels, k=3) + assert preds[0].label == "toxin" # toxin RI sum beats lone enzyme neighbour + assert 0.0 < preds[0].reliability <= 1.0 + + +def test_cosine_metric(): + ref_emb = np.array([[1.0, 0.0], [0.0, 1.0]], dtype=np.float32) + preds = eat( + np.array([[1.0, 0.05]], dtype=np.float32), + ["Q"], + ref_emb, + ["R0", "R1"], + ["a", "b"], + k=1, + metric="cosine", + ) + assert preds[0].label == "a" + + +def test_length_mismatch_raises(): + ref_emb, ref_ids, ref_labels, query_emb, query_ids = _setup() + with pytest.raises(ValueError): + eat(query_emb, query_ids, ref_emb, ref_ids, ref_labels[:-1], k=1) + + +def test_empty_references_raises(): + with pytest.raises(ValueError): + eat( + np.zeros((1, 2), dtype=np.float32), + ["Q"], + np.zeros((0, 2), dtype=np.float32), + [], + [], + k=1, + ) From 4b39cb8cc9f5cb8eecf4e0288c464cfed2b1c34c Mon Sep 17 00:00:00 2001 From: tsenoner Date: Thu, 11 Jun 2026 19:42:57 +0200 Subject: [PATCH 07/21] test(protlabel): document RI tie-break and cover nearest-source selection Co-Authored-By: Claude Sonnet 4.6 --- src/protlabel/transfer.py | 3 +++ tests/test_protlabel_transfer.py | 12 ++++++++++++ 2 files changed, 15 insertions(+) diff --git a/src/protlabel/transfer.py b/src/protlabel/transfer.py index 36b44d3a..59fbc893 100644 --- a/src/protlabel/transfer.py +++ b/src/protlabel/transfer.py @@ -64,6 +64,9 @@ def eat( if lab not in nearest_src or d < nearest_src[lab][0]: nearest_src[lab] = (d, ref_ids[ref_i]) # Normalise by k (the goPredSim 1/k term). + # Tie-break: max() over the insertion-ordered dict returns the first label + # seen while iterating neighbours (which are in ascending distance order), + # so for distinct distances the nearest neighbour's label wins a tie. best_label = max(ri_by_label, key=lambda p: ri_by_label[p]) ri = ri_by_label[best_label] / eff_k src_dist, src_id = nearest_src[best_label] diff --git a/tests/test_protlabel_transfer.py b/tests/test_protlabel_transfer.py index 5862532c..417f5218 100644 --- a/tests/test_protlabel_transfer.py +++ b/tests/test_protlabel_transfer.py @@ -77,3 +77,15 @@ def test_empty_references_raises(): [], k=1, ) + + +def test_source_is_nearest_neighbour_with_winning_label(): + # Two neighbours share the winning label at distinct distances; the source + # must be the closer one. + ref_emb = np.array([[0.0, 0.0], [0.5, 0.0], [10.0, 0.0]], dtype=np.float32) + ref_ids = ["R_far", "R_near", "R_other"] + ref_labels = ["toxin", "toxin", "enzyme"] + query_emb = np.array([[0.4, 0.0]], dtype=np.float32) # R_near at 0.1, R_far at 0.4 + preds = eat(query_emb, ["Q"], ref_emb, ref_ids, ref_labels, k=2) + assert preds[0].label == "toxin" + assert preds[0].source_id == "R_near" From 796e5b1f51c485bc16f087ef0a5bc39e01024522 Mon Sep 17 00:00:00 2001 From: tsenoner Date: Thu, 11 Jun 2026 19:44:35 +0200 Subject: [PATCH 08/21] feat(protlabel): persistable Lookup sidecar + public API Co-Authored-By: Claude Sonnet 4.6 --- src/protlabel/__init__.py | 5 +++ src/protlabel/lookup.py | 64 ++++++++++++++++++++++++++++++++++ tests/test_protlabel_lookup.py | 41 ++++++++++++++++++++++ 3 files changed, 110 insertions(+) create mode 100644 src/protlabel/lookup.py create mode 100644 tests/test_protlabel_lookup.py diff --git a/src/protlabel/__init__.py b/src/protlabel/__init__.py index c255e887..7650ecbb 100644 --- a/src/protlabel/__init__.py +++ b/src/protlabel/__init__.py @@ -4,4 +4,9 @@ with the goPredSim reliability index. Pure numpy/scipy/h5py; no protspace imports. """ +from protlabel.lookup import Lookup +from protlabel.transfer import Prediction, eat + __version__ = "0.1.0" + +__all__ = ["Lookup", "Prediction", "eat", "__version__"] diff --git a/src/protlabel/lookup.py b/src/protlabel/lookup.py new file mode 100644 index 00000000..5220b3b1 --- /dev/null +++ b/src/protlabel/lookup.py @@ -0,0 +1,64 @@ +"""A persistable reference lookup: embeddings + ids + labels, plus query(). + +Serialized as a single .npz sidecar so it can live next to a bundle or in a +cache dir and be rebuilt on demand from the source HDF5. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path + +import numpy as np + +from protlabel.transfer import Prediction, eat + + +@dataclass +class Lookup: + """Reference set for embedding annotation transfer.""" + + embeddings: np.ndarray # (N, D) float32 + ids: list[str] + labels: list[str] + metric: str = "euclidean" + model: str = field(default="") + + def query( + self, query_emb: np.ndarray, query_ids: list[str], *, k: int = 1 + ) -> list[Prediction]: + """Transfer labels to query embeddings from this lookup.""" + return eat( + query_emb, + query_ids, + self.embeddings, + self.ids, + self.labels, + k=k, + metric=self.metric, + ) + + def save(self, path: Path) -> None: + """Serialize to a .npz sidecar.""" + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + np.savez( + path, + embeddings=self.embeddings.astype(np.float16), + ids=np.array(self.ids, dtype=object), + labels=np.array(self.labels, dtype=object), + metric=self.metric, + model=self.model, + ) + + @classmethod + def load(cls, path: Path) -> Lookup: + """Load a .npz sidecar (re-upcasts embeddings to float32).""" + with np.load(path, allow_pickle=True) as data: + return cls( + embeddings=data["embeddings"].astype(np.float32), + ids=list(data["ids"]), + labels=list(data["labels"]), + metric=str(data["metric"]), + model=str(data["model"]), + ) diff --git a/tests/test_protlabel_lookup.py b/tests/test_protlabel_lookup.py new file mode 100644 index 00000000..b84b6066 --- /dev/null +++ b/tests/test_protlabel_lookup.py @@ -0,0 +1,41 @@ +"""Tests for protlabel.lookup.Lookup (the rebuildable sidecar).""" + +import numpy as np + +from protlabel import Lookup, Prediction + + +def _lookup(): + emb = np.array([[0.0, 0.0], [10.0, 0.0]], dtype=np.float32) + return Lookup(embeddings=emb, ids=["R0", "R1"], labels=["a", "b"]) + + +def test_query_returns_predictions(): + lk = _lookup() + preds = lk.query(np.array([[0.2, 0.0]], dtype=np.float32), ["Q0"], k=1) + assert isinstance(preds[0], Prediction) + assert preds[0].label == "a" + assert preds[0].source_id == "R0" + + +def test_save_load_roundtrip(tmp_path): + lk = _lookup() + path = tmp_path / "lookup.npz" + lk.save(path) + assert path.exists() + loaded = Lookup.load(path) + assert loaded.ids == lk.ids + assert loaded.labels == lk.labels + assert loaded.metric == lk.metric + assert np.allclose(loaded.embeddings, lk.embeddings) + + +def test_loaded_lookup_queries_identically(tmp_path): + lk = _lookup() + q = np.array([[9.8, 0.0]], dtype=np.float32) + before = lk.query(q, ["Q"], k=1) + path = tmp_path / "lk.npz" + lk.save(path) + after = Lookup.load(path).query(q, ["Q"], k=1) + assert before[0].label == after[0].label == "b" + assert before[0].reliability == after[0].reliability From ae7fcc23011e69a40dc81ef2aec4a1093ce66549 Mon Sep 17 00:00:00 2001 From: tsenoner Date: Thu, 11 Jun 2026 19:49:47 +0200 Subject: [PATCH 09/21] feat: query/reference classifier for annotation transfer Co-Authored-By: Claude Sonnet 4.6 --- src/protspace/analysis/__init__.py | 1 + src/protspace/analysis/classification.py | 66 ++++++++++++++++++++++++ tests/test_classification.py | 54 +++++++++++++++++++ 3 files changed, 121 insertions(+) create mode 100644 src/protspace/analysis/__init__.py create mode 100644 src/protspace/analysis/classification.py create mode 100644 tests/test_classification.py diff --git a/src/protspace/analysis/__init__.py b/src/protspace/analysis/__init__.py new file mode 100644 index 00000000..137457c7 --- /dev/null +++ b/src/protspace/analysis/__init__.py @@ -0,0 +1 @@ +"""Optional analysis layer for ProtSpace (classification, gating, mining).""" diff --git a/src/protspace/analysis/classification.py b/src/protspace/analysis/classification.py new file mode 100644 index 00000000..6a8111d2 --- /dev/null +++ b/src/protspace/analysis/classification.py @@ -0,0 +1,66 @@ +"""Classify proteins as transfer queries vs annotated references. + +Rules match by ID prefix and/or a case-insensitive metadata substring +(``column ~ substring``). No biology is hardcoded; a query rule that matches +nothing is an error. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +import pyarrow as pa + + +@dataclass +class Rule: + """A classification rule. A protein matches if ANY clause matches.""" + + id_prefixes: list[str] = field(default_factory=list) + where: list[tuple[str, str]] = field(default_factory=list) # (column, substring) + + +def _matches(rule: Rule, identifier: str, row: dict[str, str]) -> bool: + if any(identifier.startswith(p) for p in rule.id_prefixes): + return True + for column, substring in rule.where: + if column not in row: + raise KeyError(f"Classification column {column!r} not in annotations") + value = row[column] + if value is not None and substring.lower() in str(value).lower(): + return True + return False + + +def classify( + annotations: pa.Table, query_rule: Rule, reference_rule: Rule +) -> tuple[list[int], list[int]]: + """Return (query_indices, reference_indices) into the annotations table. + + Query classification takes precedence: a protein matching both rules is a + query. Raises ValueError if the query rule matches nothing. + """ + columns = set(annotations.column_names) + # Validate where-columns up front so an empty table still raises KeyError. + for rule in (query_rule, reference_rule): + for column, _ in rule.where: + if column not in columns: + raise KeyError(f"Classification column {column!r} not in annotations") + + rows = annotations.to_pylist() + identifiers = [str(r["identifier"]) for r in rows] + + query_indices: list[int] = [] + reference_indices: list[int] = [] + for i, (identifier, row) in enumerate(zip(identifiers, rows, strict=True)): + if _matches(query_rule, identifier, row): + query_indices.append(i) + elif _matches(reference_rule, identifier, row): + reference_indices.append(i) + + if not query_indices: + raise ValueError( + "Classifier matched no query proteins; check --query-id-prefix / " + "--query-where rules." + ) + return query_indices, reference_indices diff --git a/tests/test_classification.py b/tests/test_classification.py new file mode 100644 index 00000000..98cabd17 --- /dev/null +++ b/tests/test_classification.py @@ -0,0 +1,54 @@ +"""Tests for the query/reference classifier.""" + +import pyarrow as pa +import pytest + +from protspace.analysis.classification import Rule, classify + + +def _table(): + return pa.table( + { + "identifier": ["TRINITY_1", "TRINITY_2", "P00001", "P00002"], + "protein_category": ["mSCR", "mSCR", "neurotoxin", "enzyme"], + } + ) + + +def test_prefix_rule_selects_queries(): + q = Rule(id_prefixes=["TRINITY_"]) + r = Rule(where=[("protein_category", "neurotoxin")]) + qi, ri = classify(_table(), q, r) + assert qi == [0, 1] + assert ri == [2] + + +def test_where_substring_is_case_insensitive(): + q = Rule(where=[("protein_category", "MSCR")]) + r = Rule(id_prefixes=["P0"]) + qi, ri = classify(_table(), q, r) + assert qi == [0, 1] + assert ri == [2, 3] + + +def test_query_takes_precedence_over_reference(): + # A protein matching both rules is classified as a query, never a reference. + q = Rule(id_prefixes=["P00001"]) + r = Rule(where=[("protein_category", "neurotoxin")]) + qi, ri = classify(_table(), q, r) + assert 2 in qi + assert 2 not in ri + + +def test_empty_query_match_raises(): + q = Rule(id_prefixes=["NOPE_"]) + r = Rule(id_prefixes=["P0"]) + with pytest.raises(ValueError, match="no query"): + classify(_table(), q, r) + + +def test_missing_where_column_raises(): + q = Rule(where=[("not_a_column", "x")]) + r = Rule(id_prefixes=["P0"]) + with pytest.raises(KeyError): + classify(_table(), q, r) From bc8837e21ab882c3b26df3debe2c8fa52dc993ba Mon Sep 17 00:00:00 2001 From: tsenoner Date: Thu, 11 Jun 2026 19:54:36 +0200 Subject: [PATCH 10/21] test: cover neither-match exclusion and multi-prefix OR in classifier --- tests/test_classification.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/test_classification.py b/tests/test_classification.py index 98cabd17..ecaba1a0 100644 --- a/tests/test_classification.py +++ b/tests/test_classification.py @@ -52,3 +52,20 @@ def test_missing_where_column_raises(): r = Rule(id_prefixes=["P0"]) with pytest.raises(KeyError): classify(_table(), q, r) + + +def test_protein_matching_neither_rule_is_excluded(): + # P00002 / enzyme matches neither rule -> absent from both lists. + q = Rule(id_prefixes=["TRINITY_"]) + r = Rule(where=[("protein_category", "neurotoxin")]) + qi, ri = classify(_table(), q, r) + assert 3 not in qi + assert 3 not in ri + + +def test_multiple_id_prefixes_use_or_semantics(): + q = Rule(id_prefixes=["TRINITY_", "P00001"]) + r = Rule(id_prefixes=["P00002"]) + qi, ri = classify(_table(), q, r) + assert qi == [0, 1, 2] + assert ri == [3] From 94b4f0fe0885d12b059dbd5cc45561f24d58ed37 Mon Sep 17 00:00:00 2001 From: tsenoner Date: Thu, 11 Jun 2026 19:56:06 +0200 Subject: [PATCH 11/21] feat: build per-cell prediction overlay columns Add `add_overlay_columns()` in `src/protspace/data/io/predictions.py` that appends three aligned Arrow columns (`COL__pred_value`, `COL__pred_confidence`, `COL__pred_source`) from a list of `protlabel.Prediction` objects, leaving the curated column untouched. Co-Authored-By: Claude Sonnet 4.6 --- src/protspace/data/io/predictions.py | 46 ++++++++++++++++++++++++ tests/test_predictions_overlay.py | 52 ++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+) create mode 100644 src/protspace/data/io/predictions.py create mode 100644 tests/test_predictions_overlay.py diff --git a/src/protspace/data/io/predictions.py b/src/protspace/data/io/predictions.py new file mode 100644 index 00000000..92ce6d20 --- /dev/null +++ b/src/protspace/data/io/predictions.py @@ -0,0 +1,46 @@ +"""Turn protlabel Predictions into per-cell overlay columns on the annotations table. + +For a transferred column ``COL`` we append three aligned columns (null for +non-predicted proteins), leaving the curated ``COL`` untouched: + COL__pred_value (string) the transferred label + COL__pred_confidence (float32) the reliability index in [0, 1] + COL__pred_source (string) the nearest reference protein id +""" + +from __future__ import annotations + +from collections.abc import Sequence + +import pyarrow as pa + +from protlabel import Prediction + + +def add_overlay_columns( + annotations: pa.Table, column: str, predictions: Sequence[Prediction] +) -> pa.Table: + """Append the COL__pred_* overlay columns, aligned by identifier.""" + by_query = {p.query_id: p for p in predictions} + identifiers = [str(v) for v in annotations.column("identifier").to_pylist()] + + values: list[str | None] = [] + confidences: list[float | None] = [] + sources: list[str | None] = [] + for identifier in identifiers: + pred = by_query.get(identifier) + if pred is None: + values.append(None) + confidences.append(None) + sources.append(None) + else: + values.append(pred.label) + confidences.append(float(pred.reliability)) + sources.append(pred.source_id) + + out = annotations + out = out.append_column(f"{column}__pred_value", pa.array(values, pa.string())) + out = out.append_column( + f"{column}__pred_confidence", pa.array(confidences, pa.float32()) + ) + out = out.append_column(f"{column}__pred_source", pa.array(sources, pa.string())) + return out diff --git a/tests/test_predictions_overlay.py b/tests/test_predictions_overlay.py new file mode 100644 index 00000000..6033ff1e --- /dev/null +++ b/tests/test_predictions_overlay.py @@ -0,0 +1,52 @@ +"""Tests for building the per-cell prediction overlay columns.""" + +import pyarrow as pa + +from protlabel import Prediction +from protspace.data.io.predictions import add_overlay_columns + + +def _table(): + return pa.table( + { + "identifier": ["Q0", "Q1", "R0"], + "protein_category": ["", "", "neurotoxin"], + } + ) + + +def test_adds_three_overlay_columns(): + preds = [ + Prediction("Q0", "neurotoxin", "R0", 0.3, 0.62, 1, "euclidean"), + ] + out = add_overlay_columns(_table(), "protein_category", preds) + assert "protein_category__pred_value" in out.column_names + assert "protein_category__pred_confidence" in out.column_names + assert "protein_category__pred_source" in out.column_names + + +def test_overlay_values_aligned_by_identifier(): + preds = [Prediction("Q1", "enzyme", "R9", 0.5, 0.5, 1, "euclidean")] + out = add_overlay_columns(_table(), "protein_category", preds).to_pylist() + by_id = {r["identifier"]: r for r in out} + assert by_id["Q1"]["protein_category__pred_value"] == "enzyme" + assert by_id["Q1"]["protein_category__pred_confidence"] == 0.5 + assert by_id["Q1"]["protein_category__pred_source"] == "R9" + # Non-predicted rows are null in the overlay columns. + assert by_id["Q0"]["protein_category__pred_value"] is None + assert by_id["R0"]["protein_category__pred_confidence"] is None + + +def test_curated_column_is_left_untouched(): + preds = [Prediction("Q0", "neurotoxin", "R0", 0.1, 0.8, 1, "euclidean")] + out = add_overlay_columns(_table(), "protein_category", preds).to_pylist() + by_id = {r["identifier"]: r for r in out} + assert by_id["Q0"]["protein_category"] == "" # original column unchanged + assert by_id["R0"]["protein_category"] == "neurotoxin" + + +def test_confidence_column_is_float(): + preds = [Prediction("Q0", "x", "R0", 0.1, 0.83, 1, "euclidean")] + out = add_overlay_columns(_table(), "protein_category", preds) + field = out.schema.field("protein_category__pred_confidence") + assert pa.types.is_floating(field.type) From 05194bf989aadd055c832f85471222a75ec7cc3f Mon Sep 17 00:00:00 2001 From: tsenoner Date: Thu, 11 Jun 2026 19:59:39 +0200 Subject: [PATCH 12/21] test: cover empty-predictions and unknown-id overlay edge cases --- tests/test_predictions_overlay.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/test_predictions_overlay.py b/tests/test_predictions_overlay.py index 6033ff1e..8a44950e 100644 --- a/tests/test_predictions_overlay.py +++ b/tests/test_predictions_overlay.py @@ -50,3 +50,20 @@ def test_confidence_column_is_float(): out = add_overlay_columns(_table(), "protein_category", preds) field = out.schema.field("protein_category__pred_confidence") assert pa.types.is_floating(field.type) + + +def test_empty_predictions_appends_all_null_columns(): + out = add_overlay_columns(_table(), "protein_category", []) + assert "protein_category__pred_value" in out.column_names + assert out.column("protein_category__pred_value").to_pylist() == [None, None, None] + assert out.column("protein_category__pred_confidence").to_pylist() == [ + None, + None, + None, + ] + + +def test_prediction_for_unknown_identifier_is_ignored(): + preds = [Prediction("NOT_IN_TABLE", "x", "R0", 0.1, 0.9, 1, "euclidean")] + out = add_overlay_columns(_table(), "protein_category", preds).to_pylist() + assert all(r["protein_category__pred_value"] is None for r in out) From 5093f6653841c4800ed7238f6818a86e022cdb1d Mon Sep 17 00:00:00 2001 From: tsenoner Date: Thu, 11 Jun 2026 20:01:07 +0200 Subject: [PATCH 13/21] feat: replace annotations part of a parquetbundle in place Co-Authored-By: Claude Sonnet 4.6 --- src/protspace/data/io/bundle.py | 30 +++++++++++++++++++ tests/test_bundle_overlay.py | 52 +++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+) create mode 100644 tests/test_bundle_overlay.py diff --git a/src/protspace/data/io/bundle.py b/src/protspace/data/io/bundle.py index ca625a26..22bab8f1 100644 --- a/src/protspace/data/io/bundle.py +++ b/src/protspace/data/io/bundle.py @@ -148,6 +148,36 @@ def replace_settings_in_bundle( f.write(new_content) +def replace_annotations_in_bundle( + input_path: Path, + output_path: Path, + annotations_table: pa.Table, +) -> None: + """Replace the annotations (1st) part of a bundle, preserving the rest. + + Projection parts (2nd, 3rd) are kept byte-for-byte; an existing settings + (4th) part is carried over unchanged. + """ + with open(input_path, "rb") as f: + content = f.read() + + parts = content.split(PARQUET_BUNDLE_DELIMITER) + if len(parts) < 3 or len(parts) > 4: + raise ValueError(f"Expected 3 or 4 parts in parquetbundle, found {len(parts)}") + + buf = io.BytesIO() + pq.write_table(annotations_table, buf) + new_parts = [buf.getvalue(), parts[1], parts[2]] + if len(parts) == 4: + new_parts.append(parts[3]) + + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "wb") as f: + f.write(PARQUET_BUNDLE_DELIMITER.join(new_parts)) + + logger.info(f"Wrote bundle with updated annotations to: {output_path}") + + def create_settings_parquet(settings_dict: dict) -> bytes: """Serialize a settings dict into parquet bytes. diff --git a/tests/test_bundle_overlay.py b/tests/test_bundle_overlay.py new file mode 100644 index 00000000..0e25b4aa --- /dev/null +++ b/tests/test_bundle_overlay.py @@ -0,0 +1,52 @@ +"""Round-trip tests for replacing the annotations part of a bundle.""" + +import io + +import pyarrow as pa +import pyarrow.parquet as pq + +from protspace.data.io.bundle import ( + read_bundle, + replace_annotations_in_bundle, + write_bundle, +) + + +def _tables(): + annotations = pa.table({"identifier": ["A", "B"], "cat": ["x", "y"]}) + proj_meta = pa.table({"name": ["PCA 2"], "dims": [2]}) + proj_data = pa.table({"id": ["A", "B"], "x": [0.0, 1.0], "y": [0.0, 1.0]}) + return [annotations, proj_meta, proj_data] + + +def _read_part(part_bytes): + return pq.read_table(io.BytesIO(part_bytes)) + + +def test_replaces_annotations_keeps_other_parts(tmp_path): + src = tmp_path / "in.parquetbundle" + out = tmp_path / "out.parquetbundle" + write_bundle(_tables(), src) + + new_annotations = pa.table( + {"identifier": ["A", "B"], "cat": ["x", "y"], "cat__pred_value": [None, "z"]} + ) + replace_annotations_in_bundle(src, out, new_annotations) + + parts, settings = read_bundle(out) + assert "cat__pred_value" in _read_part(parts[0]).column_names + # Projections preserved byte-for-byte. + assert _read_part(parts[1]).column_names == ["name", "dims"] + assert _read_part(parts[2]).to_pydict()["x"] == [0.0, 1.0] + + +def test_preserves_settings_when_present(tmp_path): + src = tmp_path / "in.parquetbundle" + out = tmp_path / "out.parquetbundle" + write_bundle(_tables(), src, settings={"foo": 1}) + + new_annotations = pa.table({"identifier": ["A", "B"], "cat": ["x", "y"]}) + replace_annotations_in_bundle(src, out, new_annotations) + + _parts, settings = read_bundle(out) + assert settings == {"foo": 1} From c9cae3f537ec431108189f8121cbf0eb5bfe9e50 Mon Sep 17 00:00:00 2001 From: tsenoner Date: Thu, 11 Jun 2026 20:07:33 +0200 Subject: [PATCH 14/21] feat: add 'protspace transfer' annotation-transfer subcommand Implements Task 9: the EAT orchestration core (run_transfer) and the 'protspace transfer' Typer CLI command, wiring classification, nearest- neighbour lookup (protlabel.eat), and overlay-column writing into a single pipeline for filling missing annotation values from pLM embedding space. Co-Authored-By: Claude Sonnet 4.6 --- src/protspace/cli/app.py | 1 + src/protspace/cli/transfer.py | 195 ++++++++++++++++++++++++++++++++++ tests/test_transfer_cli.py | 66 ++++++++++++ 3 files changed, 262 insertions(+) create mode 100644 src/protspace/cli/transfer.py create mode 100644 tests/test_transfer_cli.py diff --git a/src/protspace/cli/app.py b/src/protspace/cli/app.py index c718fbae..1cd8ba39 100644 --- a/src/protspace/cli/app.py +++ b/src/protspace/cli/app.py @@ -70,6 +70,7 @@ def _register_commands() -> None: project, serve, style, + transfer, ) diff --git a/src/protspace/cli/transfer.py b/src/protspace/cli/transfer.py new file mode 100644 index 00000000..45ca098d --- /dev/null +++ b/src/protspace/cli/transfer.py @@ -0,0 +1,195 @@ +"""protspace transfer — fill missing annotation values from nearest references. + +Embedding Annotation Transfer (EAT): for each query protein with a missing +value in a target column, transfer the value of its nearest annotated +reference in pLM embedding space, with a reliability-index confidence. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Annotated + +import numpy as np +import pyarrow as pa +import typer + +from protspace.cli.app import app, setup_logging +from protspace.cli.common_options import Opt_Verbose + +logger = logging.getLogger(__name__) + + +def _is_missing(value) -> bool: + return value is None or str(value).strip() == "" + + +def run_transfer( + *, + annotations: pa.Table, + embeddings: dict[str, np.ndarray], + transfer_columns: list[str], + query_rule, + reference_rule, + k: int = 1, + metric: str = "euclidean", +) -> pa.Table: + """Pure core: classify, transfer per column, return the augmented table. + + ``embeddings`` maps protein id -> 1-D float32 vector. Proteins without an + embedding cannot act as queries or references. + """ + from protlabel import eat + from protspace.analysis.classification import classify + from protspace.data.io.predictions import add_overlay_columns + + # Restrict classification to proteins that actually have an embedding. + has_emb = pa.array( + [str(v) in embeddings for v in annotations.column("identifier").to_pylist()] + ) + embedded = annotations.filter(has_emb) + + query_idx, ref_idx = classify(embedded, query_rule, reference_rule) + rows = embedded.to_pylist() + + out = annotations + for column in transfer_columns: + if column not in annotations.column_names: + raise KeyError(f"Transfer column {column!r} not in annotations table") + + # References: classified refs that HAVE a value in this column. + ref_ids, ref_labels, ref_vecs = [], [], [] + for i in ref_idx: + value = rows[i].get(column) + if not _is_missing(value): + rid = str(rows[i]["identifier"]) + ref_ids.append(rid) + ref_labels.append(str(value)) + ref_vecs.append(embeddings[rid]) + if not ref_ids: + logger.warning("No references with a value for %r; skipping", column) + continue + + # Queries: classified queries MISSING a value in this column. + q_ids, q_vecs = [], [] + for i in query_idx: + if _is_missing(rows[i].get(column)): + qid = str(rows[i]["identifier"]) + q_ids.append(qid) + q_vecs.append(embeddings[qid]) + if not q_ids: + logger.warning("No queries missing %r; nothing to transfer", column) + continue + + preds = eat( + np.vstack(q_vecs), + q_ids, + np.vstack(ref_vecs), + ref_ids, + ref_labels, + k=k, + metric=metric, + ) + out = add_overlay_columns(out, column, preds) + logger.info("Transferred %r to %d quer(ies)", column, len(preds)) + + return out + + +@app.command() +def transfer( + bundle: Annotated[ + Path, + typer.Option("-b", "--bundle", help="Input .parquetbundle to annotate."), + ], + embeddings: Annotated[ + str, + typer.Option( + "-e", + "--embeddings", + help="HDF5 embeddings, optional :name suffix (e.g. emb.h5:prot_t5).", + ), + ], + transfer_columns: Annotated[ + list[str], + typer.Option( + "-t", "--transfer", help="Annotation column to transfer (repeat)." + ), + ], + output: Annotated[ + Path, + typer.Option("-o", "--output", help="Output .parquetbundle path."), + ], + query_id_prefix: Annotated[ + list[str] | None, typer.Option("--query-id-prefix") + ] = None, + query_where: Annotated[ + list[str] | None, + typer.Option("--query-where", help="col~substr"), + ] = None, + reference_id_prefix: Annotated[ + list[str] | None, typer.Option("--reference-id-prefix") + ] = None, + reference_where: Annotated[ + list[str] | None, + typer.Option("--reference-where", help="col~substr"), + ] = None, + k: Annotated[ + int, typer.Option("--k", help="Neighbours considered (default 1).") + ] = 1, + metric: Annotated[ + str, typer.Option("--metric", help="euclidean | cosine.") + ] = "euclidean", + verbose: Opt_Verbose = 0, +) -> None: + """Transfer annotations to query proteins from nearest reference neighbours.""" + setup_logging(verbose) + + import io + + import pyarrow.parquet as pq + + from protspace.analysis.classification import Rule + from protspace.data.io.bundle import read_bundle, replace_annotations_in_bundle + from protspace.data.loaders import load_h5 + + def _parse_where(items: list[str] | None) -> list[tuple[str, str]]: + clauses = [] + for item in items or []: + if "~" not in item: + raise typer.BadParameter(f"--*-where must be col~substr, got {item!r}") + col, sub = item.split("~", 1) + clauses.append((col, sub)) + return clauses + + query_rule = Rule( + id_prefixes=query_id_prefix or [], where=_parse_where(query_where) + ) + reference_rule = Rule( + id_prefixes=reference_id_prefix or [], where=_parse_where(reference_where) + ) + + # Load embeddings (name override after ':'). + h5_spec = embeddings.split(":", 1) + h5_path = Path(h5_spec[0]) + name_override = h5_spec[1] if len(h5_spec) == 2 else None + emb_set = load_h5([h5_path], name_override=name_override) + emb_map = {header: emb_set.data[i] for i, header in enumerate(emb_set.headers)} + + # Read the annotations part of the bundle. + parts, _settings = read_bundle(bundle) + annotations = pq.read_table(io.BytesIO(parts[0])) + + augmented = run_transfer( + annotations=annotations, + embeddings=emb_map, + transfer_columns=transfer_columns, + query_rule=query_rule, + reference_rule=reference_rule, + k=k, + metric=metric, + ) + + replace_annotations_in_bundle(bundle, output, augmented) + logger.info("Wrote transferred bundle to %s", output) diff --git a/tests/test_transfer_cli.py b/tests/test_transfer_cli.py new file mode 100644 index 00000000..797fd349 --- /dev/null +++ b/tests/test_transfer_cli.py @@ -0,0 +1,66 @@ +"""Tests for the transfer orchestration core and CLI registration.""" + +import numpy as np +import pyarrow as pa +import pytest + +from protspace.analysis.classification import Rule +from protspace.cli.transfer import run_transfer + + +def _inputs(): + annotations = pa.table( + { + "identifier": ["TRINITY_1", "P00001", "P00002"], + "protein_category": ["", "neurotoxin", "enzyme"], + } + ) + # TRINITY_1 sits right on top of the neurotoxin reference P00001. + embeddings = { + "TRINITY_1": np.array([0.0, 0.0], dtype=np.float32), + "P00001": np.array([0.05, 0.0], dtype=np.float32), + "P00002": np.array([9.0, 0.0], dtype=np.float32), + } + return annotations, embeddings + + +def test_run_transfer_predicts_for_query_with_missing_value(): + annotations, embeddings = _inputs() + out = run_transfer( + annotations=annotations, + embeddings=embeddings, + transfer_columns=["protein_category"], + query_rule=Rule(id_prefixes=["TRINITY_"]), + reference_rule=Rule(where=[("protein_category", "")]), # any non-empty ref + k=1, + metric="euclidean", + ) + by_id = {r["identifier"]: r for r in out.to_pylist()} + assert by_id["TRINITY_1"]["protein_category__pred_value"] == "neurotoxin" + assert by_id["TRINITY_1"]["protein_category__pred_source"] == "P00001" + assert by_id["TRINITY_1"]["protein_category__pred_confidence"] > 0.9 + + +def test_run_transfer_skips_proteins_without_embeddings(): + annotations, embeddings = _inputs() + embeddings.pop("TRINITY_1") # no embedding -> cannot be a query + with pytest.raises(ValueError, match="no query"): + run_transfer( + annotations=annotations, + embeddings=embeddings, + transfer_columns=["protein_category"], + query_rule=Rule(id_prefixes=["TRINITY_"]), + reference_rule=Rule(id_prefixes=["P0"]), + k=1, + metric="euclidean", + ) + + +def test_transfer_command_is_registered(): + from typer.testing import CliRunner + + from protspace.cli.app import app + + result = CliRunner().invoke(app, ["transfer", "--help"]) + assert result.exit_code == 0 + assert "transfer" in result.output.lower() From c708f90f87e2714835d1ee288c6cea9279827541 Mon Sep 17 00:00:00 2001 From: tsenoner Date: Thu, 11 Jun 2026 20:27:54 +0200 Subject: [PATCH 15/21] fix(transfer): handle protein_id id column in real bundles; clearer errors MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Normalize protein_id→identifier before run_transfer and rename back after so real bundles (produced by protspace prepare) no longer KeyError. - Add ValueError when no bundle proteins match any embedding key. - Correct misleading comment in test_run_transfer_predicts_for_query_with_missing_value. - Add end-to-end regression test exercising the protein_id rename path. Co-Authored-By: Claude Sonnet 4.6 --- src/protspace/cli/transfer.py | 26 +++++++++++++++ tests/test_transfer_cli.py | 60 ++++++++++++++++++++++++++++++++++- 2 files changed, 85 insertions(+), 1 deletion(-) diff --git a/src/protspace/cli/transfer.py b/src/protspace/cli/transfer.py index 45ca098d..2557c84e 100644 --- a/src/protspace/cli/transfer.py +++ b/src/protspace/cli/transfer.py @@ -50,6 +50,12 @@ def run_transfer( ) embedded = annotations.filter(has_emb) + if embedded.num_rows == 0: + raise ValueError( + "No proteins in the bundle have a matching embedding " + "(check that the --embeddings ids match the bundle identifiers)." + ) + query_idx, ref_idx = classify(embedded, query_rule, reference_rule) rows = embedded.to_pylist() @@ -181,6 +187,19 @@ def _parse_where(items: list[str] | None) -> list[tuple[str, str]]: parts, _settings = read_bundle(bundle) annotations = pq.read_table(io.BytesIO(parts[0])) + # Real bundles name the id column "protein_id"; run_transfer works on "identifier". + id_col = "protein_id" if "protein_id" in annotations.column_names else "identifier" + if id_col != "identifier": + annotations = annotations.rename_columns( + ["identifier" if n == id_col else n for n in annotations.column_names] + ) + + for col in transfer_columns: + if col not in annotations.column_names: + raise typer.BadParameter( + f"--transfer column {col!r} not found in the bundle annotations" + ) + augmented = run_transfer( annotations=annotations, embeddings=emb_map, @@ -191,5 +210,12 @@ def _parse_where(items: list[str] | None) -> list[tuple[str, str]]: metric=metric, ) + # Rename id column back so the written bundle keeps its original name + # (the web frontend expects "protein_id"). + if id_col != "identifier": + augmented = augmented.rename_columns( + [id_col if n == "identifier" else n for n in augmented.column_names] + ) + replace_annotations_in_bundle(bundle, output, augmented) logger.info("Wrote transferred bundle to %s", output) diff --git a/tests/test_transfer_cli.py b/tests/test_transfer_cli.py index 797fd349..7a3a01c3 100644 --- a/tests/test_transfer_cli.py +++ b/tests/test_transfer_cli.py @@ -31,7 +31,9 @@ def test_run_transfer_predicts_for_query_with_missing_value(): embeddings=embeddings, transfer_columns=["protein_category"], query_rule=Rule(id_prefixes=["TRINITY_"]), - reference_rule=Rule(where=[("protein_category", "")]), # any non-empty ref + reference_rule=Rule( + where=[("protein_category", "")] + ), # matches all proteins; run_transfer keeps only those with a value k=1, metric="euclidean", ) @@ -64,3 +66,59 @@ def test_transfer_command_is_registered(): result = CliRunner().invoke(app, ["transfer", "--help"]) assert result.exit_code == 0 assert "transfer" in result.output.lower() + + +def test_cli_end_to_end_protein_id_bundle(tmp_path): + """Real bundles key the id column 'protein_id'; the CLI must handle it and + preserve that name on write while adding the overlay columns.""" + import io + + import h5py + import pyarrow.parquet as pq + from typer.testing import CliRunner + + from protspace.cli.app import app + from protspace.data.io.bundle import read_bundle, write_bundle + + annotations = pa.table( + {"protein_id": ["TRINITY_1", "P00001"], "protein_category": ["", "neurotoxin"]} + ) + proj_meta = pa.table({"name": ["PCA 2"], "dims": [2]}) + proj_data = pa.table( + {"id": ["TRINITY_1", "P00001"], "x": [0.0, 9.0], "y": [0.0, 0.0]} + ) + bundle_path = tmp_path / "in.parquetbundle" + write_bundle([annotations, proj_meta, proj_data], bundle_path) + + h5_path = tmp_path / "emb.h5" + with h5py.File(h5_path, "w") as f: + f.attrs["model_name"] = "test_model" + f.create_dataset("TRINITY_1", data=np.array([0.0, 0.0], dtype=np.float32)) + f.create_dataset("P00001", data=np.array([0.1, 0.0], dtype=np.float32)) + + out_path = tmp_path / "out.parquetbundle" + result = CliRunner().invoke( + app, + [ + "transfer", + "-b", + str(bundle_path), + "-e", + str(h5_path), + "-t", + "protein_category", + "-o", + str(out_path), + "--query-id-prefix", + "TRINITY_", + "--reference-id-prefix", + "P0", + ], + ) + assert result.exit_code == 0, result.output + parts, _ = read_bundle(out_path) + table = pq.read_table(io.BytesIO(parts[0])) + assert "protein_id" in table.column_names # id column preserved for the web reader + rows = {r["protein_id"]: r for r in table.to_pylist()} + assert rows["TRINITY_1"]["protein_category__pred_value"] == "neurotoxin" + assert rows["TRINITY_1"]["protein_category__pred_source"] == "P00001" From 0ee1354d28a3d29f68484848e72984708bc9fc61 Mon Sep 17 00:00:00 2001 From: tsenoner Date: Thu, 11 Jun 2026 20:33:05 +0200 Subject: [PATCH 16/21] docs: document protspace transfer + prediction overlay columns Co-Authored-By: Claude Sonnet 4.6 --- CLAUDE.md | 1 + docs/annotations.md | 14 ++++ docs/cli.md | 32 ++++++++ notebooks/ProtSpace_Transfer.ipynb | 122 +++++++++++++++++++++++++++++ 4 files changed, 169 insertions(+) create mode 100644 notebooks/ProtSpace_Transfer.ipynb diff --git a/CLAUDE.md b/CLAUDE.md index f3a8aeb7..a158c6db 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -49,6 +49,7 @@ Single entry point: `protspace = protspace.cli.app:app` | `protspace bundle` | Combine projections + annotations → .parquetbundle | | `protspace serve` | Launch Dash web frontend | | `protspace style` | Add annotation colors/styles | +| `protspace transfer` | Fill missing annotations from nearest reference embeddings (EAT) | ### protspace prepare Usage diff --git a/docs/annotations.md b/docs/annotations.md index 1f21d127..cfb3dab2 100644 --- a/docs/annotations.md +++ b/docs/annotations.md @@ -183,3 +183,17 @@ Per-protein predictions from the [Biocentral API](https://biocentral.rostlab.org | Pfam clans | `~/.cache/protspace/pfam_clans/` | 30 days | Pfam family → clan mapping | The `default` group only requires the UniProt REST API (+ ExPASy for EC names). For `--keep-tmp` annotation caching, see [CLI Reference](cli.md#annotation-caching---keep-tmp). + +## Prediction Overlay Columns (EAT Transfer) + +Running `protspace transfer` appends three new columns to the bundle's annotations table for each requested column `COL`. The curated `COL` column is never modified. + +| Column | Type | Meaning | +| --- | --- | --- | +| `COL__pred_value` | string | The transferred label from the nearest annotated reference protein | +| `COL__pred_confidence` | float | Reliability index in [0, 1]: `0.5 / (0.5 + distance)` — 1 = identical embeddings | +| `COL__pred_source` | string | UniProt accession (or ID) of the nearest reference protein | + +A protein is considered "predicted" for `COL` when `COL` is empty but `COL__pred_value` is present. Use `COL__pred_confidence` to threshold low-reliability transfers. + +See [`protspace transfer`](cli.md#protspace-transfer) for usage and option details. diff --git a/docs/cli.md b/docs/cli.md index ddb3fd90..c2890e6e 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -9,6 +9,7 @@ | `protspace bundle` | Combine projections + annotations into .parquetbundle | | `protspace serve` | Launch interactive Dash web frontend | | `protspace style` | Add/inspect annotation styles in existing files | +| `protspace transfer` | Fill missing annotations from nearest reference embeddings (EAT) | Run `protspace -h` for detailed help. @@ -183,6 +184,37 @@ protspace style input.parquetbundle output.parquetbundle --annotation-styles sty protspace style data.parquetbundle --dump-settings ``` +## `protspace transfer` + +Embedding Annotation Transfer (EAT): fills missing annotation values for query proteins by transferring the annotation of the nearest annotated reference protein in pLM embedding space. For each query protein that lacks a value in the requested annotation column, the command finds the closest reference (by Euclidean distance in the original high-dimensional embedding space — not in the 2-D/3-D projection) and assigns that reference's label along with a reliability index adapted from goPredSim (`confidence = 0.5 / (0.5 + distance)`), yielding a score in [0, 1] where 1 means identical embeddings. The curated source column (`COL`) is left untouched; results are written as three new columns: `COL__pred_value` (string), `COL__pred_confidence` (float), and `COL__pred_source` (the nearest reference protein ID). The method is a direct application of the approach introduced by Littmann et al., Sci Rep 2021 ([DOI 10.1038/s41598-020-80786-0](https://doi.org/10.1038/s41598-020-80786-0)) and extended by Heinzinger et al., NAR Genom Bioinform 2022 ([DOI 10.1093/nargab/lqac043](https://doi.org/10.1093/nargab/lqac043)). + +```bash +protspace transfer \ + -b results.parquetbundle \ + -e embeddings.h5:prot_t5 \ + -t protein_category \ + -o results.parquetbundle \ + --query-id-prefix TRINITY_ \ + --reference-where 'protein_category~neurotoxin' +``` + +**Key options:** + +| Flag | Description | Default | +| ---- | ----------- | ------- | +| `-b, --bundle` | Input `.parquetbundle` file | — | +| `-e, --embeddings` | HDF5 embeddings file (use `:name` suffix for external files) | — | +| `-t, --transfer` | Annotation column to transfer (repeatable) | — | +| `-o, --output` | Output `.parquetbundle` (may overwrite input) | — | +| `--query-id-prefix` | Restrict query proteins to IDs starting with this prefix | — | +| `--query-where` | Filter query proteins by annotation value (`col~substr`) | — | +| `--reference-id-prefix` | Restrict reference proteins to IDs starting with this prefix | — | +| `--reference-where` | Filter reference proteins by annotation value (`col~substr`) | — | +| `--k` | Number of nearest neighbours | `1` | +| `--metric` | Distance metric (`euclidean`, `cosine`, `manhattan`) | `euclidean` | + +Distances are computed in the original embedding space (HDF5), not in the 2-D/3-D projection. + ## Combining Multiple Inputs (`-i`) When multiple `-i` inputs are provided, behavior depends on whether they share the same embedding name: diff --git a/notebooks/ProtSpace_Transfer.ipynb b/notebooks/ProtSpace_Transfer.ipynb new file mode 100644 index 00000000..7af93236 --- /dev/null +++ b/notebooks/ProtSpace_Transfer.ipynb @@ -0,0 +1,122 @@ +{ + "metadata": { + "colab": { + "collapsed_sections": [], + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5, + "cells": [ + { + "cell_type": "markdown", + "id": "a1b2c3d4", + "metadata": {}, + "source": [ + "# ProtSpace — Embedding Annotation Transfer (EAT)\n", + "\n", + "This notebook demonstrates **Embedding Annotation Transfer (EAT)** with `protspace transfer`.\n", + "For each query protein that lacks an annotation value, the command finds the closest\n", + "annotated reference protein in pLM embedding space and transfers its label, together\n", + "with a reliability index (`confidence = 0.5 / (0.5 + distance)`, range [0, 1]).\n", + "The method follows the goPredSim approach introduced in:\n", + "\n", + "- Littmann et al., *Sci Rep* 2021 — [DOI 10.1038/s41598-020-80786-0](https://doi.org/10.1038/s41598-020-80786-0)\n", + "- Heinzinger et al., *NAR Genom Bioinform* 2022 — [DOI 10.1093/nargab/lqac043](https://doi.org/10.1093/nargab/lqac043)\n", + "\n", + "Distances are computed in the original high-dimensional embedding space (HDF5),\n", + "not in any 2-D/3-D projection. The curated source column is left untouched;\n", + "results are written as `COL__pred_value`, `COL__pred_confidence`, and `COL__pred_source`\n", + "columns in the bundle's annotations table.\n", + "\n", + "📚 [GitHub](https://github.com/tsenoner/protspace) · [CLI Reference](https://github.com/tsenoner/protspace/blob/main/docs/cli.md#protspace-transfer) · [Annotation Reference](https://github.com/tsenoner/protspace/blob/main/docs/annotations.md#prediction-overlay-columns-eat-transfer)" + ] + }, + { + "cell_type": "markdown", + "id": "b2c3d4e5", + "metadata": {}, + "source": [ + "## Install" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c3d4e5f6", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install protspace" + ] + }, + { + "cell_type": "markdown", + "id": "d4e5f6a7", + "metadata": {}, + "source": [ + "## Run transfer\n", + "\n", + "Transfer the `protein_category` annotation from annotated reference proteins\n", + "(filtered to those with `protein_category` containing `neurotoxin`) to\n", + "unannotated query proteins whose IDs start with `TRINITY_`.\n", + "\n", + "Adjust `-b`, `-e`, `-t`, `--query-id-prefix`, and `--reference-where` to match your data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e5f6a7b8", + "metadata": {}, + "outputs": [], + "source": [ + "!protspace transfer \\\n", + " -b results.parquetbundle \\\n", + " -e embeddings.h5:prot_t5 \\\n", + " -t protein_category \\\n", + " -o results.parquetbundle \\\n", + " --query-id-prefix TRINITY_ \\\n", + " --reference-where 'protein_category~neurotoxin'" + ] + }, + { + "cell_type": "markdown", + "id": "f6a7b8c9", + "metadata": {}, + "source": [ + "## Read predictions back\n", + "\n", + "Load the updated bundle and inspect the `*__pred_*` columns for the transferred annotations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a7b8c9d0", + "metadata": {}, + "outputs": [], + "source": [ + "import io\n", + "import pyarrow.parquet as pq\n", + "from protspace.data.io.bundle import read_bundle\n", + "\n", + "parts, _ = read_bundle(\"results.parquetbundle\")\n", + "df = pq.read_table(io.BytesIO(parts[0])).to_pandas()\n", + "\n", + "pred_cols = [c for c in df.columns if c.endswith(\"__pred_value\")]\n", + "df[df[pred_cols[0]].notna()].head() if pred_cols else df.head()" + ] + } + ] +} From 21d508ceef7de14c457f78fd183fc834f8cce24f Mon Sep 17 00:00:00 2001 From: tsenoner Date: Thu, 11 Jun 2026 20:36:04 +0200 Subject: [PATCH 17/21] docs: correct transfer --metric options (euclidean, cosine only) --- docs/cli.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/cli.md b/docs/cli.md index c2890e6e..b18ae744 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -211,7 +211,7 @@ protspace transfer \ | `--reference-id-prefix` | Restrict reference proteins to IDs starting with this prefix | — | | `--reference-where` | Filter reference proteins by annotation value (`col~substr`) | — | | `--k` | Number of nearest neighbours | `1` | -| `--metric` | Distance metric (`euclidean`, `cosine`, `manhattan`) | `euclidean` | +| `--metric` | Distance metric (`euclidean`, `cosine`) | `euclidean` | Distances are computed in the original embedding space (HDF5), not in the 2-D/3-D projection. From a05e977f051b5743bc290068f96c64c2116335d4 Mon Sep 17 00:00:00 2001 From: tsenoner Date: Thu, 11 Jun 2026 20:42:57 +0200 Subject: [PATCH 18/21] feat(transfer): warn on zero transfers; validate --metric/--k early Co-Authored-By: Claude Sonnet 4.6 --- src/protspace/cli/transfer.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/protspace/cli/transfer.py b/src/protspace/cli/transfer.py index 2557c84e..e6035b07 100644 --- a/src/protspace/cli/transfer.py +++ b/src/protspace/cli/transfer.py @@ -60,6 +60,7 @@ def run_transfer( rows = embedded.to_pylist() out = annotations + total_transferred = 0 for column in transfer_columns: if column not in annotations.column_names: raise KeyError(f"Transfer column {column!r} not in annotations table") @@ -98,8 +99,14 @@ def run_transfer( metric=metric, ) out = add_overlay_columns(out, column, preds) + total_transferred += len(preds) logger.info("Transferred %r to %d quer(ies)", column, len(preds)) + if total_transferred == 0: + logger.warning( + "No annotations were transferred. Check the --reference-* rules and " + "that query proteins have missing values in the target column(s)." + ) return out @@ -176,6 +183,11 @@ def _parse_where(items: list[str] | None) -> list[tuple[str, str]]: id_prefixes=reference_id_prefix or [], where=_parse_where(reference_where) ) + if metric not in ("euclidean", "cosine"): + raise typer.BadParameter("--metric must be 'euclidean' or 'cosine'") + if k < 1: + raise typer.BadParameter("--k must be >= 1") + # Load embeddings (name override after ':'). h5_spec = embeddings.split(":", 1) h5_path = Path(h5_spec[0]) From 98b42f664869a8af082aa5652aeda5e95e955b3d Mon Sep 17 00:00:00 2001 From: tsenoner Date: Fri, 12 Jun 2026 19:45:03 +0200 Subject: [PATCH 19/21] chore(docs): remove EAT build plan + superseded draft; keep design spec --- .../plans/2026-06-11-eat-transfer-backend.md | 1671 ----------------- .../2026-05-27-neighbors-subcommand-design.md | 14 - ...26-06-11-eat-annotation-transfer-design.md | 2 +- 3 files changed, 1 insertion(+), 1686 deletions(-) delete mode 100644 docs/superpowers/plans/2026-06-11-eat-transfer-backend.md delete mode 100644 docs/superpowers/specs/2026-05-27-neighbors-subcommand-design.md diff --git a/docs/superpowers/plans/2026-06-11-eat-transfer-backend.md b/docs/superpowers/plans/2026-06-11-eat-transfer-backend.md deleted file mode 100644 index a3afcd20..00000000 --- a/docs/superpowers/plans/2026-06-11-eat-transfer-backend.md +++ /dev/null @@ -1,1671 +0,0 @@ -# EAT Annotation-Transfer Backend Implementation Plan - -> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. - -**Goal:** Add a `protlabel` embedding-annotation-transfer engine and a `protspace transfer` CLI subcommand that fills in missing annotation values for query proteins from their nearest reference neighbours in pLM embedding space, writing a per-cell prediction overlay back into the `.parquetbundle`. - -**Architecture:** `protlabel` is a small, ProtSpace-agnostic package (numpy/scipy/h5py only) that does the kNN search + goPredSim reliability index + label transfer. `protspace transfer` is a thin Typer subcommand that reads a bundle + HDF5 embeddings, classifies query vs reference proteins, calls `protlabel`, and appends `__pred_value` / `__pred_confidence` / `__pred_source` columns to the bundle's annotations table. Default = Euclidean, k=1. Optional gating/mining/report are out of scope for this MVP. - -**Tech Stack:** Python ≥3.10, numpy, scipy (new dep), h5py, pyarrow, Typer, pytest, ruff, uv. - -**Spec:** `docs/superpowers/specs/2026-06-11-eat-annotation-transfer-design.md`. The two refinements below override the spec where they differ (and the spec's §4/§10 are updated to match): -- **Packaging:** `protlabel` ships as a second top-level package *inside the protspace repo* (`src/protlabel/`), bundled into the protspace wheel — not a suite-level uv workspace member (the suite root is not a uv workspace, and a separate PyPI distribution would need its own release/CI). The strict no-`protspace`-imports boundary keeps a future standalone split trivial. -- **Overlay storage:** extra `*__pred_*` columns on the existing annotations table (bundle part 1), **not** a new 5th bundle part — this is backward-compatible with the currently-deployed web reader, which tolerates unknown columns but parses a fixed part count. - -**Out of scope (follow-up plans):** the `protspace_web` frontend rendering (separate repo/PR), optional `--cutoff` gating, `--mine` neighborhood mining, `--report`/`--plots`/`--full-tables`, faiss-cpu ANN backend, ProtTucker learned distance. - -**Conventions (enforced):** all Python via `uv run`; deps via `uv add` (never hand-edit `[project.dependencies]`); ruff clean; update docs + Colab notebook before committing; feature branch + PR, never push to `main`; commit prefixes — `feat:` only for user-visible changes, `chore:`/`test:`/`refactor:` for the rest. - ---- - -## File Structure - -**New package `protlabel` (`src/protlabel/`):** -- `src/protlabel/__init__.py` — public API: `eat`, `Lookup`, `Prediction`. -- `src/protlabel/reliability.py` — `similarity()`, the goPredSim distance→`[0,1]` transform. -- `src/protlabel/backends.py` — `nearest()`, chunked brute-force kNN (Euclidean/cosine). -- `src/protlabel/transfer.py` — `Prediction` dataclass + `eat()`/`transfer_labels()` (kNN → RI → label). -- `src/protlabel/lookup.py` — `Lookup` dataclass: build / `save` / `load` (sidecar `.npz`) / `query`. - -**ProtSpace additions/changes:** -- Create `src/protspace/analysis/__init__.py`, `src/protspace/analysis/classification.py` — query/reference classifier. -- Create `src/protspace/data/io/predictions.py` — build overlay columns from predictions. -- Create `src/protspace/cli/transfer.py` — `run_transfer()` (pure) + the `transfer` Typer command. -- Modify `src/protspace/data/io/bundle.py` — add `replace_annotations_in_bundle()`. -- Modify `src/protspace/cli/app.py:65-73` — register the `transfer` subcommand. -- Modify `pyproject.toml` — add `[tool.hatch.build.targets.wheel] packages` (both packages); add `scipy` via `uv add`. - -**Tests (all under `tests/`, run with `uv run pytest tests/`):** -- `tests/test_protlabel_reliability.py`, `tests/test_protlabel_backends.py`, `tests/test_protlabel_transfer.py`, `tests/test_protlabel_lookup.py` -- `tests/test_classification.py`, `tests/test_predictions_overlay.py`, `tests/test_bundle_overlay.py`, `tests/test_transfer_cli.py` - -**Docs:** -- Modify `docs/cli.md`, `docs/annotations.md`, top-level `../CLAUDE.md` CLI table; add `notebooks/ProtSpace_Transfer.ipynb`. - ---- - -## Task 1: Scaffold the `protlabel` package - -**Files:** -- Create: `src/protlabel/__init__.py` -- Modify: `pyproject.toml` (hatchling packages + scipy dep) -- Test: `tests/test_protlabel_reliability.py` (placeholder import test, expanded in Task 2) - -- [ ] **Step 1: Create the package marker with a version** - -Create `src/protlabel/__init__.py`: - -```python -"""protlabel — Embedding Annotation Transfer (EAT) engine. - -Nearest-neighbour label transfer in protein-language-model embedding space, -with the goPredSim reliability index. Pure numpy/scipy/h5py; no protspace imports. -""" - -__version__ = "0.1.0" -``` - -- [ ] **Step 2: Tell hatchling to build both packages** - -Edit `pyproject.toml`. After the `[build-system]` block (around line 71), add: - -```toml -[tool.hatch.build.targets.wheel] -packages = ["src/protspace", "src/protlabel"] -``` - -- [ ] **Step 3: Add the scipy dependency (via uv, not by hand)** - -Run: `uv add 'scipy>=1.10'` -Expected: `pyproject.toml` gains `scipy>=1.10` in `[project.dependencies]` and `uv.lock` updates. - -- [ ] **Step 4: Sync so `import protlabel` resolves** - -Run: `uv sync` -Then verify: `uv run python -c "import protlabel; print(protlabel.__version__)"` -Expected: prints `0.1.0` - -- [ ] **Step 5: Write a smoke test** - -Create `tests/test_protlabel_reliability.py`: - -```python -"""Tests for protlabel.reliability.""" - - -def test_protlabel_imports(): - import protlabel - - assert protlabel.__version__ -``` - -- [ ] **Step 6: Run the smoke test** - -Run: `uv run pytest tests/test_protlabel_reliability.py -v` -Expected: PASS - -- [ ] **Step 7: Commit** - -```bash -git add src/protlabel/__init__.py pyproject.toml uv.lock tests/test_protlabel_reliability.py -git commit -m "chore(protlabel): scaffold EAT engine package + scipy dep" -``` - ---- - -## Task 2: Reliability index (`protlabel.reliability`) - -**Files:** -- Create: `src/protlabel/reliability.py` -- Test: `tests/test_protlabel_reliability.py` - -- [ ] **Step 1: Write the failing tests** - -Replace `tests/test_protlabel_reliability.py` with: - -```python -"""Tests for protlabel.reliability.""" - -import math - -import pytest - -from protlabel.reliability import similarity - - -def test_euclidean_at_zero_distance_is_one(): - assert similarity(0.0, "euclidean") == pytest.approx(1.0) - - -def test_euclidean_at_half_distance_is_half(): - assert similarity(0.5, "euclidean") == pytest.approx(0.5) - - -def test_euclidean_decreases_to_zero(): - assert similarity(1e9, "euclidean") == pytest.approx(0.0, abs=1e-6) - - -def test_cosine_is_one_minus_distance(): - assert similarity(0.2, "cosine") == pytest.approx(0.8) - - -def test_cosine_clamped_to_unit_interval(): - # cosine distance can be up to 2.0 -> 1 - d would go negative; clamp at 0 - assert similarity(1.7, "cosine") == pytest.approx(0.0) - assert similarity(-0.1, "cosine") == pytest.approx(1.0) - - -def test_unknown_metric_raises(): - with pytest.raises(ValueError): - similarity(0.5, "manhattan") - - -def test_smoke(): - assert math.isfinite(similarity(0.5, "euclidean")) -``` - -- [ ] **Step 2: Run to verify failure** - -Run: `uv run pytest tests/test_protlabel_reliability.py -v` -Expected: FAIL — `ImportError: cannot import name 'similarity'` - -- [ ] **Step 3: Implement** - -Create `src/protlabel/reliability.py`: - -```python -"""goPredSim reliability index: map an embedding distance to a [0,1] confidence. - -Euclidean: s(d) = 0.5 / (0.5 + d) (1.0 at d=0, 0.5 at d=0.5, ->0 as d->inf) -Cosine: s(d) = 1 - d (clamped to [0,1]; cosine distance in [0,2]) - -Reference: Littmann et al., Sci Rep 2021 (Eq. 5); goPredSim calc_reliability_index. -""" - -from __future__ import annotations - - -def similarity(distance: float, metric: str) -> float: - """Per-neighbour distance->similarity (the goPredSim reliability transform).""" - if metric == "euclidean": - return 0.5 / (0.5 + distance) - if metric == "cosine": - return min(1.0, max(0.0, 1.0 - distance)) - raise ValueError(f"Unknown metric {metric!r}; expected 'euclidean' or 'cosine'") -``` - -- [ ] **Step 4: Run to verify pass** - -Run: `uv run pytest tests/test_protlabel_reliability.py -v` -Expected: PASS (all 7) - -- [ ] **Step 5: Lint + commit** - -```bash -uv run ruff check src/protlabel/reliability.py tests/test_protlabel_reliability.py -git add src/protlabel/reliability.py tests/test_protlabel_reliability.py -git commit -m "feat(protlabel): goPredSim reliability index transform" -``` - ---- - -## Task 3: Brute-force kNN backend (`protlabel.backends`) - -**Files:** -- Create: `src/protlabel/backends.py` -- Test: `tests/test_protlabel_backends.py` - -- [ ] **Step 1: Write the failing tests** - -Create `tests/test_protlabel_backends.py`: - -```python -"""Tests for protlabel.backends.nearest.""" - -import numpy as np -import pytest - -from protlabel.backends import nearest - - -def _toy(): - # 3 references on a line; queries close to ref 0 and ref 2 - refs = np.array([[0.0, 0.0], [10.0, 0.0], [20.0, 0.0]], dtype=np.float32) - queries = np.array([[0.1, 0.0], [19.5, 0.0]], dtype=np.float32) - return queries, refs - - -def test_returns_shapes(): - queries, refs = _toy() - idx, dist = nearest(queries, refs, k=2, metric="euclidean") - assert idx.shape == (2, 2) - assert dist.shape == (2, 2) - - -def test_nearest_index_euclidean(): - queries, refs = _toy() - idx, dist = nearest(queries, refs, k=1, metric="euclidean") - assert idx[0, 0] == 0 # first query nearest to ref 0 - assert idx[1, 0] == 2 # second query nearest to ref 2 - assert dist[0, 0] == pytest.approx(0.1, abs=1e-4) - - -def test_neighbours_sorted_by_distance(): - queries, refs = _toy() - idx, dist = nearest(queries, refs, k=3, metric="euclidean") - assert np.all(np.diff(dist, axis=1) >= -1e-6) # non-decreasing per row - - -def test_cosine_metric_runs_and_orders(): - refs = np.array([[1.0, 0.0], [0.0, 1.0]], dtype=np.float32) - queries = np.array([[1.0, 0.1]], dtype=np.float32) # closest in angle to ref 0 - idx, dist = nearest(queries, refs, k=1, metric="cosine") - assert idx[0, 0] == 0 - - -def test_k_capped_to_num_refs(): - queries, refs = _toy() - idx, dist = nearest(queries, refs, k=10, metric="euclidean") - assert idx.shape == (2, 3) # only 3 refs available - - -def test_chunking_matches_unchunked(): - rng = np.random.default_rng(0) - refs = rng.standard_normal((50, 8)).astype(np.float32) - queries = rng.standard_normal((7, 8)).astype(np.float32) - a_idx, a_dist = nearest(queries, refs, k=3, metric="euclidean", chunk=1000) - b_idx, b_dist = nearest(queries, refs, k=3, metric="euclidean", chunk=3) - assert np.array_equal(a_idx, b_idx) - assert np.allclose(a_dist, b_dist, atol=1e-5) - - -def test_unknown_metric_raises(): - queries, refs = _toy() - with pytest.raises(ValueError): - nearest(queries, refs, k=1, metric="manhattan") -``` - -- [ ] **Step 2: Run to verify failure** - -Run: `uv run pytest tests/test_protlabel_backends.py -v` -Expected: FAIL — `ModuleNotFoundError: No module named 'protlabel.backends'` - -- [ ] **Step 3: Implement** - -Create `src/protlabel/backends.py`: - -```python -"""Exact (brute-force) k-nearest-neighbour search over reference embeddings. - -Chunked over the query axis so the Q_chunk x N distance block stays small, -which keeps peak memory near the reference matrix itself even at Swiss-Prot -scale. scipy.cdist handles both euclidean and cosine. -""" - -from __future__ import annotations - -import numpy as np -from scipy.spatial.distance import cdist - -_METRICS = {"euclidean", "cosine"} - - -def nearest( - queries: np.ndarray, - refs: np.ndarray, - k: int, - metric: str = "euclidean", - chunk: int = 4096, -) -> tuple[np.ndarray, np.ndarray]: - """Return (idx, dist) of the k nearest *references* per query. - - idx[i] -> indices into ``refs`` of the k nearest, ascending by distance. - dist[i] -> the corresponding distances. - k is capped to the number of references. - """ - if metric not in _METRICS: - raise ValueError(f"Unknown metric {metric!r}; expected one of {_METRICS}") - - queries = np.ascontiguousarray(queries, dtype=np.float32) - refs = np.ascontiguousarray(refs, dtype=np.float32) - n_refs = refs.shape[0] - k = min(k, n_refs) - - idx_out = np.empty((queries.shape[0], k), dtype=np.int64) - dist_out = np.empty((queries.shape[0], k), dtype=np.float32) - - for start in range(0, queries.shape[0], chunk): - block = queries[start : start + chunk] - d = cdist(block, refs, metric=metric).astype(np.float32) # (b, n_refs) - part = np.argpartition(d, kth=k - 1, axis=1)[:, :k] # unsorted top-k - rows = np.arange(block.shape[0])[:, None] - part_d = d[rows, part] - order = np.argsort(part_d, axis=1) # sort the k by distance - sorted_idx = part[rows, order] - idx_out[start : start + block.shape[0]] = sorted_idx - dist_out[start : start + block.shape[0]] = d[rows, sorted_idx] - - return idx_out, dist_out -``` - -- [ ] **Step 4: Run to verify pass** - -Run: `uv run pytest tests/test_protlabel_backends.py -v` -Expected: PASS (all 7) - -- [ ] **Step 5: Lint + commit** - -```bash -uv run ruff check src/protlabel/backends.py tests/test_protlabel_backends.py -git add src/protlabel/backends.py tests/test_protlabel_backends.py -git commit -m "feat(protlabel): chunked brute-force kNN backend" -``` - ---- - -## Task 4: Label transfer + Prediction (`protlabel.transfer`) - -**Files:** -- Create: `src/protlabel/transfer.py` -- Test: `tests/test_protlabel_transfer.py` - -- [ ] **Step 1: Write the failing tests** - -Create `tests/test_protlabel_transfer.py`: - -```python -"""Tests for protlabel.transfer.""" - -import numpy as np -import pytest - -from protlabel.transfer import Prediction, eat - - -def _setup(): - ref_emb = np.array([[0.0, 0.0], [10.0, 0.0], [20.0, 0.0]], dtype=np.float32) - ref_ids = ["R0", "R1", "R2"] - ref_labels = ["toxin", "enzyme", "toxin"] - query_emb = np.array([[0.0, 0.0], [19.7, 0.0]], dtype=np.float32) - query_ids = ["Q0", "Q1"] - return ref_emb, ref_ids, ref_labels, query_emb, query_ids - - -def test_k1_transfers_nearest_label_and_source(): - ref_emb, ref_ids, ref_labels, query_emb, query_ids = _setup() - preds = eat(query_emb, query_ids, ref_emb, ref_ids, ref_labels, k=1) - assert isinstance(preds[0], Prediction) - assert preds[0].query_id == "Q0" - assert preds[0].label == "toxin" - assert preds[0].source_id == "R0" - assert preds[0].reliability == pytest.approx(1.0) # distance 0 -> RI 1.0 - - -def test_k1_reliability_uses_gopredsim_transform(): - ref_emb, ref_ids, ref_labels, query_emb, query_ids = _setup() - preds = eat(query_emb, query_ids, ref_emb, ref_ids, ref_labels, k=1) - # Q1 distance to R2 is 0.3 -> RI = 0.5/(0.5+0.3) - assert preds[1].label == "toxin" - assert preds[1].source_id == "R2" - assert preds[1].reliability == pytest.approx(0.5 / 0.8, abs=1e-4) - - -def test_k3_vote_picks_majority_label(): - # Query equidistant-ish but two of three nearest are "toxin" - ref_emb = np.array( - [[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [3.0, 0.0]], dtype=np.float32 - ) - ref_ids = ["R0", "R1", "R2", "R3"] - ref_labels = ["toxin", "enzyme", "toxin", "toxin"] - query_emb = np.array([[1.4, 0.0]], dtype=np.float32) - preds = eat(query_emb, ["Q"], ref_emb, ref_ids, ref_labels, k=3) - assert preds[0].label == "toxin" # toxin RI sum beats lone enzyme neighbour - assert 0.0 < preds[0].reliability <= 1.0 - - -def test_cosine_metric(): - ref_emb = np.array([[1.0, 0.0], [0.0, 1.0]], dtype=np.float32) - preds = eat( - np.array([[1.0, 0.05]], dtype=np.float32), - ["Q"], - ref_emb, - ["R0", "R1"], - ["a", "b"], - k=1, - metric="cosine", - ) - assert preds[0].label == "a" - - -def test_length_mismatch_raises(): - ref_emb, ref_ids, ref_labels, query_emb, query_ids = _setup() - with pytest.raises(ValueError): - eat(query_emb, query_ids, ref_emb, ref_ids, ref_labels[:-1], k=1) - - -def test_empty_references_raises(): - with pytest.raises(ValueError): - eat( - np.zeros((1, 2), dtype=np.float32), - ["Q"], - np.zeros((0, 2), dtype=np.float32), - [], - [], - k=1, - ) -``` - -- [ ] **Step 2: Run to verify failure** - -Run: `uv run pytest tests/test_protlabel_transfer.py -v` -Expected: FAIL — `ModuleNotFoundError: No module named 'protlabel.transfer'` - -- [ ] **Step 3: Implement** - -Create `src/protlabel/transfer.py`: - -```python -"""Embedding annotation transfer: kNN -> reliability index -> transferred label. - -Implements the goPredSim aggregation (Littmann et al. 2021, Eq. 5): - RI(p) = (1/k) * sum over neighbours carrying label p of similarity(d). -The transferred label is argmax RI(p); its source is the nearest neighbour -carrying that label. -""" - -from __future__ import annotations - -from dataclasses import dataclass - -import numpy as np - -from protlabel.backends import nearest -from protlabel.reliability import similarity - - -@dataclass(frozen=True) -class Prediction: - """One transferred annotation for a query protein.""" - - query_id: str - label: str - source_id: str - distance: float - reliability: float - k: int - metric: str - - -def eat( - query_emb: np.ndarray, - query_ids: list[str], - ref_emb: np.ndarray, - ref_ids: list[str], - ref_labels: list[str], - *, - k: int = 1, - metric: str = "euclidean", -) -> list[Prediction]: - """Transfer the best-guess label to each query from its k nearest references.""" - if not (len(ref_ids) == len(ref_labels) == ref_emb.shape[0]): - raise ValueError("ref_emb, ref_ids and ref_labels must have equal length") - if ref_emb.shape[0] == 0: - raise ValueError("No reference embeddings to transfer from") - if len(query_ids) != query_emb.shape[0]: - raise ValueError("query_emb and query_ids must have equal length") - - idx, dist = nearest(query_emb, ref_emb, k=k, metric=metric) - eff_k = idx.shape[1] - predictions: list[Prediction] = [] - - for qi, query_id in enumerate(query_ids): - neigh_idx = idx[qi] - neigh_dist = dist[qi] - # Accumulate RI per label and track the nearest source per label. - ri_by_label: dict[str, float] = {} - nearest_src: dict[str, tuple[float, str]] = {} - for j, ref_i in enumerate(neigh_idx): - lab = ref_labels[ref_i] - d = float(neigh_dist[j]) - ri_by_label[lab] = ri_by_label.get(lab, 0.0) + similarity(d, metric) - if lab not in nearest_src or d < nearest_src[lab][0]: - nearest_src[lab] = (d, ref_ids[ref_i]) - # Normalise by k (the goPredSim 1/k term). - best_label = max(ri_by_label, key=lambda p: ri_by_label[p]) - ri = ri_by_label[best_label] / eff_k - src_dist, src_id = nearest_src[best_label] - predictions.append( - Prediction( - query_id=query_id, - label=best_label, - source_id=src_id, - distance=src_dist, - reliability=ri, - k=eff_k, - metric=metric, - ) - ) - - return predictions -``` - -- [ ] **Step 4: Run to verify pass** - -Run: `uv run pytest tests/test_protlabel_transfer.py -v` -Expected: PASS (all 6) - -- [ ] **Step 5: Lint + commit** - -```bash -uv run ruff check src/protlabel/transfer.py tests/test_protlabel_transfer.py -git add src/protlabel/transfer.py tests/test_protlabel_transfer.py -git commit -m "feat(protlabel): kNN label transfer with reliability index" -``` - ---- - -## Task 5: Persistable lookup (`protlabel.lookup`) + public API - -**Files:** -- Create: `src/protlabel/lookup.py` -- Modify: `src/protlabel/__init__.py` -- Test: `tests/test_protlabel_lookup.py` - -- [ ] **Step 1: Write the failing tests** - -Create `tests/test_protlabel_lookup.py`: - -```python -"""Tests for protlabel.lookup.Lookup (the rebuildable sidecar).""" - -import numpy as np - -from protlabel import Lookup, Prediction - - -def _lookup(): - emb = np.array([[0.0, 0.0], [10.0, 0.0]], dtype=np.float32) - return Lookup(embeddings=emb, ids=["R0", "R1"], labels=["a", "b"]) - - -def test_query_returns_predictions(): - lk = _lookup() - preds = lk.query(np.array([[0.2, 0.0]], dtype=np.float32), ["Q0"], k=1) - assert isinstance(preds[0], Prediction) - assert preds[0].label == "a" - assert preds[0].source_id == "R0" - - -def test_save_load_roundtrip(tmp_path): - lk = _lookup() - path = tmp_path / "lookup.npz" - lk.save(path) - assert path.exists() - loaded = Lookup.load(path) - assert loaded.ids == lk.ids - assert loaded.labels == lk.labels - assert loaded.metric == lk.metric - assert np.allclose(loaded.embeddings, lk.embeddings) - - -def test_loaded_lookup_queries_identically(tmp_path): - lk = _lookup() - q = np.array([[9.8, 0.0]], dtype=np.float32) - before = lk.query(q, ["Q"], k=1) - path = tmp_path / "lk.npz" - lk.save(path) - after = Lookup.load(path).query(q, ["Q"], k=1) - assert before[0].label == after[0].label == "b" - assert before[0].reliability == after[0].reliability -``` - -- [ ] **Step 2: Run to verify failure** - -Run: `uv run pytest tests/test_protlabel_lookup.py -v` -Expected: FAIL — `ImportError: cannot import name 'Lookup' from 'protlabel'` - -- [ ] **Step 3: Implement the Lookup** - -Create `src/protlabel/lookup.py`: - -```python -"""A persistable reference lookup: embeddings + ids + labels, plus query(). - -Serialized as a single .npz sidecar so it can live next to a bundle or in a -cache dir and be rebuilt on demand from the source HDF5. -""" - -from __future__ import annotations - -from dataclasses import dataclass, field -from pathlib import Path - -import numpy as np - -from protlabel.transfer import Prediction, eat - - -@dataclass -class Lookup: - """Reference set for embedding annotation transfer.""" - - embeddings: np.ndarray # (N, D) float32 - ids: list[str] - labels: list[str] - metric: str = "euclidean" - model: str = field(default="") - - def query( - self, query_emb: np.ndarray, query_ids: list[str], *, k: int = 1 - ) -> list[Prediction]: - """Transfer labels to query embeddings from this lookup.""" - return eat( - query_emb, - query_ids, - self.embeddings, - self.ids, - self.labels, - k=k, - metric=self.metric, - ) - - def save(self, path: Path) -> None: - """Serialize to a .npz sidecar.""" - path = Path(path) - path.parent.mkdir(parents=True, exist_ok=True) - np.savez( - path, - embeddings=self.embeddings.astype(np.float16), - ids=np.array(self.ids, dtype=object), - labels=np.array(self.labels, dtype=object), - metric=self.metric, - model=self.model, - ) - - @classmethod - def load(cls, path: Path) -> Lookup: - """Load a .npz sidecar (re-upcasts embeddings to float32).""" - with np.load(path, allow_pickle=True) as data: - return cls( - embeddings=data["embeddings"].astype(np.float32), - ids=list(data["ids"]), - labels=list(data["labels"]), - metric=str(data["metric"]), - model=str(data["model"]), - ) -``` - -> Note: `np.savez` appends `.npz` if the path has no extension; the tests use an explicit `.npz` so the saved file matches the requested path. - -- [ ] **Step 4: Export the public API** - -Replace `src/protlabel/__init__.py` with: - -```python -"""protlabel — Embedding Annotation Transfer (EAT) engine. - -Nearest-neighbour label transfer in protein-language-model embedding space, -with the goPredSim reliability index. Pure numpy/scipy/h5py; no protspace imports. -""" - -from protlabel.lookup import Lookup -from protlabel.transfer import Prediction, eat - -__version__ = "0.1.0" - -__all__ = ["Lookup", "Prediction", "eat", "__version__"] -``` - -- [ ] **Step 5: Run to verify pass** - -Run: `uv run pytest tests/test_protlabel_lookup.py -v` -Expected: PASS (all 3) - -- [ ] **Step 6: Run the whole protlabel suite + guard the boundary** - -Run: `uv run pytest tests/test_protlabel_*.py -v` -Expected: PASS -Run: `! grep -rqE "import protspace|from protspace" src/protlabel/ && echo "boundary clean"` -Expected: prints `boundary clean` (protlabel must not import protspace) - -- [ ] **Step 7: Lint + commit** - -```bash -uv run ruff check src/protlabel/ tests/test_protlabel_lookup.py -git add src/protlabel/lookup.py src/protlabel/__init__.py tests/test_protlabel_lookup.py -git commit -m "feat(protlabel): persistable Lookup sidecar + public API" -``` - ---- - -## Task 6: Query/reference classifier (`protspace.analysis.classification`) - -**Files:** -- Create: `src/protspace/analysis/__init__.py`, `src/protspace/analysis/classification.py` -- Test: `tests/test_classification.py` - -- [ ] **Step 1: Write the failing tests** - -Create `tests/test_classification.py`: - -```python -"""Tests for the query/reference classifier.""" - -import pyarrow as pa -import pytest - -from protspace.analysis.classification import Rule, classify - - -def _table(): - return pa.table( - { - "identifier": ["TRINITY_1", "TRINITY_2", "P00001", "P00002"], - "protein_category": ["mSCR", "mSCR", "neurotoxin", "enzyme"], - } - ) - - -def test_prefix_rule_selects_queries(): - q = Rule(id_prefixes=["TRINITY_"]) - r = Rule(where=[("protein_category", "neurotoxin")]) - qi, ri = classify(_table(), q, r) - assert qi == [0, 1] - assert ri == [2] - - -def test_where_substring_is_case_insensitive(): - q = Rule(where=[("protein_category", "MSCR")]) - r = Rule(id_prefixes=["P0"]) - qi, ri = classify(_table(), q, r) - assert qi == [0, 1] - assert ri == [2, 3] - - -def test_query_takes_precedence_over_reference(): - # A protein matching both rules is classified as a query, never a reference. - q = Rule(id_prefixes=["P00001"]) - r = Rule(where=[("protein_category", "neurotoxin")]) - qi, ri = classify(_table(), q, r) - assert 2 in qi - assert 2 not in ri - - -def test_empty_query_match_raises(): - q = Rule(id_prefixes=["NOPE_"]) - r = Rule(id_prefixes=["P0"]) - with pytest.raises(ValueError, match="no query"): - classify(_table(), q, r) - - -def test_missing_where_column_raises(): - q = Rule(where=[("not_a_column", "x")]) - r = Rule(id_prefixes=["P0"]) - with pytest.raises(KeyError): - classify(_table(), q, r) -``` - -- [ ] **Step 2: Run to verify failure** - -Run: `uv run pytest tests/test_classification.py -v` -Expected: FAIL — `ModuleNotFoundError: No module named 'protspace.analysis'` - -- [ ] **Step 3: Implement** - -Create `src/protspace/analysis/__init__.py`: - -```python -"""Optional analysis layer for ProtSpace (classification, gating, mining).""" -``` - -Create `src/protspace/analysis/classification.py`: - -```python -"""Classify proteins as transfer queries vs annotated references. - -Rules match by ID prefix and/or a case-insensitive metadata substring -(``column ~ substring``). No biology is hardcoded; a query rule that matches -nothing is an error. -""" - -from __future__ import annotations - -from dataclasses import dataclass, field - -import pyarrow as pa - - -@dataclass -class Rule: - """A classification rule. A protein matches if ANY clause matches.""" - - id_prefixes: list[str] = field(default_factory=list) - where: list[tuple[str, str]] = field(default_factory=list) # (column, substring) - - -def _matches(rule: Rule, identifier: str, row: dict[str, str]) -> bool: - if any(identifier.startswith(p) for p in rule.id_prefixes): - return True - for column, substring in rule.where: - if column not in row: - raise KeyError(f"Classification column {column!r} not in annotations") - value = row[column] - if value is not None and substring.lower() in str(value).lower(): - return True - return False - - -def classify( - annotations: pa.Table, query_rule: Rule, reference_rule: Rule -) -> tuple[list[int], list[int]]: - """Return (query_indices, reference_indices) into the annotations table. - - Query classification takes precedence: a protein matching both rules is a - query. Raises ValueError if the query rule matches nothing. - """ - columns = set(annotations.column_names) - # Validate where-columns up front so an empty table still raises KeyError. - for rule in (query_rule, reference_rule): - for column, _ in rule.where: - if column not in columns: - raise KeyError(f"Classification column {column!r} not in annotations") - - rows = annotations.to_pylist() - identifiers = [str(r["identifier"]) for r in rows] - - query_indices: list[int] = [] - reference_indices: list[int] = [] - for i, (identifier, row) in enumerate(zip(identifiers, rows, strict=True)): - if _matches(query_rule, identifier, row): - query_indices.append(i) - elif _matches(reference_rule, identifier, row): - reference_indices.append(i) - - if not query_indices: - raise ValueError( - "Classifier matched no query proteins; check --query-id-prefix / " - "--query-where rules." - ) - return query_indices, reference_indices -``` - -- [ ] **Step 4: Run to verify pass** - -Run: `uv run pytest tests/test_classification.py -v` -Expected: PASS (all 5) - -- [ ] **Step 5: Lint + commit** - -```bash -uv run ruff check src/protspace/analysis/ tests/test_classification.py -git add src/protspace/analysis/ tests/test_classification.py -git commit -m "feat: query/reference classifier for annotation transfer" -``` - ---- - -## Task 7: Build the overlay columns (`protspace.data.io.predictions`) - -**Files:** -- Create: `src/protspace/data/io/predictions.py` -- Test: `tests/test_predictions_overlay.py` - -- [ ] **Step 1: Write the failing tests** - -Create `tests/test_predictions_overlay.py`: - -```python -"""Tests for building the per-cell prediction overlay columns.""" - -import pyarrow as pa - -from protlabel import Prediction -from protspace.data.io.predictions import add_overlay_columns - - -def _table(): - return pa.table( - { - "identifier": ["Q0", "Q1", "R0"], - "protein_category": ["", "", "neurotoxin"], - } - ) - - -def test_adds_three_overlay_columns(): - preds = [ - Prediction("Q0", "neurotoxin", "R0", 0.3, 0.62, 1, "euclidean"), - ] - out = add_overlay_columns(_table(), "protein_category", preds) - assert "protein_category__pred_value" in out.column_names - assert "protein_category__pred_confidence" in out.column_names - assert "protein_category__pred_source" in out.column_names - - -def test_overlay_values_aligned_by_identifier(): - preds = [Prediction("Q1", "enzyme", "R9", 0.5, 0.5, 1, "euclidean")] - out = add_overlay_columns(_table(), "protein_category", preds).to_pylist() - by_id = {r["identifier"]: r for r in out} - assert by_id["Q1"]["protein_category__pred_value"] == "enzyme" - assert by_id["Q1"]["protein_category__pred_confidence"] == 0.5 - assert by_id["Q1"]["protein_category__pred_source"] == "R9" - # Non-predicted rows are null in the overlay columns. - assert by_id["Q0"]["protein_category__pred_value"] is None - assert by_id["R0"]["protein_category__pred_confidence"] is None - - -def test_curated_column_is_left_untouched(): - preds = [Prediction("Q0", "neurotoxin", "R0", 0.1, 0.8, 1, "euclidean")] - out = add_overlay_columns(_table(), "protein_category", preds).to_pylist() - by_id = {r["identifier"]: r for r in out} - assert by_id["Q0"]["protein_category"] == "" # original column unchanged - assert by_id["R0"]["protein_category"] == "neurotoxin" - - -def test_confidence_column_is_float(): - preds = [Prediction("Q0", "x", "R0", 0.1, 0.83, 1, "euclidean")] - out = add_overlay_columns(_table(), "protein_category", preds) - field = out.schema.field("protein_category__pred_confidence") - assert pa.types.is_floating(field.type) -``` - -- [ ] **Step 2: Run to verify failure** - -Run: `uv run pytest tests/test_predictions_overlay.py -v` -Expected: FAIL — `ModuleNotFoundError: No module named 'protspace.data.io.predictions'` - -- [ ] **Step 3: Implement** - -Create `src/protspace/data/io/predictions.py`: - -```python -"""Turn protlabel Predictions into per-cell overlay columns on the annotations table. - -For a transferred column ``COL`` we append three aligned columns (null for -non-predicted proteins), leaving the curated ``COL`` untouched: - COL__pred_value (string) the transferred label - COL__pred_confidence (float32) the reliability index in [0, 1] - COL__pred_source (string) the nearest reference protein id -""" - -from __future__ import annotations - -from collections.abc import Sequence - -import pyarrow as pa - -from protlabel import Prediction - - -def add_overlay_columns( - annotations: pa.Table, column: str, predictions: Sequence[Prediction] -) -> pa.Table: - """Append the COL__pred_* overlay columns, aligned by identifier.""" - by_query = {p.query_id: p for p in predictions} - identifiers = [str(v) for v in annotations.column("identifier").to_pylist()] - - values: list[str | None] = [] - confidences: list[float | None] = [] - sources: list[str | None] = [] - for identifier in identifiers: - pred = by_query.get(identifier) - if pred is None: - values.append(None) - confidences.append(None) - sources.append(None) - else: - values.append(pred.label) - confidences.append(float(pred.reliability)) - sources.append(pred.source_id) - - out = annotations - out = out.append_column(f"{column}__pred_value", pa.array(values, pa.string())) - out = out.append_column( - f"{column}__pred_confidence", pa.array(confidences, pa.float32()) - ) - out = out.append_column(f"{column}__pred_source", pa.array(sources, pa.string())) - return out -``` - -- [ ] **Step 4: Run to verify pass** - -Run: `uv run pytest tests/test_predictions_overlay.py -v` -Expected: PASS (all 4) - -- [ ] **Step 5: Lint + commit** - -```bash -uv run ruff check src/protspace/data/io/predictions.py tests/test_predictions_overlay.py -git add src/protspace/data/io/predictions.py tests/test_predictions_overlay.py -git commit -m "feat: build per-cell prediction overlay columns" -``` - ---- - -## Task 8: Rewrite the annotations part of a bundle (`bundle.replace_annotations_in_bundle`) - -**Files:** -- Modify: `src/protspace/data/io/bundle.py` -- Test: `tests/test_bundle_overlay.py` - -- [ ] **Step 1: Write the failing tests** - -Create `tests/test_bundle_overlay.py`: - -```python -"""Round-trip tests for replacing the annotations part of a bundle.""" - -import io - -import pyarrow as pa -import pyarrow.parquet as pq - -from protspace.data.io.bundle import ( - read_bundle, - replace_annotations_in_bundle, - write_bundle, -) - - -def _tables(): - annotations = pa.table({"identifier": ["A", "B"], "cat": ["x", "y"]}) - proj_meta = pa.table({"name": ["PCA 2"], "dims": [2]}) - proj_data = pa.table({"id": ["A", "B"], "x": [0.0, 1.0], "y": [0.0, 1.0]}) - return [annotations, proj_meta, proj_data] - - -def _read_part(part_bytes): - return pq.read_table(io.BytesIO(part_bytes)) - - -def test_replaces_annotations_keeps_other_parts(tmp_path): - src = tmp_path / "in.parquetbundle" - out = tmp_path / "out.parquetbundle" - write_bundle(_tables(), src) - - new_annotations = pa.table( - {"identifier": ["A", "B"], "cat": ["x", "y"], "cat__pred_value": [None, "z"]} - ) - replace_annotations_in_bundle(src, out, new_annotations) - - parts, settings = read_bundle(out) - assert "cat__pred_value" in _read_part(parts[0]).column_names - # Projections preserved byte-for-byte. - assert _read_part(parts[1]).column_names == ["name", "dims"] - assert _read_part(parts[2]).to_pydict()["x"] == [0.0, 1.0] - - -def test_preserves_settings_when_present(tmp_path): - src = tmp_path / "in.parquetbundle" - out = tmp_path / "out.parquetbundle" - write_bundle(_tables(), src, settings={"foo": 1}) - - new_annotations = pa.table({"identifier": ["A", "B"], "cat": ["x", "y"]}) - replace_annotations_in_bundle(src, out, new_annotations) - - _parts, settings = read_bundle(out) - assert settings == {"foo": 1} -``` - -- [ ] **Step 2: Run to verify failure** - -Run: `uv run pytest tests/test_bundle_overlay.py -v` -Expected: FAIL — `ImportError: cannot import name 'replace_annotations_in_bundle'` - -- [ ] **Step 3: Implement** - -Add to `src/protspace/data/io/bundle.py` (after `replace_settings_in_bundle`, around line 149): - -```python -def replace_annotations_in_bundle( - input_path: Path, - output_path: Path, - annotations_table: pa.Table, -) -> None: - """Replace the annotations (1st) part of a bundle, preserving the rest. - - Projection parts (2nd, 3rd) are kept byte-for-byte; an existing settings - (4th) part is carried over unchanged. - """ - with open(input_path, "rb") as f: - content = f.read() - - parts = content.split(PARQUET_BUNDLE_DELIMITER) - if len(parts) < 3 or len(parts) > 4: - raise ValueError(f"Expected 3 or 4 parts in parquetbundle, found {len(parts)}") - - buf = io.BytesIO() - pq.write_table(annotations_table, buf) - new_parts = [buf.getvalue(), parts[1], parts[2]] - if len(parts) == 4: - new_parts.append(parts[3]) - - output_path.parent.mkdir(parents=True, exist_ok=True) - with open(output_path, "wb") as f: - f.write(PARQUET_BUNDLE_DELIMITER.join(new_parts)) - - logger.info(f"Wrote bundle with updated annotations to: {output_path}") -``` - -- [ ] **Step 4: Run to verify pass** - -Run: `uv run pytest tests/test_bundle_overlay.py -v` -Expected: PASS (both) - -- [ ] **Step 5: Lint + commit** - -```bash -uv run ruff check src/protspace/data/io/bundle.py tests/test_bundle_overlay.py -git add src/protspace/data/io/bundle.py tests/test_bundle_overlay.py -git commit -m "feat: replace annotations part of a parquetbundle in place" -``` - ---- - -## Task 9: The `transfer` orchestration core + CLI (`protspace.cli.transfer`) - -**Files:** -- Create: `src/protspace/cli/transfer.py` -- Modify: `src/protspace/cli/app.py:65-73` -- Test: `tests/test_transfer_cli.py` - -- [ ] **Step 1: Write the failing tests (pure core + registration)** - -Create `tests/test_transfer_cli.py`: - -```python -"""Tests for the transfer orchestration core and CLI registration.""" - -import numpy as np -import pyarrow as pa -import pytest - -from protspace.analysis.classification import Rule -from protspace.cli.transfer import run_transfer - - -def _inputs(): - annotations = pa.table( - { - "identifier": ["TRINITY_1", "P00001", "P00002"], - "protein_category": ["", "neurotoxin", "enzyme"], - } - ) - # TRINITY_1 sits right on top of the neurotoxin reference P00001. - embeddings = { - "TRINITY_1": np.array([0.0, 0.0], dtype=np.float32), - "P00001": np.array([0.05, 0.0], dtype=np.float32), - "P00002": np.array([9.0, 0.0], dtype=np.float32), - } - return annotations, embeddings - - -def test_run_transfer_predicts_for_query_with_missing_value(): - annotations, embeddings = _inputs() - out = run_transfer( - annotations=annotations, - embeddings=embeddings, - transfer_columns=["protein_category"], - query_rule=Rule(id_prefixes=["TRINITY_"]), - reference_rule=Rule(where=[("protein_category", "")]), # any non-empty ref - k=1, - metric="euclidean", - ) - by_id = {r["identifier"]: r for r in out.to_pylist()} - assert by_id["TRINITY_1"]["protein_category__pred_value"] == "neurotoxin" - assert by_id["TRINITY_1"]["protein_category__pred_source"] == "P00001" - assert by_id["TRINITY_1"]["protein_category__pred_confidence"] > 0.9 - - -def test_run_transfer_skips_proteins_without_embeddings(): - annotations, embeddings = _inputs() - embeddings.pop("TRINITY_1") # no embedding -> cannot be a query - with pytest.raises(ValueError, match="no query"): - run_transfer( - annotations=annotations, - embeddings=embeddings, - transfer_columns=["protein_category"], - query_rule=Rule(id_prefixes=["TRINITY_"]), - reference_rule=Rule(id_prefixes=["P0"]), - k=1, - metric="euclidean", - ) - - -def test_transfer_command_is_registered(): - from typer.testing import CliRunner - - from protspace.cli.app import app - - result = CliRunner().invoke(app, ["transfer", "--help"]) - assert result.exit_code == 0 - assert "transfer" in result.output.lower() -``` - -- [ ] **Step 2: Run to verify failure** - -Run: `uv run pytest tests/test_transfer_cli.py -v` -Expected: FAIL — `ModuleNotFoundError: No module named 'protspace.cli.transfer'` - -- [ ] **Step 3: Implement the core + command** - -Create `src/protspace/cli/transfer.py`: - -```python -"""protspace transfer — fill missing annotation values from nearest references. - -Embedding Annotation Transfer (EAT): for each query protein with a missing -value in a target column, transfer the value of its nearest annotated -reference in pLM embedding space, with a reliability-index confidence. -""" - -from __future__ import annotations - -import logging -from pathlib import Path -from typing import Annotated - -import numpy as np -import pyarrow as pa -import typer - -from protspace.cli.app import app, setup_logging -from protspace.cli.common_options import Opt_Verbose - -logger = logging.getLogger(__name__) - - -def _is_missing(value) -> bool: - return value is None or str(value).strip() == "" - - -def run_transfer( - *, - annotations: pa.Table, - embeddings: dict[str, np.ndarray], - transfer_columns: list[str], - query_rule, - reference_rule, - k: int = 1, - metric: str = "euclidean", -) -> pa.Table: - """Pure core: classify, transfer per column, return the augmented table. - - ``embeddings`` maps protein id -> 1-D float32 vector. Proteins without an - embedding cannot act as queries or references. - """ - from protlabel import eat - - from protspace.analysis.classification import classify - from protspace.data.io.predictions import add_overlay_columns - - # Restrict classification to proteins that actually have an embedding. - has_emb = pa.array( - [str(v) in embeddings for v in annotations.column("identifier").to_pylist()] - ) - embedded = annotations.filter(has_emb) - - query_idx, ref_idx = classify(embedded, query_rule, reference_rule) - rows = embedded.to_pylist() - - out = annotations - for column in transfer_columns: - if column not in annotations.column_names: - raise KeyError(f"Transfer column {column!r} not in annotations table") - - # References: classified refs that HAVE a value in this column. - ref_ids, ref_labels, ref_vecs = [], [], [] - for i in ref_idx: - value = rows[i].get(column) - if not _is_missing(value): - rid = str(rows[i]["identifier"]) - ref_ids.append(rid) - ref_labels.append(str(value)) - ref_vecs.append(embeddings[rid]) - if not ref_ids: - logger.warning("No references with a value for %r; skipping", column) - continue - - # Queries: classified queries MISSING a value in this column. - q_ids, q_vecs = [], [] - for i in query_idx: - if _is_missing(rows[i].get(column)): - qid = str(rows[i]["identifier"]) - q_ids.append(qid) - q_vecs.append(embeddings[qid]) - if not q_ids: - logger.warning("No queries missing %r; nothing to transfer", column) - continue - - preds = eat( - np.vstack(q_vecs), - q_ids, - np.vstack(ref_vecs), - ref_ids, - ref_labels, - k=k, - metric=metric, - ) - out = add_overlay_columns(out, column, preds) - logger.info("Transferred %r to %d quer(ies)", column, len(preds)) - - return out - - -@app.command() -def transfer( - bundle: Annotated[ - Path, - typer.Option("-b", "--bundle", help="Input .parquetbundle to annotate."), - ], - embeddings: Annotated[ - str, - typer.Option( - "-e", - "--embeddings", - help="HDF5 embeddings, optional :name suffix (e.g. emb.h5:prot_t5).", - ), - ], - transfer_columns: Annotated[ - list[str], - typer.Option("-t", "--transfer", help="Annotation column to transfer (repeat)."), - ], - output: Annotated[ - Path, - typer.Option("-o", "--output", help="Output .parquetbundle path."), - ], - query_id_prefix: Annotated[list[str], typer.Option("--query-id-prefix")] = None, - query_where: Annotated[list[str], typer.Option("--query-where", help="col~substr")] = None, - reference_id_prefix: Annotated[list[str], typer.Option("--reference-id-prefix")] = None, - reference_where: Annotated[list[str], typer.Option("--reference-where", help="col~substr")] = None, - k: Annotated[int, typer.Option("--k", help="Neighbours considered (default 1).")] = 1, - metric: Annotated[str, typer.Option("--metric", help="euclidean | cosine.")] = "euclidean", - verbose: Opt_Verbose = 0, -) -> None: - """Transfer annotations to query proteins from nearest reference neighbours.""" - setup_logging(verbose) - - import io - - import pyarrow.parquet as pq - - from protspace.analysis.classification import Rule - from protspace.data.io.bundle import read_bundle, replace_annotations_in_bundle - from protspace.data.loaders import load_h5 - - def _parse_where(items: list[str] | None) -> list[tuple[str, str]]: - clauses = [] - for item in items or []: - if "~" not in item: - raise typer.BadParameter(f"--*-where must be col~substr, got {item!r}") - col, sub = item.split("~", 1) - clauses.append((col, sub)) - return clauses - - query_rule = Rule(id_prefixes=query_id_prefix or [], where=_parse_where(query_where)) - reference_rule = Rule( - id_prefixes=reference_id_prefix or [], where=_parse_where(reference_where) - ) - - # Load embeddings (name override after ':'). - h5_spec = embeddings.split(":", 1) - h5_path = Path(h5_spec[0]) - name_override = h5_spec[1] if len(h5_spec) == 2 else None - emb_set = load_h5([h5_path], name_override=name_override) - emb_map = { - header: emb_set.data[i] for i, header in enumerate(emb_set.headers) - } - - # Read the annotations part of the bundle. - parts, _settings = read_bundle(bundle) - annotations = pq.read_table(io.BytesIO(parts[0])) - - augmented = run_transfer( - annotations=annotations, - embeddings=emb_map, - transfer_columns=transfer_columns, - query_rule=query_rule, - reference_rule=reference_rule, - k=k, - metric=metric, - ) - - replace_annotations_in_bundle(bundle, output, augmented) - logger.info("Wrote transferred bundle to %s", output) -``` - -- [ ] **Step 4: Register the subcommand** - -Edit `src/protspace/cli/app.py`. In `_register_commands()` (lines 65-73), add `transfer` to the import list (keep alphabetical): - -```python - from protspace.cli import ( # noqa: F401 - annotate, - bundle, - embed, - prepare, - project, - serve, - style, - transfer, - ) -``` - -- [ ] **Step 5: Run to verify pass** - -Run: `uv run pytest tests/test_transfer_cli.py -v` -Expected: PASS (all 3) - -- [ ] **Step 6: Run the full suite (fast) to check for regressions** - -Run: `uv run pytest tests/ -m "not slow" -q` -Expected: all pass (existing + new) - -- [ ] **Step 7: Lint + commit** - -```bash -uv run ruff check src/protspace/cli/transfer.py src/protspace/cli/app.py tests/test_transfer_cli.py -git add src/protspace/cli/transfer.py src/protspace/cli/app.py tests/test_transfer_cli.py -git commit -m "feat: add 'protspace transfer' annotation-transfer subcommand" -``` - ---- - -## Task 10: End-to-end smoke test through a real bundle round-trip - -**Files:** -- Test: `tests/test_transfer_cli.py` (append) - -- [ ] **Step 1: Write the failing end-to-end test** - -Append to `tests/test_transfer_cli.py`: - -```python -def test_end_to_end_bundle_roundtrip(tmp_path): - """Build a tiny bundle + h5, run the CLI, read the overlay back.""" - import io - - import h5py - import pyarrow.parquet as pq - from typer.testing import CliRunner - - from protspace.cli.app import app - from protspace.data.io.bundle import read_bundle, write_bundle - - annotations = pa.table( - {"identifier": ["TRINITY_1", "P00001"], "protein_category": ["", "neurotoxin"]} - ) - proj_meta = pa.table({"name": ["PCA 2"], "dims": [2]}) - proj_data = pa.table({"id": ["TRINITY_1", "P00001"], "x": [0.0, 9.0], "y": [0.0, 0.0]}) - bundle_path = tmp_path / "in.parquetbundle" - write_bundle([annotations, proj_meta, proj_data], bundle_path) - - h5_path = tmp_path / "emb.h5" - with h5py.File(h5_path, "w") as f: - f.create_dataset("TRINITY_1", data=np.array([0.0, 0.0], dtype=np.float32)) - f.create_dataset("P00001", data=np.array([0.1, 0.0], dtype=np.float32)) - - out_path = tmp_path / "out.parquetbundle" - result = CliRunner().invoke( - app, - [ - "transfer", - "-b", str(bundle_path), - "-e", str(h5_path), - "-t", "protein_category", - "-o", str(out_path), - "--query-id-prefix", "TRINITY_", - "--reference-id-prefix", "P0", - ], - ) - assert result.exit_code == 0, result.output - parts, _ = read_bundle(out_path) - rows = {r["identifier"]: r for r in pq.read_table(io.BytesIO(parts[0])).to_pylist()} - assert rows["TRINITY_1"]["protein_category__pred_value"] == "neurotoxin" - assert rows["TRINITY_1"]["protein_category__pred_source"] == "P00001" -``` - -- [ ] **Step 2: Run to verify it passes (implementation already exists)** - -Run: `uv run pytest tests/test_transfer_cli.py::test_end_to_end_bundle_roundtrip -v` -Expected: PASS. If it fails, fix `cli/transfer.py` until it passes (do not edit the test). - -- [ ] **Step 3: Commit** - -```bash -git add tests/test_transfer_cli.py -git commit -m "test: end-to-end protspace transfer bundle round-trip" -``` - ---- - -## Task 11: Documentation + notebook (required before final commit) - -**Files:** -- Modify: `docs/cli.md`, `docs/annotations.md`, `../CLAUDE.md` -- Create: `notebooks/ProtSpace_Transfer.ipynb` - -- [ ] **Step 1: Document the subcommand in `docs/cli.md`** - -Add a new section (match the heading style of the existing `protspace project` section): - -```markdown -### protspace transfer - -Fill in missing annotation values for query proteins by **Embedding Annotation -Transfer (EAT)** — each query's missing value is transferred from its nearest -annotated reference in pLM embedding space, with a goPredSim reliability-index -confidence (`0.5 / (0.5 + distance)` for Euclidean). - -```bash -protspace transfer \ - -b results.parquetbundle \ - -e embeddings.h5:prot_t5 \ - -t protein_category \ - -o results.parquetbundle \ - --query-id-prefix TRINITY_ \ - --reference-where 'protein_category~neurotoxin' -``` - -Default metric is Euclidean (canonical EAT); `--metric cosine` and `--k N` are -available. Writes `protein_category__pred_value`, `__pred_confidence`, and -`__pred_source` columns into the bundle's annotations table. Distances are -computed in the original embedding space (HDF5), not in the 2-D/3-D projection. - -References: Littmann et al., *Sci Rep* 2021 (DOI 10.1038/s41598-020-80786-0); -Heinzinger et al., *NAR Genom Bioinform* 2022 (DOI 10.1093/nargab/lqac043). -``` - -- [ ] **Step 2: Document the overlay columns in `docs/annotations.md`** - -Add a short subsection documenting the `__pred_value` / `__pred_confidence` / `__pred_source` convention so the `protspace_web` annotation registry can stay aligned: - -```markdown -## Predicted-by-transfer overlay columns - -`protspace transfer` appends three columns per transferred annotation `COL`, -populated only for proteins whose `COL` value was predicted (null otherwise): - -| Column | Type | Meaning | -|--------|------|---------| -| `COL__pred_value` | string | the transferred label | -| `COL__pred_confidence` | float | reliability index in [0, 1] | -| `COL__pred_source` | string | nearest reference protein id | - -The curated `COL` is left untouched. A protein is "predicted" for `COL` when -`COL` is empty but `COL__pred_value` is present. -``` - -- [ ] **Step 3: Update the CLI table in `../CLAUDE.md`** - -In the `## CLI Commands` table (the `protspace/CLAUDE.md` one), add a row: - -```markdown -| `protspace transfer` | Fill missing annotations from nearest reference embeddings (EAT) | -``` - -- [ ] **Step 4: Create the Colab notebook** - -Create `notebooks/ProtSpace_Transfer.ipynb` — a minimal notebook (use `uv run jupytext` or write JSON directly) with: (1) a markdown intro to EAT and `protspace transfer`, (2) a cell installing protspace, (3) a cell running the example command on a public dataset, (4) a cell reading the `__pred_*` columns back with pandas. Keep it runnable end-to-end. - -Run to validate it parses: `uv run python -c "import json,nbformat; nbformat.read(open('notebooks/ProtSpace_Transfer.ipynb'), as_version=4)"` -Expected: no error. - -- [ ] **Step 5: Final lint of the whole change** - -Run: `uv run ruff check src/ tests/` -Expected: no errors. - -- [ ] **Step 6: Commit docs** - -```bash -git add docs/cli.md docs/annotations.md ../CLAUDE.md notebooks/ProtSpace_Transfer.ipynb -git commit -m "docs: document protspace transfer + prediction overlay columns" -``` - ---- - -## Task 12: Full verification + open a PR - -- [ ] **Step 1: Run the complete fast suite** - -Run: `uv run pytest tests/ -m "not slow" -q` -Expected: all pass. - -- [ ] **Step 2: Confirm the protlabel boundary is still clean** - -Run: `! grep -rqE "import protspace|from protspace" src/protlabel/ && echo "boundary clean"` -Expected: `boundary clean` - -- [ ] **Step 3: Confirm the command is wired** - -Run: `uv run protspace transfer --help` -Expected: help text with `--bundle`, `--embeddings`, `--transfer`, `--metric`, `--k`. - -- [ ] **Step 4: Push the branch and open a PR** - -```bash -git push -u origin -gh pr create --title "feat: protlabel EAT engine + protspace transfer subcommand" \ - --body "Implements the backend of docs/superpowers/specs/2026-06-11-eat-annotation-transfer-design.md (closes #54 backend scope). protlabel = embedding annotation-transfer engine; protspace transfer = CLI that writes a per-cell prediction overlay into the bundle. Frontend rendering is a separate PR." -``` - ---- - -## Self-Review (completed during planning) - -**1. Spec coverage:** -- protlabel engine (spec §4) → Tasks 1–5. ✓ -- Euclidean default + RI formula (spec §3) → Tasks 2, 4. ✓ -- Embedding-space distances, not DR (spec §3) → Task 9 (`run_transfer` reads the HDF5, never projections). ✓ -- Query/reference classifier, no hardcoded biology (spec §5 step 2) → Task 6. ✓ -- Per-cell overlay representation (spec §6.2, user decision) → Tasks 7, 9. ✓ -- Rebuildable sidecar lookup, not in the bundle (spec §6.1) → Task 5 (`Lookup.save/load`). ✓ -- Brute-force default, no ANN (spec §7) → Task 3. ✓ -- One default output table; gating/mining/report opt-in/out-of-scope (spec §5, §13 Q4) → handled by scoping; noted out of scope. ✓ -- Docs + notebook (spec §12) → Task 11. ✓ -- **Deferred to follow-up plans (intentional):** frontend rendering (spec §9), optional gating/mining/report, faiss-cpu, ProtTucker. Noted in the header. - -**2. Placeholder scan:** Every code step contains complete code; commands have expected output; no "TBD"/"handle edge cases". Task 11 Step 4 (notebook) describes cell contents rather than embedding full notebook JSON — acceptable because the artifact is a notebook, not source code, and the validation command is concrete. - -**3. Type consistency:** `Prediction(query_id, label, source_id, distance, reliability, k, metric)` is defined in Task 4 and used identically in Tasks 5, 7, 9. `Rule(id_prefixes, where)` defined in Task 6, used in Tasks 9, 10. `nearest()->(idx, dist)`, `eat(...)->list[Prediction]`, `add_overlay_columns(table, column, predictions)->Table`, `replace_annotations_in_bundle(input, output, table)`, `run_transfer(...)->Table` — signatures consistent across tasks. Overlay column names (`__pred_value/__pred_confidence/__pred_source`) identical in Tasks 7, 9, 10, 11. ✓ diff --git a/docs/superpowers/specs/2026-05-27-neighbors-subcommand-design.md b/docs/superpowers/specs/2026-05-27-neighbors-subcommand-design.md deleted file mode 100644 index 3ed151d4..00000000 --- a/docs/superpowers/specs/2026-05-27-neighbors-subcommand-design.md +++ /dev/null @@ -1,14 +0,0 @@ -# Design: `protspace neighbors` — reproducible proximity mining - -> **⚠️ Superseded (2026-06-11).** This early draft scoped a single `protspace neighbors` -> subcommand and defaulted to cosine distance. It has been replaced by -> [`2026-06-11-eat-annotation-transfer-design.md`](./2026-06-11-eat-annotation-transfer-design.md), -> which: -> - splits the work into a **`protlabel`** engine (the EAT lookup, per GitHub issue #54) + -> a thin **`protspace transfer`** subcommand; -> - flips the default metric to **Euclidean** (canonical EAT) with cosine opt-in; -> - adopts the goPredSim **reliability index** as the confidence column; -> - specifies **storage** (reference lookup as a rebuildable sidecar, prediction overlay in the bundle), -> **compute feasibility**, and the **frontend representation** (extending PR #272). -> -> Kept for history only. Read the 2026-06-11 spec instead. diff --git a/docs/superpowers/specs/2026-06-11-eat-annotation-transfer-design.md b/docs/superpowers/specs/2026-06-11-eat-annotation-transfer-design.md index 610b3bf7..3d291f45 100644 --- a/docs/superpowers/specs/2026-06-11-eat-annotation-transfer-design.md +++ b/docs/superpowers/specs/2026-06-11-eat-annotation-transfer-design.md @@ -2,7 +2,7 @@ **Status:** Draft for review **Date:** 2026-06-11 -**Supersedes:** `2026-05-27-neighbors-subcommand-design.md` (the earlier "neighbors-subcommand" draft — this expands its scope, reconciles it with GitHub issue #54 and frontend PR #272, and corrects two defaults). +**Supersedes:** an earlier "neighbors-subcommand" draft (since removed) — this expanded its scope, reconciled it with GitHub issue #54 and frontend PR #272, and corrected two defaults (cosine→Euclidean, and the reliability aggregation). **Trigger:** Conference feedback (`Conference_feedback/ProtSpaceExtractor_v1.7.4_mod 1.py`) + GitHub issue [#54 "EAT — Embedding Annotation Transfer (protlabel lookup table)"](https://github.com/tsenoner/protspace/issues/54) + frontend PR [protspace_web #272 "mark predictions and surface per-annotation docs"](https://github.com/tsenoner/protspace_web/pull/272). **Research backing:** Literature + codebase fan-out (8 agents) with adversarial verification of the storage/compute math and the EAT algorithm against primary sources. Citations in §15. From 9da7f4d552690a9403a6425b07c0b81837fcf859 Mon Sep 17 00:00:00 2001 From: tsenoner Date: Tue, 16 Jun 2026 17:20:29 +0200 Subject: [PATCH 20/21] =?UTF-8?q?fix(transfer):=20address=20review=20findi?= =?UTF-8?q?ngs=20=E2=80=94=20atomicity,=20precision,=20security,=20robustn?= =?UTF-8?q?ess?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Resolve issues found in code review of the EAT transfer backend (PR #55): - predictions: make the overlay idempotent — drop existing __pred_* columns before re-appending, so re-running transfer replaces them instead of producing a duplicate-column bundle that can no longer be read back - bundle: atomic writes (temp file + os.replace) in write_bundle and the replace_* helpers, so an interrupted in-place overwrite (-b X -o X) can no longer destroy the bundle; reject the reserved delimiter in serialized parts - backends: replace scipy.cdist with a pure-numpy BLAS GEMM path and recompute the surviving top-k distances in float64 (precise for near-identical vectors); guard cosine against zero-norm NaN - lookup: store float32 + unicode arrays, load with allow_pickle=False (no pickle/RCE surface; lossless round-trip) - transfer/classification: materialize only the needed columns (no full to_pylist); deterministic RI tie-break; translate input errors to BadParameter - cli: colon/Windows-safe -e/-i parsing via a shared split_h5_spec helper - docs/notebook: qualify the reliability-index formula per metric and k Adds tests for protlabel engine, overlay idempotency, atomic write, spec parsing, and CLI error handling. Full suite: 572 passed; ruff clean. Co-Authored-By: Claude Fable 5 --- docs/annotations.md | 8 +- docs/cli.md | 12 +- ...26-06-11-eat-annotation-transfer-design.md | 12 + notebooks/ProtSpace_Transfer.ipynb | 22 +- src/protlabel/__init__.py | 13 +- src/protlabel/backends.py | 106 ++++++-- src/protlabel/lookup.py | 27 +- src/protlabel/transfer.py | 24 +- src/protspace/analysis/classification.py | 13 +- src/protspace/cli/prepare.py | 19 +- src/protspace/cli/transfer.py | 73 +++-- src/protspace/data/io/bundle.py | 81 ++++-- src/protspace/data/io/predictions.py | 12 + src/protspace/data/loaders/__init__.py | 8 +- src/protspace/data/loaders/h5.py | 22 ++ tests/test_base_data_processor.py | 19 +- tests/test_bundle_overlay.py | 68 +++++ tests/test_h5_parse_identifier.py | 28 +- tests/test_predictions_overlay.py | 25 ++ tests/test_protlabel_backends.py | 59 +++++ tests/test_protlabel_lookup.py | 25 ++ tests/test_protlabel_transfer.py | 11 + tests/test_protlabel_version.py | 11 + tests/test_transfer_cli.py | 250 ++++++++++++++++++ 24 files changed, 808 insertions(+), 140 deletions(-) create mode 100644 tests/test_protlabel_version.py diff --git a/docs/annotations.md b/docs/annotations.md index cfb3dab2..7928fb61 100644 --- a/docs/annotations.md +++ b/docs/annotations.md @@ -191,9 +191,15 @@ Running `protspace transfer` appends three new columns to the bundle's annotatio | Column | Type | Meaning | | --- | --- | --- | | `COL__pred_value` | string | The transferred label from the nearest annotated reference protein | -| `COL__pred_confidence` | float | Reliability index in [0, 1]: `0.5 / (0.5 + distance)` — 1 = identical embeddings | +| `COL__pred_confidence` | float | Reliability index in [0, 1] — 1 = identical embeddings (formula depends on `--metric`/`--k`, see below) | | `COL__pred_source` | string | UniProt accession (or ID) of the nearest reference protein | A protein is considered "predicted" for `COL` when `COL` is empty but `COL__pred_value` is present. Use `COL__pred_confidence` to threshold low-reliability transfers. +The reliability index depends on the `--metric` and `--k` used during transfer: + +- **Default (`--metric euclidean`, `--k 1`):** `0.5 / (0.5 + distance)`. +- **`--metric cosine` (`--k 1`):** `clamp(1 - cosine_distance, 0, 1)`, where `cosine_distance` is in [0, 2]. +- **`--k > 1`:** the goPredSim mean reliability — `(1/m) · Σ s(d)` of the per-neighbour similarity over the `k` nearest neighbours carrying the chosen label, with `m = min(k, number of references)`. Because of this normalization, values are **not** comparable across different `--k` settings. + See [`protspace transfer`](cli.md#protspace-transfer) for usage and option details. diff --git a/docs/cli.md b/docs/cli.md index b18ae744..158c7a53 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -186,7 +186,13 @@ protspace style data.parquetbundle --dump-settings ## `protspace transfer` -Embedding Annotation Transfer (EAT): fills missing annotation values for query proteins by transferring the annotation of the nearest annotated reference protein in pLM embedding space. For each query protein that lacks a value in the requested annotation column, the command finds the closest reference (by Euclidean distance in the original high-dimensional embedding space — not in the 2-D/3-D projection) and assigns that reference's label along with a reliability index adapted from goPredSim (`confidence = 0.5 / (0.5 + distance)`), yielding a score in [0, 1] where 1 means identical embeddings. The curated source column (`COL`) is left untouched; results are written as three new columns: `COL__pred_value` (string), `COL__pred_confidence` (float), and `COL__pred_source` (the nearest reference protein ID). The method is a direct application of the approach introduced by Littmann et al., Sci Rep 2021 ([DOI 10.1038/s41598-020-80786-0](https://doi.org/10.1038/s41598-020-80786-0)) and extended by Heinzinger et al., NAR Genom Bioinform 2022 ([DOI 10.1093/nargab/lqac043](https://doi.org/10.1093/nargab/lqac043)). +Embedding Annotation Transfer (EAT): fills missing annotation values for query proteins by transferring the annotation of the nearest annotated reference protein in pLM embedding space. For each query protein that lacks a value in the requested annotation column, the command finds the closest reference (by distance in the original high-dimensional embedding space — Euclidean by default, or cosine via `--metric`, and not in the 2-D/3-D projection) and assigns that reference's label along with a reliability index adapted from goPredSim, yielding a score in [0, 1] where 1 means identical embeddings. The curated source column (`COL`) is left untouched; results are written as three new columns: `COL__pred_value` (string), `COL__pred_confidence` (float), and `COL__pred_source` (the nearest reference protein ID). The method is a direct application of the approach introduced by Littmann et al., Sci Rep 2021 ([DOI 10.1038/s41598-020-80786-0](https://doi.org/10.1038/s41598-020-80786-0)) and extended by Heinzinger et al., NAR Genom Bioinform 2022 ([DOI 10.1093/nargab/lqac043](https://doi.org/10.1093/nargab/lqac043)). + +**Reliability index (`COL__pred_confidence`).** The exact form depends on `--metric` and `--k`: + +- **Default (`--metric euclidean`, `--k 1`):** `confidence = 0.5 / (0.5 + distance)` (1 at distance 0, 0.5 at distance 0.5, → 0 as distance → ∞). +- **`--metric cosine` (`--k 1`):** `confidence = clamp(1 - cosine_distance, 0, 1)`, where `cosine_distance` is in [0, 2]. +- **`--k > 1`:** the value is the goPredSim mean reliability — `(1/m) · Σ s(d)`, the sum of the per-neighbour similarity `s(d)` (the euclidean or cosine form above) over the `k` nearest neighbours that carry the chosen label, divided by `m = min(k, number of references)`. Because of this normalization, confidence values are **not** comparable across different `--k` settings. ```bash protspace transfer \ @@ -211,9 +217,9 @@ protspace transfer \ | `--reference-id-prefix` | Restrict reference proteins to IDs starting with this prefix | — | | `--reference-where` | Filter reference proteins by annotation value (`col~substr`) | — | | `--k` | Number of nearest neighbours | `1` | -| `--metric` | Distance metric (`euclidean`, `cosine`) | `euclidean` | +| `--metric` | Distance metric (`euclidean`, `cosine`); see the reliability-index forms above | `euclidean` | -Distances are computed in the original embedding space (HDF5), not in the 2-D/3-D projection. +Distances are computed in the original embedding space (HDF5), not in the 2-D/3-D projection. The `--metric` choice also changes how `COL__pred_confidence` is computed: euclidean uses `0.5 / (0.5 + distance)`, while cosine uses `clamp(1 - cosine_distance, 0, 1)` (see the reliability-index note above). ## Combining Multiple Inputs (`-i`) diff --git a/docs/superpowers/specs/2026-06-11-eat-annotation-transfer-design.md b/docs/superpowers/specs/2026-06-11-eat-annotation-transfer-design.md index 3d291f45..0038b355 100644 --- a/docs/superpowers/specs/2026-06-11-eat-annotation-transfer-design.md +++ b/docs/superpowers/specs/2026-06-11-eat-annotation-transfer-design.md @@ -6,6 +6,18 @@ **Trigger:** Conference feedback (`Conference_feedback/ProtSpaceExtractor_v1.7.4_mod 1.py`) + GitHub issue [#54 "EAT — Embedding Annotation Transfer (protlabel lookup table)"](https://github.com/tsenoner/protspace/issues/54) + frontend PR [protspace_web #272 "mark predictions and surface per-annotation docs"](https://github.com/tsenoner/protspace_web/pull/272). **Research backing:** Literature + codebase fan-out (8 agents) with adversarial verification of the storage/compute math and the EAT algorithm against primary sources. Citations in §15. +> **Shipped vs deferred (read this first).** This document is a design draft, not as-built documentation. Only a subset of what is described below shipped in PR #55; several flags and the long-format table are deferred follow-ups. Treat anything not in the "Shipped" list as future work that is **not yet implemented**. +> +> **Shipped in PR #55** (`protspace transfer`, see `src/protspace/cli/transfer.py`): +> - Flags: `-b/--bundle`, `-e/--embeddings`, `-t/--transfer`, `-o/--output`, `--query-id-prefix`, `--query-where`, `--reference-id-prefix`, `--reference-where`, `--k`, `--metric` (`euclidean` | `cosine`). +> - Wide overlay columns appended to the bundle annotations table (`src/protspace/data/io/predictions.py`): `__pred_value`, `__pred_confidence`, `__pred_source`. +> - Brute-force nearest-neighbour search and the goPredSim reliability index. +> +> **Deferred / not yet implemented (future work):** +> - The opt-in flags described in §3, §5, and §11: `--cutoff`, `--mine`/`--top-n`, `--lookup`, `--report`, `--plots`, `--full-tables`, `--excel`, distance-threshold transfer mode. +> - The long-format `predicted_annotations` parquet table (§7.2); shipped output is the wide `__pred_*` overlay columns instead. +> - The faiss accelerator and the `protspace[ann]` extra — **the `ann` extra is not declared in `pyproject.toml`**, so `pip install protspace[ann]` does not work today. + --- ## 1. One-paragraph decision diff --git a/notebooks/ProtSpace_Transfer.ipynb b/notebooks/ProtSpace_Transfer.ipynb index 7af93236..8e2e4089 100644 --- a/notebooks/ProtSpace_Transfer.ipynb +++ b/notebooks/ProtSpace_Transfer.ipynb @@ -22,25 +22,7 @@ "cell_type": "markdown", "id": "a1b2c3d4", "metadata": {}, - "source": [ - "# ProtSpace — Embedding Annotation Transfer (EAT)\n", - "\n", - "This notebook demonstrates **Embedding Annotation Transfer (EAT)** with `protspace transfer`.\n", - "For each query protein that lacks an annotation value, the command finds the closest\n", - "annotated reference protein in pLM embedding space and transfers its label, together\n", - "with a reliability index (`confidence = 0.5 / (0.5 + distance)`, range [0, 1]).\n", - "The method follows the goPredSim approach introduced in:\n", - "\n", - "- Littmann et al., *Sci Rep* 2021 — [DOI 10.1038/s41598-020-80786-0](https://doi.org/10.1038/s41598-020-80786-0)\n", - "- Heinzinger et al., *NAR Genom Bioinform* 2022 — [DOI 10.1093/nargab/lqac043](https://doi.org/10.1093/nargab/lqac043)\n", - "\n", - "Distances are computed in the original high-dimensional embedding space (HDF5),\n", - "not in any 2-D/3-D projection. The curated source column is left untouched;\n", - "results are written as `COL__pred_value`, `COL__pred_confidence`, and `COL__pred_source`\n", - "columns in the bundle's annotations table.\n", - "\n", - "📚 [GitHub](https://github.com/tsenoner/protspace) · [CLI Reference](https://github.com/tsenoner/protspace/blob/main/docs/cli.md#protspace-transfer) · [Annotation Reference](https://github.com/tsenoner/protspace/blob/main/docs/annotations.md#prediction-overlay-columns-eat-transfer)" - ] + "source": "# ProtSpace — Embedding Annotation Transfer (EAT)\n\nThis notebook demonstrates **Embedding Annotation Transfer (EAT)** with `protspace transfer`.\nFor each query protein that lacks an annotation value, the command finds the closest\nannotated reference protein in pLM embedding space and transfers its label, together\nwith a reliability index in [0, 1]. The exact confidence formula depends on `--metric` and `--k`:\n\n- Default (`--metric euclidean`, `--k 1`): `confidence = 0.5 / (0.5 + distance)`.\n- `--metric cosine` (`--k 1`): `confidence = clamp(1 - cosine_distance, 0, 1)` (cosine distance in [0, 2]).\n- `--k > 1`: the goPredSim mean reliability — `(1/m) * sum` of the per-neighbour similarity over the `k` nearest neighbours carrying the chosen label, where `m = min(k, number of references)`. Because of this normalization, values are not comparable across different `--k` settings.\n\nThe method follows the goPredSim approach introduced in:\n\n- Littmann et al., *Sci Rep* 2021 — [DOI 10.1038/s41598-020-80786-0](https://doi.org/10.1038/s41598-020-80786-0)\n- Heinzinger et al., *NAR Genom Bioinform* 2022 — [DOI 10.1093/nargab/lqac043](https://doi.org/10.1093/nargab/lqac043)\n\nDistances are computed in the original high-dimensional embedding space (HDF5),\nnot in any 2-D/3-D projection. The curated source column is left untouched;\nresults are written as `COL__pred_value`, `COL__pred_confidence`, and `COL__pred_source`\ncolumns in the bundle's annotations table.\n\n📚 [GitHub](https://github.com/tsenoner/protspace) · [CLI Reference](https://github.com/tsenoner/protspace/blob/main/docs/cli.md#protspace-transfer) · [Annotation Reference](https://github.com/tsenoner/protspace/blob/main/docs/annotations.md#prediction-overlay-columns-eat-transfer)" }, { "cell_type": "markdown", @@ -119,4 +101,4 @@ ] } ] -} +} \ No newline at end of file diff --git a/src/protlabel/__init__.py b/src/protlabel/__init__.py index 7650ecbb..aebc59d8 100644 --- a/src/protlabel/__init__.py +++ b/src/protlabel/__init__.py @@ -1,12 +1,21 @@ """protlabel — Embedding Annotation Transfer (EAT) engine. Nearest-neighbour label transfer in protein-language-model embedding space, -with the goPredSim reliability index. Pure numpy/scipy/h5py; no protspace imports. +with the goPredSim reliability index. Pure numpy (plus the standard library); +no protspace imports. """ +from importlib.metadata import PackageNotFoundError, version + from protlabel.lookup import Lookup from protlabel.transfer import Prediction, eat -__version__ = "0.1.0" +try: + # protlabel currently ships inside the protspace wheel; report that version + # rather than a hard-coded literal that would silently drift across releases. + # (Reads distribution metadata by name — it does not import protspace.) + __version__ = version("protspace") +except PackageNotFoundError: # pragma: no cover - source/uninstalled fallback + __version__ = "0.0.0" __all__ = ["Lookup", "Prediction", "eat", "__version__"] diff --git a/src/protlabel/backends.py b/src/protlabel/backends.py index cfedd914..e5472ee4 100644 --- a/src/protlabel/backends.py +++ b/src/protlabel/backends.py @@ -1,20 +1,56 @@ """Exact (brute-force) k-nearest-neighbour search over reference embeddings. -Chunked over both the query axis and, adaptively, the reference axis so the -per-chunk float64 distance block emitted by ``cdist`` is bounded to -``max_block_bytes`` (default 256 MiB) regardless of ``n_refs``. This keeps -peak memory near the reference matrix itself even at Swiss-Prot scale -(~570 000 references). scipy.cdist handles both euclidean and cosine. +Distances are computed with a chunked BLAS matrix product (numpy ``@``) plus +``argpartition`` — the GEMM path, which is roughly an order of magnitude faster +than ``scipy.cdist`` while staying pure-numpy (no scipy/sklearn dependency). + +The query axis is chunked and, adaptively, bounded against the reference axis so +the per-chunk distance block is kept at or below ``max_block_bytes`` (default +256 MiB) regardless of ``n_refs``. This keeps peak memory close to the +reference matrix itself even at Swiss-Prot scale (~570 000 references). """ from __future__ import annotations +import warnings + import numpy as np -from scipy.spatial.distance import cdist _METRICS = {"euclidean", "cosine"} +def _l2_normalize(x: np.ndarray) -> np.ndarray: + """Row-wise L2 normalize. Zero-magnitude rows stay zero (kept finite). + + A zero vector would make cosine distance NaN; mapping it to the zero vector + yields a finite cosine distance of 1.0 to every reference instead. + """ + norms = np.linalg.norm(x, axis=1, keepdims=True) + safe = np.where(norms == 0.0, 1.0, norms) + return x / safe + + +def _exact_distances(block: np.ndarray, sel: np.ndarray, metric: str) -> np.ndarray: + """Distances from each query (block[i]) to its k candidates (sel[i]) in float64. + + The fast GEMM block selects the top-k candidates; this recomputes their + distances by direct subtraction in float64 to avoid the catastrophic + cancellation of ``||q||^2 - 2 q.r + ||r||^2`` for near-identical vectors. + Cost is O(b * k * d), not O(b * n_refs), so it is cheap. + """ + blk = block[:, None, :].astype(np.float64) # (b, 1, d) + sel = sel.astype(np.float64) # (b, k, d) + if metric == "euclidean": + diff = blk - sel + return np.sqrt(np.einsum("bkd,bkd->bk", diff, diff)) + bn = np.linalg.norm(blk, axis=2, keepdims=True) + bn = np.where(bn == 0.0, 1.0, bn) + sn = np.linalg.norm(sel, axis=2, keepdims=True) + sn = np.where(sn == 0.0, 1.0, sn) + cos = np.einsum("bkd,bkd->bk", blk / bn, sel / sn) + return np.clip(1.0 - cos, 0.0, 2.0) + + def nearest( queries: np.ndarray, refs: np.ndarray, @@ -31,14 +67,17 @@ def nearest( Memory behaviour ---------------- - ``cdist`` internally produces a float64 block of shape - ``(query_chunk, n_refs)``. When ``n_refs`` is large (e.g. Swiss-Prot - ~570 000), even a modest ``chunk`` of 4096 yields a ~19 GiB block. - To bound this, the effective query-chunk size ``eff_chunk`` is computed - as ``min(chunk, max_block_bytes // (n_refs * 8))`` so the float64 block - stays at or below ``max_block_bytes`` (default 256 MiB) independent of - ``n_refs``. Peak memory therefore remains close to the reference matrix - itself, making the function laptop-feasible at Swiss-Prot scale. + The per-chunk distance block has shape ``(query_chunk, n_refs)``. When + ``n_refs`` is large (e.g. Swiss-Prot ~570 000), even a modest ``chunk`` of + 4096 yields a multi-GiB block. To bound this, the effective query-chunk + size ``eff_chunk`` is computed as ``min(chunk, max_block_bytes // (n_refs * + 8))`` so each block stays at or below ``max_block_bytes`` (default 256 MiB) + independent of ``n_refs``. The factor 8 budgets for a float64-sized block; + the euclidean path holds a few float32 ``(query_chunk, n_refs)`` temporaries + at once, so the real peak is comparable to (not far below) that budget. The + exact-distance recompute only touches the k surviving candidates, so peak + memory remains close to the reference matrix itself, making the function + laptop-feasible at Swiss-Prot scale. """ if metric not in _METRICS: raise ValueError(f"Unknown metric {metric!r}; expected 'euclidean' or 'cosine'") @@ -51,23 +90,44 @@ def nearest( n_refs = refs.shape[0] k = min(k, n_refs) - # Adaptively shrink the query chunk so the float64 cdist block stays - # within max_block_bytes (cdist emits float64 = 8 bytes per element). + # Adaptively shrink the query chunk so the distance block stays within + # max_block_bytes (budgeting for a float64 = 8 bytes-per-element block). bytes_per_row = max(1, n_refs * 8) eff_chunk = max(1, min(chunk, max_block_bytes // bytes_per_row)) idx_out = np.empty((queries.shape[0], k), dtype=np.int64) dist_out = np.empty((queries.shape[0], k), dtype=np.float32) + # Cosine: normalize once; distance is 1 - cosine similarity via a dot product. + refs_unit = _l2_normalize(refs) if metric == "cosine" else None + # Euclidean: precompute ||ref||^2 once; ||q-r||^2 = ||q||^2 - 2 q.r + ||r||^2. + refs_sq = np.einsum("ij,ij->i", refs, refs) if metric == "euclidean" else None + for start in range(0, queries.shape[0], eff_chunk): block = queries[start : start + eff_chunk] - d = cdist(block, refs, metric=metric).astype(np.float32) # (b, n_refs) - part = np.argpartition(d, kth=k - 1, axis=1)[:, :k] # unsorted top-k + with warnings.catch_warnings(): + # GEMM on near-overflowing float16-origin values can emit harmless + # RuntimeWarnings; the inputs are upcast float32 so values are safe. + warnings.simplefilter("ignore", RuntimeWarning) + if metric == "euclidean": + block_sq = np.einsum("ij,ij->i", block, block) + cross = block @ refs.T # (b, n_refs), BLAS GEMM + d2 = block_sq[:, None] - 2.0 * cross + refs_sq[None, :] + np.maximum(d2, 0.0, out=d2) # clip tiny negative fp artifacts + d = np.sqrt(d2, dtype=np.float32) + else: # cosine + block_unit = _l2_normalize(block) + d = (1.0 - block_unit @ refs_unit.T).astype(np.float32) + np.clip(d, 0.0, 2.0, out=d) # cosine distance in [0, 2] + + # Select the k candidates with the fast (float32) block, then recompute + # their distances exactly in float64 and order by that — so the reported + # distance is precise even for near-identical vectors. + cand = np.argpartition(d, kth=k - 1, axis=1)[:, :k] # unsorted top-k rows = np.arange(block.shape[0])[:, None] - part_d = d[rows, part] - order = np.argsort(part_d, axis=1) # sort the k by distance - sorted_idx = part[rows, order] - idx_out[start : start + block.shape[0]] = sorted_idx - dist_out[start : start + block.shape[0]] = d[rows, sorted_idx] + exact = _exact_distances(block, refs[cand], metric) # (b, k) float64 + order = np.argsort(exact, axis=1) + idx_out[start : start + block.shape[0]] = cand[rows, order] + dist_out[start : start + block.shape[0]] = exact[rows, order].astype(np.float32) return idx_out, dist_out diff --git a/src/protlabel/lookup.py b/src/protlabel/lookup.py index 5220b3b1..916f557c 100644 --- a/src/protlabel/lookup.py +++ b/src/protlabel/lookup.py @@ -39,26 +39,33 @@ def query( ) def save(self, path: Path) -> None: - """Serialize to a .npz sidecar.""" + """Serialize to a .npz sidecar. + + Embeddings are stored as float32 (lossless round-trip); ids/labels are + stored as unicode arrays rather than pickled object arrays so the sidecar + can be loaded with ``allow_pickle=False`` (no arbitrary-code-execution + surface on load). ids/labels must not contain trailing NUL bytes, which + numpy's fixed-width unicode arrays strip on round-trip. + """ path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) np.savez( path, - embeddings=self.embeddings.astype(np.float16), - ids=np.array(self.ids, dtype=object), - labels=np.array(self.labels, dtype=object), - metric=self.metric, - model=self.model, + embeddings=self.embeddings.astype(np.float32), + ids=np.asarray(self.ids, dtype=np.str_), + labels=np.asarray(self.labels, dtype=np.str_), + metric=np.asarray(self.metric, dtype=np.str_), + model=np.asarray(self.model, dtype=np.str_), ) @classmethod def load(cls, path: Path) -> Lookup: - """Load a .npz sidecar (re-upcasts embeddings to float32).""" - with np.load(path, allow_pickle=True) as data: + """Load a .npz sidecar (with pickling disabled).""" + with np.load(path, allow_pickle=False) as data: return cls( embeddings=data["embeddings"].astype(np.float32), - ids=list(data["ids"]), - labels=list(data["labels"]), + ids=[str(x) for x in data["ids"]], + labels=[str(x) for x in data["labels"]], metric=str(data["metric"]), model=str(data["model"]), ) diff --git a/src/protlabel/transfer.py b/src/protlabel/transfer.py index 59fbc893..e0389343 100644 --- a/src/protlabel/transfer.py +++ b/src/protlabel/transfer.py @@ -1,9 +1,12 @@ """Embedding annotation transfer: kNN -> reliability index -> transferred label. Implements the goPredSim aggregation (Littmann et al. 2021, Eq. 5): - RI(p) = (1/k) * sum over neighbours carrying label p of similarity(d). -The transferred label is argmax RI(p); its source is the nearest neighbour -carrying that label. + RI(p) = (1/eff_k) * sum over neighbours carrying label p of similarity(d), +where ``eff_k`` is the number of neighbours actually used (``k`` capped to the +number of references). The transferred label is argmax RI(p); its source is the +nearest neighbour carrying that label. Ties are broken deterministically by +smallest source distance, then by lexically smallest label, so the result never +depends on the (arbitrary) ordering of equidistant neighbours. """ from __future__ import annotations @@ -63,11 +66,16 @@ def eat( ri_by_label[lab] = ri_by_label.get(lab, 0.0) + similarity(d, metric) if lab not in nearest_src or d < nearest_src[lab][0]: nearest_src[lab] = (d, ref_ids[ref_i]) - # Normalise by k (the goPredSim 1/k term). - # Tie-break: max() over the insertion-ordered dict returns the first label - # seen while iterating neighbours (which are in ascending distance order), - # so for distinct distances the nearest neighbour's label wins a tie. - best_label = max(ri_by_label, key=lambda p: ri_by_label[p]) + # Pick the highest-RI label; break ties deterministically by smallest + # source distance, then lexically smallest label. This makes the choice + # independent of the order of equidistant neighbours (whose argsort order + # is otherwise arbitrary). For distinct distances the nearest neighbour's + # label wins, as before. + best_label = min( + ri_by_label, + key=lambda p: (-ri_by_label[p], nearest_src[p][0], p), + ) + # Normalise by eff_k (the goPredSim 1/k term, k capped to n_refs). ri = ri_by_label[best_label] / eff_k src_dist, src_id = nearest_src[best_label] predictions.append( diff --git a/src/protspace/analysis/classification.py b/src/protspace/analysis/classification.py index 6a8111d2..93474b27 100644 --- a/src/protspace/analysis/classification.py +++ b/src/protspace/analysis/classification.py @@ -47,12 +47,19 @@ def classify( if column not in columns: raise KeyError(f"Classification column {column!r} not in annotations") - rows = annotations.to_pylist() - identifiers = [str(r["identifier"]) for r in rows] + # Materialize only the columns the rules actually need (identifier + any + # where-columns) instead of the whole table — the latter is ~GB-scale at + # Swiss-Prot row counts. + identifiers = [str(v) for v in annotations.column("identifier").to_pylist()] + where_columns = { + column for rule in (query_rule, reference_rule) for column, _ in rule.where + } + column_data = {c: annotations.column(c).to_pylist() for c in where_columns} query_indices: list[int] = [] reference_indices: list[int] = [] - for i, (identifier, row) in enumerate(zip(identifiers, rows, strict=True)): + for i, identifier in enumerate(identifiers): + row = {c: column_data[c][i] for c in where_columns} if _matches(query_rule, identifier, row): query_indices.append(i) elif _matches(reference_rule, identifier, row): diff --git a/src/protspace/cli/prepare.py b/src/protspace/cli/prepare.py index 4831f6b2..abbdf369 100644 --- a/src/protspace/cli/prepare.py +++ b/src/protspace/cli/prepare.py @@ -549,22 +549,9 @@ def prepare( def _parse_input_specs(raw_inputs: list[str]) -> list[tuple[Path, str | None]]: """Parse inputs with optional colon name override: file.h5:model_name.""" - specs: list[tuple[Path, str | None]] = [] - for raw in raw_inputs: - if ":" in raw: - last_colon = raw.rfind(":") - path_part, name_part = raw[:last_colon], raw[last_colon + 1 :] - if ( - name_part - and not name_part.startswith(("/", "\\")) - and Path(path_part).suffix - ): - specs.append((Path(path_part), name_part)) - else: - specs.append((Path(raw), None)) - else: - specs.append((Path(raw), None)) - return specs + from protspace.data.loaders import split_h5_spec + + return [split_h5_spec(raw) for raw in raw_inputs] def _parse_embedders(embedder_arg: str | None) -> list[str]: diff --git a/src/protspace/cli/transfer.py b/src/protspace/cli/transfer.py index e6035b07..23a75165 100644 --- a/src/protspace/cli/transfer.py +++ b/src/protspace/cli/transfer.py @@ -9,7 +9,7 @@ import logging from pathlib import Path -from typing import Annotated +from typing import TYPE_CHECKING, Annotated import numpy as np import pyarrow as pa @@ -18,6 +18,9 @@ from protspace.cli.app import app, setup_logging from protspace.cli.common_options import Opt_Verbose +if TYPE_CHECKING: + from protspace.analysis.classification import Rule + logger = logging.getLogger(__name__) @@ -30,8 +33,8 @@ def run_transfer( annotations: pa.Table, embeddings: dict[str, np.ndarray], transfer_columns: list[str], - query_rule, - reference_rule, + query_rule: Rule, + reference_rule: Rule, k: int = 1, metric: str = "euclidean", ) -> pa.Table: @@ -57,7 +60,9 @@ def run_transfer( ) query_idx, ref_idx = classify(embedded, query_rule, reference_rule) - rows = embedded.to_pylist() + # Materialize only the id column once (not the whole table); per-column values + # are pulled inside the loop. Avoids GB-scale Python lists at Swiss-Prot size. + id_list = [str(v) for v in embedded.column("identifier").to_pylist()] out = annotations total_transferred = 0 @@ -65,12 +70,14 @@ def run_transfer( if column not in annotations.column_names: raise KeyError(f"Transfer column {column!r} not in annotations table") + col_vals = embedded.column(column).to_pylist() + # References: classified refs that HAVE a value in this column. ref_ids, ref_labels, ref_vecs = [], [], [] for i in ref_idx: - value = rows[i].get(column) + value = col_vals[i] if not _is_missing(value): - rid = str(rows[i]["identifier"]) + rid = id_list[i] ref_ids.append(rid) ref_labels.append(str(value)) ref_vecs.append(embeddings[rid]) @@ -81,8 +88,8 @@ def run_transfer( # Queries: classified queries MISSING a value in this column. q_ids, q_vecs = [], [] for i in query_idx: - if _is_missing(rows[i].get(column)): - qid = str(rows[i]["identifier"]) + if _is_missing(col_vals[i]): + qid = id_list[i] q_ids.append(qid) q_vecs.append(embeddings[qid]) if not q_ids: @@ -165,7 +172,7 @@ def transfer( from protspace.analysis.classification import Rule from protspace.data.io.bundle import read_bundle, replace_annotations_in_bundle - from protspace.data.loaders import load_h5 + from protspace.data.loaders import load_h5, split_h5_spec def _parse_where(items: list[str] | None) -> list[tuple[str, str]]: clauses = [] @@ -188,10 +195,8 @@ def _parse_where(items: list[str] | None) -> list[tuple[str, str]]: if k < 1: raise typer.BadParameter("--k must be >= 1") - # Load embeddings (name override after ':'). - h5_spec = embeddings.split(":", 1) - h5_path = Path(h5_spec[0]) - name_override = h5_spec[1] if len(h5_spec) == 2 else None + # Load embeddings (optional ':name' override; colon/Windows-safe parsing). + h5_path, name_override = split_h5_spec(embeddings) emb_set = load_h5([h5_path], name_override=name_override) emb_map = {header: emb_set.data[i] for i, header in enumerate(emb_set.headers)} @@ -200,6 +205,14 @@ def _parse_where(items: list[str] | None) -> list[tuple[str, str]]: annotations = pq.read_table(io.BytesIO(parts[0])) # Real bundles name the id column "protein_id"; run_transfer works on "identifier". + if ( + "protein_id" in annotations.column_names + and "identifier" in annotations.column_names + ): + raise typer.BadParameter( + "Bundle annotations contain both 'protein_id' and 'identifier' columns; " + "cannot determine the id column unambiguously." + ) id_col = "protein_id" if "protein_id" in annotations.column_names else "identifier" if id_col != "identifier": annotations = annotations.rename_columns( @@ -212,15 +225,31 @@ def _parse_where(items: list[str] | None) -> list[tuple[str, str]]: f"--transfer column {col!r} not found in the bundle annotations" ) - augmented = run_transfer( - annotations=annotations, - embeddings=emb_map, - transfer_columns=transfer_columns, - query_rule=query_rule, - reference_rule=reference_rule, - k=k, - metric=metric, - ) + # Validate classification --*-where columns up front for a clean error. + available = set(annotations.column_names) + for col, _ in query_rule.where + reference_rule.where: + if col not in available: + raise typer.BadParameter( + f"--*-where column {col!r} not found in the bundle annotations" + ) + + # Translate input-driven errors from the core into clean CLI errors rather + # than leaking raw KeyError/ValueError tracebacks to the user. + try: + augmented = run_transfer( + annotations=annotations, + embeddings=emb_map, + transfer_columns=transfer_columns, + query_rule=query_rule, + reference_rule=reference_rule, + k=k, + metric=metric, + ) + except (KeyError, ValueError) as exc: + # Use the raw message (KeyError stringifies with repr quotes) rather than + # stripping quotes off the rendered string, which could mangle messages. + message = exc.args[0] if exc.args else str(exc) + raise typer.BadParameter(str(message)) from exc # Rename id column back so the written bundle keeps its original name # (the web frontend expects "protein_id"). diff --git a/src/protspace/data/io/bundle.py b/src/protspace/data/io/bundle.py index 22bab8f1..8cd8a343 100644 --- a/src/protspace/data/io/bundle.py +++ b/src/protspace/data/io/bundle.py @@ -8,6 +8,7 @@ import io import json import logging +import os import tempfile from pathlib import Path @@ -27,6 +28,43 @@ SETTINGS_FILENAME = "settings.parquet" +def _atomic_write_bytes(path: Path, data: bytes) -> None: + """Write ``data`` to ``path`` atomically (temp file + ``os.replace``). + + The destination is never left truncated or partial on interrupt — it keeps + the old bytes until the rename completes, then atomically becomes the full + new bytes. Critical for the in-place overwrite workflow that ``transfer`` + documents (``-b results.parquetbundle -o results.parquetbundle``): a Ctrl+C + or crash mid-write can no longer destroy the user's bundle. + """ + path.parent.mkdir(parents=True, exist_ok=True) + fd, tmp = tempfile.mkstemp(dir=str(path.parent), suffix=".tmp") + try: + with os.fdopen(fd, "wb") as f: + f.write(data) + f.flush() + os.fsync(f.fileno()) + os.replace(tmp, path) + except BaseException: + Path(tmp).unlink(missing_ok=True) + raise + + +def _check_no_delimiter(part_bytes: bytes) -> None: + """Guard: a serialized part must not contain the bundle delimiter. + + If a value (e.g. an annotation string) happens to contain the reserved + delimiter byte string, the part split on read-back would be corrupted; fail + loudly at write time instead. + """ + if PARQUET_BUNDLE_DELIMITER in part_bytes: + raise ValueError( + "Serialized parquet part contains the bundle delimiter " + f"{PARQUET_BUNDLE_DELIMITER!r}; a value includes this reserved byte " + "string and would corrupt the bundle on read." + ) + + def extract_bundle_to_dir(bundle_path: Path, target_dir: Path | None = None) -> str: """Extract a .parquetbundle into separate parquet files on disk. @@ -101,20 +139,23 @@ def write_bundle( bundle_path: Output file path. settings: Optional settings dict to include as 4th part. """ - bundle_path.parent.mkdir(parents=True, exist_ok=True) - - with open(bundle_path, "wb") as f: - for i, table in enumerate(tables): - if i > 0: - f.write(PARQUET_BUNDLE_DELIMITER) - buf = io.BytesIO() - pq.write_table(table, buf) - f.write(buf.getvalue()) - - if settings is not None: - f.write(PARQUET_BUNDLE_DELIMITER) - f.write(create_settings_parquet(settings)) - + buf = io.BytesIO() + for i, table in enumerate(tables): + if i > 0: + buf.write(PARQUET_BUNDLE_DELIMITER) + table_buf = io.BytesIO() + pq.write_table(table, table_buf) + part_bytes = table_buf.getvalue() + _check_no_delimiter(part_bytes) + buf.write(part_bytes) + + if settings is not None: + buf.write(PARQUET_BUNDLE_DELIMITER) + settings_bytes = create_settings_parquet(settings) + _check_no_delimiter(settings_bytes) + buf.write(settings_bytes) + + _atomic_write_bytes(bundle_path, buf.getvalue()) logger.info(f"Saved bundled output to: {bundle_path}") @@ -143,9 +184,7 @@ def replace_settings_in_bundle( core = PARQUET_BUNDLE_DELIMITER.join(parts[:3]) new_content = core + PARQUET_BUNDLE_DELIMITER + settings_bytes - output_path.parent.mkdir(parents=True, exist_ok=True) - with open(output_path, "wb") as f: - f.write(new_content) + _atomic_write_bytes(output_path, new_content) def replace_annotations_in_bundle( @@ -167,13 +206,13 @@ def replace_annotations_in_bundle( buf = io.BytesIO() pq.write_table(annotations_table, buf) - new_parts = [buf.getvalue(), parts[1], parts[2]] + new_annotations_bytes = buf.getvalue() + _check_no_delimiter(new_annotations_bytes) + new_parts = [new_annotations_bytes, parts[1], parts[2]] if len(parts) == 4: new_parts.append(parts[3]) - output_path.parent.mkdir(parents=True, exist_ok=True) - with open(output_path, "wb") as f: - f.write(PARQUET_BUNDLE_DELIMITER.join(new_parts)) + _atomic_write_bytes(output_path, PARQUET_BUNDLE_DELIMITER.join(new_parts)) logger.info(f"Wrote bundle with updated annotations to: {output_path}") diff --git a/src/protspace/data/io/predictions.py b/src/protspace/data/io/predictions.py index 92ce6d20..e60ae5f5 100644 --- a/src/protspace/data/io/predictions.py +++ b/src/protspace/data/io/predictions.py @@ -37,7 +37,19 @@ def add_overlay_columns( confidences.append(float(pred.reliability)) sources.append(pred.source_id) + # Drop any pre-existing overlay columns first so re-running transfer on an + # already-overlaid table replaces them rather than appending duplicates + # (duplicate field names produce a parquet table that cannot be read back). + overlay_names = [ + f"{column}__pred_value", + f"{column}__pred_confidence", + f"{column}__pred_source", + ] out = annotations + stale = [name for name in overlay_names if name in out.column_names] + if stale: + out = out.drop_columns(stale) + out = out.append_column(f"{column}__pred_value", pa.array(values, pa.string())) out = out.append_column( f"{column}__pred_confidence", pa.array(confidences, pa.float32()) diff --git a/src/protspace/data/loaders/__init__.py b/src/protspace/data/loaders/__init__.py index 908e571b..e864a9fd 100644 --- a/src/protspace/data/loaders/__init__.py +++ b/src/protspace/data/loaders/__init__.py @@ -6,7 +6,12 @@ merge_same_name_sets, ) from protspace.data.loaders.fasta import embed_fasta -from protspace.data.loaders.h5 import EMBEDDING_EXTENSIONS, load_h5, parse_identifier +from protspace.data.loaders.h5 import ( + EMBEDDING_EXTENSIONS, + load_h5, + parse_identifier, + split_h5_spec, +) from protspace.data.loaders.query import ( extract_identifiers_from_fasta, query_uniprot, @@ -24,4 +29,5 @@ "load_h5", "parse_identifier", "query_uniprot", + "split_h5_spec", ] diff --git a/src/protspace/data/loaders/h5.py b/src/protspace/data/loaders/h5.py index f9353dd3..88306a18 100644 --- a/src/protspace/data/loaders/h5.py +++ b/src/protspace/data/loaders/h5.py @@ -22,6 +22,28 @@ ) +def split_h5_spec(spec: str) -> tuple[Path, str | None]: + """Parse an HDF5 input spec with an optional ``:name`` override. + + ``file.h5:model_name`` -> ``(Path("file.h5"), "model_name")``. + + Colon/Windows-safe: the split uses the LAST colon and is only taken when the + right side looks like a name (not a path separator) and the left side has a + file suffix. So ``C:\\data\\emb.h5`` and ``C:\\data\\emb.h5:prot_t5`` parse + correctly instead of splitting on the drive-letter colon. + """ + if ":" in spec: + last_colon = spec.rfind(":") + path_part, name_part = spec[:last_colon], spec[last_colon + 1 :] + if ( + name_part + and not name_part.startswith(("/", "\\")) + and Path(path_part).suffix + ): + return Path(path_part), name_part + return Path(spec), None + + def parse_identifier(raw_key: str) -> str: """Extract protein identifier from an H5 key. diff --git a/tests/test_base_data_processor.py b/tests/test_base_data_processor.py index 8758b8e6..216065b3 100644 --- a/tests/test_base_data_processor.py +++ b/tests/test_base_data_processor.py @@ -137,16 +137,17 @@ def test_save_output_separate_files(self, mock_write_table): assert mock_write_table.call_count == 3 mock_mkdir.assert_called() - @patch("src.protspace.data.processors.base_processor.pq.write_table") - def test_save_output_bundled(self, _): + def test_save_output_bundled(self, tmp_path): + from src.protspace.data.io.bundle import read_bundle + processor = BaseDataProcessor(SAMPLE_CONFIG, {"pca": DummyReducer}) tables = processor.create_output( SAMPLE_METADATA, SAMPLE_REDUCTIONS, SAMPLE_HEADERS ) - with ( - patch("pathlib.Path.mkdir") as mock_mkdir, - patch("builtins.open", new_callable=MagicMock) as mock_open, - ): - processor.save_output(tables, Path("output_dir"), bundled=True) - mock_mkdir.assert_called() - mock_open.assert_called() + out_dir = tmp_path / "output_dir" + processor.save_output(tables, out_dir, bundled=True) + + bundle_path = out_dir / "data.parquetbundle" + assert bundle_path.exists() + parts, _settings = read_bundle(bundle_path) + assert len(parts) == 3 diff --git a/tests/test_bundle_overlay.py b/tests/test_bundle_overlay.py index 0e25b4aa..5cec3262 100644 --- a/tests/test_bundle_overlay.py +++ b/tests/test_bundle_overlay.py @@ -4,8 +4,10 @@ import pyarrow as pa import pyarrow.parquet as pq +import pytest from protspace.data.io.bundle import ( + PARQUET_BUNDLE_DELIMITER, read_bundle, replace_annotations_in_bundle, write_bundle, @@ -40,6 +42,72 @@ def test_replaces_annotations_keeps_other_parts(tmp_path): assert _read_part(parts[2]).to_pydict()["x"] == [0.0, 1.0] +def test_projection_parts_preserved_byte_for_byte(tmp_path): + src = tmp_path / "in.parquetbundle" + out = tmp_path / "out.parquetbundle" + write_bundle(_tables(), src, settings={"foo": 1}) + + new_annotations = pa.table( + {"identifier": ["A", "B"], "cat": ["x", "y"], "cat__pred_value": [None, "z"]} + ) + replace_annotations_in_bundle(src, out, new_annotations) + + in_parts = src.read_bytes().split(PARQUET_BUNDLE_DELIMITER) + out_parts = out.read_bytes().split(PARQUET_BUNDLE_DELIMITER) + assert out_parts[1] == in_parts[1] # projections_metadata, byte-identical + assert out_parts[2] == in_parts[2] # projections_data, byte-identical + assert out_parts[3] == in_parts[3] # settings, byte-identical + + +def test_delimiter_in_annotation_cell_raises(tmp_path): + # If an annotation value contains the bundle delimiter, the written part + # would corrupt the split on read-back; this must fail loudly instead. + src = tmp_path / "in.parquetbundle" + out = tmp_path / "out.parquetbundle" + write_bundle(_tables(), src) + + evil = "ev" + PARQUET_BUNDLE_DELIMITER.decode() + "il" + bad = pa.table({"identifier": ["A", "B"], "cat": ["x", evil]}) + with pytest.raises(ValueError): + replace_annotations_in_bundle(src, out, bad) + + +def test_in_place_overwrite_works_and_leaves_no_temp(tmp_path): + # The documented -b == -o workflow must produce the augmented bundle and + # leave no stray temp file behind. + path = tmp_path / "b.parquetbundle" + write_bundle(_tables(), path) + new_annotations = pa.table( + {"identifier": ["A", "B"], "cat": ["x", "y"], "cat__pred_value": [None, "z"]} + ) + replace_annotations_in_bundle(path, path, new_annotations) # same path + parts, _ = read_bundle(path) + assert "cat__pred_value" in pq.read_table(io.BytesIO(parts[0])).column_names + assert not list(tmp_path.glob("*.tmp")) + + +def test_failed_replace_preserves_original_in_place(tmp_path, monkeypatch): + # If the rename is interrupted, the original bundle must survive intact + # (atomic write) rather than being left truncated. + import protspace.data.io.bundle as bundle_mod + + path = tmp_path / "b.parquetbundle" + write_bundle(_tables(), path) + original = path.read_bytes() + new_annotations = pa.table( + {"identifier": ["A", "B"], "cat": ["x", "y"], "cat__pred_value": [None, "z"]} + ) + + def boom(*args, **kwargs): + raise OSError("simulated interrupt before rename") + + monkeypatch.setattr(bundle_mod.os, "replace", boom) + with pytest.raises(OSError): + replace_annotations_in_bundle(path, path, new_annotations) + assert path.read_bytes() == original # untouched + assert not list(tmp_path.glob("*.tmp")) # temp cleaned up + + def test_preserves_settings_when_present(tmp_path): src = tmp_path / "in.parquetbundle" out = tmp_path / "out.parquetbundle" diff --git a/tests/test_h5_parse_identifier.py b/tests/test_h5_parse_identifier.py index e4a5b530..fc62964a 100644 --- a/tests/test_h5_parse_identifier.py +++ b/tests/test_h5_parse_identifier.py @@ -1,8 +1,34 @@ """Tests for H5 identifier parsing.""" +from pathlib import Path + import pytest -from protspace.data.loaders.h5 import parse_identifier +from protspace.data.loaders.h5 import parse_identifier, split_h5_spec + + +class TestSplitH5Spec: + def test_no_override(self): + assert split_h5_spec("emb.h5") == (Path("emb.h5"), None) + + def test_name_override(self): + assert split_h5_spec("emb.h5:prot_t5") == (Path("emb.h5"), "prot_t5") + + def test_posix_path_with_override(self): + assert split_h5_spec("/data/emb.h5:prot_t5") == ( + Path("/data/emb.h5"), + "prot_t5", + ) + + def test_windows_drive_path_no_override(self): + # The drive-letter colon must not be mistaken for a name override. + assert split_h5_spec("C:\\data\\emb.h5") == (Path("C:\\data\\emb.h5"), None) + + def test_windows_drive_path_with_override(self): + assert split_h5_spec("C:\\data\\emb.h5:prot_t5") == ( + Path("C:\\data\\emb.h5"), + "prot_t5", + ) class TestParseIdentifier: diff --git a/tests/test_predictions_overlay.py b/tests/test_predictions_overlay.py index 8a44950e..66f94d92 100644 --- a/tests/test_predictions_overlay.py +++ b/tests/test_predictions_overlay.py @@ -1,6 +1,9 @@ """Tests for building the per-cell prediction overlay columns.""" +import io + import pyarrow as pa +import pyarrow.parquet as pq from protlabel import Prediction from protspace.data.io.predictions import add_overlay_columns @@ -67,3 +70,25 @@ def test_prediction_for_unknown_identifier_is_ignored(): preds = [Prediction("NOT_IN_TABLE", "x", "R0", 0.1, 0.9, 1, "euclidean")] out = add_overlay_columns(_table(), "protein_category", preds).to_pylist() assert all(r["protein_category__pred_value"] is None for r in out) + + +def test_reapplying_overlay_replaces_not_duplicates(): + # Re-running transfer on an already-overlaid table must replace the overlay + # columns, not append duplicates (which produce an unreadable parquet table). + preds1 = [Prediction("Q0", "old", "R0", 0.3, 0.6, 1, "euclidean")] + once = add_overlay_columns(_table(), "protein_category", preds1) + preds2 = [Prediction("Q0", "new", "R0", 0.1, 0.9, 1, "euclidean")] + twice = add_overlay_columns(once, "protein_category", preds2) + + assert twice.column_names.count("protein_category__pred_value") == 1 + assert twice.column_names.count("protein_category__pred_confidence") == 1 + assert twice.column_names.count("protein_category__pred_source") == 1 + + by_id = {r["identifier"]: r for r in twice.to_pylist()} + assert by_id["Q0"]["protein_category__pred_value"] == "new" + + # Duplicate column names would make this round-trip raise ArrowInvalid. + buf = io.BytesIO() + pq.write_table(twice, buf) + reread = pq.read_table(io.BytesIO(buf.getvalue())) + assert reread.column("protein_category__pred_value").to_pylist()[0] == "new" diff --git a/tests/test_protlabel_backends.py b/tests/test_protlabel_backends.py index f1ea8732..5f7285aa 100644 --- a/tests/test_protlabel_backends.py +++ b/tests/test_protlabel_backends.py @@ -79,3 +79,62 @@ def test_k_less_than_one_raises(): queries, refs = _toy() with pytest.raises(ValueError): nearest(queries, refs, k=0, metric="euclidean") + + +def test_cosine_zero_vector_query_is_finite(): + # A zero-magnitude query must not produce NaN cosine distances (scipy.cdist + # returns NaN here); the result must stay finite and deterministic. + refs = np.array([[1.0, 0.0], [0.0, 1.0]], dtype=np.float32) + queries = np.array([[0.0, 0.0]], dtype=np.float32) + idx, dist = nearest(queries, refs, k=2, metric="cosine") + assert idx.shape == (1, 2) + assert np.all(np.isfinite(dist)) + + +def test_cosine_zero_vector_reference_is_finite(): + # A zero-magnitude reference must not poison the distance block with NaN. + refs = np.array([[0.0, 0.0], [0.0, 1.0]], dtype=np.float32) + queries = np.array([[1.0, 0.1]], dtype=np.float32) + idx, dist = nearest(queries, refs, k=2, metric="cosine") + assert np.all(np.isfinite(dist)) + + +def test_euclidean_matches_reference_distances(): + # The (BLAS) euclidean path must agree with a direct distance computation. + rng = np.random.default_rng(7) + refs = rng.standard_normal((30, 12)).astype(np.float32) + queries = rng.standard_normal((5, 12)).astype(np.float32) + idx, dist = nearest(queries, refs, k=4, metric="euclidean") + for qi in range(queries.shape[0]): + full = np.linalg.norm(refs - queries[qi], axis=1) + expected = np.sort(full)[:4] + assert np.allclose(dist[qi], expected, atol=1e-4) + + +def test_euclidean_precise_for_near_identical_high_dim(): + # Near-identical high-dimensional vectors are the catastrophic-cancellation + # regime for the GEMM expansion; the reported distance must stay precise + # (not collapse to 0) by matching a direct float64 norm. + rng = np.random.default_rng(3) + base = rng.standard_normal(1024).astype(np.float32) + other = rng.standard_normal(1024).astype(np.float32) + refs = np.stack([base, base + np.float32(1e-3), other]) + query = (base + np.float32(2e-4))[None, :] + idx, dist = nearest(query, refs, k=2, metric="euclidean") + assert idx[0, 0] == 0 # nearest is the near-duplicate of base + expected = np.linalg.norm(refs[0].astype(np.float64) - query[0].astype(np.float64)) + assert dist[0, 0] > 0.0 # not clipped to zero + assert dist[0, 0] == pytest.approx(expected, rel=1e-3) + + +def test_cosine_matches_reference_distances(): + rng = np.random.default_rng(8) + refs = rng.standard_normal((25, 10)).astype(np.float32) + queries = rng.standard_normal((4, 10)).astype(np.float32) + idx, dist = nearest(queries, refs, k=3, metric="cosine") + rn = refs / np.linalg.norm(refs, axis=1, keepdims=True) + for qi in range(queries.shape[0]): + qn = queries[qi] / np.linalg.norm(queries[qi]) + full = 1.0 - rn @ qn + expected = np.sort(full)[:3] + assert np.allclose(dist[qi], expected, atol=1e-4) diff --git a/tests/test_protlabel_lookup.py b/tests/test_protlabel_lookup.py index b84b6066..ff056e28 100644 --- a/tests/test_protlabel_lookup.py +++ b/tests/test_protlabel_lookup.py @@ -30,6 +30,31 @@ def test_save_load_roundtrip(tmp_path): assert np.allclose(loaded.embeddings, lk.embeddings) +def test_save_load_preserves_float32_values_exactly(tmp_path): + # Non-power-of-2 values would be corrupted by a float16 round-trip; the + # sidecar must preserve the float32 embeddings exactly. + emb = np.array([[0.123456, 1.9876543], [3.14159, 2.71828]], dtype=np.float32) + lk = Lookup(embeddings=emb, ids=["P1", "P2"], labels=["x", "y"]) + path = tmp_path / "lk.npz" + lk.save(path) + loaded = Lookup.load(path) + assert loaded.embeddings.dtype == np.float32 + assert np.array_equal(loaded.embeddings, emb) + + +def test_save_uses_no_object_arrays(tmp_path): + # ids/labels must be stored as plain unicode (not pickled object arrays), so + # the sidecar can be loaded with allow_pickle=False (no RCE surface). + lk = Lookup(embeddings=_lookup().embeddings, ids=["R0", "R1"], labels=["a", "b"]) + path = tmp_path / "lk.npz" + lk.save(path) + with np.load(path, allow_pickle=False) as data: + assert data["ids"].dtype != object + assert data["labels"].dtype != object + assert list(data["ids"]) == ["R0", "R1"] + assert list(data["labels"]) == ["a", "b"] + + def test_loaded_lookup_queries_identically(tmp_path): lk = _lookup() q = np.array([[9.8, 0.0]], dtype=np.float32) diff --git a/tests/test_protlabel_transfer.py b/tests/test_protlabel_transfer.py index 417f5218..458d56f9 100644 --- a/tests/test_protlabel_transfer.py +++ b/tests/test_protlabel_transfer.py @@ -79,6 +79,17 @@ def test_empty_references_raises(): ) +def test_equal_distance_tie_break_is_deterministic(): + # Two refs exactly equidistant from the query carry different labels at equal + # source distances; the winner must be deterministic (lexically smallest + # label) rather than dependent on argsort ordering of the equal distances. + ref_emb = np.array([[1.0, 0.0], [-1.0, 0.0]], dtype=np.float32) + query_emb = np.array([[0.0, 0.0]], dtype=np.float32) # distance 1.0 to both + preds = eat(query_emb, ["Q"], ref_emb, ["R_z", "R_a"], ["zebra", "apple"], k=2) + assert preds[0].label == "apple" + assert preds[0].source_id == "R_a" + + def test_source_is_nearest_neighbour_with_winning_label(): # Two neighbours share the winning label at distinct distances; the source # must be the closer one. diff --git a/tests/test_protlabel_version.py b/tests/test_protlabel_version.py new file mode 100644 index 00000000..3b436511 --- /dev/null +++ b/tests/test_protlabel_version.py @@ -0,0 +1,11 @@ +"""protlabel.__version__ reports the version it actually ships under.""" + +from importlib.metadata import version + +import protlabel + + +def test_version_matches_shipped_distribution(): + # protlabel currently ships inside the protspace wheel, so its reported + # version must track the installed distribution rather than a stale literal. + assert protlabel.__version__ == version("protspace") diff --git a/tests/test_transfer_cli.py b/tests/test_transfer_cli.py index 7a3a01c3..1eea4c3f 100644 --- a/tests/test_transfer_cli.py +++ b/tests/test_transfer_cli.py @@ -1,5 +1,7 @@ """Tests for the transfer orchestration core and CLI registration.""" +import logging + import numpy as np import pyarrow as pa import pytest @@ -8,6 +10,46 @@ from protspace.cli.transfer import run_transfer +def _three_protein_inputs(extra_columns=None): + cols = { + "identifier": ["TRINITY_1", "P00001", "P00002"], + "protein_category": ["", "neurotoxin", "enzyme"], + } + if extra_columns: + cols.update(extra_columns) + annotations = pa.table(cols) + embeddings = { + "TRINITY_1": np.array([0.0, 0.0], dtype=np.float32), + "P00001": np.array([0.05, 0.0], dtype=np.float32), + "P00002": np.array([9.0, 0.0], dtype=np.float32), + } + return annotations, embeddings + + +def _write_bundle_and_h5(tmp_path, *, id_col="protein_id", extra_columns=None): + import h5py + + from protspace.data.io.bundle import write_bundle + + cols = {id_col: ["TRINITY_1", "P00001"], "protein_category": ["", "neurotoxin"]} + if extra_columns: + cols.update(extra_columns) + annotations = pa.table(cols) + proj_meta = pa.table({"name": ["PCA 2"], "dims": [2]}) + proj_data = pa.table( + {"id": ["TRINITY_1", "P00001"], "x": [0.0, 9.0], "y": [0.0, 0.0]} + ) + bundle_path = tmp_path / "in.parquetbundle" + write_bundle([annotations, proj_meta, proj_data], bundle_path) + + h5_path = tmp_path / "emb.h5" + with h5py.File(h5_path, "w") as f: + f.attrs["model_name"] = "test_model" + f.create_dataset("TRINITY_1", data=np.array([0.0, 0.0], dtype=np.float32)) + f.create_dataset("P00001", data=np.array([0.1, 0.0], dtype=np.float32)) + return bundle_path, h5_path + + def _inputs(): annotations = pa.table( { @@ -58,6 +100,214 @@ def test_run_transfer_skips_proteins_without_embeddings(): ) +def test_run_transfer_multi_column_skips_column_without_references(caplog): + # Two transfer columns; the second has no references with a value and must be + # skipped while the first still produces its overlay. + annotations, embeddings = _three_protein_inputs( + extra_columns={"other_col": ["", "", ""]} + ) + with caplog.at_level(logging.WARNING): + out = run_transfer( + annotations=annotations, + embeddings=embeddings, + transfer_columns=["protein_category", "other_col"], + query_rule=Rule(id_prefixes=["TRINITY_"]), + reference_rule=Rule(id_prefixes=["P0"]), + k=1, + metric="euclidean", + ) + assert "protein_category__pred_value" in out.column_names + assert "other_col__pred_value" not in out.column_names + assert "other_col" in caplog.text + + +def test_run_transfer_k_greater_than_one(): + annotations, embeddings = _three_protein_inputs() + out = run_transfer( + annotations=annotations, + embeddings=embeddings, + transfer_columns=["protein_category"], + query_rule=Rule(id_prefixes=["TRINITY_"]), + reference_rule=Rule(id_prefixes=["P0"]), + k=2, + metric="euclidean", + ) + by_id = {r["identifier"]: r for r in out.to_pylist()} + assert by_id["TRINITY_1"]["protein_category__pred_value"] == "neurotoxin" + + +def test_run_transfer_cosine_metric(): + annotations, embeddings = _three_protein_inputs() + # Cosine needs non-parallel-to-axis vectors to be meaningful. + embeddings = { + "TRINITY_1": np.array([1.0, 0.1], dtype=np.float32), + "P00001": np.array([1.0, 0.0], dtype=np.float32), + "P00002": np.array([0.0, 1.0], dtype=np.float32), + } + out = run_transfer( + annotations=annotations, + embeddings=embeddings, + transfer_columns=["protein_category"], + query_rule=Rule(id_prefixes=["TRINITY_"]), + reference_rule=Rule(id_prefixes=["P0"]), + k=1, + metric="cosine", + ) + by_id = {r["identifier"]: r for r in out.to_pylist()} + assert by_id["TRINITY_1"]["protein_category__pred_value"] == "neurotoxin" + assert by_id["TRINITY_1"]["protein_category__pred_source"] == "P00001" + + +def test_run_transfer_warns_when_nothing_transferred(caplog): + # Query proteins all already have a value -> nothing to transfer -> warning. + annotations = pa.table( + { + "identifier": ["TRINITY_1", "P00001"], + "protein_category": ["already_set", "neurotoxin"], + } + ) + embeddings = { + "TRINITY_1": np.array([0.0, 0.0], dtype=np.float32), + "P00001": np.array([0.1, 0.0], dtype=np.float32), + } + with caplog.at_level(logging.WARNING): + out = run_transfer( + annotations=annotations, + embeddings=embeddings, + transfer_columns=["protein_category"], + query_rule=Rule(id_prefixes=["TRINITY_"]), + reference_rule=Rule(id_prefixes=["P0"]), + k=1, + metric="euclidean", + ) + assert "protein_category__pred_value" not in out.column_names + assert "No annotations were transferred" in caplog.text + + +def test_cli_bad_where_column_is_clean_error(tmp_path): + from typer.testing import CliRunner + + from protspace.cli.app import app + + bundle, h5 = _write_bundle_and_h5(tmp_path) + out = tmp_path / "out.parquetbundle" + result = CliRunner().invoke( + app, + [ + "transfer", + "-b", + str(bundle), + "-e", + str(h5), + "-t", + "protein_category", + "-o", + str(out), + "--query-where", + "nonexistent~x", + "--reference-id-prefix", + "P0", + ], + ) + assert result.exit_code != 0 + assert not isinstance(result.exception, KeyError) + assert "nonexistent" in result.output + + +def test_cli_no_matching_embeddings_is_clean_error(tmp_path): + import h5py + from typer.testing import CliRunner + + from protspace.cli.app import app + + bundle, _ = _write_bundle_and_h5(tmp_path) + bad_h5 = tmp_path / "bad.h5" + with h5py.File(bad_h5, "w") as f: + f.attrs["model_name"] = "m" + f.create_dataset("ZZZ", data=np.array([0.0, 0.0], dtype=np.float32)) + out = tmp_path / "out.parquetbundle" + result = CliRunner().invoke( + app, + [ + "transfer", + "-b", + str(bundle), + "-e", + str(bad_h5), + "-t", + "protein_category", + "-o", + str(out), + "--query-id-prefix", + "TRINITY_", + "--reference-id-prefix", + "P0", + ], + ) + assert result.exit_code != 0 + assert not isinstance(result.exception, ValueError) + + +def test_cli_no_query_match_is_clean_error(tmp_path): + from typer.testing import CliRunner + + from protspace.cli.app import app + + bundle, h5 = _write_bundle_and_h5(tmp_path) + out = tmp_path / "out.parquetbundle" + result = CliRunner().invoke( + app, + [ + "transfer", + "-b", + str(bundle), + "-e", + str(h5), + "-t", + "protein_category", + "-o", + str(out), + "--query-id-prefix", + "NOPE_", + "--reference-id-prefix", + "P0", + ], + ) + assert result.exit_code != 0 + assert not isinstance(result.exception, ValueError) + + +def test_cli_both_id_columns_present_is_clean_error(tmp_path): + from typer.testing import CliRunner + + from protspace.cli.app import app + + bundle, h5 = _write_bundle_and_h5( + tmp_path, extra_columns={"identifier": ["TRINITY_1", "P00001"]} + ) + out = tmp_path / "out.parquetbundle" + result = CliRunner().invoke( + app, + [ + "transfer", + "-b", + str(bundle), + "-e", + str(h5), + "-t", + "protein_category", + "-o", + str(out), + "--query-id-prefix", + "TRINITY_", + "--reference-id-prefix", + "P0", + ], + ) + assert result.exit_code != 0 + assert not isinstance(result.exception, KeyError) + + def test_transfer_command_is_registered(): from typer.testing import CliRunner From f7186f56812641558c49672500bdf785f80234af Mon Sep 17 00:00:00 2001 From: tsenoner Date: Tue, 16 Jun 2026 19:13:18 +0200 Subject: [PATCH 21/21] refactor(transfer): drop __pred_source overlay column; keep numeric confidence MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The per-cell prediction overlay now writes only __pred_value and __pred_confidence. The reference id (source) is noise as a colour feature, so it is dropped from the bundle; it remains available on protlabel's Prediction. A legacy __pred_source is dropped on re-run so older bundles are cleaned up. Keeping confidence as a separate numeric column lets the web frontend colour and threshold by reliability (gradient legend) — which inline label|score values do not enable (those render tooltip-only). Co-Authored-By: Claude Fable 5 --- docs/annotations.md | 3 +-- docs/cli.md | 2 +- ...26-06-11-eat-annotation-transfer-design.md | 2 +- notebooks/ProtSpace_Transfer.ipynb | 2 +- src/protspace/data/io/predictions.py | 14 +++++------ tests/test_predictions_overlay.py | 23 +++++++++++++++---- tests/test_transfer_cli.py | 6 ++--- 7 files changed, 33 insertions(+), 19 deletions(-) diff --git a/docs/annotations.md b/docs/annotations.md index 7928fb61..8698630d 100644 --- a/docs/annotations.md +++ b/docs/annotations.md @@ -186,13 +186,12 @@ The `default` group only requires the UniProt REST API (+ ExPASy for EC names). ## Prediction Overlay Columns (EAT Transfer) -Running `protspace transfer` appends three new columns to the bundle's annotations table for each requested column `COL`. The curated `COL` column is never modified. +Running `protspace transfer` appends two new columns to the bundle's annotations table for each requested column `COL`. The curated `COL` column is never modified. | Column | Type | Meaning | | --- | --- | --- | | `COL__pred_value` | string | The transferred label from the nearest annotated reference protein | | `COL__pred_confidence` | float | Reliability index in [0, 1] — 1 = identical embeddings (formula depends on `--metric`/`--k`, see below) | -| `COL__pred_source` | string | UniProt accession (or ID) of the nearest reference protein | A protein is considered "predicted" for `COL` when `COL` is empty but `COL__pred_value` is present. Use `COL__pred_confidence` to threshold low-reliability transfers. diff --git a/docs/cli.md b/docs/cli.md index 158c7a53..ded5ea37 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -186,7 +186,7 @@ protspace style data.parquetbundle --dump-settings ## `protspace transfer` -Embedding Annotation Transfer (EAT): fills missing annotation values for query proteins by transferring the annotation of the nearest annotated reference protein in pLM embedding space. For each query protein that lacks a value in the requested annotation column, the command finds the closest reference (by distance in the original high-dimensional embedding space — Euclidean by default, or cosine via `--metric`, and not in the 2-D/3-D projection) and assigns that reference's label along with a reliability index adapted from goPredSim, yielding a score in [0, 1] where 1 means identical embeddings. The curated source column (`COL`) is left untouched; results are written as three new columns: `COL__pred_value` (string), `COL__pred_confidence` (float), and `COL__pred_source` (the nearest reference protein ID). The method is a direct application of the approach introduced by Littmann et al., Sci Rep 2021 ([DOI 10.1038/s41598-020-80786-0](https://doi.org/10.1038/s41598-020-80786-0)) and extended by Heinzinger et al., NAR Genom Bioinform 2022 ([DOI 10.1093/nargab/lqac043](https://doi.org/10.1093/nargab/lqac043)). +Embedding Annotation Transfer (EAT): fills missing annotation values for query proteins by transferring the annotation of the nearest annotated reference protein in pLM embedding space. For each query protein that lacks a value in the requested annotation column, the command finds the closest reference (by distance in the original high-dimensional embedding space — Euclidean by default, or cosine via `--metric`, and not in the 2-D/3-D projection) and assigns that reference's label along with a reliability index adapted from goPredSim, yielding a score in [0, 1] where 1 means identical embeddings. The curated source column (`COL`) is left untouched; results are written as two new columns: `COL__pred_value` (string) and `COL__pred_confidence` (float). The method is a direct application of the approach introduced by Littmann et al., Sci Rep 2021 ([DOI 10.1038/s41598-020-80786-0](https://doi.org/10.1038/s41598-020-80786-0)) and extended by Heinzinger et al., NAR Genom Bioinform 2022 ([DOI 10.1093/nargab/lqac043](https://doi.org/10.1093/nargab/lqac043)). **Reliability index (`COL__pred_confidence`).** The exact form depends on `--metric` and `--k`: diff --git a/docs/superpowers/specs/2026-06-11-eat-annotation-transfer-design.md b/docs/superpowers/specs/2026-06-11-eat-annotation-transfer-design.md index 0038b355..698f595e 100644 --- a/docs/superpowers/specs/2026-06-11-eat-annotation-transfer-design.md +++ b/docs/superpowers/specs/2026-06-11-eat-annotation-transfer-design.md @@ -10,7 +10,7 @@ > > **Shipped in PR #55** (`protspace transfer`, see `src/protspace/cli/transfer.py`): > - Flags: `-b/--bundle`, `-e/--embeddings`, `-t/--transfer`, `-o/--output`, `--query-id-prefix`, `--query-where`, `--reference-id-prefix`, `--reference-where`, `--k`, `--metric` (`euclidean` | `cosine`). -> - Wide overlay columns appended to the bundle annotations table (`src/protspace/data/io/predictions.py`): `__pred_value`, `__pred_confidence`, `__pred_source`. +> - Wide overlay columns appended to the bundle annotations table (`src/protspace/data/io/predictions.py`): `__pred_value`, `__pred_confidence`. > - Brute-force nearest-neighbour search and the goPredSim reliability index. > > **Deferred / not yet implemented (future work):** diff --git a/notebooks/ProtSpace_Transfer.ipynb b/notebooks/ProtSpace_Transfer.ipynb index 8e2e4089..cf98973e 100644 --- a/notebooks/ProtSpace_Transfer.ipynb +++ b/notebooks/ProtSpace_Transfer.ipynb @@ -22,7 +22,7 @@ "cell_type": "markdown", "id": "a1b2c3d4", "metadata": {}, - "source": "# ProtSpace — Embedding Annotation Transfer (EAT)\n\nThis notebook demonstrates **Embedding Annotation Transfer (EAT)** with `protspace transfer`.\nFor each query protein that lacks an annotation value, the command finds the closest\nannotated reference protein in pLM embedding space and transfers its label, together\nwith a reliability index in [0, 1]. The exact confidence formula depends on `--metric` and `--k`:\n\n- Default (`--metric euclidean`, `--k 1`): `confidence = 0.5 / (0.5 + distance)`.\n- `--metric cosine` (`--k 1`): `confidence = clamp(1 - cosine_distance, 0, 1)` (cosine distance in [0, 2]).\n- `--k > 1`: the goPredSim mean reliability — `(1/m) * sum` of the per-neighbour similarity over the `k` nearest neighbours carrying the chosen label, where `m = min(k, number of references)`. Because of this normalization, values are not comparable across different `--k` settings.\n\nThe method follows the goPredSim approach introduced in:\n\n- Littmann et al., *Sci Rep* 2021 — [DOI 10.1038/s41598-020-80786-0](https://doi.org/10.1038/s41598-020-80786-0)\n- Heinzinger et al., *NAR Genom Bioinform* 2022 — [DOI 10.1093/nargab/lqac043](https://doi.org/10.1093/nargab/lqac043)\n\nDistances are computed in the original high-dimensional embedding space (HDF5),\nnot in any 2-D/3-D projection. The curated source column is left untouched;\nresults are written as `COL__pred_value`, `COL__pred_confidence`, and `COL__pred_source`\ncolumns in the bundle's annotations table.\n\n📚 [GitHub](https://github.com/tsenoner/protspace) · [CLI Reference](https://github.com/tsenoner/protspace/blob/main/docs/cli.md#protspace-transfer) · [Annotation Reference](https://github.com/tsenoner/protspace/blob/main/docs/annotations.md#prediction-overlay-columns-eat-transfer)" + "source": "# ProtSpace — Embedding Annotation Transfer (EAT)\n\nThis notebook demonstrates **Embedding Annotation Transfer (EAT)** with `protspace transfer`.\nFor each query protein that lacks an annotation value, the command finds the closest\nannotated reference protein in pLM embedding space and transfers its label, together\nwith a reliability index in [0, 1]. The exact confidence formula depends on `--metric` and `--k`:\n\n- Default (`--metric euclidean`, `--k 1`): `confidence = 0.5 / (0.5 + distance)`.\n- `--metric cosine` (`--k 1`): `confidence = clamp(1 - cosine_distance, 0, 1)` (cosine distance in [0, 2]).\n- `--k > 1`: the goPredSim mean reliability — `(1/m) * sum` of the per-neighbour similarity over the `k` nearest neighbours carrying the chosen label, where `m = min(k, number of references)`. Because of this normalization, values are not comparable across different `--k` settings.\n\nThe method follows the goPredSim approach introduced in:\n\n- Littmann et al., *Sci Rep* 2021 — [DOI 10.1038/s41598-020-80786-0](https://doi.org/10.1038/s41598-020-80786-0)\n- Heinzinger et al., *NAR Genom Bioinform* 2022 — [DOI 10.1093/nargab/lqac043](https://doi.org/10.1093/nargab/lqac043)\n\nDistances are computed in the original high-dimensional embedding space (HDF5),\nnot in any 2-D/3-D projection. The curated source column is left untouched;\nresults are written as `COL__pred_value` and `COL__pred_confidence`\ncolumns in the bundle's annotations table.\n\n📚 [GitHub](https://github.com/tsenoner/protspace) · [CLI Reference](https://github.com/tsenoner/protspace/blob/main/docs/cli.md#protspace-transfer) · [Annotation Reference](https://github.com/tsenoner/protspace/blob/main/docs/annotations.md#prediction-overlay-columns-eat-transfer)" }, { "cell_type": "markdown", diff --git a/src/protspace/data/io/predictions.py b/src/protspace/data/io/predictions.py index e60ae5f5..4722bb4e 100644 --- a/src/protspace/data/io/predictions.py +++ b/src/protspace/data/io/predictions.py @@ -1,10 +1,13 @@ """Turn protlabel Predictions into per-cell overlay columns on the annotations table. -For a transferred column ``COL`` we append three aligned columns (null for +For a transferred column ``COL`` we append two aligned columns (null for non-predicted proteins), leaving the curated ``COL`` untouched: COL__pred_value (string) the transferred label COL__pred_confidence (float32) the reliability index in [0, 1] - COL__pred_source (string) the nearest reference protein id + +The reference protein the value came from is available on the ``Prediction`` +(``source_id``) but is intentionally not written to the bundle: it is noise as a +per-cell colour feature, and confidence is the signal users threshold on. """ from __future__ import annotations @@ -19,27 +22,25 @@ def add_overlay_columns( annotations: pa.Table, column: str, predictions: Sequence[Prediction] ) -> pa.Table: - """Append the COL__pred_* overlay columns, aligned by identifier.""" + """Append the COL__pred_value / COL__pred_confidence columns, by identifier.""" by_query = {p.query_id: p for p in predictions} identifiers = [str(v) for v in annotations.column("identifier").to_pylist()] values: list[str | None] = [] confidences: list[float | None] = [] - sources: list[str | None] = [] for identifier in identifiers: pred = by_query.get(identifier) if pred is None: values.append(None) confidences.append(None) - sources.append(None) else: values.append(pred.label) confidences.append(float(pred.reliability)) - sources.append(pred.source_id) # Drop any pre-existing overlay columns first so re-running transfer on an # already-overlaid table replaces them rather than appending duplicates # (duplicate field names produce a parquet table that cannot be read back). + # The legacy __pred_source is included so older bundles are cleaned up. overlay_names = [ f"{column}__pred_value", f"{column}__pred_confidence", @@ -54,5 +55,4 @@ def add_overlay_columns( out = out.append_column( f"{column}__pred_confidence", pa.array(confidences, pa.float32()) ) - out = out.append_column(f"{column}__pred_source", pa.array(sources, pa.string())) return out diff --git a/tests/test_predictions_overlay.py b/tests/test_predictions_overlay.py index 66f94d92..3817ace1 100644 --- a/tests/test_predictions_overlay.py +++ b/tests/test_predictions_overlay.py @@ -18,14 +18,16 @@ def _table(): ) -def test_adds_three_overlay_columns(): +def test_adds_value_and_confidence_overlay_columns(): preds = [ Prediction("Q0", "neurotoxin", "R0", 0.3, 0.62, 1, "euclidean"), ] out = add_overlay_columns(_table(), "protein_category", preds) assert "protein_category__pred_value" in out.column_names assert "protein_category__pred_confidence" in out.column_names - assert "protein_category__pred_source" in out.column_names + # The source column is intentionally not emitted (provenance is not a useful + # per-cell overlay; confidence carries the signal users threshold on). + assert "protein_category__pred_source" not in out.column_names def test_overlay_values_aligned_by_identifier(): @@ -34,7 +36,6 @@ def test_overlay_values_aligned_by_identifier(): by_id = {r["identifier"]: r for r in out} assert by_id["Q1"]["protein_category__pred_value"] == "enzyme" assert by_id["Q1"]["protein_category__pred_confidence"] == 0.5 - assert by_id["Q1"]["protein_category__pred_source"] == "R9" # Non-predicted rows are null in the overlay columns. assert by_id["Q0"]["protein_category__pred_value"] is None assert by_id["R0"]["protein_category__pred_confidence"] is None @@ -82,7 +83,7 @@ def test_reapplying_overlay_replaces_not_duplicates(): assert twice.column_names.count("protein_category__pred_value") == 1 assert twice.column_names.count("protein_category__pred_confidence") == 1 - assert twice.column_names.count("protein_category__pred_source") == 1 + assert "protein_category__pred_source" not in twice.column_names by_id = {r["identifier"]: r for r in twice.to_pylist()} assert by_id["Q0"]["protein_category__pred_value"] == "new" @@ -92,3 +93,17 @@ def test_reapplying_overlay_replaces_not_duplicates(): pq.write_table(twice, buf) reread = pq.read_table(io.BytesIO(buf.getvalue())) assert reread.column("protein_category__pred_value").to_pylist()[0] == "new" + + +def test_legacy_source_column_is_removed_on_rerun(): + # A bundle written by an older version may carry a __pred_source column; + # re-running must drop it rather than leave it orphaned/stale. + legacy = _table().append_column( + "protein_category__pred_source", pa.array(["R0", None, None], pa.string()) + ) + out = add_overlay_columns( + legacy, + "protein_category", + [Prediction("Q0", "neurotoxin", "R0", 0.1, 0.8, 1, "euclidean")], + ) + assert "protein_category__pred_source" not in out.column_names diff --git a/tests/test_transfer_cli.py b/tests/test_transfer_cli.py index 1eea4c3f..a3cacb97 100644 --- a/tests/test_transfer_cli.py +++ b/tests/test_transfer_cli.py @@ -81,8 +81,8 @@ def test_run_transfer_predicts_for_query_with_missing_value(): ) by_id = {r["identifier"]: r for r in out.to_pylist()} assert by_id["TRINITY_1"]["protein_category__pred_value"] == "neurotoxin" - assert by_id["TRINITY_1"]["protein_category__pred_source"] == "P00001" assert by_id["TRINITY_1"]["protein_category__pred_confidence"] > 0.9 + assert "protein_category__pred_source" not in out.column_names def test_run_transfer_skips_proteins_without_embeddings(): @@ -155,7 +155,7 @@ def test_run_transfer_cosine_metric(): ) by_id = {r["identifier"]: r for r in out.to_pylist()} assert by_id["TRINITY_1"]["protein_category__pred_value"] == "neurotoxin" - assert by_id["TRINITY_1"]["protein_category__pred_source"] == "P00001" + assert "protein_category__pred_source" not in out.column_names def test_run_transfer_warns_when_nothing_transferred(caplog): @@ -371,4 +371,4 @@ def test_cli_end_to_end_protein_id_bundle(tmp_path): assert "protein_id" in table.column_names # id column preserved for the web reader rows = {r["protein_id"]: r for r in table.to_pylist()} assert rows["TRINITY_1"]["protein_category__pred_value"] == "neurotoxin" - assert rows["TRINITY_1"]["protein_category__pred_source"] == "P00001" + assert "protein_category__pred_source" not in table.column_names