From 241cbb532a06d278bd43142e5b9a845ad1b04998 Mon Sep 17 00:00:00 2001 From: Abdelsalam Date: Thu, 9 Apr 2026 23:59:52 +0200 Subject: [PATCH 1/9] Add unified CLI for run_guidance and fix missing args bug Add GuidanceConfig.from_cli() classmethod that provides a single entrypoint for CLI argument parsing, replacing the 8 separate parse_*_args() functions. Register sampleworks-guidance console_scripts entrypoint in pyproject.toml. Simplify all 8 standalone scripts to thin wrappers calling from_cli() with pre-set model and guidance_type values. Fixes #181 (protein/model/guidance_type missing from argparse Namespace) Fixes #154 (unified CLI interface for run_guidance) Fixes #153 (CLI documentation in README) --- README.md | 26 ++- pixi.lock | 4 +- pyproject.toml | 3 + scripts/boltz1_fk_steering.py | 19 +- scripts/boltz1_pure_guidance.py | 19 +- scripts/boltz2_fk_steering.py | 20 +- scripts/boltz2_pure_guidance.py | 20 +- scripts/protenix_fk_steering.py | 19 +- scripts/protenix_pure_guidance.py | 19 +- scripts/rf3_fk_steering.py | 21 +- scripts/rf3_pure_guidance.py | 21 +- src/sampleworks/cli/__init__.py | 0 src/sampleworks/cli/guidance.py | 24 +++ .../utils/guidance_script_arguments.py | 201 ++++++++++-------- tests/cli/__init__.py | 0 tests/cli/test_guidance_cli.py | 167 +++++++++++++++ 16 files changed, 412 insertions(+), 171 deletions(-) create mode 100644 src/sampleworks/cli/__init__.py create mode 100644 src/sampleworks/cli/guidance.py create mode 100644 tests/cli/__init__.py create mode 100644 tests/cli/test_guidance_cli.py diff --git a/README.md b/README.md index 7087c139..cc36a862 100644 --- a/README.md +++ b/README.md @@ -66,7 +66,10 @@ download_boltz2(cache) Run Boltz-2 pure guidance on the included 1VME example: ```bash -pixi run -e boltz python scripts/boltz2_pure_guidance.py \ +pixi run -e boltz sampleworks-guidance \ + --model boltz2 \ + --guidance-type pure_guidance \ + --protein 1VME \ --model-checkpoint ~/.boltz/boltz2_conf.ckpt \ --structure tests/resources/1vme/1vme_final_carved_edited_0.5occA_0.5occB.cif \ --density tests/resources/1vme/1vme_final_carved_edited_0.5occA_0.5occB_1.80A.ccp4 \ @@ -78,7 +81,26 @@ pixi run -e boltz python scripts/boltz2_pure_guidance.py \ --align-to-input ``` -Output files appear in `output/boltz2_pure_guidance/`: `refined.cif` (final ensemble), `losses.txt`, `trajectory/`, `run.log`. See [`scripts/README.md`](scripts/README.md) for all scripts and arguments. +Output files appear in `output/boltz2_pure_guidance/`: `refined.cif` (final ensemble), `losses.txt`, `trajectory/`, `run.log`. + +### CLI reference + +`sampleworks-guidance` is the unified command-line interface for running guidance on a single structure. + +**Required arguments:** + +| Argument | Description | +|---|---| +| `--model` | `boltz1`, `boltz2`, `protenix`, or `rf3` | +| `--guidance-type` | `pure_guidance` or `fk_steering` | +| `--protein` | Protein identifier (should match naming used in grid search / evaluation) | +| `--structure` | Path to input structure file (CIF) | +| `--density` | Path to density map (CCP4/MRC/MAP) | +| `--resolution` | Map resolution in Angstroms | + +Model-specific arguments (e.g. `--method` for boltz2, `--msa-path` for rf3) and guidance-type-specific arguments (e.g. `--num-particles` for fk_steering) are included automatically. Run `sampleworks-guidance --model --guidance-type --help` to see all available options. + +The individual scripts in `scripts/` (e.g. `boltz2_pure_guidance.py`) still work as shortcuts with `--model` and `--guidance-type` pre-set. ## Grid Search diff --git a/pixi.lock b/pixi.lock index 39db0edc..56b57c46 100644 --- a/pixi.lock +++ b/pixi.lock @@ -9719,8 +9719,8 @@ packages: timestamp: 1753407970803 - pypi: ./ name: sampleworks - version: 0.4.4 - sha256: 4fb2f096d24cf2ced2c425a465ddeef02add83541db2b457442084fd5cf1269c + version: 0.4.3 + sha256: 86ffa999bf980b906be83b7e94152b34e4e617dc90cb3d37a2d16baa4abce7ba requires_dist: - atomworks[ml]==2.1.1 - python-dotenv diff --git a/pyproject.toml b/pyproject.toml index 73815e16..73e0f4fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,9 @@ name = "sampleworks" requires-python = ">= 3.11, <3.14" version = "0.4.4" +[project.scripts] +sampleworks-guidance = "sampleworks.cli.guidance:main" + [tool.hatch.metadata] allow-direct-references = true diff --git a/scripts/boltz1_fk_steering.py b/scripts/boltz1_fk_steering.py index 3a120aa9..f1fafaff 100644 --- a/scripts/boltz1_fk_steering.py +++ b/scripts/boltz1_fk_steering.py @@ -1,19 +1,18 @@ -""" -Run FK steering with real-space density reward on the Boltz1 model. -""" +"""Run FK steering with real-space density reward on the Boltz1 model.""" -from sampleworks.utils.guidance_constants import GuidanceType, StructurePredictor -from sampleworks.utils.guidance_script_arguments import parse_boltz1_fk_steering_args +from sampleworks.utils.guidance_script_arguments import GuidanceConfig from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance -def main(args): +def main(): + config = GuidanceConfig.from_cli(model="boltz1", guidance_type="fk_steering") device, model_wrapper = get_model_and_device( - "", args.model_checkpoint, StructurePredictor.BOLTZ_1 + config.device, + getattr(config, "model_checkpoint", None), + config.model, ) - run_guidance(args, GuidanceType.FK_STEERING, model_wrapper, device) + run_guidance(config, config.guidance_type, model_wrapper, device) if __name__ == "__main__": - guidance_args = parse_boltz1_fk_steering_args() - main(guidance_args) + main() diff --git a/scripts/boltz1_pure_guidance.py b/scripts/boltz1_pure_guidance.py index c4d48693..3d09cfdf 100644 --- a/scripts/boltz1_pure_guidance.py +++ b/scripts/boltz1_pure_guidance.py @@ -1,19 +1,18 @@ -""" -Run pure guidance with real-space density reward on the Boltz1 model. -""" +"""Run pure guidance with real-space density reward on the Boltz1 model.""" -from sampleworks.utils.guidance_constants import GuidanceType, StructurePredictor -from sampleworks.utils.guidance_script_arguments import parse_boltz1_pure_guidance_args +from sampleworks.utils.guidance_script_arguments import GuidanceConfig from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance -def main(args): +def main(): + config = GuidanceConfig.from_cli(model="boltz1", guidance_type="pure_guidance") device, model_wrapper = get_model_and_device( - "", args.model_checkpoint, StructurePredictor.BOLTZ_1 + config.device, + getattr(config, "model_checkpoint", None), + config.model, ) - run_guidance(args, GuidanceType.PURE_GUIDANCE, model_wrapper, device) + run_guidance(config, config.guidance_type, model_wrapper, device) if __name__ == "__main__": - guidance_args = parse_boltz1_pure_guidance_args() - main(guidance_args) + main() diff --git a/scripts/boltz2_fk_steering.py b/scripts/boltz2_fk_steering.py index 13e768a0..1fad3f47 100644 --- a/scripts/boltz2_fk_steering.py +++ b/scripts/boltz2_fk_steering.py @@ -1,19 +1,19 @@ -""" -Run FK steering with real-space density reward on the Boltz2 model. -""" +"""Run FK steering with real-space density reward on the Boltz2 model.""" -from sampleworks.utils.guidance_constants import GuidanceType, StructurePredictor -from sampleworks.utils.guidance_script_arguments import parse_boltz2_fk_steering_args +from sampleworks.utils.guidance_script_arguments import GuidanceConfig from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance -def main(args): +def main(): + config = GuidanceConfig.from_cli(model="boltz2", guidance_type="fk_steering") device, model_wrapper = get_model_and_device( - "", args.model_checkpoint, StructurePredictor.BOLTZ_2, method=args.method + config.device, + getattr(config, "model_checkpoint", None), + config.model, + method=getattr(config, "method", None), ) - run_guidance(args, GuidanceType.FK_STEERING, model_wrapper, device) + run_guidance(config, config.guidance_type, model_wrapper, device) if __name__ == "__main__": - guidance_args = parse_boltz2_fk_steering_args() - main(guidance_args) + main() diff --git a/scripts/boltz2_pure_guidance.py b/scripts/boltz2_pure_guidance.py index f5dc2550..77c7d9a6 100644 --- a/scripts/boltz2_pure_guidance.py +++ b/scripts/boltz2_pure_guidance.py @@ -1,19 +1,19 @@ -""" -Run pure guidance with real-space density reward on the Boltz2 model. -""" +"""Run pure guidance with real-space density reward on the Boltz2 model.""" -from sampleworks.utils.guidance_constants import GuidanceType, StructurePredictor -from sampleworks.utils.guidance_script_arguments import parse_boltz2_pure_guidance_args +from sampleworks.utils.guidance_script_arguments import GuidanceConfig from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance -def main(args): +def main(): + config = GuidanceConfig.from_cli(model="boltz2", guidance_type="pure_guidance") device, model_wrapper = get_model_and_device( - "", args.model_checkpoint, StructurePredictor.BOLTZ_2, method=args.method + config.device, + getattr(config, "model_checkpoint", None), + config.model, + method=getattr(config, "method", None), ) - run_guidance(args, GuidanceType.PURE_GUIDANCE, model_wrapper, device) + run_guidance(config, config.guidance_type, model_wrapper, device) if __name__ == "__main__": - guidance_args = parse_boltz2_pure_guidance_args() - main(guidance_args) + main() diff --git a/scripts/protenix_fk_steering.py b/scripts/protenix_fk_steering.py index e2ab10ec..a8acb8d3 100644 --- a/scripts/protenix_fk_steering.py +++ b/scripts/protenix_fk_steering.py @@ -1,19 +1,18 @@ -""" -Run FK steering with real-space density reward on the Protenix model. -""" +"""Run FK steering with real-space density reward on the Protenix model.""" -from sampleworks.utils.guidance_constants import GuidanceType, StructurePredictor -from sampleworks.utils.guidance_script_arguments import parse_protenix_fk_steering_args +from sampleworks.utils.guidance_script_arguments import GuidanceConfig from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance -def main(args): +def main(): + config = GuidanceConfig.from_cli(model="protenix", guidance_type="fk_steering") device, model_wrapper = get_model_and_device( - "", args.model_checkpoint, StructurePredictor.PROTENIX + config.device, + getattr(config, "model_checkpoint", None), + config.model, ) - run_guidance(args, GuidanceType.FK_STEERING, model_wrapper, device) + run_guidance(config, config.guidance_type, model_wrapper, device) if __name__ == "__main__": - guidance_args = parse_protenix_fk_steering_args() - main(guidance_args) + main() diff --git a/scripts/protenix_pure_guidance.py b/scripts/protenix_pure_guidance.py index 5aaef3b4..7f46fdfb 100644 --- a/scripts/protenix_pure_guidance.py +++ b/scripts/protenix_pure_guidance.py @@ -1,19 +1,18 @@ -""" -Run pure guidance with real-space density reward on the Protenix model. -""" +"""Run pure guidance with real-space density reward on the Protenix model.""" -from sampleworks.utils.guidance_constants import GuidanceType, StructurePredictor -from sampleworks.utils.guidance_script_arguments import parse_protenix_pure_guidance_args +from sampleworks.utils.guidance_script_arguments import GuidanceConfig from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance -def main(args): +def main(): + config = GuidanceConfig.from_cli(model="protenix", guidance_type="pure_guidance") device, model_wrapper = get_model_and_device( - "", args.model_checkpoint, StructurePredictor.PROTENIX + config.device, + getattr(config, "model_checkpoint", None), + config.model, ) - run_guidance(args, GuidanceType.PURE_GUIDANCE, model_wrapper, device) + run_guidance(config, config.guidance_type, model_wrapper, device) if __name__ == "__main__": - guidance_args = parse_protenix_pure_guidance_args() - main(guidance_args) + main() diff --git a/scripts/rf3_fk_steering.py b/scripts/rf3_fk_steering.py index 595e93f9..fe063e0f 100644 --- a/scripts/rf3_fk_steering.py +++ b/scripts/rf3_fk_steering.py @@ -1,17 +1,18 @@ -""" -Run FK steering with real-space density reward on the RF3 model. -""" +"""Run FK steering with real-space density reward on the RF3 model.""" -from sampleworks.utils.guidance_constants import GuidanceType, StructurePredictor -from sampleworks.utils.guidance_script_arguments import parse_rf3_fk_steering_args +from sampleworks.utils.guidance_script_arguments import GuidanceConfig from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance -def main(args): - device, model_wrapper = get_model_and_device("", args.model_checkpoint, StructurePredictor.RF3) - run_guidance(args, GuidanceType.FK_STEERING, model_wrapper, device) +def main(): + config = GuidanceConfig.from_cli(model="rf3", guidance_type="fk_steering") + device, model_wrapper = get_model_and_device( + config.device, + getattr(config, "model_checkpoint", None), + config.model, + ) + run_guidance(config, config.guidance_type, model_wrapper, device) if __name__ == "__main__": - guidance_args = parse_rf3_fk_steering_args() - main(guidance_args) + main() diff --git a/scripts/rf3_pure_guidance.py b/scripts/rf3_pure_guidance.py index f89f4e0e..160fb444 100644 --- a/scripts/rf3_pure_guidance.py +++ b/scripts/rf3_pure_guidance.py @@ -1,17 +1,18 @@ -""" -Run pure guidance with real-space density reward on the RF3 model. -""" +"""Run pure guidance with real-space density reward on the RF3 model.""" -from sampleworks.utils.guidance_constants import GuidanceType, StructurePredictor -from sampleworks.utils.guidance_script_arguments import parse_rf3_pure_guidance_args +from sampleworks.utils.guidance_script_arguments import GuidanceConfig from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance -def main(args): - device, model_wrapper = get_model_and_device("", args.model_checkpoint, StructurePredictor.RF3) - run_guidance(args, GuidanceType.PURE_GUIDANCE, model_wrapper, device) +def main(): + config = GuidanceConfig.from_cli(model="rf3", guidance_type="pure_guidance") + device, model_wrapper = get_model_and_device( + config.device, + getattr(config, "model_checkpoint", None), + config.model, + ) + run_guidance(config, config.guidance_type, model_wrapper, device) if __name__ == "__main__": - guidance_args = parse_rf3_pure_guidance_args() - main(guidance_args) + main() diff --git a/src/sampleworks/cli/__init__.py b/src/sampleworks/cli/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/sampleworks/cli/guidance.py b/src/sampleworks/cli/guidance.py new file mode 100644 index 00000000..5f71fe22 --- /dev/null +++ b/src/sampleworks/cli/guidance.py @@ -0,0 +1,24 @@ +"""Unified CLI for running diffusion guidance with sampleworks.""" + +from __future__ import annotations + +import sys + +from sampleworks.utils.guidance_script_arguments import GuidanceConfig +from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance + + +def main(argv: list[str] | None = None) -> int: + config = GuidanceConfig.from_cli(argv) + device, model_wrapper = get_model_and_device( + config.device, + getattr(config, "model_checkpoint", None), + config.model, + method=getattr(config, "method", None), + ) + result = run_guidance(config, config.guidance_type, model_wrapper, device) + return result.exit_code + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/sampleworks/utils/guidance_script_arguments.py b/src/sampleworks/utils/guidance_script_arguments.py index e563bc55..9f033c1e 100644 --- a/src/sampleworks/utils/guidance_script_arguments.py +++ b/src/sampleworks/utils/guidance_script_arguments.py @@ -112,9 +112,36 @@ def validate_model_checkpoint( return str(checkpoint_path) +_MODEL_ARG_ADDERS: dict[str, Any] = { + "boltz1": lambda p: add_boltz1_specific_args(p), + "boltz2": lambda p: add_boltz2_specific_args(p), + "protenix": lambda p: add_protenix_specific_args(p), + "rf3": lambda p: add_rf3_specific_args(p), +} + +_GUIDANCE_ARG_ADDERS: dict[str, Any] = { + "pure_guidance": lambda p: add_pure_guidance_args(p), + "fk_steering": lambda p: add_fk_steering_args(p), +} + +# Attributes set dynamically by add_*_args helpers that should be copied +# from a parsed argparse.Namespace onto a GuidanceConfig instance. +_DYNAMIC_ATTRS = [ + # pure guidance + "step_size", "step_scaler_type", + # fk steering + "num_particles", "fk_resampling_interval", "fk_lambda", + "num_gd_steps", "guidance_weight", "guidance_interval", + # model-specific + "model_checkpoint", "method", "msa_path", + "disable_chiral_features", "track_chiral_features", + # generic (overridable) + "ensemble_size", +] + + @dataclass class GuidanceConfig: - # TODO add a class method to set this up completely from args and job config. """ Class to hold guidance config arguments, compatible with argparse, but which also can do some basic validation. @@ -145,6 +172,92 @@ def add_argument(self, name: str, default: Any = None, **kwargs): """Add an argument to the guidance config, in a form compatible with argparse""" setattr(self, name.lstrip("-").replace("-", "_"), default) + @classmethod + def from_cli( + cls, + argv: list[str] | None = None, + model: str | None = None, + guidance_type: str | None = None, + ) -> GuidanceConfig: + """Parse CLI arguments and return a fully populated GuidanceConfig. + + When *model* and *guidance_type* are provided (e.g. from legacy + scripts), they are used directly and ``--model`` / ``--guidance-type`` + are not required on the command line. Otherwise they are parsed as + required CLI arguments. + """ + model_choices = [m.value for m in StructurePredictor] + guidance_choices = [g.value for g in GuidanceType] + + # -- first pass: resolve model & guidance_type if not pre-set -------- + if model is None or guidance_type is None: + pre = argparse.ArgumentParser(add_help=False) + if model is None: + pre.add_argument( + "--model", type=str, required=True, choices=model_choices, + help="Structure prediction model", + ) + if guidance_type is None: + pre.add_argument( + "--guidance-type", type=str, required=True, + choices=guidance_choices, help="Guidance method", + ) + pre_args, _ = pre.parse_known_args(argv) + model = model or pre_args.model + guidance_type = guidance_type or pre_args.guidance_type + + # -- full parser ----------------------------------------------------- + parser = argparse.ArgumentParser( + description=f"Run {guidance_type} guidance with {model}", + ) + if model is not None: + parser.add_argument( + "--model", type=str, default=model, choices=model_choices, + help="Structure prediction model", + ) + if guidance_type is not None: + parser.add_argument( + "--guidance-type", type=str, default=guidance_type, + choices=guidance_choices, help="Guidance method", + ) + parser.add_argument( + "--protein", type=str, required=True, + help="Protein identifier (must match naming used in grid search / evaluation)", + ) + add_generic_args(parser) + _MODEL_ARG_ADDERS[model](parser) + _GUIDANCE_ARG_ADDERS[guidance_type](parser) + + args = parser.parse_args(argv) + + config = cls( + protein=args.protein, + structure=args.structure, + density=args.density, + model=model, + guidance_type=guidance_type, + log_path=getattr(args, "log_path", None) or "", + output_dir=args.output_dir, + partial_diffusion_step=args.partial_diffusion_step, + loss_order=args.loss_order, + resolution=args.resolution, + device=getattr(args, "device", "") or "", + gradient_normalization=args.gradient_normalization, + em=args.em, + guidance_start=args.guidance_start, + augmentation=args.augmentation, + align_to_input=args.align_to_input, + ) + + # __post_init__ already set defaults for model/guidance-specific + # attrs; override with any explicit CLI values. + for attr in _DYNAMIC_ATTRS: + val = getattr(args, attr, None) + if val is not None: + setattr(config, attr, val) + + return config + def __post_init__(self): """Set up guidance config for a given model and guidance type""" if self.guidance_type == GuidanceType.PURE_GUIDANCE: @@ -367,92 +480,6 @@ def add_rf3_specific_args(parser: argparse.ArgumentParser | GuidanceConfig): ) -############## -# Use these methods to parse arguments in scripts which load the model themselves. -############## -def parse_boltz2_pure_guidance_args(): - parser = argparse.ArgumentParser( - description="Pure guidance refinement with Boltz-2 and real-space density" - ) - add_generic_args(parser) - add_boltz2_specific_args(parser) - add_pure_guidance_args(parser) - - return parser.parse_args() - - -def parse_boltz1_pure_guidance_args(): - parser = argparse.ArgumentParser( - description="Pure guidance refinement with Boltz-1 and real-space density" - ) - add_generic_args(parser) - add_boltz1_specific_args(parser) - add_pure_guidance_args(parser) - - return parser.parse_args() - - -def parse_protenix_pure_guidance_args(): - parser = argparse.ArgumentParser( - description="Pure guidance refinement with Protenix and real-space density" - ) - add_generic_args(parser) - add_protenix_specific_args(parser) - add_pure_guidance_args(parser) - - return parser.parse_args() - - -def parse_protenix_fk_steering_args(): - parser = argparse.ArgumentParser( - description="FK steering refinement with Protenix and real-space density" - ) - add_protenix_specific_args(parser) - add_generic_args(parser) - add_fk_steering_args(parser) - return parser.parse_args() - - -def parse_boltz2_fk_steering_args(): - parser = argparse.ArgumentParser( - description="FK steering refinement with Boltz-2 and real-space density" - ) - add_boltz2_specific_args(parser) - add_generic_args(parser) - add_fk_steering_args(parser) - return parser.parse_args() - - -def parse_boltz1_fk_steering_args(): - parser = argparse.ArgumentParser( - description="FK steering refinement with Boltz-1 and real-space density" - ) - add_boltz1_specific_args(parser) - add_generic_args(parser) - add_fk_steering_args(parser) - return parser.parse_args() - - -def parse_rf3_pure_guidance_args(): - parser = argparse.ArgumentParser( - description="Pure guidance refinement with RF3 and real-space density" - ) - add_generic_args(parser) - add_rf3_specific_args(parser) - add_pure_guidance_args(parser) - return parser.parse_args() - - -def parse_rf3_fk_steering_args(): - parser = argparse.ArgumentParser( - description="FK steering refinement with RF3 and real-space density" - ) - add_generic_args(parser) - add_rf3_specific_args(parser) - add_fk_steering_args(parser) - return parser.parse_args() - - @dataclass class JobConfig: protein: str diff --git a/tests/cli/__init__.py b/tests/cli/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/cli/test_guidance_cli.py b/tests/cli/test_guidance_cli.py new file mode 100644 index 00000000..8516f524 --- /dev/null +++ b/tests/cli/test_guidance_cli.py @@ -0,0 +1,167 @@ +"""Tests for the unified guidance CLI and GuidanceConfig.from_cli().""" + +from __future__ import annotations + +import pytest + +from sampleworks.utils.guidance_script_arguments import GuidanceConfig + + +COMMON_ARGS = [ + "--protein", "1VME", + "--structure", "test.cif", + "--density", "test.ccp4", + "--resolution", "1.8", + "--output-dir", "output", +] + + +class TestFromCliUnified: + """Test from_cli() when model and guidance_type come from CLI args.""" + + def test_boltz2_pure_guidance(self): + argv = ["--model", "boltz2", "--guidance-type", "pure_guidance"] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.model == "boltz2" + assert config.guidance_type == "pure_guidance" + assert config.protein == "1VME" + assert config.structure == "test.cif" + assert config.density == "test.ccp4" + assert config.resolution == 1.8 + + def test_rf3_fk_steering(self): + argv = ["--model", "rf3", "--guidance-type", "fk_steering"] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.model == "rf3" + assert config.guidance_type == "fk_steering" + assert hasattr(config, "num_particles") + assert hasattr(config, "fk_lambda") + + def test_model_specific_args_boltz2_method(self): + argv = [ + "--model", "boltz2", "--guidance-type", "pure_guidance", + "--method", "MD", + ] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.method == "MD" + + def test_model_specific_args_rf3_msa(self): + argv = [ + "--model", "rf3", "--guidance-type", "pure_guidance", + "--msa-path", "/data/msa.a3m", + ] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.msa_path == "/data/msa.a3m" + + def test_guidance_specific_args_fk(self): + argv = [ + "--model", "boltz1", "--guidance-type", "fk_steering", + "--num-particles", "5", + "--fk-lambda", "2.0", + ] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.num_particles == 5 + assert config.fk_lambda == 2.0 + + def test_guidance_specific_args_pure(self): + argv = [ + "--model", "boltz1", "--guidance-type", "pure_guidance", + "--step-size", "0.5", + "--step-scaler-type", "dataspace", + ] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.step_size == 0.5 + assert config.step_scaler_type == "dataspace" + + +class TestFromCliLegacyScripts: + """Test from_cli() when model/guidance_type are pre-set (legacy script pattern).""" + + def test_preset_model_and_guidance_type(self): + config = GuidanceConfig.from_cli( + COMMON_ARGS, + model="boltz2", + guidance_type="pure_guidance", + ) + assert config.model == "boltz2" + assert config.guidance_type == "pure_guidance" + assert config.protein == "1VME" + + def test_no_model_guidance_type_on_cli_required(self): + """Legacy scripts should not require --model/--guidance-type on CLI.""" + config = GuidanceConfig.from_cli( + COMMON_ARGS, + model="protenix", + guidance_type="fk_steering", + ) + assert config.model == "protenix" + assert config.guidance_type == "fk_steering" + + @pytest.mark.parametrize("model", ["boltz1", "boltz2", "protenix", "rf3"]) + @pytest.mark.parametrize("guidance_type", ["pure_guidance", "fk_steering"]) + def test_all_model_guidance_combos(self, model, guidance_type): + config = GuidanceConfig.from_cli( + COMMON_ARGS, model=model, guidance_type=guidance_type, + ) + assert config.model == model + assert config.guidance_type == guidance_type + + +class TestFromCliValidation: + """Test error handling for invalid inputs.""" + + def test_missing_protein_errors(self): + argv = [ + "--model", "boltz2", "--guidance-type", "pure_guidance", + "--structure", "test.cif", + "--density", "test.ccp4", + "--resolution", "1.8", + ] + with pytest.raises(SystemExit): + GuidanceConfig.from_cli(argv) + + def test_invalid_model_errors(self): + argv = ["--model", "invalid_model", "--guidance-type", "pure_guidance"] + COMMON_ARGS + with pytest.raises(SystemExit): + GuidanceConfig.from_cli(argv) + + def test_invalid_guidance_type_errors(self): + argv = ["--model", "boltz2", "--guidance-type", "invalid_type"] + COMMON_ARGS + with pytest.raises(SystemExit): + GuidanceConfig.from_cli(argv) + + def test_missing_required_structure_errors(self): + argv = [ + "--model", "boltz2", "--guidance-type", "pure_guidance", + "--protein", "1VME", + "--density", "test.ccp4", + "--resolution", "1.8", + ] + with pytest.raises(SystemExit): + GuidanceConfig.from_cli(argv) + + +class TestFromCliDefaults: + """Test that defaults are applied correctly.""" + + def test_generic_defaults(self): + argv = ["--model", "boltz1", "--guidance-type", "pure_guidance"] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.output_dir == "output" + assert config.partial_diffusion_step == 0 + assert config.loss_order == 2 + assert config.guidance_start == -1 + assert config.augmentation is False + assert config.align_to_input is False + assert config.ensemble_size == 4 + + def test_boolean_flags(self): + argv = [ + "--model", "boltz1", "--guidance-type", "pure_guidance", + "--augmentation", "--align-to-input", "--gradient-normalization", "--em", + ] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.augmentation is True + assert config.align_to_input is True + assert config.gradient_normalization is True + assert config.em is True From 0fb3df59ab685dd68437add3ead43ccebf957e9f Mon Sep 17 00:00:00 2001 From: Abdelsalam Date: Fri, 10 Apr 2026 00:27:16 +0200 Subject: [PATCH 2/9] Fix ruff I001 import sort in test_guidance_cli --- tests/cli/test_guidance_cli.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/cli/test_guidance_cli.py b/tests/cli/test_guidance_cli.py index 8516f524..1952c698 100644 --- a/tests/cli/test_guidance_cli.py +++ b/tests/cli/test_guidance_cli.py @@ -3,7 +3,6 @@ from __future__ import annotations import pytest - from sampleworks.utils.guidance_script_arguments import GuidanceConfig From e2c9112618380d713247ef97d7cf3dfa0eca8dc9 Mon Sep 17 00:00:00 2001 From: Abdelsalam Date: Fri, 10 Apr 2026 00:34:46 +0200 Subject: [PATCH 3/9] Fix ruff format, propagate exit codes, reject conflicting preset args Apply ruff format to guidance_script_arguments.py and test_guidance_cli.py. Propagate run_guidance exit code in all 8 wrapper scripts so failures are visible to CI. Stop registering --model/--guidance-type on the full parser when they are preset by wrapper scripts, preventing silent conflicts. --- scripts/boltz1_fk_steering.py | 7 +- scripts/boltz1_pure_guidance.py | 7 +- scripts/boltz2_fk_steering.py | 7 +- scripts/boltz2_pure_guidance.py | 7 +- scripts/protenix_fk_steering.py | 7 +- scripts/protenix_pure_guidance.py | 7 +- scripts/rf3_fk_steering.py | 7 +- scripts/rf3_pure_guidance.py | 7 +- .../utils/guidance_script_arguments.py | 58 +++++++---- tests/cli/test_guidance_cli.py | 95 ++++++++++++++----- 10 files changed, 150 insertions(+), 59 deletions(-) diff --git a/scripts/boltz1_fk_steering.py b/scripts/boltz1_fk_steering.py index f1fafaff..3ef4316d 100644 --- a/scripts/boltz1_fk_steering.py +++ b/scripts/boltz1_fk_steering.py @@ -1,5 +1,7 @@ """Run FK steering with real-space density reward on the Boltz1 model.""" +import sys + from sampleworks.utils.guidance_script_arguments import GuidanceConfig from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance @@ -11,8 +13,9 @@ def main(): getattr(config, "model_checkpoint", None), config.model, ) - run_guidance(config, config.guidance_type, model_wrapper, device) + result = run_guidance(config, config.guidance_type, model_wrapper, device) + return result.exit_code if __name__ == "__main__": - main() + sys.exit(main()) diff --git a/scripts/boltz1_pure_guidance.py b/scripts/boltz1_pure_guidance.py index 3d09cfdf..28aa4de9 100644 --- a/scripts/boltz1_pure_guidance.py +++ b/scripts/boltz1_pure_guidance.py @@ -1,5 +1,7 @@ """Run pure guidance with real-space density reward on the Boltz1 model.""" +import sys + from sampleworks.utils.guidance_script_arguments import GuidanceConfig from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance @@ -11,8 +13,9 @@ def main(): getattr(config, "model_checkpoint", None), config.model, ) - run_guidance(config, config.guidance_type, model_wrapper, device) + result = run_guidance(config, config.guidance_type, model_wrapper, device) + return result.exit_code if __name__ == "__main__": - main() + sys.exit(main()) diff --git a/scripts/boltz2_fk_steering.py b/scripts/boltz2_fk_steering.py index 1fad3f47..a65e5eec 100644 --- a/scripts/boltz2_fk_steering.py +++ b/scripts/boltz2_fk_steering.py @@ -1,5 +1,7 @@ """Run FK steering with real-space density reward on the Boltz2 model.""" +import sys + from sampleworks.utils.guidance_script_arguments import GuidanceConfig from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance @@ -12,8 +14,9 @@ def main(): config.model, method=getattr(config, "method", None), ) - run_guidance(config, config.guidance_type, model_wrapper, device) + result = run_guidance(config, config.guidance_type, model_wrapper, device) + return result.exit_code if __name__ == "__main__": - main() + sys.exit(main()) diff --git a/scripts/boltz2_pure_guidance.py b/scripts/boltz2_pure_guidance.py index 77c7d9a6..33729516 100644 --- a/scripts/boltz2_pure_guidance.py +++ b/scripts/boltz2_pure_guidance.py @@ -1,5 +1,7 @@ """Run pure guidance with real-space density reward on the Boltz2 model.""" +import sys + from sampleworks.utils.guidance_script_arguments import GuidanceConfig from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance @@ -12,8 +14,9 @@ def main(): config.model, method=getattr(config, "method", None), ) - run_guidance(config, config.guidance_type, model_wrapper, device) + result = run_guidance(config, config.guidance_type, model_wrapper, device) + return result.exit_code if __name__ == "__main__": - main() + sys.exit(main()) diff --git a/scripts/protenix_fk_steering.py b/scripts/protenix_fk_steering.py index a8acb8d3..00104c91 100644 --- a/scripts/protenix_fk_steering.py +++ b/scripts/protenix_fk_steering.py @@ -1,5 +1,7 @@ """Run FK steering with real-space density reward on the Protenix model.""" +import sys + from sampleworks.utils.guidance_script_arguments import GuidanceConfig from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance @@ -11,8 +13,9 @@ def main(): getattr(config, "model_checkpoint", None), config.model, ) - run_guidance(config, config.guidance_type, model_wrapper, device) + result = run_guidance(config, config.guidance_type, model_wrapper, device) + return result.exit_code if __name__ == "__main__": - main() + sys.exit(main()) diff --git a/scripts/protenix_pure_guidance.py b/scripts/protenix_pure_guidance.py index 7f46fdfb..d64143d0 100644 --- a/scripts/protenix_pure_guidance.py +++ b/scripts/protenix_pure_guidance.py @@ -1,5 +1,7 @@ """Run pure guidance with real-space density reward on the Protenix model.""" +import sys + from sampleworks.utils.guidance_script_arguments import GuidanceConfig from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance @@ -11,8 +13,9 @@ def main(): getattr(config, "model_checkpoint", None), config.model, ) - run_guidance(config, config.guidance_type, model_wrapper, device) + result = run_guidance(config, config.guidance_type, model_wrapper, device) + return result.exit_code if __name__ == "__main__": - main() + sys.exit(main()) diff --git a/scripts/rf3_fk_steering.py b/scripts/rf3_fk_steering.py index fe063e0f..c83e8cb2 100644 --- a/scripts/rf3_fk_steering.py +++ b/scripts/rf3_fk_steering.py @@ -1,5 +1,7 @@ """Run FK steering with real-space density reward on the RF3 model.""" +import sys + from sampleworks.utils.guidance_script_arguments import GuidanceConfig from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance @@ -11,8 +13,9 @@ def main(): getattr(config, "model_checkpoint", None), config.model, ) - run_guidance(config, config.guidance_type, model_wrapper, device) + result = run_guidance(config, config.guidance_type, model_wrapper, device) + return result.exit_code if __name__ == "__main__": - main() + sys.exit(main()) diff --git a/scripts/rf3_pure_guidance.py b/scripts/rf3_pure_guidance.py index 160fb444..21b41ac8 100644 --- a/scripts/rf3_pure_guidance.py +++ b/scripts/rf3_pure_guidance.py @@ -1,5 +1,7 @@ """Run pure guidance with real-space density reward on the RF3 model.""" +import sys + from sampleworks.utils.guidance_script_arguments import GuidanceConfig from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance @@ -11,8 +13,9 @@ def main(): getattr(config, "model_checkpoint", None), config.model, ) - run_guidance(config, config.guidance_type, model_wrapper, device) + result = run_guidance(config, config.guidance_type, model_wrapper, device) + return result.exit_code if __name__ == "__main__": - main() + sys.exit(main()) diff --git a/src/sampleworks/utils/guidance_script_arguments.py b/src/sampleworks/utils/guidance_script_arguments.py index 9f033c1e..adf5cc87 100644 --- a/src/sampleworks/utils/guidance_script_arguments.py +++ b/src/sampleworks/utils/guidance_script_arguments.py @@ -128,13 +128,21 @@ def validate_model_checkpoint( # from a parsed argparse.Namespace onto a GuidanceConfig instance. _DYNAMIC_ATTRS = [ # pure guidance - "step_size", "step_scaler_type", + "step_size", + "step_scaler_type", # fk steering - "num_particles", "fk_resampling_interval", "fk_lambda", - "num_gd_steps", "guidance_weight", "guidance_interval", + "num_particles", + "fk_resampling_interval", + "fk_lambda", + "num_gd_steps", + "guidance_weight", + "guidance_interval", # model-specific - "model_checkpoint", "method", "msa_path", - "disable_chiral_features", "track_chiral_features", + "model_checkpoint", + "method", + "msa_path", + "disable_chiral_features", + "track_chiral_features", # generic (overridable) "ensemble_size", ] @@ -188,19 +196,27 @@ def from_cli( """ model_choices = [m.value for m in StructurePredictor] guidance_choices = [g.value for g in GuidanceType] + model_preset = model is not None + guidance_preset = guidance_type is not None # -- first pass: resolve model & guidance_type if not pre-set -------- - if model is None or guidance_type is None: + if not model_preset or not guidance_preset: pre = argparse.ArgumentParser(add_help=False) - if model is None: + if not model_preset: pre.add_argument( - "--model", type=str, required=True, choices=model_choices, + "--model", + type=str, + required=True, + choices=model_choices, help="Structure prediction model", ) - if guidance_type is None: + if not guidance_preset: pre.add_argument( - "--guidance-type", type=str, required=True, - choices=guidance_choices, help="Guidance method", + "--guidance-type", + type=str, + required=True, + choices=guidance_choices, + help="Guidance method", ) pre_args, _ = pre.parse_known_args(argv) model = model or pre_args.model @@ -210,18 +226,26 @@ def from_cli( parser = argparse.ArgumentParser( description=f"Run {guidance_type} guidance with {model}", ) - if model is not None: + if not model_preset: parser.add_argument( - "--model", type=str, default=model, choices=model_choices, + "--model", + type=str, + default=model, + choices=model_choices, help="Structure prediction model", ) - if guidance_type is not None: + if not guidance_preset: parser.add_argument( - "--guidance-type", type=str, default=guidance_type, - choices=guidance_choices, help="Guidance method", + "--guidance-type", + type=str, + default=guidance_type, + choices=guidance_choices, + help="Guidance method", ) parser.add_argument( - "--protein", type=str, required=True, + "--protein", + type=str, + required=True, help="Protein identifier (must match naming used in grid search / evaluation)", ) add_generic_args(parser) diff --git a/tests/cli/test_guidance_cli.py b/tests/cli/test_guidance_cli.py index 1952c698..e5ce60bf 100644 --- a/tests/cli/test_guidance_cli.py +++ b/tests/cli/test_guidance_cli.py @@ -7,11 +7,16 @@ COMMON_ARGS = [ - "--protein", "1VME", - "--structure", "test.cif", - "--density", "test.ccp4", - "--resolution", "1.8", - "--output-dir", "output", + "--protein", + "1VME", + "--structure", + "test.cif", + "--density", + "test.ccp4", + "--resolution", + "1.8", + "--output-dir", + "output", ] @@ -38,25 +43,38 @@ def test_rf3_fk_steering(self): def test_model_specific_args_boltz2_method(self): argv = [ - "--model", "boltz2", "--guidance-type", "pure_guidance", - "--method", "MD", + "--model", + "boltz2", + "--guidance-type", + "pure_guidance", + "--method", + "MD", ] + COMMON_ARGS config = GuidanceConfig.from_cli(argv) assert config.method == "MD" def test_model_specific_args_rf3_msa(self): argv = [ - "--model", "rf3", "--guidance-type", "pure_guidance", - "--msa-path", "/data/msa.a3m", + "--model", + "rf3", + "--guidance-type", + "pure_guidance", + "--msa-path", + "/data/msa.a3m", ] + COMMON_ARGS config = GuidanceConfig.from_cli(argv) assert config.msa_path == "/data/msa.a3m" def test_guidance_specific_args_fk(self): argv = [ - "--model", "boltz1", "--guidance-type", "fk_steering", - "--num-particles", "5", - "--fk-lambda", "2.0", + "--model", + "boltz1", + "--guidance-type", + "fk_steering", + "--num-particles", + "5", + "--fk-lambda", + "2.0", ] + COMMON_ARGS config = GuidanceConfig.from_cli(argv) assert config.num_particles == 5 @@ -64,9 +82,14 @@ def test_guidance_specific_args_fk(self): def test_guidance_specific_args_pure(self): argv = [ - "--model", "boltz1", "--guidance-type", "pure_guidance", - "--step-size", "0.5", - "--step-scaler-type", "dataspace", + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--step-size", + "0.5", + "--step-scaler-type", + "dataspace", ] + COMMON_ARGS config = GuidanceConfig.from_cli(argv) assert config.step_size == 0.5 @@ -100,7 +123,9 @@ def test_no_model_guidance_type_on_cli_required(self): @pytest.mark.parametrize("guidance_type", ["pure_guidance", "fk_steering"]) def test_all_model_guidance_combos(self, model, guidance_type): config = GuidanceConfig.from_cli( - COMMON_ARGS, model=model, guidance_type=guidance_type, + COMMON_ARGS, + model=model, + guidance_type=guidance_type, ) assert config.model == model assert config.guidance_type == guidance_type @@ -111,10 +136,16 @@ class TestFromCliValidation: def test_missing_protein_errors(self): argv = [ - "--model", "boltz2", "--guidance-type", "pure_guidance", - "--structure", "test.cif", - "--density", "test.ccp4", - "--resolution", "1.8", + "--model", + "boltz2", + "--guidance-type", + "pure_guidance", + "--structure", + "test.cif", + "--density", + "test.ccp4", + "--resolution", + "1.8", ] with pytest.raises(SystemExit): GuidanceConfig.from_cli(argv) @@ -131,10 +162,16 @@ def test_invalid_guidance_type_errors(self): def test_missing_required_structure_errors(self): argv = [ - "--model", "boltz2", "--guidance-type", "pure_guidance", - "--protein", "1VME", - "--density", "test.ccp4", - "--resolution", "1.8", + "--model", + "boltz2", + "--guidance-type", + "pure_guidance", + "--protein", + "1VME", + "--density", + "test.ccp4", + "--resolution", + "1.8", ] with pytest.raises(SystemExit): GuidanceConfig.from_cli(argv) @@ -156,8 +193,14 @@ def test_generic_defaults(self): def test_boolean_flags(self): argv = [ - "--model", "boltz1", "--guidance-type", "pure_guidance", - "--augmentation", "--align-to-input", "--gradient-normalization", "--em", + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--augmentation", + "--align-to-input", + "--gradient-normalization", + "--em", ] + COMMON_ARGS config = GuidanceConfig.from_cli(argv) assert config.augmentation is True From 9ff2349f47d15b927cd7abb891b931bb9f7f95fa Mon Sep 17 00:00:00 2001 From: Abdelsalam Date: Fri, 10 Apr 2026 18:04:27 +0200 Subject: [PATCH 4/9] Address review: dispatch tables, preset validation, cross-model tests Remove lambda wrappers from _MODEL_ARG_ADDERS and _GUIDANCE_ARG_ADDERS, move dicts below function definitions. Replace __post_init__ if/elif chain with dispatch table lookups. Register --model/--guidance-type as hidden (SUPPRESS) args in preset mode and validate after parsing. Legacy scripts now error clearly on conflicting values instead of silently ignoring them. Add 5 tests for cross-model arg rejection: wrong model-specific args (--method on protenix/rf3, --msa-path on boltz1) and conflicting preset model/guidance_type. --- .../utils/guidance_script_arguments.py | 86 ++++++++++--------- tests/cli/test_guidance_cli.py | 50 +++++++++++ 2 files changed, 94 insertions(+), 42 deletions(-) diff --git a/src/sampleworks/utils/guidance_script_arguments.py b/src/sampleworks/utils/guidance_script_arguments.py index adf5cc87..c5fdc681 100644 --- a/src/sampleworks/utils/guidance_script_arguments.py +++ b/src/sampleworks/utils/guidance_script_arguments.py @@ -112,18 +112,6 @@ def validate_model_checkpoint( return str(checkpoint_path) -_MODEL_ARG_ADDERS: dict[str, Any] = { - "boltz1": lambda p: add_boltz1_specific_args(p), - "boltz2": lambda p: add_boltz2_specific_args(p), - "protenix": lambda p: add_protenix_specific_args(p), - "rf3": lambda p: add_rf3_specific_args(p), -} - -_GUIDANCE_ARG_ADDERS: dict[str, Any] = { - "pure_guidance": lambda p: add_pure_guidance_args(p), - "fk_steering": lambda p: add_fk_steering_args(p), -} - # Attributes set dynamically by add_*_args helpers that should be copied # from a parsed argparse.Namespace onto a GuidanceConfig instance. _DYNAMIC_ATTRS = [ @@ -226,22 +214,20 @@ def from_cli( parser = argparse.ArgumentParser( description=f"Run {guidance_type} guidance with {model}", ) - if not model_preset: - parser.add_argument( - "--model", - type=str, - default=model, - choices=model_choices, - help="Structure prediction model", - ) - if not guidance_preset: - parser.add_argument( - "--guidance-type", - type=str, - default=guidance_type, - choices=guidance_choices, - help="Guidance method", - ) + parser.add_argument( + "--model", + type=str, + default=model, + choices=model_choices, + help=argparse.SUPPRESS if model_preset else "Structure prediction model", + ) + parser.add_argument( + "--guidance-type", + type=str, + default=guidance_type, + choices=guidance_choices, + help=argparse.SUPPRESS if guidance_preset else "Guidance method", + ) parser.add_argument( "--protein", type=str, @@ -254,6 +240,17 @@ def from_cli( args = parser.parse_args(argv) + if model_preset and args.model != model: + parser.error( + f"This script is fixed to --model {model}." + f" Use sampleworks-guidance for other models." + ) + if guidance_preset and args.guidance_type != guidance_type: + parser.error( + f"This script is fixed to --guidance-type {guidance_type}." + f" Use sampleworks-guidance for other guidance types." + ) + config = cls( protein=args.protein, structure=args.structure, @@ -284,22 +281,14 @@ def from_cli( def __post_init__(self): """Set up guidance config for a given model and guidance type""" - if self.guidance_type == GuidanceType.PURE_GUIDANCE: - add_pure_guidance_args(self) - elif self.guidance_type == GuidanceType.FK_STEERING: - add_fk_steering_args(self) - else: + try: + _GUIDANCE_ARG_ADDERS[self.guidance_type](self) + except KeyError: raise ValueError(f"Unknown guidance type: {self.guidance_type}") - if self.model == StructurePredictor.BOLTZ_1: - add_boltz1_specific_args(self) - elif self.model == StructurePredictor.BOLTZ_2: - add_boltz2_specific_args(self) - elif self.model == StructurePredictor.PROTENIX: - add_protenix_specific_args(self) - elif self.model == StructurePredictor.RF3: - add_rf3_specific_args(self) - else: + try: + _MODEL_ARG_ADDERS[self.model](self) + except KeyError: raise ValueError(f"Unknown model type: {self.model}") def populate_config_for_guidance_type(self, job: JobConfig, args: argparse.Namespace): @@ -504,6 +493,19 @@ def add_rf3_specific_args(parser: argparse.ArgumentParser | GuidanceConfig): ) +_MODEL_ARG_ADDERS: dict[str, Any] = { + "boltz1": add_boltz1_specific_args, + "boltz2": add_boltz2_specific_args, + "protenix": add_protenix_specific_args, + "rf3": add_rf3_specific_args, +} + +_GUIDANCE_ARG_ADDERS: dict[str, Any] = { + "pure_guidance": add_pure_guidance_args, + "fk_steering": add_fk_steering_args, +} + + @dataclass class JobConfig: protein: str diff --git a/tests/cli/test_guidance_cli.py b/tests/cli/test_guidance_cli.py index e5ce60bf..581780aa 100644 --- a/tests/cli/test_guidance_cli.py +++ b/tests/cli/test_guidance_cli.py @@ -207,3 +207,53 @@ def test_boolean_flags(self): assert config.align_to_input is True assert config.gradient_normalization is True assert config.em is True + + +class TestCrossModelArgRejection: + """Test that model-specific args are rejected for the wrong model.""" + + def test_method_rejected_for_protenix(self): + argv = [ + "--model", + "protenix", + "--guidance-type", + "pure_guidance", + "--method", + "MD", + ] + COMMON_ARGS + with pytest.raises(SystemExit): + GuidanceConfig.from_cli(argv) + + def test_method_rejected_for_rf3(self): + argv = [ + "--model", + "rf3", + "--guidance-type", + "pure_guidance", + "--method", + "MD", + ] + COMMON_ARGS + with pytest.raises(SystemExit): + GuidanceConfig.from_cli(argv) + + def test_msa_path_rejected_for_boltz1(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--msa-path", + "/data/msa.a3m", + ] + COMMON_ARGS + with pytest.raises(SystemExit): + GuidanceConfig.from_cli(argv) + + def test_conflicting_model_in_preset_mode(self): + argv = ["--model", "boltz2"] + COMMON_ARGS + with pytest.raises(SystemExit): + GuidanceConfig.from_cli(argv, model="rf3", guidance_type="fk_steering") + + def test_conflicting_guidance_type_in_preset_mode(self): + argv = ["--guidance-type", "fk_steering"] + COMMON_ARGS + with pytest.raises(SystemExit): + GuidanceConfig.from_cli(argv, model="boltz1", guidance_type="pure_guidance") From f2efed2bdf51a483ec1093997c537f8baf385458 Mon Sep 17 00:00:00 2001 From: Abdelsalam Date: Fri, 10 Apr 2026 18:15:31 +0200 Subject: [PATCH 5/9] Add comprehensive CLI test coverage Add 20 new tests covering cross-guidance-type arg rejection, preset value matching, arg passthrough for checkpoint/device/log-path/chiral features, and validation edge cases (empty args, missing resolution, invalid choices). --- tests/cli/test_guidance_cli.py | 246 +++++++++++++++++++++++++++++++++ 1 file changed, 246 insertions(+) diff --git a/tests/cli/test_guidance_cli.py b/tests/cli/test_guidance_cli.py index 581780aa..885b1565 100644 --- a/tests/cli/test_guidance_cli.py +++ b/tests/cli/test_guidance_cli.py @@ -257,3 +257,249 @@ def test_conflicting_guidance_type_in_preset_mode(self): argv = ["--guidance-type", "fk_steering"] + COMMON_ARGS with pytest.raises(SystemExit): GuidanceConfig.from_cli(argv, model="boltz1", guidance_type="pure_guidance") + + def test_fk_args_rejected_for_pure_guidance(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--num-particles", + "5", + ] + COMMON_ARGS + with pytest.raises(SystemExit): + GuidanceConfig.from_cli(argv) + + def test_pure_args_rejected_for_fk_steering(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "fk_steering", + "--step-scaler-type", + "dataspace", + ] + COMMON_ARGS + with pytest.raises(SystemExit): + GuidanceConfig.from_cli(argv) + + def test_msa_path_rejected_for_boltz2(self): + argv = [ + "--model", + "boltz2", + "--guidance-type", + "pure_guidance", + "--msa-path", + "/data/msa.a3m", + ] + COMMON_ARGS + with pytest.raises(SystemExit): + GuidanceConfig.from_cli(argv) + + def test_chiral_flags_rejected_for_boltz1(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--disable-chiral-features", + ] + COMMON_ARGS + with pytest.raises(SystemExit): + GuidanceConfig.from_cli(argv) + + +class TestPresetMatchingAccepted: + """Matching preset values should not trigger false-positive rejection.""" + + def test_matching_model_accepted(self): + argv = ["--model", "boltz1"] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv, model="boltz1", guidance_type="fk_steering") + assert config.model == "boltz1" + + def test_matching_guidance_type_accepted(self): + argv = ["--guidance-type", "pure_guidance"] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv, model="boltz2", guidance_type="pure_guidance") + assert config.guidance_type == "pure_guidance" + + +class TestArgPassthrough: + """Test that non-default argument values propagate to GuidanceConfig.""" + + def test_model_checkpoint(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--model-checkpoint", + "/custom/path.ckpt", + ] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.model_checkpoint == "/custom/path.ckpt" + + def test_device(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--device", + "cpu", + ] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.device == "cpu" + + def test_output_dir_override(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--output-dir", + "/custom/output", + "--protein", + "1VME", + "--structure", + "test.cif", + "--density", + "test.ccp4", + "--resolution", + "1.8", + ] + config = GuidanceConfig.from_cli(argv) + assert config.output_dir == "/custom/output" + + def test_partial_diffusion_step(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--partial-diffusion-step", + "50", + ] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.partial_diffusion_step == 50 + + def test_log_path(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--log-path", + "/tmp/run.log", + ] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.log_path == "/tmp/run.log" + + def test_ensemble_size(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "fk_steering", + "--ensemble-size", + "8", + ] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.ensemble_size == 8 + + def test_rf3_chiral_features(self): + argv = [ + "--model", + "rf3", + "--guidance-type", + "pure_guidance", + "--disable-chiral-features", + "--track-chiral-features", + ] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.disable_chiral_features is True + assert config.track_chiral_features is True + + def test_loss_order(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--loss-order", + "1", + ] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.loss_order == 1 + + def test_guidance_start(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--guidance-start", + "10", + ] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.guidance_start == 10 + + +class TestValidationEdgeCases: + """Additional validation edge cases.""" + + def test_no_args_at_all(self): + with pytest.raises(SystemExit): + GuidanceConfig.from_cli([]) + + def test_missing_resolution(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--protein", + "1VME", + "--structure", + "test.cif", + "--density", + "test.ccp4", + ] + with pytest.raises(SystemExit): + GuidanceConfig.from_cli(argv) + + def test_missing_density(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--protein", + "1VME", + "--structure", + "test.cif", + "--resolution", + "1.8", + ] + with pytest.raises(SystemExit): + GuidanceConfig.from_cli(argv) + + def test_invalid_loss_order(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--loss-order", + "3", + ] + COMMON_ARGS + with pytest.raises(SystemExit): + GuidanceConfig.from_cli(argv) + + def test_invalid_step_scaler_type(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--step-scaler-type", + "invalid", + ] + COMMON_ARGS + with pytest.raises(SystemExit): + GuidanceConfig.from_cli(argv) From b589947efd2a66964aa155154c02c4974d370a63 Mon Sep 17 00:00:00 2001 From: Abdelsalam Date: Fri, 10 Apr 2026 18:16:47 +0200 Subject: [PATCH 6/9] Update pixi.lock after rebase onto main --- pixi.lock | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pixi.lock b/pixi.lock index 56b57c46..75018f01 100644 --- a/pixi.lock +++ b/pixi.lock @@ -9719,8 +9719,8 @@ packages: timestamp: 1753407970803 - pypi: ./ name: sampleworks - version: 0.4.3 - sha256: 86ffa999bf980b906be83b7e94152b34e4e617dc90cb3d37a2d16baa4abce7ba + version: 0.4.4 + sha256: 8e5bd913ac135e29d9e6008401c2bc54f50c939ffaf090ce3283f6da46035df1 requires_dist: - atomworks[ml]==2.1.1 - python-dotenv From 898d3c716b6574214756351780fe96fa38a65823 Mon Sep 17 00:00:00 2001 From: Abdelsalam Date: Fri, 10 Apr 2026 18:21:55 +0200 Subject: [PATCH 7/9] Remove 8 model-specific scripts in favor of unified CLI Delete boltz1/boltz2/protenix/rf3 x pure_guidance/fk_steering wrapper scripts. The sampleworks-guidance CLI replaces all of them. Update README and AGENTS.md to remove references to deleted scripts. --- AGENTS.md | 24 +++++++----------------- README.md | 1 - scripts/boltz1_fk_steering.py | 21 --------------------- scripts/boltz1_pure_guidance.py | 21 --------------------- scripts/boltz2_fk_steering.py | 22 ---------------------- scripts/boltz2_pure_guidance.py | 22 ---------------------- scripts/protenix_fk_steering.py | 21 --------------------- scripts/protenix_pure_guidance.py | 21 --------------------- scripts/rf3_fk_steering.py | 21 --------------------- scripts/rf3_pure_guidance.py | 21 --------------------- 10 files changed, 7 insertions(+), 188 deletions(-) delete mode 100644 scripts/boltz1_fk_steering.py delete mode 100644 scripts/boltz1_pure_guidance.py delete mode 100644 scripts/boltz2_fk_steering.py delete mode 100644 scripts/boltz2_pure_guidance.py delete mode 100644 scripts/protenix_fk_steering.py delete mode 100644 scripts/protenix_pure_guidance.py delete mode 100644 scripts/rf3_fk_steering.py delete mode 100644 scripts/rf3_pure_guidance.py diff --git a/AGENTS.md b/AGENTS.md index a017e274..b56d51d6 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -159,25 +159,13 @@ src/sampleworks/ ## Running Guidance Pipelines -The `scripts/` directory contains ready-to-run guidance scripts for each supported model and trajectory scaler: - -``` -scripts/ -├── boltz1_pure_guidance.py # Boltz-1 + pure guidance -├── boltz2_pure_guidance.py # Boltz-2 + pure guidance -├── boltz2_fk_steering.py # Boltz-2 + Feynman-Kaç steering -├── protenix_pure_guidance.py # Protenix + pure guidance -├── protenix_fk_steering.py # Protenix + FK steering -├── rf3_pure_guidance.py # RF3 + pure guidance -├── rf3_fk_steering.py # RF3 + FK steering -├── run_guidance_pipeline.py # Generic pipeline runner -└── eval/ # Evaluation scripts (RSCC, lDDT, clashscore) -``` - -Each script follows the same pattern: load a model wrapper, then call `run_guidance()` from `utils/guidance_script_utils.py`. Example invocation: +Use the unified `sampleworks-guidance` CLI to run guidance with any supported model and trajectory scaler: ```bash -pixi run -e boltz python scripts/boltz2_pure_guidance.py \ +pixi run -e boltz sampleworks-guidance \ + --model boltz2 \ + --guidance-type pure_guidance \ + --protein 1VME \ --model-checkpoint ~/.boltz/boltz2_conf.ckpt \ --output-dir output/boltz2_pure_guidance \ --structure tests/resources/1vme/1vme_final_carved_edited_0.5occA_0.5occB.cif \ @@ -188,6 +176,8 @@ pixi run -e boltz python scripts/boltz2_pure_guidance.py \ --augmentation --align-to-input ``` +Run `sampleworks-guidance --model --guidance-type --help` to see all available options. + The `run_guidance()` function in `utils/guidance_script_utils.py` is the central orchestrator. It wires together the model wrapper, sampler (`AF3EDMSampler`), step scaler (`DataSpaceDPSScaler` or `NoiseSpaceDPSScaler`), trajectory scaler (`PureGuidance` or `FKSteering`), and reward function. When adding a new model or guidance strategy, this is the best reference for how components compose in practice. ## Development Environment diff --git a/README.md b/README.md index cc36a862..0b123355 100644 --- a/README.md +++ b/README.md @@ -100,7 +100,6 @@ Output files appear in `output/boltz2_pure_guidance/`: `refined.cif` (final ense Model-specific arguments (e.g. `--method` for boltz2, `--msa-path` for rf3) and guidance-type-specific arguments (e.g. `--num-particles` for fk_steering) are included automatically. Run `sampleworks-guidance --model --guidance-type --help` to see all available options. -The individual scripts in `scripts/` (e.g. `boltz2_pure_guidance.py`) still work as shortcuts with `--model` and `--guidance-type` pre-set. ## Grid Search diff --git a/scripts/boltz1_fk_steering.py b/scripts/boltz1_fk_steering.py deleted file mode 100644 index 3ef4316d..00000000 --- a/scripts/boltz1_fk_steering.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Run FK steering with real-space density reward on the Boltz1 model.""" - -import sys - -from sampleworks.utils.guidance_script_arguments import GuidanceConfig -from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance - - -def main(): - config = GuidanceConfig.from_cli(model="boltz1", guidance_type="fk_steering") - device, model_wrapper = get_model_and_device( - config.device, - getattr(config, "model_checkpoint", None), - config.model, - ) - result = run_guidance(config, config.guidance_type, model_wrapper, device) - return result.exit_code - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/scripts/boltz1_pure_guidance.py b/scripts/boltz1_pure_guidance.py deleted file mode 100644 index 28aa4de9..00000000 --- a/scripts/boltz1_pure_guidance.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Run pure guidance with real-space density reward on the Boltz1 model.""" - -import sys - -from sampleworks.utils.guidance_script_arguments import GuidanceConfig -from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance - - -def main(): - config = GuidanceConfig.from_cli(model="boltz1", guidance_type="pure_guidance") - device, model_wrapper = get_model_and_device( - config.device, - getattr(config, "model_checkpoint", None), - config.model, - ) - result = run_guidance(config, config.guidance_type, model_wrapper, device) - return result.exit_code - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/scripts/boltz2_fk_steering.py b/scripts/boltz2_fk_steering.py deleted file mode 100644 index a65e5eec..00000000 --- a/scripts/boltz2_fk_steering.py +++ /dev/null @@ -1,22 +0,0 @@ -"""Run FK steering with real-space density reward on the Boltz2 model.""" - -import sys - -from sampleworks.utils.guidance_script_arguments import GuidanceConfig -from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance - - -def main(): - config = GuidanceConfig.from_cli(model="boltz2", guidance_type="fk_steering") - device, model_wrapper = get_model_and_device( - config.device, - getattr(config, "model_checkpoint", None), - config.model, - method=getattr(config, "method", None), - ) - result = run_guidance(config, config.guidance_type, model_wrapper, device) - return result.exit_code - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/scripts/boltz2_pure_guidance.py b/scripts/boltz2_pure_guidance.py deleted file mode 100644 index 33729516..00000000 --- a/scripts/boltz2_pure_guidance.py +++ /dev/null @@ -1,22 +0,0 @@ -"""Run pure guidance with real-space density reward on the Boltz2 model.""" - -import sys - -from sampleworks.utils.guidance_script_arguments import GuidanceConfig -from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance - - -def main(): - config = GuidanceConfig.from_cli(model="boltz2", guidance_type="pure_guidance") - device, model_wrapper = get_model_and_device( - config.device, - getattr(config, "model_checkpoint", None), - config.model, - method=getattr(config, "method", None), - ) - result = run_guidance(config, config.guidance_type, model_wrapper, device) - return result.exit_code - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/scripts/protenix_fk_steering.py b/scripts/protenix_fk_steering.py deleted file mode 100644 index 00104c91..00000000 --- a/scripts/protenix_fk_steering.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Run FK steering with real-space density reward on the Protenix model.""" - -import sys - -from sampleworks.utils.guidance_script_arguments import GuidanceConfig -from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance - - -def main(): - config = GuidanceConfig.from_cli(model="protenix", guidance_type="fk_steering") - device, model_wrapper = get_model_and_device( - config.device, - getattr(config, "model_checkpoint", None), - config.model, - ) - result = run_guidance(config, config.guidance_type, model_wrapper, device) - return result.exit_code - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/scripts/protenix_pure_guidance.py b/scripts/protenix_pure_guidance.py deleted file mode 100644 index d64143d0..00000000 --- a/scripts/protenix_pure_guidance.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Run pure guidance with real-space density reward on the Protenix model.""" - -import sys - -from sampleworks.utils.guidance_script_arguments import GuidanceConfig -from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance - - -def main(): - config = GuidanceConfig.from_cli(model="protenix", guidance_type="pure_guidance") - device, model_wrapper = get_model_and_device( - config.device, - getattr(config, "model_checkpoint", None), - config.model, - ) - result = run_guidance(config, config.guidance_type, model_wrapper, device) - return result.exit_code - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/scripts/rf3_fk_steering.py b/scripts/rf3_fk_steering.py deleted file mode 100644 index c83e8cb2..00000000 --- a/scripts/rf3_fk_steering.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Run FK steering with real-space density reward on the RF3 model.""" - -import sys - -from sampleworks.utils.guidance_script_arguments import GuidanceConfig -from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance - - -def main(): - config = GuidanceConfig.from_cli(model="rf3", guidance_type="fk_steering") - device, model_wrapper = get_model_and_device( - config.device, - getattr(config, "model_checkpoint", None), - config.model, - ) - result = run_guidance(config, config.guidance_type, model_wrapper, device) - return result.exit_code - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/scripts/rf3_pure_guidance.py b/scripts/rf3_pure_guidance.py deleted file mode 100644 index 21b41ac8..00000000 --- a/scripts/rf3_pure_guidance.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Run pure guidance with real-space density reward on the RF3 model.""" - -import sys - -from sampleworks.utils.guidance_script_arguments import GuidanceConfig -from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance - - -def main(): - config = GuidanceConfig.from_cli(model="rf3", guidance_type="pure_guidance") - device, model_wrapper = get_model_and_device( - config.device, - getattr(config, "model_checkpoint", None), - config.model, - ) - result = run_guidance(config, config.guidance_type, model_wrapper, device) - return result.exit_code - - -if __name__ == "__main__": - sys.exit(main()) From d7888e30dddabea203f67eedef25aec06fe3c3db Mon Sep 17 00:00:00 2001 From: Abdelsalam Date: Fri, 10 Apr 2026 18:56:39 +0200 Subject: [PATCH 8/9] Validate preset args early, add CLI entrypoint smoke test Validate preset model/guidance_type against allowed choices before dispatch map lookup, raising ValueError instead of raw KeyError. Add smoke test that invokes sampleworks.cli.guidance --help as a subprocess to verify the console_scripts wiring and import chain. --- .../utils/guidance_script_arguments.py | 5 +++ tests/cli/test_guidance_cli.py | 33 +++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/src/sampleworks/utils/guidance_script_arguments.py b/src/sampleworks/utils/guidance_script_arguments.py index c5fdc681..de08ffc2 100644 --- a/src/sampleworks/utils/guidance_script_arguments.py +++ b/src/sampleworks/utils/guidance_script_arguments.py @@ -187,6 +187,11 @@ def from_cli( model_preset = model is not None guidance_preset = guidance_type is not None + if model_preset and model not in model_choices: + raise ValueError(f"Unknown model type: {model}") + if guidance_preset and guidance_type not in guidance_choices: + raise ValueError(f"Unknown guidance type: {guidance_type}") + # -- first pass: resolve model & guidance_type if not pre-set -------- if not model_preset or not guidance_preset: pre = argparse.ArgumentParser(add_help=False) diff --git a/tests/cli/test_guidance_cli.py b/tests/cli/test_guidance_cli.py index 885b1565..4783b073 100644 --- a/tests/cli/test_guidance_cli.py +++ b/tests/cli/test_guidance_cli.py @@ -2,6 +2,9 @@ from __future__ import annotations +import subprocess +import sys + import pytest from sampleworks.utils.guidance_script_arguments import GuidanceConfig @@ -503,3 +506,33 @@ def test_invalid_step_scaler_type(self): ] + COMMON_ARGS with pytest.raises(SystemExit): GuidanceConfig.from_cli(argv) + + +class TestEntrypointSmoke: + """Smoke test the actual CLI entrypoint.""" + + def test_help_exits_zero(self): + result = subprocess.run( + [ + sys.executable, + "-m", + "sampleworks.cli.guidance", + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--help", + ], + capture_output=True, + ) + assert result.returncode == 0 + assert b"--protein" in result.stdout + assert b"--structure" in result.stdout + + def test_invalid_preset_model_raises(self): + with pytest.raises(ValueError, match="Unknown model type"): + GuidanceConfig.from_cli(COMMON_ARGS, model="typo", guidance_type="fk_steering") + + def test_invalid_preset_guidance_type_raises(self): + with pytest.raises(ValueError, match="Unknown guidance type"): + GuidanceConfig.from_cli(COMMON_ARGS, model="boltz1", guidance_type="typo") From 9c0d1bea3d7a1e9ea1c5b68b597f4452ddea4643 Mon Sep 17 00:00:00 2001 From: Abdelsalam Date: Fri, 10 Apr 2026 21:18:38 +0200 Subject: [PATCH 9/9] Expose recycling_steps and num_diffusion_steps to CLI Add --recycling-steps and --num-diffusion-steps to add_generic_args() and _DYNAMIC_ATTRS so they are available via sampleworks-guidance. These were added in #205 but not wired to the CLI. --- .../utils/guidance_script_arguments.py | 14 ++++++++ tests/cli/test_guidance_cli.py | 34 +++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/src/sampleworks/utils/guidance_script_arguments.py b/src/sampleworks/utils/guidance_script_arguments.py index de08ffc2..b5ac4621 100644 --- a/src/sampleworks/utils/guidance_script_arguments.py +++ b/src/sampleworks/utils/guidance_script_arguments.py @@ -133,6 +133,8 @@ def validate_model_checkpoint( "track_chiral_features", # generic (overridable) "ensemble_size", + "recycling_steps", + "num_diffusion_steps", ] @@ -381,6 +383,18 @@ def add_generic_args(parser: argparse.ArgumentParser | GuidanceConfig): default=4, help="Ensemble size to generate (per particle for FK-steering)", ) + parser.add_argument( + "--recycling-steps", + type=int, + default=None, + help="Number of recycling steps for the model (default: model-specific)", + ) + parser.add_argument( + "--num-diffusion-steps", + type=int, + default=200, + help="Number of diffusion denoising steps (default: 200)", + ) ###################### diff --git a/tests/cli/test_guidance_cli.py b/tests/cli/test_guidance_cli.py index 4783b073..3d4b496e 100644 --- a/tests/cli/test_guidance_cli.py +++ b/tests/cli/test_guidance_cli.py @@ -443,6 +443,40 @@ def test_guidance_start(self): config = GuidanceConfig.from_cli(argv) assert config.guidance_start == 10 + def test_recycling_steps(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--recycling-steps", + "3", + ] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.recycling_steps == 3 + + def test_num_diffusion_steps(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--num-diffusion-steps", + "100", + ] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.num_diffusion_steps == 100 + + def test_recycling_steps_default_none(self): + argv = ["--model", "boltz1", "--guidance-type", "pure_guidance"] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.recycling_steps is None + + def test_num_diffusion_steps_default(self): + argv = ["--model", "boltz1", "--guidance-type", "pure_guidance"] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.num_diffusion_steps == 200 + class TestValidationEdgeCases: """Additional validation edge cases."""