Skip to content

fix(samplers): reconcile FlowModelWrapper.step call with Protocol#288

Merged
marcuscollins merged 1 commit into
diff-use:mdc/add-protpardellefrom
manzuoni-astera:michaelanzuoni/issue-283-sampler-protocol
Jul 1, 2026
Merged

fix(samplers): reconcile FlowModelWrapper.step call with Protocol#288
marcuscollins merged 1 commit into
diff-use:mdc/add-protpardellefrom
manzuoni-astera:michaelanzuoni/issue-283-sampler-protocol

Conversation

@manzuoni-astera

Copy link
Copy Markdown
Contributor

Summary

Fixes #283 and unbreaks CI on #274 (mdc/add-protpardelle).

The EDM sampler called model_wrapper.step(noisy_state, t_hat, eps, features=features), passing a third positional eps that is not part of the FlowModelWrapper.step(x_t, t, *, features=None) protocol. Because boltz/protenix/rf3 (and the test mocks) implement the protocol exactly — no extra positional, no **kwargs — every sampler-path test failed with:

TypeError: <Wrapper>.step() takes 3 positional arguments but 4 positional arguments (and 1 keyword-only argument) were given
  src/sampleworks/core/samplers/edm.py:428

That single call site accounts for all 164 failing tests across the boltz-dev / protenix-dev / rf3-dev jobs (110 MismatchCaseWrapper + 54 MockFlowModelWrapper).

The extra argument was only ever absorbed by Protpardelle's step(), where it landed in a sigma_float parameter that is never referenced in the method body — so removing it changes no behavior. This matches @k-chrispens's note on #283 (already addressed by #267; eps is the noise tensor, not a noise level) and restores the resolution that was lost when the branch was force-updated.

Changes

Protocol reconciliation (#283)

  • core/samplers/edm.py: drop the extra eps positional from the model_wrapper.step() call. eps is still computed locally and used to build the noisy state and the working-frame guidance math — only the (dead) hand-off to the wrapper is removed. Deleted the temporary "I need to modify the Protocol itself" TODO.
  • models/protpardelle/wrapper.py: remove the unused sigma_float parameter so step() matches FlowModelWrapper.

Lint job (was failing, 8 errors)

  • cli/guidance.py: remove unused top-level from loguru import logger (it's re-imported inside main()) — fixes F401 / F811 / I001.
  • models/protpardelle/wrapper.py: use the repo's existing leading-space single-axis jaxtyping convention (Int[Tensor, " atoms"], as in edm.py/step_scalers.py) to avoid UP037 / F821.
  • tests/runs/test_runner.py: wrap an over-length inline comment (E501).

Validation

  • ruff 0.15.8 check passes on all touched files locally (same version as CI).
  • Test suites can't run on this macOS checkout (pixi envs are linux-64); CI on this PR exercises them. Every failing test traced to the removed edm.py:428 positional and the tests already call step(x, t, features=...), so they should go green.

Closes #283.

The EDM sampler passed a third positional `eps` to `model_wrapper.step()`,
which is not part of the `FlowModelWrapper.step(x_t, t, *, features=None)`
protocol. This raised `TypeError: step() takes 3 positional arguments but 4
were given` for every protocol-compliant wrapper (boltz/protenix/rf3 and the
test mocks), failing all 164 sampler-path tests. The extra arg was only
absorbed by Protpardelle's `step()`, where it landed in an unused
`sigma_float` parameter.

- edm.py: drop the extra `eps` positional (eps is still used locally to build
  the noisy state and for the working-frame guidance math). Removes the
  temporary "modify the Protocol" TODO.
- protpardelle/wrapper.py: remove the dead `sigma_float` parameter so the
  signature matches the protocol.

Also fixes the failing lint job:
- cli/guidance.py: drop unused top-level `loguru.logger` import (F401/F811/I001).
- protpardelle/wrapper.py: use the repo's leading-space single-axis jaxtyping
  convention (`" atoms"`) to avoid UP037/F821.
- tests/runs/test_runner.py: wrap an over-length comment (E501).

Fixes diff-use#283.
@coderabbitai

coderabbitai Bot commented Jul 1, 2026

Copy link
Copy Markdown
Contributor

Important

Review skipped

Auto reviews are disabled on base/target branches other than the default branch.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 39e620db-c893-4bf7-a772-cb0638ccd9f6

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@manzuoni-astera

Copy link
Copy Markdown
Contributor Author

Validation: ran the previously-failing tests on-cluster ✅

Since CI only runs on PRs targeting main, I validated this branch directly on an Astera pod (baked sampleworks image, boltz env — torch 2.7.1+cu126, atomworks, einx, jaxtyping) by running the exact two files that accounted for all 164 failures on #274:

$ pytest tests/integration/test_pipeline_integration.py \
         tests/integration/test_mismatch_integration.py -q -m "not slow"
...
209 passed, 132 deselected, 2 warnings in 16.49s

Before this change these same tests failed with TypeError: <Wrapper>.step() takes 3 positional arguments but 4 were given (edm.py:428) — 110 MismatchCaseWrapper + 54 MockFlowModelWrapper. With the extra eps positional removed they're green.

ruff 0.15.8 check (the CI version) also passes on all four touched files, clearing the previously-failing lint job.

Note (pre-existing, out of scope for this PR): protpardelle/wrapper.py:656 emits a SyntaxWarning: invalid escape sequence '\h' — the :math:\\hat{t}`` docstring should be a raw string. Not introduced here; happy to fix in a follow-up if wanted.

@marcuscollins marcuscollins merged commit b3340ba into diff-use:mdc/add-protpardelle Jul 1, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants