diff --git a/docker-entrypoint.sh b/docker-entrypoint.sh index ac6e1a83..6f4b4211 100755 --- a/docker-entrypoint.sh +++ b/docker-entrypoint.sh @@ -42,7 +42,7 @@ OPTIONS: ENVIRONMENTS: boltz For boltz1 and boltz2 models - protenix For protenix model + protenix For protenix model rf3 For RF3 model analysis For scripts/eval analysis jobs diff --git a/scripts/eval/find_altloc_selections.py b/scripts/eval/find_altloc_selections.py index 10435430..e537dadc 100644 --- a/scripts/eval/find_altloc_selections.py +++ b/scripts/eval/find_altloc_selections.py @@ -6,11 +6,19 @@ from sampleworks.utils.cif_utils import find_altloc_selections -def _process_row(row: pd.Series, altloc_label: str, min_span: int) -> pd.Series: +def _process_row( + row: pd.Series, altloc_label: str, min_span: int, include_all_altlocs: bool +) -> pd.Series: cif_file = row["structure"] - selections = ";".join(find_altloc_selections(cif_file, altloc_label, min_span)) + selections = ";".join( + find_altloc_selections(cif_file, altloc_label, min_span, include_all_altlocs) + ) if not selections: - logger.warning(f"No altlocs found for {cif_file}") + logger.warning( + f"No selections emitted for {cif_file} " + f"(min_span={min_span}, include_all_altlocs={include_all_altlocs}) " + f"altlocs may exist but none met these criteria." + ) # The column names here are defined by the input requirements of scripts like # rscc_grid_search_script.py @@ -35,7 +43,11 @@ def main(args): """ input_df = pd.read_csv(args.input_csv) output = input_df.apply( - _process_row, altloc_label=args.altloc_label, min_span=args.min_span, axis=1 + _process_row, + altloc_label=args.altloc_label, + min_span=args.min_span, + include_all_altlocs=args.include_all_altlocs, + axis=1, ) output.to_csv(args.output_file, index=False) @@ -51,5 +63,11 @@ def main(args): parser.add_argument("--output-file", type=Path, required=True) parser.add_argument("--min-span", type=int, default=5) parser.add_argument("--altloc-label", type=str, default="label_alt_id") + parser.add_argument( + "--no-all-altlocs", + dest="include_all_altlocs", + action="store_false", + help="Omit the final per-chain selection string that includes all altloc residues", + ) args = parser.parse_args() main(args) diff --git a/src/sampleworks/utils/cif_utils.py b/src/sampleworks/utils/cif_utils.py index 048b5f90..82989b9a 100644 --- a/src/sampleworks/utils/cif_utils.py +++ b/src/sampleworks/utils/cif_utils.py @@ -19,13 +19,16 @@ def find_altloc_selections( - cif_file: Path | str, altloc_label: str = "label_alt_id", min_span: int = 5 + cif_file: Path | str, + altloc_label: str = "label_alt_id", + min_span: int = 5, + include_all_altlocs: bool = True, ) -> Iterable[str]: """Find alternative location selections in a CIF file. Individual spans at least ``min_span`` residues long are yielded as selection strings. - A final batch of selection strings is also yielded that contains all residues with - altlocs, one selection per chain. + Optionally, a final batch of selection strings is also yielded that contains all residues + with altlocs, one selection per chain. Parameters ---------- @@ -37,7 +40,11 @@ def find_altloc_selections( min_span : int Minimum number of consecutive residues to consider an altloc selection. Spans of altlocs shorter than this are not yielded as selection strings, but ARE - included in the final selections which includes all residues with altlocs in each chain. + included in the final selections which includes all residues with altlocs in each chain when + ``include_all_altlocs=True``. + include_all_altlocs : bool + If True (default), yield a final per-chain selection string containing all residues + with altlocs regardless of span length. Yields ------ @@ -72,12 +79,13 @@ def find_altloc_selections( # FIXME use new style selection https://github.com/diff-use/sampleworks/issues/56 yield f"chain {chain} and resi {start}-{end}" # old style, more compact, selection - if chain not in all_altloc_selections: - all_altloc_selections[chain] = [] - if start == end: - all_altloc_selections[chain].append(f"(res_id == {start})") - else: - all_altloc_selections[chain].append(f"(res_id >= {start} and res_id <= {end})") + if include_all_altlocs: + if chain not in all_altloc_selections: + all_altloc_selections[chain] = [] + if start == end: + all_altloc_selections[chain].append(f"(res_id == {start})") + else: + all_altloc_selections[chain].append(f"(res_id >= {start} and res_id <= {end})") for chain, selections in all_altloc_selections.items(): yield f"chain_id == '{chain}' and ({' or '.join(selections)})" diff --git a/tests/eval/test_find_altloc_selections_script.py b/tests/eval/test_find_altloc_selections_script.py new file mode 100644 index 00000000..25517a31 --- /dev/null +++ b/tests/eval/test_find_altloc_selections_script.py @@ -0,0 +1,170 @@ +""" +Integration tests for ``scripts/eval/find_altloc_selections.py``. +""" + +from __future__ import annotations + +import argparse +import importlib.util +from pathlib import Path + +import pandas as pd +import pytest + + +_SCRIPT_PATH = ( + Path(__file__).resolve().parents[2] / "scripts" / "eval" / "find_altloc_selections.py" +) + + +def _load_script(): + """Import the script module by path so tests don't require it to be + installed on ``sys.path``.""" + spec = importlib.util.spec_from_file_location("find_altloc_selections_script", _SCRIPT_PATH) + assert spec is not None and spec.loader is not None + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +@pytest.fixture +def find_altloc_script(): + return _load_script() + + +def _altloc_input_csv(tmp_path: Path, resources_dir: Path, *, n_rows: int = 1) -> Path: + """Write an ``n_rows``-row input CSV pointing at the 1vme altloc CIF. + + Skips the test if the CIF resource is missing. + """ + cif_path = resources_dir / "1vme" / "1vme_final_carved_edited_0.5occA_0.5occB.cif" + map_path = resources_dir / "1vme" / "1vme_final_carved_edited_0.5occA_0.5occB_1.80A.ccp4" + if not cif_path.exists(): + pytest.skip(f"Test resource not found: {cif_path}") + rows = [ + { + "name": f"1VME_{i}", + "structure": str(cif_path), + "density": str(map_path), + "resolution": 1.8, + } + for i in range(n_rows) + ] + input_csv = tmp_path / "input.csv" + pd.DataFrame(rows).to_csv(input_csv, index=False) + return input_csv + + +def _make_args( + *, + input_csv: Path, + output_file: Path, + min_span: int = 5, + altloc_label: str = "label_alt_id", + include_all_altlocs: bool = True, +) -> argparse.Namespace: + return argparse.Namespace( + input_csv=input_csv, + output_file=output_file, + min_span=min_span, + altloc_label=altloc_label, + include_all_altlocs=include_all_altlocs, + ) + + +_EXPECTED_COLS = { + "protein", + "selection", + "structure_pattern", + "map_pattern", + "base_map_dir", + "resolution", +} + + +@pytest.mark.slow +def test_main_writes_expected_output_columns_and_derived_paths( + tmp_path: Path, resources_dir: Path, find_altloc_script +): + """Two-row happy path: verifies output schema, derived path columns, and + that both old- and new-style selections appear with default flags.""" + input_csv = _altloc_input_csv(tmp_path, resources_dir, n_rows=2) + output_file = tmp_path / "output.csv" + + find_altloc_script.main(_make_args(input_csv=input_csv, output_file=output_file)) + + assert output_file.exists() + df = pd.read_csv(output_file) + + assert set(df.columns) == _EXPECTED_COLS + assert len(df) == 2 + + for i, row in df.iterrows(): + assert row["protein"] == f"1VME_{i}" + assert row["structure_pattern"] == "1vme_final_carved_edited_0.5occA_0.5occB.cif" + assert row["map_pattern"] == "1vme_final_carved_edited_0.5occA_0.5occB_1.80A.ccp4" + assert row["base_map_dir"] == "1vme" + assert row["resolution"] == pytest.approx(1.8) + + entries = row["selection"].split(";") + assert any(s.startswith("chain ") and " and resi " in s for s in entries), ( + f"expected at least one old-style selection in {entries}" + ) + assert any("chain_id == " in s for s in entries), ( + f"expected at least one new-style selection in {entries} " + "(include_all_altlocs default is True)" + ) + + +@pytest.mark.slow +def test_no_all_altlocs_flag_omits_per_chain_selection( + tmp_path: Path, resources_dir: Path, find_altloc_script +): + """``include_all_altlocs=False`` (CLI ``--no-all-altlocs``) suppresses the + per-chain new-style selection; only old-style entries (if any) remain.""" + input_csv = _altloc_input_csv(tmp_path, resources_dir, n_rows=1) + output_file = tmp_path / "output.csv" + + find_altloc_script.main( + _make_args(input_csv=input_csv, output_file=output_file, include_all_altlocs=False) + ) + + df = pd.read_csv(output_file) + selection = df.iloc[0]["selection"] + selection_str = "" if pd.isna(selection) else selection + assert "chain_id == " not in selection_str + if selection_str: + entries = selection_str.split(";") + assert all(s.startswith("chain ") and " and resi " in s for s in entries), ( + f"only old-style entries should remain; got {entries}" + ) + + +@pytest.mark.slow +def test_large_min_span_with_no_all_altlocs_yields_empty_selection( + tmp_path: Path, resources_dir: Path, find_altloc_script +): + """An impossibly large ``min_span`` paired with ``include_all_altlocs=False`` + drops every selection. The row is still written with derived columns + populated and an empty ``selection`` cell.""" + input_csv = _altloc_input_csv(tmp_path, resources_dir, n_rows=1) + output_file = tmp_path / "output.csv" + + find_altloc_script.main( + _make_args( + input_csv=input_csv, + output_file=output_file, + min_span=10**6, + include_all_altlocs=False, + ) + ) + + df = pd.read_csv(output_file) + assert len(df) == 1 + row = df.iloc[0] + + assert pd.isna(row["selection"]) or row["selection"] == "" + assert row["protein"] == "1VME_0" + assert row["structure_pattern"] == "1vme_final_carved_edited_0.5occA_0.5occB.cif" + assert row["map_pattern"] == "1vme_final_carved_edited_0.5occA_0.5occB_1.80A.ccp4" + assert row["base_map_dir"] == "1vme"