Skip to content

Reconcile Sampler protocol with model_wrapper.step() arguments required by Protpardelle #283

Description

@marcuscollins

The following change in edm.py in https://github.com/diff-use/sampleworks/pull/274/changes#diff-db2bf8d2d9093b15fb1f471360efc972f4b0ec228b42c6e91d27e5f7081320f7R428

        # TODO testing adding eps to signature for use with Protpardelle-1c, if successful,
        #   I need to modify the Protocol itself. @Michael Anzuoni
        with torch.set_grad_enabled(allow_gradients):
            x_hat_0 = model_wrapper.step(noisy_state, t_hat, eps, features=features)

technically breaks our Sampler protocol by passing an extra argument (here, eps, which is a noise level). It was introduced into the branch mdc/add-propardelle to properly handle frame transformations (since we should transform before adding noise in Protpardelle-1c.)

I (marcuscollins) don't think this will break any of the actual code, since the model wrappers all include a * for extra positional arguments to their step() methods. The question is how to handle it in the Protocol, when some arguments are required and others are not.

If PR https://github.com/diff-use/sampleworks/pull/274/changes isn't yet merged, feel free to add any changes there.

Metadata

Metadata

Assignees

Labels

P0Highest priority workengineeringTask that is best suited to software engineers, not research scientists

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions