Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 38 additions & 19 deletions edit_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,33 +33,52 @@ def remove_checkpoint_keys(checkpoint_path, keys):

torch.save(checkpoint_dict, checkpoint_path)

if __name__ == "__main__":
# ---------------------------------------------------------------------------
# Entry-point helpers: collection → validation → processing
# ---------------------------------------------------------------------------

def build_parser():
import argparse
arg_parser = argparse.ArgumentParser(
description="Inspect or modify a model checkpoint file.")
arg_parser.add_argument("checkpoint_path", help="Path to the model checkpoint (.pth or directory)")
arg_parser.add_argument("--reset-optimizer", action="store_true", help="Remove the optimizer state")
arg_parser.add_argument("--reset-scheduler", action="store_true", help="Remove the scheduler state")
arg_parser.add_argument("--reset-epoch", action="store_true", help="Remove the epoch counter and history")
arg_parser.add_argument("--info", action="store_true", help="Print the current epoch and model config")
return arg_parser

arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("checkpoint_path", help="The model checkpoint path")
arg_parser.add_argument("--reset-optimizer", action="store_true", help="Reset the optimizer state")
arg_parser.add_argument("--reset-scheduler", action="store_true", help="Reset the scheduler state")
arg_parser.add_argument("--reset-epoch", action="store_true", help="Reset the epoch and history")
arg_parser.add_argument("--info", action="store_true", help="Show the current epoch & config")

args = arg_parser.parse_args()
def validate_config(cfg) -> None:
path = cfg.checkpoint_path
is_valid = os.path.isfile(path) or (os.path.isdir(path) and os.path.isfile(os.path.join(path, "checkpoint.pth")))
assert is_valid, f"Checkpoint not found: {path}"

if args.info:
checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
print(f"Epoch: {checkpoint['epoch']}")
print("Config:")
config_str = yaml.dump(checkpoint["hyper_parameters"])
print("\n".join([" " + i for i in config_str.split("\n")]))

def process_config(cfg) -> list:
keys_to_remove = []
if args.reset_optimizer:
if cfg.reset_optimizer:
keys_to_remove.append("optimizer")
if args.reset_scheduler:
if cfg.reset_scheduler:
keys_to_remove.append("scheduler")
if args.reset_epoch:
if cfg.reset_epoch:
keys_to_remove.append("epoch")
return keys_to_remove


if __name__ == "__main__":
from module.base_config import BaseConfig
cfg = BaseConfig(build_parser())
validate_config(cfg)

if cfg.info:
checkpoint = torch.load(cfg.checkpoint_path, map_location="cpu")
print(f"Epoch: {checkpoint['epoch']}")
print("Config:")
config_str = yaml.dump(checkpoint["hyper_parameters"])
print("\n".join([" " + i for i in config_str.split("\n")]))

if len(keys_to_remove):
remove_checkpoint_keys(args.checkpoint_path, keys_to_remove)
keys_to_remove = process_config(cfg)
if keys_to_remove:
remove_checkpoint_keys(cfg.checkpoint_path, keys_to_remove)

57 changes: 39 additions & 18 deletions learn_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,29 +556,50 @@ def process_batch(batch, training_mol_list, learnable_ff):

### End Train

if __name__ == "__main__":
# ---------------------------------------------------------------------------
# Entry-point helpers: collection → validation → processing
# ---------------------------------------------------------------------------

def build_parser():
import argparse
parser = argparse.ArgumentParser(description="Train a TorchMD prior forcefield")
parser.add_argument("input", help="Preprocessed prior & data to train on")
parser.add_argument("result", help="Checkpoint save directory")
parser.add_argument("--pdbids", nargs="*", help="List of specific PDB IDs to process")
parser.add_argument("--pdbids", nargs="*", help="Restrict training to these PDB IDs")
parser.add_argument("--batch", type=int, default=10, help="The batch size to use")
parser.add_argument("--epochs", type=int, default=25, help="The total number of epochs to train for")
parser.add_argument("--lr", type=float, default=0.005, help="Learning rate")
parser.add_argument("--gamma", type=float, default=None, help="Learning rate scheduler gamma")
parser.add_argument("--clip-parameters", action='store_true', help="Clip parameters to a valid range after each update")
parser.add_argument("--gamma", type=float, default=None, help="Exponential LR scheduler gamma")
parser.add_argument("--clip-parameters", action='store_true',
help="Clip parameters to a valid range after each update")
parser.add_argument('--initial-loss', default=True, action=argparse.BooleanOptionalAction,
help="Whether to calculate initial loss before training. Default=True")

args = parser.parse_args()
print(args)

train(data_dir = args.input,
pdb_ids = args.pdbids,
checkpoint_dir = args.result,
n_epochs = args.epochs,
batch_size = args.batch,
learning_rate = args.lr,
scheduler_gamma = args.gamma,
clip_parameters = args.clip_parameters,
initial_loss_calc = args.initial_loss)
help="Calculate loss before the first training step (default=True)")
return parser


def validate_config(cfg) -> None:
assert os.path.isdir(cfg.input), f"Input directory does not exist: {cfg.input}"
assert cfg.lr > 0, "--lr must be positive"
assert cfg.epochs > 0, "--epochs must be positive"


def process_config(cfg) -> dict:
return dict(
data_dir=cfg.input,
pdb_ids=cfg.pdbids,
checkpoint_dir=cfg.result,
n_epochs=cfg.epochs,
batch_size=cfg.batch,
learning_rate=cfg.lr,
scheduler_gamma=cfg.gamma,
clip_parameters=cfg.clip_parameters,
initial_loss_calc=cfg.initial_loss,
)


if __name__ == "__main__":
from module.base_config import BaseConfig
cfg = BaseConfig(build_parser())
print(cfg.as_namespace())
validate_config(cfg)
train(**process_config(cfg))
2 changes: 1 addition & 1 deletion modules
Submodule modules updated 1 files
+362 −0 base_config.py
168 changes: 104 additions & 64 deletions preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,100 +1007,140 @@ def gen_input_mapping(conf):
"CA_lj_only":Prior_CA_lj_only,
}

if __name__ == "__main__":
# ---------------------------------------------------------------------------
# Entry-point helpers: collection → validation → processing
# ---------------------------------------------------------------------------

def build_parser():
parser = argparse.ArgumentParser(description="Preprocess data.")
parser.add_argument("input", nargs = "+", help="Input directory path")
parser.add_argument("input", nargs="+",
help="Input directory paths or YAML dataset config files")
parser.add_argument("-o", "--output", required=True, help="Output directory path")
parser.add_argument("--pdbids", nargs="*", help="List of specific PDB IDs to process")
parser.add_argument("--num-frames", "--num_frames", type=int, default=None, help="Number of frames to process")
parser.add_argument("--frame-slice", type=str, default=None, help="Select frames to process using a python slice: start:end:stride")
parser.add_argument("--temp", type=int, default=300, help="Temperature")
parser.add_argument("--prior", type=str, default=None, help="Select the prior forcefield to use, must be one of: " + ", ".join(sorted(prior_types.keys())))
parser.add_argument("--optimize-forces", action="store_true", help="Use statistically optimal force aggregation (Kramer 2023)")
parser.add_argument("--prior-file", default=None, help="Use PRIOR_FILE instead of fitting a prior")
parser.add_argument('--no-box', default=False, action='store_true', help="Don't use periodic box information")
parser.add_argument('--prior-plots', default=True, action='store_true', help="Save plots of the prior fit functions")
parser.add_argument('--no-prior-plots', dest='prior_plots', action='store_false', help="Don't save plots of the prior fit functions")
parser.add_argument('--no-fit-constraints', default=False, action='store_true', help="Disable range constraints when fitting prior functions")
parser.add_argument('--fit-min-cnt', type=int, default=0, help="Only bins with cnt > min_cnt will be considered when fitting the prior (default 0)")
# parser.add_argument('--tag-beta-turns', default=False, action='store_true', help="Give beta turns a different bond type in the prior")
parser.add_argument('--resume', default=False, action='store_true', help="Resume processing rather than overwriting, all settings must be identical between calls")
parser.add_argument('--num-cores', type=int, default=32, help="Number of cores to be used for parallelization of preprocessing")
parser.add_argument('--jobid', type=int, default=None, help="Integer denoting jobid, if not -1, it will only process a subset of the PDBs")
parser.add_argument('--totalNrJobs', type=int, default=None, help="Integer denoting how many jobs are in total.")


args = parser.parse_args()
print(args)

output_dir = args.output
pdbids = args.pdbids
assert not (args.num_frames and args.frame_slice)
if args.num_frames:
frame_slice = slice(0, args.num_frames)
elif args.frame_slice:
# Convert the arg string into a slice
frame_slice = slice(*[int(i) if i != '' else None for i in args.frame_slice.split(":") ])
parser.add_argument("--pdbids", nargs="*", help="Restrict processing to these PDB IDs")
parser.add_argument("--num-frames", "--num_frames", type=int, default=None,
help="Number of frames to process (mutually exclusive with --frame-slice)")
parser.add_argument("--frame-slice", type=str, default=None,
help="Select frames via a Python slice string: start:end:stride "
"(mutually exclusive with --num-frames)")
parser.add_argument("--temp", type=int, default=300, help="Temperature (K)")
parser.add_argument("--prior", type=str, default=None,
help="Prior forcefield to use, one of: " + ", ".join(sorted(prior_types.keys())))
parser.add_argument("--optimize-forces", action="store_true",
help="Use statistically optimal force aggregation (Kramer 2023)")
parser.add_argument("--prior-file", default=None,
help="Load a pre-built prior from PRIOR_FILE instead of fitting one")
parser.add_argument('--no-box', default=False, action='store_true',
help="Don't use periodic box information")
parser.add_argument('--prior-plots', default=True, action='store_true',
help="Save plots of the prior fit functions")
parser.add_argument('--no-prior-plots', dest='prior_plots', action='store_false',
help="Don't save plots of the prior fit functions")
parser.add_argument('--no-fit-constraints', default=False, action='store_true',
help="Disable range constraints when fitting prior functions")
parser.add_argument('--fit-min-cnt', type=int, default=0,
help="Minimum bin count to include when fitting the prior (default 0)")
# parser.add_argument('--tag-beta-turns', default=False, action='store_true',
# help="Give beta turns a different bond type in the prior")
parser.add_argument('--resume', default=False, action='store_true',
help="Resume an interrupted run — all settings must be identical")
parser.add_argument('--num-cores', type=int, default=32,
help="CPU cores for parallel preprocessing (default 32)")
parser.add_argument('--jobid', type=int, default=None,
help="Job index for array jobs; limits processing to a subset of PDBs")
parser.add_argument('--totalNrJobs', type=int, default=None,
help="Total number of array jobs (used with --jobid)")
return parser


def validate_config(cfg) -> None:
assert not (cfg.num_frames and cfg.frame_slice), \
"--num-frames and --frame-slice are mutually exclusive"
if cfg.prior_file:
assert os.path.exists(cfg.prior_file), f"Prior file does not exist: {cfg.prior_file}"
assert cfg.prior or cfg.prior_file, \
"Specify a prior with --prior or --prior-file"


def process_config(cfg):
# Frame selection
if cfg.num_frames:
frame_slice = slice(0, cfg.num_frames)
elif cfg.frame_slice:
frame_slice = slice(*[int(i) if i != '' else None for i in cfg.frame_slice.split(":")])
else:
frame_slice = slice(None)
temp = args.temp
optimize_forces = args.optimize_forces
box = not args.no_box
prior_plots = args.prior_plots
prior_name = args.prior
prior_file = args.prior_file
resume_preprocess = args.resume
num_cores = args.num_cores
jobid = args.jobid
totalNrJobs = args.totalNrJobs

# Resolve prior name from --prior-file if not explicitly set
prior_name = cfg.prior
prior_file = cfg.prior_file
if prior_file:
assert os.path.exists(prior_file), f"Prior file does not exist: {prior_file}"
prior_params_path = get_prior_params_path(prior_file)
with open(prior_params_path, "r", encoding="utf-8") as f:
prior_params = json.load(f)
prior_configuration_name = prior_params["prior_configuration_name"]
if prior_name is None:
prior_name = prior_configuration_name
elif prior_name != prior_configuration_name:
print()
print(f"WARNING: Prior \"{prior_name}\" differs from the one used to build the prior file \"{prior_configuration_name}\"")
print()

assert prior_name, " You must specify the prior to use with either --prior or --prior-file"
prior_configuration_name = prior_params["prior_configuration_name"]
if prior_name is None:
prior_name = prior_configuration_name
elif prior_name != prior_configuration_name:
print()
print(f"WARNING: Prior \"{prior_name}\" differs from the one in the prior file \"{prior_configuration_name}\"")
print()

if prior_name not in prior_types:
raise RuntimeError(f"Unknown prior configuration: {prior_name}")
print(f"Using prior: {prior_name}")
prior_builder = prior_types[prior_name]() # <- () to instantiate the class
prior_builder.enable_fit_constraints(not args.no_fit_constraints)
# prior_builder.enable_bond_tags(args.tag_beta_turns)

# Build and configure the prior
prior_builder = prior_types[prior_name]()
prior_builder.enable_fit_constraints(not cfg.no_fit_constraints)
# prior_builder.enable_bond_tags(cfg.tag_beta_turns)
prior_builder.enable_bond_tags(False)
prior_builder.set_min_cnt(args.fit_min_cnt)
prior_builder.set_min_cnt(cfg.fit_min_cnt)

if 'external' in prior_builder.prior_params.keys():
mp.set_start_method('spawn')

# Set matplotlib to use a thread safe backend (for prior fit plots)
# Thread-safe matplotlib backend required for prior fit plots
import matplotlib
matplotlib.use('Agg')

# Build dataset config from file paths or inline dicts
dataset_conf = []

for i in args.input:
for i in cfg.input:
if os.path.isfile(i):
with open(args.input[0], "r") as f:
with open(i, "r") as f:
dataset_conf += yaml.safe_load(f)
else:
dataset_conf += [{"path": i}]

input_path_map = gen_input_mapping(dataset_conf)
if cfg.pdbids:
input_path_map = {i: input_path_map[i] for i in cfg.pdbids}

return dict(
dataset_conf=dataset_conf,
input_path_map=input_path_map,
output_dir=cfg.output,
prior_builder=prior_builder,
prior_file=prior_file,
prior_name=prior_name,
frame_slice=frame_slice,
temp=cfg.temp,
optimize_forces=cfg.optimize_forces,
box=not cfg.no_box,
prior_plots=cfg.prior_plots,
resume_preprocess=cfg.resume,
num_cores=cfg.num_cores,
jobid=cfg.jobid,
totalNrJobs=cfg.totalNrJobs,
)

if pdbids:
input_path_map = {i: input_path_map[i] for i in pdbids}

preprocessor = Preprocessor(dataset_conf, input_path_map, output_dir, prior_builder, prior_file, prior_name, frame_slice, temp, optimize_forces, box, prior_plots, resume_preprocess, num_cores, jobid, totalNrJobs)

if __name__ == "__main__":
from module.base_config import BaseConfig
cfg = BaseConfig(build_parser())
print(cfg.as_namespace())
validate_config(cfg)
params = process_config(cfg)
preprocessor = Preprocessor(**params)
preprocessor.preprocess()

Loading