fix(samplers): reconcile FlowModelWrapper.step call with Protocol#288
Conversation
The EDM sampler passed a third positional `eps` to `model_wrapper.step()`, which is not part of the `FlowModelWrapper.step(x_t, t, *, features=None)` protocol. This raised `TypeError: step() takes 3 positional arguments but 4 were given` for every protocol-compliant wrapper (boltz/protenix/rf3 and the test mocks), failing all 164 sampler-path tests. The extra arg was only absorbed by Protpardelle's `step()`, where it landed in an unused `sigma_float` parameter. - edm.py: drop the extra `eps` positional (eps is still used locally to build the noisy state and for the working-frame guidance math). Removes the temporary "modify the Protocol" TODO. - protpardelle/wrapper.py: remove the dead `sigma_float` parameter so the signature matches the protocol. Also fixes the failing lint job: - cli/guidance.py: drop unused top-level `loguru.logger` import (F401/F811/I001). - protpardelle/wrapper.py: use the repo's leading-space single-axis jaxtyping convention (`" atoms"`) to avoid UP037/F821. - tests/runs/test_runner.py: wrap an over-length comment (E501). Fixes diff-use#283.
|
Important Review skippedAuto reviews are disabled on base/target branches other than the default branch. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Validation: ran the previously-failing tests on-cluster ✅Since CI only runs on PRs targeting Before this change these same tests failed with
Note (pre-existing, out of scope for this PR): |
Summary
Fixes #283 and unbreaks CI on #274 (
mdc/add-protpardelle).The EDM sampler called
model_wrapper.step(noisy_state, t_hat, eps, features=features), passing a third positionalepsthat is not part of theFlowModelWrapper.step(x_t, t, *, features=None)protocol. Because boltz/protenix/rf3 (and the test mocks) implement the protocol exactly — no extra positional, no**kwargs— every sampler-path test failed with:That single call site accounts for all 164 failing tests across the
boltz-dev/protenix-dev/rf3-devjobs (110MismatchCaseWrapper+ 54MockFlowModelWrapper).The extra argument was only ever absorbed by Protpardelle's
step(), where it landed in asigma_floatparameter that is never referenced in the method body — so removing it changes no behavior. This matches @k-chrispens's note on #283 (already addressed by #267;epsis the noise tensor, not a noise level) and restores the resolution that was lost when the branch was force-updated.Changes
Protocol reconciliation (#283)
core/samplers/edm.py: drop the extraepspositional from themodel_wrapper.step()call.epsis still computed locally and used to build the noisy state and the working-frame guidance math — only the (dead) hand-off to the wrapper is removed. Deleted the temporary "I need to modify the Protocol itself" TODO.models/protpardelle/wrapper.py: remove the unusedsigma_floatparameter sostep()matchesFlowModelWrapper.Lint job (was failing, 8 errors)
cli/guidance.py: remove unused top-levelfrom loguru import logger(it's re-imported insidemain()) — fixes F401 / F811 / I001.models/protpardelle/wrapper.py: use the repo's existing leading-space single-axis jaxtyping convention (Int[Tensor, " atoms"], as inedm.py/step_scalers.py) to avoid UP037 / F821.tests/runs/test_runner.py: wrap an over-length inline comment (E501).Validation
ruff 0.15.8 checkpasses on all touched files locally (same version as CI).edm.py:428positional and the tests already callstep(x, t, features=...), so they should go green.Closes #283.