From 6b11747505689f6aa3d7715ee0175d2721a9fbd4 Mon Sep 17 00:00:00 2001 From: xraymemory Date: Mon, 22 Jun 2026 12:46:21 -0400 Subject: [PATCH 1/2] fix(eval): align synthetic density occupancy handling --- .../eval/generate_synthetic_density.py | 63 +++++++++++-------- tests/eval/test_generate_synthetic_density.py | 33 ++++++++++ 2 files changed, 71 insertions(+), 25 deletions(-) create mode 100644 tests/eval/test_generate_synthetic_density.py diff --git a/src/sampleworks/eval/generate_synthetic_density.py b/src/sampleworks/eval/generate_synthetic_density.py index 755ee2d7..c0415154 100644 --- a/src/sampleworks/eval/generate_synthetic_density.py +++ b/src/sampleworks/eval/generate_synthetic_density.py @@ -10,7 +10,10 @@ from joblib import delayed, Parallel from loguru import logger from sampleworks.core.forward_models.xray.real_space_density import XMap_torch -from sampleworks.eval.synthetic_utils import load_structure_for_synthetic_reward +from sampleworks.eval.synthetic_utils import ( + load_structure_for_synthetic_reward, + validate_occupancy_values, +) from sampleworks.utils.atom_array_utils import save_structure_to_cif from sampleworks.utils.density_utils import compute_density_from_atomarray from sampleworks.utils.torch_utils import try_gpu @@ -26,29 +29,34 @@ class BatchRow: Path to the structure file (relative to base_dir) selection Optional atom selection string in pyMOL-like syntax (e.g., 'chain A and resi 10-50') - occ_values + occupancy_values Custom occupancy values for altlocs, must be in range [0.0, 1.0] mapfile Optional custom output filename for the density map """ - VALID_EXTENSIONS: ClassVar[frozenset[str]] = frozenset({".pdb", ".cif", ".mmcif", ".ent"}) + VALID_EXTENSIONS: ClassVar[frozenset[str]] = frozenset({".cif", ".mmcif"}) + LEGACY_EXTENSIONS: ClassVar[frozenset[str]] = frozenset({".pdb", ".ent"}) filename: str selection: str | None = None - occ_values: list[float] = field(default_factory=list) + occupancy_values: list[float] = field(default_factory=list) mapfile: str | None = None def __post_init__(self) -> None: ext = Path(self.filename).suffix.lower() - if ext not in self.VALID_EXTENSIONS: + all_supported = self.VALID_EXTENSIONS | self.LEGACY_EXTENSIONS + if ext not in all_supported: raise ValueError( f"Invalid file extension '{ext}' for '{self.filename}'. " - f"Expected one of: {', '.join(sorted(self.VALID_EXTENSIONS))}" + f"Expected one of: {', '.join(sorted(all_supported))}" ) - for i, v in enumerate(self.occ_values): - if not 0.0 <= v <= 1.0: - raise ValueError(f"Occupancy value {v} at index {i} is out of range [0.0, 1.0]") + if ext in self.LEGACY_EXTENSIONS: + logger.warning( + f"'{ext}' is a legacy PDB format and support may be removed in a future version. " + "Prefer .cif or .mmcif (mmCIF format)." + ) + validate_occupancy_values(self.occupancy_values) @classmethod def from_dict(cls, row: dict[str, str]) -> "BatchRow": @@ -58,7 +66,7 @@ def from_dict(cls, row: dict[str, str]) -> "BatchRow": ---------- row Dictionary with keys 'filename' (required), and optionally - 'selection', 'occ_values' (colon-separated), and 'mapfile' + 'selection', 'occupancy_values' (colon-separated), and 'mapfile' Returns ------- @@ -75,14 +83,15 @@ def from_dict(cls, row: dict[str, str]) -> "BatchRow": if "filename" not in row: raise KeyError("CSV is missing required 'filename' column") - occ_values: list[float] = [] - if row.get("occ_values"): - occ_values = [float(v.strip()) for v in row["occ_values"].split(":")] + occupancy_values: list[float] = [] + occupancy_values_csv = row.get("occupancy_values") or row.get("occ_values") + if occupancy_values_csv: + occupancy_values = [float(v.strip()) for v in occupancy_values_csv.split(":")] return cls( filename=row["filename"], selection=row.get("selection") or None, - occ_values=occ_values, + occupancy_values=occupancy_values, mapfile=row.get("mapfile") or None, ) @@ -111,7 +120,7 @@ def load_batch_csv(csv_path: Path) -> list[BatchRow]: ---------- csv_path Path to CSV file with columns: filename (required), selection (optional), - occ_values (optional), mapfile (optional) + occupancy_values (optional), mapfile (optional) Returns ------- @@ -135,7 +144,7 @@ def load_batch_csv(csv_path: Path) -> list[BatchRow]: def _process_single_row( row: BatchRow, - occ_mode: str, + occupancy_mode: str, base_dir: Path, output_dir: Path, resolution: float, @@ -152,7 +161,7 @@ def _process_single_row( ---------- row BatchRow containing structure information - occ_mode + occupancy_mode Occupancy assignment mode: 'default', 'uniform', or 'custom' base_dir Base directory for resolving relative structure file paths @@ -180,8 +189,8 @@ def _process_single_row( structure_path = base_dir / row.filename atom_array = load_structure_for_synthetic_reward( structure_path, - occupancy_mode=occ_mode, - occupancy_values=row.occ_values, + occupancy_mode=occupancy_mode, + occupancy_values=row.occupancy_values, strip_hydrogens=strip_hydrogens, strip_waters=strip_waters, strip_ligands=strip_ligands, @@ -237,7 +246,7 @@ def process_batch( base_dir: Path, output_dir: Path, resolution: float, - occ_mode: str, + occupancy_mode: str, em_mode: bool, device: torch.device, n_jobs: int = -1, @@ -282,7 +291,7 @@ def process_batch( Parallel(n_jobs=n_jobs, backend="loky")( delayed(_process_single_row)( row=row, - occ_mode=occ_mode, + occupancy_mode=occupancy_mode, base_dir=base_dir, output_dir=output_dir, resolution=resolution, @@ -323,13 +332,17 @@ def parse_args() -> argparse.Namespace: occ_group = parser.add_argument_group("Occupancy Options") occ_group.add_argument( + "--occupancy-mode", "--occ-mode", + dest="occupancy_mode", choices=["default", "uniform", "custom"], default="default", help="Occupancy assignment mode", ) occ_group.add_argument( + "--occupancy-values", "--occ-values", + dest="occupancy_values", type=str, help="Colon-separated occupancy values for custom mode (e.g., '0.3:0.7')", ) @@ -389,7 +402,7 @@ def main() -> None: base_dir=args.base_dir, output_dir=args.output_dir, resolution=args.resolution, - occ_mode=args.occ_mode, + occupancy_mode=args.occupancy_mode, em_mode=args.em_mode, device=device, n_jobs=args.n_jobs, @@ -402,14 +415,14 @@ def main() -> None: row = BatchRow( filename=str(args.structure), selection=args.selection, - occ_values=[float(v.strip()) for v in args.occ_values.split(":")] - if args.occ_values + occupancy_values=[float(v.strip()) for v in args.occupancy_values.split(":")] + if args.occupancy_values else [], mapfile=args.output.name if args.output else None, ) _process_single_row( row=row, - occ_mode=args.occ_mode, + occupancy_mode=args.occupancy_mode, base_dir=args.structure.parent, output_dir=args.output.parent if args.output else Path("."), resolution=args.resolution, diff --git a/tests/eval/test_generate_synthetic_density.py b/tests/eval/test_generate_synthetic_density.py new file mode 100644 index 00000000..207ae282 --- /dev/null +++ b/tests/eval/test_generate_synthetic_density.py @@ -0,0 +1,33 @@ +"""Tests for synthetic density batch-row argument handling.""" + +import pytest +from sampleworks.eval.generate_synthetic_density import BatchRow + + +def test_batch_row_accepts_occupancy_values_column() -> None: + """Batch CSV parsing uses the canonical occupancy_values column name.""" + row = BatchRow.from_dict( + {"filename": "input.cif", "selection": "chain A", "occupancy_values": "0.25:0.75"} + ) + + assert row.occupancy_values == [0.25, 0.75] + assert row.selection == "chain A" + + +def test_batch_row_accepts_legacy_occ_values_column() -> None: + """The old occ_values column remains accepted for existing batch CSVs.""" + row = BatchRow.from_dict({"filename": "input.cif", "occ_values": "0.4:0.6"}) + + assert row.occupancy_values == [0.4, 0.6] + + +def test_batch_row_rejects_occupancy_values_that_do_not_sum_to_one() -> None: + """Density generation now uses the shared occupancy-value validation helper.""" + with pytest.raises(ValueError, match="must sum to 1.0"): + BatchRow(filename="input.cif", occupancy_values=[0.2, 0.3]) + + +def test_batch_row_rejects_unsupported_extension() -> None: + """Only mmCIF and legacy PDB-like structure extensions are supported.""" + with pytest.raises(ValueError, match="Invalid file extension"): + BatchRow(filename="input.txt") From ba9ab4900a285170695510d7628a9920b3195a19 Mon Sep 17 00:00:00 2001 From: xraymemory Date: Mon, 22 Jun 2026 14:15:09 -0400 Subject: [PATCH 2/2] fix(eval): lazy-load synthetic density batch deps --- src/sampleworks/eval/generate_synthetic_density.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/sampleworks/eval/generate_synthetic_density.py b/src/sampleworks/eval/generate_synthetic_density.py index c0415154..8b8521f8 100644 --- a/src/sampleworks/eval/generate_synthetic_density.py +++ b/src/sampleworks/eval/generate_synthetic_density.py @@ -7,7 +7,6 @@ from typing import ClassVar import torch -from joblib import delayed, Parallel from loguru import logger from sampleworks.core.forward_models.xray.real_space_density import XMap_torch from sampleworks.eval.synthetic_utils import ( @@ -282,6 +281,8 @@ def process_batch( save_structure If True, save the processed structure to a CIF file in the input directory. """ + from joblib import delayed, Parallel + rows = load_batch_csv(csv_path) logger.info(f"Processing {len(rows)} structures from {csv_path} using {n_jobs} jobs")