diff --git a/src/protspace/cli/app.py b/src/protspace/cli/app.py index c718fbae..ffae85a2 100644 --- a/src/protspace/cli/app.py +++ b/src/protspace/cli/app.py @@ -69,6 +69,7 @@ def _register_commands() -> None: prepare, project, serve, + stats, style, ) diff --git a/src/protspace/cli/bundle.py b/src/protspace/cli/bundle.py index fbb96e5b..cf380381 100644 --- a/src/protspace/cli/bundle.py +++ b/src/protspace/cli/bundle.py @@ -35,6 +35,23 @@ def bundle( Path, typer.Option("-o", "--output", help="Output .parquetbundle file path."), ], + statistics: Annotated[ + Path | None, + typer.Option( + "-s", + "--statistics", + help="Optional projection-statistics parquet file → 5th bundle part.", + exists=True, + ), + ] = None, + settings: Annotated[ + Path | None, + typer.Option( + "--settings", + help="Optional settings JSON (e.g. auto-generated cluster styles) → 4th bundle part.", + exists=True, + ), + ] = None, verbose: Annotated[ int, typer.Option("-v", "--verbose", count=True, help="Increase verbosity."), @@ -49,10 +66,14 @@ def bundle( """ setup_logging(verbose) + import json + import pyarrow.parquet as pq from protspace.data.io.bundle import write_bundle + settings_obj = json.loads(settings.read_text()) if settings is not None else None + metadata_path = projections / "projections_metadata.parquet" data_path = projections / "projections_data.parquet" @@ -72,10 +93,16 @@ def bundle( [("protein_id" if c == "identifier" else c) for c in col_names] ) + statistics_table = ( + pq.read_table(str(statistics)) if statistics is not None else None + ) + output_path = output.with_suffix(".parquetbundle") write_bundle( [annotations_table, metadata_table, data_table], output_path, + settings=settings_obj, + statistics=statistics_table, ) typer.echo(f"Saved: {output_path}") diff --git a/src/protspace/cli/prepare.py b/src/protspace/cli/prepare.py index 4831f6b2..05629993 100644 --- a/src/protspace/cli/prepare.py +++ b/src/protspace/cli/prepare.py @@ -115,6 +115,14 @@ rich_help_panel="Annotations", ), ] +Opt_Stats = Annotated[ + bool, + typer.Option( + "--stats/--no-stats", + help="Compute projection statistics (cluster-validity + faithfulness).", + rich_help_panel="Output", + ), +] REFETCH_STAGES = frozenset( { "query", @@ -290,6 +298,7 @@ def prepare( # Annotations annotations: Opt_Annotations = None, scores: Opt_Scores = True, + stats: Opt_Stats = True, refetch: Opt_Refetch = None, # Output output: Opt_Output = Path("."), @@ -505,6 +514,7 @@ def prepare( bundled=bundled, keep_tmp=keep_tmp, no_scores=not scores, + stats=stats, refetch_stages=refetch_stages, annotations=annotation_list, intermediate_dir=cache_dir, diff --git a/src/protspace/cli/project.py b/src/protspace/cli/project.py index 123c3e27..bc0f7931 100644 --- a/src/protspace/cli/project.py +++ b/src/protspace/cli/project.py @@ -153,6 +153,7 @@ def project( dims, disambiguation_suffix(spec, method_counts), ) + reduction["source"] = emb_set.name all_reductions.append(reduction) output.mkdir(parents=True, exist_ok=True) diff --git a/src/protspace/cli/stats.py b/src/protspace/cli/stats.py new file mode 100644 index 00000000..48317e31 --- /dev/null +++ b/src/protspace/cli/stats.py @@ -0,0 +1,265 @@ +"""protspace stats — compute projection statistics for an existing project. + +Loads the embedding H5(s) (for faithfulness) and the projection coordinates from +a project directory, computes the tidy statistics table, and writes it as a +parquet file — the optional fifth ``.parquetbundle`` part. No annotations are +needed. Best-effort: per-statistic failures are isolated by the driver. +""" + +import json +import logging +from pathlib import Path +from typing import Annotated + +import typer + +from protspace.cli.app import app, setup_logging + +logger = logging.getLogger(__name__) + + +def _load_reductions( + projections: Path, default_metric: str = "euclidean" +) -> list[dict]: + """Reconstruct per-projection ``{name, data, ids, info, source}`` from a dir. + + Reads ``projections_data.parquet`` (long table of projection_name/identifier/ + x/y/z) into per-projection coordinate arrays + id order, and the reducer + metric + source-embedding name from ``projections_metadata.parquet``. + """ + import numpy as np + import pyarrow.parquet as pq + + data_path = projections / "projections_data.parquet" + meta_path = projections / "projections_metadata.parquet" + if not data_path.exists(): + raise typer.BadParameter(f"Missing: {data_path}") + + metric_by_name: dict[str, str] = {} + dims_by_name: dict[str, int] = {} + source_by_name: dict[str, str] = {} + if meta_path.exists(): + mt = pq.read_table(str(meta_path)).to_pydict() + names = mt.get("projection_name", []) + infos = mt.get("info_json", []) + dims = mt.get("dimensions", []) + sources = mt.get("source", []) + for i, nm in enumerate(names): + try: + info = json.loads(infos[i]) if i < len(infos) and infos[i] else {} + except (json.JSONDecodeError, TypeError): + info = {} + metric_by_name[nm] = info.get("metric") or default_metric + if i < len(dims): + dims_by_name[nm] = int(dims[i]) + if i < len(sources) and sources[i]: + source_by_name[nm] = sources[i] + + dt = pq.read_table(str(data_path)).to_pydict() + pnames = dt["projection_name"] + idents = dt["identifier"] + xs, ys = dt["x"], dt["y"] + zs = dt.get("z", [None] * len(pnames)) + + grouped: dict[str, dict] = {} + for i in range(len(pnames)): + g = grouped.setdefault(pnames[i], {"ids": [], "x": [], "y": [], "z": []}) + g["ids"].append(idents[i]) + g["x"].append(xs[i]) + g["y"].append(ys[i]) + g["z"].append(zs[i]) + + reductions: list[dict] = [] + for nm, g in grouped.items(): + dims = dims_by_name.get(nm, 2) + if dims == 3 and any(v is not None for v in g["z"]): + coords = np.array([g["x"], g["y"], g["z"]], dtype=float).T + else: + coords = np.array([g["x"], g["y"]], dtype=float).T + red = { + "name": nm, + "data": coords, + "ids": list(g["ids"]), + "info": {"metric": metric_by_name.get(nm, default_metric)}, + } + if nm in source_by_name: + red["source"] = source_by_name[nm] + reductions.append(red) + return reductions + + +def _merge_quality_into_metadata(meta_path: Path, quality_by_name: dict) -> None: + """Fold faithfulness ``quality`` objects into ``projections_metadata.parquet``. + + Rewrites the file in place, parsing each row's ``info_json``, injecting the + matching projection's ``quality`` (preserving the reducer's existing info), and + re-serialising — leaving every other column untouched. This is how the + standalone ``stats`` path carries faithfulness into the bundle: a later + ``protspace bundle -p`` reads the enriched metadata as the bundle's 2nd part. + """ + import pyarrow as pa + import pyarrow.parquet as pq + + if not quality_by_name or not meta_path.exists(): + return + table = pq.read_table(str(meta_path)) + if ( + "projection_name" not in table.column_names + or "info_json" not in table.column_names + ): + return + + names = table.column("projection_name").to_pylist() + infos = table.column("info_json").to_pylist() + new_infos: list[str] = [] + for nm, raw in zip(names, infos, strict=False): + try: + info = json.loads(raw) if raw else {} + except (json.JSONDecodeError, TypeError): + info = {} + quality = quality_by_name.get(nm) + if quality is not None: + info["quality"] = quality + new_infos.append(json.dumps(info)) + + idx = table.column_names.index("info_json") + table = table.set_column(idx, "info_json", pa.array(new_infos, type=pa.string())) + pq.write_table(table, str(meta_path)) + + +def _merge_annotations_with_columns(ann_path: Path, report) -> int: + """Merge the report's per-protein ``AnnotationColumn``s into ``ann_path``. + + Rewrites the annotations parquet in place with the computed ``cluster_*`` / + ``silhouette_*`` columns joined by identifier. Added columns are stringified + (membership → category labels, silhouette → numeric strings, absent → empty) + so they match the prepare path's all-string annotations and the frontend's + content-based type inference. Returns the number of columns added. + """ + import pyarrow as pa + import pyarrow.parquet as pq + + from protspace.stats.carriage import merge_annotation_columns + + if not report.annotation_columns or not ann_path.exists(): + return 0 + df = pq.read_table(str(ann_path)).to_pandas() + id_col = "identifier" if "identifier" in df.columns else df.columns[0] + added = merge_annotation_columns(report, df, id_col=id_col) + for name in added: + df[name] = df[name].fillna("").astype(str) + pq.write_table(pa.Table.from_pandas(df, preserve_index=False), str(ann_path)) + return len(added) + + +@app.command() +def stats( + input: Annotated[ + list[str], + typer.Option( + "-i", + "--input", + help="HDF5 embedding file(s). Repeat for multi-embedding. Name override: -i file.h5:name", + ), + ], + projections: Annotated[ + Path, + typer.Option( + "-p", + "--projections", + help="Directory with projections_metadata.parquet and projections_data.parquet.", + exists=True, + ), + ], + output: Annotated[ + Path, + typer.Option("-o", "--output", help="Output statistics.parquet path."), + ], + annotations: Annotated[ + Path | None, + typer.Option( + "-a", + "--annotations", + help="Annotations parquet to enrich in place with per-protein " + "cluster-membership + silhouette columns. Omit to skip per-protein outputs.", + ), + ] = None, + settings_out: Annotated[ + Path | None, + typer.Option( + "--settings-out", + help="Write auto-generated cluster-membership legend styles here (JSON) " + "for `protspace bundle --settings`. Only with -a/--annotations.", + ), + ] = None, + seed: Annotated[int, typer.Option("--seed", help="Random seed.")] = 42, + metric: Annotated[ + str, + typer.Option( + "--metric", + help="High-dim distance metric for faithfulness when the projection metadata omits one (e.g. PCA/MDS).", + ), + ] = "euclidean", + verbose: Annotated[ + int, typer.Option("-v", "--verbose", count=True, help="Increase verbosity.") + ] = 0, +) -> None: + """Compute cluster-validity + faithfulness statistics for each projection.""" + setup_logging(verbose) + + import pyarrow.parquet as pq + + from protspace.cli.prepare import _parse_input_specs + from protspace.data.loaders import load_h5 + from protspace.stats import compute_statistics + from protspace.stats.carriage import ( + build_cluster_legend_settings, + route_faithfulness_to_metadata, + ) + + embedding_sets = [ + load_h5([path], name_override=name_override) + for path, name_override in _parse_input_specs(list(input)) + ] + + reductions = _load_reductions(projections, default_metric=metric) + # Per-protein outputs (cluster membership + per-point silhouette) are only + # computed when there's an annotations file to land them in — silhouette_samples + # is O(n^2), so we don't pay for it with nowhere to write. + params = {} if annotations is not None else {"cluster_annotations": False} + report = compute_statistics( + embedding_sets, + reductions, + rng_seed=seed, + params=params, + default_metric=metric, + ) + + # Route per-projection faithfulness into projections_metadata.info_json.quality + # (rewritten in place); the aggregate fifth part keeps validity/meta rows only. + route_faithfulness_to_metadata(report, reductions) + quality_by_name = { + r["name"]: r["info"]["quality"] + for r in reductions + if isinstance(r.get("info"), dict) and "quality" in r["info"] + } + _merge_quality_into_metadata( + projections / "projections_metadata.parquet", quality_by_name + ) + + n_cols = 0 + if annotations is not None: + n_cols = _merge_annotations_with_columns(annotations, report) + if settings_out is not None: + cluster_settings = build_cluster_legend_settings(report) + settings_out.parent.mkdir(parents=True, exist_ok=True) + settings_out.write_text(json.dumps(cluster_settings)) + + table = report.to_arrow() + output.parent.mkdir(parents=True, exist_ok=True) + pq.write_table(table, str(output)) + typer.echo( + f"Saved {table.num_rows} statistic row(s): {output}" + f" (faithfulness → {len(quality_by_name)} projection(s);" + f" {n_cols} computed annotation column(s))" + ) diff --git a/src/protspace/data/io/bundle.py b/src/protspace/data/io/bundle.py index ca625a26..a5b90ee1 100644 --- a/src/protspace/data/io/bundle.py +++ b/src/protspace/data/io/bundle.py @@ -2,7 +2,13 @@ A .parquetbundle file concatenates multiple parquet files separated by a delimiter. The first three parts are the core data tables; an optional -fourth part carries settings (annotation colours, shapes, etc.). +fourth part carries settings (annotation colours, shapes, etc.); an optional +fifth part carries projection statistics. + +Positional layout: ``core(3) + settings? + statistics?``. When statistics are +present but settings are absent, the fourth part is written as **zero bytes** so +the statistics part is unambiguously the fifth — readers and writers branch on +the fourth part's emptiness, not on the raw part count. """ import io @@ -25,13 +31,15 @@ ] SETTINGS_FILENAME = "settings.parquet" +STATISTICS_FILENAME = "statistics.parquet" def extract_bundle_to_dir(bundle_path: Path, target_dir: Path | None = None) -> str: """Extract a .parquetbundle into separate parquet files on disk. - Supports bundles with 3 parts (core data only) or 4 parts (core data + - settings). + Supports bundles with 3 parts (core data only), 4 parts (core + settings), + or 5 parts (core + settings + statistics, where the settings part may be + zero bytes). Args: bundle_path: Path to the .parquetbundle file. @@ -52,23 +60,31 @@ def extract_bundle_to_dir(bundle_path: Path, target_dir: Path | None = None) -> parts = content.split(PARQUET_BUNDLE_DELIMITER) - if len(parts) < 3 or len(parts) > 4: - raise ValueError(f"Expected 3 or 4 parts in parquetbundle, found {len(parts)}") + if len(parts) < 3 or len(parts) > 5: + raise ValueError(f"Expected 3 to 5 parts in parquetbundle, found {len(parts)}") # Write core parts for part_bytes, filename in zip(parts[:3], CORE_FILENAMES, strict=False): if part_bytes: (target_dir / filename).write_bytes(part_bytes) - # Write optional settings part - if len(parts) == 4 and parts[3]: + # Write optional settings part (branch on emptiness, not part count) + if len(parts) >= 4 and parts[3]: (target_dir / SETTINGS_FILENAME).write_bytes(parts[3]) + # Write optional statistics part + if len(parts) == 5 and parts[4]: + (target_dir / STATISTICS_FILENAME).write_bytes(parts[4]) + return str(target_dir) def read_bundle(bundle_path: Path) -> tuple[list[bytes], dict | None]: - """Read a bundle and return raw part bytes plus parsed settings. + """Read a bundle and return raw core part bytes plus parsed settings. + + The return shape is preserved (``(core_parts, settings)``) so existing + callers keep working; use :func:`read_statistics_from_bundle` for the + optional fifth part. Returns: (core_parts_bytes, settings_dict_or_None) @@ -78,28 +94,43 @@ def read_bundle(bundle_path: Path) -> tuple[list[bytes], dict | None]: parts = content.split(PARQUET_BUNDLE_DELIMITER) - if len(parts) < 3 or len(parts) > 4: - raise ValueError(f"Expected 3 or 4 parts in parquetbundle, found {len(parts)}") + if len(parts) < 3 or len(parts) > 5: + raise ValueError(f"Expected 3 to 5 parts in parquetbundle, found {len(parts)}") settings = None - if len(parts) == 4 and parts[3]: + if len(parts) >= 4 and parts[3]: settings = read_settings_from_bytes(parts[3]) return parts[:3], settings +def read_statistics_from_bundle(bundle_path: Path) -> bytes | None: + """Return the raw statistics parquet bytes (fifth part), or None if absent.""" + with open(bundle_path, "rb") as f: + content = f.read() + + parts = content.split(PARQUET_BUNDLE_DELIMITER) + if len(parts) == 5 and parts[4]: + return parts[4] + return None + + def write_bundle( tables: list[pa.Table], bundle_path: Path, settings: dict | None = None, + statistics: "pa.Table | None" = None, ) -> None: - """Write Arrow tables (and optional settings) to a .parquetbundle. + """Write Arrow tables (and optional settings/statistics) to a .parquetbundle. Args: tables: List of 3 Arrow tables (annotations, projections_metadata, projections_data). bundle_path: Output file path. settings: Optional settings dict to include as 4th part. + statistics: Optional projection-statistics Arrow table to include as the + 5th part. When given without ``settings``, a zero-byte settings slot + is written so the statistics part stays at position five. """ bundle_path.parent.mkdir(parents=True, exist_ok=True) @@ -111,9 +142,18 @@ def write_bundle( pq.write_table(table, buf) f.write(buf.getvalue()) - if settings is not None: + # A settings slot must exist whenever statistics follow it. + if settings is not None or statistics is not None: + f.write(PARQUET_BUNDLE_DELIMITER) + if settings is not None: + f.write(create_settings_parquet(settings)) + # else: zero-byte settings slot + + if statistics is not None: f.write(PARQUET_BUNDLE_DELIMITER) - f.write(create_settings_parquet(settings)) + buf = io.BytesIO() + pq.write_table(statistics, buf) + f.write(buf.getvalue()) logger.info(f"Saved bundled output to: {bundle_path}") @@ -125,23 +165,24 @@ def replace_settings_in_bundle( ) -> None: """Append or replace the settings (4th) part in a bundle. - The three core parts are preserved byte-for-byte. + The three core parts are preserved byte-for-byte, and an existing statistics + (5th) part is preserved so styling a statistics-bearing bundle is non-lossy. """ with open(input_path, "rb") as f: content = f.read() parts = content.split(PARQUET_BUNDLE_DELIMITER) - if len(parts) < 3: - raise ValueError( - f"Expected at least 3 parts in parquetbundle, found {len(parts)}" - ) + if len(parts) < 3 or len(parts) > 5: + raise ValueError(f"Expected 3 to 5 parts in parquetbundle, found {len(parts)}") settings_bytes = create_settings_parquet(settings) - # Build new content: first 3 parts + new settings - core = PARQUET_BUNDLE_DELIMITER.join(parts[:3]) - new_content = core + PARQUET_BUNDLE_DELIMITER + settings_bytes + # core(3) + new settings, preserving a trailing statistics part if present. + new_parts = parts[:3] + [settings_bytes] + if len(parts) == 5: + new_parts.append(parts[4]) + new_content = PARQUET_BUNDLE_DELIMITER.join(new_parts) output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, "wb") as f: diff --git a/src/protspace/data/processors/base_processor.py b/src/protspace/data/processors/base_processor.py index 2dbc0359..9008f6df 100644 --- a/src/protspace/data/processors/base_processor.py +++ b/src/protspace/data/processors/base_processor.py @@ -108,7 +108,12 @@ def create_output( } def save_output( - self, data: dict[str, pa.Table], output_path: Path, bundled: bool = True + self, + data: dict[str, pa.Table], + output_path: Path, + bundled: bool = True, + statistics: pa.Table | None = None, + settings: dict | None = None, ): """Save output data to Parquet files using Apache Arrow. @@ -116,6 +121,9 @@ def save_output( data: Dictionary of Apache Arrow tables to save output_path: Path for output (file or directory) bundled: Whether to bundle into single .parquetbundle file + statistics: Optional projection-statistics table → 5th bundle part. + settings: Optional bundle settings (e.g. auto-generated cluster styles) + → 4th bundle part. """ # Custom filename mapping for better naming filename_mapping = { @@ -138,7 +146,12 @@ def save_output( output_path.mkdir(parents=True, exist_ok=True) bundle_path = output_path / "data.parquetbundle" - write_bundle(list(data.values()), bundle_path) + write_bundle( + list(data.values()), + bundle_path, + settings=settings, + statistics=statistics, + ) else: # Save as separate parquet files # output_path must be a directory @@ -152,6 +165,9 @@ def save_output( # Overwrite existing files pq.write_table(table, str(table_path)) + if statistics is not None: + pq.write_table(statistics, str(base_path / "statistics.parquet")) + logger.info(f"Saved separate parquet files to: {base_path}") def _create_protein_annotations_table(self, metadata: pd.DataFrame) -> pa.Table: @@ -182,6 +198,9 @@ def _create_projections_metadata_table( "projection_name": reduction["name"], "dimensions": reduction["dimensions"], "info_json": json.dumps(reduction["info"]), + # Raw source-embedding name, so `protspace stats` can map each + # projection back to its embedding in multi-embedding runs. + "source": str(reduction.get("source", "")), } ) diff --git a/src/protspace/data/processors/pipeline.py b/src/protspace/data/processors/pipeline.py index 10ff2a8a..0912f7fc 100644 --- a/src/protspace/data/processors/pipeline.py +++ b/src/protspace/data/processors/pipeline.py @@ -73,6 +73,7 @@ class PipelineConfig: bundled: bool = True keep_tmp: bool = False no_scores: bool = False + stats: bool = False refetch_stages: frozenset[str] = field(default_factory=frozenset) annotations: list[str] | None = None intermediate_dir: Path | None = None @@ -254,10 +255,23 @@ def run(self, embedding_sets: list[EmbeddingSet]) -> Path: # DR: each embedding set × each method all_reductions = self._run_reductions(embedding_sets) + # Projection statistics (best-effort; never fail the run for a secondary + # artifact). Computed here where embeddings and projections coexist. + statistics_table = None + self._stats_settings: dict = {} + if self.config.stats: + statistics_table = self._compute_statistics( + embedding_sets, all_reductions, all_headers, metadata + ) + # Create and save output output = self.base.create_output(metadata, all_reductions, all_headers) self.base.save_output( - output, self.config.output_path, bundled=self.config.bundled + output, + self.config.output_path, + bundled=self.config.bundled, + statistics=statistics_table, + settings=self._stats_settings or None, ) logger.info( @@ -616,6 +630,7 @@ def _run_reductions( emb_set.name, MDS_NAME, 2, global_params ) if cached: + cached["source"] = emb_set.name all_reductions.append(cached) cached_projections.append(f"MDS 2 ({emb_set.name})") continue @@ -625,6 +640,7 @@ def _run_reductions( self.base, effective_params, MDS_NAME, 2, emb_set.data ) reduction["name"] = format_projection_name(emb_set.name, MDS_NAME, 2) + reduction["source"] = emb_set.name all_reductions.append(reduction) self._save_projection_cache( emb_set.name, MDS_NAME, 2, reduction, global_params @@ -649,6 +665,7 @@ def _run_reductions( emb_set.name, method, dims, effective_params, param_suffix ) if cached: + cached["source"] = emb_set.name all_reductions.append(cached) cached_projections.append( f"{method.upper()} {dims} ({emb_set.name})" @@ -663,6 +680,7 @@ def _run_reductions( reduction["name"] = format_projection_name( emb_set.name, method, dims, param_suffix ) + reduction["source"] = emb_set.name all_reductions.append(reduction) self._save_projection_cache( emb_set.name, method, dims, reduction, effective_params @@ -677,3 +695,55 @@ def _run_reductions( ) return all_reductions + + def _compute_statistics( + self, embedding_sets, all_reductions, all_headers, metadata=None + ): + """Compute projection statistics, returning an Arrow table or None. + + Best-effort: any failure is logged and yields ``None`` so the bundle + still ships. Each reduction's coordinate rows correspond to + ``all_headers`` (the common header order), which is also the embedding + row order after header validation — so faithfulness aligns cleanly. + + Routes outputs to their parts in place: faithfulness → each projection's + ``info_json.quality``; per-protein cluster membership / silhouette → + columns on ``metadata`` (joined by identifier). The returned table is the + aggregate-validity-only fifth part. + """ + try: + from protspace.stats import compute_statistics + from protspace.stats.carriage import ( + build_cluster_legend_settings, + merge_annotation_columns, + route_faithfulness_to_metadata, + ) + + for red in all_reductions: + red.setdefault("ids", all_headers) + report = compute_statistics( + embedding_sets, + all_reductions, + rng_seed=self.config.reducer_params.random_state, + # Faithfulness high-dim metric: reducers like PCA/MDS/PaCMAP omit + # 'metric' from their params, so fall back to the run's metric + # rather than silently assuming euclidean. + default_metric=self.config.reducer_params.metric, + ) + route_faithfulness_to_metadata(report, all_reductions) + if metadata is not None and report.annotation_columns: + added = merge_annotation_columns(report, metadata) + # Auto-style the membership columns so clusters are colored when + # selected (a full legend envelope → the bundle's settings part). + self._stats_settings = build_cluster_legend_settings(report) + logger.info( + "Routed %d computed annotation column(s); styled %d", + len(added), + len(self._stats_settings), + ) + table = report.to_arrow() + logger.info("Computed %d projection-statistic row(s)", table.num_rows) + return table if table.num_rows else None + except Exception as exc: # noqa: BLE001 - statistics are secondary + logger.warning("Statistics computation failed: %s", exc) + return None diff --git a/src/protspace/stats/__init__.py b/src/protspace/stats/__init__.py new file mode 100644 index 00000000..0c1e550a --- /dev/null +++ b/src/protspace/stats/__init__.py @@ -0,0 +1,31 @@ +"""Projection statistics — registry + entry point. + +Mirrors the lazy ``REDUCERS`` pattern in ``protspace.utils``: statistic classes +(which pull in scikit-learn) are imported on first ``get_statistics()`` call, not +at package import, so ``import protspace`` / CLI startup stays fast. +""" + +from __future__ import annotations + +_STATISTICS: list | None = None + + +def get_statistics() -> list: + """Return the registered Statistic instances (lazy-imported).""" + global _STATISTICS + if _STATISTICS is None: + from protspace.stats.metrics.faithfulness import FaithfulnessStatistic + from protspace.stats.metrics.validity import ClusterValidityStatistic + + _STATISTICS = [ClusterValidityStatistic(), FaithfulnessStatistic()] + return _STATISTICS + + +def compute_statistics(*args, **kwargs): + """Run the statistics driver (lazy import to keep this module light).""" + from protspace.stats.driver import compute_statistics as _compute + + return _compute(*args, **kwargs) + + +__all__ = ["get_statistics", "compute_statistics"] diff --git a/src/protspace/stats/base.py b/src/protspace/stats/base.py new file mode 100644 index 00000000..c8ded23b --- /dev/null +++ b/src/protspace/stats/base.py @@ -0,0 +1,168 @@ +"""Core data structures for projection statistics. + +A ``Statistic`` describes a projection (and optionally its source embedding). It +declares the inputs it needs and returns one or more ``StatRow`` records. The +tidy long-format table produced by ``StatsReport.to_arrow`` (eight columns) is +the bundle-boundary contract consumed downstream. + +Heavy imports (scikit-learn) live inside the metric/cluster modules, function- +local, so importing this package does not pull sklearn into CLI startup. +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from typing import Any, Protocol + +import numpy as np +import pyarrow as pa + +# The frozen eight-column schema. New scalar statistics add rows, never columns; +# any per-source attribute (e.g. an annotation column name) goes in ``extra_json``. +STATS_SCHEMA = pa.schema( + [ + ("space_kind", pa.string()), + ("space_name", pa.string()), + ("stat_family", pa.string()), + ("label_kind", pa.string()), + ("metric", pa.string()), + ("metric_kind", pa.string()), + ("value", pa.float64()), + ("extra_json", pa.string()), + ] +) + + +def _json_default(o: Any): + if isinstance(o, np.integer): + return int(o) + if isinstance(o, np.floating): + return float(o) + if isinstance(o, np.ndarray): + return o.tolist() + return str(o) + + +@dataclass +class StatContext: + """Inputs handed to a statistic for one projection space. + + ``coords`` and ``embedding`` are row-aligned to ``ids`` (an id-intersection + join is performed by the driver), so faithfulness can compare them directly. + """ + + space_kind: str + space_name: str + coords: np.ndarray # FULL projection coordinates (cluster_validity uses these) + ids: list[str] # ids for `coords` + rng_seed: int = 42 + embedding: np.ndarray | None = None # source embedding, aligned to embedding_coords + embedding_coords: np.ndarray | None = ( + None # projection coords aligned to `embedding` + ) + embedding_ids: list[str] | None = ( + None # ids for the aligned embedding/embedding_coords + ) + embedding_name: str | None = None + high_dim_metric: str = "euclidean" + params: dict = field(default_factory=dict) + + +@dataclass +class StatRow: + """One statistic value. + + ``destination`` routes the row to a bundle part at carriage time: + ``statistics_part`` (the tidy 8-column table — the default, so every existing + construction is unchanged), ``projection_metadata`` (folded into a projection's + ``info_json``), or ``annotation`` (a per-protein column). It is carriage + metadata, not a tidy-table column, so ``to_record`` never emits it. + """ + + space_kind: str + space_name: str + stat_family: str + label_kind: str + metric: str + metric_kind: str + value: float + extra: dict = field(default_factory=dict) + destination: str = "statistics_part" + + def to_record(self) -> dict: + return { + "space_kind": self.space_kind, + "space_name": self.space_name, + "stat_family": self.stat_family, + "label_kind": self.label_kind, + "metric": self.metric, + "metric_kind": self.metric_kind, + "value": float(self.value), + "extra_json": json.dumps(self.extra, sort_keys=True, default=_json_default), + } + + +@dataclass +class AnnotationColumn: + """A per-protein statistic output destined for the ``protein_annotations`` part. + + ``values`` maps protein identifier → value (a category label string for + ``kind="categorical"``, a float for ``kind="numeric"``); a protein absent from + the mapping has no value for the column. ``kind`` records the intended frontend + type so the carriage layer can format it for content-based inference. + """ + + name: str + kind: str # "categorical" | "numeric" + values: dict[str, Any] = field(default_factory=dict) + extra: dict = field(default_factory=dict) + destination: str = "annotation" + + +@dataclass +class StatsReport: + """Accumulates statistic outputs: scalar ``StatRow``s (the tidy fifth-part + table) and per-protein ``AnnotationColumn``s (a separate carriage channel).""" + + rows: list[StatRow] = field(default_factory=list) + annotation_columns: list[AnnotationColumn] = field(default_factory=list) + + def add(self, items: list) -> None: + """Accept a mixed list of ``StatRow`` / ``AnnotationColumn`` outputs, + routing each to its channel.""" + for item in items or []: + if isinstance(item, AnnotationColumn): + self.annotation_columns.append(item) + else: + self.rows.append(item) + + def partition(self) -> dict[str, list[StatRow]]: + """Group rows by ``destination`` for the carriage layer to fan out.""" + buckets: dict[str, list[StatRow]] = {} + for row in self.rows: + buckets.setdefault(row.destination, []).append(row) + return buckets + + def to_arrow(self) -> pa.Table: + # Only the statistics-part bucket is the tidy fifth part; rows routed to + # projection metadata / annotations are carried elsewhere by the router. + records = [ + r.to_record() for r in self.rows if r.destination == "statistics_part" + ] + if not records: + return pa.Table.from_pylist([], schema=STATS_SCHEMA) + return pa.Table.from_pylist(records, schema=STATS_SCHEMA) + + +class Statistic(Protocol): + """A unit of computation over a projection space. + + ``requires_embedding`` lets the driver skip statistics when no source + embedding is available for a projection. + """ + + family: str + requires_embedding: bool + + def compute(self, ctx: StatContext) -> list[StatRow]: ... diff --git a/src/protspace/stats/carriage.py b/src/protspace/stats/carriage.py new file mode 100644 index 00000000..e3821e96 --- /dev/null +++ b/src/protspace/stats/carriage.py @@ -0,0 +1,133 @@ +"""Carriage: fan routed statistic outputs to their bundle parts. + +A ``StatRow`` declares a ``destination`` (see ``stats.base``); this module moves +the non-default destinations out of the tidy fifth part and into the bundle part +whose existing frontend consumer matches the statistic's granularity. + +Phase 1 routes **faithfulness** (per-projection scalars) into each projection's +``info_json.quality``. Per-protein ``annotation`` routing (Phase 2) will live here +too. +""" + +from __future__ import annotations + +import math + +from protspace.stats.base import StatRow, StatsReport + + +def _json_safe(value: float) -> float | None: + """Map a faithfulness value to a JSON-serialisable one. + + The skip row carries ``NaN``; ``json.dumps`` would emit the non-standard + ``NaN`` token, breaking the requirement that ``info_json`` stay valid JSON, so + ``NaN`` becomes ``None`` (the consumer reads a missing value, with the skip + marker still in the provenance). + """ + if isinstance(value, float) and math.isnan(value): + return None + return value + + +def _quality_from_rows(rows: list[StatRow]) -> dict: + """One projection's faithfulness rows → a ``quality`` dict keyed by metric. + + Each metric carries its value plus its own provenance (``k``, the high-dim + distance metric, sampling and/or skip markers) so a consumer can render + discrete per-metric rows. + """ + quality: dict = {} + for row in rows: + quality[row.metric] = {"value": _json_safe(float(row.value)), **row.extra} + return quality + + +def route_faithfulness_to_metadata(report: StatsReport, reductions: list[dict]) -> None: + """Fold ``projection_metadata``-destined rows into each reduction's + ``info['quality']``, in place. + + Rows are matched to reductions by name (``StatRow.space_name`` == + ``reduction['name']``). A reduction with no faithfulness rows is left + untouched — no ``quality`` key — so a projection without an available + embedding omits faithfulness rather than recording a wrong value. + """ + by_space: dict[str, list[StatRow]] = {} + for row in report.partition().get("projection_metadata", []): + by_space.setdefault(row.space_name, []).append(row) + if not by_space: + return + + for reduction in reductions: + rows = by_space.get(reduction.get("name")) + if not rows: + continue + info = reduction.get("info") + if not isinstance(info, dict): + info = {} + reduction["info"] = info + info["quality"] = _quality_from_rows(rows) + + +def merge_annotation_columns( + report: StatsReport, frame, id_col: str = "identifier" +) -> list[str]: + """Merge ``annotation``-destined per-protein columns into an annotations frame. + + Each ``AnnotationColumn`` is joined onto ``frame`` by identifier (proteins + absent from a column get no value, not a fabricated one). Mutates ``frame`` in + place and returns the names of the columns added — membership stays a + non-numeric string, per-point silhouette a float, so the downstream + ``.astype(str)`` writer yields categorical / continuous inference respectively. + """ + if id_col not in getattr(frame, "columns", []): + return [] + added: list[str] = [] + for col in report.annotation_columns: + frame[col.name] = frame[id_col].map(col.values) + added.append(col.name) + return added + + +def _cluster_label_sort_key(label: str): + """Order ``cluster N`` labels by their integer N, others alphabetically.""" + head, _, tail = label.rpartition(" ") + if head and tail.isdigit(): + return (0, int(tail)) + return (1, label) + + +def build_cluster_legend_settings(report: StatsReport, shape: str = "circle") -> dict: + """Build a legend-settings map auto-styling each categorical membership column. + + Returns ``{column_name: LegendPersistedSettings}`` (the bundle's settings part + format) with a full envelope per ``categorical`` ``AnnotationColumn`` — every + field the frontend's ``sanitizeLegendSettingsEntry`` requires, categories keyed + by the exact label strings with a Kelly-palette ``color`` + ``zOrder`` + ``shape`` + — so clusters are colored when selected without a manual styling step. Numeric + columns (per-point silhouette) are left to the default continuous ramp. + """ + from protspace.data.io.settings_converter import KELLYS_COLORS + + settings: dict = {} + for col in report.annotation_columns: + if col.kind != "categorical": + continue + labels = sorted(set(col.values.values()), key=_cluster_label_sort_key) + categories = { + label: { + "zOrder": i, + "color": KELLYS_COLORS[i % len(KELLYS_COLORS)], + "shape": shape, + } + for i, label in enumerate(labels) + } + settings[col.name] = { + "maxVisibleValues": max(10, len(labels)), + "shapeSize": 30, + "sortMode": "size-desc", + "hiddenValues": [], + "enableDuplicateStackUI": False, + "selectedPaletteId": "kellys", + "categories": categories, + } + return settings diff --git a/src/protspace/stats/cluster/__init__.py b/src/protspace/stats/cluster/__init__.py new file mode 100644 index 00000000..0b85cc4b --- /dev/null +++ b/src/protspace/stats/cluster/__init__.py @@ -0,0 +1 @@ +"""Clustering / label sources for cluster-validity statistics.""" diff --git a/src/protspace/stats/cluster/kmeans_elbow.py b/src/protspace/stats/cluster/kmeans_elbow.py new file mode 100644 index 00000000..1fdb4643 --- /dev/null +++ b/src/protspace/stats/cluster/kmeans_elbow.py @@ -0,0 +1,117 @@ +"""KMeans + distance-to-chord elbow for choosing the cluster count. + +The knee selection reuses the distance-to-chord geometry from the +``ProtSpaceExtractor`` prototype: the elbow is the index of maximum perpendicular +deviation of the (normalised) inertia curve from its first-to-last chord. We take +the chord-deviation *index* and map it to K — not the prototype's returned curve +y-value (which was a distance cutoff). The prototype's median-jump term is +intentionally not used (it targets sorted-distance distributions, not an inertia +curve). + +scikit-learn imports are function-local to keep CLI startup fast. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +import numpy as np + + +@dataclass +class ElbowResult: + k: int + labels: np.ndarray + k_range: list[int] + inertia: list[float] + knee_confidence: str # "high" | "low" + silhouette_optimal_k: int | None + + +def chord_deviation(y: np.ndarray) -> np.ndarray: + """Perpendicular deviation of each point of a curve from its end-to-end chord. + + The curve is normalised (x in [0, 1], y in [0, 1]) so the geometry is scale + free. Returns an array the same length as ``y``. + """ + y = np.asarray(y, dtype=float) + n = len(y) + if n < 3: + return np.zeros(n) + x = np.linspace(0.0, 1.0, n) + span = max(float(y.max() - y.min()), 1e-12) + yn = (y - y.min()) / span + x1, y1 = 0.0, float(yn[0]) + x2, y2 = 1.0, float(yn[-1]) + denom = max(float(np.hypot(x2 - x1, y2 - y1)), 1e-12) + return np.abs((y2 - y1) * x - (x2 - x1) * yn + (x2 * y1 - y2 * x1)) / denom + + +def kmeans_elbow( + X: np.ndarray, + *, + rng_seed: int = 42, + k_max: int | None = None, + n_init: int = 10, + knee_min_deviation: float = 0.05, + silhouette_sample: int = 2000, +) -> ElbowResult | None: + """Sweep KMeans over K and pick the elbow via max chord deviation. + + Returns ``None`` when there are too few points to cluster (n < 3). + ``silhouette_sample`` bounds the cost of the silhouette-optimal-K cross-check. + """ + from sklearn.cluster import KMeans + from sklearn.metrics import silhouette_score + + X = np.asarray(X, dtype=float) + n = X.shape[0] + if n < 3: + return None + + if k_max is None: + k_max = int(round(np.sqrt(n))) + k_max = max(2, min(k_max, 50, n - 1)) + k_range = list(range(2, k_max + 1)) + + inertia: list[float] = [] + labels_by_k: dict[int, np.ndarray] = {} + for k in k_range: + km = KMeans(n_clusters=k, random_state=rng_seed, n_init=n_init).fit(X) + inertia.append(float(km.inertia_)) + labels_by_k[k] = km.labels_ + + if len(k_range) < 3: + # Too short to find a chord knee; take the smallest K, flag low confidence. + k = k_range[0] + return ElbowResult(k, labels_by_k[k], k_range, inertia, "low", None) + + dev = chord_deviation(np.asarray(inertia, dtype=float)) + k_idx = int(np.argmax(dev)) + k = k_range[k_idx] + # With only 3 swept points the chord knee is structurally pinned to the middle + # K; require a wider sweep before claiming high confidence. + knee_confidence = ( + "high" + if len(k_range) >= 4 and float(dev.max()) >= knee_min_deviation + else "low" + ) + + # Silhouette-optimal K over the sweep, for cross-checking (bounded by sampling). + sil_kwargs = {} + if n > silhouette_sample: + sil_kwargs = {"sample_size": silhouette_sample, "random_state": rng_seed} + silhouette_optimal_k: int | None = None + best_sil = -np.inf + for kk in k_range: + try: + s = float(silhouette_score(X, labels_by_k[kk], **sil_kwargs)) + except Exception: + continue + if s > best_sil: + best_sil = s + silhouette_optimal_k = kk + + return ElbowResult( + k, labels_by_k[k], k_range, inertia, knee_confidence, silhouette_optimal_k + ) diff --git a/src/protspace/stats/driver.py b/src/protspace/stats/driver.py new file mode 100644 index 00000000..a2cfa0e3 --- /dev/null +++ b/src/protspace/stats/driver.py @@ -0,0 +1,167 @@ +"""Driver: build per-projection contexts, map to embeddings, run statistics. + +For each reduction (projection) it selects the *source* embedding, id-joins the +embedding rows to the projection coordinates, reads the high-dim distance metric +from the reducer params, and runs every registered statistic — isolating +per-statistic failures so a partial/empty report never raises. +""" + +from __future__ import annotations + +import logging + +import numpy as np + +from protspace.stats import get_statistics +from protspace.stats.base import StatContext, StatsReport + +logger = logging.getLogger(__name__) + + +def _select_embedding(reduction: dict, embedding_sets: list, emb_by_name: dict): + """Pick the embedding set that produced this projection. + + Preference: explicit ``source`` name → single available embedding → the set + whose headers best cover the projection's ids. + """ + src = reduction.get("source") or reduction.get("embedding_name") + if src and src in emb_by_name: + return emb_by_name[src] + if len(embedding_sets) == 1: + return embedding_sets[0] + red_ids = reduction.get("ids") + if not red_ids: + return None + target = set(red_ids) + best, best_overlap = None, 0 + for es in embedding_sets: + overlap = len(target.intersection(es.headers)) + if overlap == len(target): + return es # exact id-set cover wins immediately (no ambiguous tie-break) + if overlap > best_overlap: + best, best_overlap = es, overlap + return best + + +def _align(emb_set, red_ids, coords): + """Id-intersection join of embedding rows to projection coordinates. + + Returns ``(coords_aligned, embedding_aligned, ids_aligned)`` or ``None``. + Falls back to positional correspondence when the projection carries no ids + and the row counts already match (the common single-embedding prepare path). + """ + emb_headers = list(emb_set.headers) + emb_data = np.asarray(emb_set.data, dtype=float) + + if not red_ids: + if emb_data.shape[0] == coords.shape[0]: + return coords, emb_data, emb_headers + red_ids = emb_headers + + emb_index = {h: i for i, h in enumerate(emb_headers)} + coord_rows: list[int] = [] + emb_rows: list[int] = [] + ids: list[str] = [] + for j, pid in enumerate(red_ids): + i = emb_index.get(pid) + if i is None: + continue + coord_rows.append(j) + emb_rows.append(i) + ids.append(pid) + if not ids: + return None + return coords[coord_rows], emb_data[emb_rows], ids + + +def compute_statistics( + embedding_sets: list, + reductions: list[dict], + *, + rng_seed: int = 42, + params: dict | None = None, + statistics: list | None = None, + default_metric: str = "euclidean", +) -> StatsReport: + """Compute statistics for each projection. + + Args: + embedding_sets: ``EmbeddingSet`` objects (``.name``, ``.data``, ``.headers``). + reductions: dicts with ``name`` and ``data`` (coords); optionally ``source`` + (embedding name), ``ids`` (coords row identifiers), and ``info`` + (reducer params, used for the high-dim ``metric``). + rng_seed: deterministic seed. + params: tunables (``k``, ``k_max``, ``sample_threshold``, ``hard_ceiling``). + + Returns: + A ``StatsReport`` (may be partial/empty; never raises on a statistic error). + """ + params = params or {} + stats = statistics if statistics is not None else get_statistics() + report = StatsReport() + emb_by_name = {es.name: es for es in embedding_sets} + + for red in reductions: + try: + space_name = red.get("name", "") + full_coords = np.asarray(red["data"], dtype=float) + red_ids = red.get("ids") + full_ids = ( + list(red_ids) + if red_ids is not None + else [str(i) for i in range(full_coords.shape[0])] + ) + + info = red.get("info") or {} + high_dim_metric = ( + (info.get("metric") if isinstance(info, dict) else None) + or default_metric + or "euclidean" + ) + + embedding = None + embedding_coords = None + embedding_ids = None + embedding_name = None + emb_set = _select_embedding(red, embedding_sets, emb_by_name) + # Skip faithfulness for precomputed (similarity/distance) matrices — + # an (n, n) matrix is not a high-dim embedding. + if emb_set is not None and not getattr(emb_set, "precomputed", False): + aligned = _align(emb_set, red_ids, full_coords) + if aligned is not None: + embedding_coords, embedding, embedding_ids = aligned + embedding_name = emb_set.name + + ctx = StatContext( + space_kind="projection", + space_name=space_name, + coords=full_coords, # cluster_validity scores the FULL projection + ids=full_ids, + rng_seed=rng_seed, + embedding=embedding, + embedding_coords=embedding_coords, # faithfulness scores this aligned subset + embedding_ids=embedding_ids, + embedding_name=embedding_name, + high_dim_metric=high_dim_metric, + params=params, + ) + except Exception as exc: # noqa: BLE001 - one bad reduction must not sink the report + logger.warning( + "statistics setup failed for projection '%s': %s", red.get("name"), exc + ) + continue + + for stat in stats: + if getattr(stat, "requires_embedding", False) and ctx.embedding is None: + continue + try: + report.add(stat.compute(ctx)) + except Exception as exc: # noqa: BLE001 - statistics are secondary + logger.warning( + "statistic %s failed for projection '%s': %s", + getattr(stat, "family", stat), + ctx.space_name, + exc, + ) + + return report diff --git a/src/protspace/stats/metrics/__init__.py b/src/protspace/stats/metrics/__init__.py new file mode 100644 index 00000000..e149d098 --- /dev/null +++ b/src/protspace/stats/metrics/__init__.py @@ -0,0 +1 @@ +"""Scalar statistics: cluster-validity and projection-faithfulness.""" diff --git a/src/protspace/stats/metrics/faithfulness.py b/src/protspace/stats/metrics/faithfulness.py new file mode 100644 index 00000000..c4ad7290 --- /dev/null +++ b/src/protspace/stats/metrics/faithfulness.py @@ -0,0 +1,182 @@ +"""Projection-faithfulness statistics: kNN-overlap, trustworthiness, continuity. + +These compare a projection to its *source embedding*. The high-dimensional +distance metric (the reducer's own metric, euclidean by default unless the +projection was built with e.g. cosine) is applied to whichever computation has +the embedding as its primary input: + + trustworthiness = trustworthiness(embedding, coords, metric=high_dim_metric) + continuity = trustworthiness(coords, embedding, metric="euclidean") + +``sklearn.manifold.trustworthiness`` materialises a full pairwise distance matrix +(no ANN path), so above a sample threshold a fixed-seed shared subsample is used, +and beyond a hard ceiling the statistic is skipped with a recorded marker. + +scikit-learn imports are function-local to keep CLI startup fast. +""" + +from __future__ import annotations + +import hashlib + +import numpy as np + +from protspace.stats.base import StatContext, StatRow + +DEFAULT_K = 15 +DEFAULT_SAMPLE_THRESHOLD = 5000 +DEFAULT_HARD_CEILING = 20000 + + +def _subsample_seed(rng_seed: int, ids: list[str]) -> int: + """A seed derived from (rng_seed, sorted ids). Within a single run all + projections of one embedding share the same id row order, so they draw the + same positional subset — keeping cross-projection scores comparable.""" + digest = hashlib.sha256("|".join(sorted(ids)).encode()).hexdigest()[:8] + return (rng_seed * 2654435761 + int(digest, 16)) % (2**32) + + +def _knn_overlap(embedding, coords, k: int, metric: str) -> float: + from sklearn.neighbors import NearestNeighbors + + n = embedding.shape[0] + hi = ( + NearestNeighbors(n_neighbors=k + 1, metric=metric) + .fit(embedding) + .kneighbors(embedding, return_distance=False) + ) + lo = ( + NearestNeighbors(n_neighbors=k + 1, metric="euclidean") + .fit(coords) + .kneighbors(coords, return_distance=False) + ) + total = 0 + for i in range(n): + # Exclude self explicitly (not by slicing column 0): on coincident points + # self may not be the first returned neighbour. + hi_i = [j for j in hi[i] if j != i][:k] + lo_i = [j for j in lo[i] if j != i][:k] + total += len(set(hi_i).intersection(lo_i)) + return float(total / (n * k)) + + +class FaithfulnessStatistic: + """kNN-overlap + trustworthiness + continuity of a projection vs its embedding.""" + + family = "faithfulness" + requires_embedding = True + + def compute(self, ctx: StatContext) -> list[StatRow]: + from sklearn.manifold import trustworthiness + + if ctx.embedding is None: + return [] + emb = np.asarray(ctx.embedding, dtype=float) + # Use the projection coordinates ALIGNED to the embedding (id-intersection + # join), falling back to full coords only when no aligned view was built. + coords_src = ( + ctx.embedding_coords if ctx.embedding_coords is not None else ctx.coords + ) + coords = np.asarray(coords_src, dtype=float) + ids = ctx.embedding_ids if ctx.embedding_ids is not None else ctx.ids + n = emb.shape[0] + if n < 3: + return [] + + k = int(ctx.params.get("k", DEFAULT_K)) + sample_threshold = int( + ctx.params.get("sample_threshold", DEFAULT_SAMPLE_THRESHOLD) + ) + hard_ceiling = int(ctx.params.get("hard_ceiling", DEFAULT_HARD_CEILING)) + hi_metric = ctx.high_dim_metric or "euclidean" + + base = { + "space_kind": ctx.space_kind, + "space_name": ctx.space_name, + "stat_family": self.family, + "label_kind": "none", + # Faithfulness is a per-projection scalar: route it into the + # projection's info_json.quality, not the aggregate fifth part. + "destination": "projection_metadata", + } + + if n > hard_ceiling: + return [ + StatRow( + metric="knn_overlap", + metric_kind="faithfulness", + value=float("nan"), + extra={ + "skipped": "n_too_large", + "n": int(n), + "hard_ceiling": hard_ceiling, + "embedding": ctx.embedding_name, + }, + **base, + ) + ] + + sampled = False + if n > sample_threshold: + rng = np.random.default_rng(_subsample_seed(ctx.rng_seed, ids)) + idx = np.sort(rng.permutation(n)[:sample_threshold]) + emb = emb[idx] + coords = coords[idx] + n = len(idx) + sampled = True + + # sklearn.manifold.trustworthiness requires n_neighbors < n / 2 (strict), + # else it raises. Clamp accordingly so trustworthiness/continuity are not + # silently dropped for small n; k+1 <= n keeps the kNN-overlap query valid. + k = max(1, min(k, (n - 1) // 2)) + common = { + "k": k, + "seed": ctx.rng_seed, + "sampled": sampled, + "sample_size": int(n), + "embedding": ctx.embedding_name, + } + rows: list[StatRow] = [] + + try: + rows.append( + StatRow( + metric="knn_overlap", + metric_kind="faithfulness", + value=_knn_overlap(emb, coords, k, hi_metric), + extra={**common, "metric": hi_metric}, + **base, + ) + ) + except Exception: # noqa: BLE001 - faithfulness is best-effort + pass + try: + rows.append( + StatRow( + metric="trustworthiness", + metric_kind="faithfulness", + value=float( + trustworthiness(emb, coords, n_neighbors=k, metric=hi_metric) + ), + extra={**common, "metric": hi_metric}, + **base, + ) + ) + except Exception: # noqa: BLE001 + pass + try: + rows.append( + StatRow( + metric="continuity", + metric_kind="faithfulness", + value=float( + trustworthiness(coords, emb, n_neighbors=k, metric="euclidean") + ), + extra={**common, "metric": "euclidean"}, + **base, + ) + ) + except Exception: # noqa: BLE001 + pass + + return rows diff --git a/src/protspace/stats/metrics/validity.py b/src/protspace/stats/metrics/validity.py new file mode 100644 index 00000000..97a85464 --- /dev/null +++ b/src/protspace/stats/metrics/validity.py @@ -0,0 +1,186 @@ +"""Cluster-validity statistics on projection coordinates. + +KMeans (with an elbow-chosen K) labels the projection; silhouette, Davies-Bouldin +and Calinski-Harabasz score that labelling. The chosen K is emitted as a +``metric_kind="meta"`` row so consumers can exclude it from validity aggregates. + +scikit-learn imports are function-local to keep CLI startup fast. +""" + +from __future__ import annotations + +import numpy as np + +from protspace.stats.base import AnnotationColumn, StatContext, StatRow +from protspace.stats.cluster.kmeans_elbow import kmeans_elbow + +DEFAULT_SAMPLE_THRESHOLD = 5000 +# silhouette_samples is O(n^2) with no sampling escape hatch (unlike the aggregate +# mean), so the per-point column is skipped beyond this point count. +DEFAULT_SILHOUETTE_HARD_CEILING = 20000 + + +def _silhouette(X, labels, *, rng_seed: int, sample_threshold: int): + from sklearn.metrics import silhouette_score + + n = len(labels) + if n > sample_threshold: + val = float( + silhouette_score( + X, labels, sample_size=sample_threshold, random_state=rng_seed + ) + ) + return val, {"sampled": True, "sample_size": int(sample_threshold)} + return float(silhouette_score(X, labels)), {"sampled": False, "sample_size": int(n)} + + +def _has_singleton(labels) -> bool: + _, counts = np.unique(labels, return_counts=True) + return bool((counts < 2).any()) + + +class ClusterValidityStatistic: + """Elbow K + silhouette / Davies-Bouldin / Calinski-Harabasz on the coords.""" + + family = "cluster_validity" + requires_embedding = False + + def compute(self, ctx: StatContext) -> list: + from sklearn.metrics import calinski_harabasz_score, davies_bouldin_score + + X = np.asarray(ctx.coords, dtype=float) + n = X.shape[0] + rng_seed = ctx.rng_seed + sample_threshold = int( + ctx.params.get("sample_threshold", DEFAULT_SAMPLE_THRESHOLD) + ) + k_max = ctx.params.get("k_max") + + res = kmeans_elbow( + X, rng_seed=rng_seed, k_max=k_max, silhouette_sample=sample_threshold + ) + if res is None: # n < 3 + return [] + labels = res.labels + k = res.k + # Report the ACHIEVED number of distinct clusters (KMeans can collapse on + # coincident points), keeping the elbow's requested K in extra. + achieved = int(len(np.unique(labels))) + + base = { + "space_kind": ctx.space_kind, + "space_name": ctx.space_name, + "stat_family": self.family, + "label_kind": "kmeans_elbow", + } + rows: list[StatRow] = [ + StatRow( + metric="n_clusters", + metric_kind="meta", + value=float(achieved), + extra={ + "requested_k": k, + "k_range": [res.k_range[0], res.k_range[-1]], + "inertia": res.inertia, + "knee_confidence": res.knee_confidence, + "silhouette_optimal_k": res.silhouette_optimal_k, + "seed": rng_seed, + }, + **base, + ) + ] + + # silhouette needs 2 <= k <= n - 1 + if 2 <= k <= n - 1: + try: + sil, sx = _silhouette( + X, labels, rng_seed=rng_seed, sample_threshold=sample_threshold + ) + rows.append( + StatRow( + metric="silhouette", + metric_kind="validity", + value=sil, + extra={**sx, "seed": rng_seed}, + **base, + ) + ) + except Exception: # noqa: BLE001 - validity is best-effort + pass + + # Davies-Bouldin / Calinski-Harabasz are unstable with singleton clusters. + if not _has_singleton(labels): + for metric_name, fn in ( + ("davies_bouldin", davies_bouldin_score), + ("calinski_harabasz", calinski_harabasz_score), + ): + try: + rows.append( + StatRow( + metric=metric_name, + metric_kind="validity", + value=float(fn(X, labels)), + extra={"seed": rng_seed}, + **base, + ) + ) + except Exception: # noqa: BLE001 + pass + + outputs: list = list(rows) + + # Per-protein outputs (route-projection-statistics Phase 2): the elbow-K + # labelling becomes a categorical membership column and per-point silhouette + # a numeric column, both joined by identifier. Gated by the cluster_annotations + # param; emitted only when there is a genuine (>=2) clustering and the ids + # line up with the scored points. + if ( + ctx.params.get("cluster_annotations", True) + and achieved >= 2 + and len(ctx.ids) == n + ): + ann_extra = { + "projection": ctx.space_name, + "k": int(k), + "seed": rng_seed, + "computed": True, + } + # Membership as NON-numeric label strings so the frontend's content-based + # type inference reads the column as categorical, not a numeric ramp. + outputs.append( + AnnotationColumn( + name=f"cluster_{ctx.space_name}", + kind="categorical", + values={ + pid: f"cluster {int(lbl)}" + for pid, lbl in zip(ctx.ids, labels, strict=False) + }, + extra=ann_extra, + ) + ) + + hard_ceiling = int( + ctx.params.get( + "silhouette_hard_ceiling", DEFAULT_SILHOUETTE_HARD_CEILING + ) + ) + if 2 <= k <= n - 1 and n <= hard_ceiling: + try: + from sklearn.metrics import silhouette_samples + + samples = silhouette_samples(X, labels) + outputs.append( + AnnotationColumn( + name=f"silhouette_{ctx.space_name}", + kind="numeric", + values={ + pid: float(s) + for pid, s in zip(ctx.ids, samples, strict=False) + }, + extra=ann_extra, + ) + ) + except Exception: # noqa: BLE001 - per-point silhouette is best-effort + pass + + return outputs diff --git a/tests/test_stats.py b/tests/test_stats.py new file mode 100644 index 00000000..85ebbc09 --- /dev/null +++ b/tests/test_stats.py @@ -0,0 +1,618 @@ +"""Tests for the projection-statistics package (protspace.stats). + +Known-answer fixtures with numeric tolerances — not just "rows exist". +""" + +from __future__ import annotations + +import numpy as np +import pytest +from sklearn.datasets import make_blobs +from sklearn.decomposition import PCA + +from protspace.stats import compute_statistics, get_statistics +from protspace.stats.base import STATS_SCHEMA, StatContext, StatRow, StatsReport +from protspace.stats.cluster.kmeans_elbow import chord_deviation, kmeans_elbow +from protspace.stats.driver import compute_statistics as driver_compute +from protspace.stats.metrics.faithfulness import FaithfulnessStatistic +from protspace.stats.metrics.validity import ClusterValidityStatistic + + +class _EmbSet: + """Minimal stand-in for EmbeddingSet (name/data/headers/precomputed).""" + + def __init__(self, name, data, headers, precomputed=False): + self.name = name + self.data = np.asarray(data, dtype=float) + self.headers = list(headers) + self.precomputed = precomputed + + +def _blobs(n=300, centers=4, dim=2, seed=0): + X, y = make_blobs( + n_samples=n, centers=centers, n_features=dim, random_state=seed, cluster_std=0.6 + ) + return X, y + + +# --------------------------------------------------------------------------- # +# 1. scaffolding / contract +# --------------------------------------------------------------------------- # + + +def test_registry_returns_two_statistics(): + stats = get_statistics() + families = {s.family for s in stats} + assert families == {"cluster_validity", "faithfulness"} + + +def test_to_arrow_has_eight_column_schema(): + report = StatsReport() + report.add( + [ + StatRow( + space_kind="projection", + space_name="UMAP_2", + stat_family="cluster_validity", + label_kind="kmeans_elbow", + metric="silhouette", + metric_kind="validity", + value=0.42, + extra={"seed": 42}, + ) + ] + ) + table = report.to_arrow() + assert table.schema.names == [ + "space_kind", + "space_name", + "stat_family", + "label_kind", + "metric", + "metric_kind", + "value", + "extra_json", + ] + assert table.num_rows == 1 + assert table.column("value")[0].as_py() == pytest.approx(0.42) + + +def test_empty_report_keeps_schema(): + table = StatsReport().to_arrow() + assert table.num_rows == 0 + assert table.schema == STATS_SCHEMA + + +# --------------------------------------------------------------------------- # +# 2. cluster validity / elbow +# --------------------------------------------------------------------------- # + + +def test_elbow_recovers_known_cluster_count(): + X, _ = _blobs(n=300, centers=4, dim=2, seed=1) + res = kmeans_elbow(X, rng_seed=42) + assert res is not None + assert res.k in {3, 4, 5} + assert res.knee_confidence == "high" + + +def test_cluster_validity_separated_vs_overlapping(): + sep, _ = _blobs(n=300, centers=4, dim=2, seed=2) + ctx = StatContext( + "projection", "PCA_2", coords=sep, ids=[str(i) for i in range(len(sep))] + ) + sep_sil = { + r.metric: r.value + for r in ClusterValidityStatistic().compute(ctx) + if isinstance(r, StatRow) + }["silhouette"] + assert sep_sil > 0.6 + + # Heavily overlapping clusters: KMeans still imposes a split, but the + # silhouette is markedly lower than for well-separated clusters. + overlap, _ = make_blobs( + n_samples=300, centers=4, n_features=2, random_state=2, cluster_std=4.0 + ) + ctx2 = StatContext( + "projection", "PCA_2", coords=overlap, ids=[str(i) for i in range(300)] + ) + ov_sil = { + r.metric: r.value + for r in ClusterValidityStatistic().compute(ctx2) + if isinstance(r, StatRow) + }["silhouette"] + assert ov_sil < 0.45 + assert sep_sil > ov_sil + 0.2 + + +def test_cluster_validity_emits_meta_and_validity_kinds(): + X, _ = _blobs(n=200, centers=3, dim=2, seed=3) + ctx = StatContext( + "projection", "PCA_2", coords=X, ids=[str(i) for i in range(len(X))] + ) + rows = [ + r for r in ClusterValidityStatistic().compute(ctx) if isinstance(r, StatRow) + ] + by_metric = {r.metric: r for r in rows} + assert by_metric["n_clusters"].metric_kind == "meta" + assert by_metric["silhouette"].metric_kind == "validity" + assert {"davies_bouldin", "calinski_harabasz"} <= set(by_metric) + assert all(r.label_kind == "kmeans_elbow" for r in rows) + + +def test_cluster_validity_too_few_points(): + ctx = StatContext("projection", "PCA_2", coords=np.zeros((2, 2)), ids=["a", "b"]) + assert ClusterValidityStatistic().compute(ctx) == [] + + +def test_chord_deviation_linear_curve_is_flat(): + y = np.linspace(10.0, 1.0, 20) # straight line + dev = chord_deviation(y) + assert float(dev.max()) < 0.05 + + +# --------------------------------------------------------------------------- # +# 3. faithfulness +# --------------------------------------------------------------------------- # + + +def test_faithful_projection_scores_higher_than_random(): + X, _ = _blobs(n=200, centers=5, dim=8, seed=4) + faithful = PCA(n_components=2, random_state=0).fit_transform(X) + rng = np.random.default_rng(0) + random_proj = rng.normal(size=(200, 2)) + ids = [str(i) for i in range(200)] + + stat = FaithfulnessStatistic() + good = { + r.metric: r.value + for r in stat.compute( + StatContext( + "projection", + "PCA_2", + coords=faithful, + ids=ids, + embedding=X, + embedding_name="e", + ) + ) + } + bad = { + r.metric: r.value + for r in stat.compute( + StatContext( + "projection", + "RAND_2", + coords=random_proj, + ids=ids, + embedding=X, + embedding_name="e", + ) + ) + } + assert good["trustworthiness"] > 0.9 + assert good["trustworthiness"] > bad["trustworthiness"] + assert good["knn_overlap"] > bad["knn_overlap"] + + +def test_faithfulness_records_k_and_metric(): + X, _ = _blobs(n=120, centers=3, dim=6, seed=5) + coords = PCA(n_components=2, random_state=0).fit_transform(X) + ids = [str(i) for i in range(120)] + rows = FaithfulnessStatistic().compute( + StatContext( + "projection", + "PCA_2", + coords=coords, + ids=ids, + embedding=X, + embedding_name="e", + high_dim_metric="cosine", + ) + ) + knn = next(r for r in rows if r.metric == "knn_overlap") + assert knn.extra["k"] == 15 + assert knn.extra["metric"] == "cosine" + assert knn.label_kind == "none" + + +def test_faithfulness_skips_without_embedding(): + ctx = StatContext( + "projection", "PCA_2", coords=np.zeros((10, 2)), ids=[str(i) for i in range(10)] + ) + assert FaithfulnessStatistic().compute(ctx) == [] + + +def test_faithfulness_rows_route_to_projection_metadata(): + X, _ = _blobs(n=120, centers=3, dim=6, seed=21) + coords = PCA(n_components=2, random_state=0).fit_transform(X) + ids = [str(i) for i in range(120)] + rows = FaithfulnessStatistic().compute( + StatContext( + "projection", + "PCA_2", + coords=coords, + ids=ids, + embedding=X, + embedding_name="e", + ) + ) + assert rows # sanity: faithfulness produced rows + assert all(r.destination == "projection_metadata" for r in rows) + + +def test_faithfulness_skip_row_routes_to_projection_metadata(): + # Beyond the hard ceiling faithfulness emits a single skip row — it must also + # route to projection metadata, not the aggregate fifth part. + n = 30 + X, _ = _blobs(n=n, centers=2, dim=4, seed=22) + coords = PCA(n_components=2, random_state=0).fit_transform(X) + ids = [str(i) for i in range(n)] + rows = FaithfulnessStatistic().compute( + StatContext( + "projection", + "PCA_2", + coords=coords, + ids=ids, + embedding=X, + embedding_name="e", + params={"hard_ceiling": 10}, # force the skip path + ) + ) + assert len(rows) == 1 and rows[0].extra.get("skipped") == "n_too_large" + assert rows[0].destination == "projection_metadata" + + +# --------------------------------------------------------------------------- # +# 4. driver: mapping, alignment, failure isolation +# --------------------------------------------------------------------------- # + + +def test_driver_full_matrix_shape(): + X, _ = _blobs(n=150, centers=4, dim=5, seed=6) + coords = PCA(n_components=2, random_state=0).fit_transform(X) + headers = [f"p{i}" for i in range(150)] + emb = _EmbSet("prot_t5", X, headers) + reductions = [ + {"name": "ProtT5 — PCA 2", "data": coords, "ids": headers, "source": "prot_t5"} + ] + report = compute_statistics([emb], reductions, rng_seed=42) + metrics = {r.metric for r in report.rows} + assert { + "silhouette", + "davies_bouldin", + "calinski_harabasz", + "n_clusters", + } <= metrics + assert {"knn_overlap", "trustworthiness", "continuity"} <= metrics + assert all(r.space_name == "ProtT5 — PCA 2" for r in report.rows) + + +def test_driver_alignment_is_permutation_invariant(): + X, _ = _blobs(n=120, centers=4, dim=6, seed=7) + coords = PCA(n_components=2, random_state=0).fit_transform(X) + headers = [f"p{i}" for i in range(120)] + emb = _EmbSet("e", X, headers) + + base = compute_statistics( + [emb], + [{"name": "P", "data": coords, "ids": headers, "source": "e"}], + rng_seed=42, + ) + # Permute the projection rows + ids together; the id-join must recover pairing. + perm = np.random.default_rng(3).permutation(120) + permuted = compute_statistics( + [emb], + [ + { + "name": "P", + "data": coords[perm], + "ids": [headers[i] for i in perm], + "source": "e", + } + ], + rng_seed=42, + ) + b = {r.metric: r.value for r in base.rows} + p = {r.metric: r.value for r in permuted.rows} + assert b["trustworthiness"] == pytest.approx(p["trustworthiness"], abs=1e-9) + assert b["knn_overlap"] == pytest.approx(p["knn_overlap"], abs=1e-9) + + +def test_driver_maps_each_projection_to_its_embedding(): + Xa, _ = _blobs(n=100, centers=3, dim=5, seed=8) + Xb, _ = _blobs(n=100, centers=3, dim=7, seed=9) + ha = [f"a{i}" for i in range(100)] + hb = [f"b{i}" for i in range(100)] + ea, eb = _EmbSet("A", Xa, ha), _EmbSet("B", Xb, hb) + ca = PCA(n_components=2, random_state=0).fit_transform(Xa) + cb = PCA(n_components=2, random_state=0).fit_transform(Xb) + reductions = [ + {"name": "A — PCA 2", "data": ca, "ids": ha, "source": "A"}, + {"name": "B — PCA 2", "data": cb, "ids": hb, "source": "B"}, + ] + report = compute_statistics([ea, eb], reductions, rng_seed=42) + embs = { + r.space_name: r.extra.get("embedding") + for r in report.rows + if r.stat_family == "faithfulness" + } + assert embs["A — PCA 2"] == "A" + assert embs["B — PCA 2"] == "B" + + +def test_faithfulness_small_n_emits_trustworthiness_and_continuity(): + # Regression: for n in [4,30] the k clamp must satisfy sklearn's n_neighbors < n/2, + # so trustworthiness AND continuity are emitted (previously silently dropped). + for n in (4, 8, 12, 20, 30): + X, _ = _blobs(n=n, centers=2, dim=4, seed=n) + coords = PCA(n_components=2, random_state=0).fit_transform(X) + ids = [str(i) for i in range(n)] + rows = { + r.metric + for r in FaithfulnessStatistic().compute( + StatContext( + "projection", + "P", + coords=coords, + ids=ids, + embedding_coords=coords, + embedding=X, + embedding_ids=ids, + embedding_name="e", + ) + ) + } + assert {"knn_overlap", "trustworthiness", "continuity"} <= rows, ( + f"n={n} dropped metrics: {rows}" + ) + + +def test_cluster_validity_uses_full_projection_not_embedding_subset(): + # Regression: cluster_validity must score the FULL projection; only faithfulness + # uses the embedding-aligned subset. Embedding covers 60 of 100 projected ids. + X, _ = _blobs(n=100, centers=4, dim=5, seed=11) + coords = PCA(n_components=2, random_state=0).fit_transform(X) + headers = [f"p{i}" for i in range(100)] + emb = _EmbSet("e", X[:60], headers[:60]) # strict subset + report = compute_statistics( + [emb], + [{"name": "P", "data": coords, "ids": headers, "source": "e"}], + rng_seed=42, + ) + faith = [r for r in report.rows if r.stat_family == "faithfulness"] + assert all( + r.extra["sample_size"] == 60 for r in faith + ) # faithfulness on the subset + # cluster_validity still runs (on the full 100-point projection) + assert any(r.metric == "silhouette" for r in report.rows) + + +def test_faithfulness_honors_default_metric_when_info_lacks_metric(): + X, _ = _blobs(n=120, centers=3, dim=6, seed=12) + coords = PCA(n_components=2, random_state=0).fit_transform(X) + headers = [f"p{i}" for i in range(120)] + emb = _EmbSet("e", X, headers) + # reduction info has no 'metric' (like PCA); default_metric must be used. + report = compute_statistics( + [emb], + [{"name": "P", "data": coords, "ids": headers, "source": "e", "info": {}}], + rng_seed=42, + default_metric="cosine", + ) + knn = next(r for r in report.rows if r.metric == "knn_overlap") + assert knn.extra["metric"] == "cosine" + + +def test_precomputed_embedding_skips_faithfulness(): + headers = [f"p{i}" for i in range(40)] + sim = _EmbSet("sim", np.eye(40), headers, precomputed=True) # (n,n) similarity + X, _ = _blobs(n=40, centers=3, dim=2, seed=13) + report = compute_statistics( + [sim], + [{"name": "MDS", "data": X, "ids": headers, "source": "sim"}], + rng_seed=42, + ) + assert not any(r.stat_family == "faithfulness" for r in report.rows) + assert any(r.metric == "silhouette" for r in report.rows) + + +def test_source_disambiguates_same_id_embeddings(): + # Two embeddings sharing identical ids but different vectors; explicit source + # must route each projection to its own embedding (overlap tie would pick [0]). + headers = [f"p{i}" for i in range(80)] + Xa, _ = _blobs(n=80, centers=3, dim=5, seed=14) + Xb, _ = _blobs(n=80, centers=3, dim=5, seed=15) + ea, eb = _EmbSet("A", Xa, headers), _EmbSet("B", Xb, headers) + ca = PCA(n_components=2, random_state=0).fit_transform(Xa) + cb = PCA(n_components=2, random_state=0).fit_transform(Xb) + report = compute_statistics( + [ea, eb], + [ + {"name": "A — PCA 2", "data": ca, "ids": headers, "source": "A"}, + {"name": "B — PCA 2", "data": cb, "ids": headers, "source": "B"}, + ], + rng_seed=42, + ) + embs = { + r.space_name: r.extra.get("embedding") + for r in report.rows + if r.stat_family == "faithfulness" + } + assert embs["A — PCA 2"] == "A" + assert embs["B — PCA 2"] == "B" + + +def test_driver_isolates_failures(): + class _Boom: + family = "boom" + requires_embedding = False + + def compute(self, ctx): + raise RuntimeError("boom") + + X, _ = _blobs(n=80, centers=3, dim=2, seed=10) + headers = [str(i) for i in range(80)] + emb = _EmbSet("e", X, headers) + report = driver_compute( + [emb], + [{"name": "P", "data": X, "ids": headers, "source": "e"}], + statistics=[_Boom(), ClusterValidityStatistic()], + ) + # Boom is swallowed; cluster validity still produced rows. + assert any(r.metric == "silhouette" for r in report.rows) + + +# --------------------------------------------------------------------------- # +# 5. routing / destination (route-projection-statistics, Phase 1A) +# --------------------------------------------------------------------------- # + + +def _statrow(metric, value, *, destination=None, **extra): + kw = {} if destination is None else {"destination": destination} + return StatRow( + space_kind="projection", + space_name="PCA_2", + stat_family="cluster_validity", + label_kind="kmeans_elbow", + metric=metric, + metric_kind="validity", + value=value, + extra=extra, + **kw, + ) + + +def test_statrow_defaults_to_statistics_part_destination(): + # Every existing construction stays valid and keeps the 5th-part destination. + assert _statrow("silhouette", 0.5).destination == "statistics_part" + + +def test_destination_is_not_a_tidy_table_column(): + # destination is carriage metadata, never a column in the 8-column schema. + rec = _statrow("silhouette", 0.5).to_record() + assert "destination" not in rec + assert set(rec) == set(STATS_SCHEMA.names) + + +def test_partition_groups_rows_by_destination(): + report = StatsReport() + report.add( + [ + _statrow("silhouette", 0.5), # default -> statistics_part + _statrow("trustworthiness", 0.9, destination="projection_metadata"), + _statrow("cluster", 1.0, destination="annotation"), + ] + ) + buckets = report.partition() + assert {r.metric for r in buckets["statistics_part"]} == {"silhouette"} + assert {r.metric for r in buckets["projection_metadata"]} == {"trustworthiness"} + assert {r.metric for r in buckets["annotation"]} == {"cluster"} + + +def test_to_arrow_serializes_only_statistics_part_rows(): + report = StatsReport() + report.add( + [ + _statrow("silhouette", 0.5), # statistics_part + _statrow("trustworthiness", 0.9, destination="projection_metadata"), + _statrow("cluster", 1.0, destination="annotation"), + ] + ) + table = report.to_arrow() + assert table.schema == STATS_SCHEMA + assert table.column("metric").to_pylist() == ["silhouette"] + + +# --------------------------------------------------------------------------- # +# 6. per-protein annotation outputs (route-projection-statistics Phase 2A) +# --------------------------------------------------------------------------- # + + +def test_annotation_column_defaults_to_annotation_destination(): + from protspace.stats.base import AnnotationColumn + + col = AnnotationColumn(name="cluster_PCA_2", kind="categorical", values={"a": "c0"}) + assert col.destination == "annotation" + + +def test_report_collects_annotation_columns_separately_from_rows(): + from protspace.stats.base import AnnotationColumn + + report = StatsReport() + report.add( + [ + _statrow("silhouette", 0.5), # statistics_part row + AnnotationColumn( + name="cluster_PCA_2", kind="categorical", values={"a": "c0"} + ), + ] + ) + # the scalar row stays in rows / to_arrow; the column is a separate channel + assert [r.metric for r in report.rows] == ["silhouette"] + assert report.to_arrow().num_rows == 1 + assert [c.name for c in report.annotation_columns] == ["cluster_PCA_2"] + + +def test_cluster_validity_emits_per_protein_annotation_columns(): + from protspace.stats.base import AnnotationColumn + + X, _ = _blobs(n=200, centers=4, dim=2, seed=31) + ids = [f"p{i}" for i in range(200)] + outs = ClusterValidityStatistic().compute( + StatContext("projection", "PCA_2", coords=X, ids=ids) + ) + cols = {o.name: o for o in outs if isinstance(o, AnnotationColumn)} + assert {"cluster_PCA_2", "silhouette_PCA_2"} <= set(cols) + + mem = cols["cluster_PCA_2"] + assert mem.destination == "annotation" and mem.kind == "categorical" + assert set(mem.values) == set(ids) # one value per protein, joined by id + # non-numeric label strings so content-based inference reads categorical + assert all( + isinstance(v, str) and v.startswith("cluster ") for v in mem.values.values() + ) + assert mem.extra["k"] >= 2 and mem.extra.get("computed") is True + + sil = cols["silhouette_PCA_2"] + assert sil.kind == "numeric" + assert set(sil.values) == set(ids) + assert all(isinstance(v, float) for v in sil.values.values()) + + +def test_per_point_silhouette_skipped_beyond_hard_ceiling(): + from protspace.stats.base import AnnotationColumn + + X, _ = _blobs(n=200, centers=3, dim=2, seed=32) + ids = [f"p{i}" for i in range(200)] + outs = ClusterValidityStatistic().compute( + StatContext( + "projection", + "PCA_2", + coords=X, + ids=ids, + params={"silhouette_hard_ceiling": 50}, # n=200 > 50 + ) + ) + names = {o.name for o in outs if isinstance(o, AnnotationColumn)} + assert "cluster_PCA_2" in names # membership is cheap, still emitted + assert "silhouette_PCA_2" not in names # O(n^2) silhouette skipped + + +def test_cluster_annotations_can_be_disabled(): + from protspace.stats.base import AnnotationColumn + + X, _ = _blobs(n=150, centers=3, dim=2, seed=33) + ids = [f"p{i}" for i in range(150)] + outs = ClusterValidityStatistic().compute( + StatContext( + "projection", + "PCA_2", + coords=X, + ids=ids, + params={"cluster_annotations": False}, + ) + ) + assert not any(isinstance(o, AnnotationColumn) for o in outs) + # aggregate validity rows still produced + assert any(getattr(o, "metric", None) == "silhouette" for o in outs) diff --git a/tests/test_stats_bundle.py b/tests/test_stats_bundle.py new file mode 100644 index 00000000..e33f0fa2 --- /dev/null +++ b/tests/test_stats_bundle.py @@ -0,0 +1,99 @@ +"""Round-trip tests for the optional fifth (statistics) bundle part.""" + +from __future__ import annotations + +import pyarrow as pa +import pyarrow.parquet as pq + +from protspace.data.io.bundle import ( + PARQUET_BUNDLE_DELIMITER, + extract_bundle_to_dir, + read_bundle, + read_statistics_from_bundle, + replace_settings_in_bundle, + write_bundle, +) + + +def _core() -> list[pa.Table]: + return [ + pa.table({"protein_id": ["a", "b"]}), + pa.table({"projection_name": ["PCA_2"]}), + pa.table({"projection_name": ["PCA_2", "PCA_2"], "identifier": ["a", "b"]}), + ] + + +def _stats() -> pa.Table: + return pa.table({"space_name": ["PCA_2"], "metric": ["silhouette"], "value": [0.5]}) + + +def _ndelims(path) -> int: + return path.read_bytes().count(PARQUET_BUNDLE_DELIMITER) + + +def test_three_part_bundle_roundtrips(tmp_path): + p = tmp_path / "b.parquetbundle" + write_bundle(_core(), p) + assert _ndelims(p) == 2 + core, settings = read_bundle(p) + assert len(core) == 3 and settings is None + assert read_statistics_from_bundle(p) is None + + +def test_four_part_settings_only(tmp_path): + p = tmp_path / "b.parquetbundle" + write_bundle(_core(), p, settings={"hello": "world"}) + assert _ndelims(p) == 3 + _, settings = read_bundle(p) + assert settings == {"hello": "world"} + assert read_statistics_from_bundle(p) is None + + +def test_five_part_settings_and_stats(tmp_path): + p = tmp_path / "b.parquetbundle" + write_bundle(_core(), p, settings={"k": 1}, statistics=_stats()) + assert _ndelims(p) == 4 + _, settings = read_bundle(p) + assert settings == {"k": 1} + stats_bytes = read_statistics_from_bundle(p) + assert stats_bytes is not None + table = pq.read_table(pa.BufferReader(stats_bytes)) + assert table.column("metric")[0].as_py() == "silhouette" + + +def test_five_part_stats_only_empty_settings(tmp_path): + p = tmp_path / "b.parquetbundle" + write_bundle(_core(), p, statistics=_stats()) + assert _ndelims(p) == 4 # zero-byte settings slot keeps stats at position 5 + core, settings = read_bundle(p) + assert len(core) == 3 and settings is None + assert read_statistics_from_bundle(p) is not None + + +def test_extract_to_dir_writes_statistics(tmp_path): + p = tmp_path / "b.parquetbundle" + write_bundle(_core(), p, statistics=_stats()) + out = extract_bundle_to_dir(p, tmp_path / "out") + assert (tmp_path / "out" / "statistics.parquet").exists() + assert not (tmp_path / "out" / "settings.parquet").exists() + assert out + + +def test_style_preserves_stats_with_settings(tmp_path): + src = tmp_path / "b.parquetbundle" + write_bundle(_core(), src, settings={"old": 1}, statistics=_stats()) + out = tmp_path / "styled.parquetbundle" + replace_settings_in_bundle(src, out, {"new": 2}) + _, settings = read_bundle(out) + assert settings == {"new": 2} + assert read_statistics_from_bundle(out) is not None + + +def test_style_preserves_stats_on_stats_only_input(tmp_path): + src = tmp_path / "b.parquetbundle" + write_bundle(_core(), src, statistics=_stats()) # empty settings slot + out = tmp_path / "styled.parquetbundle" + replace_settings_in_bundle(src, out, {"new": 2}) + _, settings = read_bundle(out) + assert settings == {"new": 2} + assert read_statistics_from_bundle(out) is not None diff --git a/tests/test_stats_carriage.py b/tests/test_stats_carriage.py new file mode 100644 index 00000000..7d3aa33e --- /dev/null +++ b/tests/test_stats_carriage.py @@ -0,0 +1,272 @@ +"""Carriage router: faithfulness → projections_metadata.info_json.quality. + +Phase 1A of route-projection-statistics. Per-projection faithfulness scalars are +folded into each projection's ``info_json`` under a ``quality`` object; the +aggregate fifth part stays validity-only. +""" + +from __future__ import annotations + +import json + +import numpy as np +from sklearn.datasets import make_blobs +from sklearn.decomposition import PCA + +from protspace.data.processors.base_processor import BaseProcessor +from protspace.stats import compute_statistics +from protspace.stats.base import StatRow, StatsReport +from protspace.stats.carriage import route_faithfulness_to_metadata + + +class _EmbSet: + def __init__(self, name, data, headers, precomputed=False): + self.name = name + self.data = np.asarray(data, dtype=float) + self.headers = list(headers) + self.precomputed = precomputed + + +def _faith_row(space_name, metric_name, value, **extra): + return StatRow( + space_kind="projection", + space_name=space_name, + stat_family="faithfulness", + label_kind="none", + metric=metric_name, + metric_kind="faithfulness", + value=value, + extra=extra, + destination="projection_metadata", + ) + + +def _faith_report(space_name, **provenance): + report = StatsReport() + report.add( + [ + _faith_row(space_name, "knn_overlap", 0.8, **provenance), + _faith_row(space_name, "trustworthiness", 0.9, **provenance), + _faith_row(space_name, "continuity", 0.85, **provenance), + ] + ) + return report + + +def test_router_injects_quality_per_metric_with_provenance(): + report = _faith_report("PCA_2", k=15, metric="cosine", sampled=False) + reductions = [{"name": "PCA_2", "info": {"metric": "cosine"}}] + route_faithfulness_to_metadata(report, reductions) + + q = reductions[0]["info"]["quality"] + assert set(q) == {"knn_overlap", "trustworthiness", "continuity"} + assert q["trustworthiness"]["value"] == 0.9 + # each value records its own provenance (k + distance metric) + assert q["knn_overlap"]["k"] == 15 + assert q["knn_overlap"]["metric"] == "cosine" + # pre-existing info keys are preserved + assert reductions[0]["info"]["metric"] == "cosine" + + +def test_router_omits_quality_when_no_faithfulness(): + # No faithfulness rows (e.g. projection without an available embedding) → no key. + report = StatsReport() + reductions = [{"name": "PCA_2", "info": {"metric": "euclidean"}}] + route_faithfulness_to_metadata(report, reductions) + assert "quality" not in reductions[0]["info"] + + +def test_router_creates_info_dict_when_missing(): + report = _faith_report("PCA_2", k=15, metric="euclidean") + reductions = [{"name": "PCA_2"}] # no info dict yet + route_faithfulness_to_metadata(report, reductions) + assert "quality" in reductions[0]["info"] + + +def test_router_maps_skip_nan_to_null_and_keeps_marker(): + report = StatsReport() + report.add( + [ + _faith_row( + "PCA_2", + "knn_overlap", + float("nan"), + skipped="n_too_large", + n=30000, + hard_ceiling=20000, + ) + ] + ) + reductions = [{"name": "PCA_2", "info": {}}] + route_faithfulness_to_metadata(report, reductions) + + q = reductions[0]["info"]["quality"]["knn_overlap"] + assert q["value"] is None # NaN is not valid JSON → null + assert q["skipped"] == "n_too_large" + # The injected info must serialize to strictly valid JSON (no `NaN` token). + serialized = json.dumps(reductions[0]["info"]) + assert "NaN" not in serialized + assert json.loads(serialized)["quality"]["knn_overlap"]["value"] is None + + +def test_router_round_trips_through_projections_metadata_table(): + X, _ = make_blobs(n_samples=120, centers=3, n_features=6, random_state=5) + coords = PCA(n_components=2, random_state=0).fit_transform(X) + headers = [f"p{i}" for i in range(120)] + emb = _EmbSet("e", X, headers) + reductions = [ + { + "name": "P", + "dimensions": 2, + "info": {"metric": "euclidean"}, + "data": coords, + "ids": headers, + "source": "e", + } + ] + report = compute_statistics([emb], reductions, rng_seed=42) + route_faithfulness_to_metadata(report, reductions) + + table = BaseProcessor({}, {})._create_projections_metadata_table(reductions) + info = json.loads(table.column("info_json")[0].as_py()) + assert {"knn_overlap", "trustworthiness", "continuity"} <= set(info["quality"]) + # faithfulness stayed OUT of the aggregate fifth part + families = set(report.to_arrow().column("stat_family").to_pylist()) + assert "faithfulness" not in families + + +def test_build_cluster_legend_settings_produces_valid_envelope(): + from protspace.stats.base import AnnotationColumn, StatsReport + from protspace.stats.carriage import build_cluster_legend_settings + + report = StatsReport() + report.add( + [ + AnnotationColumn( + name="cluster_P", + kind="categorical", + values={"a": "cluster 0", "b": "cluster 1", "c": "cluster 0"}, + ), + AnnotationColumn(name="silhouette_P", kind="numeric", values={"a": 0.5}), + ] + ) + settings = build_cluster_legend_settings(report) + + # only the categorical membership column is styled (silhouette is a numeric ramp) + assert set(settings) == {"cluster_P"} + env = settings["cluster_P"] + # every field sanitizeLegendSettingsEntry requires, with the right types + assert isinstance(env["maxVisibleValues"], int) + assert isinstance(env["shapeSize"], int | float) + assert env["sortMode"] in { + "size-asc", + "size-desc", + "alpha-asc", + "alpha-desc", + "manual", + "manual-reverse", + } + assert env["enableDuplicateStackUI"] is False + assert env["hiddenValues"] == [] + assert env["selectedPaletteId"] == "kellys" + cats = env["categories"] + assert set(cats) == {"cluster 0", "cluster 1"} + colors = set() + for cat in cats.values(): + assert isinstance(cat["zOrder"], int) + assert isinstance(cat["color"], str) and cat["color"].startswith("#") + assert isinstance(cat["shape"], str) + colors.add(cat["color"]) + assert len(colors) == 2 # distinct palette colors per cluster + + +def test_merge_annotation_columns_joins_by_identifier(): + import pandas as pd + + from protspace.stats.base import AnnotationColumn, StatsReport + from protspace.stats.carriage import merge_annotation_columns + + report = StatsReport() + report.add( + [ + AnnotationColumn( + name="cluster_P", + kind="categorical", + values={"a": "cluster 0", "b": "cluster 1"}, + ), + AnnotationColumn( + name="silhouette_P", kind="numeric", values={"a": 0.5, "b": 0.2} + ), + ] + ) + frame = pd.DataFrame({"identifier": ["a", "b", "c"]}) + added = merge_annotation_columns(report, frame) + + assert added == ["cluster_P", "silhouette_P"] + assert frame.loc[frame.identifier == "a", "cluster_P"].item() == "cluster 0" + assert frame.loc[frame.identifier == "b", "silhouette_P"].item() == 0.2 + # a protein absent from the column gets no value (not a fabricated one) + assert pd.isna(frame.loc[frame.identifier == "c", "cluster_P"].item()) + + +def test_annotation_columns_are_typed_in_protein_annotations_table(): + import pandas as pd + + from protspace.data.processors.base_processor import BaseProcessor + from protspace.stats import compute_statistics + from protspace.stats.carriage import merge_annotation_columns + + X, _ = make_blobs(n_samples=120, centers=3, n_features=5, random_state=7) + headers = [f"p{i}" for i in range(120)] + coords = PCA(n_components=2, random_state=0).fit_transform(X) + emb = _EmbSet("e", X, headers) + reductions = [ + { + "name": "P", + "dimensions": 2, + "info": {}, + "data": coords, + "ids": headers, + "source": "e", + } + ] + report = compute_statistics([emb], reductions, rng_seed=42) + metadata = pd.DataFrame({"identifier": headers}) + merge_annotation_columns(report, metadata) + + table = BaseProcessor({}, {})._create_protein_annotations_table(metadata) + cols = table.column_names + assert "cluster_P" in cols and "silhouette_P" in cols + d = table.to_pydict() + # membership: non-numeric category labels → categorical inference + assert all(v.startswith("cluster ") for v in d["cluster_P"]) + # per-point silhouette: clean numeric strings → continuous inference + for v in d["silhouette_P"]: + float(v) # must not raise + + +def test_router_multi_embedding_routes_each_projection_to_its_own_scores(): + Xa, _ = make_blobs(n_samples=100, centers=3, n_features=5, random_state=8) + Xb, _ = make_blobs(n_samples=100, centers=4, n_features=7, random_state=9) + ha = [f"a{i}" for i in range(100)] + hb = [f"b{i}" for i in range(100)] + ea, eb = _EmbSet("A", Xa, ha), _EmbSet("B", Xb, hb) + ca = PCA(n_components=2, random_state=0).fit_transform(Xa) + cb = PCA(n_components=2, random_state=0).fit_transform(Xb) + reductions = [ + {"name": "A — PCA 2", "info": {}, "data": ca, "ids": ha, "source": "A"}, + {"name": "B — PCA 2", "info": {}, "data": cb, "ids": hb, "source": "B"}, + ] + report = compute_statistics([ea, eb], reductions, rng_seed=42) + driver_vals = { + r.space_name: r.value + for r in report.rows + if r.stat_family == "faithfulness" and r.metric == "trustworthiness" + } + route_faithfulness_to_metadata(report, reductions) + + qa = reductions[0]["info"]["quality"]["trustworthiness"]["value"] + qb = reductions[1]["info"]["quality"]["trustworthiness"]["value"] + assert qa == driver_vals["A — PCA 2"] + assert qb == driver_vals["B — PCA 2"] + assert qa != qb # each projection scored against its own embedding diff --git a/tests/test_stats_cli.py b/tests/test_stats_cli.py new file mode 100644 index 00000000..8e739b55 --- /dev/null +++ b/tests/test_stats_cli.py @@ -0,0 +1,427 @@ +"""Integration tests for the discrete `protspace stats` path and prepare wiring.""" + +from __future__ import annotations + +import h5py +import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq +import pytest +from sklearn.datasets import make_blobs +from sklearn.decomposition import PCA + +from protspace.cli.stats import _load_reductions +from protspace.data.loaders import load_h5 +from protspace.stats import compute_statistics + + +def _project_dir(tmp_path, n=150, dim=6, centers=4, seed=0): + X, _ = make_blobs(n_samples=n, centers=centers, n_features=dim, random_state=seed) + headers = [f"p{i}" for i in range(n)] + coords = PCA(n_components=2, random_state=0).fit_transform(X) + + h5_path = tmp_path / "emb.h5" + with h5py.File(h5_path, "w") as f: + for i, h in enumerate(headers): + f.create_dataset(h, data=X[i].astype(np.float32)) + + proj = tmp_path / "project" + proj.mkdir() + pq.write_table( + pa.table( + { + "projection_name": ["E — PCA 2"], + "dimensions": [2], + "info_json": ['{"metric": "euclidean"}'], + } + ), + str(proj / "projections_metadata.parquet"), + ) + pq.write_table( + pa.table( + { + "projection_name": ["E — PCA 2"] * n, + "identifier": headers, + "x": coords[:, 0], + "y": coords[:, 1], + "z": [None] * n, + } + ), + str(proj / "projections_data.parquet"), + ) + return h5_path, proj, headers + + +def test_load_reductions_reconstructs_coords(tmp_path): + _, proj, headers = _project_dir(tmp_path) + reductions = _load_reductions(proj) + assert len(reductions) == 1 + red = reductions[0] + assert red["name"] == "E — PCA 2" + assert red["ids"] == headers + assert red["data"].shape == (len(headers), 2) + assert red["info"]["metric"] == "euclidean" + + +def test_discrete_path_produces_full_matrix(tmp_path): + h5_path, proj, _ = _project_dir(tmp_path) + emb = load_h5([h5_path], name_override="E") + reductions = _load_reductions(proj) + report = compute_statistics([emb], reductions, rng_seed=42) + metrics = {r.metric for r in report.rows} + # cluster-validity (coords only) + faithfulness (embedding matched by id-join) + assert {"silhouette", "n_clusters"} <= metrics + assert {"knn_overlap", "trustworthiness", "continuity"} <= metrics + # The fifth part (to_arrow) now carries aggregate validity only — faithfulness + # routes to projection metadata, not this table (route-projection-statistics). + table = report.to_arrow() + assert table.schema.names[0] == "space_kind" + assert set(table.column("stat_family").to_pylist()) == {"cluster_validity"} + assert table.num_rows == len(report.partition()["statistics_part"]) + + +def test_stats_command_writes_aggregate_only_part(tmp_path): + """`protspace stats -o statistics.parquet` writes validity/meta rows only — + faithfulness now rides in projection metadata, not this fifth part + (route-projection-statistics Phase 1A; the prep stats+bundle path stays valid).""" + from typer.testing import CliRunner + + from protspace.cli.app import app + + h5_path, proj, _ = _project_dir(tmp_path) + out = tmp_path / "statistics.parquet" + result = CliRunner().invoke( + app, ["stats", "-i", f"{h5_path}:E", "-p", str(proj), "-o", str(out)] + ) + assert result.exit_code == 0, result.output + assert out.exists() + + table = pq.read_table(str(out)) + families = set(table.column("stat_family").to_pylist()) + assert families == {"cluster_validity"} + metrics = set(table.column("metric").to_pylist()) + assert { + "silhouette", + "davies_bouldin", + "calinski_harabasz", + "n_clusters", + } <= metrics + assert not ({"knn_overlap", "trustworthiness", "continuity"} & metrics) + + +def test_stats_command_writes_faithfulness_into_metadata(tmp_path): + """`protspace stats` folds faithfulness into projections_metadata.info_json.quality + in place, so the prep `protspace bundle -p` carries it through to the bundle + (route-projection-statistics Phase 1B, option A).""" + import json + + from typer.testing import CliRunner + + from protspace.cli.app import app + + h5_path, proj, _ = _project_dir(tmp_path) + meta_path = proj / "projections_metadata.parquet" + before = pq.read_table(str(meta_path)) + + out = tmp_path / "statistics.parquet" + result = CliRunner().invoke( + app, ["stats", "-i", f"{h5_path}:E", "-p", str(proj), "-o", str(out)] + ) + assert result.exit_code == 0, result.output + + after = pq.read_table(str(meta_path)) + # All non-info columns and rows preserved; only info_json is enriched. + assert after.num_rows == before.num_rows + assert after.column_names == before.column_names + assert ( + after.column("dimensions").to_pylist() + == before.column("dimensions").to_pylist() + ) + + info_by_name = dict( + zip( + after.column("projection_name").to_pylist(), + after.column("info_json").to_pylist(), + strict=False, + ) + ) + info = json.loads(info_by_name["E — PCA 2"]) + assert {"knn_overlap", "trustworthiness", "continuity"} <= set(info["quality"]) + assert info["quality"]["knn_overlap"]["value"] is not None + assert info["metric"] == "euclidean" # pre-existing reducer info preserved + + +def test_stats_command_enriches_annotations_with_computed_columns(tmp_path): + """`protspace stats -a annotations.parquet` merges per-protein cluster + membership + silhouette columns into the annotations file in place + (route-projection-statistics Phase 2A), so the prep `bundle -a` carries them.""" + from typer.testing import CliRunner + + from protspace.cli.app import app + + h5_path, proj, headers = _project_dir(tmp_path) + ann_path = tmp_path / "annotations.parquet" + pq.write_table( + pa.table({"identifier": headers, "organism": ["x"] * len(headers)}), + str(ann_path), + ) + + out = tmp_path / "statistics.parquet" + result = CliRunner().invoke( + app, + [ + "stats", + "-i", + f"{h5_path}:E", + "-p", + str(proj), + "-a", + str(ann_path), + "-o", + str(out), + ], + ) + assert result.exit_code == 0, result.output + + df = pq.read_table(str(ann_path)).to_pandas() + cluster_cols = [c for c in df.columns if c.startswith("cluster_")] + sil_cols = [c for c in df.columns if c.startswith("silhouette_")] + assert cluster_cols and sil_cols + assert "organism" in df.columns # pre-existing annotation preserved + assert "identifier" in df.columns + # membership categorical (non-numeric strings); silhouette numeric strings + assert str(df[cluster_cols[0]].iloc[0]).startswith("cluster ") + float(df[sil_cols[0]].iloc[0]) # must not raise + + +def test_stats_without_annotations_does_not_compute_per_protein(tmp_path): + """Without -a, stats stays aggregate+faithfulness only (the per-protein + computation has nowhere to land, so it is skipped).""" + from typer.testing import CliRunner + + from protspace.cli.app import app + + h5_path, proj, _ = _project_dir(tmp_path) + out = tmp_path / "statistics.parquet" + result = CliRunner().invoke( + app, ["stats", "-i", f"{h5_path}:E", "-p", str(proj), "-o", str(out)] + ) + assert result.exit_code == 0, result.output + table = pq.read_table(str(out)) + assert set(table.column("stat_family").to_pylist()) == {"cluster_validity"} + + +def test_stats_a_then_bundle_carries_computed_columns_into_bundle(tmp_path): + """End-to-end prep path: `stats -a` then `bundle -a` ships a bundle whose + protein_annotations part carries the computed cluster_/silhouette_ columns.""" + from typer.testing import CliRunner + + from protspace.cli.app import app + from protspace.data.io.bundle import read_bundle + + h5_path, proj, headers = _project_dir(tmp_path) + runner = CliRunner() + ann_path = tmp_path / "annotations.parquet" + pq.write_table( + pa.table({"identifier": headers, "organism": ["x"] * len(headers)}), + str(ann_path), + ) + + stats_out = tmp_path / "statistics.parquet" + r1 = runner.invoke( + app, + [ + "stats", + "-i", + f"{h5_path}:E", + "-p", + str(proj), + "-a", + str(ann_path), + "-o", + str(stats_out), + ], + ) + assert r1.exit_code == 0, r1.output + + bundle_out = tmp_path / "data.parquetbundle" + r2 = runner.invoke( + app, + [ + "bundle", + "-p", + str(proj), + "-a", + str(ann_path), + "-s", + str(stats_out), + "-o", + str(bundle_out), + ], + ) + assert r2.exit_code == 0, r2.output + + core, _ = read_bundle(bundle_out) + ann_table = pq.read_table( + pa.BufferReader(core[0]) + ) # protein_annotations is 1st part + cols = ann_table.column_names + assert any(c.startswith("cluster_") for c in cols) + assert any(c.startswith("silhouette_") for c in cols) + assert "protein_id" in cols # identifier renamed by bundle + + +def test_stats_settings_out_then_bundle_settings_styles_clusters(tmp_path): + """End-to-end auto-style: `stats -a --settings-out` writes a valid cluster + legend-settings JSON, and `bundle --settings` folds it into the bundle's + settings part (route-projection-statistics Phase 2A.4).""" + import json + + from typer.testing import CliRunner + + from protspace.cli.app import app + from protspace.data.io.bundle import read_bundle + + h5_path, proj, headers = _project_dir(tmp_path) + runner = CliRunner() + ann_path = tmp_path / "annotations.parquet" + pq.write_table(pa.table({"identifier": headers}), str(ann_path)) + + stats_out = tmp_path / "statistics.parquet" + styles_out = tmp_path / "cluster_styles.json" + r1 = runner.invoke( + app, + [ + "stats", + "-i", + f"{h5_path}:E", + "-p", + str(proj), + "-a", + str(ann_path), + "--settings-out", + str(styles_out), + "-o", + str(stats_out), + ], + ) + assert r1.exit_code == 0, r1.output + assert styles_out.exists() + styles = json.loads(styles_out.read_text()) + cluster_keys = [k for k in styles if k.startswith("cluster_")] + assert cluster_keys + env = styles[cluster_keys[0]] + assert env["selectedPaletteId"] == "kellys" and env["categories"] + + bundle_out = tmp_path / "data.parquetbundle" + r2 = runner.invoke( + app, + [ + "bundle", + "-p", + str(proj), + "-a", + str(ann_path), + "-s", + str(stats_out), + "--settings", + str(styles_out), + "-o", + str(bundle_out), + ], + ) + assert r2.exit_code == 0, r2.output + + _, settings = read_bundle(bundle_out) + assert settings is not None + assert any(k.startswith("cluster_") for k in settings) + + +def test_stats_then_bundle_carries_faithfulness_into_bundle(tmp_path): + """End-to-end prep path: `protspace stats` then `protspace bundle -p` ships a + bundle whose projections_metadata.info_json carries faithfulness quality, and + whose aggregate fifth part stays validity-only.""" + import json + + from typer.testing import CliRunner + + from protspace.cli.app import app + from protspace.data.io.bundle import read_bundle, read_statistics_from_bundle + + h5_path, proj, headers = _project_dir(tmp_path) + runner = CliRunner() + + stats_out = tmp_path / "statistics.parquet" + r1 = runner.invoke( + app, ["stats", "-i", f"{h5_path}:E", "-p", str(proj), "-o", str(stats_out)] + ) + assert r1.exit_code == 0, r1.output + + ann_path = tmp_path / "annotations.parquet" + pq.write_table(pa.table({"identifier": headers}), str(ann_path)) + + bundle_out = tmp_path / "data.parquetbundle" + r2 = runner.invoke( + app, + [ + "bundle", + "-p", + str(proj), + "-a", + str(ann_path), + "-s", + str(stats_out), + "-o", + str(bundle_out), + ], + ) + assert r2.exit_code == 0, r2.output + + core, _ = read_bundle(bundle_out) + # core parts are raw parquet bytes; projections_metadata is the 2nd part. + metadata_table = pq.read_table(pa.BufferReader(core[1])) + info_by_name = dict( + zip( + metadata_table.column("projection_name").to_pylist(), + metadata_table.column("info_json").to_pylist(), + strict=False, + ) + ) + info = json.loads(info_by_name["E — PCA 2"]) + assert {"knn_overlap", "trustworthiness", "continuity"} <= set(info["quality"]) + + # The fifth part still ships, aggregate-only. + stats_bytes = read_statistics_from_bundle(bundle_out) + assert stats_bytes is not None + fifth = pq.read_table(pa.BufferReader(stats_bytes)) + assert set(fifth.column("stat_family").to_pylist()) == {"cluster_validity"} + + +def test_prepare_pipeline_compute_statistics(tmp_path): + from pathlib import Path + + from protspace.data.processors.pipeline import PipelineConfig, ReductionPipeline + + class _EmbSet: + def __init__(self, name, data, headers): + self.name = name + self.data = data + self.headers = headers + + X, _ = make_blobs(n_samples=120, centers=3, n_features=5, random_state=1) + headers = [f"p{i}" for i in range(120)] + coords = PCA(n_components=2, random_state=0).fit_transform(X) + emb = _EmbSet("E", X, headers) + reductions = [{"name": "E — PCA 2", "data": coords, "source": "E"}] + + pipeline = ReductionPipeline(PipelineConfig(methods=[], output_path=Path(tmp_path))) + table = pipeline._compute_statistics([emb], reductions, headers) + assert table is not None + assert table.num_rows > 0 + # Fifth part is aggregate-only now; faithfulness rides in projection metadata. + families = set(table.column("stat_family").to_pylist()) + assert "cluster_validity" in families + assert "faithfulness" not in families + # ...and _compute_statistics routed faithfulness into the reduction's info.quality + quality = reductions[0]["info"]["quality"] + assert {"knn_overlap", "trustworthiness", "continuity"} <= set(quality)