diff --git a/edit_checkpoint.py b/edit_checkpoint.py index a104b8b..c0e87dc 100755 --- a/edit_checkpoint.py +++ b/edit_checkpoint.py @@ -33,33 +33,52 @@ def remove_checkpoint_keys(checkpoint_path, keys): torch.save(checkpoint_dict, checkpoint_path) -if __name__ == "__main__": +# --------------------------------------------------------------------------- +# Entry-point helpers: collection → validation → processing +# --------------------------------------------------------------------------- + +def build_parser(): import argparse + arg_parser = argparse.ArgumentParser( + description="Inspect or modify a model checkpoint file.") + arg_parser.add_argument("checkpoint_path", help="Path to the model checkpoint (.pth or directory)") + arg_parser.add_argument("--reset-optimizer", action="store_true", help="Remove the optimizer state") + arg_parser.add_argument("--reset-scheduler", action="store_true", help="Remove the scheduler state") + arg_parser.add_argument("--reset-epoch", action="store_true", help="Remove the epoch counter and history") + arg_parser.add_argument("--info", action="store_true", help="Print the current epoch and model config") + return arg_parser - arg_parser = argparse.ArgumentParser() - arg_parser.add_argument("checkpoint_path", help="The model checkpoint path") - arg_parser.add_argument("--reset-optimizer", action="store_true", help="Reset the optimizer state") - arg_parser.add_argument("--reset-scheduler", action="store_true", help="Reset the scheduler state") - arg_parser.add_argument("--reset-epoch", action="store_true", help="Reset the epoch and history") - arg_parser.add_argument("--info", action="store_true", help="Show the current epoch & config") - args = arg_parser.parse_args() +def validate_config(cfg) -> None: + path = cfg.checkpoint_path + is_valid = os.path.isfile(path) or (os.path.isdir(path) and os.path.isfile(os.path.join(path, "checkpoint.pth"))) + assert is_valid, f"Checkpoint not found: {path}" - if args.info: - checkpoint = torch.load(args.checkpoint_path, map_location="cpu") - print(f"Epoch: {checkpoint['epoch']}") - print("Config:") - config_str = yaml.dump(checkpoint["hyper_parameters"]) - print("\n".join([" " + i for i in config_str.split("\n")])) +def process_config(cfg) -> list: keys_to_remove = [] - if args.reset_optimizer: + if cfg.reset_optimizer: keys_to_remove.append("optimizer") - if args.reset_scheduler: + if cfg.reset_scheduler: keys_to_remove.append("scheduler") - if args.reset_epoch: + if cfg.reset_epoch: keys_to_remove.append("epoch") + return keys_to_remove + + +if __name__ == "__main__": + from module.base_config import BaseConfig + cfg = BaseConfig(build_parser()) + validate_config(cfg) + + if cfg.info: + checkpoint = torch.load(cfg.checkpoint_path, map_location="cpu") + print(f"Epoch: {checkpoint['epoch']}") + print("Config:") + config_str = yaml.dump(checkpoint["hyper_parameters"]) + print("\n".join([" " + i for i in config_str.split("\n")])) - if len(keys_to_remove): - remove_checkpoint_keys(args.checkpoint_path, keys_to_remove) + keys_to_remove = process_config(cfg) + if keys_to_remove: + remove_checkpoint_keys(cfg.checkpoint_path, keys_to_remove) diff --git a/learn_prior.py b/learn_prior.py index 08b7cfa..333c911 100755 --- a/learn_prior.py +++ b/learn_prior.py @@ -556,29 +556,50 @@ def process_batch(batch, training_mol_list, learnable_ff): ### End Train -if __name__ == "__main__": +# --------------------------------------------------------------------------- +# Entry-point helpers: collection → validation → processing +# --------------------------------------------------------------------------- + +def build_parser(): import argparse parser = argparse.ArgumentParser(description="Train a TorchMD prior forcefield") parser.add_argument("input", help="Preprocessed prior & data to train on") parser.add_argument("result", help="Checkpoint save directory") - parser.add_argument("--pdbids", nargs="*", help="List of specific PDB IDs to process") + parser.add_argument("--pdbids", nargs="*", help="Restrict training to these PDB IDs") parser.add_argument("--batch", type=int, default=10, help="The batch size to use") parser.add_argument("--epochs", type=int, default=25, help="The total number of epochs to train for") parser.add_argument("--lr", type=float, default=0.005, help="Learning rate") - parser.add_argument("--gamma", type=float, default=None, help="Learning rate scheduler gamma") - parser.add_argument("--clip-parameters", action='store_true', help="Clip parameters to a valid range after each update") + parser.add_argument("--gamma", type=float, default=None, help="Exponential LR scheduler gamma") + parser.add_argument("--clip-parameters", action='store_true', + help="Clip parameters to a valid range after each update") parser.add_argument('--initial-loss', default=True, action=argparse.BooleanOptionalAction, - help="Whether to calculate initial loss before training. Default=True") - - args = parser.parse_args() - print(args) - - train(data_dir = args.input, - pdb_ids = args.pdbids, - checkpoint_dir = args.result, - n_epochs = args.epochs, - batch_size = args.batch, - learning_rate = args.lr, - scheduler_gamma = args.gamma, - clip_parameters = args.clip_parameters, - initial_loss_calc = args.initial_loss) + help="Calculate loss before the first training step (default=True)") + return parser + + +def validate_config(cfg) -> None: + assert os.path.isdir(cfg.input), f"Input directory does not exist: {cfg.input}" + assert cfg.lr > 0, "--lr must be positive" + assert cfg.epochs > 0, "--epochs must be positive" + + +def process_config(cfg) -> dict: + return dict( + data_dir=cfg.input, + pdb_ids=cfg.pdbids, + checkpoint_dir=cfg.result, + n_epochs=cfg.epochs, + batch_size=cfg.batch, + learning_rate=cfg.lr, + scheduler_gamma=cfg.gamma, + clip_parameters=cfg.clip_parameters, + initial_loss_calc=cfg.initial_loss, + ) + + +if __name__ == "__main__": + from module.base_config import BaseConfig + cfg = BaseConfig(build_parser()) + print(cfg.as_namespace()) + validate_config(cfg) + train(**process_config(cfg)) diff --git a/modules b/modules index 78e5ec5..7dd1aec 160000 --- a/modules +++ b/modules @@ -1 +1 @@ -Subproject commit 78e5ec5d0743420b0eb8e526d80c89a4e734b704 +Subproject commit 7dd1aec113b0fe14d318db3e57e2b714ab0b269d diff --git a/preprocess.py b/preprocess.py index d7657e8..a8b02e8 100755 --- a/preprocess.py +++ b/preprocess.py @@ -1007,100 +1007,140 @@ def gen_input_mapping(conf): "CA_lj_only":Prior_CA_lj_only, } -if __name__ == "__main__": +# --------------------------------------------------------------------------- +# Entry-point helpers: collection → validation → processing +# --------------------------------------------------------------------------- +def build_parser(): parser = argparse.ArgumentParser(description="Preprocess data.") - parser.add_argument("input", nargs = "+", help="Input directory path") + parser.add_argument("input", nargs="+", + help="Input directory paths or YAML dataset config files") parser.add_argument("-o", "--output", required=True, help="Output directory path") - parser.add_argument("--pdbids", nargs="*", help="List of specific PDB IDs to process") - parser.add_argument("--num-frames", "--num_frames", type=int, default=None, help="Number of frames to process") - parser.add_argument("--frame-slice", type=str, default=None, help="Select frames to process using a python slice: start:end:stride") - parser.add_argument("--temp", type=int, default=300, help="Temperature") - parser.add_argument("--prior", type=str, default=None, help="Select the prior forcefield to use, must be one of: " + ", ".join(sorted(prior_types.keys()))) - parser.add_argument("--optimize-forces", action="store_true", help="Use statistically optimal force aggregation (Kramer 2023)") - parser.add_argument("--prior-file", default=None, help="Use PRIOR_FILE instead of fitting a prior") - parser.add_argument('--no-box', default=False, action='store_true', help="Don't use periodic box information") - parser.add_argument('--prior-plots', default=True, action='store_true', help="Save plots of the prior fit functions") - parser.add_argument('--no-prior-plots', dest='prior_plots', action='store_false', help="Don't save plots of the prior fit functions") - parser.add_argument('--no-fit-constraints', default=False, action='store_true', help="Disable range constraints when fitting prior functions") - parser.add_argument('--fit-min-cnt', type=int, default=0, help="Only bins with cnt > min_cnt will be considered when fitting the prior (default 0)") - # parser.add_argument('--tag-beta-turns', default=False, action='store_true', help="Give beta turns a different bond type in the prior") - parser.add_argument('--resume', default=False, action='store_true', help="Resume processing rather than overwriting, all settings must be identical between calls") - parser.add_argument('--num-cores', type=int, default=32, help="Number of cores to be used for parallelization of preprocessing") - parser.add_argument('--jobid', type=int, default=None, help="Integer denoting jobid, if not -1, it will only process a subset of the PDBs") - parser.add_argument('--totalNrJobs', type=int, default=None, help="Integer denoting how many jobs are in total.") - - - args = parser.parse_args() - print(args) - - output_dir = args.output - pdbids = args.pdbids - assert not (args.num_frames and args.frame_slice) - if args.num_frames: - frame_slice = slice(0, args.num_frames) - elif args.frame_slice: - # Convert the arg string into a slice - frame_slice = slice(*[int(i) if i != '' else None for i in args.frame_slice.split(":") ]) + parser.add_argument("--pdbids", nargs="*", help="Restrict processing to these PDB IDs") + parser.add_argument("--num-frames", "--num_frames", type=int, default=None, + help="Number of frames to process (mutually exclusive with --frame-slice)") + parser.add_argument("--frame-slice", type=str, default=None, + help="Select frames via a Python slice string: start:end:stride " + "(mutually exclusive with --num-frames)") + parser.add_argument("--temp", type=int, default=300, help="Temperature (K)") + parser.add_argument("--prior", type=str, default=None, + help="Prior forcefield to use, one of: " + ", ".join(sorted(prior_types.keys()))) + parser.add_argument("--optimize-forces", action="store_true", + help="Use statistically optimal force aggregation (Kramer 2023)") + parser.add_argument("--prior-file", default=None, + help="Load a pre-built prior from PRIOR_FILE instead of fitting one") + parser.add_argument('--no-box', default=False, action='store_true', + help="Don't use periodic box information") + parser.add_argument('--prior-plots', default=True, action='store_true', + help="Save plots of the prior fit functions") + parser.add_argument('--no-prior-plots', dest='prior_plots', action='store_false', + help="Don't save plots of the prior fit functions") + parser.add_argument('--no-fit-constraints', default=False, action='store_true', + help="Disable range constraints when fitting prior functions") + parser.add_argument('--fit-min-cnt', type=int, default=0, + help="Minimum bin count to include when fitting the prior (default 0)") + # parser.add_argument('--tag-beta-turns', default=False, action='store_true', + # help="Give beta turns a different bond type in the prior") + parser.add_argument('--resume', default=False, action='store_true', + help="Resume an interrupted run — all settings must be identical") + parser.add_argument('--num-cores', type=int, default=32, + help="CPU cores for parallel preprocessing (default 32)") + parser.add_argument('--jobid', type=int, default=None, + help="Job index for array jobs; limits processing to a subset of PDBs") + parser.add_argument('--totalNrJobs', type=int, default=None, + help="Total number of array jobs (used with --jobid)") + return parser + + +def validate_config(cfg) -> None: + assert not (cfg.num_frames and cfg.frame_slice), \ + "--num-frames and --frame-slice are mutually exclusive" + if cfg.prior_file: + assert os.path.exists(cfg.prior_file), f"Prior file does not exist: {cfg.prior_file}" + assert cfg.prior or cfg.prior_file, \ + "Specify a prior with --prior or --prior-file" + + +def process_config(cfg): + # Frame selection + if cfg.num_frames: + frame_slice = slice(0, cfg.num_frames) + elif cfg.frame_slice: + frame_slice = slice(*[int(i) if i != '' else None for i in cfg.frame_slice.split(":")]) else: frame_slice = slice(None) - temp = args.temp - optimize_forces = args.optimize_forces - box = not args.no_box - prior_plots = args.prior_plots - prior_name = args.prior - prior_file = args.prior_file - resume_preprocess = args.resume - num_cores = args.num_cores - jobid = args.jobid - totalNrJobs = args.totalNrJobs + # Resolve prior name from --prior-file if not explicitly set + prior_name = cfg.prior + prior_file = cfg.prior_file if prior_file: - assert os.path.exists(prior_file), f"Prior file does not exist: {prior_file}" prior_params_path = get_prior_params_path(prior_file) with open(prior_params_path, "r", encoding="utf-8") as f: prior_params = json.load(f) - prior_configuration_name = prior_params["prior_configuration_name"] - if prior_name is None: - prior_name = prior_configuration_name - elif prior_name != prior_configuration_name: - print() - print(f"WARNING: Prior \"{prior_name}\" differs from the one used to build the prior file \"{prior_configuration_name}\"") - print() - - assert prior_name, " You must specify the prior to use with either --prior or --prior-file" + prior_configuration_name = prior_params["prior_configuration_name"] + if prior_name is None: + prior_name = prior_configuration_name + elif prior_name != prior_configuration_name: + print() + print(f"WARNING: Prior \"{prior_name}\" differs from the one in the prior file \"{prior_configuration_name}\"") + print() if prior_name not in prior_types: raise RuntimeError(f"Unknown prior configuration: {prior_name}") print(f"Using prior: {prior_name}") - prior_builder = prior_types[prior_name]() # <- () to instantiate the class - prior_builder.enable_fit_constraints(not args.no_fit_constraints) - # prior_builder.enable_bond_tags(args.tag_beta_turns) + + # Build and configure the prior + prior_builder = prior_types[prior_name]() + prior_builder.enable_fit_constraints(not cfg.no_fit_constraints) + # prior_builder.enable_bond_tags(cfg.tag_beta_turns) prior_builder.enable_bond_tags(False) - prior_builder.set_min_cnt(args.fit_min_cnt) + prior_builder.set_min_cnt(cfg.fit_min_cnt) if 'external' in prior_builder.prior_params.keys(): mp.set_start_method('spawn') - # Set matplotlib to use a thread safe backend (for prior fit plots) + # Thread-safe matplotlib backend required for prior fit plots import matplotlib matplotlib.use('Agg') + # Build dataset config from file paths or inline dicts dataset_conf = [] - - for i in args.input: + for i in cfg.input: if os.path.isfile(i): - with open(args.input[0], "r") as f: + with open(i, "r") as f: dataset_conf += yaml.safe_load(f) else: dataset_conf += [{"path": i}] input_path_map = gen_input_mapping(dataset_conf) + if cfg.pdbids: + input_path_map = {i: input_path_map[i] for i in cfg.pdbids} + + return dict( + dataset_conf=dataset_conf, + input_path_map=input_path_map, + output_dir=cfg.output, + prior_builder=prior_builder, + prior_file=prior_file, + prior_name=prior_name, + frame_slice=frame_slice, + temp=cfg.temp, + optimize_forces=cfg.optimize_forces, + box=not cfg.no_box, + prior_plots=cfg.prior_plots, + resume_preprocess=cfg.resume, + num_cores=cfg.num_cores, + jobid=cfg.jobid, + totalNrJobs=cfg.totalNrJobs, + ) - if pdbids: - input_path_map = {i: input_path_map[i] for i in pdbids} - - preprocessor = Preprocessor(dataset_conf, input_path_map, output_dir, prior_builder, prior_file, prior_name, frame_slice, temp, optimize_forces, box, prior_plots, resume_preprocess, num_cores, jobid, totalNrJobs) +if __name__ == "__main__": + from module.base_config import BaseConfig + cfg = BaseConfig(build_parser()) + print(cfg.as_namespace()) + validate_config(cfg) + params = process_config(cfg) + preprocessor = Preprocessor(**params) preprocessor.preprocess() diff --git a/simulate.py b/simulate.py index 02426b3..3eb8fab 100755 --- a/simulate.py +++ b/simulate.py @@ -531,33 +531,75 @@ def build_calc(model, mol, embeddings, use_box=False, replicas=1, temperature=30 class ArgsMock(): pass -def main(): +# --------------------------------------------------------------------------- +# Entry-point helpers: collection → validation → processing +# --------------------------------------------------------------------------- + +def build_parser(): import argparse arg_parser = argparse.ArgumentParser() arg_parser.add_argument("checkpoint_path", help="The model checkpoint to use") - arg_parser.add_argument("processed_path", help="One or more input files from which the model simulations will start. Each different input file will be processed as a different simulation replica.", nargs="+") - arg_parser.add_argument("--conf", default=None, help="The hyperparameters to use instead of those contained in the checkpoint") - arg_parser.add_argument("--max-num-neighbors", default=None, type=int, help="Override the 'max_num_neighbors' parameter of the model") + arg_parser.add_argument("processed_path", nargs="+", + help="One or more input files to start simulations from. " + "Each file is treated as a separate replica.") + arg_parser.add_argument("--conf", default=None, + help="Hyperparameter file to use instead of those stored in the checkpoint") + arg_parser.add_argument("--max-num-neighbors", default=None, type=int, + help="Override the 'max_num_neighbors' parameter of the model") arg_parser.add_argument("--temperature", default=300, type=int, help="Simulation temperature (Kelvin)") arg_parser.add_argument("--timestep", default=1, type=int, help="Simulation timestep (femtoseconds)") arg_parser.add_argument("--steps", default=10000, type=int, help="The number of frames to simulate") - arg_parser.add_argument("-o", "--output", default="sim.pdb", help="The output file name (may be a .pdb or .h5 file)") + arg_parser.add_argument("-o", "--output", default="sim.pdb", + help="Output file name (.pdb or .h5)") arg_parser.add_argument("--save-steps", default=100, type=int, help="Save frames every n steps") - arg_parser.add_argument("--prior-only", default=False, action='store_true', help="Disable the model and use only the prior forcefield") + arg_parser.add_argument("--prior-only", default=False, action='store_true', + help="Disable the model and run with the prior forcefield only") arg_parser.add_argument('--no-box', action='store_true', help='Do not use box information') - arg_parser.add_argument("--replicas", default=1, type=int, help="The number of simulations running in parallel") - arg_parser.add_argument("--torchmd", default=False, action='store_true', help="Use TorchMD for the prior instead of TorchForceField") - arg_parser.add_argument("--verbose", default=True, action='store_true', help="Prints detailed logs") - arg_parser.add_argument("--prior-nn", default=None, type=str, help="Path to the folder of a neural network prior.") - + arg_parser.add_argument("--replicas", default=1, type=int, + help="Number of parallel simulations to run") + arg_parser.add_argument("--torchmd", default=False, action='store_true', + help="Use TorchMD for the prior instead of TorchForceField") + arg_parser.add_argument("--verbose", default=True, action='store_true', help="Print detailed logs") + arg_parser.add_argument("--prior-nn", default=None, type=str, + help="Path to a neural network prior folder") + return arg_parser + + +def validate_config(cfg) -> None: + assert os.path.exists(cfg.checkpoint_path) or os.path.isdir(cfg.checkpoint_path), \ + f"Checkpoint not found: {cfg.checkpoint_path}" + output_dir = os.path.dirname(cfg.output) + if output_dir: + assert os.path.exists(output_dir), f"Output directory does not exist: {output_dir}" - args = arg_parser.parse_args() - print(args) - run_simulation(args.checkpoint_path, args.processed_path, conf=args.conf, max_num_neighbors=args.max_num_neighbors, - temperature=args.temperature, timestep=args.timestep, steps=args.steps, output=args.output, save_steps=args.save_steps, - prior_only=args.prior_only, no_box=args.no_box, replicas=args.replicas, torchmd=args.torchmd, verbose=args.verbose, - prior_nn=args.prior_nn, gpu=0) +def process_config(cfg) -> dict: + return dict( + checkpoint_path=cfg.checkpoint_path, + processed_path=cfg.processed_path, + conf=cfg.conf, + max_num_neighbors=cfg.max_num_neighbors, + temperature=cfg.temperature, + timestep=cfg.timestep, + steps=cfg.steps, + output=cfg.output, + save_steps=cfg.save_steps, + prior_only=cfg.prior_only, + no_box=cfg.no_box, + replicas=cfg.replicas, + torchmd=cfg.torchmd, + verbose=cfg.verbose, + prior_nn=cfg.prior_nn, + gpu=0, + ) + + +def main(): + from module.base_config import BaseConfig + cfg = BaseConfig(build_parser()) + validate_config(cfg) + params = process_config(cfg) + run_simulation(**params) def run_simulation(checkpoint_path, processed_path, conf=None, max_num_neighbors=None, temperature=300, timestep=1, steps=10000, output='sim.pdb', save_steps=100, diff --git a/train.py b/train.py index 40d8f55..39b4375 100755 --- a/train.py +++ b/train.py @@ -4,10 +4,12 @@ import torch.nn as nn import torch.optim as optim import torch.utils.data +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler import yaml import numpy as np -# from torchmdnet.models.model import create_model from module.torchmdnet.model import create_model from module import dataset from module import model_util @@ -17,42 +19,179 @@ import json import time from tqdm import tqdm -import datetime +import datetime import shutil import resource import sys import traceback import itertools +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple -# Type hinting... -from typing import Tuple from torch import Tensor -# Useful for debugging pytorch CUDA crashes -# os.environ["CUDA_LAUNCH_BLOCKING"]="1" +# os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + + +# =========================================================================== +# Distributed context +# =========================================================================== + +@dataclass +class DistributedContext: + rank: Optional[int] = None + world_size: Optional[int] = None + local_rank: Optional[int] = None + + @property + def is_main(self) -> bool: + return self.rank is None or self.rank == 0 + + @property + def is_distributed(self) -> bool: + return self.world_size is not None and self.world_size > 1 + + def barrier(self): + if self.is_distributed: + dist.barrier() + + def all_reduce_scalar(self, value: float, device) -> float: + if not self.is_distributed: + return value + t = torch.tensor(value, device=device) + dist.all_reduce(t, op=dist.ReduceOp.SUM) + return t.item() + + def broadcast_bool(self, value: bool, device) -> bool: + if not self.is_distributed: + return value + t = torch.tensor(1 if value else 0, device=device) + dist.broadcast(t, src=0) + return bool(t.item()) + + +def setup_distributed() -> DistributedContext: + """Auto-detect torchrun or SLURM environment and initialize NCCL process group.""" + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + rank = int(os.environ['RANK']) + world_size = int(os.environ['WORLD_SIZE']) + local_rank = int(os.environ.get('LOCAL_RANK', 0)) + elif 'SLURM_PROCID' in os.environ: + rank = int(os.environ['SLURM_PROCID']) + world_size = int(os.environ['SLURM_NTASKS']) + local_rank = int(os.environ.get('SLURM_LOCALID', 0)) + else: + return DistributedContext() + + torch.cuda.set_device(local_rank) + dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) + return DistributedContext(rank=rank, world_size=world_size, local_rank=local_rank) + + +def cleanup_distributed(): + if dist.is_initialized(): + dist.destroy_process_group() + + +# =========================================================================== +# Configuration dataclasses +# =========================================================================== + +@dataclass +class LossConfig: + energy_matching: bool + energy_weight: float + force_weight: float + train_term_def: 'TermDef' + extra_term_keys: List[str] + use_force_weights: bool + + +@dataclass +class EpochMetrics: + """Aggregated (summed) metrics for one epoch. Divide by num_cal to get means.""" + total_loss: float = 0.0 + energy_loss: float = 0.0 + force_loss: float = 0.0 + num_cal: int = 0 + term_losses: Dict[str, float] = field(default_factory=dict) + term_num_cal: Dict[str, int] = field(default_factory=dict) + + def mean_loss(self) -> float: + return self.total_loss / self.num_cal if self.num_cal > 0 else 0.0 + + def mean_energy_loss(self) -> float: + return self.energy_loss / self.num_cal if self.num_cal > 0 else 0.0 + + def mean_force_loss(self) -> float: + return self.force_loss / self.num_cal if self.num_cal > 0 else 0.0 + + def mean_term_loss(self, k: str) -> float: + n = self.term_num_cal.get(k, 0) + return self.term_losses.get(k, 0.0) / n if n > 0 else 0.0 + + +# =========================================================================== +# Pure utilities +# =========================================================================== def flatten_first(t): - """Flatten the first two dimentions of tensor t""" + """Flatten the first two dimensions of a tensor.""" if t is None: return t if len(t.shape) < 2: return t - return t.reshape(t.shape[0]*t.shape[1], *t.shape[2:]) + return t.reshape(t.shape[0] * t.shape[1], *t.shape[2:]) + def make_term_offsets(lengths, term_lengths): result = [] count = 0 - - repeats = len(term_lengths)//len(lengths) + repeats = len(term_lengths) // len(lengths) lengths = np.tile(lengths, repeats) assert len(lengths) == len(term_lengths) - - # For each batch we want to offset the indicies used by the terms by the number of atoms in the prior batches for off, nterms in zip(lengths, term_lengths): result.append(torch.full((nterms, 1), count, dtype=torch.long)) count += off return torch.cat(result) + +def deterministic_shuffle(target, seed): + generator = torch.Generator().manual_seed(seed) + indices = torch.randperm(n=len(target), generator=generator, device="cpu") + return [target[i] for i in indices] + + +def check_early_stopping(val_list, patience=1) -> bool: + """Return True if val loss increased patience+1 consecutive epochs. patience<0 disables.""" + if patience < 0: + return False + if len(val_list) < patience + 2: + return False + check_range = np.array(val_list[-(patience + 2):]) + if np.all((check_range[1:] - check_range[:-1]) > 0): + print(f"Validation loss increased {patience + 1} times, stopping...") + return True + return False + + +def should_decay(param_name: str) -> bool: + parts = param_name.split('.') + assert len(parts) > 0 + if parts[-1] == "bias": + return False + if parts[-2] == "embedding": + return False + if parts[-2] == "distance_expansion": + return False + assert parts[-1] == "weight" + return True + + +# =========================================================================== +# Model classes +# =========================================================================== + class BatchWrapper(nn.Module): def __init__(self, model): super().__init__() @@ -62,121 +201,44 @@ def forward(self, pos, lengths, **kwargs) -> Tuple[Tensor, Tensor]: batch_nums = dataset.make_batch_nums(len(pos), lengths) batch_nums = batch_nums.to(pos.device) assert batch_nums.device == pos.device - kwargs["pos"] = pos for k, v in kwargs.items(): kwargs[k] = flatten_first(v) - - #TODO: It would be better if the term lengths were also python lists like the batch lengths to avoid the round trip to the GPU and back if "bonds" in kwargs: kwargs["bonds"] = kwargs["bonds"] + make_term_offsets(lengths, kwargs.pop("len_bonds").cpu()).to(pos.device) if "angles" in kwargs: kwargs["angles"] = kwargs["angles"] + make_term_offsets(lengths, kwargs.pop("len_angles").cpu()).to(pos.device) if "dihedrals" in kwargs: kwargs["dihedrals"] = kwargs["dihedrals"] + make_term_offsets(lengths, kwargs.pop("len_dihedrals").cpu()).to(pos.device) - kwargs["batch"] = batch_nums result = self.model(**kwargs) if len(result) == 2: result = [*result, {}] - return result #pyright: ignore[reportReturnType] + return result # pyright: ignore[reportReturnType] -class TermDef(): - def __init__(self, path=None, conf=None): - self.scales = {} - self.angle_wrap = {} +class TermDef: + def __init__(self, path=None, conf=None): + self.scales: Dict[str, float] = {} + self.angle_wrap: Dict[str, bool] = {} if path: - with open(path, 'r') as file: - conf = yaml.safe_load(file) - + with open(path, 'r') as f: + conf = yaml.safe_load(f) if conf: for k, v in conf.items(): - if v is not None and "scale" in v: - self.scales[k] = float(v["scale"]) - else: - self.scales[k] = 1.0 - - if v is not None and "angle_wrap" in v: - self.angle_wrap[k] = bool(v["angle_wrap"]) - else: - self.angle_wrap[k] = False + self.scales[k] = float(v["scale"]) if v is not None and "scale" in v else 1.0 + self.angle_wrap[k] = bool(v["angle_wrap"]) if v is not None and "angle_wrap" in v else False - def get_names(self): + def get_names(self) -> List[str]: return list(self.scales.keys()) - def get_scale(self, name): + def get_scale(self, name: str) -> float: return self.scales[name] - def get_angle_wrap(self, name): + def get_angle_wrap(self, name: str) -> bool: return self.angle_wrap[name] -def deterministic_shuffle(target, seed): - generator = torch.Generator().manual_seed(seed) - indices = torch.randperm(n=len(target), generator=generator, device="cpu") - return [target[i] for i in indices] - -def check_early_stopping(val_list, patience=1): - """Return True if the number of epochs with increasing val_loss > patience. If patience < 0 always return False.""" - if patience < 0: - return False - if len(val_list) < patience+2: - return False - check_range = np.array(val_list[-(patience+2):]) - if np.all((check_range[1:]-check_range[:-1])>0): - print(f"Validation loss increased {patience+1} times, stopping...") - return True - -def save_checkpoint(checkpoint_path, epoch, model, optimizer, model_conf, scheduler, extra=None): - checkpoint_dict = { - "epoch":epoch, - "optimizer":optimizer.state_dict(), - "state_dict":model.state_dict(), - "hyper_parameters":model_conf, - } - - if scheduler: - checkpoint_dict["scheduler"] = scheduler.state_dict() - - if extra: - checkpoint_dict["extra"] = extra - - torch.save(checkpoint_dict, checkpoint_path) - -def gen_dataloaders(directory_path, pdb_list, energy_filename, embedding_filename, use_npfile, enable_shuffle, val_ratio, batch_size, atoms_per_call): - print("Dataset:", " ".join(pdb_list)) - - all_data = dataset.ProteinDataset(directory_path, pdb_list, energy_file=energy_filename, embeddings_file=embedding_filename, use_npfile=use_npfile) - # num_proteins = all_data.num_proteins() - - assert val_ratio > 0.0 and val_ratio < 1.0 - val_size = int(val_ratio * len(all_data)) - train_size = len(all_data) - val_size - - if enable_shuffle: - # Generate the test and validation split with deterministic indices - generator1 = torch.Generator().manual_seed(12341234) - val_idx, train_idx = torch.utils.data.random_split(torch.arange(len(all_data)), [val_size, train_size], generator=generator1) #pyright: ignore[reportArgumentType] - else: - # Data was pre-shuffled during preprocess and can be read sequentially - train_idx = range(train_size) - val_idx = range(train_size, train_size+val_size) - train = torch.utils.data.Subset(all_data, train_idx) #pyright: ignore[reportArgumentType] - val = torch.utils.data.Subset(all_data, val_idx) #pyright: ignore[reportArgumentType] - - collate_fn = dataset.ProteinBatchCollate(atoms_per_call) - - train_data = DataLoader(train, batch_size=batch_size, shuffle=False, num_workers=4, - persistent_workers=True, pin_memory=True, collate_fn=collate_fn) - val_data = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=4, - persistent_workers=True, pin_memory=True, collate_fn=collate_fn) - - # print(f"Number of proteins in the dataset: {num_proteins}") - # print(f"Using periodic box: {all_data.has_box()}") - - return all_data, train_data, val_data - class RoundRobinDataWrapper: def __init__(self, *iterables): self.iterables = iterables @@ -185,617 +247,886 @@ def __len__(self): return sum(map(len, self.iterables)) def __iter__(self): - # From https://docs.python.org/3/library/itertools.html#itertools-recipes iterators = map(iter, self.iterables) for num_active in range(len(self.iterables), 0, -1): iterators = itertools.cycle(itertools.islice(iterators, num_active)) yield from map(next, iterators) -def train_model(directory_path, conf_path, result_directory, dry_run, gpu_ids, - weight_decay, learning_rate, epochs, batch_size, val_ratio, atoms_per_call, - scheduler, reset_early_stopping, enable_shuffle, mini_epoch_size, early_stopping, - checkpoint_save, subsetpdbs, energy_weight, force_weight, energy_matching, train_term_def, - embedding_filename, dataset_chunk_size, use_npfile): - with open(os.path.join(directory_path, "result", subsetpdbs), 'r') as file: - pdb_list = file.read().split('\n') +# =========================================================================== +# Data loading +# =========================================================================== - # Remove duplicates and empty strings - pdb_list = sorted(list(set([i for i in pdb_list if i]))) - pdb_list = deterministic_shuffle(pdb_list, seed=47563537) +def gen_dataloaders(directory_path, pdb_list, energy_filename, embedding_filename, + use_npfile, enable_shuffle, val_ratio, batch_size, atoms_per_call, + num_workers, ddp: DistributedContext): + """Build train/val DataLoaders for one PDB chunk, with optional DistributedSampler.""" + print("Dataset:", " ".join(pdb_list)) + all_data = dataset.ProteinDataset( + directory_path, pdb_list, + energy_file=energy_filename, embeddings_file=embedding_filename, use_npfile=use_npfile, + ) + assert 0.0 < val_ratio < 1.0 + val_size = int(val_ratio * len(all_data)) + train_size = len(all_data) - val_size + + if enable_shuffle: + gen = torch.Generator().manual_seed(12341234) + val_idx, train_idx = torch.utils.data.random_split( + torch.arange(len(all_data)), [val_size, train_size], generator=gen + ) # pyright: ignore[reportArgumentType] + else: + train_idx = range(train_size) + val_idx = range(train_size, train_size + val_size) + train_set = torch.utils.data.Subset(all_data, train_idx) # pyright: ignore[reportArgumentType] + val_set = torch.utils.data.Subset(all_data, val_idx) # pyright: ignore[reportArgumentType] + collate_fn = dataset.ProteinBatchCollate(atoms_per_call) + + train_sampler = ( + DistributedSampler(train_set, num_replicas=ddp.world_size, rank=ddp.rank, shuffle=False) + if ddp.is_distributed else None + ) + val_sampler = ( + DistributedSampler(val_set, num_replicas=ddp.world_size, rank=ddp.rank, shuffle=False) + if ddp.is_distributed else None + ) + + train_loader = DataLoader( + train_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, + persistent_workers=True, pin_memory=True, collate_fn=collate_fn, sampler=train_sampler, + ) + val_loader = DataLoader( + val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, + persistent_workers=True, pin_memory=True, collate_fn=collate_fn, sampler=val_sampler, + ) + return all_data, train_loader, val_loader, train_sampler + + +def _load_pdb_list(directory_path, subsetpdbs, dataset_chunk_size): + """Read, deduplicate, shuffle, and optionally chunk the PDB ID list.""" + if subsetpdbs is None: + subsetpdbs = os.path.join(directory_path, "result", "ok_list.txt") + with open(subsetpdbs, 'r') as f: + pdb_list = f.read().split('\n') + pdb_list = sorted(set(p for p in pdb_list if p)) + pdb_list = deterministic_shuffle(pdb_list, seed=47563537) if dataset_chunk_size is not None: pdb_lists = [pdb_list[i:i + dataset_chunk_size] for i in range(0, len(pdb_list), dataset_chunk_size)] else: pdb_lists = [pdb_list] + return pdb_lists, pdb_list - # Load all proteins into a datasets - energy_filename = None - if energy_matching: - energy_filename = "tica_delta_energies.npy" - if embedding_filename is None: - embedding_filename = "embeddings.npy" - - datasets = [] - train_dataloaders = [] - val_dataloaders = [] +def _build_all_dataloaders(pdb_lists, directory_path, energy_filename, embedding_filename, + use_npfile, enable_shuffle, val_ratio, batch_size, atoms_per_call, + num_workers, + ddp: DistributedContext): + """Build datasets and merged RoundRobin loaders for all PDB chunks.""" + datasets, train_loaders, val_loaders, train_samplers = [], [], [], [] for pdb_chunk in pdb_lists: - ds, train_loader, val_loader = gen_dataloaders(directory_path, pdb_chunk, energy_filename, embedding_filename, use_npfile, enable_shuffle, val_ratio, batch_size, atoms_per_call) + ds, train_loader, val_loader, sampler = gen_dataloaders( + directory_path, pdb_chunk, energy_filename, embedding_filename, + use_npfile, enable_shuffle, val_ratio, batch_size, atoms_per_call, num_workers, ddp, + ) datasets.append(ds) - train_dataloaders.append(train_loader) - val_dataloaders.append(val_loader) + train_loaders.append(train_loader) + val_loaders.append(val_loader) + train_samplers.append(sampler) + train_data = RoundRobinDataWrapper(*train_loaders) + val_data = RoundRobinDataWrapper(*val_loaders) + return datasets, train_data, val_data, train_samplers + - train_data = RoundRobinDataWrapper(*train_dataloaders) - val_data = RoundRobinDataWrapper(*val_dataloaders) +# =========================================================================== +# Model setup +# =========================================================================== - # Create the model +def _load_model_conf(conf_path, ddp: DistributedContext) -> dict: + """Load and print the model architecture YAML config.""" if conf_path is None: conf_path = "../configs/config.yaml" - with open(conf_path, 'r') as file: - conf = yaml.safe_load(file) - print("Config:\n", conf, "\n") + with open(conf_path, 'r') as f: + conf = yaml.safe_load(f) + if ddp.is_main: + print("Config:\n", conf, "\n") + return conf - if conf.get("external_embedding_channels") == None and embedding_filename != "embeddings.npy": - print("WARNING: external embeddings usually should use graph-network-ext network") - - # Set the network to return the harmonic term info if we're training them - if "harmonic_net" in conf and train_term_def.get_names(): - conf["harmonic_net_return_terms"] = True - - model = create_model(args=conf) - # We need to construct DataParallel and move the model to CUDA before - # initializing the optimizer or we get "Expected all tensors to be on the same device" - # errors. When exactly this error happens depends on how many GPUs are used and whether - # we're loading a checkpoint or not. - if gpu_ids == "cpu": - parallel_model = BatchWrapper(model) - device_src = "cpu" - device_output = "cpu" - print("Training on CPU") - else: - parallel_model = nn.DataParallel(BatchWrapper(model), device_ids=gpu_ids) - device_src = parallel_model.src_device_obj - device_output = parallel_model.output_device - print(f"DataParallel: Training on {len(parallel_model.device_ids)} GPU(s)") - model.to(device_src) - print("Model:\n", model, "\n") +def _augment_datasets(conf, datasets, train_term_def: TermDef, use_force_weights: bool, + embedding_filename: str, ddp: DistributedContext) -> List[str]: + """Add optional features (sequences, classical terms, force weights) to every dataset.""" + if conf.get("external_embedding_channels") is None and embedding_filename != "embeddings.npy": + if ddp.is_main: + print("WARNING: external embeddings usually should use graph-network-ext network") - extra_train_terms = [] + extra_train_terms: List[str] = [] - # Add additional features to the dataset if the model requires them if "sequence_basis_radius" in conf: - print(f"Adding sequences to dataset... (sequence_basis_radius={conf['sequence_basis_radius']})") + if ddp.is_main: + print(f"Adding sequences to dataset... (sequence_basis_radius={conf['sequence_basis_radius']})") for d in datasets: d.build_sequences() if "harmonic_net" in conf: - print(f"Adding classical terms to dataset... (harmonic_net={conf['harmonic_net']})") + if ddp.is_main: + print(f"Adding classical terms to dataset... (harmonic_net={conf['harmonic_net']})") for d in datasets: d.build_classical_terms() - - if train_term_def.get_names(): - # FIXME: Rename this to more generic - harmonic_trained_terms = train_term_def.get_names() - print(f"Loading additional trained terms: {harmonic_trained_terms}") + if train_term_def.get_names(): + harmonic_trained_terms = train_term_def.get_names() + if ddp.is_main: + print(f"Loading additional trained terms: {harmonic_trained_terms}") + print(f" Term Scales: {[train_term_def.get_scale(t) for t in harmonic_trained_terms]}") + print(f" Term Angle Wrap: {[train_term_def.get_angle_wrap(t) for t in harmonic_trained_terms]}") + for d in datasets: + d.load_frame_terms(harmonic_trained_terms) + extra_train_terms.extend(harmonic_trained_terms) + + if use_force_weights: + if ddp.is_main: + print("Loading force weights...") for d in datasets: - d.load_frame_terms(harmonic_trained_terms) - extra_train_terms.extend(harmonic_trained_terms) - print(f" Term Scales: {[train_term_def.get_scale(i) for i in harmonic_trained_terms]}") - print(f" Term Angle Wrap: {[train_term_def.get_angle_wrap(i) for i in harmonic_trained_terms]}") + d.load_frame_terms(["forces_weights"]) + + return extra_train_terms + + +def build_parallel_model(model, gpu_ids, ddp: DistributedContext): + """ + Wrap model for distributed/parallel execution. + Returns (parallel_model, device, device_output). + device — where weights live; use for checkpoints and all_reduce + device_output — where model outputs land; use for moving targets + """ + wrapped = BatchWrapper(model) + if ddp.is_distributed: + device = torch.device(f'cuda:{ddp.local_rank}') + model.to(device) + parallel = DDP(wrapped, device_ids=[ddp.local_rank], output_device=ddp.local_rank, + find_unused_parameters=True) + device_output = device + if ddp.is_main: + print(f"DDP: {ddp.world_size} GPUs (local_rank={ddp.local_rank})") + elif gpu_ids == "cpu": + device = torch.device("cpu") + model.to(device) + parallel = wrapped + device_output = device + print("Training on CPU") + else: + device = torch.device(f'cuda:{gpu_ids[0]}') + model.to(device) + parallel = nn.DataParallel(wrapped, device_ids=gpu_ids) + device_output = parallel.output_device + print(f"DataParallel: {len(parallel.device_ids)} GPU(s)") + return parallel, device, device_output - print() - criterion = nn.MSELoss() - term_criterion = nn.MSELoss(reduction="none") +# =========================================================================== +# Optimizer +# =========================================================================== - do_decay = [] - dont_decay = [] +def _build_optimizer(model, weight_decay: float, learning_rate: float): + """AdamW with weight decay applied only to non-embedding, non-bias parameters.""" + do_decay, dont_decay = [], [] for name, param in model.named_parameters(): - if should_decay(name): - do_decay.append(param) - else: - dont_decay.append(param) - - optimizer = optim.AdamW( - [ - {"params": do_decay, "weight_decay": weight_decay}, - {"params": dont_decay} - ], - lr=learning_rate) - + (do_decay if should_decay(name) else dont_decay).append(param) + return optim.AdamW( + [{"params": do_decay, "weight_decay": weight_decay}, {"params": dont_decay}], + lr=learning_rate, + ) + + +# =========================================================================== +# Checkpoint I/O +# =========================================================================== + +def save_checkpoint(checkpoint_path, epoch, model, optimizer, model_conf, scheduler, + extra=None, ddp: Optional[DistributedContext] = None): + """Save checkpoint. In distributed mode only rank 0 writes.""" + if ddp is not None and not ddp.is_main: + return + d = { + "epoch": epoch, + "optimizer": optimizer.state_dict(), + "state_dict": model.state_dict(), + "hyper_parameters": model_conf, + } if scheduler: - scheduler.initialize(optimizer) + d["scheduler"] = scheduler.state_dict() + if extra: + d["extra"] = extra + torch.save(d, checkpoint_path) + - epoch_resume = None +def _load_checkpoint_state(result_directory, model, optimizer, scheduler, device, + ddp: DistributedContext): + """Find and load the best available checkpoint. Returns (epoch, epoch_resume).""" checkpoint_path = None - if os.path.exists(f'{result_directory}/checkpoint-mini.pth'): + if result_directory and os.path.exists(f'{result_directory}/checkpoint-mini.pth'): checkpoint_path = f'{result_directory}/checkpoint-mini.pth' - elif os.path.exists(f'{result_directory}/checkpoint.pth'): + elif result_directory and os.path.exists(f'{result_directory}/checkpoint.pth'): checkpoint_path = f'{result_directory}/checkpoint.pth' - print("checkpoint_path", checkpoint_path) - if checkpoint_path: + if ddp.is_main: + print("checkpoint_path", checkpoint_path) + + if not checkpoint_path: + return 0, None + + if ddp.is_main: print("Resuming:", result_directory) - checkpoint = torch.load(checkpoint_path, weights_only=False, map_location=device_src) - model_util.load_state_dict_with_rename(model, checkpoint["state_dict"]) - if "optimizer" in checkpoint and checkpoint["optimizer"] is not None: - optimizer.load_state_dict(checkpoint["optimizer"]) - else: - print(" No optimizer in checkpoint, resetting...") + checkpoint = torch.load(checkpoint_path, weights_only=False, map_location=device) + model_util.load_state_dict_with_rename(model, checkpoint["state_dict"]) - if scheduler and "scheduler" in checkpoint and checkpoint["scheduler"] is not None: - scheduler.load_state_dict(checkpoint["scheduler"]) + if "optimizer" in checkpoint and checkpoint["optimizer"] is not None: + optimizer.load_state_dict(checkpoint["optimizer"]) + elif ddp.is_main: + print(" No optimizer in checkpoint, resetting...") - if "extra" in checkpoint: # This was a mini-checkpoint - epoch_resume = checkpoint["extra"] + if scheduler and "scheduler" in checkpoint and checkpoint["scheduler"] is not None: + scheduler.load_state_dict(checkpoint["scheduler"]) - if "epoch" in checkpoint: - epoch = checkpoint["epoch"] - else: - epoch = 0 + return checkpoint.get("epoch", 0), checkpoint.get("extra") + + +def _commit_epoch_checkpoint(epoch, model, optimizer, conf, scheduler, result_directory, + val_loss_list, history, checkpoint_save, ddp: DistributedContext): + """Atomically promote the epoch checkpoint and optionally save a best/backup copy.""" + tmp_path = f'{result_directory}/checkpoint-{epoch}.pth' + save_checkpoint(tmp_path, epoch + 1, model, optimizer, conf, scheduler, ddp=ddp) + + if not ddp.is_main: + return + + if os.path.exists(f'{result_directory}/checkpoint-mini.pth'): + os.unlink(f'{result_directory}/checkpoint-mini.pth') + + if checkpoint_save and (epoch % checkpoint_save == 0): + shutil.copyfile(tmp_path, f'{result_directory}/checkpoint.pth') else: - if not result_directory or not os.path.exists(result_directory): - if not result_directory: - result_directory = "../data/result-" + datetime.datetime.now().strftime("%Y.%m.%d-%H.%M.%S") - if not dry_run: - os.makedirs(result_directory, exist_ok=False) - else: - assert os.path.exists(result_directory) == False, "Result directory exists but is invalid" + os.replace(tmp_path, f'{result_directory}/checkpoint.pth') + + if val_loss_list[-1] <= np.min(val_loss_list): + shutil.copyfile(f'{result_directory}/checkpoint.pth', f'{result_directory}/checkpoint-best.pth') + + np.save(f'{result_directory}/history.npy', history) # pyright: ignore[reportArgumentType] + print(" Checkpoint saved.") + + +def _save_mini_epoch_checkpoint(epoch, i, model, optimizer, conf, scheduler, + total_loss, num_cal, mini_train_loss, mini_num_cal, + train_data, result_directory, epoch_history, + ddp: DistributedContext): + """Save a mid-epoch checkpoint and record the mini-epoch in epoch_history.""" + tmp_path = f'{result_directory}/checkpoint-{epoch}-{i}.pth' + save_checkpoint(tmp_path, epoch, model, optimizer, conf, scheduler, + extra={"train_loss": total_loss, "num_cal": num_cal, "i": i}, ddp=ddp) + if not ddp.is_main: + return + os.replace(tmp_path, f'{result_directory}/checkpoint-mini.pth') + epoch_history[f"{epoch}-{i}"] = { + "train_loss": total_loss / num_cal, + "mini_train_loss": mini_train_loss / mini_num_cal, + "epoch_len": len(train_data), + "lr": [g['lr'] for g in optimizer.param_groups], + } + with open(os.path.join(result_directory, "epoch_history.json"), "w") as f: + json.dump(epoch_history, f, indent=2) + + +# =========================================================================== +# Result directory and history +# =========================================================================== + +def _setup_result_directory(result_directory, epoch, dry_run, ddp: DistributedContext) -> str: + """Create the result directory for a fresh run, or validate it for resumption.""" + if epoch > 0: + pass # Resuming — directory already exists + elif not os.path.exists(result_directory): + if not dry_run and ddp.is_main: + os.makedirs(result_directory, exist_ok=False) + if ddp.is_main: print("Created:", result_directory) - elif os.path.exists(f'{result_directory}/training_info.json'): - # Most likely the training started but was canceled/crashed before the first epoch finished + elif os.path.exists(f'{result_directory}/training_info.json'): + if ddp.is_main: print("Re-initializing:", result_directory) - else: - raise RuntimeError("Model directory exists but doesn't contain a checkpoint.pth or training_info.json file") - epoch = 0 + else: + raise RuntimeError("Model directory exists but doesn't contain a checkpoint.pth or training_info.json file") - epoch_history = {} - train_loss_list = [] - val_loss_list = [] - energy_loss_list = [] - force_loss_list = [] + if ddp.is_main: + print("Saving to:", result_directory) + return result_directory + + +def _load_history(result_directory, epoch): + """Load loss history arrays and epoch_history JSON from a previous run.""" + train_loss_list, val_loss_list, energy_loss_list, force_loss_list = [], [], [], [] + epoch_history: dict = {} if epoch > 0: - # Load the numpy history files history = np.load(f'{result_directory}/history.npy', allow_pickle=True).item() train_loss_list = history['train'] val_loss_list = history['val'] energy_loss_list = history['energy'] force_loss_list = history['force'] - # Might exist before epoch 1 if mini-checkpoints were saved epoch_history_path = os.path.join(result_directory, "epoch_history.json") if os.path.exists(epoch_history_path): with open(epoch_history_path, "r") as f: epoch_history = json.load(f) - print("Saving to:", result_directory) + return train_loss_list, val_loss_list, energy_loss_list, force_loss_list, epoch_history + + +def _write_training_info(result_directory, epoch, weight_decay, learning_rate, epochs, batch_size, + directory_path, pdb_list, energy_weight, force_weight, embedding_filename, + ddp: DistributedContext, scheduler, optimizer, dry_run): + """Write training_info.json and copy prior files. Only runs on rank 0.""" + if not ddp.is_main: + return - # Document training parameters and input data training_info_path = os.path.join(result_directory, "training_info.json") - training_info_dict = {} + training_info_dict: dict = {} if os.path.exists(training_info_path): with open(training_info_path, "r") as f: training_info_dict = json.load(f) - - # Check for the old dict format and update it - if "input_directory" in training_info_dict.keys(): + if "input_directory" in training_info_dict: training_info_dict = {"0": training_info_dict} else: print("Path", training_info_path, "does not exist") - # TODO: Only add a new entry if the parameters have changed? training_info_dict[str(epoch)] = { - "weight_decay" : weight_decay, - "learning_rate" : learning_rate, - "epochs" : epochs, - "batch_size" : batch_size, - "input_directory" : directory_path, - "pdbs" : pdb_list, + "weight_decay": weight_decay, + "learning_rate": learning_rate, + "epochs": epochs, + "batch_size": batch_size, + "input_directory": directory_path, + "pdbs": pdb_list, "energy_weight": energy_weight, "force_weight": force_weight, - "embedding_filename" : embedding_filename, + "embedding_filename": embedding_filename, + "world_size": ddp.world_size if ddp.world_size else 1, } + if scheduler: training_info_dict[str(epoch)]["lr_scheduler"] = repr(scheduler) else: - # If there's no scheduler reset the learning of the optimizer to the passed value for g in optimizer.param_groups: g['lr'] = learning_rate - if not dry_run: with open(training_info_path, "w") as f: json.dump(training_info_dict, f, indent=2) - - # Save the validation frame indices - # FIXME: This isn't compatible with chunking - #np.save(os.path.join(result_directory, "validation_frames.npy"), np.array(val_idx)) - - # Save the prior with the model prior_path = os.path.join(directory_path, "priors.yaml") if os.path.exists(prior_path): - prior_params_path = os.path.join(directory_path, "prior_params.json") shutil.copy(prior_path, result_directory) - shutil.copy(prior_params_path, result_directory) - - # Disable earlly stopping when using an annealing (cycling) schedualer - if scheduler and scheduler.is_annealing(): - early_stopping = -1 - - first_early_stopping_epoch = 0 - if reset_early_stopping == True: - first_early_stopping_epoch = epoch + shutil.copy(os.path.join(directory_path, "prior_params.json"), result_directory) + + +# =========================================================================== +# Per-epoch helpers +# =========================================================================== + +def _set_sampler_epochs(train_samplers, epoch): + """Notify DistributedSamplers of the current epoch so shuffling stays deterministic.""" + for sampler in train_samplers: + if sampler is not None: + sampler.set_epoch(epoch) + + +def _update_loss_history(train_loss_list, val_loss_list, energy_loss_list, force_loss_list, + train_metrics: EpochMetrics, val_metrics: EpochMetrics): + """Append per-epoch means to the running history lists.""" + train_loss_list.append(train_metrics.mean_loss()) + energy_loss_list.append(train_metrics.mean_energy_loss()) + force_loss_list.append(train_metrics.mean_force_loss()) + val_loss_list.append(val_metrics.mean_loss()) + + +def _log_epoch_results(ddp: DistributedContext, epoch, train_loss_list, val_loss_list, + train_metrics: EpochMetrics, val_metrics: EpochMetrics, + extra_train_terms, epoch_history, optimizer, t0, train_data, + result_directory): + """Record epoch stats in epoch_history, write JSON, and print a summary.""" + if not ddp.is_main: + return + + epoch_history[f"{epoch}"] = { + "train_loss": train_loss_list[-1], + "val_loss": val_loss_list[-1], + "energy_loss": train_metrics.mean_energy_loss(), + "force_loss": train_metrics.mean_force_loss(), + "epoch_len": len(train_data), + "lr": [g['lr'] for g in optimizer.param_groups], + } + for k in extra_train_terms: + epoch_history[f"{epoch}"][f"train_loss_{k}"] = train_metrics.mean_term_loss(k) + epoch_history[f"{epoch}"][f"val_loss_{k}"] = val_metrics.mean_term_loss(k) - verbose_loss_report = sys.stdout.isatty() + with open(os.path.join(result_directory, "epoch_history.json"), "w") as f: + json.dump(epoch_history, f, indent=2) - while epoch < epochs: - t0 = time.time() - model.train() - train_loss = 0 - train_energy_loss = 0 - train_force_loss = 0 - num_cal = 0 # The total number of elements trained on - epoch_offset = 0 - mini_train_loss = 0 - mini_num_cal = 0 - - train_term_losses = {k: 0.0 for k in extra_train_terms} - train_term_num_cal = {k: 0 for k in extra_train_terms} - - if epoch_resume: + print(f"Epoch {epoch} - Train Loss: {train_loss_list[-1]} - Val Loss: {val_loss_list[-1]} - time: {round(time.time() - t0, 2)}s") + if epoch > 0: + print(f" ∆Train: {train_loss_list[-1] - train_loss_list[-2]} - ∆Val: {val_loss_list[-1] - val_loss_list[-2]}") + for k in extra_train_terms: + print(f" Train {k} loss={train_metrics.mean_term_loss(k):.4f}") + print(f" Val {k} loss={val_metrics.mean_term_loss(k):.4f}") + + +def _check_early_stop(ddp: DistributedContext, val_loss_list, first_early_stopping_epoch, + early_stopping, device) -> bool: + """Evaluate early stopping on rank 0 then broadcast the decision to all ranks.""" + should_stop = False + if ddp.is_main: + should_stop = check_early_stopping(val_loss_list[first_early_stopping_epoch:], patience=early_stopping) + should_stop = ddp.broadcast_bool(should_stop, device) + if should_stop and ddp.is_main: + print("Early stopping triggered.") + return should_stop + + +# =========================================================================== +# Forward pass (shared between train and val) +# =========================================================================== + +def _process_sub_batch(sub_batch, parallel_model, criterion, term_criterion, + device_output, loss_cfg: LossConfig, + total_batch_size: int, total_term_batch_size: Dict[str, int]): + """ + Forward one sub-batch through the model and compute weighted losses. + Returns (loss_tensor, loss_sum, energy_loss_sum, force_loss_sum, term_loss_sums). + *_sum values are pre-multiplied by total_batch_size for direct accumulation. + Call loss_tensor.backward() in the training loop. + """ + force = sub_batch.pop("force") + force = force.reshape(-1, force.shape[-1]).to(device_output) + + force_weights = None + if loss_cfg.use_force_weights: + force_weights = sub_batch.pop("forces_weights").reshape(-1).to(device_output) + + energy = None + if loss_cfg.energy_matching: + energy = sub_batch.pop("energy") + energy = energy.reshape(-1, energy.shape[-1]).to(device_output) + + term_targets = {k: sub_batch.pop(k).flatten().to(device_output) for k in loss_cfg.extra_term_keys} + + out_energy, out_force, extra = parallel_model(**sub_batch) + + sub_batch_size = force.numel() + weight = sub_batch_size / total_batch_size + + energy_loss = torch.tensor(0.0) + if loss_cfg.energy_matching: + energy_loss = criterion(out_energy, energy) * weight + + fl = term_criterion(out_force, force) + if force_weights is not None: + fl = fl * force_weights[:, None] + force_loss = fl.mean() * weight + + loss = loss_cfg.energy_weight * energy_loss + loss_cfg.force_weight * force_loss + + term_loss_sums: Dict[str, float] = {} + for k in loss_cfg.extra_term_keys: + td = loss_cfg.train_term_def + if td.get_angle_wrap(k): + tl = (extra[k] - term_targets[k] + torch.pi) % (2 * torch.pi) - torch.pi + tl = tl ** 2 + else: + tl = term_criterion(extra[k], term_targets[k]) + tl = tl / total_term_batch_size[k] + tl = tl * (term_targets[k] >= -10).float() + tl = torch.sum(tl) + loss = loss + tl * td.get_scale(k) + term_loss_sums[k] = tl.item() * total_term_batch_size[k] + + return loss, loss.item() * total_batch_size, energy_loss.item() * total_batch_size, force_loss.item() * total_batch_size, term_loss_sums + + +def _reduce_to_epoch_metrics(total_loss, energy_loss, force_loss, num_cal, + term_losses, term_num_cal, extra_term_keys, + ddp: DistributedContext, device) -> EpochMetrics: + """All-reduce raw accumulators across ranks and package into EpochMetrics.""" + return EpochMetrics( + total_loss=ddp.all_reduce_scalar(total_loss, device), + energy_loss=ddp.all_reduce_scalar(energy_loss, device), + force_loss=ddp.all_reduce_scalar(force_loss, device), + num_cal=int(ddp.all_reduce_scalar(num_cal, device)), + term_losses={k: ddp.all_reduce_scalar(term_losses[k], device) for k in extra_term_keys}, + term_num_cal={k: int(ddp.all_reduce_scalar(term_num_cal[k], device)) for k in extra_term_keys}, + ) + + +# =========================================================================== +# Epoch runners +# =========================================================================== + +def run_train_epoch( + epoch, epochs, train_data, parallel_model, model, + optimizer, scheduler, criterion, term_criterion, + device, device_output, loss_cfg: LossConfig, ddp: DistributedContext, + dry_run: bool, verbose_loss_report: bool, + mini_epoch_size, epoch_resume, result_directory, conf, epoch_history, +) -> EpochMetrics: + parallel_model.train() + + total_loss = 0.0 + energy_loss = 0.0 + force_loss = 0.0 + num_cal = 0 + mini_train_loss = 0.0 + mini_num_cal = 0 + epoch_offset = 0 + term_losses: Dict[str, float] = {k: 0.0 for k in loss_cfg.extra_term_keys} + term_num_cal: Dict[str, int] = {k: 0 for k in loss_cfg.extra_term_keys} + + if epoch_resume: + if ddp.is_main: print("Resuming epoch...") - train_loss = float(epoch_resume["train_loss"]) - num_cal = int(epoch_resume["num_cal"]) - epoch_offset = int(epoch_resume["i"]) - epoch_resume = None - - - # Setting miniters is required to keep the bar from stalling after skipping ahead while resuming a batch - tqdm_iter = tqdm(enumerate(train_data), desc=f"Training ({epoch}/{epochs})", total=len(train_data), dynamic_ncols=True, miniters=1) - for i, batch in tqdm_iter: - # Handle mini-batches - if i < epoch_offset: - # TODO: It's very wasteful to load everything then discard it, but the alternative requires making a 2nd dataset object... - continue - elif epoch_offset and i == epoch_offset: + total_loss = float(epoch_resume["train_loss"]) + num_cal = int(epoch_resume["num_cal"]) + epoch_offset = int(epoch_resume["i"]) + + # Setting miniters is required to keep the bar from stalling when skipping ahead on resume + tqdm_iter = ( + tqdm(enumerate(train_data), desc=f"Training ({epoch}/{epochs})", + total=len(train_data), dynamic_ncols=True, miniters=1) + if ddp.is_main else enumerate(train_data) + ) + + for i, batch in tqdm_iter: + if i < epoch_offset: + continue + elif epoch_offset and i == epoch_offset: + if ddp.is_main and hasattr(tqdm_iter, 'write'): tqdm_iter.write(f"Resumed epoch at batch {i}") - elif mini_epoch_size and 0 == i % mini_epoch_size and i > 0: - tmp_checkpoint_path = f'{result_directory}/checkpoint-{epoch}-{i}.pth' - save_checkpoint(tmp_checkpoint_path, epoch, model, optimizer, conf, scheduler, extra = {"train_loss":train_loss, "num_cal":num_cal, "i":i}) - os.replace(tmp_checkpoint_path, f'{result_directory}/checkpoint-mini.pth') - - epoch_history[f"{epoch}-{i}"] = { - "train_loss":train_loss/num_cal, - "mini_train_loss":mini_train_loss/mini_num_cal, - "epoch_len":len(train_data), - "lr":[g['lr'] for g in optimizer.param_groups], - } - with open(epoch_history_path, "w") as f: - json.dump(epoch_history, f, indent=2) - tqdm_iter.write(f"Mini-epoch {epoch}-{i}: Train Loss: {train_loss/num_cal}") - - mini_train_loss = 0 - mini_num_cal = 0 - total_batch_size = sum([i["force"].numel() for i in batch]) - num_cal += total_batch_size - mini_num_cal += total_batch_size - - total_term_batch_size = {k: sum([i[k].numel() for i in batch]) for k in train_term_losses} - for k in train_term_num_cal.keys(): - train_term_num_cal[k] += total_term_batch_size[k] - - optimizer.zero_grad() - for sub_batch in batch: - force = sub_batch.pop("force") - force = force.reshape(-1, force.shape[-1]).to(device_output) - energy = None - if energy_matching: - energy = sub_batch.pop("energy") - energy = energy.reshape(-1, energy.shape[-1]).to(device_output) - - term_targets = {} - for k in train_term_losses.keys(): - term_targets[k] = sub_batch.pop(k).flatten().to(device_output) - - out_energy, out_force, extra = parallel_model(**sub_batch) - - # Scale the sub_batch to be a term in the overall mean of the batch - sub_batch_size = force.numel() - energy_loss: torch.Tensor = torch.tensor(0.0) - if energy_matching: - energy_loss = criterion(out_energy, energy) * (sub_batch_size / total_batch_size) - force_loss = criterion(out_force, force) * (sub_batch_size / total_batch_size) - loss = energy_weight * energy_loss + force_weight * force_loss - - train_force_loss += force_loss.item() * total_batch_size - if energy_matching: - train_energy_loss += (energy_loss.item() * total_batch_size) - - delta_loss = loss.item() * total_batch_size - train_loss += delta_loss - mini_train_loss += delta_loss - - for k in train_term_losses.keys(): - # TODO: Find a more generic way of doing this - if train_term_def.get_angle_wrap(k): - train_term_loss = (extra[k] - term_targets[k] + torch.pi) % (2*torch.pi) - torch.pi - train_term_loss = train_term_loss**2 - else: - train_term_loss = term_criterion(extra[k], term_targets[k]) - - # We don't multiply by numel here because the term loss criterion doesn't do a mean reduction - train_term_loss = train_term_loss / total_term_batch_size[k] - # Mask out undefined values - # TODO: Ensure this is a good threshold value - train_term_loss = train_term_loss * (term_targets[k] >= -10).float() - train_term_loss = torch.sum(train_term_loss) - - loss = loss + train_term_loss*train_term_def.get_scale(k) - - train_term_losses[k] += train_term_loss.item() * total_term_batch_size[k] - - # Accumulate gradient - loss.backward() - - - - optimizer.step() - if scheduler: - scheduler.step_batch(epoch + i/len(train_data)) - if dry_run: - print("\nDry run OK!") - sys.exit(0) + elif mini_epoch_size and i > 0 and i % mini_epoch_size == 0: + _save_mini_epoch_checkpoint( + epoch, i, model, optimizer, conf, scheduler, + total_loss, num_cal, mini_train_loss, mini_num_cal, + train_data, result_directory, epoch_history, ddp, + ) + if ddp.is_main and hasattr(tqdm_iter, 'write'): + tqdm_iter.write(f"Mini-epoch {epoch}-{i}: Train Loss: {total_loss / num_cal}") + mini_train_loss = 0.0 + mini_num_cal = 0 + + batch_size_total = sum(sb["force"].numel() for sb in batch) + batch_term_sizes = {k: sum(sb[k].numel() for sb in batch) for k in loss_cfg.extra_term_keys} + num_cal += batch_size_total + mini_num_cal += batch_size_total + for k in term_num_cal: + term_num_cal[k] += batch_term_sizes[k] + + optimizer.zero_grad() + for sub_batch in batch: + loss, loss_sum, e_sum, f_sum, term_sums = _process_sub_batch( + sub_batch, parallel_model, criterion, term_criterion, + device_output, loss_cfg, batch_size_total, batch_term_sizes, + ) + total_loss += loss_sum + mini_train_loss += loss_sum + energy_loss += e_sum + force_loss += f_sum + for k, v in term_sums.items(): + term_losses[k] += v + loss.backward() + optimizer.step() - if verbose_loss_report: - desc = [f"Training ({epoch}/{epochs}) (T={train_loss/num_cal:.4f}"] - for t in train_term_losses: - desc.append(f"{t}={train_term_losses[t]/train_term_num_cal[t]:.4f}") - desc = ", ".join(desc) + ")" - tqdm_iter.set_description(desc) + if scheduler: + scheduler.step_batch(epoch + i / len(train_data)) + if dry_run: + if ddp.is_main: + print("\nDry run OK!") + sys.exit(0) + + if verbose_loss_report and ddp.is_main and hasattr(tqdm_iter, 'set_description'): + desc = [f"Training ({epoch}/{epochs}) (T={total_loss / num_cal:.4f}"] + for t in loss_cfg.extra_term_keys: + if term_num_cal[t] > 0: + desc.append(f"{t}={term_losses[t] / term_num_cal[t]:.4f}") + tqdm_iter.set_description(", ".join(desc) + ")") + + return _reduce_to_epoch_metrics( + total_loss, energy_loss, force_loss, num_cal, term_losses, term_num_cal, + loss_cfg.extra_term_keys, ddp, device, + ) + + +def run_val_epoch( + epoch, epochs, val_data, parallel_model, + criterion, term_criterion, device, device_output, + loss_cfg: LossConfig, ddp: DistributedContext, +) -> EpochMetrics: + parallel_model.eval() + + total_loss = 0.0 + num_cal = 0 + term_losses: Dict[str, float] = {k: 0.0 for k in loss_cfg.extra_term_keys} + term_num_cal: Dict[str, int] = {k: 0 for k in loss_cfg.extra_term_keys} + + val_iter = ( + tqdm(val_data, desc=f"Validation ({epoch}/{epochs})", total=len(val_data), dynamic_ncols=True) + if ddp.is_main else val_data + ) + + with torch.no_grad(): + for batch in val_iter: + batch_size_total = sum(sb["force"].numel() for sb in batch) + batch_term_sizes = {k: sum(sb[k].numel() for sb in batch) for k in loss_cfg.extra_term_keys} + num_cal += batch_size_total + for k in term_num_cal: + term_num_cal[k] += batch_term_sizes[k] - train_loss_list.append(train_loss/num_cal) - energy_loss_list.append(train_energy_loss/num_cal) - force_loss_list.append(train_force_loss/num_cal) + for sub_batch in batch: + _, loss_sum, _, _, term_sums = _process_sub_batch( + sub_batch, parallel_model, criterion, term_criterion, + device_output, loss_cfg, batch_size_total, batch_term_sizes, + ) + total_loss += loss_sum + for k, v in term_sums.items(): + term_losses[k] += v + + return _reduce_to_epoch_metrics( + total_loss, 0.0, 0.0, num_cal, term_losses, term_num_cal, + loss_cfg.extra_term_keys, ddp, device, + ) + + +# =========================================================================== +# Training orchestration +# =========================================================================== + +def train_model( + directory_path, conf_path, result_directory, dry_run, gpu_ids, + weight_decay, learning_rate, epochs, batch_size, val_ratio, atoms_per_call, + scheduler, reset_early_stopping, enable_shuffle, mini_epoch_size, early_stopping, + checkpoint_save, subsetpdbs, energy_weight, force_weight, energy_matching, + train_term_def, embedding_filename, dataset_chunk_size, use_npfile, use_force_weights, + num_workers, + ddp: Optional[DistributedContext] = None, +): + if ddp is None: + ddp = DistributedContext() + + energy_filename = "tica_delta_energies.npy" if energy_matching else None + if embedding_filename is None: + embedding_filename = "embeddings.npy" - parallel_model.eval() - val_loss = 0 - num_cal = 0 + pdb_lists, pdb_list = _load_pdb_list(directory_path, subsetpdbs, dataset_chunk_size) + datasets, train_data, val_data, train_samplers = _build_all_dataloaders( + pdb_lists, directory_path, energy_filename, embedding_filename, + use_npfile, enable_shuffle, val_ratio, batch_size, atoms_per_call, num_workers, ddp, + ) - val_term_losses = {k: 0.0 for k in train_term_losses} - val_term_num_cal = {k: 0 for k in train_term_num_cal} + conf = _load_model_conf(conf_path, ddp) + if "harmonic_net" in conf and train_term_def.get_names(): + conf["harmonic_net_return_terms"] = True - for batch in tqdm(val_data, desc=f"Validation ({epoch}/{epochs})", total=len(val_data), dynamic_ncols=True): - total_batch_size = sum([i["force"].numel() for i in batch]) - num_cal += total_batch_size + model = create_model(args=conf) + parallel_model, device, device_output = build_parallel_model(model, gpu_ids, ddp) + if ddp.is_main: + print("Model:\n", model, "\n") + + extra_train_terms = _augment_datasets(conf, datasets, train_term_def, use_force_weights, embedding_filename, ddp) + loss_cfg = LossConfig( + energy_matching=energy_matching, energy_weight=energy_weight, force_weight=force_weight, + train_term_def=train_term_def, extra_term_keys=extra_train_terms, use_force_weights=use_force_weights, + ) + criterion = nn.MSELoss() + term_criterion = nn.MSELoss(reduction="none") - total_term_batch_size = {k: sum([i[k].numel() for i in batch]) for k in val_term_losses} - for k in val_term_num_cal.keys(): - val_term_num_cal[k] += total_term_batch_size[k] + optimizer = _build_optimizer(model, weight_decay, learning_rate) + if scheduler: + scheduler.initialize(optimizer) - for sub_batch in batch: - force = sub_batch.pop("force") - force = force.reshape(-1, force.shape[-1]).to(device_output) - energy = None - if energy_matching: - energy = sub_batch.pop("energy") - energy = energy.reshape(-1, energy.shape[-1]).to(device_output) + epoch, epoch_resume = _load_checkpoint_state(result_directory, model, optimizer, scheduler, device, ddp) + result_directory = _setup_result_directory(result_directory, epoch, dry_run, ddp) + ddp.barrier() - term_targets = {} - for k in val_term_losses.keys(): - term_targets[k] = sub_batch.pop(k).flatten().to(device_output) + train_loss_list, val_loss_list, energy_loss_list, force_loss_list, epoch_history = _load_history(result_directory, epoch) + _write_training_info( + result_directory, epoch, weight_decay, learning_rate, epochs, batch_size, + directory_path, pdb_list, energy_weight, force_weight, embedding_filename, + ddp, scheduler, optimizer, dry_run, + ) - out_energy, out_force, extra = parallel_model(**sub_batch) + if scheduler and scheduler.is_annealing(): + early_stopping = -1 + first_early_stopping_epoch = epoch if reset_early_stopping else 0 + verbose_loss_report = sys.stdout.isatty() - sub_batch_size = force.numel() - energy_loss: torch.Tensor = torch.tensor(0.0) - if energy_matching: - energy_loss = criterion(out_energy, energy) * (sub_batch_size / total_batch_size) - force_loss = criterion(out_force, force) * (sub_batch_size / total_batch_size) - loss = energy_weight * energy_loss + force_weight * force_loss + while epoch < epochs: + _set_sampler_epochs(train_samplers, epoch) + t0 = time.time() - val_loss += loss.item() * total_batch_size + train_metrics = run_train_epoch( + epoch, epochs, train_data, parallel_model, model, + optimizer, scheduler, criterion, term_criterion, + device, device_output, loss_cfg, ddp, + dry_run, verbose_loss_report, + mini_epoch_size, epoch_resume, result_directory, conf, epoch_history, + ) + epoch_resume = None - for k in val_term_losses.keys(): - if train_term_def.get_angle_wrap(k): - val_term_loss = (extra[k] - term_targets[k] + torch.pi) % (2*torch.pi) - torch.pi - val_term_loss = val_term_loss**2 - else: - val_term_loss = term_criterion(extra[k], term_targets[k]) - val_term_loss = val_term_loss / total_term_batch_size[k] - val_term_loss = val_term_loss * (term_targets[k] >= -10).float() - val_term_loss = torch.sum(val_term_loss) + val_metrics = run_val_epoch( + epoch, epochs, val_data, parallel_model, + criterion, term_criterion, device, device_output, loss_cfg, ddp, + ) - # loss = loss + val_term_loss*term_val_weight + _update_loss_history(train_loss_list, val_loss_list, energy_loss_list, force_loss_list, + train_metrics, val_metrics) - val_term_losses[k] += val_term_loss.item() * total_term_batch_size[k] + if scheduler: + scheduler.step(val_loss_list[-1]) - val_loss_list.append(val_loss/num_cal) + _log_epoch_results(ddp, epoch, train_loss_list, val_loss_list, train_metrics, val_metrics, + extra_train_terms, epoch_history, optimizer, t0, train_data, result_directory) - if scheduler: - scheduler.step(val_loss/num_cal) - - epoch_history[f"{epoch}"] = { - "train_loss":train_loss_list[-1], - "val_loss":val_loss_list[-1], - "energy_loss":energy_loss_list[-1], - "force_loss":force_loss_list[-1], - "epoch_len":len(train_data), - "lr":[g['lr'] for g in optimizer.param_groups], - } - - for k in extra_train_terms: - epoch_history[f"{epoch}"][f"train_loss_{k}"] = train_term_losses[k]/train_term_num_cal[k] - epoch_history[f"{epoch}"][f"val_loss_{k}"] = val_term_losses[k]/val_term_num_cal[k] - - with open(epoch_history_path, "w") as f: - json.dump(epoch_history, f, indent=2) - - print(f"Epoch {epoch} - Train Loss: {train_loss_list[-1]} - Val Loss: {val_loss_list[-1]} - time: {round(time.time() - t0,2)}s") - if epoch > 0: - print(f" ∆Train: {train_loss_list[-1]-train_loss_list[-2]} - ∆Val: {val_loss_list[-1] - val_loss_list[-2]}") - # print(f" ∆Energy: {energy_loss_list[-1]-energy_loss_list[-2]} - ∆Force: {force_loss_list[-1] - force_loss_list[-2]}") - - for k in val_term_losses.keys(): - print(f" Train {k} loss={train_term_losses[k]/train_term_num_cal[k]:.4f}") - print(f" Val {k} loss={val_term_losses[k]/val_term_num_cal[k]:.4f}") - - if check_early_stopping(val_loss_list[first_early_stopping_epoch:], patience=early_stopping): - print("Early stopping triggered.") + if _check_early_stop(ddp, val_loss_list, first_early_stopping_epoch, early_stopping, device): break history = {"train": train_loss_list, "val": val_loss_list, "energy": energy_loss_list, "force": force_loss_list} - - # Save the model - # I've attempted to make this compatible with the TorchMD calculators.External class, but I'm not sure how well the keys match - Daniel - tmp_checkpoint_path = f'{result_directory}/checkpoint-{epoch}.pth' - save_checkpoint(tmp_checkpoint_path, epoch + 1, model, optimizer, conf, scheduler) - - if os.path.exists(f'{result_directory}/checkpoint-mini.pth'): - os.unlink(f'{result_directory}/checkpoint-mini.pth') - - if checkpoint_save and (epoch % checkpoint_save == 0): - shutil.copyfile(tmp_checkpoint_path, f'{result_directory}/checkpoint.pth') - else: - os.replace(tmp_checkpoint_path, f'{result_directory}/checkpoint.pth') - - # If this is <= to the lowest validation loss seen so far also save it to checkpoint-best.pth - if val_loss_list[-1] <= np.min(val_loss_list): - shutil.copyfile(f'{result_directory}/checkpoint.pth', f'{result_directory}/checkpoint-best.pth') - - # Save the loss history - np.save(f'{result_directory}/history.npy', history)#pyright: ignore[reportArgumentType] - print(" Checkpoint saved.") - + _commit_epoch_checkpoint(epoch, model, optimizer, conf, scheduler, result_directory, + val_loss_list, history, checkpoint_save, ddp) + ddp.barrier() epoch += 1 -def should_decay(param_name: str) -> bool: - #usually something like "representation_model.distance_expansion.means" - #want to not decay the embeddings and the biases - parts = param_name.split('.') - assert len(parts) > 0 - if parts[-1] == "bias": - return False - if parts[-2] == "embedding": - return False - if parts[-2] == "distance_expansion": - #not sure for this - return False - assert parts[-1] == "weight" - return True -if __name__ == "__main__": +# =========================================================================== +# Entry-point helpers: collection → validation → processing +# =========================================================================== + +def build_parser(): import argparse parser = argparse.ArgumentParser(description="Train a CGSchNet network") - parser.add_argument("input", help="Processed data to train on ") - parser.add_argument("result", default=None, nargs="?", help="Checkpoint directory to continue") - parser.add_argument("-c", "--config", default="../configs/config.yaml", type=str, help="") - parser.add_argument("--gpus", default=None, type=str, help="List of GPUs to train on (e.g. \"0,1,2\")") + parser.add_argument("input", help="Processed data to train on") + parser.add_argument("result", help="Checkpoint directory to continue") + parser.add_argument("-c", "--config", default="../configs/config.yaml", type=str, help="Path to model architecture config YAML") + parser.add_argument("--gpus", default=None, type=str, help="List of GPUs to train on (e.g. \"0,1,2\") or \"cpu\"") parser.add_argument("--batch", type=int, default=50, help="The batch size to use") parser.add_argument("--epochs", type=int, default=25, help="The total number of epochs to train for") - parser.add_argument("--lr", type=float, default="1e-4", help="Learning rate") + parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate") parser.add_argument("--wd", type=float, default=0, help="Weight decay") parser.add_argument("--val-ratio", type=float, default=0.1, help="Validation set ratio, should be between 0.0 and 1.0") parser.add_argument("--apc", "--atoms-per-call", type=int, default=None, help="Number of atoms to include in each sub-batch") parser.add_argument("--cos-anneal", default=None, help="Train using cosine annealing, parameters are \"T_0,T_mult\"") parser.add_argument("--cos-lr", default=None, help="Train using a cosine learning rate, parameters are \"T_max,eta_min\"") - parser.add_argument("--exp-lr", default=None, help="Train using a exponential learning rate, parameters are \"gamma\"") - parser.add_argument("--plateau-lr", default=None, help="Train using a plateau learning rate, parameters are \"factor\" \"patience\" \"min_lr\"") + parser.add_argument("--exp-lr", default=None, help="Train using an exponential learning rate, parameter is \"gamma\"") + parser.add_argument("--plateau-lr", default=None, help="Train using a plateau learning rate, parameters are \"factor,patience,threshold,min_lr\"") parser.add_argument("--dry-run", action="store_true", help="Do a dry run of the training loop but produce no output") parser.add_argument("--reset-early-stopping", action="store_true", help="Reset the early stopping check to start from the current epoch") parser.add_argument("--no-shuffle", action="store_true", help="Do not shuffle the training dataset") + parser.add_argument("--num-workers", type=int, default=0, help="Number of workers to use for data loading") parser.add_argument("--mini-epoch", type=int, default=None, help="Save a mini epoch after every n batches") - parser.add_argument("--early-stopping", type=int, default=1, help="The number of epochs validation loss can increase before triggering early stopping or -1 to disable early stopping (default=1)") + parser.add_argument("--early-stopping", type=int, default=1, help="Epochs validation loss can increase before early stopping, or -1 to disable (default=1)") parser.add_argument("--checkpoint-save", type=int, default=10, help="Save a backup checkpoint every n epochs, 0 to disable (default=10)") - parser.add_argument("--subsetpdbs", default='ok_list.txt', type=str, help="Change the pdbid list used when reading in the dataset (default=ok_list.txt)") - parser.add_argument("--energy-weight", default=0.0, type=float, help="Energy Weighting for Loss Function") - parser.add_argument("--force-weight", default=1.0, type=float, help="Force Weighting for Loss Function") - parser.add_argument("--term-def", default=None, type=str, help="The path to a term definition yaml file, which can additional loss terms used during training.") - parser.add_argument("--embedding", type=str, default=None, help="Specify an alternate file to load embeddings from (default: embeddings.npy).") + parser.add_argument("--subsetpdbs", default=None, type=str, help="Change the pdbid list used when reading in the dataset (default=input_dir/result/ok_list.txt)") + parser.add_argument("--energy-weight", default=0.0, type=float, help="Energy weighting for loss function") + parser.add_argument("--force-weight", default=1.0, type=float, help="Force weighting for loss function") + parser.add_argument("--term-def", default=None, type=str, help="Path to a term definition yaml file for additional loss terms during training") + parser.add_argument("--embedding", type=str, default=None, help="Alternate file to load embeddings from (default: embeddings.npy)") parser.add_argument("--chunk-dataset", type=int, default=None, help="Break the dataset into chunks of n proteins per batch") parser.add_argument("--npfile", action="store_true", help="Use file loader instead of mmap to load dataset") + parser.add_argument("--use-force-weights", action="store_true", default=False, help="Use per-bead force weights in training") + parser.add_argument("--distributed", action="store_true", help="Enable distributed training (also auto-detected from RANK/SLURM_PROCID env vars)") + return parser + +def validate_config(cfg) -> None: assert torch.cuda.is_available(), "CUDA is not available, please run on a machine with CUDA or use --gpus cpu" + assert os.path.isdir(cfg.input), f"Input directory does not exist: {cfg.input}" + assert os.path.isdir(cfg.result), f"Result directory does not exist: {cfg.result}" + assert os.path.isfile(cfg.config), f"Config file does not exist: {cfg.config}" + assert cfg.checkpoint_save >= 0, "--checkpoint-save must be >= 0" + n_schedulers = sum(s is not None for s in [cfg.cos_anneal, cfg.cos_lr, cfg.exp_lr, cfg.plateau_lr]) + assert n_schedulers <= 1, "At most one LR scheduler may be specified (--cos-anneal, --cos-lr, --exp-lr, --plateau-lr)" - args = parser.parse_args() - - directory_path = args.input - assert os.path.isdir(directory_path), f"Input directory does not exist: {directory_path}" - result_directory = args.result - conf_path = args.config - assert os.path.isfile(conf_path), f"Config file does not exist: {conf_path}" - weight_decay = args.wd - learning_rate = args.lr - if args.gpus: - if args.gpus == "cpu": - gpu_ids = "cpu" - else: - gpu_ids = [int(i) for i in args.gpus.strip().split(",")] + +def process_config(cfg) -> dict: + if cfg.gpus == "cpu": + gpu_ids = "cpu" + elif cfg.gpus: + gpu_ids = [int(i) for i in cfg.gpus.strip().split(",")] else: gpu_ids = "cpu" - epochs = args.epochs - batch_size = args.batch - val_ratio = args.val_ratio - atoms_per_call = args.apc - dry_run = args.dry_run - reset_early_stopping = args.reset_early_stopping - enable_shuffle = not args.no_shuffle - mini_epoch_size = args.mini_epoch - early_stopping = args.early_stopping - checkpoint_save = args.checkpoint_save - assert checkpoint_save >= 0 - - subsetpdbs = args.subsetpdbs - energy_weight = args.energy_weight - force_weight = args.force_weight - energy_matching = args.energy_weight != 0.0 - embedding_filename = args.embedding - dataset_chunk_size = args.chunk_dataset - use_npfile = args.npfile - - # Relax the maximum number of open files as much as possible - # We will potentially open a lot of files (~4 per molecule per ProteinDataset object) - soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) - resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard)) - lr_scheduler = None - if args.cos_anneal: - T_0, T_mult = [int(i) for i in args.cos_anneal.split(",")] + if cfg.cos_anneal: + T_0, T_mult = [int(i) for i in cfg.cos_anneal.split(",")] lr_scheduler = SchedulerWrapper_CosineAnnealingWarmRestarts(T_0, T_mult) - if args.cos_lr: - assert lr_scheduler is None - T_max, eta_min = args.cos_lr.split(",") - T_max, eta_min = int(T_max), float(eta_min) - lr_scheduler = SchedulerWrapper_CosineAnnealingLR(T_max, eta_min) - if args.exp_lr: - assert lr_scheduler is None - lr_scheduler = SchedulerWrapper_ExponentialLR(float(args.exp_lr)) - if args.plateau_lr: - assert lr_scheduler is None - factor, patience, threshold, min_lr = args.plateau_lr.split(",") - factor, patience, threshold, min_lr = float(factor), int(patience), float(threshold), float(min_lr) - lr_scheduler = SchedulerWrapper_ReduceLROnPlateau(factor, patience, threshold, min_lr) - - if args.term_def is not None: - train_term_def = TermDef(path=args.term_def) - else: - train_term_def = TermDef() + elif cfg.cos_lr: + T_max, eta_min = cfg.cos_lr.split(",") + lr_scheduler = SchedulerWrapper_CosineAnnealingLR(int(T_max), float(eta_min)) + elif cfg.exp_lr: + lr_scheduler = SchedulerWrapper_ExponentialLR(float(cfg.exp_lr)) + elif cfg.plateau_lr: + factor, patience, threshold, min_lr = cfg.plateau_lr.split(",") + lr_scheduler = SchedulerWrapper_ReduceLROnPlateau(float(factor), int(patience), float(threshold), float(min_lr)) + + train_term_def = TermDef(path=cfg.term_def) if cfg.term_def is not None else TermDef() + + # Raise the OS file-descriptor limit — large datasets open ~4 files per molecule + soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) + resource.setrlimit(resource.RLIMIT_NOFILE, (soft, hard)) + + return dict( + directory_path=cfg.input, + result_directory=cfg.result, + conf_path=cfg.config, + dry_run=cfg.dry_run, + weight_decay=cfg.wd, + learning_rate=cfg.lr, + gpu_ids=gpu_ids, + epochs=cfg.epochs, + batch_size=cfg.batch, + val_ratio=cfg.val_ratio, + atoms_per_call=cfg.apc, + scheduler=lr_scheduler, + reset_early_stopping=cfg.reset_early_stopping, + enable_shuffle=not cfg.no_shuffle, + mini_epoch_size=cfg.mini_epoch, + early_stopping=cfg.early_stopping, + checkpoint_save=cfg.checkpoint_save, + subsetpdbs=cfg.subsetpdbs, + energy_weight=cfg.energy_weight, + force_weight=cfg.force_weight, + energy_matching=cfg.energy_weight != 0.0, + train_term_def=train_term_def, + embedding_filename=cfg.embedding, + dataset_chunk_size=cfg.chunk_dataset, + use_npfile=cfg.npfile, + use_force_weights=cfg.use_force_weights, + num_workers=cfg.num_workers, + ) + + +if __name__ == "__main__": + from module.base_config import BaseConfig + + cfg = BaseConfig(build_parser()) + validate_config(cfg) + params = process_config(cfg) + + ddp = DistributedContext() + if cfg.distributed or 'RANK' in os.environ or 'SLURM_PROCID' in os.environ: + ddp = setup_distributed() + if ddp.is_distributed and ddp.is_main: + print(f"Initialized process {ddp.rank}/{ddp.world_size} (local_rank={ddp.local_rank})") try: - train_model(directory_path, result_directory=result_directory, conf_path=conf_path, dry_run=dry_run, weight_decay=weight_decay, - learning_rate=learning_rate, gpu_ids=gpu_ids, epochs=epochs, batch_size=batch_size, val_ratio=val_ratio, scheduler=lr_scheduler, - atoms_per_call=atoms_per_call, reset_early_stopping=reset_early_stopping, enable_shuffle=enable_shuffle, - mini_epoch_size=mini_epoch_size, early_stopping=early_stopping, checkpoint_save=checkpoint_save, subsetpdbs=subsetpdbs, energy_weight=energy_weight, - force_weight=force_weight, energy_matching=energy_matching, train_term_def=train_term_def, embedding_filename=embedding_filename, - dataset_chunk_size=dataset_chunk_size, use_npfile=use_npfile) + train_model(**params, ddp=ddp) except Exception as e: - # Uncaught exceptions cause pytorch to hang for quite a while before exiting traceback.print_tb(e.__traceback__) print(e) sys.exit(1) + finally: + cleanup_distributed()