Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions news/118-fix-ligand-cofactor-selection.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
**Fixed:**

* ``gather_rms_data`` now correctly identifies the ligand when cofactors
share the same residue name ``UNK``. The ligand is auto-detected using
hybrid topology tempfactors (b-factors), while custom ``ligand_selection``
strings are still respected.
59 changes: 58 additions & 1 deletion src/openfe_analysis/rmsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,63 @@ def _single_frame(self) -> None:
)


def _select_ligand(
universe: mda.Universe,
ligand_selection: str,
) -> mda.AtomGroup:
"""
Select the ligand AtomGroup, handling cofactors with the same resname.

In OpenFE hybrid topologies, the alchemical ligand atoms have tempfactors
(b-factors) of 0.25, 0.5, or 0.75, while cofactors typically have
tempfactors of 0.0. When the default selection `"resname UNK"` matches
multiple residues, the ligand is identified as the residue containing
atoms with tempfactors strictly between 0 and 1.

Parameters
----------
universe : mda.Universe
The MDAnalysis universe to select from.
ligand_selection : str
MDAnalysis selection string. If this is the default
`"resname UNK"` and multiple residues match, the ligand is
auto-detected using tempfactors.

Returns
-------
mda.AtomGroup
The selected ligand atoms.
"""
ligand = universe.select_atoms(ligand_selection)

# If user provided a custom selection, only one residue matches,
# or no atoms matched, return the selection as-is.
if ligand_selection != "resname UNK" or len(ligand.residues) <= 1 or ligand.n_atoms == 0:
return ligand

# Multiple UNK residues: identify the ligand by hybrid tempfactors.
# In OpenFE hybrid topologies, alchemical atoms have tempfactors
# of 0.25 (state A unique), 0.5 (shared), or 0.75 (state B unique).
try:
has_tempfactors = True
universe.atoms.tempfactors
except AttributeError:
has_tempfactors = False

if has_tempfactors:
hybrid_residues = [
res for res in ligand.residues if any(0.0 < tf < 1.0 for tf in res.atoms.tempfactors)
]

if len(hybrid_residues) == 1:
return hybrid_residues[0].atoms
elif len(hybrid_residues) > 1:
return max(hybrid_residues, key=lambda r: len(r.atoms)).atoms

# Fallback: no hybrid residues found or no tempfactors available
return max(ligand.residues, key=lambda r: len(r.atoms)).atoms


def gather_rms_data(
pdb_topology: pathlib.Path,
dataset: pathlib.Path,
Expand Down Expand Up @@ -335,7 +392,7 @@ def gather_rms_data(
universe = create_universe_single_state(u_top._topology, ds, state_idx)

prot = universe.select_atoms(protein_selection)
ligand = universe.select_atoms(ligand_selection)
ligand = _select_ligand(universe, ligand_selection)

if prot:
apply_complex_alignment_transformations(
Expand Down
83 changes: 75 additions & 8 deletions src/openfe_analysis/tests/test_gather_rms_data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import MDAnalysis as mda
import numpy as np
from numpy.testing import assert_allclose

from openfe_analysis.rmsd import gather_rms_data
from openfe_analysis.rmsd import _select_ligand, gather_rms_data


def test_gather_rms_data_regression(simulation_nc, hybrid_system_pdb):
Expand Down Expand Up @@ -31,7 +32,6 @@ def test_gather_rms_data_regression(simulation_nc, hybrid_system_pdb):
rtol=1e-3,
)
assert len(output["protein_2D_RMSD"]) == 3
# 15 entries because 6 * 6 frames // 2
assert len(output["protein_2D_RMSD"][0]) == 15
assert_allclose(
output["protein_2D_RMSD"][0][:6],
Expand All @@ -49,7 +49,6 @@ def test_gather_rms_data_regression_skippednc(simulation_skipped_nc, hybrid_syst

assert_allclose(output["time(ps)"], np.arange(0, 5001, 100))
assert len(output["protein_RMSD"]) == 11
# RMSD is low for this multichain protein
assert_allclose(
output["protein_RMSD"][0][:6],
[0, 1.089747, 1.006143, 1.045068, 1.476353, 1.332893],
Expand All @@ -68,7 +67,6 @@ def test_gather_rms_data_regression_skippednc(simulation_skipped_nc, hybrid_syst
rtol=1e-3,
)
assert len(output["protein_2D_RMSD"]) == 11
# 15 entries because 6 * 6 frames // 2
assert len(output["protein_2D_RMSD"][0]) == 1275
assert_allclose(
output["protein_2D_RMSD"][0][:6],
Expand All @@ -82,13 +80,82 @@ def test_gather_rms_data_ligand_only(simulation_skipped_nc, hybrid_system_skippe
hybrid_system_skipped_pdb,
simulation_skipped_nc,
skip=100,
protein_selection="resname DOESNOTEXIST", # no protein
protein_selection="resname DOESNOTEXIST",
)

# No protein results
assert len(output["protein_RMSD"]) == 0
assert len(output["protein_2D_RMSD"]) == 0

# Ligand results should still be present
assert len(output["ligand_RMSD"]) > 0
assert len(output["ligand_wander"]) > 0


def _make_test_universe(tempfactors, resids, resnames=None):
n_atoms = len(tempfactors)
n_residues = len(set(resids))
unique_resids = list(dict.fromkeys(resids))
resid_to_idx = {r: i for i, r in enumerate(unique_resids)}
atom_resindex = [resid_to_idx[r] for r in resids]

if resnames is None:
resnames = ["UNK"] * n_residues

u = mda.Universe.empty(
n_atoms=n_atoms,
n_residues=n_residues,
atom_resindex=atom_resindex,
)
u.add_TopologyAttr("names", [f"C{i}" for i in range(1, n_atoms + 1)])
u.add_TopologyAttr("resnames", resnames)
u.add_TopologyAttr("resids", unique_resids)
u.add_TopologyAttr("tempfactors", tempfactors)
return u


def test_select_ligand_single_unk():
u = _make_test_universe(
tempfactors=[0.25, 0.50, 0.75],
resids=[1, 1, 1],
)
ligand = _select_ligand(u, "resname UNK")
assert len(ligand) == 3
assert set(ligand.resids) == {1}


def test_select_ligand_cofactor_present():
u = _make_test_universe(
tempfactors=[0.25, 0.50, 0.75, 0.00],
resids=[1, 1, 1, 2],
)
ligand = _select_ligand(u, "resname UNK")
assert len(ligand) == 3
assert set(ligand.resids) == {1}


def test_select_ligand_multiple_hybrid():
u = _make_test_universe(
tempfactors=[0.25, 0.50, 0.75, 0.25, 0.50],
resids=[1, 1, 1, 2, 2],
)
ligand = _select_ligand(u, "resname UNK")
assert len(ligand) == 3
assert set(ligand.resids) == {1}


def test_select_ligand_no_hybrid_fallback():
u = _make_test_universe(
tempfactors=[0.00, 0.00, 0.00, 0.00],
resids=[1, 1, 2, 2],
)
ligand = _select_ligand(u, "resname UNK")
assert len(ligand) == 2


def test_select_ligand_custom_selection_unchanged():
u = _make_test_universe(
tempfactors=[0.25, 0.50, 0.75, 0.00],
resids=[1, 1, 1, 2],
resnames=["LIG", "LIG", "LIG", "COF"],
)
ligand = _select_ligand(u, "resname LIG")
assert len(ligand) == 3
assert set(ligand.resids) == {1}