diff --git a/AGENTS.md b/AGENTS.md index a017e274..b56d51d6 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -159,25 +159,13 @@ src/sampleworks/ ## Running Guidance Pipelines -The `scripts/` directory contains ready-to-run guidance scripts for each supported model and trajectory scaler: - -``` -scripts/ -├── boltz1_pure_guidance.py # Boltz-1 + pure guidance -├── boltz2_pure_guidance.py # Boltz-2 + pure guidance -├── boltz2_fk_steering.py # Boltz-2 + Feynman-Kaç steering -├── protenix_pure_guidance.py # Protenix + pure guidance -├── protenix_fk_steering.py # Protenix + FK steering -├── rf3_pure_guidance.py # RF3 + pure guidance -├── rf3_fk_steering.py # RF3 + FK steering -├── run_guidance_pipeline.py # Generic pipeline runner -└── eval/ # Evaluation scripts (RSCC, lDDT, clashscore) -``` - -Each script follows the same pattern: load a model wrapper, then call `run_guidance()` from `utils/guidance_script_utils.py`. Example invocation: +Use the unified `sampleworks-guidance` CLI to run guidance with any supported model and trajectory scaler: ```bash -pixi run -e boltz python scripts/boltz2_pure_guidance.py \ +pixi run -e boltz sampleworks-guidance \ + --model boltz2 \ + --guidance-type pure_guidance \ + --protein 1VME \ --model-checkpoint ~/.boltz/boltz2_conf.ckpt \ --output-dir output/boltz2_pure_guidance \ --structure tests/resources/1vme/1vme_final_carved_edited_0.5occA_0.5occB.cif \ @@ -188,6 +176,8 @@ pixi run -e boltz python scripts/boltz2_pure_guidance.py \ --augmentation --align-to-input ``` +Run `sampleworks-guidance --model --guidance-type --help` to see all available options. + The `run_guidance()` function in `utils/guidance_script_utils.py` is the central orchestrator. It wires together the model wrapper, sampler (`AF3EDMSampler`), step scaler (`DataSpaceDPSScaler` or `NoiseSpaceDPSScaler`), trajectory scaler (`PureGuidance` or `FKSteering`), and reward function. When adding a new model or guidance strategy, this is the best reference for how components compose in practice. ## Development Environment diff --git a/README.md b/README.md index 7087c139..0b123355 100644 --- a/README.md +++ b/README.md @@ -66,7 +66,10 @@ download_boltz2(cache) Run Boltz-2 pure guidance on the included 1VME example: ```bash -pixi run -e boltz python scripts/boltz2_pure_guidance.py \ +pixi run -e boltz sampleworks-guidance \ + --model boltz2 \ + --guidance-type pure_guidance \ + --protein 1VME \ --model-checkpoint ~/.boltz/boltz2_conf.ckpt \ --structure tests/resources/1vme/1vme_final_carved_edited_0.5occA_0.5occB.cif \ --density tests/resources/1vme/1vme_final_carved_edited_0.5occA_0.5occB_1.80A.ccp4 \ @@ -78,7 +81,25 @@ pixi run -e boltz python scripts/boltz2_pure_guidance.py \ --align-to-input ``` -Output files appear in `output/boltz2_pure_guidance/`: `refined.cif` (final ensemble), `losses.txt`, `trajectory/`, `run.log`. See [`scripts/README.md`](scripts/README.md) for all scripts and arguments. +Output files appear in `output/boltz2_pure_guidance/`: `refined.cif` (final ensemble), `losses.txt`, `trajectory/`, `run.log`. + +### CLI reference + +`sampleworks-guidance` is the unified command-line interface for running guidance on a single structure. + +**Required arguments:** + +| Argument | Description | +|---|---| +| `--model` | `boltz1`, `boltz2`, `protenix`, or `rf3` | +| `--guidance-type` | `pure_guidance` or `fk_steering` | +| `--protein` | Protein identifier (should match naming used in grid search / evaluation) | +| `--structure` | Path to input structure file (CIF) | +| `--density` | Path to density map (CCP4/MRC/MAP) | +| `--resolution` | Map resolution in Angstroms | + +Model-specific arguments (e.g. `--method` for boltz2, `--msa-path` for rf3) and guidance-type-specific arguments (e.g. `--num-particles` for fk_steering) are included automatically. Run `sampleworks-guidance --model --guidance-type --help` to see all available options. + ## Grid Search diff --git a/pixi.lock b/pixi.lock index 6e3917cb..d2a57ff9 100644 --- a/pixi.lock +++ b/pixi.lock @@ -9720,7 +9720,7 @@ packages: - pypi: ./ name: sampleworks version: 0.5.0 - sha256: cb4cf99c106ad969bd73884c85871ee5e904e995536cff6ce0c4ccb98255a582 + sha256: 0990d4d0a4685b902656605abb1c0803231b65d12bb405c23e797f78a102ec3b requires_dist: - atomworks[ml]==2.1.1 - python-dotenv diff --git a/pyproject.toml b/pyproject.toml index c6200cf6..68409387 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,9 @@ name = "sampleworks" requires-python = ">= 3.11, <3.14" version = "0.5.0" +[project.scripts] +sampleworks-guidance = "sampleworks.cli.guidance:main" + [tool.hatch.metadata] allow-direct-references = true diff --git a/scripts/boltz1_fk_steering.py b/scripts/boltz1_fk_steering.py deleted file mode 100644 index 3a120aa9..00000000 --- a/scripts/boltz1_fk_steering.py +++ /dev/null @@ -1,19 +0,0 @@ -""" -Run FK steering with real-space density reward on the Boltz1 model. -""" - -from sampleworks.utils.guidance_constants import GuidanceType, StructurePredictor -from sampleworks.utils.guidance_script_arguments import parse_boltz1_fk_steering_args -from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance - - -def main(args): - device, model_wrapper = get_model_and_device( - "", args.model_checkpoint, StructurePredictor.BOLTZ_1 - ) - run_guidance(args, GuidanceType.FK_STEERING, model_wrapper, device) - - -if __name__ == "__main__": - guidance_args = parse_boltz1_fk_steering_args() - main(guidance_args) diff --git a/scripts/boltz1_pure_guidance.py b/scripts/boltz1_pure_guidance.py deleted file mode 100644 index c4d48693..00000000 --- a/scripts/boltz1_pure_guidance.py +++ /dev/null @@ -1,19 +0,0 @@ -""" -Run pure guidance with real-space density reward on the Boltz1 model. -""" - -from sampleworks.utils.guidance_constants import GuidanceType, StructurePredictor -from sampleworks.utils.guidance_script_arguments import parse_boltz1_pure_guidance_args -from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance - - -def main(args): - device, model_wrapper = get_model_and_device( - "", args.model_checkpoint, StructurePredictor.BOLTZ_1 - ) - run_guidance(args, GuidanceType.PURE_GUIDANCE, model_wrapper, device) - - -if __name__ == "__main__": - guidance_args = parse_boltz1_pure_guidance_args() - main(guidance_args) diff --git a/scripts/boltz2_fk_steering.py b/scripts/boltz2_fk_steering.py deleted file mode 100644 index 13e768a0..00000000 --- a/scripts/boltz2_fk_steering.py +++ /dev/null @@ -1,19 +0,0 @@ -""" -Run FK steering with real-space density reward on the Boltz2 model. -""" - -from sampleworks.utils.guidance_constants import GuidanceType, StructurePredictor -from sampleworks.utils.guidance_script_arguments import parse_boltz2_fk_steering_args -from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance - - -def main(args): - device, model_wrapper = get_model_and_device( - "", args.model_checkpoint, StructurePredictor.BOLTZ_2, method=args.method - ) - run_guidance(args, GuidanceType.FK_STEERING, model_wrapper, device) - - -if __name__ == "__main__": - guidance_args = parse_boltz2_fk_steering_args() - main(guidance_args) diff --git a/scripts/boltz2_pure_guidance.py b/scripts/boltz2_pure_guidance.py deleted file mode 100644 index f5dc2550..00000000 --- a/scripts/boltz2_pure_guidance.py +++ /dev/null @@ -1,19 +0,0 @@ -""" -Run pure guidance with real-space density reward on the Boltz2 model. -""" - -from sampleworks.utils.guidance_constants import GuidanceType, StructurePredictor -from sampleworks.utils.guidance_script_arguments import parse_boltz2_pure_guidance_args -from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance - - -def main(args): - device, model_wrapper = get_model_and_device( - "", args.model_checkpoint, StructurePredictor.BOLTZ_2, method=args.method - ) - run_guidance(args, GuidanceType.PURE_GUIDANCE, model_wrapper, device) - - -if __name__ == "__main__": - guidance_args = parse_boltz2_pure_guidance_args() - main(guidance_args) diff --git a/scripts/protenix_fk_steering.py b/scripts/protenix_fk_steering.py deleted file mode 100644 index e2ab10ec..00000000 --- a/scripts/protenix_fk_steering.py +++ /dev/null @@ -1,19 +0,0 @@ -""" -Run FK steering with real-space density reward on the Protenix model. -""" - -from sampleworks.utils.guidance_constants import GuidanceType, StructurePredictor -from sampleworks.utils.guidance_script_arguments import parse_protenix_fk_steering_args -from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance - - -def main(args): - device, model_wrapper = get_model_and_device( - "", args.model_checkpoint, StructurePredictor.PROTENIX - ) - run_guidance(args, GuidanceType.FK_STEERING, model_wrapper, device) - - -if __name__ == "__main__": - guidance_args = parse_protenix_fk_steering_args() - main(guidance_args) diff --git a/scripts/protenix_pure_guidance.py b/scripts/protenix_pure_guidance.py deleted file mode 100644 index 5aaef3b4..00000000 --- a/scripts/protenix_pure_guidance.py +++ /dev/null @@ -1,19 +0,0 @@ -""" -Run pure guidance with real-space density reward on the Protenix model. -""" - -from sampleworks.utils.guidance_constants import GuidanceType, StructurePredictor -from sampleworks.utils.guidance_script_arguments import parse_protenix_pure_guidance_args -from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance - - -def main(args): - device, model_wrapper = get_model_and_device( - "", args.model_checkpoint, StructurePredictor.PROTENIX - ) - run_guidance(args, GuidanceType.PURE_GUIDANCE, model_wrapper, device) - - -if __name__ == "__main__": - guidance_args = parse_protenix_pure_guidance_args() - main(guidance_args) diff --git a/scripts/rf3_fk_steering.py b/scripts/rf3_fk_steering.py deleted file mode 100644 index 595e93f9..00000000 --- a/scripts/rf3_fk_steering.py +++ /dev/null @@ -1,17 +0,0 @@ -""" -Run FK steering with real-space density reward on the RF3 model. -""" - -from sampleworks.utils.guidance_constants import GuidanceType, StructurePredictor -from sampleworks.utils.guidance_script_arguments import parse_rf3_fk_steering_args -from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance - - -def main(args): - device, model_wrapper = get_model_and_device("", args.model_checkpoint, StructurePredictor.RF3) - run_guidance(args, GuidanceType.FK_STEERING, model_wrapper, device) - - -if __name__ == "__main__": - guidance_args = parse_rf3_fk_steering_args() - main(guidance_args) diff --git a/scripts/rf3_pure_guidance.py b/scripts/rf3_pure_guidance.py deleted file mode 100644 index f89f4e0e..00000000 --- a/scripts/rf3_pure_guidance.py +++ /dev/null @@ -1,17 +0,0 @@ -""" -Run pure guidance with real-space density reward on the RF3 model. -""" - -from sampleworks.utils.guidance_constants import GuidanceType, StructurePredictor -from sampleworks.utils.guidance_script_arguments import parse_rf3_pure_guidance_args -from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance - - -def main(args): - device, model_wrapper = get_model_and_device("", args.model_checkpoint, StructurePredictor.RF3) - run_guidance(args, GuidanceType.PURE_GUIDANCE, model_wrapper, device) - - -if __name__ == "__main__": - guidance_args = parse_rf3_pure_guidance_args() - main(guidance_args) diff --git a/src/sampleworks/cli/__init__.py b/src/sampleworks/cli/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/sampleworks/cli/guidance.py b/src/sampleworks/cli/guidance.py new file mode 100644 index 00000000..5f71fe22 --- /dev/null +++ b/src/sampleworks/cli/guidance.py @@ -0,0 +1,24 @@ +"""Unified CLI for running diffusion guidance with sampleworks.""" + +from __future__ import annotations + +import sys + +from sampleworks.utils.guidance_script_arguments import GuidanceConfig +from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance + + +def main(argv: list[str] | None = None) -> int: + config = GuidanceConfig.from_cli(argv) + device, model_wrapper = get_model_and_device( + config.device, + getattr(config, "model_checkpoint", None), + config.model, + method=getattr(config, "method", None), + ) + result = run_guidance(config, config.guidance_type, model_wrapper, device) + return result.exit_code + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/sampleworks/utils/guidance_script_arguments.py b/src/sampleworks/utils/guidance_script_arguments.py index e563bc55..b5ac4621 100644 --- a/src/sampleworks/utils/guidance_script_arguments.py +++ b/src/sampleworks/utils/guidance_script_arguments.py @@ -112,9 +112,34 @@ def validate_model_checkpoint( return str(checkpoint_path) +# Attributes set dynamically by add_*_args helpers that should be copied +# from a parsed argparse.Namespace onto a GuidanceConfig instance. +_DYNAMIC_ATTRS = [ + # pure guidance + "step_size", + "step_scaler_type", + # fk steering + "num_particles", + "fk_resampling_interval", + "fk_lambda", + "num_gd_steps", + "guidance_weight", + "guidance_interval", + # model-specific + "model_checkpoint", + "method", + "msa_path", + "disable_chiral_features", + "track_chiral_features", + # generic (overridable) + "ensemble_size", + "recycling_steps", + "num_diffusion_steps", +] + + @dataclass class GuidanceConfig: - # TODO add a class method to set this up completely from args and job config. """ Class to hold guidance config arguments, compatible with argparse, but which also can do some basic validation. @@ -145,24 +170,132 @@ def add_argument(self, name: str, default: Any = None, **kwargs): """Add an argument to the guidance config, in a form compatible with argparse""" setattr(self, name.lstrip("-").replace("-", "_"), default) + @classmethod + def from_cli( + cls, + argv: list[str] | None = None, + model: str | None = None, + guidance_type: str | None = None, + ) -> GuidanceConfig: + """Parse CLI arguments and return a fully populated GuidanceConfig. + + When *model* and *guidance_type* are provided (e.g. from legacy + scripts), they are used directly and ``--model`` / ``--guidance-type`` + are not required on the command line. Otherwise they are parsed as + required CLI arguments. + """ + model_choices = [m.value for m in StructurePredictor] + guidance_choices = [g.value for g in GuidanceType] + model_preset = model is not None + guidance_preset = guidance_type is not None + + if model_preset and model not in model_choices: + raise ValueError(f"Unknown model type: {model}") + if guidance_preset and guidance_type not in guidance_choices: + raise ValueError(f"Unknown guidance type: {guidance_type}") + + # -- first pass: resolve model & guidance_type if not pre-set -------- + if not model_preset or not guidance_preset: + pre = argparse.ArgumentParser(add_help=False) + if not model_preset: + pre.add_argument( + "--model", + type=str, + required=True, + choices=model_choices, + help="Structure prediction model", + ) + if not guidance_preset: + pre.add_argument( + "--guidance-type", + type=str, + required=True, + choices=guidance_choices, + help="Guidance method", + ) + pre_args, _ = pre.parse_known_args(argv) + model = model or pre_args.model + guidance_type = guidance_type or pre_args.guidance_type + + # -- full parser ----------------------------------------------------- + parser = argparse.ArgumentParser( + description=f"Run {guidance_type} guidance with {model}", + ) + parser.add_argument( + "--model", + type=str, + default=model, + choices=model_choices, + help=argparse.SUPPRESS if model_preset else "Structure prediction model", + ) + parser.add_argument( + "--guidance-type", + type=str, + default=guidance_type, + choices=guidance_choices, + help=argparse.SUPPRESS if guidance_preset else "Guidance method", + ) + parser.add_argument( + "--protein", + type=str, + required=True, + help="Protein identifier (must match naming used in grid search / evaluation)", + ) + add_generic_args(parser) + _MODEL_ARG_ADDERS[model](parser) + _GUIDANCE_ARG_ADDERS[guidance_type](parser) + + args = parser.parse_args(argv) + + if model_preset and args.model != model: + parser.error( + f"This script is fixed to --model {model}." + f" Use sampleworks-guidance for other models." + ) + if guidance_preset and args.guidance_type != guidance_type: + parser.error( + f"This script is fixed to --guidance-type {guidance_type}." + f" Use sampleworks-guidance for other guidance types." + ) + + config = cls( + protein=args.protein, + structure=args.structure, + density=args.density, + model=model, + guidance_type=guidance_type, + log_path=getattr(args, "log_path", None) or "", + output_dir=args.output_dir, + partial_diffusion_step=args.partial_diffusion_step, + loss_order=args.loss_order, + resolution=args.resolution, + device=getattr(args, "device", "") or "", + gradient_normalization=args.gradient_normalization, + em=args.em, + guidance_start=args.guidance_start, + augmentation=args.augmentation, + align_to_input=args.align_to_input, + ) + + # __post_init__ already set defaults for model/guidance-specific + # attrs; override with any explicit CLI values. + for attr in _DYNAMIC_ATTRS: + val = getattr(args, attr, None) + if val is not None: + setattr(config, attr, val) + + return config + def __post_init__(self): """Set up guidance config for a given model and guidance type""" - if self.guidance_type == GuidanceType.PURE_GUIDANCE: - add_pure_guidance_args(self) - elif self.guidance_type == GuidanceType.FK_STEERING: - add_fk_steering_args(self) - else: + try: + _GUIDANCE_ARG_ADDERS[self.guidance_type](self) + except KeyError: raise ValueError(f"Unknown guidance type: {self.guidance_type}") - if self.model == StructurePredictor.BOLTZ_1: - add_boltz1_specific_args(self) - elif self.model == StructurePredictor.BOLTZ_2: - add_boltz2_specific_args(self) - elif self.model == StructurePredictor.PROTENIX: - add_protenix_specific_args(self) - elif self.model == StructurePredictor.RF3: - add_rf3_specific_args(self) - else: + try: + _MODEL_ARG_ADDERS[self.model](self) + except KeyError: raise ValueError(f"Unknown model type: {self.model}") def populate_config_for_guidance_type(self, job: JobConfig, args: argparse.Namespace): @@ -250,6 +383,18 @@ def add_generic_args(parser: argparse.ArgumentParser | GuidanceConfig): default=4, help="Ensemble size to generate (per particle for FK-steering)", ) + parser.add_argument( + "--recycling-steps", + type=int, + default=None, + help="Number of recycling steps for the model (default: model-specific)", + ) + parser.add_argument( + "--num-diffusion-steps", + type=int, + default=200, + help="Number of diffusion denoising steps (default: 200)", + ) ###################### @@ -367,90 +512,17 @@ def add_rf3_specific_args(parser: argparse.ArgumentParser | GuidanceConfig): ) -############## -# Use these methods to parse arguments in scripts which load the model themselves. -############## -def parse_boltz2_pure_guidance_args(): - parser = argparse.ArgumentParser( - description="Pure guidance refinement with Boltz-2 and real-space density" - ) - add_generic_args(parser) - add_boltz2_specific_args(parser) - add_pure_guidance_args(parser) - - return parser.parse_args() - - -def parse_boltz1_pure_guidance_args(): - parser = argparse.ArgumentParser( - description="Pure guidance refinement with Boltz-1 and real-space density" - ) - add_generic_args(parser) - add_boltz1_specific_args(parser) - add_pure_guidance_args(parser) - - return parser.parse_args() - - -def parse_protenix_pure_guidance_args(): - parser = argparse.ArgumentParser( - description="Pure guidance refinement with Protenix and real-space density" - ) - add_generic_args(parser) - add_protenix_specific_args(parser) - add_pure_guidance_args(parser) - - return parser.parse_args() - - -def parse_protenix_fk_steering_args(): - parser = argparse.ArgumentParser( - description="FK steering refinement with Protenix and real-space density" - ) - add_protenix_specific_args(parser) - add_generic_args(parser) - add_fk_steering_args(parser) - return parser.parse_args() - - -def parse_boltz2_fk_steering_args(): - parser = argparse.ArgumentParser( - description="FK steering refinement with Boltz-2 and real-space density" - ) - add_boltz2_specific_args(parser) - add_generic_args(parser) - add_fk_steering_args(parser) - return parser.parse_args() - - -def parse_boltz1_fk_steering_args(): - parser = argparse.ArgumentParser( - description="FK steering refinement with Boltz-1 and real-space density" - ) - add_boltz1_specific_args(parser) - add_generic_args(parser) - add_fk_steering_args(parser) - return parser.parse_args() - - -def parse_rf3_pure_guidance_args(): - parser = argparse.ArgumentParser( - description="Pure guidance refinement with RF3 and real-space density" - ) - add_generic_args(parser) - add_rf3_specific_args(parser) - add_pure_guidance_args(parser) - return parser.parse_args() - +_MODEL_ARG_ADDERS: dict[str, Any] = { + "boltz1": add_boltz1_specific_args, + "boltz2": add_boltz2_specific_args, + "protenix": add_protenix_specific_args, + "rf3": add_rf3_specific_args, +} -def parse_rf3_fk_steering_args(): - parser = argparse.ArgumentParser( - description="FK steering refinement with RF3 and real-space density" - ) - add_generic_args(parser) - add_rf3_specific_args(parser) - add_fk_steering_args(parser) - return parser.parse_args() +_GUIDANCE_ARG_ADDERS: dict[str, Any] = { + "pure_guidance": add_pure_guidance_args, + "fk_steering": add_fk_steering_args, +} @dataclass diff --git a/tests/cli/__init__.py b/tests/cli/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/cli/test_guidance_cli.py b/tests/cli/test_guidance_cli.py new file mode 100644 index 00000000..3d4b496e --- /dev/null +++ b/tests/cli/test_guidance_cli.py @@ -0,0 +1,572 @@ +"""Tests for the unified guidance CLI and GuidanceConfig.from_cli().""" + +from __future__ import annotations + +import subprocess +import sys + +import pytest +from sampleworks.utils.guidance_script_arguments import GuidanceConfig + + +COMMON_ARGS = [ + "--protein", + "1VME", + "--structure", + "test.cif", + "--density", + "test.ccp4", + "--resolution", + "1.8", + "--output-dir", + "output", +] + + +class TestFromCliUnified: + """Test from_cli() when model and guidance_type come from CLI args.""" + + def test_boltz2_pure_guidance(self): + argv = ["--model", "boltz2", "--guidance-type", "pure_guidance"] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.model == "boltz2" + assert config.guidance_type == "pure_guidance" + assert config.protein == "1VME" + assert config.structure == "test.cif" + assert config.density == "test.ccp4" + assert config.resolution == 1.8 + + def test_rf3_fk_steering(self): + argv = ["--model", "rf3", "--guidance-type", "fk_steering"] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.model == "rf3" + assert config.guidance_type == "fk_steering" + assert hasattr(config, "num_particles") + assert hasattr(config, "fk_lambda") + + def test_model_specific_args_boltz2_method(self): + argv = [ + "--model", + "boltz2", + "--guidance-type", + "pure_guidance", + "--method", + "MD", + ] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.method == "MD" + + def test_model_specific_args_rf3_msa(self): + argv = [ + "--model", + "rf3", + "--guidance-type", + "pure_guidance", + "--msa-path", + "/data/msa.a3m", + ] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.msa_path == "/data/msa.a3m" + + def test_guidance_specific_args_fk(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "fk_steering", + "--num-particles", + "5", + "--fk-lambda", + "2.0", + ] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.num_particles == 5 + assert config.fk_lambda == 2.0 + + def test_guidance_specific_args_pure(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--step-size", + "0.5", + "--step-scaler-type", + "dataspace", + ] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.step_size == 0.5 + assert config.step_scaler_type == "dataspace" + + +class TestFromCliLegacyScripts: + """Test from_cli() when model/guidance_type are pre-set (legacy script pattern).""" + + def test_preset_model_and_guidance_type(self): + config = GuidanceConfig.from_cli( + COMMON_ARGS, + model="boltz2", + guidance_type="pure_guidance", + ) + assert config.model == "boltz2" + assert config.guidance_type == "pure_guidance" + assert config.protein == "1VME" + + def test_no_model_guidance_type_on_cli_required(self): + """Legacy scripts should not require --model/--guidance-type on CLI.""" + config = GuidanceConfig.from_cli( + COMMON_ARGS, + model="protenix", + guidance_type="fk_steering", + ) + assert config.model == "protenix" + assert config.guidance_type == "fk_steering" + + @pytest.mark.parametrize("model", ["boltz1", "boltz2", "protenix", "rf3"]) + @pytest.mark.parametrize("guidance_type", ["pure_guidance", "fk_steering"]) + def test_all_model_guidance_combos(self, model, guidance_type): + config = GuidanceConfig.from_cli( + COMMON_ARGS, + model=model, + guidance_type=guidance_type, + ) + assert config.model == model + assert config.guidance_type == guidance_type + + +class TestFromCliValidation: + """Test error handling for invalid inputs.""" + + def test_missing_protein_errors(self): + argv = [ + "--model", + "boltz2", + "--guidance-type", + "pure_guidance", + "--structure", + "test.cif", + "--density", + "test.ccp4", + "--resolution", + "1.8", + ] + with pytest.raises(SystemExit): + GuidanceConfig.from_cli(argv) + + def test_invalid_model_errors(self): + argv = ["--model", "invalid_model", "--guidance-type", "pure_guidance"] + COMMON_ARGS + with pytest.raises(SystemExit): + GuidanceConfig.from_cli(argv) + + def test_invalid_guidance_type_errors(self): + argv = ["--model", "boltz2", "--guidance-type", "invalid_type"] + COMMON_ARGS + with pytest.raises(SystemExit): + GuidanceConfig.from_cli(argv) + + def test_missing_required_structure_errors(self): + argv = [ + "--model", + "boltz2", + "--guidance-type", + "pure_guidance", + "--protein", + "1VME", + "--density", + "test.ccp4", + "--resolution", + "1.8", + ] + with pytest.raises(SystemExit): + GuidanceConfig.from_cli(argv) + + +class TestFromCliDefaults: + """Test that defaults are applied correctly.""" + + def test_generic_defaults(self): + argv = ["--model", "boltz1", "--guidance-type", "pure_guidance"] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.output_dir == "output" + assert config.partial_diffusion_step == 0 + assert config.loss_order == 2 + assert config.guidance_start == -1 + assert config.augmentation is False + assert config.align_to_input is False + assert config.ensemble_size == 4 + + def test_boolean_flags(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--augmentation", + "--align-to-input", + "--gradient-normalization", + "--em", + ] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.augmentation is True + assert config.align_to_input is True + assert config.gradient_normalization is True + assert config.em is True + + +class TestCrossModelArgRejection: + """Test that model-specific args are rejected for the wrong model.""" + + def test_method_rejected_for_protenix(self): + argv = [ + "--model", + "protenix", + "--guidance-type", + "pure_guidance", + "--method", + "MD", + ] + COMMON_ARGS + with pytest.raises(SystemExit): + GuidanceConfig.from_cli(argv) + + def test_method_rejected_for_rf3(self): + argv = [ + "--model", + "rf3", + "--guidance-type", + "pure_guidance", + "--method", + "MD", + ] + COMMON_ARGS + with pytest.raises(SystemExit): + GuidanceConfig.from_cli(argv) + + def test_msa_path_rejected_for_boltz1(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--msa-path", + "/data/msa.a3m", + ] + COMMON_ARGS + with pytest.raises(SystemExit): + GuidanceConfig.from_cli(argv) + + def test_conflicting_model_in_preset_mode(self): + argv = ["--model", "boltz2"] + COMMON_ARGS + with pytest.raises(SystemExit): + GuidanceConfig.from_cli(argv, model="rf3", guidance_type="fk_steering") + + def test_conflicting_guidance_type_in_preset_mode(self): + argv = ["--guidance-type", "fk_steering"] + COMMON_ARGS + with pytest.raises(SystemExit): + GuidanceConfig.from_cli(argv, model="boltz1", guidance_type="pure_guidance") + + def test_fk_args_rejected_for_pure_guidance(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--num-particles", + "5", + ] + COMMON_ARGS + with pytest.raises(SystemExit): + GuidanceConfig.from_cli(argv) + + def test_pure_args_rejected_for_fk_steering(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "fk_steering", + "--step-scaler-type", + "dataspace", + ] + COMMON_ARGS + with pytest.raises(SystemExit): + GuidanceConfig.from_cli(argv) + + def test_msa_path_rejected_for_boltz2(self): + argv = [ + "--model", + "boltz2", + "--guidance-type", + "pure_guidance", + "--msa-path", + "/data/msa.a3m", + ] + COMMON_ARGS + with pytest.raises(SystemExit): + GuidanceConfig.from_cli(argv) + + def test_chiral_flags_rejected_for_boltz1(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--disable-chiral-features", + ] + COMMON_ARGS + with pytest.raises(SystemExit): + GuidanceConfig.from_cli(argv) + + +class TestPresetMatchingAccepted: + """Matching preset values should not trigger false-positive rejection.""" + + def test_matching_model_accepted(self): + argv = ["--model", "boltz1"] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv, model="boltz1", guidance_type="fk_steering") + assert config.model == "boltz1" + + def test_matching_guidance_type_accepted(self): + argv = ["--guidance-type", "pure_guidance"] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv, model="boltz2", guidance_type="pure_guidance") + assert config.guidance_type == "pure_guidance" + + +class TestArgPassthrough: + """Test that non-default argument values propagate to GuidanceConfig.""" + + def test_model_checkpoint(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--model-checkpoint", + "/custom/path.ckpt", + ] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.model_checkpoint == "/custom/path.ckpt" + + def test_device(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--device", + "cpu", + ] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.device == "cpu" + + def test_output_dir_override(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--output-dir", + "/custom/output", + "--protein", + "1VME", + "--structure", + "test.cif", + "--density", + "test.ccp4", + "--resolution", + "1.8", + ] + config = GuidanceConfig.from_cli(argv) + assert config.output_dir == "/custom/output" + + def test_partial_diffusion_step(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--partial-diffusion-step", + "50", + ] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.partial_diffusion_step == 50 + + def test_log_path(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--log-path", + "/tmp/run.log", + ] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.log_path == "/tmp/run.log" + + def test_ensemble_size(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "fk_steering", + "--ensemble-size", + "8", + ] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.ensemble_size == 8 + + def test_rf3_chiral_features(self): + argv = [ + "--model", + "rf3", + "--guidance-type", + "pure_guidance", + "--disable-chiral-features", + "--track-chiral-features", + ] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.disable_chiral_features is True + assert config.track_chiral_features is True + + def test_loss_order(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--loss-order", + "1", + ] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.loss_order == 1 + + def test_guidance_start(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--guidance-start", + "10", + ] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.guidance_start == 10 + + def test_recycling_steps(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--recycling-steps", + "3", + ] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.recycling_steps == 3 + + def test_num_diffusion_steps(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--num-diffusion-steps", + "100", + ] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.num_diffusion_steps == 100 + + def test_recycling_steps_default_none(self): + argv = ["--model", "boltz1", "--guidance-type", "pure_guidance"] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.recycling_steps is None + + def test_num_diffusion_steps_default(self): + argv = ["--model", "boltz1", "--guidance-type", "pure_guidance"] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert config.num_diffusion_steps == 200 + + +class TestValidationEdgeCases: + """Additional validation edge cases.""" + + def test_no_args_at_all(self): + with pytest.raises(SystemExit): + GuidanceConfig.from_cli([]) + + def test_missing_resolution(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--protein", + "1VME", + "--structure", + "test.cif", + "--density", + "test.ccp4", + ] + with pytest.raises(SystemExit): + GuidanceConfig.from_cli(argv) + + def test_missing_density(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--protein", + "1VME", + "--structure", + "test.cif", + "--resolution", + "1.8", + ] + with pytest.raises(SystemExit): + GuidanceConfig.from_cli(argv) + + def test_invalid_loss_order(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--loss-order", + "3", + ] + COMMON_ARGS + with pytest.raises(SystemExit): + GuidanceConfig.from_cli(argv) + + def test_invalid_step_scaler_type(self): + argv = [ + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--step-scaler-type", + "invalid", + ] + COMMON_ARGS + with pytest.raises(SystemExit): + GuidanceConfig.from_cli(argv) + + +class TestEntrypointSmoke: + """Smoke test the actual CLI entrypoint.""" + + def test_help_exits_zero(self): + result = subprocess.run( + [ + sys.executable, + "-m", + "sampleworks.cli.guidance", + "--model", + "boltz1", + "--guidance-type", + "pure_guidance", + "--help", + ], + capture_output=True, + ) + assert result.returncode == 0 + assert b"--protein" in result.stdout + assert b"--structure" in result.stdout + + def test_invalid_preset_model_raises(self): + with pytest.raises(ValueError, match="Unknown model type"): + GuidanceConfig.from_cli(COMMON_ARGS, model="typo", guidance_type="fk_steering") + + def test_invalid_preset_guidance_type_raises(self): + with pytest.raises(ValueError, match="Unknown guidance type"): + GuidanceConfig.from_cli(COMMON_ARGS, model="boltz1", guidance_type="typo")