From 3c909715af0394079eb474acba41d99940f2f9d2 Mon Sep 17 00:00:00 2001 From: xraymemory Date: Wed, 1 Jul 2026 10:48:08 -0400 Subject: [PATCH] fix(samplers): reconcile FlowModelWrapper.step call with Protocol The EDM sampler passed a third positional `eps` to `model_wrapper.step()`, which is not part of the `FlowModelWrapper.step(x_t, t, *, features=None)` protocol. This raised `TypeError: step() takes 3 positional arguments but 4 were given` for every protocol-compliant wrapper (boltz/protenix/rf3 and the test mocks), failing all 164 sampler-path tests. The extra arg was only absorbed by Protpardelle's `step()`, where it landed in an unused `sigma_float` parameter. - edm.py: drop the extra `eps` positional (eps is still used locally to build the noisy state and for the working-frame guidance math). Removes the temporary "modify the Protocol" TODO. - protpardelle/wrapper.py: remove the dead `sigma_float` parameter so the signature matches the protocol. Also fixes the failing lint job: - cli/guidance.py: drop unused top-level `loguru.logger` import (F401/F811/I001). - protpardelle/wrapper.py: use the repo's leading-space single-axis jaxtyping convention (`" atoms"`) to avoid UP037/F821. - tests/runs/test_runner.py: wrap an over-length comment (E501). Fixes #283. --- src/sampleworks/cli/guidance.py | 1 - src/sampleworks/core/samplers/edm.py | 4 +--- src/sampleworks/models/protpardelle/wrapper.py | 3 +-- tests/runs/test_runner.py | 3 ++- 4 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/sampleworks/cli/guidance.py b/src/sampleworks/cli/guidance.py index 3e95885e..1365bdfc 100644 --- a/src/sampleworks/cli/guidance.py +++ b/src/sampleworks/cli/guidance.py @@ -3,7 +3,6 @@ from __future__ import annotations import sys -from loguru import logger from sampleworks.utils.guidance_script_arguments import GuidanceConfig diff --git a/src/sampleworks/core/samplers/edm.py b/src/sampleworks/core/samplers/edm.py index 698a0346..1f81c551 100644 --- a/src/sampleworks/core/samplers/edm.py +++ b/src/sampleworks/core/samplers/edm.py @@ -422,10 +422,8 @@ def step( # t_hat will be float if check_context didn't raise # Use no_grad when gradients aren't needed to avoid memory overhead from # gradient checkpointing holding intermediate activations - # TODO testing adding eps to signature for use with Protpardelle-1c, if successful, - # I need to modify the Protocol itself. @Michael Anzuoni with torch.set_grad_enabled(allow_gradients): - x_hat_0 = model_wrapper.step(noisy_state, t_hat, eps, features=features) + x_hat_0 = model_wrapper.step(noisy_state, t_hat, features=features) reconciler = ( context.reconciler.to(torch.as_tensor(x_hat_0).device) diff --git a/src/sampleworks/models/protpardelle/wrapper.py b/src/sampleworks/models/protpardelle/wrapper.py index c6c0ce86..e993b3da 100644 --- a/src/sampleworks/models/protpardelle/wrapper.py +++ b/src/sampleworks/models/protpardelle/wrapper.py @@ -439,7 +439,7 @@ def featurize(self, structure: dict) -> GenerativeModelInput[ProtpardelleConditi def _atom37_indices_from_atom_array( self, atom_array - ) -> tuple[Int[Tensor, "atoms"], Int[Tensor, "atoms"]]: + ) -> tuple[Int[Tensor, " atoms"], Int[Tensor, " atoms"]]: """Derive per-atom atom37 destination indices from an Atomworks atom array. For each atom in ``atom_array`` (the order the sampler's flat ``x_t`` @@ -636,7 +636,6 @@ def step( self, x_t: Float[Tensor, "batch atoms 3"], t: Float[Tensor, "*batch"] | float, - sigma_float: float, *, features: GenerativeModelInput[ProtpardelleConditioning] | None = None, ) -> Float[Tensor, "batch atoms 3"]: diff --git a/tests/runs/test_runner.py b/tests/runs/test_runner.py index 442eda53..4635cd4e 100644 --- a/tests/runs/test_runner.py +++ b/tests/runs/test_runner.py @@ -265,7 +265,8 @@ def test_dry_run_does_not_create_directories( """--dry-run prints commands but never touches the filesystem.""" monkeypatch.setenv("HOME", str(tmp_path)) results_dir = tmp_path / "results" - preset = loader.load_preset("rf3_partial", overrides=["jobs.0.gpu_count=1"]) # use 1 gpu so we don't need big nodes to test + # use 1 gpu so we don't need big nodes to test + preset = loader.load_preset("rf3_partial", overrides=["jobs.0.gpu_count=1"]) runner.run(preset, results_dir=results_dir, dry_run=True) # results_dir gets created by run() (for log file location) but per-job # output subdirs must NOT exist after dry-run.