-
Notifications
You must be signed in to change notification settings - Fork 7
Mdc/add protpardelle #274
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Mdc/add protpardelle #274
Changes from all commits
58f9d0f
9231374
50fffe9
6774d82
40a7f30
c53f5ad
6713208
4ac3e5b
505f824
fbdd548
15b9ec3
83f2ea6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,18 +3,25 @@ | |
| from __future__ import annotations | ||
|
|
||
| import sys | ||
| from loguru import logger | ||
|
|
||
| from sampleworks.utils.guidance_script_arguments import GuidanceConfig | ||
| from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance | ||
|
|
||
|
|
||
| def main(argv: list[str] | None = None) -> int: | ||
| config = GuidanceConfig.from_cli(argv) | ||
|
|
||
| from loguru import logger | ||
|
|
||
| from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance | ||
|
|
||
| logger.info(f"Running guidance with config: {config}") | ||
| device, model_wrapper = get_model_and_device( | ||
| config.device, | ||
| getattr(config, "model_checkpoint", None), | ||
| config.model, | ||
| method=getattr(config, "method", None), | ||
| protpardelle_config_path=getattr(config, "protpardelle_config_path", None), | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a way we can remove this here, or at least make it more general rather than including a model specific arg here? I know there are a lot of models that use hydra YAML configs, so I expect we will probably need to generalize this at some point |
||
| ) | ||
| result = run_guidance(config, config.guidance_type, model_wrapper, device) | ||
| return result.exit_code | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -415,15 +415,17 @@ def step( | |
|
|
||
| # Store eps separately for proper frame transformation | ||
| # eps_scale will be float if check_context didn't raise | ||
| eps = torch.randn_like(maybe_augmented_state) * eps_scale # ty: ignore[unsupported-operator] | ||
| eps = torch.randn_like(maybe_augmented_state) * eps_scale | ||
| noisy_state = maybe_augmented_state + eps | ||
| noisy_state = torch.as_tensor(noisy_state).detach().requires_grad_(allow_gradients) | ||
|
|
||
| # t_hat will be float if check_context didn't raise | ||
| # Use no_grad when gradients aren't needed to avoid memory overhead from | ||
| # gradient checkpointing holding intermediate activations | ||
| # TODO testing adding eps to signature for use with Protpardelle-1c, if successful, | ||
| # I need to modify the Protocol itself. @Michael Anzuoni | ||
| with torch.set_grad_enabled(allow_gradients): | ||
| x_hat_0 = model_wrapper.step(noisy_state, t_hat, features=features) | ||
| x_hat_0 = model_wrapper.step(noisy_state, t_hat, eps, features=features) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm confused here, because this no longer appears to be needed in the model step function that is in this PR. So I think #283 should be closed. |
||
|
|
||
| reconciler = ( | ||
| context.reconciler.to(torch.as_tensor(x_hat_0).device) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,101 @@ | ||
| data: | ||
| auto_calc_sigma_data: true | ||
| chain_residx_gap: 200 | ||
| dummy_fill_mode: zero | ||
| fixed_size: 512 | ||
| mixing_ratios: | ||
| - 1.0 | ||
| n_aatype_tokens: 21 | ||
| n_examples_for_sigma_data: 500 | ||
| pdb_paths: | ||
| - /scratch/users/tianyulu/augmented_ingraham_cath_bugfree | ||
| se3_data_augment: true | ||
| sigma_data: 10.3 | ||
| subset: | ||
| - designable_ | ||
| translation_scale: 1.0 | ||
| diffusion: | ||
| sampling: | ||
| function: uniform | ||
| s_max: 80 | ||
| s_min: 0.001 | ||
| training: | ||
| function: lognormal | ||
| psigma_mean: -0.5 | ||
| psigma_std: 1.5 | ||
| model: | ||
| compute_loss_on_all_atoms: false | ||
| conditioning_style: concat | ||
| crop_conditional: true | ||
| dummy_fill_masked_atoms: false | ||
| full_mpnn_model_path: /scratch/users/tianyulu/farfalle/ProteinMPNN/vanilla_model_weights | ||
| mpnn_model: | ||
| label_smoothing: 0.1 | ||
| n_channel: 128 | ||
| n_layers: 3 | ||
| n_neighbors: 32 | ||
| noise_cond_mult: 4 | ||
| use_self_conditioning: true | ||
| mpnn_model_checkpoint: '' | ||
| pretrained_modules: [] | ||
| struct_model: | ||
| arch: dit | ||
| n_atoms: 37 | ||
| n_channel: 256 | ||
| noise_cond_mult: 4 | ||
| uvit: | ||
| cat_pwd_to_conv: false | ||
| conv_skip_connection: false | ||
| dim_head: 32 | ||
| n_blocks_per_layer: 2 | ||
| n_filt_per_layer: [] | ||
| n_heads: 8 | ||
| n_layers: 10 | ||
| patch_size: 1 | ||
| position_embedding_max: 32 | ||
| position_embedding_type: rotary | ||
| struct_model_checkpoint: '' | ||
| task: ai-allatom | ||
| train: | ||
| batch_size: 32 | ||
| checkpoint_freq: 1 | ||
| checkpoints: [] | ||
| ckpt_path: /scratch/users/tianyulu/farfalle/out_dir/farfalle/cc89/checkpoints/epoch206_training_state.pth | ||
| clip_grad_norm: true | ||
| crop_cond: | ||
| contiguous_prob: 0.05 | ||
| discontiguous_prob: 0.9 | ||
| dist_threshold: 45.0 | ||
| max_discontiguous_res: 24 | ||
| max_span_len: 12 | ||
| recenter_coords: true | ||
| sidechain_only_prob: 0.0 | ||
| sidechain_prob: 0.9 | ||
| terms_prob: 0.5 | ||
| crop_conditional: true | ||
| decay_steps: 2000000 | ||
| eval_freq: 8000000 | ||
| eval_loss_t: | ||
| - 0.1 | ||
| - 0.3 | ||
| - 0.5 | ||
| - 0.7 | ||
| - 0.9 | ||
| fpd_length_ranges_per_chain: | ||
| - - 50 | ||
| - 256 | ||
| grad_clip_val: 1.0 | ||
| length_ranges_per_chain: | ||
| - - 166 | ||
| - 188 | ||
| lr: 0.0001 | ||
| max_epochs: 10000 | ||
| n_eval_samples: 10 | ||
| n_fpd_samples: 0 | ||
| sc_num_seqs: 4 | ||
| seed: 0 | ||
| self_cond_train_prob: 0.9 | ||
| shapes_path: /scratch/users/tianyulu/protein_shapes | ||
| subsample_eval_set: 0.05 | ||
| warmup_steps: 1000 | ||
| weight_decay: 0.0 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| """Protpardelle-1c model wrapper.""" |
Uh oh!
There was an error while loading. Please reload this page.