From 3f6faf1ae925c5f624a9a580dfb7e44c7a0bae7d Mon Sep 17 00:00:00 2001 From: "Marcus D. Collins" Date: Fri, 27 Mar 2026 22:40:46 +0000 Subject: [PATCH 1/4] fix: pre-process CIF files to remove mixed ATOM/HETATM altlocs at the same residue position Atomworks infers chain sequences from the atom array when no entity_poly_seq is present. If a residue has both a canonical ATOM record (e.g. CYS, altloc A) and a modified HETATM record (e.g. CSO, altlocs B/C) at the same sequence position, atomworks treats them as two sequential residues and inserts a spurious extra residue into the sequence passed to Boltz2. Adds resolve_mixed_hetatm_atom_altlocs() in cif_utils.py, which loads the CIF with all altlocs, detects positions where ATOM and HETATM records coexist with different residue names, logs a warning per affected position, removes the HETATM records, and returns a corrected temporary CIF for atomworks to parse. Called in get_reward_function_and_structure() before atomworks.parse(). Fixes diff-use/sampleworks#194 Co-Authored-By: Claude Sonnet 4.6 --- src/sampleworks/utils/cif_utils.py | 82 ++++++- .../utils/guidance_script_utils.py | 4 +- tests/utils/test_cif_utils.py | 214 ++++++++++++++++++ 3 files changed, 298 insertions(+), 2 deletions(-) create mode 100644 tests/utils/test_cif_utils.py diff --git a/src/sampleworks/utils/cif_utils.py b/src/sampleworks/utils/cif_utils.py index 3305d48a..664363ec 100644 --- a/src/sampleworks/utils/cif_utils.py +++ b/src/sampleworks/utils/cif_utils.py @@ -1,12 +1,16 @@ import itertools +import os +import tempfile from collections import OrderedDict from collections.abc import Iterable from pathlib import Path +import numpy as np from atomworks.io.utils.io_utils import load_any +from biotite.structure import AtomArray, AtomArrayStack from loguru import logger -from sampleworks.utils.atom_array_utils import find_all_altloc_ids, select_altloc +from sampleworks.utils.atom_array_utils import find_all_altloc_ids, save_structure_to_cif, select_altloc def find_altloc_selections( @@ -147,3 +151,79 @@ def find_consecutive_residues( current_membership = res_membership if len(res_membership) > 1 else None if start is not None and next_res_id: yield chain, start, next_res_id, current_membership + + +def resolve_mixed_hetatm_atom_altlocs(cif_path: Path | str) -> Path: + """Pre-process a CIF file where ATOM and HETATM records with different residue names + share the same (chain, residue) position via different altloc IDs. + + This occurs when a residue has a modified form (e.g. CSO, cysteic acid) as some + altlocs and the canonical form (e.g. CYS) as another altloc at the same sequence + position. Atomworks treats these as two sequential residues rather than alternates, + inserting a spurious extra residue into the sequence fed to Boltz2. + + The fix: for each affected position, remove the HETATM (modified) records and keep + only the ATOM (canonical) records. Also cleans up the ``_struct_conn`` covale bonds + referencing the removed residues, since ``save_structure_to_cif`` only writes + ``_atom_site``. + + A warning is logged for every affected (chain, residue) position. + + Parameters + ---------- + cif_path + Path to the input CIF file. + + Returns + ------- + Path + Path to a fixed temporary CIF file if any positions were modified, or the + original ``cif_path`` unchanged if no issues were found. + """ + cif_path = Path(cif_path) + atom_array = load_any(str(cif_path), altloc="all", extra_fields=["occupancy", "b_factor"]) + if isinstance(atom_array, AtomArrayStack): + atom_array = atom_array[0] + + chain_id = atom_array.chain_id + res_id = atom_array.res_id + res_name = atom_array.res_name + hetero = atom_array.hetero + + keep_mask = np.ones(len(atom_array), dtype=bool) + found_any = False + + for chain in np.unique(chain_id): + for rid in np.unique(res_id[chain_id == chain]): + pos_mask = (chain_id == chain) & (res_id == rid) + has_atom = np.any(~hetero[pos_mask]) + has_hetatm = np.any(hetero[pos_mask]) + + if not (has_atom and has_hetatm): + continue + + atom_names = np.unique(res_name[pos_mask & ~hetero]) + hetatm_names = np.unique(res_name[pos_mask & hetero]) + + if set(atom_names) == set(hetatm_names): + continue # Same residue name on both — not the case we're fixing + + logger.warning( + f"Chain {chain}, residue {rid}: found mixed ATOM {list(atom_names)} " + f"and HETATM {list(hetatm_names)} records with different residue names " + f"at the same sequence position. Removing HETATM records to prevent " + f"atomworks from inserting a duplicate residue into the Boltz2 input sequence." + ) + keep_mask[pos_mask & hetero] = False + found_any = True + + if not found_any: + return cif_path + + fixed_array = atom_array[keep_mask] + tmp_fd, tmp_path_str = tempfile.mkstemp(suffix=".cif", prefix="sampleworks_fixed_cif_") + os.close(tmp_fd) + tmp_path = Path(tmp_path_str) + save_structure_to_cif(fixed_array, tmp_path) + logger.info(f"Wrote altloc-fixed CIF to temporary file: {tmp_path}") + return tmp_path diff --git a/src/sampleworks/utils/guidance_script_utils.py b/src/sampleworks/utils/guidance_script_utils.py index 40b633f5..de0563bf 100644 --- a/src/sampleworks/utils/guidance_script_utils.py +++ b/src/sampleworks/utils/guidance_script_utils.py @@ -30,6 +30,7 @@ NoiseSpaceDPSScaler, NoScalingScaler, ) +from sampleworks.utils.cif_utils import resolve_mixed_hetatm_atom_altlocs from sampleworks.utils.guidance_constants import ( GuidanceType, StructurePredictor, @@ -228,8 +229,9 @@ def get_reward_function_and_structure( structure_path: str | Path, ) -> tuple[RealSpaceRewardFunction, dict[str, Any]]: logger.debug(f"Loading structure from {structure_path}") + structure_path = resolve_mixed_hetatm_atom_altlocs(Path(structure_path)) structure = parse( - Path(structure_path), + structure_path, hydrogen_policy="remove", add_missing_atoms=False, ccd_mirror_path=None, diff --git a/tests/utils/test_cif_utils.py b/tests/utils/test_cif_utils.py new file mode 100644 index 00000000..7500c486 --- /dev/null +++ b/tests/utils/test_cif_utils.py @@ -0,0 +1,214 @@ +"""Tests for cif_utils module.""" + +import logging +from pathlib import Path + +import numpy as np +import pytest +from atomworks.io.utils.io_utils import load_any +from biotite.structure import Atom, array, AtomArray, AtomArrayStack + +from sampleworks.utils.atom_array_utils import save_structure_to_cif +from sampleworks.utils.cif_utils import resolve_mixed_hetatm_atom_altlocs + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _atom(chain_id: str, res_id: int, res_name: str, hetero: bool, atom_name: str = "CA") -> Atom: + return Atom( + [0.0, 0.0, 0.0], + chain_id=chain_id, + res_id=res_id, + res_name=res_name, + hetero=hetero, + atom_name=atom_name, + element="C", + ) + + +def _write_cif(atoms: list[Atom], path: Path) -> Path: + arr = array(atoms) + arr.set_annotation("occupancy", np.ones(len(atoms), dtype=np.float32)) + arr.set_annotation("b_factor", np.zeros(len(atoms), dtype=np.float32)) + save_structure_to_cif(arr, path) + return path + + +def _load(path: Path) -> AtomArray: + result = load_any(str(path), altloc="all", extra_fields=["occupancy", "b_factor"]) + if isinstance(result, AtomArrayStack): + return result[0] + return result + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def cif_clean(tmp_path) -> Path: + """CIF with only ATOM records — no mixed ATOM/HETATM at the same position.""" + return _write_cif( + [ + _atom("A", 99, "VAL", hetero=False), + _atom("A", 100, "VAL", hetero=False), + _atom("A", 101, "CYS", hetero=False), + _atom("A", 102, "ALA", hetero=False), + ], + tmp_path / "clean.cif", + ) + + +@pytest.fixture +def cif_mixed(tmp_path) -> Path: + """CIF mimicking the 6NI6 bug: CYS (ATOM) and CSO (HETATM) share residue 101.""" + return _write_cif( + [ + _atom("A", 100, "VAL", hetero=False), + _atom("A", 101, "CYS", hetero=False), # canonical — keep + _atom("A", 101, "CSO", hetero=True), # modified altloc — remove + _atom("A", 102, "ALA", hetero=False), + ], + tmp_path / "mixed.cif", + ) + + +@pytest.fixture +def cif_standalone_ligand(tmp_path) -> Path: + """CIF where a HETATM ligand lives at its own res_id with no overlapping ATOM record.""" + return _write_cif( + [ + _atom("A", 100, "VAL", hetero=False), + _atom("A", 101, "CYS", hetero=False), + _atom("A", 999, "ATP", hetero=True), # ligand at unique position — untouched + ], + tmp_path / "ligand.cif", + ) + + +@pytest.fixture +def cif_same_resname_hetatm(tmp_path) -> Path: + """CIF where ATOM and HETATM at the same position share the same residue name.""" + return _write_cif( + [ + _atom("A", 101, "CYS", hetero=False, atom_name="N"), + _atom("A", 101, "CYS", hetero=True, atom_name="OG"), # same name — not our bug + ], + tmp_path / "same_resname.cif", + ) + + +@pytest.fixture +def cif_multiple_mixed(tmp_path) -> Path: + """CIF with two separate mixed ATOM/HETATM positions on different chains.""" + return _write_cif( + [ + _atom("A", 101, "CYS", hetero=False), # chain A pos 101: canonical + _atom("A", 101, "CSO", hetero=True), # chain A pos 101: modified + _atom("B", 50, "SER", hetero=False), # chain B pos 50: canonical + _atom("B", 50, "SEP", hetero=True), # chain B pos 50: phosphoserine + ], + tmp_path / "multi_mixed.cif", + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestResolveMixedHetatmAtomAltlocs: + + # --- No-op cases --- + + def test_clean_cif_returns_original_path(self, cif_clean): + assert resolve_mixed_hetatm_atom_altlocs(cif_clean) == cif_clean + + def test_standalone_hetatm_ligand_untouched(self, cif_standalone_ligand): + """HETATM at a unique res_id (no overlapping ATOM) must not be removed.""" + assert resolve_mixed_hetatm_atom_altlocs(cif_standalone_ligand) == cif_standalone_ligand + + def test_same_resname_hetatm_not_removed(self, cif_same_resname_hetatm): + """ATOM and HETATM sharing the same residue name at the same position are not our bug.""" + assert resolve_mixed_hetatm_atom_altlocs(cif_same_resname_hetatm) == cif_same_resname_hetatm + + # --- Fix applied --- + + def test_mixed_returns_new_path(self, cif_mixed): + assert resolve_mixed_hetatm_atom_altlocs(cif_mixed) != cif_mixed + + def test_hetatm_records_removed_at_mixed_position(self, cif_mixed): + result_path = resolve_mixed_hetatm_atom_altlocs(cif_mixed) + arr = _load(result_path) + at_101 = arr[(arr.chain_id == "A") & (arr.res_id == 101)] + assert not np.any(at_101.hetero) + assert list(np.unique(at_101.res_name)) == ["CYS"] + + def test_canonical_atom_count_at_mixed_position(self, cif_mixed): + """Exactly one atom (the CA of CYS) should remain at position 101.""" + result_path = resolve_mixed_hetatm_atom_altlocs(cif_mixed) + arr = _load(result_path) + at_101 = arr[(arr.chain_id == "A") & (arr.res_id == 101)] + assert len(at_101) == 1 + + def test_other_residues_unaffected(self, cif_mixed): + result_path = resolve_mixed_hetatm_atom_altlocs(cif_mixed) + arr = _load(result_path) + for rid, expected_name in [(100, "VAL"), (102, "ALA")]: + res = arr[(arr.chain_id == "A") & (arr.res_id == rid)] + assert len(res) == 1 + assert res.res_name[0] == expected_name + + def test_multiple_mixed_positions_all_fixed(self, cif_multiple_mixed): + result_path = resolve_mixed_hetatm_atom_altlocs(cif_multiple_mixed) + arr = _load(result_path) + for chain, rid, expected_name in [("A", 101, "CYS"), ("B", 50, "SER")]: + at_pos = arr[(arr.chain_id == chain) & (arr.res_id == rid)] + assert not np.any(at_pos.hetero) + assert list(np.unique(at_pos.res_name)) == [expected_name] + + # --- Warning --- + + def test_warning_logged_for_mixed_position(self, cif_mixed, caplog): + with caplog.at_level(logging.WARNING): + resolve_mixed_hetatm_atom_altlocs(cif_mixed) + assert "101" in caplog.text + assert "CSO" in caplog.text + + def test_no_warning_for_clean_cif(self, cif_clean, caplog): + with caplog.at_level(logging.WARNING): + resolve_mixed_hetatm_atom_altlocs(cif_clean) + assert "HETATM" not in caplog.text + + def test_warning_per_position_for_multiple_mixed(self, cif_multiple_mixed, caplog): + with caplog.at_level(logging.WARNING): + resolve_mixed_hetatm_atom_altlocs(cif_multiple_mixed) + assert "CSO" in caplog.text + assert "SEP" in caplog.text + + # --- Real CIF --- + + def test_real_6ni6_cif_is_fixed(self, resources_dir): + """The 6NI6 density input CIF has CSO (altlocs B/C) and CYS (altloc A) at residue 101.""" + cif_path = resources_dir / "6NI6" / "6NI6_single_001_density_input.cif" + result_path = resolve_mixed_hetatm_atom_altlocs(cif_path) + assert result_path != cif_path + + def test_real_6ni6_residue_101_only_cys(self, resources_dir): + cif_path = resources_dir / "6NI6" / "6NI6_single_001_density_input.cif" + result_path = resolve_mixed_hetatm_atom_altlocs(cif_path) + arr = _load(result_path) + at_101 = arr[(arr.chain_id == "A") & (arr.res_id == 101)] + assert not np.any(at_101.hetero), "No HETATM records should remain at position 101" + assert all(n == "CYS" for n in at_101.res_name), "Only CYS should remain at position 101" + + def test_real_6ni6_warning_mentions_residue_and_modified_name(self, resources_dir, caplog): + cif_path = resources_dir / "6NI6" / "6NI6_single_001_density_input.cif" + with caplog.at_level(logging.WARNING): + resolve_mixed_hetatm_atom_altlocs(cif_path) + assert "101" in caplog.text + assert "CSO" in caplog.text From 164e63bf03287c92da182cebdfeb4de4b484bc7b Mon Sep 17 00:00:00 2001 From: "Marcus D. Collins" Date: Fri, 27 Mar 2026 19:25:44 -0700 Subject: [PATCH 2/4] fix:human cleanup of previous commit; resolving 6NI5/6 bug with HETATM and ATOM mixed in an altloc, resolves https://github.com/diff-use/sampleworks/issues/194 --- src/sampleworks/utils/cif_utils.py | 21 ++++++++++++++------- tests/utils/test_cif_utils.py | 14 ++++++-------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/src/sampleworks/utils/cif_utils.py b/src/sampleworks/utils/cif_utils.py index 664363ec..a8d0671d 100644 --- a/src/sampleworks/utils/cif_utils.py +++ b/src/sampleworks/utils/cif_utils.py @@ -1,5 +1,4 @@ import itertools -import os import tempfile from collections import OrderedDict from collections.abc import Iterable @@ -7,10 +6,14 @@ import numpy as np from atomworks.io.utils.io_utils import load_any -from biotite.structure import AtomArray, AtomArrayStack +from biotite.structure import AtomArrayStack from loguru import logger -from sampleworks.utils.atom_array_utils import find_all_altloc_ids, save_structure_to_cif, select_altloc +from sampleworks.utils.atom_array_utils import ( + find_all_altloc_ids, + save_structure_to_cif, + select_altloc, +) def find_altloc_selections( @@ -181,7 +184,7 @@ def resolve_mixed_hetatm_atom_altlocs(cif_path: Path | str) -> Path: original ``cif_path`` unchanged if no issues were found. """ cif_path = Path(cif_path) - atom_array = load_any(str(cif_path), altloc="all", extra_fields=["occupancy", "b_factor"]) + atom_array = load_any(cif_path, altloc="all", extra_fields=["occupancy", "b_factor"]) if isinstance(atom_array, AtomArrayStack): atom_array = atom_array[0] @@ -221,9 +224,13 @@ def resolve_mixed_hetatm_atom_altlocs(cif_path: Path | str) -> Path: return cif_path fixed_array = atom_array[keep_mask] - tmp_fd, tmp_path_str = tempfile.mkstemp(suffix=".cif", prefix="sampleworks_fixed_cif_") - os.close(tmp_fd) - tmp_path = Path(tmp_path_str) + with tempfile.NamedTemporaryFile( + mode="w", suffix=".cif", prefix="sampleworks_fixed_cif_", delete=False + ) as tmp_file: + tmp_path = Path(tmp_file.name) + + save_structure_to_cif(fixed_array, tmp_path) + logger.info(f"Wrote altloc-fixed CIF to temporary file: {tmp_path}") save_structure_to_cif(fixed_array, tmp_path) logger.info(f"Wrote altloc-fixed CIF to temporary file: {tmp_path}") return tmp_path diff --git a/tests/utils/test_cif_utils.py b/tests/utils/test_cif_utils.py index 7500c486..5c173034 100644 --- a/tests/utils/test_cif_utils.py +++ b/tests/utils/test_cif_utils.py @@ -6,8 +6,7 @@ import numpy as np import pytest from atomworks.io.utils.io_utils import load_any -from biotite.structure import Atom, array, AtomArray, AtomArrayStack - +from biotite.structure import array, Atom, AtomArray, AtomArrayStack from sampleworks.utils.atom_array_utils import save_structure_to_cif from sampleworks.utils.cif_utils import resolve_mixed_hetatm_atom_altlocs @@ -38,7 +37,7 @@ def _write_cif(atoms: list[Atom], path: Path) -> Path: def _load(path: Path) -> AtomArray: - result = load_any(str(path), altloc="all", extra_fields=["occupancy", "b_factor"]) + result = load_any(path, altloc="all", extra_fields=["occupancy", "b_factor"]) if isinstance(result, AtomArrayStack): return result[0] return result @@ -70,7 +69,7 @@ def cif_mixed(tmp_path) -> Path: [ _atom("A", 100, "VAL", hetero=False), _atom("A", 101, "CYS", hetero=False), # canonical — keep - _atom("A", 101, "CSO", hetero=True), # modified altloc — remove + _atom("A", 101, "CSO", hetero=True), # modified altloc — remove _atom("A", 102, "ALA", hetero=False), ], tmp_path / "mixed.cif", @@ -108,9 +107,9 @@ def cif_multiple_mixed(tmp_path) -> Path: return _write_cif( [ _atom("A", 101, "CYS", hetero=False), # chain A pos 101: canonical - _atom("A", 101, "CSO", hetero=True), # chain A pos 101: modified - _atom("B", 50, "SER", hetero=False), # chain B pos 50: canonical - _atom("B", 50, "SEP", hetero=True), # chain B pos 50: phosphoserine + _atom("A", 101, "CSO", hetero=True), # chain A pos 101: modified + _atom("B", 50, "SER", hetero=False), # chain B pos 50: canonical + _atom("B", 50, "SEP", hetero=True), # chain B pos 50: phosphoserine ], tmp_path / "multi_mixed.cif", ) @@ -122,7 +121,6 @@ def cif_multiple_mixed(tmp_path) -> Path: class TestResolveMixedHetatmAtomAltlocs: - # --- No-op cases --- def test_clean_cif_returns_original_path(self, cif_clean): From aabb52b133e61f3bdf95152daefe7ec54af0c5df Mon Sep 17 00:00:00 2001 From: "Marcus D. Collins" Date: Sat, 28 Mar 2026 11:03:50 -0700 Subject: [PATCH 3/4] removed duplicated redundant code. Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- src/sampleworks/utils/cif_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/sampleworks/utils/cif_utils.py b/src/sampleworks/utils/cif_utils.py index a8d0671d..99279d08 100644 --- a/src/sampleworks/utils/cif_utils.py +++ b/src/sampleworks/utils/cif_utils.py @@ -231,6 +231,5 @@ def resolve_mixed_hetatm_atom_altlocs(cif_path: Path | str) -> Path: save_structure_to_cif(fixed_array, tmp_path) logger.info(f"Wrote altloc-fixed CIF to temporary file: {tmp_path}") - save_structure_to_cif(fixed_array, tmp_path) - logger.info(f"Wrote altloc-fixed CIF to temporary file: {tmp_path}") + return tmp_path return tmp_path From caf7b1bd932ddd32446f64d489d97b9817a4f264 Mon Sep 17 00:00:00 2001 From: "Marcus D. Collins" Date: Mon, 30 Mar 2026 21:33:37 +0000 Subject: [PATCH 4/4] fix: final cleanup to preprocess CIF files for mixed HETATM/ATOM altlocs; resolves https://github.com/diff-use/sampleworks/issues/194 --- scripts/eval/run_and_process_tortoize.py | 4 ++-- src/sampleworks/utils/cif_utils.py | 18 ++++++++++-------- src/sampleworks/utils/guidance_script_utils.py | 7 +++++-- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/scripts/eval/run_and_process_tortoize.py b/scripts/eval/run_and_process_tortoize.py index 0b8aa678..c58ca74e 100644 --- a/scripts/eval/run_and_process_tortoize.py +++ b/scripts/eval/run_and_process_tortoize.py @@ -40,11 +40,11 @@ def main(args: argparse.Namespace) -> None: all_residue_results, all_protein_results = tuple(zip(*tortoize_results, strict=True)) - output_file = "tortoize_residues.csv" + output_file = args.grid_search_results_path / "tortoize_residues.csv" pd.concat(all_residue_results).to_csv(output_file, index=False) logger.info(f"Residue results saved to {output_file}") - output_file = "tortoize_protein_stats.csv" + output_file = args.grid_search_results_path / "tortoize_protein_stats.csv" pd.concat(all_protein_results).to_csv(output_file, index=False) logger.info(f"Protein-level stats saved to {output_file}") diff --git a/src/sampleworks/utils/cif_utils.py b/src/sampleworks/utils/cif_utils.py index 99279d08..26b01098 100644 --- a/src/sampleworks/utils/cif_utils.py +++ b/src/sampleworks/utils/cif_utils.py @@ -165,6 +165,8 @@ def resolve_mixed_hetatm_atom_altlocs(cif_path: Path | str) -> Path: position. Atomworks treats these as two sequential residues rather than alternates, inserting a spurious extra residue into the sequence fed to Boltz2. + Should Atomworks fix the underlying issue in the future, we should remove this method. + The fix: for each affected position, remove the HETATM (modified) records and keep only the ATOM (canonical) records. Also cleans up the ``_struct_conn`` covale bonds referencing the removed residues, since ``save_structure_to_cif`` only writes @@ -199,21 +201,22 @@ def resolve_mixed_hetatm_atom_altlocs(cif_path: Path | str) -> Path: for chain in np.unique(chain_id): for rid in np.unique(res_id[chain_id == chain]): pos_mask = (chain_id == chain) & (res_id == rid) - has_atom = np.any(~hetero[pos_mask]) + has_no_hetatm = np.any(~hetero[pos_mask]) has_hetatm = np.any(hetero[pos_mask]) - if not (has_atom and has_hetatm): + if not (has_no_hetatm and has_hetatm): + # there are either only HETATM or only ATOM records at this position, or none at all continue - atom_names = np.unique(res_name[pos_mask & ~hetero]) - hetatm_names = np.unique(res_name[pos_mask & hetero]) + atom_res_names = np.unique(res_name[pos_mask & ~hetero]) + hetatm_res_names = np.unique(res_name[pos_mask & hetero]) - if set(atom_names) == set(hetatm_names): + if set(atom_res_names) == set(hetatm_res_names): continue # Same residue name on both — not the case we're fixing logger.warning( - f"Chain {chain}, residue {rid}: found mixed ATOM {list(atom_names)} " - f"and HETATM {list(hetatm_names)} records with different residue names " + f"Chain {chain}, residue {rid}: found mixed ATOM {list(atom_res_names)} " + f"and HETATM {list(hetatm_res_names)} records with different residue names " f"at the same sequence position. Removing HETATM records to prevent " f"atomworks from inserting a duplicate residue into the Boltz2 input sequence." ) @@ -232,4 +235,3 @@ def resolve_mixed_hetatm_atom_altlocs(cif_path: Path | str) -> Path: save_structure_to_cif(fixed_array, tmp_path) logger.info(f"Wrote altloc-fixed CIF to temporary file: {tmp_path}") return tmp_path - return tmp_path diff --git a/src/sampleworks/utils/guidance_script_utils.py b/src/sampleworks/utils/guidance_script_utils.py index de0563bf..da9cf31b 100644 --- a/src/sampleworks/utils/guidance_script_utils.py +++ b/src/sampleworks/utils/guidance_script_utils.py @@ -229,14 +229,17 @@ def get_reward_function_and_structure( structure_path: str | Path, ) -> tuple[RealSpaceRewardFunction, dict[str, Any]]: logger.debug(f"Loading structure from {structure_path}") - structure_path = resolve_mixed_hetatm_atom_altlocs(Path(structure_path)) + safe_structure_path = resolve_mixed_hetatm_atom_altlocs(Path(structure_path)) structure = parse( - structure_path, + safe_structure_path, hydrogen_policy="remove", add_missing_atoms=False, ccd_mirror_path=None, ) + if safe_structure_path != structure_path: + safe_structure_path.unlink() # delete the temporary file if it was created + logger.debug(f"Loading density map from {density}") xmap = XMap.fromfile(density, resolution=resolution)