Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 172 additions & 0 deletions src/sampleworks/metrics/rmsd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
"""RMSD metric, developed starting from :class:`sampleworks.metrics.lddt.AllAtomLDDT`.

The metric returns a global per model RMSD and a per residue RMSD dictionary,
so it can be plugged into the same clustering code as the LDDT
metric.
"""
Comment thread
marcuscollins marked this conversation as resolved.

from typing import Any, cast

import numpy as np
from atomworks.io.transforms.atom_array import ensure_atom_array_stack
from atomworks.ml.transforms.atom_array import add_global_token_id_annotation
from biotite.structure import AtomArray, AtomArrayStack, rmsd, superimpose

from sampleworks.metrics.metric import Metric
from sampleworks.utils.atom_array_utils import filter_to_common_atoms


class AllAtomRMSD(Metric):
"""Computes all atom RMSD from AtomArrays.

Parameters
----------
superimpose
If True, superimpose predicted models onto the reference via Kabsch
before computing RMSD. Default False is correct for comparisons in a shared crystallographic
frame, but set True when the predicted and reference structures live in different frames.
log_rmsd_for_every_batch
If True, include per model RMSDs as ``all_atom_rmsd_<i>`` in the
output dict.
"""

def __init__(
self,
superimpose: bool = False,
log_rmsd_for_every_batch: bool = False,
**kwargs: Any,
):
super().__init__(**kwargs)
self.superimpose = superimpose
self.log_rmsd_for_every_batch = log_rmsd_for_every_batch

@property
def kwargs_to_compute_args(self) -> dict[str, Any]:
return {
"predicted_atom_array_stack": "predicted_atom_array_stack",
"ground_truth_atom_array_stack": "ground_truth_atom_array_stack",
"selection": "selection",
}

@property
def optional_kwargs(self) -> frozenset[str]:
return frozenset({"selection"})

def compute(
self,
predicted_atom_array_stack: AtomArrayStack | AtomArray,
ground_truth_atom_array_stack: AtomArrayStack | AtomArray,
selection: str | None = None,
) -> dict[str, Any]:
"""Calculate all-atom RMSD between predicted and ground-truth structures.

Parameters
----------
predicted_atom_array_stack
Predicted coordinates as AtomArray(Stack).
ground_truth_atom_array_stack
Ground truth coordinates as AtomArray(Stack).
selection
Optional selection string (AtomArray.mask() syntax) restricting
which residues appear in ``residue_rmsd_scores``. It does NOT
restrict the atoms used to compute the global RMSD, nor (when
``superimpose=True``) the atoms used for Kabsch superposition -
both of those always use every atom common to the predicted and
reference stacks.
Comment thread
coderabbitai[bot] marked this conversation as resolved.

Returns
-------
dict[str, Any]
A dictionary with all-atom RMSDs:

- ``best_of_1_rmsd``: global RMSD for the first model.
- ``best_of_{N}_rmsd``: minimum global RMSD across all N models.
- ``residue_rmsd_scores``: per residue RMSDs, one list per
residue keyed by ``chain_id + res_id``.

When ``selection`` is supplied, two additional keys are present:

- ``segment_rmsd``: list of all-atom RMSDs (one per predicted
model) restricted to the atoms in ``selection``, computed after
the same global Kabsch superposition used for ``best_of_1_rmsd``.
- ``best_of_{N}_segment_rmsd``: minimum of ``segment_rmsd``.
"""
# 1. Annotate token IDs so atoms can be grouped into residues downstream.
predicted_atom_array_stack = add_global_token_id_annotation(
predicted_atom_array_stack # ty: ignore[invalid-argument-type] (accepts AtomArray|AtomArrayStack at runtime; stub is narrower)
)
ground_truth_atom_array_stack = add_global_token_id_annotation(
ground_truth_atom_array_stack # ty: ignore[invalid-argument-type] (accepts AtomArray|AtomArrayStack at runtime; stub is narrower)
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

# 2. Restrict both stacks to atoms present in both structures, in matching order.
_pred, _gt = filter_to_common_atoms(
predicted_atom_array_stack, ground_truth_atom_array_stack
)
pred_aa_stack = ensure_atom_array_stack(_pred)
gt_aa_stack = ensure_atom_array_stack(_gt)

if pred_aa_stack.array_length() == 0:
raise RuntimeError("No atoms in common between the two structures.")

gt_ref = gt_aa_stack[0]
tok_idx = cast(np.ndarray, gt_ref.token_id).astype(np.int64)

# Resolve the selection mask BEFORE superposition since .mask() exists only on the
# atomworks loaded array, not the biotite stack superimpose() returned array
selected_token_ids: set[int] | None = None
selection_mask: np.ndarray | None = None
if selection is not None:
try:
selection_mask = pred_aa_stack.mask(selection)
except AttributeError as e:
raise RuntimeError(
"pred_aa_stack does not support mask(). Load atom arrays with "
"`atomworks.io.utils.io_utils.load_any()` to access this method."
) from e
selected_arr = cast(AtomArray, pred_aa_stack[0, selection_mask])
if selected_arr.token_id is not None:
selected_token_ids = {int(t) for t in np.unique(selected_arr.token_id)}

# 3. Optional Kabsch superposition (always on every common atom, regardless of selection).
if self.superimpose:
pred_aa_stack, _ = superimpose(gt_ref, pred_aa_stack)

global_rmsd = np.asarray(rmsd(gt_ref, pred_aa_stack))

# 4. Build residue-level output: for each token, call biotite.rmsd on the
# atoms that share that token, keyed by "{chain_id}{res_id}".
chain_id = cast(np.ndarray, gt_ref.chain_id)
res_id = cast(np.ndarray, gt_ref.res_id)
residue_keys = np.char.add(chain_id, res_id.astype(str))
token_to_residue_id_map = {int(k): str(v) for k, v in zip(tok_idx, residue_keys)}

residue_rmsd_scores: dict[str, list[float]] = {}
for tk in np.unique(tok_idx):
tk_int = int(tk)
if selected_token_ids is not None and tk_int not in selected_token_ids:
continue
atom_mask = tok_idx == tk
residue_rmsd = np.asarray(rmsd(gt_ref[atom_mask], pred_aa_stack[:, atom_mask]))
residue_rmsd_scores[token_to_residue_id_map[tk_int]] = residue_rmsd.tolist()

result: dict[str, Any] = {
"best_of_1_rmsd": float(global_rmsd[0]),
f"best_of_{len(global_rmsd)}_rmsd": float(global_rmsd.min()),
"residue_rmsd_scores": residue_rmsd_scores,
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

if selection_mask is not None:
# Segment all atom RMSD
segment_rmsd = np.asarray(
rmsd(gt_ref[selection_mask], pred_aa_stack[:, selection_mask])
)
result["segment_rmsd"] = [float(r) for r in segment_rmsd]
result[f"best_of_{len(segment_rmsd)}_segment_rmsd"] = float(segment_rmsd.min())

if self.log_rmsd_for_every_batch:
result.update(
{f"all_atom_rmsd_{i}": float(global_rmsd[i]) for i in range(len(global_rmsd))}
)

return result
23 changes: 23 additions & 0 deletions tests/metrics/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
Common fixtures for metrics tests.
"""

import pytest
from atomworks.io.transforms.atom_array import ensure_atom_array_stack
from biotite.structure import AtomArrayStack
from sampleworks.utils.atom_array_utils import (
select_altloc,
select_backbone,
)


@pytest.fixture(scope="session")
def altlocA_backbone(structure_6b8x_with_altlocs) -> AtomArrayStack:
altlocA = select_altloc(structure_6b8x_with_altlocs, "A", return_full_array=True)
return ensure_atom_array_stack(select_backbone(altlocA))


@pytest.fixture(scope="session")
def altlocB_backbone(structure_6b8x_with_altlocs) -> AtomArrayStack:
altlocB = select_altloc(structure_6b8x_with_altlocs, "B", return_full_array=True)
return ensure_atom_array_stack(select_backbone(altlocB))
17 changes: 0 additions & 17 deletions tests/metrics/test_lddt_metrics.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,13 @@
from typing import cast

import pytest
from atomworks.io.transforms.atom_array import ensure_atom_array_stack
from biotite.structure import AtomArrayStack
from sampleworks.metrics.lddt import AllAtomLDDT, SelectedLDDT
from sampleworks.utils.atom_array_utils import select_altloc, select_backbone


# These tests are currently too high level, but they will serve for now to demonstrate
# the expected behavior and make sure nothing gets broken.


@pytest.fixture(scope="module")
def altlocA_backbone(structure_6b8x_with_altlocs) -> AtomArrayStack:
altlocA = select_altloc(structure_6b8x_with_altlocs, "A", return_full_array=True)
altlocA_bb = select_backbone(altlocA)
return ensure_atom_array_stack(altlocA_bb)


@pytest.fixture(scope="module")
def altlocB_backbone(structure_6b8x_with_altlocs) -> AtomArrayStack:
altlocB = select_altloc(structure_6b8x_with_altlocs, "B", return_full_array=True)
altlocB_bb = select_backbone(altlocB)
return ensure_atom_array_stack(altlocB_bb)


@pytest.mark.gpu
def test_all_atom_lddt_end_to_end(altlocA_backbone, altlocB_backbone):
selection_string = "res_id > 179 and res_id < 190"
Expand Down
122 changes: 122 additions & 0 deletions tests/metrics/test_rmsd_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""
Tests for the RMSD metrics.
"""

import numpy as np
import pytest
from biotite.structure import rmsd, superimpose
from sampleworks.metrics.rmsd import AllAtomRMSD
from sampleworks.utils.atom_array_utils import filter_to_common_atoms


# TODO: make these a bit more rigorous, similarly to tests/metrics/test_lddt_metrics.py


def test_all_atom_rmsd_identity(altlocA_backbone):
"""A stack compared to itself must have zero RMSD on every residue."""
metric = AllAtomRMSD()
results = metric.compute(altlocA_backbone, altlocA_backbone)

assert results["best_of_1_rmsd"] == pytest.approx(0.0, abs=1e-6)
assert results["residue_rmsd_scores"]
for residue, scores in results["residue_rmsd_scores"].items():
assert scores == pytest.approx([0.0], abs=1e-6), f"nonzero identity RMSD at {residue}"


def test_all_atom_rmsd_end_to_end(altlocA_backbone, altlocB_backbone):
Comment thread
coderabbitai[bot] marked this conversation as resolved.
"""
A residue selection on the 6b8x altloc A/B backbone yields the expected global best_of_1_rmsd
and per-residue scores.
"""
selection_string = "res_id > 179 and res_id < 190"
metric = AllAtomRMSD()
results = metric.compute(altlocA_backbone, altlocB_backbone, selection_string)

expected_results = {
"best_of_1_rmsd": 0.7332,
"residue_rmsd_scores": {
"A180": [2.9206],
"A181": [4.2650],
"A182": [5.6639],
"A183": [2.9397],
"A184": [2.0595],
"A185": [2.2885],
"A186": [3.6929],
"A187": [1.8138],
"A188": [0.9547],
"A189": [0.7596],
},
}

assert results["best_of_1_rmsd"] == pytest.approx(expected_results["best_of_1_rmsd"], abs=0.002)

assert set(results["residue_rmsd_scores"].keys()) == set(
expected_results["residue_rmsd_scores"].keys()
)

for residue_key, expected_scores in expected_results["residue_rmsd_scores"].items():
result_scores = results["residue_rmsd_scores"][residue_key]
assert len(result_scores) == len(expected_scores)
for got, want in zip(result_scores, expected_scores):
assert got == pytest.approx(want, abs=0.005), (
f"RMSD mismatch at {residue_key}: got {got}, expected {want}"
)


def test_all_atom_rmsd_selection_does_not_change_global(altlocA_backbone, altlocB_backbone):
"""Selection filters the residue-level dict but must not change the global RMSD."""
metric = AllAtomRMSD()
all_residues = metric.compute(altlocA_backbone, altlocB_backbone)
selected = metric.compute(altlocA_backbone, altlocB_backbone, "res_id > 179 and res_id < 190")

assert selected["best_of_1_rmsd"] == pytest.approx(all_residues["best_of_1_rmsd"], abs=1e-6)
assert len(selected["residue_rmsd_scores"]) < len(all_residues["residue_rmsd_scores"])


def test_segment_rmsd_absent_without_selection(altlocA_backbone, altlocB_backbone):
"""No segment_rmsd keys are emitted when selection is None."""
metric = AllAtomRMSD()
result = metric.compute(altlocA_backbone, altlocB_backbone)

assert "segment_rmsd" not in result
assert not any(k.endswith("_segment_rmsd") for k in result)


def test_segment_rmsd_present_with_selection(altlocA_backbone, altlocB_backbone):
"""segment_rmsd is a list of length N_member and best_of_N_segment_rmsd is its min."""
metric = AllAtomRMSD()
result = metric.compute(altlocA_backbone, altlocB_backbone, "res_id > 179 and res_id < 190")

n_member = altlocB_backbone.stack_depth()
assert "segment_rmsd" in result
assert len(result["segment_rmsd"]) == n_member

key = f"best_of_{n_member}_segment_rmsd"
assert key in result
assert result[key] == pytest.approx(min(result["segment_rmsd"]))


def test_segment_rmsd_matches_hand_computed(altlocA_backbone, altlocB_backbone):
"""segment_rmsd equals biotite.rmsd on the same coordinates."""
selection = "res_id > 179 and res_id < 190"
metric = AllAtomRMSD(superimpose=True)
result = metric.compute(altlocA_backbone, altlocB_backbone, selection)

# Recompute by hand: filter to common atoms, superpose, mask, then biotite.rmsd.
pred, gt = filter_to_common_atoms(altlocA_backbone, altlocB_backbone)
gt_ref = gt[0]
pred_aligned, _ = superimpose(gt_ref, pred)
mask = pred_aligned.mask(selection)
expected = np.asarray(rmsd(gt_ref[mask], pred_aligned[:, mask]))

np.testing.assert_allclose(result["segment_rmsd"], expected.tolist(), rtol=1e-5, atol=1e-6)


def test_segment_rmsd_identity_is_zero(altlocA_backbone):
"""Comparing a stack to itself yields zero segment_rmsd."""
metric = AllAtomRMSD()
result = metric.compute(altlocA_backbone, altlocA_backbone, "res_id > 179 and res_id < 190")

assert result["segment_rmsd"]
for v in result["segment_rmsd"]:
assert v == pytest.approx(0.0, abs=1e-6)
Loading