diff --git a/CLAUDE.md b/CLAUDE.md index f3a8aeb7..a158c6db 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -49,6 +49,7 @@ Single entry point: `protspace = protspace.cli.app:app` | `protspace bundle` | Combine projections + annotations → .parquetbundle | | `protspace serve` | Launch Dash web frontend | | `protspace style` | Add annotation colors/styles | +| `protspace transfer` | Fill missing annotations from nearest reference embeddings (EAT) | ### protspace prepare Usage diff --git a/docs/annotations.md b/docs/annotations.md index 1f21d127..8698630d 100644 --- a/docs/annotations.md +++ b/docs/annotations.md @@ -183,3 +183,22 @@ Per-protein predictions from the [Biocentral API](https://biocentral.rostlab.org | Pfam clans | `~/.cache/protspace/pfam_clans/` | 30 days | Pfam family → clan mapping | The `default` group only requires the UniProt REST API (+ ExPASy for EC names). For `--keep-tmp` annotation caching, see [CLI Reference](cli.md#annotation-caching---keep-tmp). + +## Prediction Overlay Columns (EAT Transfer) + +Running `protspace transfer` appends two new columns to the bundle's annotations table for each requested column `COL`. The curated `COL` column is never modified. + +| Column | Type | Meaning | +| --- | --- | --- | +| `COL__pred_value` | string | The transferred label from the nearest annotated reference protein | +| `COL__pred_confidence` | float | Reliability index in [0, 1] — 1 = identical embeddings (formula depends on `--metric`/`--k`, see below) | + +A protein is considered "predicted" for `COL` when `COL` is empty but `COL__pred_value` is present. Use `COL__pred_confidence` to threshold low-reliability transfers. + +The reliability index depends on the `--metric` and `--k` used during transfer: + +- **Default (`--metric euclidean`, `--k 1`):** `0.5 / (0.5 + distance)`. +- **`--metric cosine` (`--k 1`):** `clamp(1 - cosine_distance, 0, 1)`, where `cosine_distance` is in [0, 2]. +- **`--k > 1`:** the goPredSim mean reliability — `(1/m) · Σ s(d)` of the per-neighbour similarity over the `k` nearest neighbours carrying the chosen label, with `m = min(k, number of references)`. Because of this normalization, values are **not** comparable across different `--k` settings. + +See [`protspace transfer`](cli.md#protspace-transfer) for usage and option details. diff --git a/docs/cli.md b/docs/cli.md index ddb3fd90..ded5ea37 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -9,6 +9,7 @@ | `protspace bundle` | Combine projections + annotations into .parquetbundle | | `protspace serve` | Launch interactive Dash web frontend | | `protspace style` | Add/inspect annotation styles in existing files | +| `protspace transfer` | Fill missing annotations from nearest reference embeddings (EAT) | Run `protspace -h` for detailed help. @@ -183,6 +184,43 @@ protspace style input.parquetbundle output.parquetbundle --annotation-styles sty protspace style data.parquetbundle --dump-settings ``` +## `protspace transfer` + +Embedding Annotation Transfer (EAT): fills missing annotation values for query proteins by transferring the annotation of the nearest annotated reference protein in pLM embedding space. For each query protein that lacks a value in the requested annotation column, the command finds the closest reference (by distance in the original high-dimensional embedding space — Euclidean by default, or cosine via `--metric`, and not in the 2-D/3-D projection) and assigns that reference's label along with a reliability index adapted from goPredSim, yielding a score in [0, 1] where 1 means identical embeddings. The curated source column (`COL`) is left untouched; results are written as two new columns: `COL__pred_value` (string) and `COL__pred_confidence` (float). The method is a direct application of the approach introduced by Littmann et al., Sci Rep 2021 ([DOI 10.1038/s41598-020-80786-0](https://doi.org/10.1038/s41598-020-80786-0)) and extended by Heinzinger et al., NAR Genom Bioinform 2022 ([DOI 10.1093/nargab/lqac043](https://doi.org/10.1093/nargab/lqac043)). + +**Reliability index (`COL__pred_confidence`).** The exact form depends on `--metric` and `--k`: + +- **Default (`--metric euclidean`, `--k 1`):** `confidence = 0.5 / (0.5 + distance)` (1 at distance 0, 0.5 at distance 0.5, → 0 as distance → ∞). +- **`--metric cosine` (`--k 1`):** `confidence = clamp(1 - cosine_distance, 0, 1)`, where `cosine_distance` is in [0, 2]. +- **`--k > 1`:** the value is the goPredSim mean reliability — `(1/m) · Σ s(d)`, the sum of the per-neighbour similarity `s(d)` (the euclidean or cosine form above) over the `k` nearest neighbours that carry the chosen label, divided by `m = min(k, number of references)`. Because of this normalization, confidence values are **not** comparable across different `--k` settings. + +```bash +protspace transfer \ + -b results.parquetbundle \ + -e embeddings.h5:prot_t5 \ + -t protein_category \ + -o results.parquetbundle \ + --query-id-prefix TRINITY_ \ + --reference-where 'protein_category~neurotoxin' +``` + +**Key options:** + +| Flag | Description | Default | +| ---- | ----------- | ------- | +| `-b, --bundle` | Input `.parquetbundle` file | — | +| `-e, --embeddings` | HDF5 embeddings file (use `:name` suffix for external files) | — | +| `-t, --transfer` | Annotation column to transfer (repeatable) | — | +| `-o, --output` | Output `.parquetbundle` (may overwrite input) | — | +| `--query-id-prefix` | Restrict query proteins to IDs starting with this prefix | — | +| `--query-where` | Filter query proteins by annotation value (`col~substr`) | — | +| `--reference-id-prefix` | Restrict reference proteins to IDs starting with this prefix | — | +| `--reference-where` | Filter reference proteins by annotation value (`col~substr`) | — | +| `--k` | Number of nearest neighbours | `1` | +| `--metric` | Distance metric (`euclidean`, `cosine`); see the reliability-index forms above | `euclidean` | + +Distances are computed in the original embedding space (HDF5), not in the 2-D/3-D projection. The `--metric` choice also changes how `COL__pred_confidence` is computed: euclidean uses `0.5 / (0.5 + distance)`, while cosine uses `clamp(1 - cosine_distance, 0, 1)` (see the reliability-index note above). + ## Combining Multiple Inputs (`-i`) When multiple `-i` inputs are provided, behavior depends on whether they share the same embedding name: diff --git a/docs/superpowers/specs/2026-06-11-eat-annotation-transfer-design.md b/docs/superpowers/specs/2026-06-11-eat-annotation-transfer-design.md new file mode 100644 index 00000000..698f595e --- /dev/null +++ b/docs/superpowers/specs/2026-06-11-eat-annotation-transfer-design.md @@ -0,0 +1,349 @@ +# Design: Embedding Annotation Transfer (`protlabel` engine + `protspace transfer` subcommand) + +**Status:** Draft for review +**Date:** 2026-06-11 +**Supersedes:** an earlier "neighbors-subcommand" draft (since removed) — this expanded its scope, reconciled it with GitHub issue #54 and frontend PR #272, and corrected two defaults (cosine→Euclidean, and the reliability aggregation). +**Trigger:** Conference feedback (`Conference_feedback/ProtSpaceExtractor_v1.7.4_mod 1.py`) + GitHub issue [#54 "EAT — Embedding Annotation Transfer (protlabel lookup table)"](https://github.com/tsenoner/protspace/issues/54) + frontend PR [protspace_web #272 "mark predictions and surface per-annotation docs"](https://github.com/tsenoner/protspace_web/pull/272). +**Research backing:** Literature + codebase fan-out (8 agents) with adversarial verification of the storage/compute math and the EAT algorithm against primary sources. Citations in §15. + +> **Shipped vs deferred (read this first).** This document is a design draft, not as-built documentation. Only a subset of what is described below shipped in PR #55; several flags and the long-format table are deferred follow-ups. Treat anything not in the "Shipped" list as future work that is **not yet implemented**. +> +> **Shipped in PR #55** (`protspace transfer`, see `src/protspace/cli/transfer.py`): +> - Flags: `-b/--bundle`, `-e/--embeddings`, `-t/--transfer`, `-o/--output`, `--query-id-prefix`, `--query-where`, `--reference-id-prefix`, `--reference-where`, `--k`, `--metric` (`euclidean` | `cosine`). +> - Wide overlay columns appended to the bundle annotations table (`src/protspace/data/io/predictions.py`): `__pred_value`, `__pred_confidence`. +> - Brute-force nearest-neighbour search and the goPredSim reliability index. +> +> **Deferred / not yet implemented (future work):** +> - The opt-in flags described in §3, §5, and §11: `--cutoff`, `--mine`/`--top-n`, `--lookup`, `--report`, `--plots`, `--full-tables`, `--excel`, distance-threshold transfer mode. +> - The long-format `predicted_annotations` parquet table (§7.2); shipped output is the wide `__pred_*` overlay columns instead. +> - The faiss accelerator and the `protspace[ann]` extra — **the `ann` extra is not declared in `pyproject.toml`**, so `pip install protspace[ann]` does not work today. + +--- + +## 1. One-paragraph decision + +Build **`protlabel`** as a small standalone uv workspace member — the **Embedding Annotation Transfer (EAT) engine**: nearest-neighbour label transfer in *true pLM embedding space* with a calibrated reliability index. **`protspace`** consumes it through a thin **`protspace transfer`** subcommand that reads a `.parquetbundle` (+ the source HDF5 embeddings), classifies query vs reference proteins, calls `protlabel`, and writes a small **prediction overlay** back into the bundle. The big artifact (the reference embedding matrix / lookup index) is **never** shipped in the bundle — it is a rebuildable **sidecar/cache file**. The tiny artifact (per-protein predicted value + confidence + source neighbour) ships in the bundle and the **frontend renders it as a new "predicted-by-transfer" visual layer** (hollow markers, confidence-driven opacity, provenance tooltip) that is *orthogonal to* PR #272's existing column-level "predicted-by-model" badge. Everything beyond the one default output table is opt-in, so ProtSpace does not get overblown. + +This is the canonical Rost-lab EAT method (Littmann et al. 2021; Heinzinger et al. 2022) — exactly what issue #54 describes — packaged so the conference users' proximity-mining workflow becomes a thin, optional layer on top rather than a parallel reimplementation. + +--- + +## 2. How the three inputs relate + +| Input | What it is | Role in this design | +|---|---|---| +| **Issue #54 (EAT / `protlabel`)** | "Given references with known annotations + embeddings, transfer labels to unknowns by nearest neighbour in *embedding space*, backed by a lookup table." | **The engine.** Defines `protlabel`. This is canonical EAT. | +| **Conference `ProtSpaceExtractor` v1.7.4** | A 1.5K-LOC script doing proximity mining in *DR/projection space* (UMAP/t-SNE coords) with cross-method consensus, an EDD elbow, Venn/agreement sets, neighborhood mining, and a 25-file + HTML report. | **A screening/exploration layer.** Its genuinely-novel parts (neighborhood mining, report) become opt-in extras on `protspace transfer`; its DR-space machinery is mostly subsumed by transferring in true embedding space (see §6 keep/adapt/drop). | +| **PR #272 (frontend)** | Marks whole annotation **columns** as model-predicted (⚡ badge, source grouping, info popovers), deliberately **frontend-only, no data-format change**. | **The base the EAT frontend extends.** EAT introduces *cell-level* predictions (a value-level axis #272 explicitly deferred). The two compose; neither replaces the other. | + +The key realization tying them together: **DR-space proximity (the conference approach) is a lossy proxy for embedding-space proximity (EAT).** Non-linear DR is not isometric, so "nearest in UMAP_3" ≠ "nearest in embedding space." Issue #54 gets this right; the conference script worked around it with consensus/normalization scaffolding that becomes unnecessary once we transfer in the true space. + +--- + +## 3. The canonical method we adopt (EAT), with corrections + +Verified against primary sources (goPredSim / Littmann et al. *Sci Rep* 2021; EAT tool / Heinzinger et al. *NAR Genom Bioinform* 2022): + +- **Space:** original pLM embedding space (mean-pooled per-protein vectors). **Not** DR coordinates. +- **Metric:** **Euclidean (L2)**, default. *Nuance (verifier correction):* the strong "Euclidean beats cosine for pLM embeddings" statement is from the **2022** paper (citing prior work); the **2021** paper only found cosine "changed little." Euclidean is still the right default because it is the canonical tool default and the documented 2022 finding — but the basis is "tool convention + 2022 claim," not "both papers." Cosine stays an opt-in `--metric`. +- **No L2-normalization by default** — canonical EAT uses raw Euclidean; magnitude carries information cosine discards. (Normalization/whitening to fight hubness are research extras, off by default.) +- **k = 1** default (the value eat.py defaults to and the 2021 paper chose after grid search). `k` is exposed; a distance-threshold mode (transfer from all references within distance d) is also supported, per goPredSim. +- **Reliability index (the confidence column).** Adopt goPredSim Eq.(5) **verbatim** — do **not** invent a separate `reliability × vote` product (a verifier flagged that as non-canonical; neighbour agreement is *already* encoded in the formula): + + ``` + RI(p) = (1/k) · Σ_{i : n_i carries label p} s(d(q, n_i)) + + where s(d) = 0.5 / (0.5 + d) for Euclidean (RI = 1.0 at d=0, 0.5 at d=0.5, →0 as d→∞) + s(d) = 1 − d for cosine (clamp to [0,1]; cosine distance ∈ [0,2]) + ``` + + For the default k=1 this collapses to `RI = 0.5/(0.5 + d)`. The `(1/k)·Σ_{neighbours carrying p}` term *is* the multi-neighbour agreement weighting; report `RI` directly as the `[0,1]` confidence. +- **Distance→accuracy calibration (reference point, ProtT5/CATH):** at Euclidean distance ≤ 1.1, ~75% coverage with ~90% accuracy at CATH H-level; ProtTucker (contrastive) reaches ~76% H-level vs raw ProtT5 EAT ~64% and HMMER ~77%. **Caveat (critical):** the `0.5` constant in `s(d)` and the `1.1` threshold are **ProtT5-specific**. ProtSpace supports 12 embedders (320–2560 dim) with different distance scales — RI stays *monotone* (good for ranking) but is **not a calibrated probability** for other models without re-validation. Document this loudly. + +**Output contract (mirror eat.py for interoperability):** per query → `query_id`, transferred `label`, `source_id` (nearest reference), `source_label`, `distance`, `reliability`. Accept goPredSim's 2-column `id → comma-separated labels` lookup-label file so existing EAT/goPredSim lookups drop in. + +**Optional upgrade path (documented, not built first):** ProtTucker-style contrastive projection or CLEAN-style EC centroids as a future "learned distance" mode. Ship raw-embedding Euclidean EAT first — it needs no training and is the published baseline. + +--- + +## 4. Architecture + +``` +protspace/ # the protspace repo (also the build root) +├── pyproject.toml # hatchling builds BOTH packages; scipy added as a dep +└── src/ + ├── protlabel/ # NEW second top-level package — the EAT engine (issue #54) + │ ├── __init__.py # public API: eat(), Lookup, Prediction + │ ├── reliability.py # goPredSim distance→[0,1] reliability transform + │ ├── backends.py # brute-force (default) | faiss (optional, later) NN search + │ ├── transfer.py # kNN + label transfer + reliability index (RI) + │ └── lookup.py # build / save / load the reference lookup sidecar (.npz) + └── protspace/ + ├── cli/transfer.py # NEW thin Typer subcommand (glue only, ~150 LOC) + ├── analysis/ # NEW — classifier now; optional gating/mining later + │ └── classification.py # query/reference classifier (no hardcoded biology) + ├── data/io/bundle.py # EXTEND: replace_annotations_in_bundle() (in-place part-1 rewrite) + └── data/io/predictions.py # NEW: build the per-cell overlay columns +``` + +**Packaging decision (refines the spec):** the suite root is *not* a uv workspace, and `protspace` +publishes to PyPI — so a separate `protlabel` distribution would force its own PyPI release + CI +changes. For the MVP, `protlabel` ships as a **second top-level package inside the protspace repo** +(`src/protlabel/`), bundled into the protspace wheel via +`[tool.hatch.build.targets.wheel] packages = ["src/protspace", "src/protlabel"]`. A strict +**no-`protspace`-imports boundary** keeps it independently testable and reusable, and makes a future +promotion to a standalone PyPI package / uv workspace member a clean, mechanical split. The optional +gating/mining/report and faiss backend are deferred to follow-up work (kept out so ProtSpace stays lean). + +**The boundary (why a submodule, not just a subcommand):** + +- **`protlabel` is pure and ProtSpace-agnostic.** In: reference embeddings (`ndarray` + ids + labels) and query embeddings (`ndarray` + ids). Out: per-query nearest neighbour(s), distance, reliability, transferred label(s). It owns the **lookup table** (issue #54's core artifact): building it, serializing it to a sidecar, and querying it. Usable as a standalone `eat`-style tool and reusable by other projects (`protspace_uniprot`, notebooks). +- **`protspace transfer` is glue.** It knows about `.parquetbundle`, HDF5 loading (`load_h5`), query/reference classification, and writing the overlay back. It contains **no distance math** — that lives in `protlabel`. +- **`protspace/analysis/`** holds the *optional* conference-derived screening (gating, neighborhood mining). Default runs skip it entirely. + +This keeps `protspace` from getting overblown: the heavy/clever code is isolated in a focused library; the subcommand stays small; the extras are opt-in modules that don't load unless requested. + +--- + +## 5. Algorithm — "best of both worlds" pipeline + +Minimal viable command produces a useful, calibrated transfer table from **embedding-space 1-NN + reliability index alone**. Everything else is a flag. + +1. **Load embeddings** from the source HDF5 (`load_h5`: float16→float32 upcast, dim validation, reject per-residue). Mean-pool already done upstream (per-protein vectors). +2. **Classify** queries vs references by ID-prefix and/or metadata-substring rules (CLI flags or YAML). **No hardcoded biology** (drop the v1.7.4 `TRINITY_`/`mscr` fallback). Error clearly if no query rule matches anything. +3. **kNN in true embedding space** (`protlabel.transfer`). Default metric **Euclidean**; `--metric {euclidean,cosine}`. Brute-force chunked search (default); faiss backend if installed. Take the `k` nearest *references* per query (default k=1). +4. **Transfer label + reliability** via Eq.(5) above → the primary `[0,1]` confidence. This *replaces* consensus/EDD as the headline number. +5. *(opt-in)* **Confidence gate** for batch triage: `--cutoff {fixed,reliability,percentile,edd}`. Default-if-requested = fixed distance/reliability tied to measured accuracy (EAT-style). If `edd`, compute `max(Kneedle distance-to-chord, median-jump)` **on the embedding-space distance curve**, clearly labeled a heuristic soft gate — never as calibrated confidence. +6. *(opt-in)* **High-precision subset** = `reliability ≥ threshold AND k-NN vote unanimous` (the embedding-space replacement for v1.7.4 "Overlapped"). +7. *(opt-in)* **Neighborhood mining** (`--mine`/`--top-n`): top-N nearest items around each reference/confident query, with recurrence counts and a non-redundant pooled panel. Pure exploration, decoupled from confidence. +8. **Output** one tidy `predictions.parquet` by default (§7.2). Extras opt-in: `--report` (Jinja2 HTML), `--plots`, `--full-tables` (reproduce the v1.7.4 25-file layout for the conference users). + +### Keep / adapt / drop the conference ideas (verified) + +| v1.7.4 idea | Verdict | Why | +|---|---|---| +| Rank-percentile normalization | **DROP** as core / keep as optional descriptor | Exists only to make incomparable DR scales (PCA vs UMAP vs t-SNE) summable. One embedding-space metric + RI makes it unnecessary. | +| Cross-DR-method consensus (0–6) | **DROP** | The 6 projections are deterministic lossy shadows of the *same* embedding; their agreement measures DR stability, not biological confidence — zero independent evidence once you transfer in the source space. (At most an opt-in "projection agreement" QC diagnostic, never labeled "confidence".) | +| EDD elbow (Kneedle ∨ median-jump) | **DEMOTE** to optional soft gate | Adaptive and parameter-free, but statistically uncalibrated (curve-shape dependent, Kneedle is noise-sensitive) — unlike EAT's accuracy-tied 1.1. Recompute in embedding space; offer as one `--cutoff` option, not the default. | +| N-way "Overlapped" agreement set | **ADAPT** | Redefine from "UMAP_3 ∩ TSNE_3 survivors" to "reliability ≥ t AND vote unanimous" — same high-precision-subset value, calibrated terms. | +| Top-N neighborhood mining | **KEEP** (opt-in) | The strongest survivor: metric-agnostic, genuinely useful for cluster expansion / focused re-runs. Gate behind `--mine`. | +| 25-file output + Venns/coverage maps/graphs | **DROP** as default / keep behind `--plots`/`--full-tables` | Venns specifically lose meaning once consensus is dropped. | +| One-page HTML report | **KEEP** (opt-in `--report`) | Useful sharing artifact; Jinja2 template, not `__doc__`-string injection. | + +> **Framing for the conference users:** present embedding-space transfer as a strict upgrade that *subsumes* their goals (a single calibrated confidence instead of a 6-way consensus proxy), and keep their exact workflow reproducible via `--full-tables`. Their contribution is acknowledged and preserved, not discarded. + +--- + +## 6. Storage & data representation — the user's "is it too large?" question + +Two artifacts with **very different** sizes. Treat them differently. + +### 6.1 The reference lookup (BIG) → sidecar / cache, never in the bundle + +Reference embedding matrix size (N proteins × D dims; binary units, ≈5% smaller than SI): + +| N | D=1024 (ProtT5) fp32 / fp16 | D=2560 (ESM2-3B) fp32 / fp16 | +|---|---|---| +| 1,000 | 3.9 MiB / 2.0 MiB | 9.8 MiB / 4.9 MiB | +| 10,000 | 39 MiB / 20 MiB | 98 MiB / 49 MiB | +| 100,000 | 391 MiB / 196 MiB | 977 MiB / 489 MiB | +| **573,000 (Swiss-Prot)** | **2.19 GiB / 1.09 GiB** | **5.47 GiB / 2.73 GiB** | + +The `.parquetbundle` is a ~45 MB portable viz payload. **Embedding a 1–5.5 GiB matrix into it is the wrong call** — it would bloat every download for a feature most viz users never touch, and it is rebuildable from the source HDF5 anyway. **Decision:** + +- **`protlabel` writes the lookup as a sidecar file** next to the bundle (or in `~/.cache/protspace/`): raw fp16 `.npy`/`.h5` for small sets, a serialized faiss index for large sets. +- **Rebuildable on demand:** if the sidecar is absent/stale, regenerate from the embeddings HDF5 (brute force needs nothing; a faiss build for 573K is seconds-to-low-minutes on CPU). +- **Optional compression** (faiss IVF-PQ) shrinks the *whole* of Swiss-Prot dramatically: `m=64 → ~35 MiB` (64× at D=1024, 160× at D=2560), `m=32 → ~17 MiB`, `m=16 → ~9 MiB`. Add an exact-distance **rerank** of top candidates to recover the recall bare 8-bit PQ loses (~70% → ~90–95%). So a "store it small" option exists if a user *does* want it portable — but it is opt-in, validated, and still a sidecar. + +The user's instinct — *"have it as an optional file in tmp"* — is exactly right and is the recommended default. + +### 6.2 The prediction overlay (TINY) → ships in the bundle + +The per-protein result is small and **sparse** (only proteins that got a transferred value). Store as a **new optional parquet table `predicted_annotations`** appended after the existing parts (the bundle is delimiter-separated and length-extensible; old readers that read parts 1–4 ignore it → backward compatible). Long format, one row per (protein, predicted column): + +| column | type | notes | +|---|---|---| +| `identifier` | string | protein id | +| `annotation_column` | string | which annotation was transferred (e.g. `protein_category`) | +| `predicted_value` | string | the transferred label | +| `reliability` | float32 | the `[0,1]` confidence (RI) | +| `distance` | float32 | embedding distance to the source | +| `source_id` | string | nearest reference protein | +| `k`, `method`, `model` | small | provenance (e.g. `k=1`, `euclidean`, `prot_t5`) | + +Even for tens of thousands of predicted cells this is well under a megabyte. **Do not** store full neighbour lists per cell for 570K × many columns — keep top-1 `source_id` + `k` + `method`; fetch fuller neighbour lists lazily only if a richer hover is ever needed. + +**Representation model (chosen): per-cell overlay on the original column.** EAT *fills missing values inside an existing annotation column* and marks those cells predicted, so the scatter shows curated (filled) + transferred (hollow) points together in one colour scheme — far more useful than a separate `predicted_` column. (A separate-column model, like Biocentral, is the simpler fallback if the overlay proves too invasive.) + +> **Note on PR #272's contract:** #272 was deliberately *no-data-format-change*. EAT is precisely the feature that introduces **value-level** predicted metadata into the bundle — the axis #272 deferred. That is expected and intended; §10 keeps the two axes cleanly separated. + +--- + +## 7. Compute & feasibility verdict + +**Brute-force kNN is laptop-feasible across the entire range, including full Swiss-Prot.** Measured (Apple Silicon, chunked numpy GEMM + argpartition; reproduced by an independent verifier within ~10–25%): + +| Query batch × references × dim | wall time | +|---|---| +| 1,000 × 100K × 1024 | ~0.8–0.9 s | +| 1,000 × 573K × 1024 | ~4–4.6 s (~4 ms/query) | +| 1,000 × 573K × 2560 | ~6 s (~6 ms/query) | +| single query × 573K | ~4–6 ms | + +**The binding constraint is RAM (to hold the reference matrix), not compute.** Mitigation: load the reference as fp16 and upcast per chunk, chunk the N axis so the Q×N distance block never materializes at full size. This stays within a 16 GB laptop at D=1024 and is borderline-but-workable at D=2560. Older Intel/CI machines run ~2–5× slower but stay sub-minute for a few queries at Swiss-Prot scale. + +**Conclusion:** the entire feature is feasible and *not* compute-intensive at realistic scales. **No ANN index is needed for speed** — exact search is already ~ms/query. ANN's only justification here is *shrinking the stored reference* (PQ), which is opt-in. + +**Default:** exact brute force (numpy/scipy/sklearn — already deps) up to ~100–200K references, and still usable to full Swiss-Prot. **Optional accelerator:** `protspace[ann]` extra → **faiss-cpu** (best wheel coverage: macOS arm64+x86_64, manylinux x86_64+aarch64, Windows; pacmap 0.9.x already adopts faiss). Reject hnswlib (sdist-only, needs a compiler) and ScaNN (no macOS wheels) as cross-platform CLI deps. *(scipy is not currently an explicit `protspace` dep — sklearn pulls it transitively; add it explicitly if using `scipy.spatial.distance.cdist`, or use `sklearn.neighbors.NearestNeighbors(algorithm='brute')`.)* + +--- + +## 8. CLI design + +```bash +# Minimal: transfer one annotation, default Euclidean 1-NN, one output table +protspace transfer \ + --bundle results.parquetbundle \ + --embeddings emb.h5:prot_t5 \ + --transfer protein_category \ + --query-id-prefix TRINITY_ \ + --reference-where 'protein_category~neurotoxin' \ + --out results.parquetbundle # writes the overlay back into the bundle (or a sidecar) + +# Tuning + optional screening +protspace transfer ... --metric euclidean --k 3 \ + --cutoff reliability --min-reliability 0.6 \ + --mine --top-n 5 \ + --report --plots +``` + +| Flag | Default | Purpose | +|---|---|---| +| `--bundle` | required | the `.parquetbundle` to annotate | +| `--embeddings h5[:model]` | required | source embeddings for true-space distance | +| `--transfer COL` | required | annotation column to transfer (repeatable) | +| `--query-* / --reference-*` | required (≥1 query rule) | classification (prefix / `col~substr`); or `--rules rules.yaml` | +| `--metric {euclidean,cosine}` | `euclidean` | **reconciled from the old cosine default** | +| `--k` | `1` | neighbours considered (Eq.5) | +| `--cutoff {none,fixed,reliability,percentile,edd}` | `none` | opt-in confidence gate | +| `--mine`, `--top-n` | off | opt-in neighborhood mining | +| `--lookup PATH` | auto sidecar | reuse/persist the reference lookup | +| `--report`, `--plots`, `--full-tables`, `--excel` | off | opt-in artifacts | + +Subcommand name is an **open question** (§13): `transfer` (clear verb), `eat` (matches #54 / Rost-lab convention), or `neighbors` (old draft). + +--- + +## 9. Frontend representation (extends PR #272, does not duplicate it) + +**Two orthogonal axes — codify this mental model:** + +- **Axis A (existing, #272): column-level provenance** — "this whole column is a model output" (Biocentral / Phobius / TED). Keep `AnnotationMeta.isPredicted`, the ⚡ dropdown/legend badge, and the info-popover **unchanged**. +- **Axis B (new, EAT): cell-level provenance** — "this specific protein's value was *transferred from a neighbour*, confidence X, source Y." New visual language below. Never overload the ⚡ badge to mean both. + +### 9.1 Scatter plot — the primary cue is *shape*, not colour + +- **Observed/curated cells → filled markers** (current behaviour). **EAT-imputed cells → hollow (outline-only) markers in the same category hue**, so cluster identity is preserved while provenance reads at a glance. This is an established convention (filled = observed, open = imputed) and satisfies "never colour-only" (accessibility; ~4% CVD). + - Implementable in the existing WebGL renderer: add a per-point `a_predicted` float attribute (mirror the existing `a_shape` plumbing) and a ring-only branch reusing the current edge-distance/outline math (`strokeWidth = 0.15`, `webgl-renderer.ts`). No shader rewrite. +- **Confidence → redundant opacity (and optional size) ramp on imputed points only.** `alpha = lerp(0.25, 0.9, confidence)`; observed points stay at `baseOpacity 0.9`. Optionally scale size by `sqrt(confidence)`. For very low confidence (<0.3), desaturate toward grey (lightweight VSUP). Hooks: `getOpacity`/`getBaseOpacity`/`getPointSize` in `style-getters.ts`. + +### 9.2 Tooltip — per-point provenance line + +Extend `AnnotationBlock` + `renderAnnotationBlock` (`protein-tooltip.ts`) with an EAT row, distinct from observed values: + +> ⚡ **Predicted:** Neurotoxin (82%) — transferred from **P12345** via ProtT5, k=1 + +with an inline confidence bar and the source id as a **click target** that selects/centres that reference in the scatter. Observed values render exactly as today (no chip). + +### 9.3 Legend — a separate "Predicted (transferred)" sub-section + +When the active annotation has any imputed cells, render a small group with two swatches — **filled = "Observed"**, **hollow = "Predicted by EAT"** — and a note "Faint = low confidence", plus live counts ("1,204 shown / 380 below threshold"). Add as a new optional block in `legend-renderer.ts` (alongside `renderHeader`). **Do not** merge into the ⚡ header badge (that is Axis A). + +### 9.4 Global control — one "Predicted annotations" group near the dropdown/legend + +- **Toggle "Show predicted annotations"** (off → imputed cells render neutral/N-A; only the curated layer shows). +- **Confidence-threshold slider** 0–100% with conventional bands (High >80 / Med 50–80 / Low <50); below-threshold imputed points **fade** (`fadedOpacity 0.15`) rather than vanish, preserving layout context. +- Feed `showPredicted` + `minConfidence` into `StyleConfig`; persist in `LegendPersistedSettings` so the choice survives reload/export. Keyboard-operable with `aria-valuetext`. + +### 9.5 Data-model extension (frontend) + +Mirror the existing parallel-array pattern (`annotation_scores`, `annotation_evidence` in `types.ts`): + +```ts +// VisualizationData (optional, populated only when the bundle carries the overlay) +annotation_predicted?: Record; +// PredictedCell = { confidence: number; sourceId: string; k?: number; method?: string } +``` + +Loader (`data-loader/utils/bundle.ts`) pivots the sparse `predicted_annotations` table into these arrays at parse time. Backward compatible: old bundles lack the table → no overlay; the parser already tolerates unknown columns/parts. + +### 9.6 Frontend gotchas to respect + +- Multi-label cells: treat a cell as imputed **only if all its values were transferred**; otherwise show observed with a tooltip note. +- Selection opacity must override confidence dimming (a clicked low-confidence point stays visible). +- Grayscale/PNG export: hollow-vs-filled must be the load-bearing cue (opacity alone is ambiguous in print). The export path renders the same shader, so hollow survives export — verify at 570K points. +- This is a **separate frontend PR** (depends on the backend emitting the overlay) and warrants its own OpenSpec change in `protspace_web`, building on #272's `annotation-metadata`/`annotation-presentation` capabilities. + +--- + +## 10. Dependencies & packaging + +- **`protlabel`** uses only `numpy`, `scipy`, `h5py` (all already in the ProtSpace stack). It is a **second top-level package in the protspace repo** (`src/protlabel/`), built into the protspace wheel via `[tool.hatch.build.targets.wheel] packages = ["src/protspace", "src/protlabel"]` — *not* a suite-level uv workspace member (the suite root is not a uv workspace, and a separate distribution would need its own PyPI release/CI). It imports nothing from `protspace` (a guarded boundary), so a future standalone split is mechanical. +- Add **`scipy`** to `protspace` via `uv add 'scipy>=1.10'` (currently only transitive via sklearn; `cdist` needs it explicit). +- **faiss-cpu** is a *future* optional accelerator (`protspace[ann]`), out of MVP scope. Rejected for cross-platform CLI: hnswlib (sdist-only) and ScaNN (no macOS wheels). + +--- + +## 11. Testing + +`protlabel` (engine, fast, no ProtSpace deps): +- `test_transfer.py` — synthetic ref/query sets with known nearest neighbours (Euclidean + cosine); RI values at known distances (`d=0→1.0`, `d=0.5→0.5`); k=1 vs k>1 agreement weighting. +- `test_lookup.py` — build/serialize/load round-trip; sidecar rebuild-on-demand; brute-force vs faiss parity (when faiss installed). + +`protspace`: +- `test_transfer_cli.py` — end-to-end on `data/sizes/phosphatase.h5` after `protspace prepare`; overlay table round-trips through the bundle; old (4-part) bundles still read. +- `test_classification.py` — prefix/substring/case rules, empty-query error (no hardcoded biology). +- `test_gating.py` — fixed/reliability/percentile/EDD on known-elbow curves (incl. degenerate `n<15` fallback). + +`protspace_web` (in its own PR): unit tests for the overlay parser + `style-getters` predicted branch; browser checks for hollow markers, legend sub-section, threshold slider, grayscale export. + +--- + +## 12. Docs & notebook + +- `protspace/docs/cli.md`: new `### protspace transfer` section; cite EAT papers. +- `protspace/docs/annotations.md`: document the predicted-overlay columns so `protspace_web`'s registry stays aligned (the #272 contract note already points here). +- New `notebooks/ProtSpace_Transfer.ipynb` — a lean annotation-transfer story on a public dataset (not the full v1.7.4 reproduction). +- Update top-level `CLAUDE.md` CLI table; pre-commit checklist (ruff + docs + notebook). +- `protspace_web/docs/guide/annotations.md` generator: extend for the predicted-by-transfer legend/UX. + +--- + +## 13. Open questions / decisions for the user + +1. **Subcommand name:** `transfer` (clear verb, recommended), `eat` (matches #54 / Rost-lab brand), or `neighbors` (old draft)? +2. **Overlay vs new column:** per-cell overlay on the original column (recommended, richer viz) or a separate `predicted_` column (simpler, mirrors Biocentral)? +3. **`protlabel` scope:** EAT engine only (recommended first), or also bundle the ProtTucker/CLEAN "learned distance" mode now? +4. **Default cutoff:** ship with `--cutoff none` (transfer everything, let the user filter on reliability — recommended) or a default reliability floor? +5. **Frontend timing:** build the backend + overlay first and the frontend PR second (recommended), or design both in lockstep? +6. **Reconcile or supersede #54:** post this design as the plan on issue #54 and keep #54 as the engine tracking issue? + +## 14. Non-goals + +- No statistical FDR/hypothesis testing (RI is a heuristic ranking, ProtT5-calibrated only). +- No automatic UniProt fetch of references — references are whatever is already in the bundle. +- No legacy `output.json` support (point users to `protspace bundle`). +- No shipping ProtTucker weights initially (raw-embedding EAT needs no training). +- The frontend work is a *separate* PR; this spec only specifies the representation. + +## 15. Citations + +- Littmann, Heinzinger, Olenyi, Dallago, Rost. "Embeddings from protein language models predict conservation and ... / Embedding-based annotation transfer." *Sci Rep* 2021. DOI 10.1038/s41598-020-80786-0 — **RI formula (Eq. 5), GO results.** +- Heinzinger, Littmann, Sillitoe, Bordin, Orengo, Rost. "Contrastive learning on protein embeddings enlightens midnight zone." *NAR Genom Bioinform* 2022. DOI 10.1093/nargab/lqac043 — **generic EAT tool, Euclidean>cosine, CATH 1.1-threshold calibration, ProtTucker.** +- goPredSim — https://github.com/Rostlab/goPredSim (reliability code, 2-column label format). EAT tool — https://github.com/Rostlab/EAT (`eat.py` interface). +- CLEAN (EC, contrastive + centroids), *Science* 2023, DOI 10.1126/science.adf2465 — documented learned-distance upgrade path. +- VSUP (Correll, Moritz, Heer, CHI 2018); ScatterUQ (arXiv 2308.04588); imputed-vs-observed marker convention (mclust) — frontend uncertainty UX. + +--- + +*Verification note: an adversarial pass confirmed the RI formula, Euclidean default, k=1, the eat.py contract, and the 1.1/≈90% calibration against primary sources; it corrected the historical framing (Euclidean>cosine is the 2022 paper's claim, not the 2021 one), the CATH comparison numbers (raw ProtT5 EAT H-level ≈64, HMMER ≈77, MMseqs2 ≈35), and the confidence aggregation (use Eq.5 directly, not a `reliability×vote` product). The storage/compute math reproduced exactly (binary units; the SI labels were ~5% low) and the "sidecar not bundle / brute-force default" recommendations were judged sound.* diff --git a/notebooks/ProtSpace_Transfer.ipynb b/notebooks/ProtSpace_Transfer.ipynb new file mode 100644 index 00000000..cf98973e --- /dev/null +++ b/notebooks/ProtSpace_Transfer.ipynb @@ -0,0 +1,104 @@ +{ + "metadata": { + "colab": { + "collapsed_sections": [], + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5, + "cells": [ + { + "cell_type": "markdown", + "id": "a1b2c3d4", + "metadata": {}, + "source": "# ProtSpace — Embedding Annotation Transfer (EAT)\n\nThis notebook demonstrates **Embedding Annotation Transfer (EAT)** with `protspace transfer`.\nFor each query protein that lacks an annotation value, the command finds the closest\nannotated reference protein in pLM embedding space and transfers its label, together\nwith a reliability index in [0, 1]. The exact confidence formula depends on `--metric` and `--k`:\n\n- Default (`--metric euclidean`, `--k 1`): `confidence = 0.5 / (0.5 + distance)`.\n- `--metric cosine` (`--k 1`): `confidence = clamp(1 - cosine_distance, 0, 1)` (cosine distance in [0, 2]).\n- `--k > 1`: the goPredSim mean reliability — `(1/m) * sum` of the per-neighbour similarity over the `k` nearest neighbours carrying the chosen label, where `m = min(k, number of references)`. Because of this normalization, values are not comparable across different `--k` settings.\n\nThe method follows the goPredSim approach introduced in:\n\n- Littmann et al., *Sci Rep* 2021 — [DOI 10.1038/s41598-020-80786-0](https://doi.org/10.1038/s41598-020-80786-0)\n- Heinzinger et al., *NAR Genom Bioinform* 2022 — [DOI 10.1093/nargab/lqac043](https://doi.org/10.1093/nargab/lqac043)\n\nDistances are computed in the original high-dimensional embedding space (HDF5),\nnot in any 2-D/3-D projection. The curated source column is left untouched;\nresults are written as `COL__pred_value` and `COL__pred_confidence`\ncolumns in the bundle's annotations table.\n\n📚 [GitHub](https://github.com/tsenoner/protspace) · [CLI Reference](https://github.com/tsenoner/protspace/blob/main/docs/cli.md#protspace-transfer) · [Annotation Reference](https://github.com/tsenoner/protspace/blob/main/docs/annotations.md#prediction-overlay-columns-eat-transfer)" + }, + { + "cell_type": "markdown", + "id": "b2c3d4e5", + "metadata": {}, + "source": [ + "## Install" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c3d4e5f6", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install protspace" + ] + }, + { + "cell_type": "markdown", + "id": "d4e5f6a7", + "metadata": {}, + "source": [ + "## Run transfer\n", + "\n", + "Transfer the `protein_category` annotation from annotated reference proteins\n", + "(filtered to those with `protein_category` containing `neurotoxin`) to\n", + "unannotated query proteins whose IDs start with `TRINITY_`.\n", + "\n", + "Adjust `-b`, `-e`, `-t`, `--query-id-prefix`, and `--reference-where` to match your data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e5f6a7b8", + "metadata": {}, + "outputs": [], + "source": [ + "!protspace transfer \\\n", + " -b results.parquetbundle \\\n", + " -e embeddings.h5:prot_t5 \\\n", + " -t protein_category \\\n", + " -o results.parquetbundle \\\n", + " --query-id-prefix TRINITY_ \\\n", + " --reference-where 'protein_category~neurotoxin'" + ] + }, + { + "cell_type": "markdown", + "id": "f6a7b8c9", + "metadata": {}, + "source": [ + "## Read predictions back\n", + "\n", + "Load the updated bundle and inspect the `*__pred_*` columns for the transferred annotations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a7b8c9d0", + "metadata": {}, + "outputs": [], + "source": [ + "import io\n", + "import pyarrow.parquet as pq\n", + "from protspace.data.io.bundle import read_bundle\n", + "\n", + "parts, _ = read_bundle(\"results.parquetbundle\")\n", + "df = pq.read_table(io.BytesIO(parts[0])).to_pandas()\n", + "\n", + "pred_cols = [c for c in df.columns if c.endswith(\"__pred_value\")]\n", + "df[df[pred_cols[0]].notna()].head() if pred_cols else df.head()" + ] + } + ] +} \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 9caf3eb9..83ef3eb3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "requests>=2.32.4", "typer>=0.24.1", "rich>=14.3.3", + "scipy>=1.10", ] [project.optional-dependencies] @@ -70,6 +71,9 @@ protspace = "protspace.cli.app:app" requires = ["hatchling"] build-backend = "hatchling.build" +[tool.hatch.build.targets.wheel] +packages = ["src/protspace", "src/protlabel"] + [tool.pytest.ini_options] markers = [ "slow: marks tests as slow (deselect with '-m \"not slow\"')", diff --git a/src/protlabel/__init__.py b/src/protlabel/__init__.py new file mode 100644 index 00000000..aebc59d8 --- /dev/null +++ b/src/protlabel/__init__.py @@ -0,0 +1,21 @@ +"""protlabel — Embedding Annotation Transfer (EAT) engine. + +Nearest-neighbour label transfer in protein-language-model embedding space, +with the goPredSim reliability index. Pure numpy (plus the standard library); +no protspace imports. +""" + +from importlib.metadata import PackageNotFoundError, version + +from protlabel.lookup import Lookup +from protlabel.transfer import Prediction, eat + +try: + # protlabel currently ships inside the protspace wheel; report that version + # rather than a hard-coded literal that would silently drift across releases. + # (Reads distribution metadata by name — it does not import protspace.) + __version__ = version("protspace") +except PackageNotFoundError: # pragma: no cover - source/uninstalled fallback + __version__ = "0.0.0" + +__all__ = ["Lookup", "Prediction", "eat", "__version__"] diff --git a/src/protlabel/backends.py b/src/protlabel/backends.py new file mode 100644 index 00000000..e5472ee4 --- /dev/null +++ b/src/protlabel/backends.py @@ -0,0 +1,133 @@ +"""Exact (brute-force) k-nearest-neighbour search over reference embeddings. + +Distances are computed with a chunked BLAS matrix product (numpy ``@``) plus +``argpartition`` — the GEMM path, which is roughly an order of magnitude faster +than ``scipy.cdist`` while staying pure-numpy (no scipy/sklearn dependency). + +The query axis is chunked and, adaptively, bounded against the reference axis so +the per-chunk distance block is kept at or below ``max_block_bytes`` (default +256 MiB) regardless of ``n_refs``. This keeps peak memory close to the +reference matrix itself even at Swiss-Prot scale (~570 000 references). +""" + +from __future__ import annotations + +import warnings + +import numpy as np + +_METRICS = {"euclidean", "cosine"} + + +def _l2_normalize(x: np.ndarray) -> np.ndarray: + """Row-wise L2 normalize. Zero-magnitude rows stay zero (kept finite). + + A zero vector would make cosine distance NaN; mapping it to the zero vector + yields a finite cosine distance of 1.0 to every reference instead. + """ + norms = np.linalg.norm(x, axis=1, keepdims=True) + safe = np.where(norms == 0.0, 1.0, norms) + return x / safe + + +def _exact_distances(block: np.ndarray, sel: np.ndarray, metric: str) -> np.ndarray: + """Distances from each query (block[i]) to its k candidates (sel[i]) in float64. + + The fast GEMM block selects the top-k candidates; this recomputes their + distances by direct subtraction in float64 to avoid the catastrophic + cancellation of ``||q||^2 - 2 q.r + ||r||^2`` for near-identical vectors. + Cost is O(b * k * d), not O(b * n_refs), so it is cheap. + """ + blk = block[:, None, :].astype(np.float64) # (b, 1, d) + sel = sel.astype(np.float64) # (b, k, d) + if metric == "euclidean": + diff = blk - sel + return np.sqrt(np.einsum("bkd,bkd->bk", diff, diff)) + bn = np.linalg.norm(blk, axis=2, keepdims=True) + bn = np.where(bn == 0.0, 1.0, bn) + sn = np.linalg.norm(sel, axis=2, keepdims=True) + sn = np.where(sn == 0.0, 1.0, sn) + cos = np.einsum("bkd,bkd->bk", blk / bn, sel / sn) + return np.clip(1.0 - cos, 0.0, 2.0) + + +def nearest( + queries: np.ndarray, + refs: np.ndarray, + k: int, + metric: str = "euclidean", + chunk: int = 4096, + max_block_bytes: int = 256 * 1024 * 1024, +) -> tuple[np.ndarray, np.ndarray]: + """Return (idx, dist) of the k nearest *references* per query. + + idx[i] -> indices into ``refs`` of the k nearest, ascending by distance. + dist[i] -> the corresponding distances. + k is capped to the number of references. + + Memory behaviour + ---------------- + The per-chunk distance block has shape ``(query_chunk, n_refs)``. When + ``n_refs`` is large (e.g. Swiss-Prot ~570 000), even a modest ``chunk`` of + 4096 yields a multi-GiB block. To bound this, the effective query-chunk + size ``eff_chunk`` is computed as ``min(chunk, max_block_bytes // (n_refs * + 8))`` so each block stays at or below ``max_block_bytes`` (default 256 MiB) + independent of ``n_refs``. The factor 8 budgets for a float64-sized block; + the euclidean path holds a few float32 ``(query_chunk, n_refs)`` temporaries + at once, so the real peak is comparable to (not far below) that budget. The + exact-distance recompute only touches the k surviving candidates, so peak + memory remains close to the reference matrix itself, making the function + laptop-feasible at Swiss-Prot scale. + """ + if metric not in _METRICS: + raise ValueError(f"Unknown metric {metric!r}; expected 'euclidean' or 'cosine'") + + if k < 1: + raise ValueError("k must be >= 1") + + queries = np.ascontiguousarray(queries, dtype=np.float32) + refs = np.ascontiguousarray(refs, dtype=np.float32) + n_refs = refs.shape[0] + k = min(k, n_refs) + + # Adaptively shrink the query chunk so the distance block stays within + # max_block_bytes (budgeting for a float64 = 8 bytes-per-element block). + bytes_per_row = max(1, n_refs * 8) + eff_chunk = max(1, min(chunk, max_block_bytes // bytes_per_row)) + + idx_out = np.empty((queries.shape[0], k), dtype=np.int64) + dist_out = np.empty((queries.shape[0], k), dtype=np.float32) + + # Cosine: normalize once; distance is 1 - cosine similarity via a dot product. + refs_unit = _l2_normalize(refs) if metric == "cosine" else None + # Euclidean: precompute ||ref||^2 once; ||q-r||^2 = ||q||^2 - 2 q.r + ||r||^2. + refs_sq = np.einsum("ij,ij->i", refs, refs) if metric == "euclidean" else None + + for start in range(0, queries.shape[0], eff_chunk): + block = queries[start : start + eff_chunk] + with warnings.catch_warnings(): + # GEMM on near-overflowing float16-origin values can emit harmless + # RuntimeWarnings; the inputs are upcast float32 so values are safe. + warnings.simplefilter("ignore", RuntimeWarning) + if metric == "euclidean": + block_sq = np.einsum("ij,ij->i", block, block) + cross = block @ refs.T # (b, n_refs), BLAS GEMM + d2 = block_sq[:, None] - 2.0 * cross + refs_sq[None, :] + np.maximum(d2, 0.0, out=d2) # clip tiny negative fp artifacts + d = np.sqrt(d2, dtype=np.float32) + else: # cosine + block_unit = _l2_normalize(block) + d = (1.0 - block_unit @ refs_unit.T).astype(np.float32) + np.clip(d, 0.0, 2.0, out=d) # cosine distance in [0, 2] + + # Select the k candidates with the fast (float32) block, then recompute + # their distances exactly in float64 and order by that — so the reported + # distance is precise even for near-identical vectors. + cand = np.argpartition(d, kth=k - 1, axis=1)[:, :k] # unsorted top-k + rows = np.arange(block.shape[0])[:, None] + exact = _exact_distances(block, refs[cand], metric) # (b, k) float64 + order = np.argsort(exact, axis=1) + idx_out[start : start + block.shape[0]] = cand[rows, order] + dist_out[start : start + block.shape[0]] = exact[rows, order].astype(np.float32) + + return idx_out, dist_out diff --git a/src/protlabel/lookup.py b/src/protlabel/lookup.py new file mode 100644 index 00000000..916f557c --- /dev/null +++ b/src/protlabel/lookup.py @@ -0,0 +1,71 @@ +"""A persistable reference lookup: embeddings + ids + labels, plus query(). + +Serialized as a single .npz sidecar so it can live next to a bundle or in a +cache dir and be rebuilt on demand from the source HDF5. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path + +import numpy as np + +from protlabel.transfer import Prediction, eat + + +@dataclass +class Lookup: + """Reference set for embedding annotation transfer.""" + + embeddings: np.ndarray # (N, D) float32 + ids: list[str] + labels: list[str] + metric: str = "euclidean" + model: str = field(default="") + + def query( + self, query_emb: np.ndarray, query_ids: list[str], *, k: int = 1 + ) -> list[Prediction]: + """Transfer labels to query embeddings from this lookup.""" + return eat( + query_emb, + query_ids, + self.embeddings, + self.ids, + self.labels, + k=k, + metric=self.metric, + ) + + def save(self, path: Path) -> None: + """Serialize to a .npz sidecar. + + Embeddings are stored as float32 (lossless round-trip); ids/labels are + stored as unicode arrays rather than pickled object arrays so the sidecar + can be loaded with ``allow_pickle=False`` (no arbitrary-code-execution + surface on load). ids/labels must not contain trailing NUL bytes, which + numpy's fixed-width unicode arrays strip on round-trip. + """ + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + np.savez( + path, + embeddings=self.embeddings.astype(np.float32), + ids=np.asarray(self.ids, dtype=np.str_), + labels=np.asarray(self.labels, dtype=np.str_), + metric=np.asarray(self.metric, dtype=np.str_), + model=np.asarray(self.model, dtype=np.str_), + ) + + @classmethod + def load(cls, path: Path) -> Lookup: + """Load a .npz sidecar (with pickling disabled).""" + with np.load(path, allow_pickle=False) as data: + return cls( + embeddings=data["embeddings"].astype(np.float32), + ids=[str(x) for x in data["ids"]], + labels=[str(x) for x in data["labels"]], + metric=str(data["metric"]), + model=str(data["model"]), + ) diff --git a/src/protlabel/reliability.py b/src/protlabel/reliability.py new file mode 100644 index 00000000..7d9a0c11 --- /dev/null +++ b/src/protlabel/reliability.py @@ -0,0 +1,18 @@ +"""goPredSim reliability index: map an embedding distance to a [0,1] confidence. + +Euclidean: s(d) = 0.5 / (0.5 + d) (1.0 at d=0, 0.5 at d=0.5, ->0 as d->inf) +Cosine: s(d) = 1 - d (clamped to [0,1]; cosine distance in [0,2]) + +Reference: Littmann et al., Sci Rep 2021 (Eq. 5); goPredSim calc_reliability_index. +""" + +from __future__ import annotations + + +def similarity(distance: float, metric: str) -> float: + """Per-neighbour distance->similarity (the goPredSim reliability transform).""" + if metric == "euclidean": + return 0.5 / (0.5 + distance) + if metric == "cosine": + return min(1.0, max(0.0, 1.0 - distance)) + raise ValueError(f"Unknown metric {metric!r}; expected 'euclidean' or 'cosine'") diff --git a/src/protlabel/transfer.py b/src/protlabel/transfer.py new file mode 100644 index 00000000..e0389343 --- /dev/null +++ b/src/protlabel/transfer.py @@ -0,0 +1,93 @@ +"""Embedding annotation transfer: kNN -> reliability index -> transferred label. + +Implements the goPredSim aggregation (Littmann et al. 2021, Eq. 5): + RI(p) = (1/eff_k) * sum over neighbours carrying label p of similarity(d), +where ``eff_k`` is the number of neighbours actually used (``k`` capped to the +number of references). The transferred label is argmax RI(p); its source is the +nearest neighbour carrying that label. Ties are broken deterministically by +smallest source distance, then by lexically smallest label, so the result never +depends on the (arbitrary) ordering of equidistant neighbours. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +import numpy as np + +from protlabel.backends import nearest +from protlabel.reliability import similarity + + +@dataclass(frozen=True) +class Prediction: + """One transferred annotation for a query protein.""" + + query_id: str + label: str + source_id: str + distance: float + reliability: float + k: int + metric: str + + +def eat( + query_emb: np.ndarray, + query_ids: list[str], + ref_emb: np.ndarray, + ref_ids: list[str], + ref_labels: list[str], + *, + k: int = 1, + metric: str = "euclidean", +) -> list[Prediction]: + """Transfer the best-guess label to each query from its k nearest references.""" + if not (len(ref_ids) == len(ref_labels) == ref_emb.shape[0]): + raise ValueError("ref_emb, ref_ids and ref_labels must have equal length") + if ref_emb.shape[0] == 0: + raise ValueError("No reference embeddings to transfer from") + if len(query_ids) != query_emb.shape[0]: + raise ValueError("query_emb and query_ids must have equal length") + + idx, dist = nearest(query_emb, ref_emb, k=k, metric=metric) + eff_k = idx.shape[1] + predictions: list[Prediction] = [] + + for qi, query_id in enumerate(query_ids): + neigh_idx = idx[qi] + neigh_dist = dist[qi] + # Accumulate RI per label and track the nearest source per label. + ri_by_label: dict[str, float] = {} + nearest_src: dict[str, tuple[float, str]] = {} + for j, ref_i in enumerate(neigh_idx): + lab = ref_labels[ref_i] + d = float(neigh_dist[j]) + ri_by_label[lab] = ri_by_label.get(lab, 0.0) + similarity(d, metric) + if lab not in nearest_src or d < nearest_src[lab][0]: + nearest_src[lab] = (d, ref_ids[ref_i]) + # Pick the highest-RI label; break ties deterministically by smallest + # source distance, then lexically smallest label. This makes the choice + # independent of the order of equidistant neighbours (whose argsort order + # is otherwise arbitrary). For distinct distances the nearest neighbour's + # label wins, as before. + best_label = min( + ri_by_label, + key=lambda p: (-ri_by_label[p], nearest_src[p][0], p), + ) + # Normalise by eff_k (the goPredSim 1/k term, k capped to n_refs). + ri = ri_by_label[best_label] / eff_k + src_dist, src_id = nearest_src[best_label] + predictions.append( + Prediction( + query_id=query_id, + label=best_label, + source_id=src_id, + distance=src_dist, + reliability=ri, + k=eff_k, + metric=metric, + ) + ) + + return predictions diff --git a/src/protspace/analysis/__init__.py b/src/protspace/analysis/__init__.py new file mode 100644 index 00000000..137457c7 --- /dev/null +++ b/src/protspace/analysis/__init__.py @@ -0,0 +1 @@ +"""Optional analysis layer for ProtSpace (classification, gating, mining).""" diff --git a/src/protspace/analysis/classification.py b/src/protspace/analysis/classification.py new file mode 100644 index 00000000..93474b27 --- /dev/null +++ b/src/protspace/analysis/classification.py @@ -0,0 +1,73 @@ +"""Classify proteins as transfer queries vs annotated references. + +Rules match by ID prefix and/or a case-insensitive metadata substring +(``column ~ substring``). No biology is hardcoded; a query rule that matches +nothing is an error. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +import pyarrow as pa + + +@dataclass +class Rule: + """A classification rule. A protein matches if ANY clause matches.""" + + id_prefixes: list[str] = field(default_factory=list) + where: list[tuple[str, str]] = field(default_factory=list) # (column, substring) + + +def _matches(rule: Rule, identifier: str, row: dict[str, str]) -> bool: + if any(identifier.startswith(p) for p in rule.id_prefixes): + return True + for column, substring in rule.where: + if column not in row: + raise KeyError(f"Classification column {column!r} not in annotations") + value = row[column] + if value is not None and substring.lower() in str(value).lower(): + return True + return False + + +def classify( + annotations: pa.Table, query_rule: Rule, reference_rule: Rule +) -> tuple[list[int], list[int]]: + """Return (query_indices, reference_indices) into the annotations table. + + Query classification takes precedence: a protein matching both rules is a + query. Raises ValueError if the query rule matches nothing. + """ + columns = set(annotations.column_names) + # Validate where-columns up front so an empty table still raises KeyError. + for rule in (query_rule, reference_rule): + for column, _ in rule.where: + if column not in columns: + raise KeyError(f"Classification column {column!r} not in annotations") + + # Materialize only the columns the rules actually need (identifier + any + # where-columns) instead of the whole table — the latter is ~GB-scale at + # Swiss-Prot row counts. + identifiers = [str(v) for v in annotations.column("identifier").to_pylist()] + where_columns = { + column for rule in (query_rule, reference_rule) for column, _ in rule.where + } + column_data = {c: annotations.column(c).to_pylist() for c in where_columns} + + query_indices: list[int] = [] + reference_indices: list[int] = [] + for i, identifier in enumerate(identifiers): + row = {c: column_data[c][i] for c in where_columns} + if _matches(query_rule, identifier, row): + query_indices.append(i) + elif _matches(reference_rule, identifier, row): + reference_indices.append(i) + + if not query_indices: + raise ValueError( + "Classifier matched no query proteins; check --query-id-prefix / " + "--query-where rules." + ) + return query_indices, reference_indices diff --git a/src/protspace/cli/app.py b/src/protspace/cli/app.py index c718fbae..1cd8ba39 100644 --- a/src/protspace/cli/app.py +++ b/src/protspace/cli/app.py @@ -70,6 +70,7 @@ def _register_commands() -> None: project, serve, style, + transfer, ) diff --git a/src/protspace/cli/prepare.py b/src/protspace/cli/prepare.py index 4831f6b2..abbdf369 100644 --- a/src/protspace/cli/prepare.py +++ b/src/protspace/cli/prepare.py @@ -549,22 +549,9 @@ def prepare( def _parse_input_specs(raw_inputs: list[str]) -> list[tuple[Path, str | None]]: """Parse inputs with optional colon name override: file.h5:model_name.""" - specs: list[tuple[Path, str | None]] = [] - for raw in raw_inputs: - if ":" in raw: - last_colon = raw.rfind(":") - path_part, name_part = raw[:last_colon], raw[last_colon + 1 :] - if ( - name_part - and not name_part.startswith(("/", "\\")) - and Path(path_part).suffix - ): - specs.append((Path(path_part), name_part)) - else: - specs.append((Path(raw), None)) - else: - specs.append((Path(raw), None)) - return specs + from protspace.data.loaders import split_h5_spec + + return [split_h5_spec(raw) for raw in raw_inputs] def _parse_embedders(embedder_arg: str | None) -> list[str]: diff --git a/src/protspace/cli/transfer.py b/src/protspace/cli/transfer.py new file mode 100644 index 00000000..23a75165 --- /dev/null +++ b/src/protspace/cli/transfer.py @@ -0,0 +1,262 @@ +"""protspace transfer — fill missing annotation values from nearest references. + +Embedding Annotation Transfer (EAT): for each query protein with a missing +value in a target column, transfer the value of its nearest annotated +reference in pLM embedding space, with a reliability-index confidence. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Annotated + +import numpy as np +import pyarrow as pa +import typer + +from protspace.cli.app import app, setup_logging +from protspace.cli.common_options import Opt_Verbose + +if TYPE_CHECKING: + from protspace.analysis.classification import Rule + +logger = logging.getLogger(__name__) + + +def _is_missing(value) -> bool: + return value is None or str(value).strip() == "" + + +def run_transfer( + *, + annotations: pa.Table, + embeddings: dict[str, np.ndarray], + transfer_columns: list[str], + query_rule: Rule, + reference_rule: Rule, + k: int = 1, + metric: str = "euclidean", +) -> pa.Table: + """Pure core: classify, transfer per column, return the augmented table. + + ``embeddings`` maps protein id -> 1-D float32 vector. Proteins without an + embedding cannot act as queries or references. + """ + from protlabel import eat + from protspace.analysis.classification import classify + from protspace.data.io.predictions import add_overlay_columns + + # Restrict classification to proteins that actually have an embedding. + has_emb = pa.array( + [str(v) in embeddings for v in annotations.column("identifier").to_pylist()] + ) + embedded = annotations.filter(has_emb) + + if embedded.num_rows == 0: + raise ValueError( + "No proteins in the bundle have a matching embedding " + "(check that the --embeddings ids match the bundle identifiers)." + ) + + query_idx, ref_idx = classify(embedded, query_rule, reference_rule) + # Materialize only the id column once (not the whole table); per-column values + # are pulled inside the loop. Avoids GB-scale Python lists at Swiss-Prot size. + id_list = [str(v) for v in embedded.column("identifier").to_pylist()] + + out = annotations + total_transferred = 0 + for column in transfer_columns: + if column not in annotations.column_names: + raise KeyError(f"Transfer column {column!r} not in annotations table") + + col_vals = embedded.column(column).to_pylist() + + # References: classified refs that HAVE a value in this column. + ref_ids, ref_labels, ref_vecs = [], [], [] + for i in ref_idx: + value = col_vals[i] + if not _is_missing(value): + rid = id_list[i] + ref_ids.append(rid) + ref_labels.append(str(value)) + ref_vecs.append(embeddings[rid]) + if not ref_ids: + logger.warning("No references with a value for %r; skipping", column) + continue + + # Queries: classified queries MISSING a value in this column. + q_ids, q_vecs = [], [] + for i in query_idx: + if _is_missing(col_vals[i]): + qid = id_list[i] + q_ids.append(qid) + q_vecs.append(embeddings[qid]) + if not q_ids: + logger.warning("No queries missing %r; nothing to transfer", column) + continue + + preds = eat( + np.vstack(q_vecs), + q_ids, + np.vstack(ref_vecs), + ref_ids, + ref_labels, + k=k, + metric=metric, + ) + out = add_overlay_columns(out, column, preds) + total_transferred += len(preds) + logger.info("Transferred %r to %d quer(ies)", column, len(preds)) + + if total_transferred == 0: + logger.warning( + "No annotations were transferred. Check the --reference-* rules and " + "that query proteins have missing values in the target column(s)." + ) + return out + + +@app.command() +def transfer( + bundle: Annotated[ + Path, + typer.Option("-b", "--bundle", help="Input .parquetbundle to annotate."), + ], + embeddings: Annotated[ + str, + typer.Option( + "-e", + "--embeddings", + help="HDF5 embeddings, optional :name suffix (e.g. emb.h5:prot_t5).", + ), + ], + transfer_columns: Annotated[ + list[str], + typer.Option( + "-t", "--transfer", help="Annotation column to transfer (repeat)." + ), + ], + output: Annotated[ + Path, + typer.Option("-o", "--output", help="Output .parquetbundle path."), + ], + query_id_prefix: Annotated[ + list[str] | None, typer.Option("--query-id-prefix") + ] = None, + query_where: Annotated[ + list[str] | None, + typer.Option("--query-where", help="col~substr"), + ] = None, + reference_id_prefix: Annotated[ + list[str] | None, typer.Option("--reference-id-prefix") + ] = None, + reference_where: Annotated[ + list[str] | None, + typer.Option("--reference-where", help="col~substr"), + ] = None, + k: Annotated[ + int, typer.Option("--k", help="Neighbours considered (default 1).") + ] = 1, + metric: Annotated[ + str, typer.Option("--metric", help="euclidean | cosine.") + ] = "euclidean", + verbose: Opt_Verbose = 0, +) -> None: + """Transfer annotations to query proteins from nearest reference neighbours.""" + setup_logging(verbose) + + import io + + import pyarrow.parquet as pq + + from protspace.analysis.classification import Rule + from protspace.data.io.bundle import read_bundle, replace_annotations_in_bundle + from protspace.data.loaders import load_h5, split_h5_spec + + def _parse_where(items: list[str] | None) -> list[tuple[str, str]]: + clauses = [] + for item in items or []: + if "~" not in item: + raise typer.BadParameter(f"--*-where must be col~substr, got {item!r}") + col, sub = item.split("~", 1) + clauses.append((col, sub)) + return clauses + + query_rule = Rule( + id_prefixes=query_id_prefix or [], where=_parse_where(query_where) + ) + reference_rule = Rule( + id_prefixes=reference_id_prefix or [], where=_parse_where(reference_where) + ) + + if metric not in ("euclidean", "cosine"): + raise typer.BadParameter("--metric must be 'euclidean' or 'cosine'") + if k < 1: + raise typer.BadParameter("--k must be >= 1") + + # Load embeddings (optional ':name' override; colon/Windows-safe parsing). + h5_path, name_override = split_h5_spec(embeddings) + emb_set = load_h5([h5_path], name_override=name_override) + emb_map = {header: emb_set.data[i] for i, header in enumerate(emb_set.headers)} + + # Read the annotations part of the bundle. + parts, _settings = read_bundle(bundle) + annotations = pq.read_table(io.BytesIO(parts[0])) + + # Real bundles name the id column "protein_id"; run_transfer works on "identifier". + if ( + "protein_id" in annotations.column_names + and "identifier" in annotations.column_names + ): + raise typer.BadParameter( + "Bundle annotations contain both 'protein_id' and 'identifier' columns; " + "cannot determine the id column unambiguously." + ) + id_col = "protein_id" if "protein_id" in annotations.column_names else "identifier" + if id_col != "identifier": + annotations = annotations.rename_columns( + ["identifier" if n == id_col else n for n in annotations.column_names] + ) + + for col in transfer_columns: + if col not in annotations.column_names: + raise typer.BadParameter( + f"--transfer column {col!r} not found in the bundle annotations" + ) + + # Validate classification --*-where columns up front for a clean error. + available = set(annotations.column_names) + for col, _ in query_rule.where + reference_rule.where: + if col not in available: + raise typer.BadParameter( + f"--*-where column {col!r} not found in the bundle annotations" + ) + + # Translate input-driven errors from the core into clean CLI errors rather + # than leaking raw KeyError/ValueError tracebacks to the user. + try: + augmented = run_transfer( + annotations=annotations, + embeddings=emb_map, + transfer_columns=transfer_columns, + query_rule=query_rule, + reference_rule=reference_rule, + k=k, + metric=metric, + ) + except (KeyError, ValueError) as exc: + # Use the raw message (KeyError stringifies with repr quotes) rather than + # stripping quotes off the rendered string, which could mangle messages. + message = exc.args[0] if exc.args else str(exc) + raise typer.BadParameter(str(message)) from exc + + # Rename id column back so the written bundle keeps its original name + # (the web frontend expects "protein_id"). + if id_col != "identifier": + augmented = augmented.rename_columns( + [id_col if n == "identifier" else n for n in augmented.column_names] + ) + + replace_annotations_in_bundle(bundle, output, augmented) + logger.info("Wrote transferred bundle to %s", output) diff --git a/src/protspace/data/io/bundle.py b/src/protspace/data/io/bundle.py index ca625a26..8cd8a343 100644 --- a/src/protspace/data/io/bundle.py +++ b/src/protspace/data/io/bundle.py @@ -8,6 +8,7 @@ import io import json import logging +import os import tempfile from pathlib import Path @@ -27,6 +28,43 @@ SETTINGS_FILENAME = "settings.parquet" +def _atomic_write_bytes(path: Path, data: bytes) -> None: + """Write ``data`` to ``path`` atomically (temp file + ``os.replace``). + + The destination is never left truncated or partial on interrupt — it keeps + the old bytes until the rename completes, then atomically becomes the full + new bytes. Critical for the in-place overwrite workflow that ``transfer`` + documents (``-b results.parquetbundle -o results.parquetbundle``): a Ctrl+C + or crash mid-write can no longer destroy the user's bundle. + """ + path.parent.mkdir(parents=True, exist_ok=True) + fd, tmp = tempfile.mkstemp(dir=str(path.parent), suffix=".tmp") + try: + with os.fdopen(fd, "wb") as f: + f.write(data) + f.flush() + os.fsync(f.fileno()) + os.replace(tmp, path) + except BaseException: + Path(tmp).unlink(missing_ok=True) + raise + + +def _check_no_delimiter(part_bytes: bytes) -> None: + """Guard: a serialized part must not contain the bundle delimiter. + + If a value (e.g. an annotation string) happens to contain the reserved + delimiter byte string, the part split on read-back would be corrupted; fail + loudly at write time instead. + """ + if PARQUET_BUNDLE_DELIMITER in part_bytes: + raise ValueError( + "Serialized parquet part contains the bundle delimiter " + f"{PARQUET_BUNDLE_DELIMITER!r}; a value includes this reserved byte " + "string and would corrupt the bundle on read." + ) + + def extract_bundle_to_dir(bundle_path: Path, target_dir: Path | None = None) -> str: """Extract a .parquetbundle into separate parquet files on disk. @@ -101,20 +139,23 @@ def write_bundle( bundle_path: Output file path. settings: Optional settings dict to include as 4th part. """ - bundle_path.parent.mkdir(parents=True, exist_ok=True) - - with open(bundle_path, "wb") as f: - for i, table in enumerate(tables): - if i > 0: - f.write(PARQUET_BUNDLE_DELIMITER) - buf = io.BytesIO() - pq.write_table(table, buf) - f.write(buf.getvalue()) - - if settings is not None: - f.write(PARQUET_BUNDLE_DELIMITER) - f.write(create_settings_parquet(settings)) - + buf = io.BytesIO() + for i, table in enumerate(tables): + if i > 0: + buf.write(PARQUET_BUNDLE_DELIMITER) + table_buf = io.BytesIO() + pq.write_table(table, table_buf) + part_bytes = table_buf.getvalue() + _check_no_delimiter(part_bytes) + buf.write(part_bytes) + + if settings is not None: + buf.write(PARQUET_BUNDLE_DELIMITER) + settings_bytes = create_settings_parquet(settings) + _check_no_delimiter(settings_bytes) + buf.write(settings_bytes) + + _atomic_write_bytes(bundle_path, buf.getvalue()) logger.info(f"Saved bundled output to: {bundle_path}") @@ -143,9 +184,37 @@ def replace_settings_in_bundle( core = PARQUET_BUNDLE_DELIMITER.join(parts[:3]) new_content = core + PARQUET_BUNDLE_DELIMITER + settings_bytes - output_path.parent.mkdir(parents=True, exist_ok=True) - with open(output_path, "wb") as f: - f.write(new_content) + _atomic_write_bytes(output_path, new_content) + + +def replace_annotations_in_bundle( + input_path: Path, + output_path: Path, + annotations_table: pa.Table, +) -> None: + """Replace the annotations (1st) part of a bundle, preserving the rest. + + Projection parts (2nd, 3rd) are kept byte-for-byte; an existing settings + (4th) part is carried over unchanged. + """ + with open(input_path, "rb") as f: + content = f.read() + + parts = content.split(PARQUET_BUNDLE_DELIMITER) + if len(parts) < 3 or len(parts) > 4: + raise ValueError(f"Expected 3 or 4 parts in parquetbundle, found {len(parts)}") + + buf = io.BytesIO() + pq.write_table(annotations_table, buf) + new_annotations_bytes = buf.getvalue() + _check_no_delimiter(new_annotations_bytes) + new_parts = [new_annotations_bytes, parts[1], parts[2]] + if len(parts) == 4: + new_parts.append(parts[3]) + + _atomic_write_bytes(output_path, PARQUET_BUNDLE_DELIMITER.join(new_parts)) + + logger.info(f"Wrote bundle with updated annotations to: {output_path}") def create_settings_parquet(settings_dict: dict) -> bytes: diff --git a/src/protspace/data/io/predictions.py b/src/protspace/data/io/predictions.py new file mode 100644 index 00000000..4722bb4e --- /dev/null +++ b/src/protspace/data/io/predictions.py @@ -0,0 +1,58 @@ +"""Turn protlabel Predictions into per-cell overlay columns on the annotations table. + +For a transferred column ``COL`` we append two aligned columns (null for +non-predicted proteins), leaving the curated ``COL`` untouched: + COL__pred_value (string) the transferred label + COL__pred_confidence (float32) the reliability index in [0, 1] + +The reference protein the value came from is available on the ``Prediction`` +(``source_id``) but is intentionally not written to the bundle: it is noise as a +per-cell colour feature, and confidence is the signal users threshold on. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +import pyarrow as pa + +from protlabel import Prediction + + +def add_overlay_columns( + annotations: pa.Table, column: str, predictions: Sequence[Prediction] +) -> pa.Table: + """Append the COL__pred_value / COL__pred_confidence columns, by identifier.""" + by_query = {p.query_id: p for p in predictions} + identifiers = [str(v) for v in annotations.column("identifier").to_pylist()] + + values: list[str | None] = [] + confidences: list[float | None] = [] + for identifier in identifiers: + pred = by_query.get(identifier) + if pred is None: + values.append(None) + confidences.append(None) + else: + values.append(pred.label) + confidences.append(float(pred.reliability)) + + # Drop any pre-existing overlay columns first so re-running transfer on an + # already-overlaid table replaces them rather than appending duplicates + # (duplicate field names produce a parquet table that cannot be read back). + # The legacy __pred_source is included so older bundles are cleaned up. + overlay_names = [ + f"{column}__pred_value", + f"{column}__pred_confidence", + f"{column}__pred_source", + ] + out = annotations + stale = [name for name in overlay_names if name in out.column_names] + if stale: + out = out.drop_columns(stale) + + out = out.append_column(f"{column}__pred_value", pa.array(values, pa.string())) + out = out.append_column( + f"{column}__pred_confidence", pa.array(confidences, pa.float32()) + ) + return out diff --git a/src/protspace/data/loaders/__init__.py b/src/protspace/data/loaders/__init__.py index 908e571b..e864a9fd 100644 --- a/src/protspace/data/loaders/__init__.py +++ b/src/protspace/data/loaders/__init__.py @@ -6,7 +6,12 @@ merge_same_name_sets, ) from protspace.data.loaders.fasta import embed_fasta -from protspace.data.loaders.h5 import EMBEDDING_EXTENSIONS, load_h5, parse_identifier +from protspace.data.loaders.h5 import ( + EMBEDDING_EXTENSIONS, + load_h5, + parse_identifier, + split_h5_spec, +) from protspace.data.loaders.query import ( extract_identifiers_from_fasta, query_uniprot, @@ -24,4 +29,5 @@ "load_h5", "parse_identifier", "query_uniprot", + "split_h5_spec", ] diff --git a/src/protspace/data/loaders/h5.py b/src/protspace/data/loaders/h5.py index f9353dd3..88306a18 100644 --- a/src/protspace/data/loaders/h5.py +++ b/src/protspace/data/loaders/h5.py @@ -22,6 +22,28 @@ ) +def split_h5_spec(spec: str) -> tuple[Path, str | None]: + """Parse an HDF5 input spec with an optional ``:name`` override. + + ``file.h5:model_name`` -> ``(Path("file.h5"), "model_name")``. + + Colon/Windows-safe: the split uses the LAST colon and is only taken when the + right side looks like a name (not a path separator) and the left side has a + file suffix. So ``C:\\data\\emb.h5`` and ``C:\\data\\emb.h5:prot_t5`` parse + correctly instead of splitting on the drive-letter colon. + """ + if ":" in spec: + last_colon = spec.rfind(":") + path_part, name_part = spec[:last_colon], spec[last_colon + 1 :] + if ( + name_part + and not name_part.startswith(("/", "\\")) + and Path(path_part).suffix + ): + return Path(path_part), name_part + return Path(spec), None + + def parse_identifier(raw_key: str) -> str: """Extract protein identifier from an H5 key. diff --git a/tests/test_base_data_processor.py b/tests/test_base_data_processor.py index 8758b8e6..216065b3 100644 --- a/tests/test_base_data_processor.py +++ b/tests/test_base_data_processor.py @@ -137,16 +137,17 @@ def test_save_output_separate_files(self, mock_write_table): assert mock_write_table.call_count == 3 mock_mkdir.assert_called() - @patch("src.protspace.data.processors.base_processor.pq.write_table") - def test_save_output_bundled(self, _): + def test_save_output_bundled(self, tmp_path): + from src.protspace.data.io.bundle import read_bundle + processor = BaseDataProcessor(SAMPLE_CONFIG, {"pca": DummyReducer}) tables = processor.create_output( SAMPLE_METADATA, SAMPLE_REDUCTIONS, SAMPLE_HEADERS ) - with ( - patch("pathlib.Path.mkdir") as mock_mkdir, - patch("builtins.open", new_callable=MagicMock) as mock_open, - ): - processor.save_output(tables, Path("output_dir"), bundled=True) - mock_mkdir.assert_called() - mock_open.assert_called() + out_dir = tmp_path / "output_dir" + processor.save_output(tables, out_dir, bundled=True) + + bundle_path = out_dir / "data.parquetbundle" + assert bundle_path.exists() + parts, _settings = read_bundle(bundle_path) + assert len(parts) == 3 diff --git a/tests/test_bundle_overlay.py b/tests/test_bundle_overlay.py new file mode 100644 index 00000000..5cec3262 --- /dev/null +++ b/tests/test_bundle_overlay.py @@ -0,0 +1,120 @@ +"""Round-trip tests for replacing the annotations part of a bundle.""" + +import io + +import pyarrow as pa +import pyarrow.parquet as pq +import pytest + +from protspace.data.io.bundle import ( + PARQUET_BUNDLE_DELIMITER, + read_bundle, + replace_annotations_in_bundle, + write_bundle, +) + + +def _tables(): + annotations = pa.table({"identifier": ["A", "B"], "cat": ["x", "y"]}) + proj_meta = pa.table({"name": ["PCA 2"], "dims": [2]}) + proj_data = pa.table({"id": ["A", "B"], "x": [0.0, 1.0], "y": [0.0, 1.0]}) + return [annotations, proj_meta, proj_data] + + +def _read_part(part_bytes): + return pq.read_table(io.BytesIO(part_bytes)) + + +def test_replaces_annotations_keeps_other_parts(tmp_path): + src = tmp_path / "in.parquetbundle" + out = tmp_path / "out.parquetbundle" + write_bundle(_tables(), src) + + new_annotations = pa.table( + {"identifier": ["A", "B"], "cat": ["x", "y"], "cat__pred_value": [None, "z"]} + ) + replace_annotations_in_bundle(src, out, new_annotations) + + parts, settings = read_bundle(out) + assert "cat__pred_value" in _read_part(parts[0]).column_names + # Projections preserved byte-for-byte. + assert _read_part(parts[1]).column_names == ["name", "dims"] + assert _read_part(parts[2]).to_pydict()["x"] == [0.0, 1.0] + + +def test_projection_parts_preserved_byte_for_byte(tmp_path): + src = tmp_path / "in.parquetbundle" + out = tmp_path / "out.parquetbundle" + write_bundle(_tables(), src, settings={"foo": 1}) + + new_annotations = pa.table( + {"identifier": ["A", "B"], "cat": ["x", "y"], "cat__pred_value": [None, "z"]} + ) + replace_annotations_in_bundle(src, out, new_annotations) + + in_parts = src.read_bytes().split(PARQUET_BUNDLE_DELIMITER) + out_parts = out.read_bytes().split(PARQUET_BUNDLE_DELIMITER) + assert out_parts[1] == in_parts[1] # projections_metadata, byte-identical + assert out_parts[2] == in_parts[2] # projections_data, byte-identical + assert out_parts[3] == in_parts[3] # settings, byte-identical + + +def test_delimiter_in_annotation_cell_raises(tmp_path): + # If an annotation value contains the bundle delimiter, the written part + # would corrupt the split on read-back; this must fail loudly instead. + src = tmp_path / "in.parquetbundle" + out = tmp_path / "out.parquetbundle" + write_bundle(_tables(), src) + + evil = "ev" + PARQUET_BUNDLE_DELIMITER.decode() + "il" + bad = pa.table({"identifier": ["A", "B"], "cat": ["x", evil]}) + with pytest.raises(ValueError): + replace_annotations_in_bundle(src, out, bad) + + +def test_in_place_overwrite_works_and_leaves_no_temp(tmp_path): + # The documented -b == -o workflow must produce the augmented bundle and + # leave no stray temp file behind. + path = tmp_path / "b.parquetbundle" + write_bundle(_tables(), path) + new_annotations = pa.table( + {"identifier": ["A", "B"], "cat": ["x", "y"], "cat__pred_value": [None, "z"]} + ) + replace_annotations_in_bundle(path, path, new_annotations) # same path + parts, _ = read_bundle(path) + assert "cat__pred_value" in pq.read_table(io.BytesIO(parts[0])).column_names + assert not list(tmp_path.glob("*.tmp")) + + +def test_failed_replace_preserves_original_in_place(tmp_path, monkeypatch): + # If the rename is interrupted, the original bundle must survive intact + # (atomic write) rather than being left truncated. + import protspace.data.io.bundle as bundle_mod + + path = tmp_path / "b.parquetbundle" + write_bundle(_tables(), path) + original = path.read_bytes() + new_annotations = pa.table( + {"identifier": ["A", "B"], "cat": ["x", "y"], "cat__pred_value": [None, "z"]} + ) + + def boom(*args, **kwargs): + raise OSError("simulated interrupt before rename") + + monkeypatch.setattr(bundle_mod.os, "replace", boom) + with pytest.raises(OSError): + replace_annotations_in_bundle(path, path, new_annotations) + assert path.read_bytes() == original # untouched + assert not list(tmp_path.glob("*.tmp")) # temp cleaned up + + +def test_preserves_settings_when_present(tmp_path): + src = tmp_path / "in.parquetbundle" + out = tmp_path / "out.parquetbundle" + write_bundle(_tables(), src, settings={"foo": 1}) + + new_annotations = pa.table({"identifier": ["A", "B"], "cat": ["x", "y"]}) + replace_annotations_in_bundle(src, out, new_annotations) + + _parts, settings = read_bundle(out) + assert settings == {"foo": 1} diff --git a/tests/test_classification.py b/tests/test_classification.py new file mode 100644 index 00000000..ecaba1a0 --- /dev/null +++ b/tests/test_classification.py @@ -0,0 +1,71 @@ +"""Tests for the query/reference classifier.""" + +import pyarrow as pa +import pytest + +from protspace.analysis.classification import Rule, classify + + +def _table(): + return pa.table( + { + "identifier": ["TRINITY_1", "TRINITY_2", "P00001", "P00002"], + "protein_category": ["mSCR", "mSCR", "neurotoxin", "enzyme"], + } + ) + + +def test_prefix_rule_selects_queries(): + q = Rule(id_prefixes=["TRINITY_"]) + r = Rule(where=[("protein_category", "neurotoxin")]) + qi, ri = classify(_table(), q, r) + assert qi == [0, 1] + assert ri == [2] + + +def test_where_substring_is_case_insensitive(): + q = Rule(where=[("protein_category", "MSCR")]) + r = Rule(id_prefixes=["P0"]) + qi, ri = classify(_table(), q, r) + assert qi == [0, 1] + assert ri == [2, 3] + + +def test_query_takes_precedence_over_reference(): + # A protein matching both rules is classified as a query, never a reference. + q = Rule(id_prefixes=["P00001"]) + r = Rule(where=[("protein_category", "neurotoxin")]) + qi, ri = classify(_table(), q, r) + assert 2 in qi + assert 2 not in ri + + +def test_empty_query_match_raises(): + q = Rule(id_prefixes=["NOPE_"]) + r = Rule(id_prefixes=["P0"]) + with pytest.raises(ValueError, match="no query"): + classify(_table(), q, r) + + +def test_missing_where_column_raises(): + q = Rule(where=[("not_a_column", "x")]) + r = Rule(id_prefixes=["P0"]) + with pytest.raises(KeyError): + classify(_table(), q, r) + + +def test_protein_matching_neither_rule_is_excluded(): + # P00002 / enzyme matches neither rule -> absent from both lists. + q = Rule(id_prefixes=["TRINITY_"]) + r = Rule(where=[("protein_category", "neurotoxin")]) + qi, ri = classify(_table(), q, r) + assert 3 not in qi + assert 3 not in ri + + +def test_multiple_id_prefixes_use_or_semantics(): + q = Rule(id_prefixes=["TRINITY_", "P00001"]) + r = Rule(id_prefixes=["P00002"]) + qi, ri = classify(_table(), q, r) + assert qi == [0, 1, 2] + assert ri == [3] diff --git a/tests/test_h5_parse_identifier.py b/tests/test_h5_parse_identifier.py index e4a5b530..fc62964a 100644 --- a/tests/test_h5_parse_identifier.py +++ b/tests/test_h5_parse_identifier.py @@ -1,8 +1,34 @@ """Tests for H5 identifier parsing.""" +from pathlib import Path + import pytest -from protspace.data.loaders.h5 import parse_identifier +from protspace.data.loaders.h5 import parse_identifier, split_h5_spec + + +class TestSplitH5Spec: + def test_no_override(self): + assert split_h5_spec("emb.h5") == (Path("emb.h5"), None) + + def test_name_override(self): + assert split_h5_spec("emb.h5:prot_t5") == (Path("emb.h5"), "prot_t5") + + def test_posix_path_with_override(self): + assert split_h5_spec("/data/emb.h5:prot_t5") == ( + Path("/data/emb.h5"), + "prot_t5", + ) + + def test_windows_drive_path_no_override(self): + # The drive-letter colon must not be mistaken for a name override. + assert split_h5_spec("C:\\data\\emb.h5") == (Path("C:\\data\\emb.h5"), None) + + def test_windows_drive_path_with_override(self): + assert split_h5_spec("C:\\data\\emb.h5:prot_t5") == ( + Path("C:\\data\\emb.h5"), + "prot_t5", + ) class TestParseIdentifier: diff --git a/tests/test_predictions_overlay.py b/tests/test_predictions_overlay.py new file mode 100644 index 00000000..3817ace1 --- /dev/null +++ b/tests/test_predictions_overlay.py @@ -0,0 +1,109 @@ +"""Tests for building the per-cell prediction overlay columns.""" + +import io + +import pyarrow as pa +import pyarrow.parquet as pq + +from protlabel import Prediction +from protspace.data.io.predictions import add_overlay_columns + + +def _table(): + return pa.table( + { + "identifier": ["Q0", "Q1", "R0"], + "protein_category": ["", "", "neurotoxin"], + } + ) + + +def test_adds_value_and_confidence_overlay_columns(): + preds = [ + Prediction("Q0", "neurotoxin", "R0", 0.3, 0.62, 1, "euclidean"), + ] + out = add_overlay_columns(_table(), "protein_category", preds) + assert "protein_category__pred_value" in out.column_names + assert "protein_category__pred_confidence" in out.column_names + # The source column is intentionally not emitted (provenance is not a useful + # per-cell overlay; confidence carries the signal users threshold on). + assert "protein_category__pred_source" not in out.column_names + + +def test_overlay_values_aligned_by_identifier(): + preds = [Prediction("Q1", "enzyme", "R9", 0.5, 0.5, 1, "euclidean")] + out = add_overlay_columns(_table(), "protein_category", preds).to_pylist() + by_id = {r["identifier"]: r for r in out} + assert by_id["Q1"]["protein_category__pred_value"] == "enzyme" + assert by_id["Q1"]["protein_category__pred_confidence"] == 0.5 + # Non-predicted rows are null in the overlay columns. + assert by_id["Q0"]["protein_category__pred_value"] is None + assert by_id["R0"]["protein_category__pred_confidence"] is None + + +def test_curated_column_is_left_untouched(): + preds = [Prediction("Q0", "neurotoxin", "R0", 0.1, 0.8, 1, "euclidean")] + out = add_overlay_columns(_table(), "protein_category", preds).to_pylist() + by_id = {r["identifier"]: r for r in out} + assert by_id["Q0"]["protein_category"] == "" # original column unchanged + assert by_id["R0"]["protein_category"] == "neurotoxin" + + +def test_confidence_column_is_float(): + preds = [Prediction("Q0", "x", "R0", 0.1, 0.83, 1, "euclidean")] + out = add_overlay_columns(_table(), "protein_category", preds) + field = out.schema.field("protein_category__pred_confidence") + assert pa.types.is_floating(field.type) + + +def test_empty_predictions_appends_all_null_columns(): + out = add_overlay_columns(_table(), "protein_category", []) + assert "protein_category__pred_value" in out.column_names + assert out.column("protein_category__pred_value").to_pylist() == [None, None, None] + assert out.column("protein_category__pred_confidence").to_pylist() == [ + None, + None, + None, + ] + + +def test_prediction_for_unknown_identifier_is_ignored(): + preds = [Prediction("NOT_IN_TABLE", "x", "R0", 0.1, 0.9, 1, "euclidean")] + out = add_overlay_columns(_table(), "protein_category", preds).to_pylist() + assert all(r["protein_category__pred_value"] is None for r in out) + + +def test_reapplying_overlay_replaces_not_duplicates(): + # Re-running transfer on an already-overlaid table must replace the overlay + # columns, not append duplicates (which produce an unreadable parquet table). + preds1 = [Prediction("Q0", "old", "R0", 0.3, 0.6, 1, "euclidean")] + once = add_overlay_columns(_table(), "protein_category", preds1) + preds2 = [Prediction("Q0", "new", "R0", 0.1, 0.9, 1, "euclidean")] + twice = add_overlay_columns(once, "protein_category", preds2) + + assert twice.column_names.count("protein_category__pred_value") == 1 + assert twice.column_names.count("protein_category__pred_confidence") == 1 + assert "protein_category__pred_source" not in twice.column_names + + by_id = {r["identifier"]: r for r in twice.to_pylist()} + assert by_id["Q0"]["protein_category__pred_value"] == "new" + + # Duplicate column names would make this round-trip raise ArrowInvalid. + buf = io.BytesIO() + pq.write_table(twice, buf) + reread = pq.read_table(io.BytesIO(buf.getvalue())) + assert reread.column("protein_category__pred_value").to_pylist()[0] == "new" + + +def test_legacy_source_column_is_removed_on_rerun(): + # A bundle written by an older version may carry a __pred_source column; + # re-running must drop it rather than leave it orphaned/stale. + legacy = _table().append_column( + "protein_category__pred_source", pa.array(["R0", None, None], pa.string()) + ) + out = add_overlay_columns( + legacy, + "protein_category", + [Prediction("Q0", "neurotoxin", "R0", 0.1, 0.8, 1, "euclidean")], + ) + assert "protein_category__pred_source" not in out.column_names diff --git a/tests/test_protlabel_backends.py b/tests/test_protlabel_backends.py new file mode 100644 index 00000000..5f7285aa --- /dev/null +++ b/tests/test_protlabel_backends.py @@ -0,0 +1,140 @@ +"""Tests for protlabel.backends.nearest.""" + +import numpy as np +import pytest + +from protlabel.backends import nearest + + +def _toy(): + # 3 references on a line; queries close to ref 0 and ref 2 + refs = np.array([[0.0, 0.0], [10.0, 0.0], [20.0, 0.0]], dtype=np.float32) + queries = np.array([[0.1, 0.0], [19.5, 0.0]], dtype=np.float32) + return queries, refs + + +def test_returns_shapes(): + queries, refs = _toy() + idx, dist = nearest(queries, refs, k=2, metric="euclidean") + assert idx.shape == (2, 2) + assert dist.shape == (2, 2) + + +def test_nearest_index_euclidean(): + queries, refs = _toy() + idx, dist = nearest(queries, refs, k=1, metric="euclidean") + assert idx[0, 0] == 0 # first query nearest to ref 0 + assert idx[1, 0] == 2 # second query nearest to ref 2 + assert dist[0, 0] == pytest.approx(0.1, abs=1e-4) + + +def test_neighbours_sorted_by_distance(): + queries, refs = _toy() + idx, dist = nearest(queries, refs, k=3, metric="euclidean") + assert np.all(np.diff(dist, axis=1) >= -1e-6) # non-decreasing per row + + +def test_cosine_metric_runs_and_orders(): + refs = np.array([[1.0, 0.0], [0.0, 1.0]], dtype=np.float32) + queries = np.array([[1.0, 0.1]], dtype=np.float32) # closest in angle to ref 0 + idx, dist = nearest(queries, refs, k=1, metric="cosine") + assert idx[0, 0] == 0 + + +def test_k_capped_to_num_refs(): + queries, refs = _toy() + idx, dist = nearest(queries, refs, k=10, metric="euclidean") + assert idx.shape == (2, 3) # only 3 refs available + assert np.all(np.diff(dist, axis=1) >= -1e-6) + assert idx[0, 0] == 0 and idx[1, 0] == 2 + + +def test_chunking_matches_unchunked(): + rng = np.random.default_rng(0) + refs = rng.standard_normal((50, 8)).astype(np.float32) + queries = rng.standard_normal((7, 8)).astype(np.float32) + a_idx, a_dist = nearest(queries, refs, k=3, metric="euclidean", chunk=1000) + b_idx, b_dist = nearest(queries, refs, k=3, metric="euclidean", chunk=3) + assert np.array_equal(a_idx, b_idx) + assert np.allclose(a_dist, b_dist, atol=1e-5) + + +def test_unknown_metric_raises(): + queries, refs = _toy() + with pytest.raises(ValueError): + nearest(queries, refs, k=1, metric="manhattan") + + +def test_tiny_memory_budget_matches_default(): + rng = np.random.default_rng(1) + refs = rng.standard_normal((40, 6)).astype(np.float32) + queries = rng.standard_normal((9, 6)).astype(np.float32) + a_idx, a_dist = nearest(queries, refs, k=3, metric="euclidean") + b_idx, b_dist = nearest(queries, refs, k=3, metric="euclidean", max_block_bytes=1) + assert np.array_equal(a_idx, b_idx) + assert np.allclose(a_dist, b_dist, atol=1e-5) + + +def test_k_less_than_one_raises(): + queries, refs = _toy() + with pytest.raises(ValueError): + nearest(queries, refs, k=0, metric="euclidean") + + +def test_cosine_zero_vector_query_is_finite(): + # A zero-magnitude query must not produce NaN cosine distances (scipy.cdist + # returns NaN here); the result must stay finite and deterministic. + refs = np.array([[1.0, 0.0], [0.0, 1.0]], dtype=np.float32) + queries = np.array([[0.0, 0.0]], dtype=np.float32) + idx, dist = nearest(queries, refs, k=2, metric="cosine") + assert idx.shape == (1, 2) + assert np.all(np.isfinite(dist)) + + +def test_cosine_zero_vector_reference_is_finite(): + # A zero-magnitude reference must not poison the distance block with NaN. + refs = np.array([[0.0, 0.0], [0.0, 1.0]], dtype=np.float32) + queries = np.array([[1.0, 0.1]], dtype=np.float32) + idx, dist = nearest(queries, refs, k=2, metric="cosine") + assert np.all(np.isfinite(dist)) + + +def test_euclidean_matches_reference_distances(): + # The (BLAS) euclidean path must agree with a direct distance computation. + rng = np.random.default_rng(7) + refs = rng.standard_normal((30, 12)).astype(np.float32) + queries = rng.standard_normal((5, 12)).astype(np.float32) + idx, dist = nearest(queries, refs, k=4, metric="euclidean") + for qi in range(queries.shape[0]): + full = np.linalg.norm(refs - queries[qi], axis=1) + expected = np.sort(full)[:4] + assert np.allclose(dist[qi], expected, atol=1e-4) + + +def test_euclidean_precise_for_near_identical_high_dim(): + # Near-identical high-dimensional vectors are the catastrophic-cancellation + # regime for the GEMM expansion; the reported distance must stay precise + # (not collapse to 0) by matching a direct float64 norm. + rng = np.random.default_rng(3) + base = rng.standard_normal(1024).astype(np.float32) + other = rng.standard_normal(1024).astype(np.float32) + refs = np.stack([base, base + np.float32(1e-3), other]) + query = (base + np.float32(2e-4))[None, :] + idx, dist = nearest(query, refs, k=2, metric="euclidean") + assert idx[0, 0] == 0 # nearest is the near-duplicate of base + expected = np.linalg.norm(refs[0].astype(np.float64) - query[0].astype(np.float64)) + assert dist[0, 0] > 0.0 # not clipped to zero + assert dist[0, 0] == pytest.approx(expected, rel=1e-3) + + +def test_cosine_matches_reference_distances(): + rng = np.random.default_rng(8) + refs = rng.standard_normal((25, 10)).astype(np.float32) + queries = rng.standard_normal((4, 10)).astype(np.float32) + idx, dist = nearest(queries, refs, k=3, metric="cosine") + rn = refs / np.linalg.norm(refs, axis=1, keepdims=True) + for qi in range(queries.shape[0]): + qn = queries[qi] / np.linalg.norm(queries[qi]) + full = 1.0 - rn @ qn + expected = np.sort(full)[:3] + assert np.allclose(dist[qi], expected, atol=1e-4) diff --git a/tests/test_protlabel_lookup.py b/tests/test_protlabel_lookup.py new file mode 100644 index 00000000..ff056e28 --- /dev/null +++ b/tests/test_protlabel_lookup.py @@ -0,0 +1,66 @@ +"""Tests for protlabel.lookup.Lookup (the rebuildable sidecar).""" + +import numpy as np + +from protlabel import Lookup, Prediction + + +def _lookup(): + emb = np.array([[0.0, 0.0], [10.0, 0.0]], dtype=np.float32) + return Lookup(embeddings=emb, ids=["R0", "R1"], labels=["a", "b"]) + + +def test_query_returns_predictions(): + lk = _lookup() + preds = lk.query(np.array([[0.2, 0.0]], dtype=np.float32), ["Q0"], k=1) + assert isinstance(preds[0], Prediction) + assert preds[0].label == "a" + assert preds[0].source_id == "R0" + + +def test_save_load_roundtrip(tmp_path): + lk = _lookup() + path = tmp_path / "lookup.npz" + lk.save(path) + assert path.exists() + loaded = Lookup.load(path) + assert loaded.ids == lk.ids + assert loaded.labels == lk.labels + assert loaded.metric == lk.metric + assert np.allclose(loaded.embeddings, lk.embeddings) + + +def test_save_load_preserves_float32_values_exactly(tmp_path): + # Non-power-of-2 values would be corrupted by a float16 round-trip; the + # sidecar must preserve the float32 embeddings exactly. + emb = np.array([[0.123456, 1.9876543], [3.14159, 2.71828]], dtype=np.float32) + lk = Lookup(embeddings=emb, ids=["P1", "P2"], labels=["x", "y"]) + path = tmp_path / "lk.npz" + lk.save(path) + loaded = Lookup.load(path) + assert loaded.embeddings.dtype == np.float32 + assert np.array_equal(loaded.embeddings, emb) + + +def test_save_uses_no_object_arrays(tmp_path): + # ids/labels must be stored as plain unicode (not pickled object arrays), so + # the sidecar can be loaded with allow_pickle=False (no RCE surface). + lk = Lookup(embeddings=_lookup().embeddings, ids=["R0", "R1"], labels=["a", "b"]) + path = tmp_path / "lk.npz" + lk.save(path) + with np.load(path, allow_pickle=False) as data: + assert data["ids"].dtype != object + assert data["labels"].dtype != object + assert list(data["ids"]) == ["R0", "R1"] + assert list(data["labels"]) == ["a", "b"] + + +def test_loaded_lookup_queries_identically(tmp_path): + lk = _lookup() + q = np.array([[9.8, 0.0]], dtype=np.float32) + before = lk.query(q, ["Q"], k=1) + path = tmp_path / "lk.npz" + lk.save(path) + after = Lookup.load(path).query(q, ["Q"], k=1) + assert before[0].label == after[0].label == "b" + assert before[0].reliability == after[0].reliability diff --git a/tests/test_protlabel_reliability.py b/tests/test_protlabel_reliability.py new file mode 100644 index 00000000..d40bec60 --- /dev/null +++ b/tests/test_protlabel_reliability.py @@ -0,0 +1,38 @@ +"""Tests for protlabel.reliability.""" + +import math + +import pytest + +from protlabel.reliability import similarity + + +def test_euclidean_at_zero_distance_is_one(): + assert similarity(0.0, "euclidean") == pytest.approx(1.0) + + +def test_euclidean_at_half_distance_is_half(): + assert similarity(0.5, "euclidean") == pytest.approx(0.5) + + +def test_euclidean_decreases_to_zero(): + assert similarity(1e9, "euclidean") == pytest.approx(0.0, abs=1e-6) + + +def test_cosine_is_one_minus_distance(): + assert similarity(0.2, "cosine") == pytest.approx(0.8) + + +def test_cosine_clamped_to_unit_interval(): + # cosine distance can be up to 2.0 -> 1 - d would go negative; clamp at 0 + assert similarity(1.7, "cosine") == pytest.approx(0.0) + assert similarity(-0.1, "cosine") == pytest.approx(1.0) + + +def test_unknown_metric_raises(): + with pytest.raises(ValueError): + similarity(0.5, "manhattan") + + +def test_smoke(): + assert math.isfinite(similarity(0.5, "euclidean")) diff --git a/tests/test_protlabel_transfer.py b/tests/test_protlabel_transfer.py new file mode 100644 index 00000000..458d56f9 --- /dev/null +++ b/tests/test_protlabel_transfer.py @@ -0,0 +1,102 @@ +"""Tests for protlabel.transfer.""" + +import numpy as np +import pytest + +from protlabel.transfer import Prediction, eat + + +def _setup(): + ref_emb = np.array([[0.0, 0.0], [10.0, 0.0], [20.0, 0.0]], dtype=np.float32) + ref_ids = ["R0", "R1", "R2"] + ref_labels = ["toxin", "enzyme", "toxin"] + query_emb = np.array([[0.0, 0.0], [19.7, 0.0]], dtype=np.float32) + query_ids = ["Q0", "Q1"] + return ref_emb, ref_ids, ref_labels, query_emb, query_ids + + +def test_k1_transfers_nearest_label_and_source(): + ref_emb, ref_ids, ref_labels, query_emb, query_ids = _setup() + preds = eat(query_emb, query_ids, ref_emb, ref_ids, ref_labels, k=1) + assert isinstance(preds[0], Prediction) + assert preds[0].query_id == "Q0" + assert preds[0].label == "toxin" + assert preds[0].source_id == "R0" + assert preds[0].reliability == pytest.approx(1.0) # distance 0 -> RI 1.0 + + +def test_k1_reliability_uses_gopredsim_transform(): + ref_emb, ref_ids, ref_labels, query_emb, query_ids = _setup() + preds = eat(query_emb, query_ids, ref_emb, ref_ids, ref_labels, k=1) + # Q1 distance to R2 is 0.3 -> RI = 0.5/(0.5+0.3) + assert preds[1].label == "toxin" + assert preds[1].source_id == "R2" + assert preds[1].reliability == pytest.approx(0.5 / 0.8, abs=1e-4) + + +def test_k3_vote_picks_majority_label(): + # Query equidistant-ish but two of three nearest are "toxin" + ref_emb = np.array( + [[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [3.0, 0.0]], dtype=np.float32 + ) + ref_ids = ["R0", "R1", "R2", "R3"] + ref_labels = ["toxin", "enzyme", "toxin", "toxin"] + query_emb = np.array([[1.4, 0.0]], dtype=np.float32) + preds = eat(query_emb, ["Q"], ref_emb, ref_ids, ref_labels, k=3) + assert preds[0].label == "toxin" # toxin RI sum beats lone enzyme neighbour + assert 0.0 < preds[0].reliability <= 1.0 + + +def test_cosine_metric(): + ref_emb = np.array([[1.0, 0.0], [0.0, 1.0]], dtype=np.float32) + preds = eat( + np.array([[1.0, 0.05]], dtype=np.float32), + ["Q"], + ref_emb, + ["R0", "R1"], + ["a", "b"], + k=1, + metric="cosine", + ) + assert preds[0].label == "a" + + +def test_length_mismatch_raises(): + ref_emb, ref_ids, ref_labels, query_emb, query_ids = _setup() + with pytest.raises(ValueError): + eat(query_emb, query_ids, ref_emb, ref_ids, ref_labels[:-1], k=1) + + +def test_empty_references_raises(): + with pytest.raises(ValueError): + eat( + np.zeros((1, 2), dtype=np.float32), + ["Q"], + np.zeros((0, 2), dtype=np.float32), + [], + [], + k=1, + ) + + +def test_equal_distance_tie_break_is_deterministic(): + # Two refs exactly equidistant from the query carry different labels at equal + # source distances; the winner must be deterministic (lexically smallest + # label) rather than dependent on argsort ordering of the equal distances. + ref_emb = np.array([[1.0, 0.0], [-1.0, 0.0]], dtype=np.float32) + query_emb = np.array([[0.0, 0.0]], dtype=np.float32) # distance 1.0 to both + preds = eat(query_emb, ["Q"], ref_emb, ["R_z", "R_a"], ["zebra", "apple"], k=2) + assert preds[0].label == "apple" + assert preds[0].source_id == "R_a" + + +def test_source_is_nearest_neighbour_with_winning_label(): + # Two neighbours share the winning label at distinct distances; the source + # must be the closer one. + ref_emb = np.array([[0.0, 0.0], [0.5, 0.0], [10.0, 0.0]], dtype=np.float32) + ref_ids = ["R_far", "R_near", "R_other"] + ref_labels = ["toxin", "toxin", "enzyme"] + query_emb = np.array([[0.4, 0.0]], dtype=np.float32) # R_near at 0.1, R_far at 0.4 + preds = eat(query_emb, ["Q"], ref_emb, ref_ids, ref_labels, k=2) + assert preds[0].label == "toxin" + assert preds[0].source_id == "R_near" diff --git a/tests/test_protlabel_version.py b/tests/test_protlabel_version.py new file mode 100644 index 00000000..3b436511 --- /dev/null +++ b/tests/test_protlabel_version.py @@ -0,0 +1,11 @@ +"""protlabel.__version__ reports the version it actually ships under.""" + +from importlib.metadata import version + +import protlabel + + +def test_version_matches_shipped_distribution(): + # protlabel currently ships inside the protspace wheel, so its reported + # version must track the installed distribution rather than a stale literal. + assert protlabel.__version__ == version("protspace") diff --git a/tests/test_transfer_cli.py b/tests/test_transfer_cli.py new file mode 100644 index 00000000..a3cacb97 --- /dev/null +++ b/tests/test_transfer_cli.py @@ -0,0 +1,374 @@ +"""Tests for the transfer orchestration core and CLI registration.""" + +import logging + +import numpy as np +import pyarrow as pa +import pytest + +from protspace.analysis.classification import Rule +from protspace.cli.transfer import run_transfer + + +def _three_protein_inputs(extra_columns=None): + cols = { + "identifier": ["TRINITY_1", "P00001", "P00002"], + "protein_category": ["", "neurotoxin", "enzyme"], + } + if extra_columns: + cols.update(extra_columns) + annotations = pa.table(cols) + embeddings = { + "TRINITY_1": np.array([0.0, 0.0], dtype=np.float32), + "P00001": np.array([0.05, 0.0], dtype=np.float32), + "P00002": np.array([9.0, 0.0], dtype=np.float32), + } + return annotations, embeddings + + +def _write_bundle_and_h5(tmp_path, *, id_col="protein_id", extra_columns=None): + import h5py + + from protspace.data.io.bundle import write_bundle + + cols = {id_col: ["TRINITY_1", "P00001"], "protein_category": ["", "neurotoxin"]} + if extra_columns: + cols.update(extra_columns) + annotations = pa.table(cols) + proj_meta = pa.table({"name": ["PCA 2"], "dims": [2]}) + proj_data = pa.table( + {"id": ["TRINITY_1", "P00001"], "x": [0.0, 9.0], "y": [0.0, 0.0]} + ) + bundle_path = tmp_path / "in.parquetbundle" + write_bundle([annotations, proj_meta, proj_data], bundle_path) + + h5_path = tmp_path / "emb.h5" + with h5py.File(h5_path, "w") as f: + f.attrs["model_name"] = "test_model" + f.create_dataset("TRINITY_1", data=np.array([0.0, 0.0], dtype=np.float32)) + f.create_dataset("P00001", data=np.array([0.1, 0.0], dtype=np.float32)) + return bundle_path, h5_path + + +def _inputs(): + annotations = pa.table( + { + "identifier": ["TRINITY_1", "P00001", "P00002"], + "protein_category": ["", "neurotoxin", "enzyme"], + } + ) + # TRINITY_1 sits right on top of the neurotoxin reference P00001. + embeddings = { + "TRINITY_1": np.array([0.0, 0.0], dtype=np.float32), + "P00001": np.array([0.05, 0.0], dtype=np.float32), + "P00002": np.array([9.0, 0.0], dtype=np.float32), + } + return annotations, embeddings + + +def test_run_transfer_predicts_for_query_with_missing_value(): + annotations, embeddings = _inputs() + out = run_transfer( + annotations=annotations, + embeddings=embeddings, + transfer_columns=["protein_category"], + query_rule=Rule(id_prefixes=["TRINITY_"]), + reference_rule=Rule( + where=[("protein_category", "")] + ), # matches all proteins; run_transfer keeps only those with a value + k=1, + metric="euclidean", + ) + by_id = {r["identifier"]: r for r in out.to_pylist()} + assert by_id["TRINITY_1"]["protein_category__pred_value"] == "neurotoxin" + assert by_id["TRINITY_1"]["protein_category__pred_confidence"] > 0.9 + assert "protein_category__pred_source" not in out.column_names + + +def test_run_transfer_skips_proteins_without_embeddings(): + annotations, embeddings = _inputs() + embeddings.pop("TRINITY_1") # no embedding -> cannot be a query + with pytest.raises(ValueError, match="no query"): + run_transfer( + annotations=annotations, + embeddings=embeddings, + transfer_columns=["protein_category"], + query_rule=Rule(id_prefixes=["TRINITY_"]), + reference_rule=Rule(id_prefixes=["P0"]), + k=1, + metric="euclidean", + ) + + +def test_run_transfer_multi_column_skips_column_without_references(caplog): + # Two transfer columns; the second has no references with a value and must be + # skipped while the first still produces its overlay. + annotations, embeddings = _three_protein_inputs( + extra_columns={"other_col": ["", "", ""]} + ) + with caplog.at_level(logging.WARNING): + out = run_transfer( + annotations=annotations, + embeddings=embeddings, + transfer_columns=["protein_category", "other_col"], + query_rule=Rule(id_prefixes=["TRINITY_"]), + reference_rule=Rule(id_prefixes=["P0"]), + k=1, + metric="euclidean", + ) + assert "protein_category__pred_value" in out.column_names + assert "other_col__pred_value" not in out.column_names + assert "other_col" in caplog.text + + +def test_run_transfer_k_greater_than_one(): + annotations, embeddings = _three_protein_inputs() + out = run_transfer( + annotations=annotations, + embeddings=embeddings, + transfer_columns=["protein_category"], + query_rule=Rule(id_prefixes=["TRINITY_"]), + reference_rule=Rule(id_prefixes=["P0"]), + k=2, + metric="euclidean", + ) + by_id = {r["identifier"]: r for r in out.to_pylist()} + assert by_id["TRINITY_1"]["protein_category__pred_value"] == "neurotoxin" + + +def test_run_transfer_cosine_metric(): + annotations, embeddings = _three_protein_inputs() + # Cosine needs non-parallel-to-axis vectors to be meaningful. + embeddings = { + "TRINITY_1": np.array([1.0, 0.1], dtype=np.float32), + "P00001": np.array([1.0, 0.0], dtype=np.float32), + "P00002": np.array([0.0, 1.0], dtype=np.float32), + } + out = run_transfer( + annotations=annotations, + embeddings=embeddings, + transfer_columns=["protein_category"], + query_rule=Rule(id_prefixes=["TRINITY_"]), + reference_rule=Rule(id_prefixes=["P0"]), + k=1, + metric="cosine", + ) + by_id = {r["identifier"]: r for r in out.to_pylist()} + assert by_id["TRINITY_1"]["protein_category__pred_value"] == "neurotoxin" + assert "protein_category__pred_source" not in out.column_names + + +def test_run_transfer_warns_when_nothing_transferred(caplog): + # Query proteins all already have a value -> nothing to transfer -> warning. + annotations = pa.table( + { + "identifier": ["TRINITY_1", "P00001"], + "protein_category": ["already_set", "neurotoxin"], + } + ) + embeddings = { + "TRINITY_1": np.array([0.0, 0.0], dtype=np.float32), + "P00001": np.array([0.1, 0.0], dtype=np.float32), + } + with caplog.at_level(logging.WARNING): + out = run_transfer( + annotations=annotations, + embeddings=embeddings, + transfer_columns=["protein_category"], + query_rule=Rule(id_prefixes=["TRINITY_"]), + reference_rule=Rule(id_prefixes=["P0"]), + k=1, + metric="euclidean", + ) + assert "protein_category__pred_value" not in out.column_names + assert "No annotations were transferred" in caplog.text + + +def test_cli_bad_where_column_is_clean_error(tmp_path): + from typer.testing import CliRunner + + from protspace.cli.app import app + + bundle, h5 = _write_bundle_and_h5(tmp_path) + out = tmp_path / "out.parquetbundle" + result = CliRunner().invoke( + app, + [ + "transfer", + "-b", + str(bundle), + "-e", + str(h5), + "-t", + "protein_category", + "-o", + str(out), + "--query-where", + "nonexistent~x", + "--reference-id-prefix", + "P0", + ], + ) + assert result.exit_code != 0 + assert not isinstance(result.exception, KeyError) + assert "nonexistent" in result.output + + +def test_cli_no_matching_embeddings_is_clean_error(tmp_path): + import h5py + from typer.testing import CliRunner + + from protspace.cli.app import app + + bundle, _ = _write_bundle_and_h5(tmp_path) + bad_h5 = tmp_path / "bad.h5" + with h5py.File(bad_h5, "w") as f: + f.attrs["model_name"] = "m" + f.create_dataset("ZZZ", data=np.array([0.0, 0.0], dtype=np.float32)) + out = tmp_path / "out.parquetbundle" + result = CliRunner().invoke( + app, + [ + "transfer", + "-b", + str(bundle), + "-e", + str(bad_h5), + "-t", + "protein_category", + "-o", + str(out), + "--query-id-prefix", + "TRINITY_", + "--reference-id-prefix", + "P0", + ], + ) + assert result.exit_code != 0 + assert not isinstance(result.exception, ValueError) + + +def test_cli_no_query_match_is_clean_error(tmp_path): + from typer.testing import CliRunner + + from protspace.cli.app import app + + bundle, h5 = _write_bundle_and_h5(tmp_path) + out = tmp_path / "out.parquetbundle" + result = CliRunner().invoke( + app, + [ + "transfer", + "-b", + str(bundle), + "-e", + str(h5), + "-t", + "protein_category", + "-o", + str(out), + "--query-id-prefix", + "NOPE_", + "--reference-id-prefix", + "P0", + ], + ) + assert result.exit_code != 0 + assert not isinstance(result.exception, ValueError) + + +def test_cli_both_id_columns_present_is_clean_error(tmp_path): + from typer.testing import CliRunner + + from protspace.cli.app import app + + bundle, h5 = _write_bundle_and_h5( + tmp_path, extra_columns={"identifier": ["TRINITY_1", "P00001"]} + ) + out = tmp_path / "out.parquetbundle" + result = CliRunner().invoke( + app, + [ + "transfer", + "-b", + str(bundle), + "-e", + str(h5), + "-t", + "protein_category", + "-o", + str(out), + "--query-id-prefix", + "TRINITY_", + "--reference-id-prefix", + "P0", + ], + ) + assert result.exit_code != 0 + assert not isinstance(result.exception, KeyError) + + +def test_transfer_command_is_registered(): + from typer.testing import CliRunner + + from protspace.cli.app import app + + result = CliRunner().invoke(app, ["transfer", "--help"]) + assert result.exit_code == 0 + assert "transfer" in result.output.lower() + + +def test_cli_end_to_end_protein_id_bundle(tmp_path): + """Real bundles key the id column 'protein_id'; the CLI must handle it and + preserve that name on write while adding the overlay columns.""" + import io + + import h5py + import pyarrow.parquet as pq + from typer.testing import CliRunner + + from protspace.cli.app import app + from protspace.data.io.bundle import read_bundle, write_bundle + + annotations = pa.table( + {"protein_id": ["TRINITY_1", "P00001"], "protein_category": ["", "neurotoxin"]} + ) + proj_meta = pa.table({"name": ["PCA 2"], "dims": [2]}) + proj_data = pa.table( + {"id": ["TRINITY_1", "P00001"], "x": [0.0, 9.0], "y": [0.0, 0.0]} + ) + bundle_path = tmp_path / "in.parquetbundle" + write_bundle([annotations, proj_meta, proj_data], bundle_path) + + h5_path = tmp_path / "emb.h5" + with h5py.File(h5_path, "w") as f: + f.attrs["model_name"] = "test_model" + f.create_dataset("TRINITY_1", data=np.array([0.0, 0.0], dtype=np.float32)) + f.create_dataset("P00001", data=np.array([0.1, 0.0], dtype=np.float32)) + + out_path = tmp_path / "out.parquetbundle" + result = CliRunner().invoke( + app, + [ + "transfer", + "-b", + str(bundle_path), + "-e", + str(h5_path), + "-t", + "protein_category", + "-o", + str(out_path), + "--query-id-prefix", + "TRINITY_", + "--reference-id-prefix", + "P0", + ], + ) + assert result.exit_code == 0, result.output + parts, _ = read_bundle(out_path) + table = pq.read_table(io.BytesIO(parts[0])) + assert "protein_id" in table.column_names # id column preserved for the web reader + rows = {r["protein_id"]: r for r in table.to_pylist()} + assert rows["TRINITY_1"]["protein_category__pred_value"] == "neurotoxin" + assert "protein_category__pred_source" not in table.column_names diff --git a/uv.lock b/uv.lock index 9a2d779e..4237dccc 100644 --- a/uv.lock +++ b/uv.lock @@ -2738,6 +2738,7 @@ dependencies = [ { name = "requests" }, { name = "rich" }, { name = "scikit-learn" }, + { name = "scipy" }, { name = "tqdm" }, { name = "typer" }, { name = "umap-learn" }, @@ -2805,6 +2806,7 @@ requires-dist = [ { name = "requests", marker = "extra == 'frontend'", specifier = ">=2.32.4" }, { name = "rich", specifier = ">=14.3.3" }, { name = "scikit-learn", specifier = ">=1.6.1" }, + { name = "scipy", specifier = ">=1.10" }, { name = "tqdm", specifier = ">=4.67.1" }, { name = "typer", specifier = ">=0.24.1" }, { name = "umap-learn", specifier = ">=0.5.10" },