The following change in edm.py in https://github.com/diff-use/sampleworks/pull/274/changes#diff-db2bf8d2d9093b15fb1f471360efc972f4b0ec228b42c6e91d27e5f7081320f7R428
# TODO testing adding eps to signature for use with Protpardelle-1c, if successful,
# I need to modify the Protocol itself. @Michael Anzuoni
with torch.set_grad_enabled(allow_gradients):
x_hat_0 = model_wrapper.step(noisy_state, t_hat, eps, features=features)
technically breaks our Sampler protocol by passing an extra argument (here, eps, which is a noise level). It was introduced into the branch mdc/add-propardelle to properly handle frame transformations (since we should transform before adding noise in Protpardelle-1c.)
I (marcuscollins) don't think this will break any of the actual code, since the model wrappers all include a * for extra positional arguments to their step() methods. The question is how to handle it in the Protocol, when some arguments are required and others are not.
If PR https://github.com/diff-use/sampleworks/pull/274/changes isn't yet merged, feel free to add any changes there.
The following change in edm.py in https://github.com/diff-use/sampleworks/pull/274/changes#diff-db2bf8d2d9093b15fb1f471360efc972f4b0ec228b42c6e91d27e5f7081320f7R428
technically breaks our Sampler protocol by passing an extra argument (here, eps, which is a noise level). It was introduced into the branch mdc/add-propardelle to properly handle frame transformations (since we should transform before adding noise in Protpardelle-1c.)
I (marcuscollins) don't think this will break any of the actual code, since the model wrappers all include a
*for extra positional arguments to theirstep()methods. The question is how to handle it in the Protocol, when some arguments are required and others are not.If PR https://github.com/diff-use/sampleworks/pull/274/changes isn't yet merged, feel free to add any changes there.