-
Notifications
You must be signed in to change notification settings - Fork 7
feat: add RMSD metric #227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,172 @@ | ||
| """RMSD metric, developed starting from :class:`sampleworks.metrics.lddt.AllAtomLDDT`. | ||
|
|
||
| The metric returns a global per model RMSD and a per residue RMSD dictionary, | ||
| so it can be plugged into the same clustering code as the LDDT | ||
| metric. | ||
| """ | ||
|
|
||
| from typing import Any, cast | ||
|
|
||
| import numpy as np | ||
| from atomworks.io.transforms.atom_array import ensure_atom_array_stack | ||
| from atomworks.ml.transforms.atom_array import add_global_token_id_annotation | ||
| from biotite.structure import AtomArray, AtomArrayStack, rmsd, superimpose | ||
|
|
||
| from sampleworks.metrics.metric import Metric | ||
| from sampleworks.utils.atom_array_utils import filter_to_common_atoms | ||
|
|
||
|
|
||
| class AllAtomRMSD(Metric): | ||
| """Computes all atom RMSD from AtomArrays. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| superimpose | ||
| If True, superimpose predicted models onto the reference via Kabsch | ||
| before computing RMSD. Default False is correct for comparisons in a shared crystallographic | ||
| frame, but set True when the predicted and reference structures live in different frames. | ||
| log_rmsd_for_every_batch | ||
| If True, include per model RMSDs as ``all_atom_rmsd_<i>`` in the | ||
| output dict. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| superimpose: bool = False, | ||
| log_rmsd_for_every_batch: bool = False, | ||
| **kwargs: Any, | ||
| ): | ||
| super().__init__(**kwargs) | ||
| self.superimpose = superimpose | ||
| self.log_rmsd_for_every_batch = log_rmsd_for_every_batch | ||
|
|
||
| @property | ||
| def kwargs_to_compute_args(self) -> dict[str, Any]: | ||
| return { | ||
| "predicted_atom_array_stack": "predicted_atom_array_stack", | ||
| "ground_truth_atom_array_stack": "ground_truth_atom_array_stack", | ||
| "selection": "selection", | ||
| } | ||
|
|
||
| @property | ||
| def optional_kwargs(self) -> frozenset[str]: | ||
| return frozenset({"selection"}) | ||
|
|
||
| def compute( | ||
| self, | ||
| predicted_atom_array_stack: AtomArrayStack | AtomArray, | ||
| ground_truth_atom_array_stack: AtomArrayStack | AtomArray, | ||
| selection: str | None = None, | ||
| ) -> dict[str, Any]: | ||
| """Calculate all-atom RMSD between predicted and ground-truth structures. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| predicted_atom_array_stack | ||
| Predicted coordinates as AtomArray(Stack). | ||
| ground_truth_atom_array_stack | ||
| Ground truth coordinates as AtomArray(Stack). | ||
| selection | ||
| Optional selection string (AtomArray.mask() syntax) restricting | ||
| which residues appear in ``residue_rmsd_scores``. It does NOT | ||
| restrict the atoms used to compute the global RMSD, nor (when | ||
| ``superimpose=True``) the atoms used for Kabsch superposition - | ||
| both of those always use every atom common to the predicted and | ||
| reference stacks. | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
|
|
||
| Returns | ||
| ------- | ||
| dict[str, Any] | ||
| A dictionary with all-atom RMSDs: | ||
|
|
||
| - ``best_of_1_rmsd``: global RMSD for the first model. | ||
| - ``best_of_{N}_rmsd``: minimum global RMSD across all N models. | ||
| - ``residue_rmsd_scores``: per residue RMSDs, one list per | ||
| residue keyed by ``chain_id + res_id``. | ||
|
|
||
| When ``selection`` is supplied, two additional keys are present: | ||
|
|
||
| - ``segment_rmsd``: list of all-atom RMSDs (one per predicted | ||
| model) restricted to the atoms in ``selection``, computed after | ||
| the same global Kabsch superposition used for ``best_of_1_rmsd``. | ||
| - ``best_of_{N}_segment_rmsd``: minimum of ``segment_rmsd``. | ||
| """ | ||
| # 1. Annotate token IDs so atoms can be grouped into residues downstream. | ||
| predicted_atom_array_stack = add_global_token_id_annotation( | ||
| predicted_atom_array_stack # ty: ignore[invalid-argument-type] (accepts AtomArray|AtomArrayStack at runtime; stub is narrower) | ||
| ) | ||
| ground_truth_atom_array_stack = add_global_token_id_annotation( | ||
| ground_truth_atom_array_stack # ty: ignore[invalid-argument-type] (accepts AtomArray|AtomArrayStack at runtime; stub is narrower) | ||
| ) | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
|
|
||
| # 2. Restrict both stacks to atoms present in both structures, in matching order. | ||
| _pred, _gt = filter_to_common_atoms( | ||
| predicted_atom_array_stack, ground_truth_atom_array_stack | ||
| ) | ||
| pred_aa_stack = ensure_atom_array_stack(_pred) | ||
| gt_aa_stack = ensure_atom_array_stack(_gt) | ||
|
|
||
| if pred_aa_stack.array_length() == 0: | ||
| raise RuntimeError("No atoms in common between the two structures.") | ||
|
|
||
| gt_ref = gt_aa_stack[0] | ||
| tok_idx = cast(np.ndarray, gt_ref.token_id).astype(np.int64) | ||
|
|
||
| # Resolve the selection mask BEFORE superposition since .mask() exists only on the | ||
| # atomworks loaded array, not the biotite stack superimpose() returned array | ||
| selected_token_ids: set[int] | None = None | ||
| selection_mask: np.ndarray | None = None | ||
| if selection is not None: | ||
| try: | ||
| selection_mask = pred_aa_stack.mask(selection) | ||
| except AttributeError as e: | ||
| raise RuntimeError( | ||
| "pred_aa_stack does not support mask(). Load atom arrays with " | ||
| "`atomworks.io.utils.io_utils.load_any()` to access this method." | ||
| ) from e | ||
| selected_arr = cast(AtomArray, pred_aa_stack[0, selection_mask]) | ||
| if selected_arr.token_id is not None: | ||
| selected_token_ids = {int(t) for t in np.unique(selected_arr.token_id)} | ||
|
|
||
| # 3. Optional Kabsch superposition (always on every common atom, regardless of selection). | ||
| if self.superimpose: | ||
| pred_aa_stack, _ = superimpose(gt_ref, pred_aa_stack) | ||
|
|
||
| global_rmsd = np.asarray(rmsd(gt_ref, pred_aa_stack)) | ||
|
|
||
| # 4. Build residue-level output: for each token, call biotite.rmsd on the | ||
| # atoms that share that token, keyed by "{chain_id}{res_id}". | ||
| chain_id = cast(np.ndarray, gt_ref.chain_id) | ||
| res_id = cast(np.ndarray, gt_ref.res_id) | ||
| residue_keys = np.char.add(chain_id, res_id.astype(str)) | ||
| token_to_residue_id_map = {int(k): str(v) for k, v in zip(tok_idx, residue_keys)} | ||
|
|
||
| residue_rmsd_scores: dict[str, list[float]] = {} | ||
| for tk in np.unique(tok_idx): | ||
| tk_int = int(tk) | ||
| if selected_token_ids is not None and tk_int not in selected_token_ids: | ||
| continue | ||
| atom_mask = tok_idx == tk | ||
| residue_rmsd = np.asarray(rmsd(gt_ref[atom_mask], pred_aa_stack[:, atom_mask])) | ||
| residue_rmsd_scores[token_to_residue_id_map[tk_int]] = residue_rmsd.tolist() | ||
|
|
||
| result: dict[str, Any] = { | ||
| "best_of_1_rmsd": float(global_rmsd[0]), | ||
| f"best_of_{len(global_rmsd)}_rmsd": float(global_rmsd.min()), | ||
| "residue_rmsd_scores": residue_rmsd_scores, | ||
| } | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
|
|
||
| if selection_mask is not None: | ||
| # Segment all atom RMSD | ||
| segment_rmsd = np.asarray( | ||
| rmsd(gt_ref[selection_mask], pred_aa_stack[:, selection_mask]) | ||
| ) | ||
| result["segment_rmsd"] = [float(r) for r in segment_rmsd] | ||
| result[f"best_of_{len(segment_rmsd)}_segment_rmsd"] = float(segment_rmsd.min()) | ||
|
|
||
| if self.log_rmsd_for_every_batch: | ||
| result.update( | ||
| {f"all_atom_rmsd_{i}": float(global_rmsd[i]) for i in range(len(global_rmsd))} | ||
| ) | ||
|
|
||
| return result | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| """ | ||
| Common fixtures for metrics tests. | ||
| """ | ||
|
|
||
| import pytest | ||
| from atomworks.io.transforms.atom_array import ensure_atom_array_stack | ||
| from biotite.structure import AtomArrayStack | ||
| from sampleworks.utils.atom_array_utils import ( | ||
| select_altloc, | ||
| select_backbone, | ||
| ) | ||
|
|
||
|
|
||
| @pytest.fixture(scope="session") | ||
| def altlocA_backbone(structure_6b8x_with_altlocs) -> AtomArrayStack: | ||
| altlocA = select_altloc(structure_6b8x_with_altlocs, "A", return_full_array=True) | ||
| return ensure_atom_array_stack(select_backbone(altlocA)) | ||
|
|
||
|
|
||
| @pytest.fixture(scope="session") | ||
| def altlocB_backbone(structure_6b8x_with_altlocs) -> AtomArrayStack: | ||
| altlocB = select_altloc(structure_6b8x_with_altlocs, "B", return_full_array=True) | ||
| return ensure_atom_array_stack(select_backbone(altlocB)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,122 @@ | ||
| """ | ||
| Tests for the RMSD metrics. | ||
| """ | ||
|
|
||
| import numpy as np | ||
| import pytest | ||
| from biotite.structure import rmsd, superimpose | ||
| from sampleworks.metrics.rmsd import AllAtomRMSD | ||
| from sampleworks.utils.atom_array_utils import filter_to_common_atoms | ||
|
|
||
|
|
||
| # TODO: make these a bit more rigorous, similarly to tests/metrics/test_lddt_metrics.py | ||
|
|
||
|
|
||
| def test_all_atom_rmsd_identity(altlocA_backbone): | ||
| """A stack compared to itself must have zero RMSD on every residue.""" | ||
| metric = AllAtomRMSD() | ||
| results = metric.compute(altlocA_backbone, altlocA_backbone) | ||
|
|
||
| assert results["best_of_1_rmsd"] == pytest.approx(0.0, abs=1e-6) | ||
| assert results["residue_rmsd_scores"] | ||
| for residue, scores in results["residue_rmsd_scores"].items(): | ||
| assert scores == pytest.approx([0.0], abs=1e-6), f"nonzero identity RMSD at {residue}" | ||
|
|
||
|
|
||
| def test_all_atom_rmsd_end_to_end(altlocA_backbone, altlocB_backbone): | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
| """ | ||
| A residue selection on the 6b8x altloc A/B backbone yields the expected global best_of_1_rmsd | ||
| and per-residue scores. | ||
| """ | ||
| selection_string = "res_id > 179 and res_id < 190" | ||
| metric = AllAtomRMSD() | ||
| results = metric.compute(altlocA_backbone, altlocB_backbone, selection_string) | ||
|
|
||
| expected_results = { | ||
| "best_of_1_rmsd": 0.7332, | ||
| "residue_rmsd_scores": { | ||
| "A180": [2.9206], | ||
| "A181": [4.2650], | ||
| "A182": [5.6639], | ||
| "A183": [2.9397], | ||
| "A184": [2.0595], | ||
| "A185": [2.2885], | ||
| "A186": [3.6929], | ||
| "A187": [1.8138], | ||
| "A188": [0.9547], | ||
| "A189": [0.7596], | ||
| }, | ||
| } | ||
|
|
||
| assert results["best_of_1_rmsd"] == pytest.approx(expected_results["best_of_1_rmsd"], abs=0.002) | ||
|
|
||
| assert set(results["residue_rmsd_scores"].keys()) == set( | ||
| expected_results["residue_rmsd_scores"].keys() | ||
| ) | ||
|
|
||
| for residue_key, expected_scores in expected_results["residue_rmsd_scores"].items(): | ||
| result_scores = results["residue_rmsd_scores"][residue_key] | ||
| assert len(result_scores) == len(expected_scores) | ||
| for got, want in zip(result_scores, expected_scores): | ||
| assert got == pytest.approx(want, abs=0.005), ( | ||
| f"RMSD mismatch at {residue_key}: got {got}, expected {want}" | ||
| ) | ||
|
|
||
|
|
||
| def test_all_atom_rmsd_selection_does_not_change_global(altlocA_backbone, altlocB_backbone): | ||
| """Selection filters the residue-level dict but must not change the global RMSD.""" | ||
| metric = AllAtomRMSD() | ||
| all_residues = metric.compute(altlocA_backbone, altlocB_backbone) | ||
| selected = metric.compute(altlocA_backbone, altlocB_backbone, "res_id > 179 and res_id < 190") | ||
|
|
||
| assert selected["best_of_1_rmsd"] == pytest.approx(all_residues["best_of_1_rmsd"], abs=1e-6) | ||
| assert len(selected["residue_rmsd_scores"]) < len(all_residues["residue_rmsd_scores"]) | ||
|
|
||
|
|
||
| def test_segment_rmsd_absent_without_selection(altlocA_backbone, altlocB_backbone): | ||
| """No segment_rmsd keys are emitted when selection is None.""" | ||
| metric = AllAtomRMSD() | ||
| result = metric.compute(altlocA_backbone, altlocB_backbone) | ||
|
|
||
| assert "segment_rmsd" not in result | ||
| assert not any(k.endswith("_segment_rmsd") for k in result) | ||
|
|
||
|
|
||
| def test_segment_rmsd_present_with_selection(altlocA_backbone, altlocB_backbone): | ||
| """segment_rmsd is a list of length N_member and best_of_N_segment_rmsd is its min.""" | ||
| metric = AllAtomRMSD() | ||
| result = metric.compute(altlocA_backbone, altlocB_backbone, "res_id > 179 and res_id < 190") | ||
|
|
||
| n_member = altlocB_backbone.stack_depth() | ||
| assert "segment_rmsd" in result | ||
| assert len(result["segment_rmsd"]) == n_member | ||
|
|
||
| key = f"best_of_{n_member}_segment_rmsd" | ||
| assert key in result | ||
| assert result[key] == pytest.approx(min(result["segment_rmsd"])) | ||
|
|
||
|
|
||
| def test_segment_rmsd_matches_hand_computed(altlocA_backbone, altlocB_backbone): | ||
| """segment_rmsd equals biotite.rmsd on the same coordinates.""" | ||
| selection = "res_id > 179 and res_id < 190" | ||
| metric = AllAtomRMSD(superimpose=True) | ||
| result = metric.compute(altlocA_backbone, altlocB_backbone, selection) | ||
|
|
||
| # Recompute by hand: filter to common atoms, superpose, mask, then biotite.rmsd. | ||
| pred, gt = filter_to_common_atoms(altlocA_backbone, altlocB_backbone) | ||
| gt_ref = gt[0] | ||
| pred_aligned, _ = superimpose(gt_ref, pred) | ||
| mask = pred_aligned.mask(selection) | ||
| expected = np.asarray(rmsd(gt_ref[mask], pred_aligned[:, mask])) | ||
|
|
||
| np.testing.assert_allclose(result["segment_rmsd"], expected.tolist(), rtol=1e-5, atol=1e-6) | ||
|
|
||
|
|
||
| def test_segment_rmsd_identity_is_zero(altlocA_backbone): | ||
| """Comparing a stack to itself yields zero segment_rmsd.""" | ||
| metric = AllAtomRMSD() | ||
| result = metric.compute(altlocA_backbone, altlocA_backbone, "res_id > 179 and res_id < 190") | ||
|
|
||
| assert result["segment_rmsd"] | ||
| for v in result["segment_rmsd"]: | ||
| assert v == pytest.approx(0.0, abs=1e-6) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.