diff --git a/news/118-fix-ligand-cofactor-selection.rst b/news/118-fix-ligand-cofactor-selection.rst new file mode 100644 index 0000000..2aa6035 --- /dev/null +++ b/news/118-fix-ligand-cofactor-selection.rst @@ -0,0 +1,6 @@ +**Fixed:** + +* ``gather_rms_data`` now correctly identifies the ligand when cofactors + share the same residue name ``UNK``. The ligand is auto-detected using + hybrid topology tempfactors (b-factors), while custom ``ligand_selection`` + strings are still respected. diff --git a/src/openfe_analysis/rmsd.py b/src/openfe_analysis/rmsd.py index 9c86200..71eba08 100644 --- a/src/openfe_analysis/rmsd.py +++ b/src/openfe_analysis/rmsd.py @@ -256,6 +256,63 @@ def _single_frame(self) -> None: ) +def _select_ligand( + universe: mda.Universe, + ligand_selection: str, +) -> mda.AtomGroup: + """ + Select the ligand AtomGroup, handling cofactors with the same resname. + + In OpenFE hybrid topologies, the alchemical ligand atoms have tempfactors + (b-factors) of 0.25, 0.5, or 0.75, while cofactors typically have + tempfactors of 0.0. When the default selection `"resname UNK"` matches + multiple residues, the ligand is identified as the residue containing + atoms with tempfactors strictly between 0 and 1. + + Parameters + ---------- + universe : mda.Universe + The MDAnalysis universe to select from. + ligand_selection : str + MDAnalysis selection string. If this is the default + `"resname UNK"` and multiple residues match, the ligand is + auto-detected using tempfactors. + + Returns + ------- + mda.AtomGroup + The selected ligand atoms. + """ + ligand = universe.select_atoms(ligand_selection) + + # If user provided a custom selection, only one residue matches, + # or no atoms matched, return the selection as-is. + if ligand_selection != "resname UNK" or len(ligand.residues) <= 1 or ligand.n_atoms == 0: + return ligand + + # Multiple UNK residues: identify the ligand by hybrid tempfactors. + # In OpenFE hybrid topologies, alchemical atoms have tempfactors + # of 0.25 (state A unique), 0.5 (shared), or 0.75 (state B unique). + try: + has_tempfactors = True + universe.atoms.tempfactors + except AttributeError: + has_tempfactors = False + + if has_tempfactors: + hybrid_residues = [ + res for res in ligand.residues if any(0.0 < tf < 1.0 for tf in res.atoms.tempfactors) + ] + + if len(hybrid_residues) == 1: + return hybrid_residues[0].atoms + elif len(hybrid_residues) > 1: + return max(hybrid_residues, key=lambda r: len(r.atoms)).atoms + + # Fallback: no hybrid residues found or no tempfactors available + return max(ligand.residues, key=lambda r: len(r.atoms)).atoms + + def gather_rms_data( pdb_topology: pathlib.Path, dataset: pathlib.Path, @@ -335,7 +392,7 @@ def gather_rms_data( universe = create_universe_single_state(u_top._topology, ds, state_idx) prot = universe.select_atoms(protein_selection) - ligand = universe.select_atoms(ligand_selection) + ligand = _select_ligand(universe, ligand_selection) if prot: apply_complex_alignment_transformations( diff --git a/src/openfe_analysis/tests/test_gather_rms_data.py b/src/openfe_analysis/tests/test_gather_rms_data.py index 9734548..5f31747 100644 --- a/src/openfe_analysis/tests/test_gather_rms_data.py +++ b/src/openfe_analysis/tests/test_gather_rms_data.py @@ -1,7 +1,8 @@ +import MDAnalysis as mda import numpy as np from numpy.testing import assert_allclose -from openfe_analysis.rmsd import gather_rms_data +from openfe_analysis.rmsd import _select_ligand, gather_rms_data def test_gather_rms_data_regression(simulation_nc, hybrid_system_pdb): @@ -31,7 +32,6 @@ def test_gather_rms_data_regression(simulation_nc, hybrid_system_pdb): rtol=1e-3, ) assert len(output["protein_2D_RMSD"]) == 3 - # 15 entries because 6 * 6 frames // 2 assert len(output["protein_2D_RMSD"][0]) == 15 assert_allclose( output["protein_2D_RMSD"][0][:6], @@ -49,7 +49,6 @@ def test_gather_rms_data_regression_skippednc(simulation_skipped_nc, hybrid_syst assert_allclose(output["time(ps)"], np.arange(0, 5001, 100)) assert len(output["protein_RMSD"]) == 11 - # RMSD is low for this multichain protein assert_allclose( output["protein_RMSD"][0][:6], [0, 1.089747, 1.006143, 1.045068, 1.476353, 1.332893], @@ -68,7 +67,6 @@ def test_gather_rms_data_regression_skippednc(simulation_skipped_nc, hybrid_syst rtol=1e-3, ) assert len(output["protein_2D_RMSD"]) == 11 - # 15 entries because 6 * 6 frames // 2 assert len(output["protein_2D_RMSD"][0]) == 1275 assert_allclose( output["protein_2D_RMSD"][0][:6], @@ -82,13 +80,82 @@ def test_gather_rms_data_ligand_only(simulation_skipped_nc, hybrid_system_skippe hybrid_system_skipped_pdb, simulation_skipped_nc, skip=100, - protein_selection="resname DOESNOTEXIST", # no protein + protein_selection="resname DOESNOTEXIST", ) - # No protein results assert len(output["protein_RMSD"]) == 0 assert len(output["protein_2D_RMSD"]) == 0 - - # Ligand results should still be present assert len(output["ligand_RMSD"]) > 0 assert len(output["ligand_wander"]) > 0 + + +def _make_test_universe(tempfactors, resids, resnames=None): + n_atoms = len(tempfactors) + n_residues = len(set(resids)) + unique_resids = list(dict.fromkeys(resids)) + resid_to_idx = {r: i for i, r in enumerate(unique_resids)} + atom_resindex = [resid_to_idx[r] for r in resids] + + if resnames is None: + resnames = ["UNK"] * n_residues + + u = mda.Universe.empty( + n_atoms=n_atoms, + n_residues=n_residues, + atom_resindex=atom_resindex, + ) + u.add_TopologyAttr("names", [f"C{i}" for i in range(1, n_atoms + 1)]) + u.add_TopologyAttr("resnames", resnames) + u.add_TopologyAttr("resids", unique_resids) + u.add_TopologyAttr("tempfactors", tempfactors) + return u + + +def test_select_ligand_single_unk(): + u = _make_test_universe( + tempfactors=[0.25, 0.50, 0.75], + resids=[1, 1, 1], + ) + ligand = _select_ligand(u, "resname UNK") + assert len(ligand) == 3 + assert set(ligand.resids) == {1} + + +def test_select_ligand_cofactor_present(): + u = _make_test_universe( + tempfactors=[0.25, 0.50, 0.75, 0.00], + resids=[1, 1, 1, 2], + ) + ligand = _select_ligand(u, "resname UNK") + assert len(ligand) == 3 + assert set(ligand.resids) == {1} + + +def test_select_ligand_multiple_hybrid(): + u = _make_test_universe( + tempfactors=[0.25, 0.50, 0.75, 0.25, 0.50], + resids=[1, 1, 1, 2, 2], + ) + ligand = _select_ligand(u, "resname UNK") + assert len(ligand) == 3 + assert set(ligand.resids) == {1} + + +def test_select_ligand_no_hybrid_fallback(): + u = _make_test_universe( + tempfactors=[0.00, 0.00, 0.00, 0.00], + resids=[1, 1, 2, 2], + ) + ligand = _select_ligand(u, "resname UNK") + assert len(ligand) == 2 + + +def test_select_ligand_custom_selection_unchanged(): + u = _make_test_universe( + tempfactors=[0.25, 0.50, 0.75, 0.00], + resids=[1, 1, 1, 2], + resnames=["LIG", "LIG", "LIG", "COF"], + ) + ligand = _select_ligand(u, "resname LIG") + assert len(ligand) == 3 + assert set(ligand.resids) == {1}