From cf4b3dc604e6447e38493b09a141808f5ee71f9c Mon Sep 17 00:00:00 2001 From: Alexander Aghili Date: Tue, 16 Dec 2025 22:47:05 -0800 Subject: [PATCH 1/4] Train Distributed + Refactor --- train.py | 1907 +++++++++++++++++++++++++++++++++++------------------- 1 file changed, 1229 insertions(+), 678 deletions(-) diff --git a/train.py b/train.py index 40d8f55..1f9218d 100755 --- a/train.py +++ b/train.py @@ -1,73 +1,205 @@ #!/usr/bin/env python3 - -import torch -import torch.nn as nn -import torch.optim as optim -import torch.utils.data -from torch.utils.data import DataLoader -import yaml -import numpy as np -# from torchmdnet.models.model import create_model -from module.torchmdnet.model import create_model -from module import dataset -from module import model_util -from module.lr_scheduler_wrappers import * +from __future__ import annotations import os +import sys import json import time -from tqdm import tqdm -import datetime +import yaml import shutil -import resource -import sys -import traceback import itertools +import datetime +import traceback +import resource +from dataclasses import dataclass +from typing import Dict, Any, Optional, Tuple, Iterable, List -# Type hinting... -from typing import Tuple +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim from torch import Tensor +from torch.utils.data import DataLoader +from tqdm import tqdm -# Useful for debugging pytorch CUDA crashes -# os.environ["CUDA_LAUNCH_BLOCKING"]="1" +# Distributed +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data.distributed import DistributedSampler -def flatten_first(t): - """Flatten the first two dimentions of tensor t""" - if t is None: +# Local modules (unchanged) +from module.torchmdnet.model import create_model +from module import dataset +from module import model_util +from module.lr_scheduler_wrappers import ( + SchedulerWrapper_CosineAnnealingWarmRestarts, + SchedulerWrapper_CosineAnnealingLR, + SchedulerWrapper_ExponentialLR, + SchedulerWrapper_ReduceLROnPlateau, +) + +# ----------------------------- Small utilities ----------------------------- + +def flatten_first(t: Optional[Tensor]) -> Optional[Tensor]: + """Flatten first two dims, preserving remaining dims.""" + if t is None or getattr(t, "shape", None) is None: return t if len(t.shape) < 2: return t - return t.reshape(t.shape[0]*t.shape[1], *t.shape[2:]) + return t.reshape(t.shape[0] * t.shape[1], *t.shape[2:]) + +def deterministic_shuffle(items: List[str], seed: int) -> List[str]: + """Deterministically shuffle a list.""" + g = torch.Generator().manual_seed(seed) + idx = torch.randperm(n=len(items), generator=g, device="cpu") + return [items[i] for i in idx] + +def check_early_stopping(val_list: List[float], patience: int = 1) -> bool: + """True if val loss increased (patience+1) consecutive times. patience<0 disables.""" + if patience < 0: + return False + if len(val_list) < patience + 2: + return False + window = np.array(val_list[-(patience + 2):]) + if np.all((window[1:] - window[:-1]) > 0): + print(f"Validation loss increased {patience+1} times, stopping...") + return True + return False -def make_term_offsets(lengths, term_lengths): +def make_term_offsets(lengths: List[int], term_lengths: Tensor) -> Tensor: + """Offset per-term indices across concatenated variable-length batches.""" result = [] count = 0 - - repeats = len(term_lengths)//len(lengths) - lengths = np.tile(lengths, repeats) + repeats = len(term_lengths) // len(lengths) + lengths = np.tile(np.array(lengths), repeats).tolist() assert len(lengths) == len(term_lengths) - - # For each batch we want to offset the indicies used by the terms by the number of atoms in the prior batches for off, nterms in zip(lengths, term_lengths): - result.append(torch.full((nterms, 1), count, dtype=torch.long)) - count += off - return torch.cat(result) + result.append(torch.full((int(nterms), 1), count, dtype=torch.long)) + count += int(off) + return torch.cat(result, dim=0) + +# ----------------------------- Domain config ------------------------------ + +class TermDef: + def __init__(self, path: Optional[str] = None, conf: Optional[dict] = None): + self.scales: Dict[str, float] = {} + self.angle_wrap: Dict[str, bool] = {} + if path: + with open(path, "r") as f: + conf = yaml.safe_load(f) + if conf: + for k, v in conf.items(): + self.scales[k] = float(v["scale"]) if (v is not None and "scale" in v) else 1.0 + self.angle_wrap[k] = bool(v["angle_wrap"]) if (v is not None and "angle_wrap" in v) else False + + def names(self) -> List[str]: + return list(self.scales.keys()) + + def scale(self, name: str) -> float: + return self.scales[name] + + def wrap(self, name: str) -> bool: + return self.angle_wrap[name] + +# ----------------------------- Distributed manager ------------------------- + +@dataclass +class DistInfo: + rank: Optional[int] = None + world_size: Optional[int] = None + local_rank: Optional[int] = None + + @property + def enabled(self) -> bool: + return self.rank is not None and self.world_size is not None and self.world_size > 1 + + @property + def is_main(self) -> bool: + return self.rank is None or self.rank == 0 + +class DistributedManager: + def __init__(self, enable: bool): + self.info = DistInfo() + self._requested = enable + + def setup(self) -> DistInfo: + """Initialize distributed if env vars indicate multi-proc or user requested.""" + if not self._requested and not any(k in os.environ for k in ("RANK", "WORLD_SIZE", "SLURM_PROCID")): + return self.info + + rank, world_size, local_rank = self._read_env_ranks() + if rank is None: + return self.info + + torch.cuda.set_device(local_rank) + dist.init_process_group( + backend="nccl", + init_method="env://", + world_size=world_size, + rank=rank, + ) + self.info = DistInfo(rank=rank, world_size=world_size, local_rank=local_rank) + print(f"Initialized process {rank}/{world_size} (local_rank={local_rank})") + return self.info + + def cleanup(self) -> None: + """Destroy process group if initialized.""" + if dist.is_initialized(): + dist.destroy_process_group() + + def barrier(self) -> None: + """Sync all processes if distributed.""" + if self.info.enabled: + dist.barrier() + + def broadcast_bool(self, value: bool, src: int = 0, device: Optional[torch.device] = None) -> bool: + """Broadcast a boolean from src to all processes.""" + if not self.info.enabled: + return value + assert device is not None + t = torch.tensor(1 if value else 0, device=device) + dist.broadcast(t, src=src) + return bool(t.item()) + + def all_reduce_sum(self, x: float, device: torch.device) -> float: + """All-reduce a scalar float via SUM.""" + if not self.info.enabled: + return x + t = torch.tensor(x, device=device) + dist.all_reduce(t, op=dist.ReduceOp.SUM) + return float(t.item()) + + def _read_env_ranks(self) -> Tuple[Optional[int], Optional[int], Optional[int]]: + """Read rank/world_size/local_rank from torchrun or SLURM env vars.""" + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + return rank, world_size, local_rank + + if "SLURM_PROCID" in os.environ: + rank = int(os.environ["SLURM_PROCID"]) + world_size = int(os.environ["SLURM_NTASKS"]) + local_rank = int(os.environ.get("SLURM_LOCALID", 0)) + return rank, world_size, local_rank + + return None, None, None + +# ----------------------------- Model wrappers ------------------------------ class BatchWrapper(nn.Module): - def __init__(self, model): + def __init__(self, model: nn.Module): super().__init__() self.model = model - def forward(self, pos, lengths, **kwargs) -> Tuple[Tensor, Tensor]: - batch_nums = dataset.make_batch_nums(len(pos), lengths) - batch_nums = batch_nums.to(pos.device) - assert batch_nums.device == pos.device - + def forward(self, pos: Tensor, lengths: List[int], **kwargs) -> Tuple[Tensor, Tensor, Dict[str, Tensor]]: + """Prepare batch indices/terms and forward into model.""" + batch_nums = dataset.make_batch_nums(len(pos), lengths).to(pos.device) kwargs["pos"] = pos - for k, v in kwargs.items(): + + for k, v in list(kwargs.items()): kwargs[k] = flatten_first(v) - #TODO: It would be better if the term lengths were also python lists like the batch lengths to avoid the round trip to the GPU and back if "bonds" in kwargs: kwargs["bonds"] = kwargs["bonds"] + make_term_offsets(lengths, kwargs.pop("len_bonds").cpu()).to(pos.device) if "angles" in kwargs: @@ -77,725 +209,1144 @@ def forward(self, pos, lengths, **kwargs) -> Tuple[Tensor, Tensor]: kwargs["batch"] = batch_nums result = self.model(**kwargs) - if len(result) == 2: - result = [*result, {}] - return result #pyright: ignore[reportReturnType] - -class TermDef(): - def __init__(self, path=None, conf=None): - self.scales = {} - self.angle_wrap = {} - if path: - with open(path, 'r') as file: - conf = yaml.safe_load(file) - - if conf: - for k, v in conf.items(): - if v is not None and "scale" in v: - self.scales[k] = float(v["scale"]) - else: - self.scales[k] = 1.0 - - if v is not None and "angle_wrap" in v: - self.angle_wrap[k] = bool(v["angle_wrap"]) - else: - self.angle_wrap[k] = False + if len(result) == 2: + out_e, out_f = result + return out_e, out_f, {} + out_e, out_f, extra = result + return out_e, out_f, extra + +class ModelManager: + def __init__(self, conf: Dict[str, Any], dist_info: DistInfo, gpu_ids, local_rank: Optional[int]): + self.conf = conf + self.dist = dist_info + self.gpu_ids = gpu_ids + self.local_rank = local_rank + + self.device = self._select_device() + self.model = create_model(args=conf).to(self.device) + self.wrapped = BatchWrapper(self.model) + self.parallel, self.device_output = self._wrap_parallel() + + def _select_device(self) -> torch.device: + """Pick a torch device for the current process.""" + if self.local_rank is not None: + return torch.device(f"cuda:{self.local_rank}") + if self.gpu_ids != "cpu": + return torch.device("cuda:0") + return torch.device("cpu") + + def _wrap_parallel(self) -> Tuple[nn.Module, Any]: + """Wrap model in DDP or DataParallel or no-wrap.""" + if self.dist.enabled: + parallel = DDP( + self.wrapped, + device_ids=[self.local_rank], + output_device=self.local_rank, + find_unused_parameters=True, + ) + if self.dist.is_main: + print(f"DDP: Training on {self.dist.world_size} GPUs across nodes (local_rank={self.local_rank})") + return parallel, self.device + + if self.gpu_ids == "cpu": + if self.dist.is_main: + print("Training on CPU") + return self.wrapped, "cpu" + + parallel = nn.DataParallel(self.wrapped, device_ids=self.gpu_ids) + if self.dist.is_main: + print(f"DataParallel: Training on {len(parallel.device_ids)} GPU(s)") + return parallel, parallel.output_device + + def train(self) -> None: + self.model.train() + + def eval(self) -> None: + self.parallel.eval() + + def state_target_for_loading(self) -> nn.Module: + """Return the object to pass to model_util.load_state_dict_with_rename.""" + if isinstance(self.parallel, DDP): + return self.parallel.module.model + if hasattr(self.parallel, "module"): + return self.parallel.module.model + return self.model + + def state_dict_for_saving(self) -> Dict[str, Tensor]: + """Extract the underlying model state dict in a wrapper-agnostic way.""" + if isinstance(self.parallel, DDP): + return self.parallel.module.model.state_dict() + if hasattr(self.parallel, "module"): + return self.parallel.module.model.state_dict() + return self.model.state_dict() + +# ----------------------------- Data module --------------------------------- - def get_names(self): - return list(self.scales.keys()) +class RoundRobinDataWrapper: + def __init__(self, *iterables: Iterable): + self.iterables = iterables - def get_scale(self, name): - return self.scales[name] + def __len__(self) -> int: + return sum(map(len, self.iterables)) - def get_angle_wrap(self, name): - return self.angle_wrap[name] + def __iter__(self): + iters = map(iter, self.iterables) + for num_active in range(len(self.iterables), 0, -1): + iters = itertools.cycle(itertools.islice(iters, num_active)) + yield from map(next, iters) + +@dataclass +class DataBundle: + datasets: List[Any] + train: RoundRobinDataWrapper + val: RoundRobinDataWrapper + train_samplers: List[Optional[DistributedSampler]] + pdb_list: List[str] + +class DataModule: + def __init__( + self, + directory_path: str, + subsetpdbs: str, + val_ratio: float, + batch_size: int, + atoms_per_call: Optional[int], + enable_shuffle: bool, + dataset_chunk_size: Optional[int], + use_npfile: bool, + embedding_filename: Optional[str], + energy_matching: bool, + dist_info: DistInfo, + ): + self.dir = directory_path + self.subsetpdbs = subsetpdbs + self.val_ratio = val_ratio + self.batch_size = batch_size + self.atoms_per_call = atoms_per_call + self.enable_shuffle = enable_shuffle + self.chunk = dataset_chunk_size + self.use_npfile = use_npfile + self.embedding_filename = embedding_filename or "embeddings.npy" + self.energy_filename = "tica_delta_energies.npy" if energy_matching else None + self.dist = dist_info + + def build(self) -> DataBundle: + """Create datasets and round-robin dataloaders (optionally chunked).""" + pdb_list = self._load_pdb_list() + pdb_lists = [pdb_list[i:i + self.chunk] for i in range(0, len(pdb_list), self.chunk)] if self.chunk else [pdb_list] + + datasets, train_loaders, val_loaders, samplers = [], [], [], [] + for chunk_list in pdb_lists: + ds, tr, va, sampler = self._make_loaders(chunk_list) + datasets.append(ds) + train_loaders.append(tr) + val_loaders.append(va) + samplers.append(sampler) + + return DataBundle( + datasets=datasets, + train=RoundRobinDataWrapper(*train_loaders), + val=RoundRobinDataWrapper(*val_loaders), + train_samplers=samplers, + pdb_list=pdb_list, + ) + + def _load_pdb_list(self) -> List[str]: + """Read, dedupe, sort, and deterministically shuffle PDB IDs.""" + with open(os.path.join(self.dir, "result", self.subsetpdbs), "r") as f: + pdb_list = [x for x in f.read().split("\n") if x] + pdb_list = sorted(set(pdb_list)) + return deterministic_shuffle(pdb_list, seed=47563537) + + def _make_loaders(self, pdb_list: List[str]): + """Build torch DataLoaders and (optionally) DistributedSamplers.""" + print("Dataset:", " ".join(pdb_list)) + all_data = dataset.ProteinDataset( + self.dir, + pdb_list, + energy_file=self.energy_filename, + embeddings_file=self.embedding_filename, + use_npfile=self.use_npfile, + ) + + assert 0.0 < self.val_ratio < 1.0 + val_size = int(self.val_ratio * len(all_data)) + train_size = len(all_data) - val_size + + if self.enable_shuffle: + g = torch.Generator().manual_seed(12341234) + val_idx, train_idx = torch.utils.data.random_split( + torch.arange(len(all_data)), + [val_size, train_size], + generator=g, + ) + else: + train_idx = range(train_size) + val_idx = range(train_size, train_size + val_size) + + train = torch.utils.data.Subset(all_data, train_idx) + val = torch.utils.data.Subset(all_data, val_idx) + + collate_fn = dataset.ProteinBatchCollate(self.atoms_per_call) + + train_sampler = None + val_sampler = None + if self.dist.enabled: + train_sampler = DistributedSampler(train, num_replicas=self.dist.world_size, rank=self.dist.rank, shuffle=False) + val_sampler = DistributedSampler(val, num_replicas=self.dist.world_size, rank=self.dist.rank, shuffle=False) + + train_loader = DataLoader( + train, + batch_size=self.batch_size, + shuffle=False if train_sampler is None else False, + num_workers=4, + persistent_workers=True, + pin_memory=True, + collate_fn=collate_fn, + sampler=train_sampler, + ) + val_loader = DataLoader( + val, + batch_size=self.batch_size, + shuffle=False, + num_workers=4, + persistent_workers=True, + pin_memory=True, + collate_fn=collate_fn, + sampler=val_sampler, + ) + return all_data, train_loader, val_loader, train_sampler + +# ----------------------------- Checkpointing ------------------------------- + +class Checkpointer: + def __init__(self, result_dir: str, dist_info: DistInfo): + self.result_dir = result_dir + self.dist = dist_info + + def find_resume_checkpoint(self) -> Optional[str]: + """Pick mini checkpoint first, else normal checkpoint.""" + mini = os.path.join(self.result_dir, "checkpoint-mini.pth") + main = os.path.join(self.result_dir, "checkpoint.pth") + if os.path.exists(mini): + return mini + if os.path.exists(main): + return main + return None + + def save( + self, + path: str, + epoch: int, + model_state: Dict[str, Tensor], + optimizer: optim.Optimizer, + model_conf: Dict[str, Any], + scheduler: Optional[Any], + extra: Optional[dict] = None, + ) -> None: + """Save checkpoint only on main process.""" + if not self.dist.is_main: + return + ckpt = { + "epoch": epoch, + "optimizer": optimizer.state_dict(), + "state_dict": model_state, + "hyper_parameters": model_conf, + } + if scheduler: + ckpt["scheduler"] = scheduler.state_dict() + if extra: + ckpt["extra"] = extra + torch.save(ckpt, path) + + def write_training_info( + self, + epoch: int, + directory_path: str, + pdb_list: List[str], + params: Dict[str, Any], + ) -> None: + """Write/update training_info.json and copy priors on main process.""" + if not self.dist.is_main: + return + + training_info_path = os.path.join(self.result_dir, "training_info.json") + info = {} + + if os.path.exists(training_info_path): + with open(training_info_path, "r") as f: + info = json.load(f) + if "input_directory" in info: + info = {"0": info} + else: + print("Path", training_info_path, "does not exist") + + info[str(epoch)] = { + "weight_decay": params["weight_decay"], + "learning_rate": params["learning_rate"], + "epochs": params["epochs"], + "batch_size": params["batch_size"], + "input_directory": directory_path, + "pdbs": pdb_list, + "energy_weight": params["energy_weight"], + "force_weight": params["force_weight"], + "embedding_filename": params["embedding_filename"], + "world_size": params.get("world_size", 1), + } + if params.get("lr_scheduler_repr") is not None: + info[str(epoch)]["lr_scheduler"] = params["lr_scheduler_repr"] + if not params["dry_run"]: + with open(training_info_path, "w") as f: + json.dump(info, f, indent=2) -def deterministic_shuffle(target, seed): - generator = torch.Generator().manual_seed(seed) - indices = torch.randperm(n=len(target), generator=generator, device="cpu") - return [target[i] for i in indices] + prior_path = os.path.join(directory_path, "priors.yaml") + if os.path.exists(prior_path): + prior_params_path = os.path.join(directory_path, "prior_params.json") + shutil.copy(prior_path, self.result_dir) + shutil.copy(prior_params_path, self.result_dir) + +# ----------------------------- Training bookkeeping ------------------------ + +@dataclass +class History: + train: List[float] + val: List[float] + energy: List[float] + force: List[float] + + @classmethod + def empty(cls) -> "History": + return cls(train=[], val=[], energy=[], force=[]) + + @classmethod + def load(cls, result_dir: str) -> "History": + path = os.path.join(result_dir, "history.npy") + if not os.path.exists(path): + return cls.empty() + data = np.load(path, allow_pickle=True).item() + return cls(train=data["train"], val=data["val"], energy=data["energy"], force=data["force"]) + + def save(self, result_dir: str) -> None: + np.save(os.path.join(result_dir, "history.npy"), {"train": self.train, "val": self.val, "energy": self.energy, "force": self.force}) + +class EpochHistory: + def __init__(self, result_dir: str): + self.result_dir = result_dir + self.path = os.path.join(result_dir, "epoch_history.json") + self.data: Dict[str, Any] = {} + if os.path.exists(self.path): + with open(self.path, "r") as f: + self.data = json.load(f) + + def update(self, key: str, value: Dict[str, Any]) -> None: + self.data[key] = value + with open(self.path, "w") as f: + json.dump(self.data, f, indent=2) + +# ----------------------------- Optim & scheduler --------------------------- -def check_early_stopping(val_list, patience=1): - """Return True if the number of epochs with increasing val_loss > patience. If patience < 0 always return False.""" - if patience < 0: +def should_decay(param_name: str) -> bool: + """Decide if a parameter should receive weight decay.""" + parts = param_name.split(".") + assert parts + if parts[-1] == "bias": return False - if len(val_list) < patience+2: + if len(parts) >= 2 and parts[-2] == "embedding": return False - check_range = np.array(val_list[-(patience+2):]) - if np.all((check_range[1:]-check_range[:-1])>0): - print(f"Validation loss increased {patience+1} times, stopping...") - return True - -def save_checkpoint(checkpoint_path, epoch, model, optimizer, model_conf, scheduler, extra=None): - checkpoint_dict = { - "epoch":epoch, - "optimizer":optimizer.state_dict(), - "state_dict":model.state_dict(), - "hyper_parameters":model_conf, - } - - if scheduler: - checkpoint_dict["scheduler"] = scheduler.state_dict() + if len(parts) >= 2 and parts[-2] == "distance_expansion": + return False + assert parts[-1] == "weight" + return True - if extra: - checkpoint_dict["extra"] = extra +class OptimFactory: + @staticmethod + def adamw(model: nn.Module, lr: float, weight_decay: float) -> optim.Optimizer: + """Create AdamW with split decay groups.""" + do_decay, dont_decay = [], [] + for name, p in model.named_parameters(): + (do_decay if should_decay(name) else dont_decay).append(p) + return optim.AdamW( + [{"params": do_decay, "weight_decay": weight_decay}, {"params": dont_decay}], + lr=lr, + ) + +class SchedulerFactory: + @staticmethod + def from_args(args) -> Optional[Any]: + """Build one of the scheduler wrappers from argparse args.""" + lr_scheduler = None + if args.cos_anneal: + T_0, T_mult = [int(i) for i in args.cos_anneal.split(",")] + lr_scheduler = SchedulerWrapper_CosineAnnealingWarmRestarts(T_0, T_mult) + if args.cos_lr: + assert lr_scheduler is None + T_max, eta_min = args.cos_lr.split(",") + lr_scheduler = SchedulerWrapper_CosineAnnealingLR(int(T_max), float(eta_min)) + if args.exp_lr: + assert lr_scheduler is None + lr_scheduler = SchedulerWrapper_ExponentialLR(float(args.exp_lr)) + if args.plateau_lr: + assert lr_scheduler is None + factor, patience, threshold, min_lr = args.plateau_lr.split(",") + lr_scheduler = SchedulerWrapper_ReduceLROnPlateau(float(factor), int(patience), float(threshold), float(min_lr)) + return lr_scheduler + +# ----------------------------- Core trainer -------------------------------- + +@dataclass +class TrainArgs: + directory_path: str + result_directory: Optional[str] + conf_path: str + gpu_ids: Any + weight_decay: float + learning_rate: float + epochs: int + batch_size: int + val_ratio: float + atoms_per_call: Optional[int] + scheduler: Optional[Any] + dry_run: bool + reset_early_stopping: bool + enable_shuffle: bool + mini_epoch_size: Optional[int] + early_stopping: int + checkpoint_save: int + subsetpdbs: str + energy_weight: float + force_weight: float + energy_matching: bool + train_term_def: TermDef + embedding_filename: Optional[str] + dataset_chunk_size: Optional[int] + use_npfile: bool + use_force_weights: bool + +class Trainer: + def __init__(self, args: TrainArgs, dist_mgr: DistributedManager): + self.args = args + self.dist_mgr = dist_mgr + self.dist = dist_mgr.info + + self.result_dir = self._ensure_result_dir(args.result_directory, args.dry_run) + self.checkpointer = Checkpointer(self.result_dir, self.dist) + + self.conf = self._load_conf(args.conf_path) + self._apply_conf_overrides() + + self.data_bundle = self._build_data() + self.model_mgr = ModelManager(self.conf, self.dist, args.gpu_ids, self.dist.local_rank) + + self.optimizer = OptimFactory.adamw(self.model_mgr.model, args.learning_rate, args.weight_decay) + if args.scheduler: + args.scheduler.initialize(self.optimizer) + self.scheduler = args.scheduler + + self.history = History.load(self.result_dir) + self.epoch_history = EpochHistory(self.result_dir) + + self.epoch = 0 + self.epoch_resume_extra: Optional[dict] = None + + self._maybe_resume() + self.dist_mgr.barrier() + self._write_training_info() + + def _load_conf(self, path: str) -> Dict[str, Any]: + """Load YAML config.""" + with open(path, "r") as f: + conf = yaml.safe_load(f) + if self.dist.is_main: + print("Config:\n", conf, "\n") + return conf + + def _apply_conf_overrides(self) -> None: + """Set conf flags needed for training terms.""" + if "harmonic_net" in self.conf and self.args.train_term_def.names(): + self.conf["harmonic_net_return_terms"] = True + + if self.conf.get("external_embedding_channels") is None and (self.args.embedding_filename and self.args.embedding_filename != "embeddings.npy"): + if self.dist.is_main: + print("WARNING: external embeddings usually should use graph-network-ext network") + + def _build_data(self) -> DataBundle: + """Construct datasets/loaders and add optional dataset features.""" + dm = DataModule( + directory_path=self.args.directory_path, + subsetpdbs=self.args.subsetpdbs, + val_ratio=self.args.val_ratio, + batch_size=self.args.batch_size, + atoms_per_call=self.args.atoms_per_call, + enable_shuffle=self.args.enable_shuffle, + dataset_chunk_size=self.args.dataset_chunk_size, + use_npfile=self.args.use_npfile, + embedding_filename=self.args.embedding_filename, + energy_matching=self.args.energy_matching, + dist_info=self.dist, + ) + bundle = dm.build() + + if "sequence_basis_radius" in self.conf: + if self.dist.is_main: + print(f"Adding sequences to dataset... (sequence_basis_radius={self.conf['sequence_basis_radius']})") + for d in bundle.datasets: + d.build_sequences() + + extra_terms = [] + if "harmonic_net" in self.conf: + if self.dist.is_main: + print(f"Adding classical terms to dataset... (harmonic_net={self.conf['harmonic_net']})") + for d in bundle.datasets: + d.build_classical_terms() + + if self.args.train_term_def.names(): + extra_terms = self.args.train_term_def.names() + if self.dist.is_main: + print(f"Loading additional trained terms: {extra_terms}") + print(f" Term Scales: {[self.args.train_term_def.scale(i) for i in extra_terms]}") + print(f" Term Angle Wrap: {[self.args.train_term_def.wrap(i) for i in extra_terms]}") + for d in bundle.datasets: + d.load_frame_terms(extra_terms) + + if self.args.use_force_weights: + if self.dist.is_main: + print("Loading forces weights...") + for d in bundle.datasets: + d.load_frame_terms(["forces_weights"]) + + self.extra_train_terms = extra_terms + if self.dist.is_main: + print() + return bundle + + def _ensure_result_dir(self, result_dir: Optional[str], dry_run: bool) -> str: + """Create or validate result directory (main process only).""" + if result_dir and os.path.exists(os.path.join(result_dir, "checkpoint.pth")): + return result_dir + if result_dir and os.path.exists(os.path.join(result_dir, "checkpoint-mini.pth")): + return result_dir + + if result_dir is None: + result_dir = "../data/result-" + datetime.datetime.now().strftime("%Y.%m.%d-%H.%M.%S") + + if os.path.exists(result_dir): + info_path = os.path.join(result_dir, "training_info.json") + if os.path.exists(info_path): + if self.dist.is_main: + print("Re-initializing:", result_dir) + return result_dir + raise RuntimeError("Model directory exists but doesn't contain a checkpoint.pth or training_info.json file") - torch.save(checkpoint_dict, checkpoint_path) + if self.dist.is_main and not dry_run: + os.makedirs(result_dir, exist_ok=False) + print("Created:", result_dir) + return result_dir -def gen_dataloaders(directory_path, pdb_list, energy_filename, embedding_filename, use_npfile, enable_shuffle, val_ratio, batch_size, atoms_per_call): - print("Dataset:", " ".join(pdb_list)) + def _maybe_resume(self) -> None: + """Resume from checkpoint if present.""" + ckpt_path = self.checkpointer.find_resume_checkpoint() + if self.dist.is_main: + print("checkpoint_path", ckpt_path) - all_data = dataset.ProteinDataset(directory_path, pdb_list, energy_file=energy_filename, embeddings_file=embedding_filename, use_npfile=use_npfile) - # num_proteins = all_data.num_proteins() + if not ckpt_path: + self.epoch = 0 + if self.dist.is_main: + print("Saving to:", self.result_dir) + return - assert val_ratio > 0.0 and val_ratio < 1.0 - val_size = int(val_ratio * len(all_data)) - train_size = len(all_data) - val_size + if self.dist.is_main: + print("Resuming:", self.result_dir) - if enable_shuffle: - # Generate the test and validation split with deterministic indices - generator1 = torch.Generator().manual_seed(12341234) - val_idx, train_idx = torch.utils.data.random_split(torch.arange(len(all_data)), [val_size, train_size], generator=generator1) #pyright: ignore[reportArgumentType] - else: - # Data was pre-shuffled during preprocess and can be read sequentially - train_idx = range(train_size) - val_idx = range(train_size, train_size+val_size) - train = torch.utils.data.Subset(all_data, train_idx) #pyright: ignore[reportArgumentType] - val = torch.utils.data.Subset(all_data, val_idx) #pyright: ignore[reportArgumentType] + device = self.model_mgr.device + ckpt = torch.load(ckpt_path, weights_only=False, map_location=device) - collate_fn = dataset.ProteinBatchCollate(atoms_per_call) + model_target = self.model_mgr.state_target_for_loading() + model_util.load_state_dict_with_rename(model_target, ckpt["state_dict"]) - train_data = DataLoader(train, batch_size=batch_size, shuffle=False, num_workers=4, - persistent_workers=True, pin_memory=True, collate_fn=collate_fn) - val_data = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=4, - persistent_workers=True, pin_memory=True, collate_fn=collate_fn) + if ckpt.get("optimizer") is not None: + self.optimizer.load_state_dict(ckpt["optimizer"]) + else: + if self.dist.is_main: + print(" No optimizer in checkpoint, resetting...") - # print(f"Number of proteins in the dataset: {num_proteins}") - # print(f"Using periodic box: {all_data.has_box()}") + if self.scheduler and ckpt.get("scheduler") is not None: + self.scheduler.load_state_dict(ckpt["scheduler"]) - return all_data, train_data, val_data + self.epoch_resume_extra = ckpt.get("extra") + self.epoch = int(ckpt.get("epoch", 0)) -class RoundRobinDataWrapper: - def __init__(self, *iterables): - self.iterables = iterables + if self.epoch > 0: + self.history = History.load(self.result_dir) + self.epoch_history = EpochHistory(self.result_dir) - def __len__(self): - return sum(map(len, self.iterables)) + if self.dist.is_main: + print("Saving to:", self.result_dir) - def __iter__(self): - # From https://docs.python.org/3/library/itertools.html#itertools-recipes - iterators = map(iter, self.iterables) - for num_active in range(len(self.iterables), 0, -1): - iterators = itertools.cycle(itertools.islice(iterators, num_active)) - yield from map(next, iterators) - -def train_model(directory_path, conf_path, result_directory, dry_run, gpu_ids, - weight_decay, learning_rate, epochs, batch_size, val_ratio, atoms_per_call, - scheduler, reset_early_stopping, enable_shuffle, mini_epoch_size, early_stopping, - checkpoint_save, subsetpdbs, energy_weight, force_weight, energy_matching, train_term_def, - embedding_filename, dataset_chunk_size, use_npfile): - - with open(os.path.join(directory_path, "result", subsetpdbs), 'r') as file: - pdb_list = file.read().split('\n') - - # Remove duplicates and empty strings - pdb_list = sorted(list(set([i for i in pdb_list if i]))) - pdb_list = deterministic_shuffle(pdb_list, seed=47563537) - - if dataset_chunk_size is not None: - pdb_lists = [pdb_list[i:i + dataset_chunk_size] for i in range(0, len(pdb_list), dataset_chunk_size)] - else: - pdb_lists = [pdb_list] - - # Load all proteins into a datasets - energy_filename = None - if energy_matching: - energy_filename = "tica_delta_energies.npy" - if embedding_filename is None: - embedding_filename = "embeddings.npy" - - datasets = [] - train_dataloaders = [] - val_dataloaders = [] - - for pdb_chunk in pdb_lists: - ds, train_loader, val_loader = gen_dataloaders(directory_path, pdb_chunk, energy_filename, embedding_filename, use_npfile, enable_shuffle, val_ratio, batch_size, atoms_per_call) - datasets.append(ds) - train_dataloaders.append(train_loader) - val_dataloaders.append(val_loader) - - train_data = RoundRobinDataWrapper(*train_dataloaders) - val_data = RoundRobinDataWrapper(*val_dataloaders) - - # Create the model - if conf_path is None: - conf_path = "../configs/config.yaml" - with open(conf_path, 'r') as file: - conf = yaml.safe_load(file) - print("Config:\n", conf, "\n") - - if conf.get("external_embedding_channels") == None and embedding_filename != "embeddings.npy": - print("WARNING: external embeddings usually should use graph-network-ext network") - - # Set the network to return the harmonic term info if we're training them - if "harmonic_net" in conf and train_term_def.get_names(): - conf["harmonic_net_return_terms"] = True - - model = create_model(args=conf) - # We need to construct DataParallel and move the model to CUDA before - # initializing the optimizer or we get "Expected all tensors to be on the same device" - # errors. When exactly this error happens depends on how many GPUs are used and whether - # we're loading a checkpoint or not. - if gpu_ids == "cpu": - parallel_model = BatchWrapper(model) - device_src = "cpu" - device_output = "cpu" - print("Training on CPU") - else: - parallel_model = nn.DataParallel(BatchWrapper(model), device_ids=gpu_ids) - device_src = parallel_model.src_device_obj - device_output = parallel_model.output_device - print(f"DataParallel: Training on {len(parallel_model.device_ids)} GPU(s)") - - model.to(device_src) - print("Model:\n", model, "\n") - - extra_train_terms = [] - - # Add additional features to the dataset if the model requires them - if "sequence_basis_radius" in conf: - print(f"Adding sequences to dataset... (sequence_basis_radius={conf['sequence_basis_radius']})") - for d in datasets: - d.build_sequences() - - if "harmonic_net" in conf: - print(f"Adding classical terms to dataset... (harmonic_net={conf['harmonic_net']})") - for d in datasets: - d.build_classical_terms() - - if train_term_def.get_names(): - # FIXME: Rename this to more generic - harmonic_trained_terms = train_term_def.get_names() - print(f"Loading additional trained terms: {harmonic_trained_terms}") - for d in datasets: - d.load_frame_terms(harmonic_trained_terms) - extra_train_terms.extend(harmonic_trained_terms) - print(f" Term Scales: {[train_term_def.get_scale(i) for i in harmonic_trained_terms]}") - print(f" Term Angle Wrap: {[train_term_def.get_angle_wrap(i) for i in harmonic_trained_terms]}") - - print() - - criterion = nn.MSELoss() - term_criterion = nn.MSELoss(reduction="none") - - do_decay = [] - dont_decay = [] - for name, param in model.named_parameters(): - if should_decay(name): - do_decay.append(param) - else: - dont_decay.append(param) - - optimizer = optim.AdamW( - [ - {"params": do_decay, "weight_decay": weight_decay}, - {"params": dont_decay} - ], - lr=learning_rate) - - if scheduler: - scheduler.initialize(optimizer) - - epoch_resume = None - checkpoint_path = None - if os.path.exists(f'{result_directory}/checkpoint-mini.pth'): - checkpoint_path = f'{result_directory}/checkpoint-mini.pth' - elif os.path.exists(f'{result_directory}/checkpoint.pth'): - checkpoint_path = f'{result_directory}/checkpoint.pth' - - print("checkpoint_path", checkpoint_path) - if checkpoint_path: - print("Resuming:", result_directory) - checkpoint = torch.load(checkpoint_path, weights_only=False, map_location=device_src) - model_util.load_state_dict_with_rename(model, checkpoint["state_dict"]) - - if "optimizer" in checkpoint and checkpoint["optimizer"] is not None: - optimizer.load_state_dict(checkpoint["optimizer"]) + def _write_training_info(self) -> None: + """Write training info and priors (main process only).""" + if self.scheduler: + lr_sched_repr = repr(self.scheduler) else: - print(" No optimizer in checkpoint, resetting...") + lr_sched_repr = None + for g in self.optimizer.param_groups: + g["lr"] = self.args.learning_rate + + self.checkpointer.write_training_info( + epoch=self.epoch, + directory_path=self.args.directory_path, + pdb_list=self.data_bundle.pdb_list, + params={ + "weight_decay": self.args.weight_decay, + "learning_rate": self.args.learning_rate, + "epochs": self.args.epochs, + "batch_size": self.args.batch_size, + "energy_weight": self.args.energy_weight, + "force_weight": self.args.force_weight, + "embedding_filename": self.args.embedding_filename or "embeddings.npy", + "world_size": self.dist.world_size if self.dist.world_size else 1, + "lr_scheduler_repr": lr_sched_repr, + "dry_run": self.args.dry_run, + }, + ) + + def run(self) -> None: + """Main training loop.""" + if self.scheduler and self.scheduler.is_annealing(): + self.args.early_stopping = -1 + + first_es_epoch = self.epoch if self.args.reset_early_stopping else 0 + verbose_loss_report = sys.stdout.isatty() + + criterion = nn.MSELoss() + term_criterion = nn.MSELoss(reduction="none") + + while self.epoch < self.args.epochs: + self._set_sampler_epoch(self.epoch) + + t0 = time.time() + self.model_mgr.train() + + train_metrics = self._train_one_epoch( + criterion=criterion, + term_criterion=term_criterion, + verbose=verbose_loss_report, + ) + + self._append_train_metrics(train_metrics) + self.model_mgr.eval() + + val_metrics = self._validate_one_epoch(criterion=criterion, term_criterion=term_criterion) + self.history.val.append(val_metrics["val_loss"]) + + if self.scheduler: + self.scheduler.step(val_metrics["val_loss"]) + + if self.dist.is_main: + self._log_epoch(train_metrics, val_metrics, t0) + + should_stop = False + if self.dist.is_main: + should_stop = check_early_stopping(self.history.val[first_es_epoch:], patience=self.args.early_stopping) + should_stop = self.dist_mgr.broadcast_bool(should_stop, src=0, device=self.model_mgr.device) + + self._save_epoch(train_metrics, val_metrics) + self.dist_mgr.barrier() + + if should_stop: + if self.dist.is_main: + print("Early stopping triggered.") + break + + self.epoch += 1 + + def _set_sampler_epoch(self, epoch: int) -> None: + """Set epoch on DistributedSamplers for determinism.""" + if not self.dist.enabled: + return + for s in self.data_bundle.train_samplers: + if s is not None: + s.set_epoch(epoch) + + def _train_one_epoch(self, criterion, term_criterion, verbose: bool) -> Dict[str, Any]: + """Train for one epoch, returning aggregated metrics.""" + args = self.args + model = self.model_mgr.parallel + device_out = self.model_mgr.device_output + + train_loss = 0.0 + train_energy_loss = 0.0 + train_force_loss = 0.0 + num_cal = 0.0 - if scheduler and "scheduler" in checkpoint and checkpoint["scheduler"] is not None: - scheduler.load_state_dict(checkpoint["scheduler"]) - - if "extra" in checkpoint: # This was a mini-checkpoint - epoch_resume = checkpoint["extra"] - - if "epoch" in checkpoint: - epoch = checkpoint["epoch"] - else: - epoch = 0 - else: - if not result_directory or not os.path.exists(result_directory): - if not result_directory: - result_directory = "../data/result-" + datetime.datetime.now().strftime("%Y.%m.%d-%H.%M.%S") - if not dry_run: - os.makedirs(result_directory, exist_ok=False) - else: - assert os.path.exists(result_directory) == False, "Result directory exists but is invalid" - print("Created:", result_directory) - elif os.path.exists(f'{result_directory}/training_info.json'): - # Most likely the training started but was canceled/crashed before the first epoch finished - print("Re-initializing:", result_directory) - else: - raise RuntimeError("Model directory exists but doesn't contain a checkpoint.pth or training_info.json file") - epoch = 0 - - epoch_history = {} - train_loss_list = [] - val_loss_list = [] - energy_loss_list = [] - force_loss_list = [] - - if epoch > 0: - # Load the numpy history files - history = np.load(f'{result_directory}/history.npy', allow_pickle=True).item() - train_loss_list = history['train'] - val_loss_list = history['val'] - energy_loss_list = history['energy'] - force_loss_list = history['force'] - - # Might exist before epoch 1 if mini-checkpoints were saved - epoch_history_path = os.path.join(result_directory, "epoch_history.json") - if os.path.exists(epoch_history_path): - with open(epoch_history_path, "r") as f: - epoch_history = json.load(f) - - print("Saving to:", result_directory) - - # Document training parameters and input data - training_info_path = os.path.join(result_directory, "training_info.json") - training_info_dict = {} - - if os.path.exists(training_info_path): - with open(training_info_path, "r") as f: - training_info_dict = json.load(f) - - # Check for the old dict format and update it - if "input_directory" in training_info_dict.keys(): - training_info_dict = {"0": training_info_dict} - else: - print("Path", training_info_path, "does not exist") - - # TODO: Only add a new entry if the parameters have changed? - training_info_dict[str(epoch)] = { - "weight_decay" : weight_decay, - "learning_rate" : learning_rate, - "epochs" : epochs, - "batch_size" : batch_size, - "input_directory" : directory_path, - "pdbs" : pdb_list, - "energy_weight": energy_weight, - "force_weight": force_weight, - "embedding_filename" : embedding_filename, - } - if scheduler: - training_info_dict[str(epoch)]["lr_scheduler"] = repr(scheduler) - else: - # If there's no scheduler reset the learning of the optimizer to the passed value - for g in optimizer.param_groups: - g['lr'] = learning_rate - - - if not dry_run: - with open(training_info_path, "w") as f: - json.dump(training_info_dict, f, indent=2) - - # Save the validation frame indices - # FIXME: This isn't compatible with chunking - #np.save(os.path.join(result_directory, "validation_frames.npy"), np.array(val_idx)) - - # Save the prior with the model - prior_path = os.path.join(directory_path, "priors.yaml") - if os.path.exists(prior_path): - prior_params_path = os.path.join(directory_path, "prior_params.json") - shutil.copy(prior_path, result_directory) - shutil.copy(prior_params_path, result_directory) - - # Disable earlly stopping when using an annealing (cycling) schedualer - if scheduler and scheduler.is_annealing(): - early_stopping = -1 - - first_early_stopping_epoch = 0 - if reset_early_stopping == True: - first_early_stopping_epoch = epoch - - verbose_loss_report = sys.stdout.isatty() - - while epoch < epochs: - t0 = time.time() - model.train() - train_loss = 0 - train_energy_loss = 0 - train_force_loss = 0 - num_cal = 0 # The total number of elements trained on epoch_offset = 0 - mini_train_loss = 0 - mini_num_cal = 0 - - train_term_losses = {k: 0.0 for k in extra_train_terms} - train_term_num_cal = {k: 0 for k in extra_train_terms} - - if epoch_resume: - print("Resuming epoch...") - train_loss = float(epoch_resume["train_loss"]) - num_cal = int(epoch_resume["num_cal"]) - epoch_offset = int(epoch_resume["i"]) - epoch_resume = None - + mini_train_loss = 0.0 + mini_num_cal = 0.0 + + train_term_losses = {k: 0.0 for k in self.extra_train_terms} + train_term_num_cal = {k: 0 for k in self.extra_train_terms} + + if self.epoch_resume_extra: + if self.dist.is_main: + print("Resuming epoch...") + train_loss = float(self.epoch_resume_extra["train_loss"]) + num_cal = float(self.epoch_resume_extra["num_cal"]) + epoch_offset = int(self.epoch_resume_extra["i"]) + self.epoch_resume_extra = None + + if self.dist.is_main: + it = tqdm( + enumerate(self.data_bundle.train), + desc=f"Training ({self.epoch}/{args.epochs})", + total=len(self.data_bundle.train), + dynamic_ncols=True, + miniters=1, + ) + else: + it = enumerate(self.data_bundle.train) - # Setting miniters is required to keep the bar from stalling after skipping ahead while resuming a batch - tqdm_iter = tqdm(enumerate(train_data), desc=f"Training ({epoch}/{epochs})", total=len(train_data), dynamic_ncols=True, miniters=1) - for i, batch in tqdm_iter: - # Handle mini-batches + for i, batch in it: if i < epoch_offset: - # TODO: It's very wasteful to load everything then discard it, but the alternative requires making a 2nd dataset object... continue - elif epoch_offset and i == epoch_offset: - tqdm_iter.write(f"Resumed epoch at batch {i}") - elif mini_epoch_size and 0 == i % mini_epoch_size and i > 0: - tmp_checkpoint_path = f'{result_directory}/checkpoint-{epoch}-{i}.pth' - save_checkpoint(tmp_checkpoint_path, epoch, model, optimizer, conf, scheduler, extra = {"train_loss":train_loss, "num_cal":num_cal, "i":i}) - os.replace(tmp_checkpoint_path, f'{result_directory}/checkpoint-mini.pth') - - epoch_history[f"{epoch}-{i}"] = { - "train_loss":train_loss/num_cal, - "mini_train_loss":mini_train_loss/mini_num_cal, - "epoch_len":len(train_data), - "lr":[g['lr'] for g in optimizer.param_groups], - } - with open(epoch_history_path, "w") as f: - json.dump(epoch_history, f, indent=2) - tqdm_iter.write(f"Mini-epoch {epoch}-{i}: Train Loss: {train_loss/num_cal}") - - mini_train_loss = 0 - mini_num_cal = 0 - total_batch_size = sum([i["force"].numel() for i in batch]) + if epoch_offset and i == epoch_offset and self.dist.is_main and hasattr(it, "write"): + it.write(f"Resumed epoch at batch {i}") + + if args.mini_epoch_size and i > 0 and (i % args.mini_epoch_size) == 0: + self._save_mini_checkpoint(i, train_loss, num_cal, mini_train_loss, mini_num_cal, len(self.data_bundle.train)) + mini_train_loss, mini_num_cal = 0.0, 0.0 + + total_batch_size = sum([sb["force"].numel() for sb in batch]) num_cal += total_batch_size mini_num_cal += total_batch_size - total_term_batch_size = {k: sum([i[k].numel() for i in batch]) for k in train_term_losses} - for k in train_term_num_cal.keys(): + total_term_batch_size = {k: sum([sb[k].numel() for sb in batch]) for k in train_term_losses} + for k in train_term_num_cal: train_term_num_cal[k] += total_term_batch_size[k] - optimizer.zero_grad() + self.optimizer.zero_grad() + for sub_batch in batch: - force = sub_batch.pop("force") - force = force.reshape(-1, force.shape[-1]).to(device_output) - energy = None - if energy_matching: - energy = sub_batch.pop("energy") - energy = energy.reshape(-1, energy.shape[-1]).to(device_output) + loss, energy_loss, force_loss, term_loss_dict = self._compute_loss_for_sub_batch( + sub_batch=sub_batch, + criterion=criterion, + term_criterion=term_criterion, + total_batch_size=total_batch_size, + total_term_batch_size=total_term_batch_size, + device_out=device_out, + model=model, + ) + + train_force_loss += float(force_loss) * total_batch_size + if args.energy_matching: + train_energy_loss += float(energy_loss) * total_batch_size + + delta_loss = float(loss) * total_batch_size + train_loss += delta_loss + mini_train_loss += delta_loss - term_targets = {} - for k in train_term_losses.keys(): - term_targets[k] = sub_batch.pop(k).flatten().to(device_output) + for k, v in term_loss_dict.items(): + train_term_losses[k] += float(v) * total_term_batch_size[k] - out_energy, out_force, extra = parallel_model(**sub_batch) + loss.backward() - # Scale the sub_batch to be a term in the overall mean of the batch - sub_batch_size = force.numel() - energy_loss: torch.Tensor = torch.tensor(0.0) - if energy_matching: - energy_loss = criterion(out_energy, energy) * (sub_batch_size / total_batch_size) - force_loss = criterion(out_force, force) * (sub_batch_size / total_batch_size) - loss = energy_weight * energy_loss + force_weight * force_loss + self.optimizer.step() - train_force_loss += force_loss.item() * total_batch_size - if energy_matching: - train_energy_loss += (energy_loss.item() * total_batch_size) + if self.scheduler: + self.scheduler.step_batch(self.epoch + i / len(self.data_bundle.train)) - delta_loss = loss.item() * total_batch_size - train_loss += delta_loss - mini_train_loss += delta_loss + if args.dry_run: + if self.dist.is_main: + print("\nDry run OK!") + sys.exit(0) - for k in train_term_losses.keys(): - # TODO: Find a more generic way of doing this - if train_term_def.get_angle_wrap(k): - train_term_loss = (extra[k] - term_targets[k] + torch.pi) % (2*torch.pi) - torch.pi - train_term_loss = train_term_loss**2 - else: - train_term_loss = term_criterion(extra[k], term_targets[k]) + if verbose and self.dist.is_main and hasattr(it, "set_description"): + desc = [f"Training ({self.epoch}/{args.epochs}) (T={train_loss/num_cal:.4f}"] + for tname in train_term_losses: + desc.append(f"{tname}={train_term_losses[tname]/train_term_num_cal[tname]:.4f}") + it.set_description(", ".join(desc) + ")") + + metrics = { + "train_loss_sum": train_loss, + "train_energy_loss_sum": train_energy_loss, + "train_force_loss_sum": train_force_loss, + "num_cal": num_cal, + "train_term_losses_sum": train_term_losses, + "train_term_num_cal": train_term_num_cal, + } + return self._aggregate_train_metrics(metrics) + + def _compute_loss_for_sub_batch( + self, + sub_batch: Dict[str, Any], + criterion, + term_criterion, + total_batch_size: int, + total_term_batch_size: Dict[str, int], + device_out, + model: nn.Module, + ): + """Compute loss (and side losses) for a single sub-batch.""" + args = self.args + + force = sub_batch.pop("force").reshape(-1, sub_batch["force"].shape[-1]).to(device_out) + + force_weights = None + if args.use_force_weights: + force_weights = sub_batch.pop("forces_weights").reshape(-1).to(device_out) + + energy = None + if args.energy_matching: + energy = sub_batch.pop("energy").reshape(-1, sub_batch["energy"].shape[-1]).to(device_out) + + term_targets = {} + for k in self.extra_train_terms: + term_targets[k] = sub_batch.pop(k).flatten().to(device_out) + + out_energy, out_force, extra = model(**sub_batch) + + sub_batch_size = force.numel() + energy_loss = torch.tensor(0.0, device=out_force.device) + if args.energy_matching: + energy_loss = criterion(out_energy, energy) * (sub_batch_size / total_batch_size) + + force_loss = term_criterion(out_force, force) + if force_weights is not None: + force_loss = force_loss * force_weights[:, None] + force_loss = force_loss.mean() * (sub_batch_size / total_batch_size) + + loss = args.energy_weight * energy_loss + args.force_weight * force_loss + + term_loss_dict: Dict[str, Tensor] = {} + for k in self.extra_train_terms: + if self.args.train_term_def.wrap(k): + tl = (extra[k] - term_targets[k] + torch.pi) % (2 * torch.pi) - torch.pi + tl = tl ** 2 + else: + tl = term_criterion(extra[k], term_targets[k]) - # We don't multiply by numel here because the term loss criterion doesn't do a mean reduction - train_term_loss = train_term_loss / total_term_batch_size[k] - # Mask out undefined values - # TODO: Ensure this is a good threshold value - train_term_loss = train_term_loss * (term_targets[k] >= -10).float() - train_term_loss = torch.sum(train_term_loss) + tl = tl / total_term_batch_size[k] + tl = tl * (term_targets[k] >= -10).float() + tl = torch.sum(tl) - loss = loss + train_term_loss*train_term_def.get_scale(k) + loss = loss + tl * self.args.train_term_def.scale(k) + term_loss_dict[k] = tl - train_term_losses[k] += train_term_loss.item() * total_term_batch_size[k] + return loss, energy_loss, force_loss, term_loss_dict - # Accumulate gradient - loss.backward() - - - - optimizer.step() - if scheduler: - scheduler.step_batch(epoch + i/len(train_data)) - if dry_run: - print("\nDry run OK!") - sys.exit(0) + def _aggregate_train_metrics(self, metrics: Dict[str, Any]) -> Dict[str, Any]: + """All-reduce training metrics across processes.""" + if not self.dist.enabled: + return metrics + + device = self.model_mgr.device + metrics["train_loss_sum"] = self.dist_mgr.all_reduce_sum(metrics["train_loss_sum"], device) + metrics["num_cal"] = self.dist_mgr.all_reduce_sum(metrics["num_cal"], device) + metrics["train_energy_loss_sum"] = self.dist_mgr.all_reduce_sum(metrics["train_energy_loss_sum"], device) + metrics["train_force_loss_sum"] = self.dist_mgr.all_reduce_sum(metrics["train_force_loss_sum"], device) + + for k in metrics["train_term_losses_sum"]: + metrics["train_term_losses_sum"][k] = self.dist_mgr.all_reduce_sum(metrics["train_term_losses_sum"][k], device) + metrics["train_term_num_cal"][k] = int(self.dist_mgr.all_reduce_sum(float(metrics["train_term_num_cal"][k]), device)) - if verbose_loss_report: - desc = [f"Training ({epoch}/{epochs}) (T={train_loss/num_cal:.4f}"] - for t in train_term_losses: - desc.append(f"{t}={train_term_losses[t]/train_term_num_cal[t]:.4f}") - desc = ", ".join(desc) + ")" - tqdm_iter.set_description(desc) + return metrics + def _append_train_metrics(self, train_metrics: Dict[str, Any]) -> None: + """Append normalized train/energy/force losses to history.""" + num_cal = train_metrics["num_cal"] + self.history.train.append(train_metrics["train_loss_sum"] / num_cal) + self.history.energy.append(train_metrics["train_energy_loss_sum"] / num_cal) + self.history.force.append(train_metrics["train_force_loss_sum"] / num_cal) - train_loss_list.append(train_loss/num_cal) - energy_loss_list.append(train_energy_loss/num_cal) - force_loss_list.append(train_force_loss/num_cal) + def _validate_one_epoch(self, criterion, term_criterion) -> Dict[str, Any]: + """Validate for one epoch, returning aggregated metrics.""" + args = self.args + model = self.model_mgr.parallel + device_out = self.model_mgr.device_output - parallel_model.eval() - val_loss = 0 - num_cal = 0 + val_loss = 0.0 + num_cal = 0.0 - val_term_losses = {k: 0.0 for k in train_term_losses} - val_term_num_cal = {k: 0 for k in train_term_num_cal} + val_term_losses = {k: 0.0 for k in self.extra_train_terms} + val_term_num_cal = {k: 0 for k in self.extra_train_terms} - for batch in tqdm(val_data, desc=f"Validation ({epoch}/{epochs})", total=len(val_data), dynamic_ncols=True): - total_batch_size = sum([i["force"].numel() for i in batch]) + if self.dist.is_main: + it = tqdm(self.data_bundle.val, desc=f"Validation ({self.epoch}/{args.epochs})", total=len(self.data_bundle.val), dynamic_ncols=True) + else: + it = self.data_bundle.val + + for batch in it: + total_batch_size = sum([sb["force"].numel() for sb in batch]) num_cal += total_batch_size - total_term_batch_size = {k: sum([i[k].numel() for i in batch]) for k in val_term_losses} - for k in val_term_num_cal.keys(): + total_term_batch_size = {k: sum([sb[k].numel() for sb in batch]) for k in val_term_losses} + for k in val_term_num_cal: val_term_num_cal[k] += total_term_batch_size[k] for sub_batch in batch: - force = sub_batch.pop("force") - force = force.reshape(-1, force.shape[-1]).to(device_output) + force = sub_batch.pop("force").reshape(-1, sub_batch["force"].shape[-1]).to(device_out) + + force_weights = None + if args.use_force_weights: + force_weights = sub_batch.pop("forces_weights").reshape(-1).to(device_out) + energy = None - if energy_matching: - energy = sub_batch.pop("energy") - energy = energy.reshape(-1, energy.shape[-1]).to(device_output) + if args.energy_matching: + energy = sub_batch.pop("energy").reshape(-1, sub_batch["energy"].shape[-1]).to(device_out) term_targets = {} - for k in val_term_losses.keys(): - term_targets[k] = sub_batch.pop(k).flatten().to(device_output) + for k in self.extra_train_terms: + term_targets[k] = sub_batch.pop(k).flatten().to(device_out) - out_energy, out_force, extra = parallel_model(**sub_batch) + out_energy, out_force, extra = model(**sub_batch) sub_batch_size = force.numel() - energy_loss: torch.Tensor = torch.tensor(0.0) - if energy_matching: + energy_loss = torch.tensor(0.0, device=out_force.device) + if args.energy_matching: energy_loss = criterion(out_energy, energy) * (sub_batch_size / total_batch_size) - force_loss = criterion(out_force, force) * (sub_batch_size / total_batch_size) - loss = energy_weight * energy_loss + force_weight * force_loss - - val_loss += loss.item() * total_batch_size - - for k in val_term_losses.keys(): - if train_term_def.get_angle_wrap(k): - val_term_loss = (extra[k] - term_targets[k] + torch.pi) % (2*torch.pi) - torch.pi - val_term_loss = val_term_loss**2 - else: - val_term_loss = term_criterion(extra[k], term_targets[k]) - val_term_loss = val_term_loss / total_term_batch_size[k] - val_term_loss = val_term_loss * (term_targets[k] >= -10).float() - val_term_loss = torch.sum(val_term_loss) - - # loss = loss + val_term_loss*term_val_weight - val_term_losses[k] += val_term_loss.item() * total_term_batch_size[k] + force_loss = term_criterion(out_force, force) + if force_weights is not None: + force_loss = force_loss * force_weights[:, None] + force_loss = force_loss.mean() * (sub_batch_size / total_batch_size) - val_loss_list.append(val_loss/num_cal) + loss = args.energy_weight * energy_loss + args.force_weight * force_loss + val_loss += float(loss) * total_batch_size - if scheduler: - scheduler.step(val_loss/num_cal) - - epoch_history[f"{epoch}"] = { - "train_loss":train_loss_list[-1], - "val_loss":val_loss_list[-1], - "energy_loss":energy_loss_list[-1], - "force_loss":force_loss_list[-1], - "epoch_len":len(train_data), - "lr":[g['lr'] for g in optimizer.param_groups], - } - - for k in extra_train_terms: - epoch_history[f"{epoch}"][f"train_loss_{k}"] = train_term_losses[k]/train_term_num_cal[k] - epoch_history[f"{epoch}"][f"val_loss_{k}"] = val_term_losses[k]/val_term_num_cal[k] - - with open(epoch_history_path, "w") as f: - json.dump(epoch_history, f, indent=2) - - print(f"Epoch {epoch} - Train Loss: {train_loss_list[-1]} - Val Loss: {val_loss_list[-1]} - time: {round(time.time() - t0,2)}s") - if epoch > 0: - print(f" ∆Train: {train_loss_list[-1]-train_loss_list[-2]} - ∆Val: {val_loss_list[-1] - val_loss_list[-2]}") - # print(f" ∆Energy: {energy_loss_list[-1]-energy_loss_list[-2]} - ∆Force: {force_loss_list[-1] - force_loss_list[-2]}") + for k in self.extra_train_terms: + if self.args.train_term_def.wrap(k): + tl = (extra[k] - term_targets[k] + torch.pi) % (2 * torch.pi) - torch.pi + tl = tl ** 2 + else: + tl = term_criterion(extra[k], term_targets[k]) + + tl = tl / total_term_batch_size[k] + tl = tl * (term_targets[k] >= -10).float() + tl = torch.sum(tl) + val_term_losses[k] += float(tl) * total_term_batch_size[k] + + if self.dist.enabled: + device = self.model_mgr.device + val_loss = self.dist_mgr.all_reduce_sum(val_loss, device) + num_cal = self.dist_mgr.all_reduce_sum(num_cal, device) + for k in val_term_losses: + val_term_losses[k] = self.dist_mgr.all_reduce_sum(val_term_losses[k], device) + val_term_num_cal[k] = int(self.dist_mgr.all_reduce_sum(float(val_term_num_cal[k]), device)) + + return { + "val_loss_sum": val_loss, + "num_cal": num_cal, + "val_loss": val_loss / num_cal, + "val_term_losses_sum": val_term_losses, + "val_term_num_cal": val_term_num_cal, + } - for k in val_term_losses.keys(): - print(f" Train {k} loss={train_term_losses[k]/train_term_num_cal[k]:.4f}") - print(f" Val {k} loss={val_term_losses[k]/val_term_num_cal[k]:.4f}") + def _save_mini_checkpoint(self, batch_i: int, train_loss: float, num_cal: float, mini_train_loss: float, mini_num_cal: float, epoch_len: int) -> None: + """Save and rotate mini-checkpoint (main process only).""" + tmp = os.path.join(self.result_dir, f"checkpoint-{self.epoch}-{batch_i}.pth") + self.checkpointer.save( + path=tmp, + epoch=self.epoch, + model_state=self.model_mgr.state_dict_for_saving(), + optimizer=self.optimizer, + model_conf=self.conf, + scheduler=self.scheduler, + extra={"train_loss": train_loss, "num_cal": num_cal, "i": batch_i}, + ) + + if not self.dist.is_main: + return + + os.replace(tmp, os.path.join(self.result_dir, "checkpoint-mini.pth")) + + self.epoch_history.update( + f"{self.epoch}-{batch_i}", + { + "train_loss": train_loss / num_cal, + "mini_train_loss": (mini_train_loss / mini_num_cal) if mini_num_cal else 0.0, + "epoch_len": epoch_len, + "lr": [g["lr"] for g in self.optimizer.param_groups], + }, + ) + + def _log_epoch(self, train_metrics: Dict[str, Any], val_metrics: Dict[str, Any], t0: float) -> None: + """Print epoch summary and write epoch_history.json (main process only).""" + entry = { + "train_loss": self.history.train[-1], + "val_loss": self.history.val[-1], + "energy_loss": self.history.energy[-1], + "force_loss": self.history.force[-1], + "epoch_len": len(self.data_bundle.train), + "lr": [g["lr"] for g in self.optimizer.param_groups], + } - if check_early_stopping(val_loss_list[first_early_stopping_epoch:], patience=early_stopping): - print("Early stopping triggered.") - break + for k in self.extra_train_terms: + entry[f"train_loss_{k}"] = train_metrics["train_term_losses_sum"][k] / max(1, train_metrics["train_term_num_cal"][k]) + entry[f"val_loss_{k}"] = val_metrics["val_term_losses_sum"][k] / max(1, val_metrics["val_term_num_cal"][k]) + + self.epoch_history.update(str(self.epoch), entry) + + print( + f"Epoch {self.epoch} - Train Loss: {self.history.train[-1]} - Val Loss: {self.history.val[-1]} - time: {round(time.time() - t0, 2)}s" + ) + if self.epoch > 0: + print(f" ∆Train: {self.history.train[-1]-self.history.train[-2]} - ∆Val: {self.history.val[-1]-self.history.val[-2]}") + + for k in self.extra_train_terms: + print(f" Train {k} loss={entry[f'train_loss_{k}']:.4f}") + print(f" Val {k} loss={entry[f'val_loss_{k}']:.4f}") + + def _save_epoch(self, train_metrics: Dict[str, Any], val_metrics: Dict[str, Any]) -> None: + """Save checkpoint/history (main process only), with barriers handled by caller.""" + tmp = os.path.join(self.result_dir, f"checkpoint-{self.epoch}.pth") + + self.checkpointer.save( + path=tmp, + epoch=self.epoch + 1, + model_state=self.model_mgr.state_dict_for_saving(), + optimizer=self.optimizer, + model_conf=self.conf, + scheduler=self.scheduler, + ) + + if not self.dist.is_main: + return + + mini = os.path.join(self.result_dir, "checkpoint-mini.pth") + if os.path.exists(mini): + os.unlink(mini) + + main = os.path.join(self.result_dir, "checkpoint.pth") + if self.args.checkpoint_save and (self.epoch % self.args.checkpoint_save == 0): + shutil.copyfile(tmp, main) + else: + os.replace(tmp, main) - history = {"train": train_loss_list, "val": val_loss_list, "energy": energy_loss_list, "force": force_loss_list} + best = os.path.join(self.result_dir, "checkpoint-best.pth") + if self.history.val[-1] <= float(np.min(self.history.val)): + shutil.copyfile(main, best) - # Save the model - # I've attempted to make this compatible with the TorchMD calculators.External class, but I'm not sure how well the keys match - Daniel - tmp_checkpoint_path = f'{result_directory}/checkpoint-{epoch}.pth' - save_checkpoint(tmp_checkpoint_path, epoch + 1, model, optimizer, conf, scheduler) + self.history.save(self.result_dir) + print(" Checkpoint saved.") - if os.path.exists(f'{result_directory}/checkpoint-mini.pth'): - os.unlink(f'{result_directory}/checkpoint-mini.pth') +# ----------------------------- CLI glue ------------------------------------ - if checkpoint_save and (epoch % checkpoint_save == 0): - shutil.copyfile(tmp_checkpoint_path, f'{result_directory}/checkpoint.pth') - else: - os.replace(tmp_checkpoint_path, f'{result_directory}/checkpoint.pth') +def parse_args(): + import argparse - # If this is <= to the lowest validation loss seen so far also save it to checkpoint-best.pth - if val_loss_list[-1] <= np.min(val_loss_list): - shutil.copyfile(f'{result_directory}/checkpoint.pth', f'{result_directory}/checkpoint-best.pth') + p = argparse.ArgumentParser(description="Train a CGSchNet network with distributed support (refactored)") + p.add_argument("input", help="Processed data to train on ") + p.add_argument("result", default=None, nargs="?", help="Checkpoint directory to continue") + p.add_argument("-c", "--config", default="../configs/config.yaml", type=str) + + p.add_argument("--gpus", default=None, type=str, help='List of GPUs (e.g. "0,1,2") or "cpu"') + p.add_argument("--batch", type=int, default=50) + p.add_argument("--epochs", type=int, default=25) + p.add_argument("--lr", type=float, default=1e-4) + p.add_argument("--wd", type=float, default=0.0) + p.add_argument("--val-ratio", type=float, default=0.1) + p.add_argument("--apc", "--atoms-per-call", dest="apc", type=int, default=None) + + p.add_argument("--cos-anneal", default=None, help='Cosine anneal: "T_0,T_mult"') + p.add_argument("--cos-lr", default=None, help='Cosine LR: "T_max,eta_min"') + p.add_argument("--exp-lr", default=None, help='Exponential LR: "gamma"') + p.add_argument("--plateau-lr", default=None, help='Plateau LR: "factor,patience,threshold,min_lr"') + + p.add_argument("--dry-run", action="store_true") + p.add_argument("--reset-early-stopping", action="store_true") + p.add_argument("--no-shuffle", action="store_true") + p.add_argument("--mini-epoch", type=int, default=None) + p.add_argument("--early-stopping", type=int, default=1) + p.add_argument("--checkpoint-save", type=int, default=10) + + p.add_argument("--subsetpdbs", default="ok_list.txt", type=str) + p.add_argument("--energy-weight", default=0.0, type=float) + p.add_argument("--force-weight", default=1.0, type=float) + + p.add_argument("--term-def", default=None, type=str) + p.add_argument("--embedding", type=str, default=None) + p.add_argument("--chunk-dataset", type=int, default=None) + p.add_argument("--npfile", action="store_true") + p.add_argument("--use-force-weights", action="store_true") + + p.add_argument("--distributed", action="store_true") + return p.parse_args() + +def relax_open_file_limit() -> None: + """Raise RLIMIT_NOFILE as much as possible.""" + soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) + resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard)) - # Save the loss history - np.save(f'{result_directory}/history.npy', history)#pyright: ignore[reportArgumentType] - print(" Checkpoint saved.") +def parse_gpu_ids(gpus_arg: Optional[str]): + """Parse --gpus argument into "cpu" or list[int].""" + if not gpus_arg: + return "cpu" + if gpus_arg == "cpu": + return "cpu" + return [int(i) for i in gpus_arg.strip().split(",")] - epoch += 1 +def build_train_args(args) -> TrainArgs: + """Translate argparse args into TrainArgs dataclass.""" + lr_scheduler = SchedulerFactory.from_args(args) + train_term_def = TermDef(path=args.term_def) if args.term_def else TermDef() -def should_decay(param_name: str) -> bool: - #usually something like "representation_model.distance_expansion.means" - #want to not decay the embeddings and the biases - parts = param_name.split('.') - assert len(parts) > 0 - if parts[-1] == "bias": - return False - if parts[-2] == "embedding": - return False - if parts[-2] == "distance_expansion": - #not sure for this - return False - assert parts[-1] == "weight" - return True + gpu_ids = parse_gpu_ids(args.gpus) + energy_matching = args.energy_weight != 0.0 -if __name__ == "__main__": - import argparse - parser = argparse.ArgumentParser(description="Train a CGSchNet network") - parser.add_argument("input", help="Processed data to train on ") - parser.add_argument("result", default=None, nargs="?", help="Checkpoint directory to continue") - parser.add_argument("-c", "--config", default="../configs/config.yaml", type=str, help="") - parser.add_argument("--gpus", default=None, type=str, help="List of GPUs to train on (e.g. \"0,1,2\")") - parser.add_argument("--batch", type=int, default=50, help="The batch size to use") - parser.add_argument("--epochs", type=int, default=25, help="The total number of epochs to train for") - parser.add_argument("--lr", type=float, default="1e-4", help="Learning rate") - parser.add_argument("--wd", type=float, default=0, help="Weight decay") - parser.add_argument("--val-ratio", type=float, default=0.1, help="Validation set ratio, should be between 0.0 and 1.0") - parser.add_argument("--apc", "--atoms-per-call", type=int, default=None, help="Number of atoms to include in each sub-batch") - parser.add_argument("--cos-anneal", default=None, help="Train using cosine annealing, parameters are \"T_0,T_mult\"") - parser.add_argument("--cos-lr", default=None, help="Train using a cosine learning rate, parameters are \"T_max,eta_min\"") - parser.add_argument("--exp-lr", default=None, help="Train using a exponential learning rate, parameters are \"gamma\"") - parser.add_argument("--plateau-lr", default=None, help="Train using a plateau learning rate, parameters are \"factor\" \"patience\" \"min_lr\"") - parser.add_argument("--dry-run", action="store_true", help="Do a dry run of the training loop but produce no output") - parser.add_argument("--reset-early-stopping", action="store_true", help="Reset the early stopping check to start from the current epoch") - parser.add_argument("--no-shuffle", action="store_true", help="Do not shuffle the training dataset") - parser.add_argument("--mini-epoch", type=int, default=None, help="Save a mini epoch after every n batches") - parser.add_argument("--early-stopping", type=int, default=1, help="The number of epochs validation loss can increase before triggering early stopping or -1 to disable early stopping (default=1)") - parser.add_argument("--checkpoint-save", type=int, default=10, help="Save a backup checkpoint every n epochs, 0 to disable (default=10)") - parser.add_argument("--subsetpdbs", default='ok_list.txt', type=str, help="Change the pdbid list used when reading in the dataset (default=ok_list.txt)") - parser.add_argument("--energy-weight", default=0.0, type=float, help="Energy Weighting for Loss Function") - parser.add_argument("--force-weight", default=1.0, type=float, help="Force Weighting for Loss Function") - parser.add_argument("--term-def", default=None, type=str, help="The path to a term definition yaml file, which can additional loss terms used during training.") - parser.add_argument("--embedding", type=str, default=None, help="Specify an alternate file to load embeddings from (default: embeddings.npy).") - parser.add_argument("--chunk-dataset", type=int, default=None, help="Break the dataset into chunks of n proteins per batch") - parser.add_argument("--npfile", action="store_true", help="Use file loader instead of mmap to load dataset") + return TrainArgs( + directory_path=args.input, + result_directory=args.result, + conf_path=args.config, + gpu_ids=gpu_ids, + weight_decay=args.wd, + learning_rate=args.lr, + epochs=args.epochs, + batch_size=args.batch, + val_ratio=args.val_ratio, + atoms_per_call=args.apc, + scheduler=lr_scheduler, + dry_run=args.dry_run, + reset_early_stopping=args.reset_early_stopping, + enable_shuffle=not args.no_shuffle, + mini_epoch_size=args.mini_epoch, + early_stopping=args.early_stopping, + checkpoint_save=args.checkpoint_save, + subsetpdbs=args.subsetpdbs, + energy_weight=args.energy_weight, + force_weight=args.force_weight, + energy_matching=energy_matching, + train_term_def=train_term_def, + embedding_filename=args.embedding, + dataset_chunk_size=args.chunk_dataset, + use_npfile=args.npfile, + use_force_weights=args.use_force_weights, + ) + +def main(): + args = parse_args() assert torch.cuda.is_available(), "CUDA is not available, please run on a machine with CUDA or use --gpus cpu" + assert os.path.isdir(args.input), f"Input directory does not exist: {args.input}" + assert os.path.isfile(args.config), f"Config file does not exist: {args.config}" + assert args.checkpoint_save >= 0 - args = parser.parse_args() - - directory_path = args.input - assert os.path.isdir(directory_path), f"Input directory does not exist: {directory_path}" - result_directory = args.result - conf_path = args.config - assert os.path.isfile(conf_path), f"Config file does not exist: {conf_path}" - weight_decay = args.wd - learning_rate = args.lr - if args.gpus: - if args.gpus == "cpu": - gpu_ids = "cpu" - else: - gpu_ids = [int(i) for i in args.gpus.strip().split(",")] - else: - gpu_ids = "cpu" - - epochs = args.epochs - batch_size = args.batch - val_ratio = args.val_ratio - atoms_per_call = args.apc - dry_run = args.dry_run - reset_early_stopping = args.reset_early_stopping - enable_shuffle = not args.no_shuffle - mini_epoch_size = args.mini_epoch - early_stopping = args.early_stopping - checkpoint_save = args.checkpoint_save - assert checkpoint_save >= 0 - - subsetpdbs = args.subsetpdbs - energy_weight = args.energy_weight - force_weight = args.force_weight - energy_matching = args.energy_weight != 0.0 - embedding_filename = args.embedding - dataset_chunk_size = args.chunk_dataset - use_npfile = args.npfile + relax_open_file_limit() - # Relax the maximum number of open files as much as possible - # We will potentially open a lot of files (~4 per molecule per ProteinDataset object) - soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) - resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard)) + dist_mgr = DistributedManager(enable=args.distributed) + dist_mgr.setup() - lr_scheduler = None - if args.cos_anneal: - T_0, T_mult = [int(i) for i in args.cos_anneal.split(",")] - lr_scheduler = SchedulerWrapper_CosineAnnealingWarmRestarts(T_0, T_mult) - if args.cos_lr: - assert lr_scheduler is None - T_max, eta_min = args.cos_lr.split(",") - T_max, eta_min = int(T_max), float(eta_min) - lr_scheduler = SchedulerWrapper_CosineAnnealingLR(T_max, eta_min) - if args.exp_lr: - assert lr_scheduler is None - lr_scheduler = SchedulerWrapper_ExponentialLR(float(args.exp_lr)) - if args.plateau_lr: - assert lr_scheduler is None - factor, patience, threshold, min_lr = args.plateau_lr.split(",") - factor, patience, threshold, min_lr = float(factor), int(patience), float(threshold), float(min_lr) - lr_scheduler = SchedulerWrapper_ReduceLROnPlateau(factor, patience, threshold, min_lr) - - if args.term_def is not None: - train_term_def = TermDef(path=args.term_def) - else: - train_term_def = TermDef() + train_args = build_train_args(args) try: - train_model(directory_path, result_directory=result_directory, conf_path=conf_path, dry_run=dry_run, weight_decay=weight_decay, - learning_rate=learning_rate, gpu_ids=gpu_ids, epochs=epochs, batch_size=batch_size, val_ratio=val_ratio, scheduler=lr_scheduler, - atoms_per_call=atoms_per_call, reset_early_stopping=reset_early_stopping, enable_shuffle=enable_shuffle, - mini_epoch_size=mini_epoch_size, early_stopping=early_stopping, checkpoint_save=checkpoint_save, subsetpdbs=subsetpdbs, energy_weight=energy_weight, - force_weight=force_weight, energy_matching=energy_matching, train_term_def=train_term_def, embedding_filename=embedding_filename, - dataset_chunk_size=dataset_chunk_size, use_npfile=use_npfile) + trainer = Trainer(train_args, dist_mgr) + trainer.run() except Exception as e: - # Uncaught exceptions cause pytorch to hang for quite a while before exiting traceback.print_tb(e.__traceback__) print(e) sys.exit(1) + finally: + dist_mgr.cleanup() + +if __name__ == "__main__": + main() + From 534d5f4edf86d56d4928e241891d6c4cf233ea84 Mon Sep 17 00:00:00 2001 From: Alexander Aghili Date: Tue, 16 Dec 2025 22:59:36 -0800 Subject: [PATCH 2/4] Preprocess upgrade --- preprocess.py | 1811 ++++++++++++++++++++++--------------------------- 1 file changed, 820 insertions(+), 991 deletions(-) diff --git a/preprocess.py b/preprocess.py index d7657e8..83a2921 100755 --- a/preprocess.py +++ b/preprocess.py @@ -1,1106 +1,935 @@ #!/usr/bin/env python3 +from __future__ import annotations + import os -import numpy as np -import yaml import json -import h5py -import mdtraj -from module.torchmd_cg_mappings import CACB_MAP -from module import prior -from module import prior_flex -from module import psfwriter -from module.make_deltaforces import DeltaForces -from module.cg_mapping import CGMapping +import yaml +import glob +import shutil +import pickle import argparse import traceback -import shutil import multiprocessing as mp -from tqdm import tqdm -import glob -import pickle +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple + +import numpy as np +import h5py +import mdtraj import torch -# this can have a small performance hit as it uses the HDD file_system to share data across different processes, but it doesn't lead to "Too many files open" error, which was limiting the max number of parallel processes to 16. -torch.multiprocessing.set_sharing_strategy('file_system') +from tqdm import tqdm + +# Your project imports +from module import prior, prior_flex, psfwriter +from module.make_deltaforces import DeltaForces +from module.cg_mapping import CGMapping +from module.torchmd_cg_mappings import CACB_MAP + +torch.multiprocessing.set_sharing_strategy("file_system") -# Raz Dec 30 2024: turns out that when preprocessing the 6.5k dataset, the last 20 pdbs are taking forever to process due to being very big proteins. In addition, some were giving hdf5 errors since the batch_generate job didn't properly finish, and this means we don't have all the frames, and they will be removed later on before training. So here's the best workflow I found to solve it: -# 1) first run just step 1 on all proteins, and set FILTER_NOT_PROCESSED_STEP_ONE = False. if any pdbs are taking forever (generally the last 20), just kill the job -# 2) re-run the pre-processing for a second time with FILTER_NOT_PROCESSED_STEP_ONE = True -FILTER_NOT_PROCESSED_STEP_ONE = False -# Controls whether to re-generate the priors in step 2 for each of the terms, terms in the list will be loaded from the cache instead of refit -#USE_CACHED_FITS = ['dihedrals', 'angles', 'bonds', 'lj'] -USE_CACHED_FITS = [] +# ----------------------------- +# MPI (optional) +# ----------------------------- +try: + from mpi4py import MPI # type: ignore + MPI_AVAILABLE = True +except Exception: + MPI_AVAILABLE = False + MPI = None # type: ignore -DEVICE_STEP_3 = 'cpu' -#DEVICE_STEP_3 = 'cuda' -DO_STEP_1 = True # whether to do step 1. if you got errors in steps 2-3 and want to resume, set this to False -REGEN_CACHE_FILES = True # whether to re-generate cache files +# ----------------------------- +# Small utilities (pure funcs) +# ----------------------------- +def slice_to_str(s: slice) -> str: + parts = [s.start, s.stop, s.step] + return ":".join("" if p is None else str(p) for p in parts) +def parse_slice(s: str) -> slice: + # "start:stop:step" with empty parts allowed + parts = [p.strip() for p in s.split(":")] + parts += [""] * (3 - len(parts)) + vals = [int(p) if p != "" else None for p in parts[:3]] + return slice(*vals) -def process_init(counter): - """This function sets the worker names such that we can use them to position the tqdm bars""" - with counter.get_lock(): - idx = int(counter.value) - counter.value += 1 - mp.current_process().name = f"PreprocessWorker-{idx}" +def get_prior_params_path(prior_yaml_path: str) -> str: + d, fn = os.path.split(prior_yaml_path) + return os.path.join(d, fn.replace("priors.yaml", "prior_params.json")) +def ensure_dir(path: str) -> None: + os.makedirs(path, exist_ok=True) +def load_h5_traj_slice(path: str, s: slice) -> Tuple[mdtraj.Trajectory, Optional[np.ndarray]]: + base_traj = mdtraj.load_frame(path, 0) + weights = None + with h5py.File(path) as f: + xyz = f["coordinates"][s][:] + time = f["time"][s][:] + unitcell_lengths = unitcell_angles = None + if "cell_lengths" in f: + unitcell_lengths = f["cell_lengths"][s][:] + unitcell_angles = f["cell_angles"][s][:] + if "weight" in f: + weights = f["weight"][s][:] + traj = mdtraj.Trajectory( + xyz, + base_traj.topology, + time=time, + unitcell_lengths=unitcell_lengths, + unitcell_angles=unitcell_angles, + ) + return traj, weights + + +# ----------------------------- +# Config objects +# ----------------------------- +@dataclass(frozen=True) +class RuntimeFlags: + filter_not_processed_step_one: bool = False + do_step_1: bool = True + regen_cache_files: bool = True + device_step_3: str = "cpu" + use_cached_fits: Tuple[str, ...] = field(default_factory=tuple) + +@dataclass(frozen=True) +class DistributedConfig: + # If MPI is available and user passes --mpi, we will split pdbids by rank + use_mpi: bool = False + # If using MPI, you can still have per-rank multiprocessing + per_rank_workers: int = 1 + +@dataclass +class PreprocessConfig: + output_dir: str + temp: int = 300 + frame_slice: slice = slice(None) + optimize_forces: bool = False + use_box: bool = True + prior_plots: bool = True + resume: bool = False + fit_constraints: bool = True + fit_min_cnt: int = 0 + + prior_name: str = "" + prior_file: Optional[str] = None + + +# ----------------------------- +# Mapping definitions (kept minimal) +# ----------------------------- class CGMappingDef_CA: def __init__(self): - residues = ["ALA", "CYS", "ASP", "GLU", "PHE", "GLY", "HIS", "ILE", "LYS", "LEU", "MET", "ASN", "PRO", "HYP", "GLN", "ARG", "SER", "THR", "VAL", "TRP", "TYR"] - # For legacy reasons we have a couple extra ambiguous residues (ASX & GLX) in the embedding map but we do not accept these for parsing - embedding_residues = ["ALA", "ARG", "ASN", "ASP", "ASX", "CYS", "GLU", "GLN", "GLX", "GLY", "HIS", "ILE", "LEU", "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL"] + residues = ["ALA","CYS","ASP","GLU","PHE","GLY","HIS","ILE","LYS","LEU","MET","ASN","PRO","HYP","GLN","ARG","SER","THR","VAL","TRP","TYR"] + embedding_residues = ["ALA","ARG","ASN","ASP","ASX","CYS","GLU","GLN","GLX","GLY","HIS","ILE","LEU","LYS","MET","PHE","PRO","SER","THR","TRP","TYR","VAL"] self.bead_embeddings = {name: [index + 1] for index, name in enumerate(sorted(embedding_residues))} - - # bead_atom_selection: A list of lists, where each inner list is the names of the atoms that will be combined to form the bead self.bead_atom_selection = {k: [["CA"]] for k in residues} - # The type names of beads (will become the atom type/element in the cg topology) self.bead_types = { - "ALA": ["CAA"], - "ARG": ["CAR"], - "ASN": ["CAN"], - "ASP": ["CAD"], - "CYS": ["CAC"], - "GLN": ["CAQ"], - "GLU": ["CAE"], - "GLY": ["CAG"], - "HIS": ["CAH"], - "HSD": ["CAH"], - "ILE": ["CAI"], - "LEU": ["CAL"], - "LYS": ["CAK"], - "MET": ["CAM"], - "PHE": ["CAF"], - "PRO": ["CAP"], - "SER": ["CAS"], - "THR": ["CAT"], - "TRP": ["CAW"], - "TYR": ["CAY"], - "VAL": ["CAV"], + "ALA":["CAA"],"ARG":["CAR"],"ASN":["CAN"],"ASP":["CAD"],"CYS":["CAC"],"GLN":["CAQ"],"GLU":["CAE"], + "GLY":["CAG"],"HIS":["CAH"],"HSD":["CAH"],"ILE":["CAI"],"LEU":["CAL"],"LYS":["CAK"],"MET":["CAM"], + "PHE":["CAF"],"PRO":["CAP"],"SER":["CAS"],"THR":["CAT"],"TRP":["CAW"],"TYR":["CAY"],"VAL":["CAV"], } - # The "atom name" assigned to the beads self.bead_atom_names = {k: ["CA"] for k in residues} self.bead_masses = {k: [12.01] for k in residues} self.bead_backbone_idx = {k: 0 for k in residues} class CGMappingDef_CACB: def __init__(self): - residues = ["ALA", "CYS", "ASP", "GLU", "PHE", "GLY", "HIS", "ILE", "LYS", "LEU", "MET", "ASN", "PRO", "HYP", "GLN", "ARG", "SER", "THR", "VAL", "TRP", "TYR"] - - # bead_atom_selection: A list of lists, where each inner list is the names of the atoms that will be combined to form the bead + residues = ["ALA","CYS","ASP","GLU","PHE","GLY","HIS","ILE","LYS","LEU","MET","ASN","PRO","HYP","GLN","ARG","SER","THR","VAL","TRP","TYR"] self.bead_atom_selection = {k: [["CA"], ["CB"]] for k in residues} self.bead_atom_selection["GLY"] = [["CA"]] - # The type names of beads (will become the atom type/element in the cg topology) self.bead_types = { - "ALA": ["CA", "CBA"], - "ARG": ["CA", "CBR"], - "ASN": ["CA", "CBN"], - "ASP": ["CA", "CBD"], - "CYS": ["CA", "CBC"], - "GLN": ["CA", "CBQ"], - "GLU": ["CA", "CBE"], - "GLY": ["CAG"], - "HIS": ["CA", "CBH"], - "HSD": ["CA", "CBH"], - "ILE": ["CA", "CBI"], - "LEU": ["CA", "CBL"], - "LYS": ["CA", "CBK"], - "MET": ["CA", "CBM"], - "PHE": ["CA", "CBF"], - "PRO": ["CA", "CBP"], - "SER": ["CA", "CBS"], - "THR": ["CA", "CBT"], - "TRP": ["CA", "CBW"], - "TYR": ["CA", "CBY"], - "VAL": ["CA", "CBV"], + "ALA":["CA","CBA"],"ARG":["CA","CBR"],"ASN":["CA","CBN"],"ASP":["CA","CBD"],"CYS":["CA","CBC"], + "GLN":["CA","CBQ"],"GLU":["CA","CBE"],"GLY":["CAG"],"HIS":["CA","CBH"],"HSD":["CA","CBH"], + "ILE":["CA","CBI"],"LEU":["CA","CBL"],"LYS":["CA","CBK"],"MET":["CA","CBM"],"PHE":["CA","CBF"], + "PRO":["CA","CBP"],"SER":["CA","CBS"],"THR":["CA","CBT"],"TRP":["CA","CBW"],"TYR":["CA","CBY"], + "VAL":["CA","CBV"], } - - embedding_map = {k:i for i,k in enumerate(sorted(set.union(*[set(i) for i in self.bead_types.values()])))} - self.bead_embeddings = {k:[embedding_map[i] for i in v] for k, v in self.bead_types.items()} - - # The "atom name" assigned to the beads - self.bead_atom_names = {k: ["CA", "CB"] for k in residues} + embedding_map = {k: i for i, k in enumerate(sorted(set.union(*[set(v) for v in self.bead_types.values()])))} + self.bead_embeddings = {k: [embedding_map[i] for i in v] for k, v in self.bead_types.items()} + self.bead_atom_names = {k: ["CA","CB"] for k in residues} self.bead_atom_names["GLY"] = ["CA"] - self.bead_masses = {k: [12.01]*len(v) for k,v in self.bead_types.items()} + self.bead_masses = {k: [12.01] * len(v) for k, v in self.bead_types.items()} self.bead_backbone_idx = {k: 0 for k in residues} + +# ----------------------------- +# PriorBuilder + PriorFactory +# ----------------------------- class PriorBuilder: - def __init__(self): - self.prior_params = dict() - self.priors = None - self.terms = dict() - self.atom_types = set() - self.fit_constraints = True - self.tag_beta_turns = False - self.min_cnt = 0 - - def select_atoms(self, topology): - """Returns tha atom index to be saved for this prior""" - raise NotImplementedError() - - def map_embeddings(self, selected_atoms, trajectory): - """Generates the embeddings array for the selected atoms""" - raise NotImplementedError() - - def write_psf(self, pdb_file, psf_file): - """Write the .psf file describing the course grain geometry""" - raise NotImplementedError() - - def add_molecule(self, mol, traj, cache_dir): - fit_ok_path = os.path.join(cache_dir, "fit_ok.txt") + """ + Base prior builder. Concrete behavior comes from: + - mapping_def (CA vs CACB) + - a list of term factories (bonds/angles/dihedrals/lj etc) + - prior_params metadata (for downstream delta-forces logic) + """ + def __init__(self, *, mapping_def: Any, prior_params: Dict[str, Any], term_factories: Dict[str, Callable[[], Any]], runtime: RuntimeFlags): + self.mapping_def = mapping_def + self.prior_params: Dict[str, Any] = dict(prior_params) + self.term_factories = dict(term_factories) + self.runtime = runtime + + self.terms: Dict[str, Any] = {k: f() for k, f in self.term_factories.items()} + self.atom_types: set[str] = set() + self.priors: Optional[Dict[str, Any]] = None - if cache_dir and os.path.exists(fit_ok_path): + # mapping + def build_mapping(self, topology) -> CGMapping: + return CGMapping(topology, self.mapping_def) + + def make_mol(self, cg_map: CGMapping): + bonds = "bonds" in self.terms + angles = "angles" in self.terms + dihedrals = "dihedrals" in self.terms + return cg_map.to_mol(bonds=bonds, angles=angles, dihedrals=dihedrals) + + # cache aggregation + def add_molecule(self, mol, traj, cache_dir: str) -> None: + fit_ok_path = os.path.join(cache_dir, "fit_ok.txt") + if os.path.exists(fit_ok_path): os.unlink(fit_ok_path) for term in self.terms.values(): term.add_molecule(mol, traj, cache_dir) - self.atom_types = self.atom_types.union(mol.atomtype) + self.atom_types |= set(mol.atomtype) - if cache_dir: - np.save(os.path.join(cache_dir, "atomtype.npy"), mol.atomtype) - with open(fit_ok_path, "wt", encoding="utf-8") as f: - f.write("ok") + np.save(os.path.join(cache_dir, "atomtype.npy"), mol.atomtype) + with open(fit_ok_path, "wt", encoding="utf-8") as f: + f.write("ok") - def load_molecule_cache(self, cache_dir): + def load_molecule_cache(self, cache_dir: str) -> None: assert os.path.exists(os.path.join(cache_dir, "fit_ok.txt")) atomtype = np.load(os.path.join(cache_dir, "atomtype.npy"), allow_pickle=True) - self.atom_types = self.atom_types.union(atomtype) - + self.atom_types |= set(atomtype) for term in self.terms.values(): term.load_molecule_cache(cache_dir) - def enable_fit_constraints(self, use_constraints): - self.fit_constraints = use_constraints - self.prior_params["fit_constraints"] = self.fit_constraints - - def enable_bond_tags(self, use_tags): - self.tag_beta_turns = use_tags - self.prior_params["tag_beta_turns"] = self.tag_beta_turns - - def set_min_cnt(self, min_cnt): - assert min_cnt >= 0 - self.min_cnt = min_cnt - self.prior_params["min_cnt"] = self.min_cnt + # fitting + def _init_prior_dict(self) -> None: + priors: Dict[str, Any] = {} + priors["atomtypes"] = sorted(self.atom_types) + priors["bonds"] = {} + priors["angles"] = {} + priors["dihedrals"] = {} + priors["lj"] = {} + priors["electrostatics"] = {at: {"charge": 0.0} for at in priors["atomtypes"]} + priors["masses"] = {at: 12.01 for at in priors["atomtypes"]} + self.priors = priors - def fit(self, temperature, plot_dir=None): - self.init_prior_dict() + def fit(self, temperature: int, plot_dir: Optional[str]) -> None: + self._init_prior_dict() assert self.priors is not None + for key, term in self.terms.items(): - if os.path.exists(f"{plot_dir}/prior_{key}.pkl") and (key in USE_CACHED_FITS): - print(f"Used cached fit for {key}...") + cached = plot_dir and os.path.exists(f"{plot_dir}/prior_{key}.pkl") and (key in self.runtime.use_cached_fits) + if cached: with open(f"{plot_dir}/prior_{key}.pkl", "rb") as f: self.priors[key] = pickle.load(f) - else: - print(f"Fitting {key}...") - self.priors[key] = term.get_param(temperature, plot_dir, self.fit_constraints, self.min_cnt) - # pickle the prior for this term + continue + + self.priors[key] = term.get_param( + temperature, + plot_dir, + self.prior_params.get("fit_constraints", True), + self.prior_params.get("min_cnt", 0), + ) + if plot_dir: with open(f"{plot_dir}/prior_{key}.pkl", "wb") as f: pickle.dump(self.priors[key], f) - def init_prior_dict(self): - # Define the force field dict - priors = {} - priors['atomtypes'] = sorted(self.atom_types) - priors['bonds'] = {} - priors['angles'] = {} - priors['dihedrals'] = {} - priors['lj'] = {} - # For mass and charge assume everything is a carbon atom - priors['electrostatics'] = {at: {'charge': 0.0} for at in priors['atomtypes']} - # The mass of carbon used here is the from OpenMM/AMBER-14 value - priors['masses'] = {at: 12.01 for at in priors['atomtypes']} - self.priors = priors - - def save_prior(self, output_path, pdbid): - prefix = "" - if pdbid: - prefix = f"{pdbid}_" - with open(os.path.join(output_path, f"{prefix}priors.yaml"), "w") as f: + def save_prior(self, output_dir: str) -> None: + assert self.priors is not None + with open(os.path.join(output_dir, "priors.yaml"), "w") as f: yaml.dump(self.priors, f) - with open(os.path.join(output_path, f"{prefix}prior_params.json"),"w") as f: - json.dump(self.prior_params, f) + with open(os.path.join(output_dir, "prior_params.json"), "w") as f: + json.dump(self.prior_params, f, indent=2, sort_keys=True) + + # flex special-case: override by composition instead of subclass if you want + def save_prior_flex(self, output_dir: str) -> None: + """ + Mirrors your flex behavior: yaml gets only classical priors, pickle saves nets. + """ + assert self.priors is not None - def make_mol(self, cg_map): - bonds = "bonds" in self.terms - angles = "angles" in self.terms - dihedrals = "dihedrals" in self.terms - return cg_map.to_mol(bonds = bonds, angles = angles, dihedrals = dihedrals) + # Write params + with open(os.path.join(output_dir, "prior_params.json"), "w") as f: + json.dump(self.prior_params, f, indent=2, sort_keys=True) -class Prior_CA(PriorBuilder): - def __init__(self): - super().__init__() - self.prior_params.update({ - "prior_configuration_name": "CA", - "exclusions" : ['bonds'], - "forceterms" : ["bonds"], - }) - self.terms["bonds"] = prior.ParamBondedCalculator() - - def build_mapping(self, topology): - return CGMapping(topology, CGMappingDef_CA()) - - def select_atoms(self, topology): - #TODO: Remove this function (replaced by build_mapping) - return topology.select('name CA and protein') - - def map_embeddings(self, selected_atoms, topology): #pyright: ignore[reportIncompatibleMethodOverride] - #TODO: Remove this function (replaced by build_mapping) - standardResidues = {"ALA", "ARG", "ASN", "ASP", "ASX", "CYS", "GLU", "GLN", "GLX", "GLY", "HIS", "ILE", "LEU", "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL"} - amino_acid_mapping = {name: index + 1 for index, name in enumerate(sorted(standardResidues))} - - result = [] - for a_idx in selected_atoms: - r_name = topology.atom(a_idx).residue.name - result.append(amino_acid_mapping[r_name]) - return np.array(result, dtype=int) - - def write_psf(self, pdb_file, psf_file): - #TODO: Remove this function (replaced by build_mapping) - bonds = "bonds" in self.terms - angles = "angles" in self.terms - dihedrals = "dihedrals" in self.terms - return psfwriter.pdb2psf_CA(pdb_file, psf_file, bonds = bonds, angles = angles, dihedrals = dihedrals, - tag_beta_turns = self.tag_beta_turns) + # Truncate bonds/angles/dihedrals from YAML (classical only) + truncated = dict(self.priors) + truncated.pop("bonds", None) + truncated.pop("angles", None) + truncated.pop("dihedrals", None) -class Prior_CACB(PriorBuilder): - """Implements the torchmd-cg CACB prior""" - def __init__(self): - super().__init__() - self.prior_params.update({ - "prior_configuration_name": "CACB", - "exclusions" : ['bonds'], - "forceterms" : ["bonds"], - }) - self.terms["bonds"] = prior.ParamBondedCalculator() - - def build_mapping(self, topology): - return CGMapping(topology, CGMappingDef_CACB()) - - def select_atoms(self, topology): - #TODO: Remove this function (replaced by build_mapping) - return topology.select('(name CA or name CB) and protein') - - def map_embeddings(self, selected_atoms, topology):#pyright: ignore[reportIncompatibleMethodOverride] - #TODO: Remove this function (replaced by build_mapping) - - # Make a map from embedding name to embedding name number - # e.g. {"CAA":0, "CAC":1, ...} - embedding_map = CACB_MAP - embedding_nums = dict([(k, i) for i, k in enumerate(sorted(set(embedding_map.values())))]) - - result = [] - for a_idx in selected_atoms: - r_name = topology.atom(a_idx).residue.name - a_name = topology.atom(a_idx).name - emb_name = embedding_map[(r_name, a_name)] - result.append(embedding_nums[emb_name]) - return np.array(result, dtype=int) - - def write_psf(self, pdb_file, psf_file): - #TODO: Remove this function (replaced by build_mapping) - bonds = "bonds" in self.terms - angles = "angles" in self.terms - dihedrals = "dihedrals" in self.terms - return psfwriter.pdb2psf_CACB(pdb_file, psf_file, bonds = bonds, angles = angles, dihedrals = dihedrals) + with open(os.path.join(output_dir, "priors.yaml"), "w") as f: + yaml.dump(truncated, f) + + payload = dict(self.priors) + payload["terms"] = self.terms + payload["prior_params"] = self.prior_params + with open(os.path.join(output_dir, "priors.pkl"), "wb") as f: + pickle.dump(payload, f) -class Prior_CACB_lj(Prior_CACB): - """torchmd-cg CACB prior with Bonded & RepulsionCG terms""" - def __init__(self): - super().__init__() - self.prior_params.update({ - "prior_configuration_name": "CACB_lj", - "exclusions" : ['bonds'], - "forceterms" : ['bonds', 'repulsioncg'], - }) - self.terms["bonds"] = prior.ParamBondedCalculator() - self.terms["lj"] = prior.ParamNonbondedCalculator(fit_range=[3, 6]) - -class Prior_CACB_lj_angle_dihedral(Prior_CACB): - """torchmd-cg CACB prior with Bonded, Angle, Dihedral & RepulsionCG terms""" - def __init__(self): - super().__init__() - self.prior_params.update({ - "prior_configuration_name": "CACB_lj_angle_dihedral", - "exclusions" : ['bonds', 'angles', 'dihedrals'], - "forceterms" : ['bonds', 'angles', 'dihedrals', 'repulsioncg'], - }) - self.terms["bonds"] = prior.ParamBondedCalculator() - self.terms["lj"] = prior.ParamNonbondedCalculator(fit_range=[3, 6]) - self.terms["angles"] = prior.ParamAngleCalculator() - self.terms["dihedrals"] = prior.ParamDihedralCalculator() - -class Prior_CA_lj(Prior_CA): - """CA prior with Bonded & RepulsionCG terms""" - def __init__(self): - super().__init__() - self.prior_params.update({ - "prior_configuration_name": "CA_lj", - "exclusions" : ['bonds'], - "forceterms" : ['bonds', 'repulsioncg'], - }) - self.terms["bonds"] = prior.ParamBondedCalculator() - self.terms["lj"] = prior.ParamNonbondedCalculator(fit_range=[3, 6]) - -class Prior_CA_lj_angle(Prior_CA): - """CA prior with Bonded, Angle, and RepulsionCG terms""" - def __init__(self): - super().__init__() - self.prior_params.update({ - "prior_configuration_name": "CA_lj_angle", - "exclusions" : ['bonds', 'angles'], - "forceterms" : ['bonds', 'angles', 'repulsioncg'], - }) - self.terms["bonds"] = prior.ParamBondedCalculator() - self.terms["lj"] = prior.ParamNonbondedCalculator(fit_range=[3, 6]) - self.terms['angles'] = prior.ParamAngleCalculator() - -class Prior_CA_lj_angle_dihedral(Prior_CA): - """torchmd-cg CA prior with Bonded, Angle, Dihedral & RepulsionCG terms""" - def __init__(self): - super().__init__() - self.prior_params.update({ - "prior_configuration_name": "CA_lj_angle_dihedral", - "exclusions" : ['bonds', 'angles', 'dihedrals'], - "forceterms" : ['bonds', 'angles', 'dihedrals', 'repulsioncg'], - }) - self.terms["bonds"] = prior.ParamBondedCalculator() - self.terms["lj"] = prior.ParamNonbondedCalculator(fit_range=[3, 6]) - self.terms["angles"] = prior.ParamAngleCalculator() - self.terms["dihedrals"] = prior.ParamDihedralCalculator() - -class Prior_CA_lj_angle_dihedralX(Prior_CA): - """torchmd-cg CA prior with Bonded, Angle, DihedralX & RepulsionCG terms""" - def __init__(self): - super().__init__() - self.prior_params.update({ - "prior_configuration_name": "CA_lj_angle_dihedralX", - "exclusions" : ['bonds', 'angles', 'dihedrals'], - "forceterms" : ['bonds', 'angles', 'dihedrals', 'repulsioncg'], - }) - self.terms["bonds"] = prior.ParamBondedCalculator() - self.terms["lj"] = prior.ParamNonbondedCalculator(fit_range=[3, 6]) - self.terms["angles"] = prior.ParamAngleCalculator() - self.terms["dihedrals"] = prior.ParamDihedralCalculator(unified=True) - -class Prior_CA_lj_angleXCX_dihedralX(Prior_CA): - """torchmd-cg CA prior with Bonded, Angle, DihedralX & RepulsionCG terms""" - def __init__(self): - super().__init__() - self.prior_params.update({ - "prior_configuration_name": "CA_lj_angleXCX_dihedralX", - "exclusions" : ['bonds', 'angles', 'dihedrals'], - "forceterms" : ['bonds', 'angles', 'dihedrals', 'repulsioncg'], - }) - self.terms["bonds"] = prior.ParamBondedCalculator() - self.terms["lj"] = prior.ParamNonbondedCalculator(fit_range=[3, 6]) - self.terms["angles"] = prior.ParamAngleCalculator(center=True) - self.terms["dihedrals"] = prior.ParamDihedralCalculator(unified=True) - -class Prior_CA_lj_angleXCX_dihedralX_flex(Prior_CA): - """torchmd-cg CA prior with highly flexible Bonded, Angle, DihedralX & RepulsionCG terms that fit the data. +class PriorFactory: """ - def __init__(self): - super().__init__() - self.prior_params.update({ - "prior_configuration_name": "CA_lj_angleXCX_dihedralX_flex", - "exclusions" : ['bonds', 'angles', 'dihedrals'], - "forceterms_nn" : ['bonds', 'angles', 'dihedrals'], - "forceterms_classical": ['repulsioncg'], # changed from lj, would need to re-generated the dataset (Jan 10 2025). repulsioncg is using just the repulsion term from lj. it uses the same parameters as lj, so need to make sure the right function is evaluated. - "external" : True - }) - self.prior_params['forceterms'] = self.prior_params['forceterms_classical'] + self.prior_params['forceterms_nn'] - - self.terms["bonds"] = prior_flex.ParamBondedFlexCalculator() - self.terms["lj"] = prior.ParamNonbondedCalculator(fit_range=[3, 6]) - self.terms["angles"] = prior_flex.ParamAngleFlexCalculator(center=True) - self.terms["dihedrals"] = prior_flex.ParamDihedralFlexCalculator(unified=True) - - # have to override this method since we're saving neural nets as priors - def save_prior(self, output_path, pdbid): - prefix = "" - # if pdbid: - # prefix = f"{pdbid}_" - with open(os.path.join(output_path, f"{prefix}prior_params.json"),"w") as f: - json.dump(self.prior_params, f) - - # print('self.priors', self.priors.keys()) - # remove the dihedrals and bonds from the priors - priorsTruncated = self.priors.copy() - priorsTruncated.pop('dihedrals') - priorsTruncated.pop('bonds') - priorsTruncated.pop('angles') - # print('priorsTruncated', priorsTruncated.keys()) - - # save the classical priors using yaml. this is requires because the classical priors are built from the yaml files - with open(os.path.join(output_path, f"{prefix}priors.yaml"), "w") as f: - yaml.dump(priorsTruncated, f) - - self.priors['terms'] = self.terms - self.priors['prior_params'] = self.prior_params - - # also save with pickle - with open(os.path.join(output_path, f"{prefix}priors.pkl"), "wb") as f: - pickle.dump(self.priors, f) - - - def load_prior_nnets(self, output_path): - # load the prior with pickle - with open(os.path.join(output_path, "priors.pkl"), "rb") as f: - self.priors = pickle.load(f) - - # return self.priors - - # with open(os.path.join(output_path, f"{prefix}priors.pkl"), "wb") as f: - # pickle.dump(self.priors, f) - - - -class Prior_CA_lj_angleXCX_dihedralX_V1(Prior_CA): - """torchmd-cg CA prior with Bonded, Angle, DihedralX & RepulsionCG terms""" - def __init__(self): - super().__init__() - self.prior_params.update({ - "prior_configuration_name": "CA_lj_angleXCX_dihedralX_V1", - "exclusions" : ['bonds', 'angles', '1-4'], - "forceterms" : ['Bonds', 'angles', 'dihedrals', 'RepulsionCG'], - }) - self.terms["bonds"] = prior.ParamBondedCalculator() - self.terms["lj"] = prior.ParamNonbondedCalculator(fit_range=[3, 6]) - self.terms["angles"] = prior.ParamAngleCalculator(center=True) - self.terms["dihedrals"] = prior.ParamDihedralCalculator(unified=True) - -class Prior_CA_lj_bondNull_angleXCX_dihedralX(Prior_CA): - """torchmd-cg CA prior with Angle, DihedralX & RepulsionCG terms (+ bond exclusions)""" - def __init__(self): - super().__init__() - self.prior_params.update({ - "prior_configuration_name": "CA_lj_bondNull_angleXCX_dihedralX", - "exclusions" : ['bonds', 'angles', '1-4'], - "forceterms" : ['Bonds', 'angles', 'dihedrals', 'RepulsionCG'], - }) - self.terms["bonds"] = prior.NullParamBondedCalculator() - self.terms["lj"] = prior.ParamNonbondedCalculator(fit_range=[3, 6]) - self.terms["angles"] = prior.ParamAngleCalculator(center=True) - self.terms["dihedrals"] = prior.ParamDihedralCalculator(unified=True) - -class Prior_CA_lj_bondNull_angleNull_dihedralX(Prior_CA): - """torchmd-cg CA prior with DihedralX & RepulsionCG terms (+ bond & angle exclusions)""" - def __init__(self): - super().__init__() - self.prior_params.update({ - "prior_configuration_name": "CA_lj_bondNull_angleNull_dihedralX", - "exclusions" : ['bonds', 'angles', '1-4'], - "forceterms" : ['Bonds', 'angles', 'dihedrals', 'RepulsionCG'], - }) - self.terms["bonds"] = prior.NullParamBondedCalculator() - self.terms["lj"] = prior.ParamNonbondedCalculator(fit_range=[3, 6]) - self.terms["angles"] = prior.NullParamAngleCalculator() - self.terms["dihedrals"] = prior.ParamDihedralCalculator(unified=True) - -class Prior_CA_lj_bondNull_angleNull_dihedralNull(Prior_CA): - """torchmd-cg CA prior with RepulsionCG terms (+ bond, angle, & dihedral exclusions)""" - def __init__(self): - super().__init__() - self.prior_params.update({ - "prior_configuration_name": "CA_lj_bondNull_angleNull_dihedralNull", - "exclusions" : ['bonds', 'angles', '1-4'], - "forceterms" : ['Bonds', 'angles', 'dihedrals', 'RepulsionCG'], - }) - self.terms["bonds"] = prior.NullParamBondedCalculator() - self.terms["lj"] = prior.ParamNonbondedCalculator(fit_range=[3, 6]) - self.terms["angles"] = prior.NullParamAngleCalculator() - self.terms["dihedrals"] = prior.NullParamDihedralCalculator() - -class Prior_CA_lj_angleNull_dihedralX(Prior_CA): - """torchmd-cg CA prior with Bonded, DihedralX & RepulsionCG terms (+ angle exclusions)""" - def __init__(self): - super().__init__() - self.prior_params.update({ - "prior_configuration_name": "CA_lj_angleNull_dihedralX", - "exclusions" : ['bonds', 'angles', '1-4'], - "forceterms" : ['Bonds', 'angles', 'dihedrals', 'RepulsionCG'], - }) - self.terms["bonds"] = prior.ParamBondedCalculator() - self.terms["lj"] = prior.ParamNonbondedCalculator(fit_range=[3, 6]) - self.terms["angles"] = prior.NullParamAngleCalculator() - self.terms["dihedrals"] = prior.ParamDihedralCalculator(unified=True) - -class Prior_CA_lj_angleNull_dihedralNull(Prior_CA): - """torchmd-cg CA prior with Bonded & RepulsionCG terms (+ angle & dihedral exclusions)""" - def __init__(self): - super().__init__() - self.prior_params.update({ - "prior_configuration_name": "CA_lj_angleNull_dihedralNull", - "exclusions" : ['bonds', 'angles', '1-4'], - "forceterms" : ['Bonds', 'angles', 'dihedrals', 'RepulsionCG'], - }) - self.terms["bonds"] = prior.ParamBondedCalculator() - self.terms["lj"] = prior.ParamNonbondedCalculator(fit_range=[3, 6]) - self.terms["angles"] = prior.NullParamAngleCalculator() - self.terms["dihedrals"] = prior.NullParamDihedralCalculator() - -class Prior_CA_Majewski2022_v0(Prior_CA): - """torchmd-cg CA prior based on the parameters used in (Majewski 2022) - Note this version (v0) has different lj exclusions than the one used in the paper. + Defines priors declaratively (no giant subclass tree). """ - def __init__(self): - super().__init__() - self.prior_params.update({ - "prior_configuration_name": "CA_Majewski2022_v0", - "exclusions" : ['bonds', 'dihedrals'], - "forceterms" : ['bonds', 'dihedrals', 'repulsioncg'], - }) - self.terms["bonds"] = prior.ParamBondedCalculator() - self.terms["lj"] = prior.ParamNonbondedCalculator(fit_range=[3, 6]) - self.terms["dihedrals"] = prior.ParamDihedralCalculator(unified=True, scale=0.5) - -class Prior_CA_Majewski2022_v1(Prior_CA): - """torchmd-cg CA prior based on the parameters used in (Majewski 2022)""" - def __init__(self): - super().__init__() - self.prior_params.update({ - "prior_configuration_name": "CA_Majewski2022_v1", - "exclusions" : ['bonds'], - "forceterms" : ['bonds', 'dihedrals', 'repulsioncg'], - }) - self.terms["bonds"] = prior.ParamBondedCalculator() - self.terms["lj"] = prior.ParamNonbondedCalculator(fit_range=[3, 6], exclusion_terms={"bonds"}) - self.terms["dihedrals"] = prior.ParamDihedralCalculator(unified=True, scale=0.5) - -class Prior_CA_null(Prior_CA): - """CA prior with no terms""" - def __init__(self): - super().__init__() - self.prior_params.update({ - "prior_configuration_name": "CA_null", - "exclusions" : [], - "forceterms" : [], - }) - self.terms = {} - -class Prior_CA_lj_only(Prior_CA): - """CA prior with just a RepulsionCG term""" - def __init__(self): - super().__init__() - self.prior_params.update({ - "prior_configuration_name": "CA_lj_only", - "exclusions" : [], - "forceterms" : ['RepulsionCG'], - }) - self.terms = {} - self.terms["lj"] = prior.ParamNonbondedCalculator(fit_range=[3, 6]) - -def slice_to_str(s): - result = [s.start, s.stop, s.step] - result = [str(i) if i is not None else '' for i in result] - return ":".join(result) - -def get_prior_params_path(prior_path): - dir_path, file_name = os.path.split(prior_path) - file_name = file_name.replace("priors.yaml", "prior_params.json") - return os.path.join(dir_path, file_name) - -def load_h5_traj_slice(path, slice): - """Load a slice from a h5 trajectory without reading the entire file into memory""" - base_traj = mdtraj.load_frame(path, 0) - with h5py.File(path) as f: - t_xyz = f["coordinates"][slice][:] #pyright: ignore[reportIndexIssue] - t_time = f["time"][slice][:] #pyright: ignore[reportIndexIssue] + @staticmethod + def make(prior_name: str, cfg: PreprocessConfig, runtime: RuntimeFlags) -> PriorBuilder: + # Shared term constructors + def bonds(): return prior.ParamBondedCalculator() + def bonds_null(): return prior.NullParamBondedCalculator() + def angles(center=False): return prior.ParamAngleCalculator(center=center) + def angles_null(): return prior.NullParamAngleCalculator() + def dihedrals(unified=False, scale=1.0): return prior.ParamDihedralCalculator(unified=unified, scale=scale) + def dihedrals_null(): return prior.NullParamDihedralCalculator() + def lj(exclusions=None): return prior.ParamNonbondedCalculator(fit_range=[3, 6], exclusion_terms=exclusions or set()) + + def flex_bonds(): return prior_flex.ParamBondedFlexCalculator() + def flex_angles(center=False): return prior_flex.ParamAngleFlexCalculator(center=center) + def flex_dihedrals(unified=False): return prior_flex.ParamDihedralFlexCalculator(unified=unified) + + # Declarative catalog + CATALOG: Dict[str, Dict[str, Any]] = { + # CA + "CA": { + "mapping_def": CGMappingDef_CA(), + "prior_params": {"prior_configuration_name":"CA","exclusions":["bonds"],"forceterms":["bonds"]}, + "terms": {"bonds": bonds}, + }, + "CA_lj": { + "mapping_def": CGMappingDef_CA(), + "prior_params": {"prior_configuration_name":"CA_lj","exclusions":["bonds"],"forceterms":["bonds","repulsioncg"]}, + "terms": {"bonds": bonds, "lj": lj}, + }, + "CA_lj_angleXCX_dihedralX": { + "mapping_def": CGMappingDef_CA(), + "prior_params": {"prior_configuration_name":"CA_lj_angleXCX_dihedralX","exclusions":["bonds","angles","dihedrals"],"forceterms":["bonds","angles","dihedrals","repulsioncg"]}, + "terms": {"bonds": bonds, "lj": lj, "angles": (lambda: angles(center=True)), "dihedrals": (lambda: dihedrals(unified=True))}, + }, + "CA_lj_bondNull_angleNull_dihedralNull": { + "mapping_def": CGMappingDef_CA(), + "prior_params": {"prior_configuration_name":"CA_lj_bondNull_angleNull_dihedralNull","exclusions":["bonds","angles","1-4"],"forceterms":["Bonds","angles","dihedrals","RepulsionCG"]}, + "terms": {"bonds": bonds_null, "lj": lj, "angles": angles_null, "dihedrals": dihedrals_null}, + }, + "CA_Majewski2022_v1": { + "mapping_def": CGMappingDef_CA(), + "prior_params": {"prior_configuration_name":"CA_Majewski2022_v1","exclusions":["bonds"],"forceterms":["bonds","dihedrals","repulsioncg"]}, + "terms": {"bonds": bonds, "lj": (lambda: lj(exclusions={"bonds"})), "dihedrals": (lambda: dihedrals(unified=True, scale=0.5))}, + }, + "CA_null": { + "mapping_def": CGMappingDef_CA(), + "prior_params": {"prior_configuration_name":"CA_null","exclusions":[],"forceterms":[]}, + "terms": {}, + }, + + # CACB + "CACB": { + "mapping_def": CGMappingDef_CACB(), + "prior_params": {"prior_configuration_name":"CACB","exclusions":["bonds"],"forceterms":["bonds"]}, + "terms": {"bonds": bonds}, + }, + "CACB_lj_angle_dihedral": { + "mapping_def": CGMappingDef_CACB(), + "prior_params": {"prior_configuration_name":"CACB_lj_angle_dihedral","exclusions":["bonds","angles","dihedrals"],"forceterms":["bonds","angles","dihedrals","repulsioncg"]}, + "terms": {"bonds": bonds, "lj": lj, "angles": angles, "dihedrals": dihedrals}, + }, + + # FLEX + "CA_lj_angleXCX_dihedralX_flex": { + "mapping_def": CGMappingDef_CA(), + "prior_params": { + "prior_configuration_name":"CA_lj_angleXCX_dihedralX_flex", + "exclusions":["bonds","angles","dihedrals"], + "forceterms_nn":["bonds","angles","dihedrals"], + "forceterms_classical":["repulsioncg"], + "external": True, + }, + "terms": { + "bonds": flex_bonds, + "angles": (lambda: flex_angles(center=True)), + "dihedrals": (lambda: flex_dihedrals(unified=True)), + "lj": lj, + }, + }, + } - t_unitcell_lengths = None - t_unitcell_angles = None - if "cell_lengths" in f.keys(): - t_unitcell_lengths = f["cell_lengths"][slice][:] #pyright: ignore[reportIndexIssue] - t_unitcell_angles = f["cell_angles"][slice][:] #pyright: ignore[reportIndexIssue] + if prior_name not in CATALOG: + raise RuntimeError(f"Unknown prior configuration: {prior_name}") - result = mdtraj.Trajectory(t_xyz, base_traj.topology, time=t_time, unitcell_lengths=t_unitcell_lengths, unitcell_angles=t_unitcell_angles) - return result + spec = CATALOG[prior_name] -class Preprocessor: - def __init__(self, dataset_conf, input_path_map, save_path, prior_builder, prior_file, prior_name, frame_slice, temp, optimize_forces, box, prior_plots, resume_preprocess, num_cores, jobid=None, totalNrJobs=None): - self.dataset_conf = dataset_conf - self.save_path = save_path - self.prior_builder = prior_builder - self.prior_file = prior_file - self.frame_slice = frame_slice - self.temp = temp - self.jobid = jobid - self.totalNrJobs = totalNrJobs - - self.pdbid_list = input_path_map - - if FILTER_NOT_PROCESSED_STEP_ONE: - pdbs_processed_step1 = [f.split('/')[-3] for f in glob.glob(os.path.join(save_path, "*/fit/fit_ok.txt"))] - # remove keys that are not in pdbs_processed_step1 - self.pdbid_list = {k: v for k, v in self.pdbid_list.items() if k in pdbs_processed_step1} - print('%d pdbs left after removing pdbs not processed in step 1' % len(self.pdbid_list)) - - # if os.path.exists(os.path.join(save_path, 'pdb_list.pkl')): - # with open(os.path.join(save_path, 'pdb_list.pkl'), 'rb') as f: - # self.pdbid_list = pickle.load(f) - # else: - # self.pdbid_list = self.get_pdbid_list() - - # if FILTER_NOT_PROCESSED_STEP_ONE: - # pdbs_processed_step1 = [f.split('/')[-3] for f in glob.glob(os.path.join(save_path, "*/fit/fit_ok.txt"))] - # # remove keys that are not in pdbs_processed_step1 - # self.pdbid_list = {k: v for k, v in self.pdbid_list.items() if k in pdbs_processed_step1} - # print('%d pdbs left after removing pdbs not processed in step 1' % len(self.pdbid_list)) - - # # pickle pdb_list - # os.makedirs(save_path, exist_ok=True) - # with open(os.path.join(save_path, 'pdb_list.pkl'), 'wb') as f: - # pickle.dump(self.pdbid_list, f) - - - self.optimize_forces = optimize_forces - self.box = box - self.prior_plots = prior_plots - self.resume_preprocess = resume_preprocess - self.num_cores = num_cores - - print("Input directory paths:", [i["path"] for i in self.dataset_conf]) - print("Save directory path:", self.save_path) - print(f"Temperature: {self.temp}") - print("Frame slice:", slice_to_str(self.frame_slice)) - print("Number of cores used for parallelization:", self.num_cores) - # print("PDB ID list:", self.pdbid_list) - - def step1_threading(self, pdbid): - try: - cache_dir = os.path.join(self.save_path, pdbid, "fit") - if not(self.resume_preprocess and os.path.exists(os.path.join(cache_dir, "fit_ok.txt"))): - # This assumes we've named the processes during initialization - bar_pos = int(mp.current_process().name.split("-")[1]) + 1 - self.process_step1(pdbid, bar_pos) - return [] - except Exception as e: - traceback.print_tb(e.__traceback__) - print(f"{pdbid}:", e) - raise - - def step3_threading(self, pdbid): - try: - # This assumes we've named the processes during initialization - bar_pos = int(mp.current_process().name.split("-")[1]) + 1 - self.process_step3(pdbid, bar_pos) - except Exception as e: - traceback.print_tb(e.__traceback__) - print(f"{pdbid}:", e) - raise - - def preprocess(self): - os.makedirs(os.path.join(self.save_path, "result"), exist_ok=True) - - info_dict = { - "input_paths": [i["path"] for i in self.dataset_conf], - "frame_slice": slice_to_str(self.frame_slice), - "pdbids": list(self.pdbid_list.keys()), - "optimize_forces": self.optimize_forces, - "box": self.box, - "prior_name": prior_name - } + # inject fit settings + prior_params = dict(spec["prior_params"]) + prior_params["fit_constraints"] = cfg.fit_constraints + prior_params["min_cnt"] = cfg.fit_min_cnt + + # flex convenience: build forceterms union + if prior_params.get("external"): + prior_params["forceterms"] = prior_params["forceterms_classical"] + prior_params["forceterms_nn"] - # If resuming, validate that no paramters that would invalidate the fit cache object have changed - # FIXME: This should also check "--tag-beta-turns" - if self.resume_preprocess: - if os.path.exists(os.path.join(self.save_path, "result/info.json")): - with open(os.path.join(self.save_path, "result/info.json"), "rt", encoding="utf-8") as f: - prevous_info = json.load(f) - for k in ["box", "frame_slice", "optimize_forces", "prior_name"]: - assert info_dict[k] == prevous_info[k], \ - f"Can't resume with different parameters: {k}: {info_dict[k]} != {prevous_info[k]}" + return PriorBuilder( + mapping_def=spec["mapping_def"], + prior_params=prior_params, + term_factories=spec["terms"], + runtime=runtime, + ) - with open(os.path.join(self.save_path, "result/info.json"), "wt", encoding="utf-8") as f: - json.dump(info_dict, f) - pdbids = self.pdbid_list +# ----------------------------- +# Execution backends +# ----------------------------- +class ExecContext: + def __init__(self, dist: DistributedConfig): + self.dist = dist + self.rank = 0 + self.world = 1 + self.comm = None - # Ensure all jobs have the same tqdm lock : https://github.com/tqdm/tqdm/issues/982 + if dist.use_mpi: + if not MPI_AVAILABLE: + raise RuntimeError("MPI requested but mpi4py is not installed.") + self.comm = MPI.COMM_WORLD + self.rank = self.comm.Get_rank() + self.world = self.comm.Get_size() + + def barrier(self) -> None: + if self.comm: + self.comm.Barrier() + + def is_root(self) -> bool: + return self.rank == 0 + + def split_items(self, items: List[str]) -> List[str]: + # deterministic split by rank (round-robin keeps load more even with variable protein sizes) + if self.world == 1: + return items + return [x for i, x in enumerate(items) if (i % self.world) == self.rank] + + +class ParallelRunner: + """ + One DRY runner for: + - serial + - multiprocessing per rank + """ + def __init__(self, workers: int): + self.workers = max(1, int(workers)) + + @staticmethod + def _proc_init(counter: mp.Value) -> None: + with counter.get_lock(): + idx = int(counter.value) + counter.value += 1 + mp.current_process().name = f"PreprocessWorker-{idx}" + + @staticmethod + def worker_bar_position(default: int = 1) -> int: + try: + return int(mp.current_process().name.split("-")[1]) + 1 + except Exception: + return default + + def run( + self, + items: List[str], + fn: Callable[[str], None], + *, + desc: str, + show_progress: bool = True, + ) -> Dict[str, str]: + if self.workers == 1: + errors: Dict[str, str] = {} + it = tqdm(items, desc=desc, dynamic_ncols=True) if show_progress else items + for x in it: + try: + fn(x) + except Exception as e: + traceback.print_tb(e.__traceback__) + errors[x] = str(e) + return errors + + # multiprocessing tqdm.get_lock() + errors: Dict[str, str] = {} + counter = mp.Value("i", 0, lock=True) + + with tqdm(total=len(items), desc=desc, dynamic_ncols=True, disable=not show_progress) as pbar: + with mp.Pool(self.workers, initializer=self._proc_init, initargs=(counter,)) as pool: + pending: Dict[mp.pool.ApplyResult, str] = { + pool.apply_async(fn, args=(x,)): x for x in items + } + while pending: + for res in list(pending.keys()): + if not res.ready(): + continue + x = pending.pop(res) + try: + res.get() + except Exception as e: + traceback.print_tb(e.__traceback__) + errors[x] = str(e) + pbar.n = len(items) - len(pending) + pbar.refresh() + if pending: + next(iter(pending)).wait(1) + return errors - # TODO: Print exceptions in the main thread for legibility - # TODO: Abstract the loop logic instead of repeating it twice - # Truncate any existing ok_list.txt - with open(os.path.join(self.save_path, "result/ok_list.txt"), "wt", encoding="utf-8") as ok_list: - pass +# ----------------------------- +# Dataset discovery +# ----------------------------- +class DatasetIndex: + def __init__(self, dataset_conf: List[Dict[str, Any]]): + self.dataset_conf = dataset_conf + def build_input_map(self) -> Dict[str, str]: + pdbid_mapping: Dict[str, str] = {} + for entry in self.dataset_conf: + input_path = entry["path"] + prefix = entry.get("prefix", "") + suffix = entry.get("suffix", "") + assert os.path.isdir(input_path), f"Input path does not exist: {input_path}" + + if "pdbids" in entry: + for dir_name in entry["pdbids"]: + input_h5 = os.path.join(input_path, dir_name, "result", f"output_{dir_name}.h5") + assert os.path.exists(input_h5), f"Requested path missing: {input_h5}" + pdbid_mapping[prefix + dir_name + suffix] = input_h5 + continue + for dir_name in sorted(os.listdir(input_path)): + input_h5 = os.path.join(input_path, dir_name, "result", f"output_{dir_name}.h5") + if os.path.exists(input_h5): + pdbid_mapping[prefix + dir_name + suffix] = input_h5 + return pdbid_mapping - # Run step 1 in parallel, saving the results to the cache - - if DO_STEP_1: - errorList = {} - thread_counter = mp.Value('i', 0, lock=True) - with tqdm(total=len(pdbids), desc="Processing Step 1", dynamic_ncols=True) as pbar: - with mp.Pool(self.num_cores, initializer=process_init, initargs=(thread_counter,)) as pool: - pending_results = {} - - # Submit tasks and map results to pdbids - for pdbid in pdbids: - result = pool.apply_async(self.step1_threading, args=(pdbid,)) - pending_results[result] = pdbid - - while pending_results: - # Check completed tasks - for result in list(pending_results.keys()): # Iterate over a copy to allow removal - if result.ready(): - try: - result.get() # Retrieve result or raise exception - except Exception as e: - pdbid = pending_results[result] # Get the corresponding pdbid - errorList[pdbid] = str(e) - finally: - del pending_results[result] # Remove completed task - - # Update the progress bar - pbar.n = len(pdbids) - len(pending_results) - pbar.refresh() - - # Wait for 1 second or for the last job to finish - if pending_results: - next(iter(pending_results)).wait(1) - - if len(errorList): - print('errorList', errorList) - print('errorList keys', errorList.keys()) - - if not self.prior_file: - # pickle prior builder - if not os.path.exists(self.save_path + '/prior_builder.pkl') or REGEN_CACHE_FILES: - # Merge cache files back into prior builder - for pdbid in tqdm(pdbids, desc="Merging cache files together"): - cache_dir = os.path.join(self.save_path, pdbid, "fit") - self.prior_builder.load_molecule_cache(cache_dir) - - with open(self.save_path + '/prior_builder.pkl', 'wb') as f: - pickle.dump(self.prior_builder, f) - else: - print("Using cached prior_builder object... ") - with open(self.save_path + '/prior_builder.pkl', 'rb') as f: - self.prior_builder = pickle.load(f) +# ----------------------------- +# Core pipeline (OO, modular) +# ----------------------------- +class Preprocessor: + def __init__( + self, + cfg: PreprocessConfig, + runtime: RuntimeFlags, + dist: DistributedConfig, + dataset_conf: List[Dict[str, Any]], + input_map: Dict[str, str], + prior_builder: PriorBuilder, + ): + self.cfg = cfg + self.runtime = runtime + self.exec = ExecContext(dist) + self.runner = ParallelRunner(workers=dist.per_rank_workers) - self.process_step2() + self.dataset_conf = dataset_conf + self.input_map = dict(input_map) # pdbid -> h5 path + self.prior_builder = prior_builder - self.prior_builder.save_prior(self.save_path, None) + # optional filter: only keep proteins that completed step 1 + if runtime.filter_not_processed_step_one: + processed = {p.split("/")[-3] for p in glob.glob(os.path.join(cfg.output_dir, "*/fit/fit_ok.txt"))} + self.input_map = {k: v for k, v in self.input_map.items() if k in processed} + + # ---------- high level ---------- + def run(self) -> None: + ensure_dir(os.path.join(self.cfg.output_dir, "result")) + all_pdbids = sorted(self.input_map.keys()) + + # MPI split (each rank gets a subset) + pdbids = self.exec.split_items(all_pdbids) + + # only root writes global info.json (avoids collisions) + if self.exec.is_root(): + self._write_run_info(all_pdbids) + + self.exec.barrier() + + # Step 1 (per-rank) + if self.runtime.do_step_1: + errs = self.runner.run( + pdbids, + self._step1_one, + desc=f"[rank {self.exec.rank}] Step 1", + show_progress=True, + ) + if errs and self.exec.is_root(): + print("Step 1 errors (sample):", dict(list(errs.items())[:10])) + + self.exec.barrier() + + # Step 2 (ONLY root fits priors once, then broadcast/filesystem share) + if not self.cfg.prior_file: + if self.exec.is_root(): + self._fit_prior_once(all_pdbids) + self.exec.barrier() else: - prior_params_path = get_prior_params_path(prior_file) - shutil.copy(self.prior_file, os.path.join(self.save_path, "priors.yaml")) - shutil.copy(prior_params_path, os.path.join(self.save_path, "prior_params.json")) - - if self.totalNrJobs: - pdblist = list(pdbids.keys()) - pdbidsPerJob = len(pdblist) // self.totalNrJobs + 1 - jobid = self.jobid - assert jobid is not None - if jobid < self.totalNrJobs - 1: - pdbids_c = [pdblist[i] for i in range(jobid * pdbidsPerJob, (jobid + 1) * pdbidsPerJob)] - else: - pdbids_c = [pdblist[i] for i in range(jobid * pdbidsPerJob, len(pdblist))] - - # filter pdbids for this job only - pdbids = {k: v for k, v in pdbids.items() if k in pdbids_c} - - print(f"Step 3: Processing {len(pdbids)} pdbids") - - # Run step 3 in parallel - thread_counter = mp.Value('i', 0, lock=True) - with tqdm(total=len(pdbids), desc="Processing Step 3", dynamic_ncols=True) as pbar: - with mp.Pool(self.num_cores, initializer=process_init, initargs=(thread_counter,)) as pool: - pending_results = [] - - for pdbid in pdbids: - pending_results += [pool.apply_async(self.step3_threading, args=(pdbid,))] - - while pending_results: - # Check for exceptions - [i.get() for i in pending_results if i.ready()] - # Remove finished jobs from the list - pending_results = [i for i in pending_results if not i.ready()] - pbar.n = len(pdbids) - len(pending_results) - pbar.refresh() - # Wait for 1 second or for the last job to finish - if pending_results: - pending_results[0].wait(1) - - # alternatively, cd to the preprocessed_data directory and run this cmd: - # ls */raw/deltaforces.npy | awk '{print substr($1, 1, 4)}' > result/ok_list.txt - with open(os.path.join(self.save_path, "result/ok_list.txt"), "wt", encoding="utf-8") as ok_list: - ok_list.write("\n".join(pdbids)) - - print("Done!") - - def save_data(self, output_path, trajectory, embeddings, forces, pdbid): - # print(f" {pdbid} (coordinates, forces): {trajectory.xyz.shape}, {forces.shape}") - np.save(f"{output_path}/raw/embeddings.npy", embeddings) - np.save(f"{output_path}/raw/forces.npy", forces) - np.save(f"{output_path}/raw/coordinates.npy", trajectory.xyz) - box_path = f"{output_path}/raw/box.npy" - if self.box: - np.save(box_path, trajectory.unitcell_vectors) + if self.exec.is_root(): + self._copy_prior_files() + self.exec.barrier() + + # Step 3 (per-rank) + errs3 = self.runner.run( + pdbids, + self._step3_one, + desc=f"[rank {self.exec.rank}] Step 3", + show_progress=True, + ) + if errs3 and self.exec.is_root(): + print("Step 3 errors (sample):", dict(list(errs3.items())[:10])) + + self.exec.barrier() + + # root writes ok_list + if self.exec.is_root(): + ok_list_path = os.path.join(self.cfg.output_dir, "result/ok_list.txt") + with open(ok_list_path, "wt", encoding="utf-8") as f: + f.write("\n".join(all_pdbids)) + print("Done!") + + # ---------- info / resume ---------- + def _write_run_info(self, all_pdbids: List[str]) -> None: + info_path = os.path.join(self.cfg.output_dir, "result/info.json") + payload = { + "input_paths": [i["path"] for i in self.dataset_conf], + "frame_slice": slice_to_str(self.cfg.frame_slice), + "pdbids": all_pdbids, + "optimize_forces": self.cfg.optimize_forces, + "box": self.cfg.use_box, + "prior_name": self.cfg.prior_name, + } + + if self.cfg.resume and os.path.exists(info_path): + with open(info_path, "rt", encoding="utf-8") as f: + prev = json.load(f) + for k in ["box", "frame_slice", "optimize_forces", "prior_name"]: + if payload[k] != prev.get(k): + raise AssertionError(f"Can't resume with different parameters: {k}: {payload[k]} != {prev.get(k)}") + + with open(info_path, "wt", encoding="utf-8") as f: + json.dump(payload, f, indent=2, sort_keys=True) + + # truncate ok_list + with open(os.path.join(self.cfg.output_dir, "result/ok_list.txt"), "wt", encoding="utf-8") as _: + pass + + # ---------- step 1 ---------- + def _step1_one(self, pdbid: str) -> None: + cache_dir = os.path.join(self.cfg.output_dir, pdbid, "fit") + if self.cfg.resume and os.path.exists(os.path.join(cache_dir, "fit_ok.txt")): + return + + bar_pos = self.runner.worker_bar_position() + self._process_step1(pdbid, bar_pos) + + def _save_raw( + self, + out_dir: str, + *, + embeddings: np.ndarray, + forces: np.ndarray, + coords: np.ndarray, + box_vectors: Optional[np.ndarray], + weights: Optional[np.ndarray], + ) -> None: + raw = os.path.join(out_dir, "raw") + ensure_dir(raw) + np.save(os.path.join(raw, "embeddings.npy"), embeddings) + np.save(os.path.join(raw, "forces.npy"), forces) + np.save(os.path.join(raw, "coordinates.npy"), coords) + + box_path = os.path.join(raw, "box.npy") + if box_vectors is not None: + np.save(box_path, box_vectors) elif os.path.exists(box_path): os.unlink(box_path) - def process_step1(self, pdbid, bar_position=0): - """Generate the course grained data and topology for the protein, then add it to the prior builder""" + w_path = os.path.join(raw, "forces_weights.npy") + if weights is not None: + np.save(w_path, weights) + elif os.path.exists(w_path): + os.unlink(w_path) + + def _process_step1(self, pdbid: str, bar_position: int) -> None: + input_h5 = self.input_map[pdbid] + out_dir = os.path.join(self.cfg.output_dir, pdbid) + ensure_dir(os.path.join(out_dir, "raw")) + ensure_dir(os.path.join(out_dir, "processed")) - with tqdm(total=7, position=bar_position, desc=f"{pdbid}: File path setup", dynamic_ncols=True, leave=False) as pbar: - def progress_bar_step(msg): + with tqdm(total=7, position=bar_position, desc=f"{pdbid}: step1", dynamic_ncols=True, leave=False) as pbar: + def step(msg: str) -> None: pbar.update(1) pbar.set_description_str(f"{pdbid}: {msg}") - # Set up paths and create directories - output_path = os.path.join(self.save_path, pdbid) + step("Loading trajectory") + aa_traj, aa_weights = load_h5_traj_slice(input_h5, self.cfg.frame_slice) + assert aa_traj.xyz is not None + aa_traj.xyz *= 10.0 # nm -> Å - # TODO: Get rit of the of subdirectories? - os.makedirs(f"{output_path}/raw", exist_ok=True) - os.makedirs(f"{output_path}/processed", exist_ok=True) + step("Building CG mapping") + cg_map = self.prior_builder.build_mapping(aa_traj.topology) + mol = self.prior_builder.make_mol(cg_map) - # Find which path this ID belongs to - input_file_path = self.pdbid_list[pdbid] + # optional debug topology save + mol.write(os.path.join(out_dir, "processed", f"{pdbid}_processed.psf")) + cg_topology_mol = cg_map.to_mol(bonds=True, angles=True, dihedrals=True) + cg_topology_mol.write(os.path.join(out_dir, "processed", "topology.psf")) - progress_bar_step("Loading trajectory") - AAtraj = load_h5_traj_slice(input_file_path, self.frame_slice) - assert AAtraj.xyz is not None # for pyright - AAtraj.xyz *= 10 # convert to angstroms + step("Mapping CG forces") + with h5py.File(input_h5, "r") as f: + forces = f["forces"][self.cfg.frame_slice, :, :] + forces = cg_map.cg_optimal_forces(aa_traj, forces) if self.cfg.optimize_forces else cg_map.cg_forces(forces) + forces = forces * 0.02390057361376673 # kJ/mol/nm -> kcal/mol/Å - progress_bar_step("Building CG mapping") - cg_map = self.prior_builder.build_mapping(AAtraj.topology) - mol = self.prior_builder.make_mol(cg_map) - topology = cg_map.to_mol(bonds=True, angles=True, dihedrals=True) - mol.write(f'{output_path}/processed/{pdbid}_processed.psf') - topology.write(f'{output_path}/processed/topology.psf') # Save the topology for the CG mapping, this is optional but useful for debugging - - # Get the forces - progress_bar_step("Mapping CG forces") - with h5py.File(input_file_path, "r") as f: - forces = f["forces"][self.frame_slice, :, :] #pyright: ignore[reportIndexIssue] - - if self.optimize_forces: - forces = cg_map.cg_optimal_forces(AAtraj, forces) - else: - forces = cg_map.cg_forces(forces) - - assert len(forces) == len(AAtraj) - # Convert from kilojoules/mole/nanometer to kilocalories/mole/angstrom - forces = forces*0.02390057361376673 - - progress_bar_step("Mapping CG coordinates") - xyz = cg_map.cg_positions(AAtraj.xyz) + step("Mapping CG coordinates") + xyz = cg_map.cg_positions(aa_traj.xyz) cg_traj = mdtraj.Trajectory(xyz, topology=cg_map.to_mdtraj()) - if self.box and AAtraj.unitcell_lengths is not None: - cg_traj.unitcell_lengths = AAtraj.unitcell_lengths * 10 - cg_traj.unitcell_angles = AAtraj.unitcell_angles + + if self.cfg.use_box and aa_traj.unitcell_lengths is not None: + cg_traj.unitcell_lengths = aa_traj.unitcell_lengths * 10.0 + cg_traj.unitcell_angles = aa_traj.unitcell_angles else: cg_traj.unitcell_lengths = None - cg_traj.unitcell_angles = None - - # Save the data - progress_bar_step("Saving Data") - self.save_data(output_path, cg_traj, cg_map.embeddings, forces, pdbid) - - # Note: moveaxis creates a view, the original trajectory.xyz is unmodified - assert cg_traj.xyz is not None - mol.coords = np.moveaxis(cg_traj.xyz, 0, -1) - - progress_bar_step("Generating prior fit data") - if not self.prior_file: - cache_dir = os.path.join(output_path, "fit") - os.makedirs(cache_dir, exist_ok=True) + cg_traj.unitcell_angles = None + + cg_weights = None + if aa_weights is not None: + if aa_weights.ndim == 2: + cg_weights = cg_map.cg_weights(aa_weights) + elif aa_weights.ndim == 1: + n_beads = len(cg_map.embeddings) + cg_weights = np.repeat(aa_weights[:, None], n_beads, axis=1) + + step("Saving raw arrays") + self._save_raw( + out_dir, + embeddings=cg_map.embeddings, + forces=forces, + coords=cg_traj.xyz, + box_vectors=(cg_traj.unitcell_vectors if self.cfg.use_box else None), + weights=cg_weights, + ) + + step("Generating prior fit data") + if not self.cfg.prior_file: + cache_dir = os.path.join(out_dir, "fit") + ensure_dir(cache_dir) + mol.coords = np.moveaxis(cg_traj.xyz, 0, -1) self.prior_builder.add_molecule(mol, cg_traj, cache_dir) - def process_step2(self): - """Fit the prior forcefield based on accumulated data""" - if self.prior_plots: - plot_dir = os.path.join(self.save_path, "prior_fit_plots") - os.makedirs(plot_dir, exist_ok=True) - else: - plot_dir = None - - self.prior_builder.fit(self.temp, plot_dir=plot_dir) - - def process_step3(self, pdbid, bar_position=0): - """Save prior focefield and generate delta forces data for each protein""" - output_path = os.path.join(self.save_path, pdbid) - - # Remove legacy prior files if they exist - if os.path.exists(f"{output_path}/raw/{pdbid}_priors.yaml"): - os.unlink(f"{output_path}/raw/{pdbid}_priors.yaml") - if os.path.exists(f"{output_path}/raw/{pdbid}_prior_params.json"): - os.unlink(f"{output_path}/raw/{pdbid}_prior_params.json") - - # Generate delta forces for all atom simulation vs. prior FF - coords_npz = f'{output_path}/raw/coordinates.npy' - forces_npz = f'{output_path}/raw/forces.npy' - delta_forces_npz = f'{output_path}/raw/deltaforces.npy' - prior_energy_npz = f'{output_path}/raw/prior_energy.npy' - box_npz = None - if self.box: - box_npz = f"{output_path}/raw/box.npy" - forcefield = os.path.join(self.save_path, "priors.yaml") - psf_file = f'{output_path}/processed/{pdbid}_processed.psf' - prior_params = self.prior_builder.prior_params - - deltaForcesObj = DeltaForces(DEVICE_STEP_3, psf_file, coords_npz, box_npz) - if 'external' in self.prior_builder.prior_params.keys(): - # forceterms = ['bonds', 'angles', 'dihedrals'] - deltaForcesObj.addExternalForces(forcefield, self.prior_builder.priors['bonds'], self.prior_builder.priors['angles'], self.prior_builder.priors['dihedrals'], forceterms=prior_params["forceterms_nn"], bar_position=bar_position) - - # forceterms = ['repulsioncg'] # update them properly in preprocess.py in the _flex class - deltaForcesObj.computePriorForces(forcefield, exclusions=prior_params["exclusions"], - forceterms=prior_params["forceterms_classical"], bar_position=bar_position) + # ---------- step 2 ---------- + def _fit_prior_once(self, all_pdbids: List[str]) -> None: + # cache prior_builder aggregation + pb_cache = os.path.join(self.cfg.output_dir, "prior_builder.pkl") + if os.path.exists(pb_cache) and not self.runtime.regen_cache_files: + with open(pb_cache, "rb") as f: + self.prior_builder = pickle.load(f) else: - deltaForcesObj.computePriorForces(forcefield, exclusions=prior_params["exclusions"], - forceterms=prior_params["forceterms"], bar_position=bar_position) - - # load MD forces from forces_npz, compute delta forces, and save them in delta_forces_npz - deltaForcesObj.makeAndSaveDeltaForces(forces_npz, delta_forces_npz, prior_energy_npz) - -def gen_input_mapping(conf): - """Find the list of input files for the passed dataset config""" - pdbid_mapping = dict() - for entry in conf: - input_path = entry["path"] - prefix = entry.get("prefix", "") - suffix = entry.get("suffix", "") - assert os.path.isdir(input_path), f"Input path does not exist: {input_path}" - if "pdbids" in entry: - for dir_name in entry["pdbids"]: - input_h5 = os.path.join(input_path, dir_name, "result", f"output_{dir_name}.h5") - assert os.path.exists(input_h5), "Requested path {input_path}/{dir_name} does not exist" - pdbid_mapping[prefix + dir_name + suffix] = input_h5 + for pdbid in tqdm(all_pdbids, desc="Merging fit caches", dynamic_ncols=True): + cache_dir = os.path.join(self.cfg.output_dir, pdbid, "fit") + self.prior_builder.load_molecule_cache(cache_dir) + with open(pb_cache, "wb") as f: + pickle.dump(self.prior_builder, f) + + plot_dir = None + if self.cfg.prior_plots: + plot_dir = os.path.join(self.cfg.output_dir, "prior_fit_plots") + ensure_dir(plot_dir) + + self.prior_builder.fit(self.cfg.temp, plot_dir=plot_dir) + + # Save priors + if self.prior_builder.prior_params.get("external"): + self.prior_builder.save_prior_flex(self.cfg.output_dir) else: - dir_names = os.listdir(input_path) - for dir_name in sorted(dir_names): - input_h5 = os.path.join(input_path, dir_name, "result", f"output_{dir_name}.h5") - if os.path.exists(input_h5): - pdbid_mapping[prefix + dir_name + suffix] = input_h5 - else: - print(f" Skipping \"{dir_name}\" (directory contains no output)") - return pdbid_mapping - -prior_types = { - "CA":Prior_CA, - "CACB":Prior_CACB, - "CACB_lj":Prior_CACB_lj, - "CACB_lj_angle_dihedral":Prior_CACB_lj_angle_dihedral, - "CA_lj":Prior_CA_lj, - "CA_lj_angle":Prior_CA_lj_angle, - "CA_lj_angle_dihedral":Prior_CA_lj_angle_dihedral, - "CA_lj_angle_dihedralX":Prior_CA_lj_angle_dihedralX, - "CA_lj_angleXCX_dihedralX":Prior_CA_lj_angleXCX_dihedralX, - "CA_lj_angleXCX_dihedralX_flex":Prior_CA_lj_angleXCX_dihedralX_flex, - "CA_lj_angleXCX_dihedralX_V1":Prior_CA_lj_angleXCX_dihedralX_V1, - "CA_Majewski2022_v0":Prior_CA_Majewski2022_v0, - "CA_Majewski2022_v1":Prior_CA_Majewski2022_v1, - "CA_lj_bondNull_angleXCX_dihedralX":Prior_CA_lj_bondNull_angleXCX_dihedralX, - "CA_lj_bondNull_angleNull_dihedralX":Prior_CA_lj_bondNull_angleNull_dihedralX, - "CA_lj_bondNull_angleNull_dihedralNull":Prior_CA_lj_bondNull_angleNull_dihedralNull, - "CA_lj_angleNull_dihedralX":Prior_CA_lj_angleNull_dihedralX, - "CA_lj_angleNull_dihedralNull":Prior_CA_lj_angleNull_dihedralNull, - "CA_null":Prior_CA_null, - "CA_lj_only":Prior_CA_lj_only, -} + self.prior_builder.save_prior(self.cfg.output_dir) + + def _copy_prior_files(self) -> None: + assert self.cfg.prior_file is not None + prior_params_path = get_prior_params_path(self.cfg.prior_file) + shutil.copy(self.cfg.prior_file, os.path.join(self.cfg.output_dir, "priors.yaml")) + shutil.copy(prior_params_path, os.path.join(self.cfg.output_dir, "prior_params.json")) + + # ---------- step 3 ---------- + def _step3_one(self, pdbid: str) -> None: + bar_pos = self.runner.worker_bar_position() + self._process_step3(pdbid, bar_pos) + + def _process_step3(self, pdbid: str, bar_position: int) -> None: + out_dir = os.path.join(self.cfg.output_dir, pdbid) + raw_dir = os.path.join(out_dir, "raw") + proc_dir = os.path.join(out_dir, "processed") + + # legacy cleanup + for legacy in [f"{pdbid}_priors.yaml", f"{pdbid}_prior_params.json"]: + p = os.path.join(raw_dir, legacy) + if os.path.exists(p): + os.unlink(p) + + coords_npz = os.path.join(raw_dir, "coordinates.npy") + forces_npz = os.path.join(raw_dir, "forces.npy") + delta_forces_npz = os.path.join(raw_dir, "deltaforces.npy") + prior_energy_npz = os.path.join(raw_dir, "prior_energy.npy") + box_npz = os.path.join(raw_dir, "box.npy") if self.cfg.use_box else None + + forcefield = os.path.join(self.cfg.output_dir, "priors.yaml") + psf_file = os.path.join(proc_dir, f"{pdbid}_processed.psf") -if __name__ == "__main__": + prior_params = self.prior_builder.prior_params + df = DeltaForces(self.runtime.device_step_3, psf_file, coords_npz, box_npz) + + if prior_params.get("external"): + df.addExternalForces( + forcefield, + self.prior_builder.priors["bonds"], + self.prior_builder.priors["angles"], + self.prior_builder.priors["dihedrals"], + forceterms=prior_params["forceterms_nn"], + bar_position=bar_position, + ) + df.computePriorForces( + forcefield, + exclusions=prior_params["exclusions"], + forceterms=prior_params["forceterms_classical"], + bar_position=bar_position, + ) + else: + df.computePriorForces( + forcefield, + exclusions=prior_params["exclusions"], + forceterms=prior_params["forceterms"], + bar_position=bar_position, + ) + + df.makeAndSaveDeltaForces(forces_npz, delta_forces_npz, prior_energy_npz) + + +# ----------------------------- +# CLI +# ----------------------------- +def load_dataset_conf(inputs: List[str]) -> List[Dict[str, Any]]: + conf: List[Dict[str, Any]] = [] + for i in inputs: + if os.path.isfile(i): + with open(i, "r") as f: + conf += yaml.safe_load(f) + else: + conf += [{"path": i}] + return conf - parser = argparse.ArgumentParser(description="Preprocess data.") - parser.add_argument("input", nargs = "+", help="Input directory path") +def main() -> None: + parser = argparse.ArgumentParser(description="Preprocess data (OO, DRY, MPI-enabled).") + parser.add_argument("input", nargs="+", help="Input directories or a YAML config file") parser.add_argument("-o", "--output", required=True, help="Output directory path") - parser.add_argument("--pdbids", nargs="*", help="List of specific PDB IDs to process") - parser.add_argument("--num-frames", "--num_frames", type=int, default=None, help="Number of frames to process") - parser.add_argument("--frame-slice", type=str, default=None, help="Select frames to process using a python slice: start:end:stride") - parser.add_argument("--temp", type=int, default=300, help="Temperature") - parser.add_argument("--prior", type=str, default=None, help="Select the prior forcefield to use, must be one of: " + ", ".join(sorted(prior_types.keys()))) - parser.add_argument("--optimize-forces", action="store_true", help="Use statistically optimal force aggregation (Kramer 2023)") - parser.add_argument("--prior-file", default=None, help="Use PRIOR_FILE instead of fitting a prior") - parser.add_argument('--no-box', default=False, action='store_true', help="Don't use periodic box information") - parser.add_argument('--prior-plots', default=True, action='store_true', help="Save plots of the prior fit functions") - parser.add_argument('--no-prior-plots', dest='prior_plots', action='store_false', help="Don't save plots of the prior fit functions") - parser.add_argument('--no-fit-constraints', default=False, action='store_true', help="Disable range constraints when fitting prior functions") - parser.add_argument('--fit-min-cnt', type=int, default=0, help="Only bins with cnt > min_cnt will be considered when fitting the prior (default 0)") - # parser.add_argument('--tag-beta-turns', default=False, action='store_true', help="Give beta turns a different bond type in the prior") - parser.add_argument('--resume', default=False, action='store_true', help="Resume processing rather than overwriting, all settings must be identical between calls") - parser.add_argument('--num-cores', type=int, default=32, help="Number of cores to be used for parallelization of preprocessing") - parser.add_argument('--jobid', type=int, default=None, help="Integer denoting jobid, if not -1, it will only process a subset of the PDBs") - parser.add_argument('--totalNrJobs', type=int, default=None, help="Integer denoting how many jobs are in total.") - + parser.add_argument("--pdbids", nargs="*", help="Specific PDB IDs to process") + parser.add_argument("--num-frames", type=int, default=None, help="Number of frames to process") + parser.add_argument("--frame-slice", type=str, default=None, help="Frame slice: start:end:stride") + parser.add_argument("--temp", type=int, default=300) + parser.add_argument("--prior", type=str, default=None, help="Prior configuration name") + parser.add_argument("--prior-file", default=None, help="Use a pre-fit priors.yaml (and matching prior_params.json)") + parser.add_argument("--optimize-forces", action="store_true") + parser.add_argument("--no-box", action="store_true") + parser.add_argument("--prior-plots", dest="prior_plots", action="store_true", default=True) + parser.add_argument("--no-prior-plots", dest="prior_plots", action="store_false") + parser.add_argument("--no-fit-constraints", action="store_true") + parser.add_argument("--fit-min-cnt", type=int, default=0) + parser.add_argument("--resume", action="store_true") + parser.add_argument("--use-cached-fits", nargs="*", default=[], help="e.g. bonds angles dihedrals lj") + + # distributed + parser.add_argument("--mpi", action="store_true", help="Enable MPI splitting (requires mpi4py, launch with mpirun)") + parser.add_argument("--workers", type=int, default=32, help="Workers per rank (multiprocessing). Use 1 to disable.") + + # runtime flags (formerly globals) + parser.add_argument("--filter-not-processed-step-one", action="store_true") + parser.add_argument("--skip-step-1", action="store_true") + parser.add_argument("--no-regen-cache-files", action="store_true") + parser.add_argument("--device-step-3", type=str, default="cpu") args = parser.parse_args() print(args) - output_dir = args.output - pdbids = args.pdbids assert not (args.num_frames and args.frame_slice) - if args.num_frames: + if args.num_frames is not None: frame_slice = slice(0, args.num_frames) - elif args.frame_slice: - # Convert the arg string into a slice - frame_slice = slice(*[int(i) if i != '' else None for i in args.frame_slice.split(":") ]) + elif args.frame_slice is not None: + frame_slice = parse_slice(args.frame_slice) else: frame_slice = slice(None) - temp = args.temp - optimize_forces = args.optimize_forces - box = not args.no_box - prior_plots = args.prior_plots - prior_name = args.prior - prior_file = args.prior_file - resume_preprocess = args.resume - num_cores = args.num_cores - jobid = args.jobid - totalNrJobs = args.totalNrJobs - - if prior_file: - assert os.path.exists(prior_file), f"Prior file does not exist: {prior_file}" - prior_params_path = get_prior_params_path(prior_file) - with open(prior_params_path, "r", encoding="utf-8") as f: - prior_params = json.load(f) - prior_configuration_name = prior_params["prior_configuration_name"] - if prior_name is None: - prior_name = prior_configuration_name - elif prior_name != prior_configuration_name: - print() - print(f"WARNING: Prior \"{prior_name}\" differs from the one used to build the prior file \"{prior_configuration_name}\"") - print() - - assert prior_name, " You must specify the prior to use with either --prior or --prior-file" - - if prior_name not in prior_types: - raise RuntimeError(f"Unknown prior configuration: {prior_name}") - print(f"Using prior: {prior_name}") - prior_builder = prior_types[prior_name]() # <- () to instantiate the class - prior_builder.enable_fit_constraints(not args.no_fit_constraints) - # prior_builder.enable_bond_tags(args.tag_beta_turns) - prior_builder.enable_bond_tags(False) - prior_builder.set_min_cnt(args.fit_min_cnt) - - if 'external' in prior_builder.prior_params.keys(): - mp.set_start_method('spawn') - - # Set matplotlib to use a thread safe backend (for prior fit plots) - import matplotlib - matplotlib.use('Agg') - - dataset_conf = [] - - for i in args.input: - if os.path.isfile(i): - with open(args.input[0], "r") as f: - dataset_conf += yaml.safe_load(f) - else: - dataset_conf += [{"path": i}] - input_path_map = gen_input_mapping(dataset_conf) + cfg = PreprocessConfig( + output_dir=args.output, + temp=args.temp, + frame_slice=frame_slice, + optimize_forces=args.optimize_forces, + use_box=not args.no_box, + prior_plots=args.prior_plots, + resume=args.resume, + fit_constraints=not args.no_fit_constraints, + fit_min_cnt=args.fit_min_cnt, + prior_name=args.prior or "", + prior_file=args.prior_file, + ) + + runtime = RuntimeFlags( + filter_not_processed_step_one=args.filter_not_processed_step_one, + do_step_1=not args.skip_step_1, + regen_cache_files=not args.no_regen_cache_files, + device_step_3=args.device_step_3, + use_cached_fits=tuple(args.use_cached_fits), + ) + + dist = DistributedConfig( + use_mpi=args.mpi, + per_rank_workers=max(1, args.workers), + ) + + # If flex prior needs spawn, do it early + # We'll detect it after we build the prior + # (safe to call set_start_method once, guarded) + torch.set_num_threads(1) + + # thread-safe matplotlib backend for plots + import matplotlib + matplotlib.use("Agg") + + # dataset + dataset_conf = load_dataset_conf(args.input) + index = DatasetIndex(dataset_conf) + input_map = index.build_input_map() + if args.pdbids: + input_map = {p: input_map[p] for p in args.pdbids} + + # prior file sanity (if used) + if cfg.prior_file: + assert os.path.exists(cfg.prior_file), f"Prior file does not exist: {cfg.prior_file}" + params_path = get_prior_params_path(cfg.prior_file) + with open(params_path, "r", encoding="utf-8") as f: + params = json.load(f) + if not cfg.prior_name: + cfg.prior_name = params["prior_configuration_name"] + elif cfg.prior_name != params["prior_configuration_name"]: + print(f'WARNING: --prior "{cfg.prior_name}" differs from prior file "{params["prior_configuration_name"]}"') + + assert cfg.prior_name, "You must specify the prior via --prior or --prior-file" + + prior_builder = PriorFactory.make(cfg.prior_name, cfg, runtime) + + if prior_builder.prior_params.get("external"): + # flex uses torch nets; spawn is safer in multiprocess contexts + try: + mp.set_start_method("spawn", force=True) + except Exception: + pass - if pdbids: - input_path_map = {i: input_path_map[i] for i in pdbids} + pre = Preprocessor(cfg, runtime, dist, dataset_conf, input_map, prior_builder) + pre.run() - preprocessor = Preprocessor(dataset_conf, input_path_map, output_dir, prior_builder, prior_file, prior_name, frame_slice, temp, optimize_forces, box, prior_plots, resume_preprocess, num_cores, jobid, totalNrJobs) - preprocessor.preprocess() +if __name__ == "__main__": + main() From e9f1569ec2c5012916227206ba3e8bf101944904 Mon Sep 17 00:00:00 2001 From: Alexander Aghili Date: Wed, 6 May 2026 11:36:48 -0700 Subject: [PATCH 3/4] Updated train.py plus config options --- edit_checkpoint.py | 57 +- learn_prior.py | 57 +- preprocess.py | 1859 ++++++++++++++++++++---------------- simulate.py | 76 +- train.py | 2237 ++++++++++++++++++++------------------------ 5 files changed, 2177 insertions(+), 2109 deletions(-) diff --git a/edit_checkpoint.py b/edit_checkpoint.py index a104b8b..c0e87dc 100755 --- a/edit_checkpoint.py +++ b/edit_checkpoint.py @@ -33,33 +33,52 @@ def remove_checkpoint_keys(checkpoint_path, keys): torch.save(checkpoint_dict, checkpoint_path) -if __name__ == "__main__": +# --------------------------------------------------------------------------- +# Entry-point helpers: collection → validation → processing +# --------------------------------------------------------------------------- + +def build_parser(): import argparse + arg_parser = argparse.ArgumentParser( + description="Inspect or modify a model checkpoint file.") + arg_parser.add_argument("checkpoint_path", help="Path to the model checkpoint (.pth or directory)") + arg_parser.add_argument("--reset-optimizer", action="store_true", help="Remove the optimizer state") + arg_parser.add_argument("--reset-scheduler", action="store_true", help="Remove the scheduler state") + arg_parser.add_argument("--reset-epoch", action="store_true", help="Remove the epoch counter and history") + arg_parser.add_argument("--info", action="store_true", help="Print the current epoch and model config") + return arg_parser - arg_parser = argparse.ArgumentParser() - arg_parser.add_argument("checkpoint_path", help="The model checkpoint path") - arg_parser.add_argument("--reset-optimizer", action="store_true", help="Reset the optimizer state") - arg_parser.add_argument("--reset-scheduler", action="store_true", help="Reset the scheduler state") - arg_parser.add_argument("--reset-epoch", action="store_true", help="Reset the epoch and history") - arg_parser.add_argument("--info", action="store_true", help="Show the current epoch & config") - args = arg_parser.parse_args() +def validate_config(cfg) -> None: + path = cfg.checkpoint_path + is_valid = os.path.isfile(path) or (os.path.isdir(path) and os.path.isfile(os.path.join(path, "checkpoint.pth"))) + assert is_valid, f"Checkpoint not found: {path}" - if args.info: - checkpoint = torch.load(args.checkpoint_path, map_location="cpu") - print(f"Epoch: {checkpoint['epoch']}") - print("Config:") - config_str = yaml.dump(checkpoint["hyper_parameters"]) - print("\n".join([" " + i for i in config_str.split("\n")])) +def process_config(cfg) -> list: keys_to_remove = [] - if args.reset_optimizer: + if cfg.reset_optimizer: keys_to_remove.append("optimizer") - if args.reset_scheduler: + if cfg.reset_scheduler: keys_to_remove.append("scheduler") - if args.reset_epoch: + if cfg.reset_epoch: keys_to_remove.append("epoch") + return keys_to_remove + + +if __name__ == "__main__": + from module.base_config import BaseConfig + cfg = BaseConfig(build_parser()) + validate_config(cfg) + + if cfg.info: + checkpoint = torch.load(cfg.checkpoint_path, map_location="cpu") + print(f"Epoch: {checkpoint['epoch']}") + print("Config:") + config_str = yaml.dump(checkpoint["hyper_parameters"]) + print("\n".join([" " + i for i in config_str.split("\n")])) - if len(keys_to_remove): - remove_checkpoint_keys(args.checkpoint_path, keys_to_remove) + keys_to_remove = process_config(cfg) + if keys_to_remove: + remove_checkpoint_keys(cfg.checkpoint_path, keys_to_remove) diff --git a/learn_prior.py b/learn_prior.py index 08b7cfa..333c911 100755 --- a/learn_prior.py +++ b/learn_prior.py @@ -556,29 +556,50 @@ def process_batch(batch, training_mol_list, learnable_ff): ### End Train -if __name__ == "__main__": +# --------------------------------------------------------------------------- +# Entry-point helpers: collection → validation → processing +# --------------------------------------------------------------------------- + +def build_parser(): import argparse parser = argparse.ArgumentParser(description="Train a TorchMD prior forcefield") parser.add_argument("input", help="Preprocessed prior & data to train on") parser.add_argument("result", help="Checkpoint save directory") - parser.add_argument("--pdbids", nargs="*", help="List of specific PDB IDs to process") + parser.add_argument("--pdbids", nargs="*", help="Restrict training to these PDB IDs") parser.add_argument("--batch", type=int, default=10, help="The batch size to use") parser.add_argument("--epochs", type=int, default=25, help="The total number of epochs to train for") parser.add_argument("--lr", type=float, default=0.005, help="Learning rate") - parser.add_argument("--gamma", type=float, default=None, help="Learning rate scheduler gamma") - parser.add_argument("--clip-parameters", action='store_true', help="Clip parameters to a valid range after each update") + parser.add_argument("--gamma", type=float, default=None, help="Exponential LR scheduler gamma") + parser.add_argument("--clip-parameters", action='store_true', + help="Clip parameters to a valid range after each update") parser.add_argument('--initial-loss', default=True, action=argparse.BooleanOptionalAction, - help="Whether to calculate initial loss before training. Default=True") - - args = parser.parse_args() - print(args) - - train(data_dir = args.input, - pdb_ids = args.pdbids, - checkpoint_dir = args.result, - n_epochs = args.epochs, - batch_size = args.batch, - learning_rate = args.lr, - scheduler_gamma = args.gamma, - clip_parameters = args.clip_parameters, - initial_loss_calc = args.initial_loss) + help="Calculate loss before the first training step (default=True)") + return parser + + +def validate_config(cfg) -> None: + assert os.path.isdir(cfg.input), f"Input directory does not exist: {cfg.input}" + assert cfg.lr > 0, "--lr must be positive" + assert cfg.epochs > 0, "--epochs must be positive" + + +def process_config(cfg) -> dict: + return dict( + data_dir=cfg.input, + pdb_ids=cfg.pdbids, + checkpoint_dir=cfg.result, + n_epochs=cfg.epochs, + batch_size=cfg.batch, + learning_rate=cfg.lr, + scheduler_gamma=cfg.gamma, + clip_parameters=cfg.clip_parameters, + initial_loss_calc=cfg.initial_loss, + ) + + +if __name__ == "__main__": + from module.base_config import BaseConfig + cfg = BaseConfig(build_parser()) + print(cfg.as_namespace()) + validate_config(cfg) + train(**process_config(cfg)) diff --git a/preprocess.py b/preprocess.py index 83a2921..a8b02e8 100755 --- a/preprocess.py +++ b/preprocess.py @@ -1,935 +1,1146 @@ #!/usr/bin/env python3 -from __future__ import annotations - import os -import json -import yaml -import glob -import shutil -import pickle -import argparse -import traceback -import multiprocessing as mp -from dataclasses import dataclass, field -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple - import numpy as np +import yaml +import json import h5py import mdtraj -import torch -from tqdm import tqdm - -# Your project imports -from module import prior, prior_flex, psfwriter +from module.torchmd_cg_mappings import CACB_MAP +from module import prior +from module import prior_flex +from module import psfwriter from module.make_deltaforces import DeltaForces from module.cg_mapping import CGMapping -from module.torchmd_cg_mappings import CACB_MAP - -torch.multiprocessing.set_sharing_strategy("file_system") +import argparse +import traceback +import shutil +import multiprocessing as mp +from tqdm import tqdm +import glob +import pickle +import torch +# this can have a small performance hit as it uses the HDD file_system to share data across different processes, but it doesn't lead to "Too many files open" error, which was limiting the max number of parallel processes to 16. +torch.multiprocessing.set_sharing_strategy('file_system') +# Raz Dec 30 2024: turns out that when preprocessing the 6.5k dataset, the last 20 pdbs are taking forever to process due to being very big proteins. In addition, some were giving hdf5 errors since the batch_generate job didn't properly finish, and this means we don't have all the frames, and they will be removed later on before training. So here's the best workflow I found to solve it: +# 1) first run just step 1 on all proteins, and set FILTER_NOT_PROCESSED_STEP_ONE = False. if any pdbs are taking forever (generally the last 20), just kill the job +# 2) re-run the pre-processing for a second time with FILTER_NOT_PROCESSED_STEP_ONE = True +FILTER_NOT_PROCESSED_STEP_ONE = False -# ----------------------------- -# MPI (optional) -# ----------------------------- -try: - from mpi4py import MPI # type: ignore - MPI_AVAILABLE = True -except Exception: - MPI_AVAILABLE = False - MPI = None # type: ignore +# Controls whether to re-generate the priors in step 2 for each of the terms, terms in the list will be loaded from the cache instead of refit +#USE_CACHED_FITS = ['dihedrals', 'angles', 'bonds', 'lj'] +USE_CACHED_FITS = [] +DEVICE_STEP_3 = 'cpu' +#DEVICE_STEP_3 = 'cuda' -# ----------------------------- -# Small utilities (pure funcs) -# ----------------------------- -def slice_to_str(s: slice) -> str: - parts = [s.start, s.stop, s.step] - return ":".join("" if p is None else str(p) for p in parts) +DO_STEP_1 = True # whether to do step 1. if you got errors in steps 2-3 and want to resume, set this to False +REGEN_CACHE_FILES = True # whether to re-generate cache files -def parse_slice(s: str) -> slice: - # "start:stop:step" with empty parts allowed - parts = [p.strip() for p in s.split(":")] - parts += [""] * (3 - len(parts)) - vals = [int(p) if p != "" else None for p in parts[:3]] - return slice(*vals) -def get_prior_params_path(prior_yaml_path: str) -> str: - d, fn = os.path.split(prior_yaml_path) - return os.path.join(d, fn.replace("priors.yaml", "prior_params.json")) +def process_init(counter): + """This function sets the worker names such that we can use them to position the tqdm bars""" + with counter.get_lock(): + idx = int(counter.value) + counter.value += 1 + mp.current_process().name = f"PreprocessWorker-{idx}" -def ensure_dir(path: str) -> None: - os.makedirs(path, exist_ok=True) -def load_h5_traj_slice(path: str, s: slice) -> Tuple[mdtraj.Trajectory, Optional[np.ndarray]]: - base_traj = mdtraj.load_frame(path, 0) - weights = None - with h5py.File(path) as f: - xyz = f["coordinates"][s][:] - time = f["time"][s][:] - unitcell_lengths = unitcell_angles = None - if "cell_lengths" in f: - unitcell_lengths = f["cell_lengths"][s][:] - unitcell_angles = f["cell_angles"][s][:] - if "weight" in f: - weights = f["weight"][s][:] - traj = mdtraj.Trajectory( - xyz, - base_traj.topology, - time=time, - unitcell_lengths=unitcell_lengths, - unitcell_angles=unitcell_angles, - ) - return traj, weights - - -# ----------------------------- -# Config objects -# ----------------------------- -@dataclass(frozen=True) -class RuntimeFlags: - filter_not_processed_step_one: bool = False - do_step_1: bool = True - regen_cache_files: bool = True - device_step_3: str = "cpu" - use_cached_fits: Tuple[str, ...] = field(default_factory=tuple) - -@dataclass(frozen=True) -class DistributedConfig: - # If MPI is available and user passes --mpi, we will split pdbids by rank - use_mpi: bool = False - # If using MPI, you can still have per-rank multiprocessing - per_rank_workers: int = 1 - -@dataclass -class PreprocessConfig: - output_dir: str - temp: int = 300 - frame_slice: slice = slice(None) - optimize_forces: bool = False - use_box: bool = True - prior_plots: bool = True - resume: bool = False - fit_constraints: bool = True - fit_min_cnt: int = 0 - - prior_name: str = "" - prior_file: Optional[str] = None - - -# ----------------------------- -# Mapping definitions (kept minimal) -# ----------------------------- class CGMappingDef_CA: def __init__(self): - residues = ["ALA","CYS","ASP","GLU","PHE","GLY","HIS","ILE","LYS","LEU","MET","ASN","PRO","HYP","GLN","ARG","SER","THR","VAL","TRP","TYR"] - embedding_residues = ["ALA","ARG","ASN","ASP","ASX","CYS","GLU","GLN","GLX","GLY","HIS","ILE","LEU","LYS","MET","PHE","PRO","SER","THR","TRP","TYR","VAL"] + residues = ["ALA", "CYS", "ASP", "GLU", "PHE", "GLY", "HIS", "ILE", "LYS", "LEU", "MET", "ASN", "PRO", "HYP", "GLN", "ARG", "SER", "THR", "VAL", "TRP", "TYR"] + # For legacy reasons we have a couple extra ambiguous residues (ASX & GLX) in the embedding map but we do not accept these for parsing + embedding_residues = ["ALA", "ARG", "ASN", "ASP", "ASX", "CYS", "GLU", "GLN", "GLX", "GLY", "HIS", "ILE", "LEU", "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL"] self.bead_embeddings = {name: [index + 1] for index, name in enumerate(sorted(embedding_residues))} + + # bead_atom_selection: A list of lists, where each inner list is the names of the atoms that will be combined to form the bead self.bead_atom_selection = {k: [["CA"]] for k in residues} + # The type names of beads (will become the atom type/element in the cg topology) self.bead_types = { - "ALA":["CAA"],"ARG":["CAR"],"ASN":["CAN"],"ASP":["CAD"],"CYS":["CAC"],"GLN":["CAQ"],"GLU":["CAE"], - "GLY":["CAG"],"HIS":["CAH"],"HSD":["CAH"],"ILE":["CAI"],"LEU":["CAL"],"LYS":["CAK"],"MET":["CAM"], - "PHE":["CAF"],"PRO":["CAP"],"SER":["CAS"],"THR":["CAT"],"TRP":["CAW"],"TYR":["CAY"],"VAL":["CAV"], + "ALA": ["CAA"], + "ARG": ["CAR"], + "ASN": ["CAN"], + "ASP": ["CAD"], + "CYS": ["CAC"], + "GLN": ["CAQ"], + "GLU": ["CAE"], + "GLY": ["CAG"], + "HIS": ["CAH"], + "HSD": ["CAH"], + "ILE": ["CAI"], + "LEU": ["CAL"], + "LYS": ["CAK"], + "MET": ["CAM"], + "PHE": ["CAF"], + "PRO": ["CAP"], + "SER": ["CAS"], + "THR": ["CAT"], + "TRP": ["CAW"], + "TYR": ["CAY"], + "VAL": ["CAV"], } + # The "atom name" assigned to the beads self.bead_atom_names = {k: ["CA"] for k in residues} self.bead_masses = {k: [12.01] for k in residues} self.bead_backbone_idx = {k: 0 for k in residues} class CGMappingDef_CACB: def __init__(self): - residues = ["ALA","CYS","ASP","GLU","PHE","GLY","HIS","ILE","LYS","LEU","MET","ASN","PRO","HYP","GLN","ARG","SER","THR","VAL","TRP","TYR"] + residues = ["ALA", "CYS", "ASP", "GLU", "PHE", "GLY", "HIS", "ILE", "LYS", "LEU", "MET", "ASN", "PRO", "HYP", "GLN", "ARG", "SER", "THR", "VAL", "TRP", "TYR"] + + # bead_atom_selection: A list of lists, where each inner list is the names of the atoms that will be combined to form the bead self.bead_atom_selection = {k: [["CA"], ["CB"]] for k in residues} self.bead_atom_selection["GLY"] = [["CA"]] + # The type names of beads (will become the atom type/element in the cg topology) self.bead_types = { - "ALA":["CA","CBA"],"ARG":["CA","CBR"],"ASN":["CA","CBN"],"ASP":["CA","CBD"],"CYS":["CA","CBC"], - "GLN":["CA","CBQ"],"GLU":["CA","CBE"],"GLY":["CAG"],"HIS":["CA","CBH"],"HSD":["CA","CBH"], - "ILE":["CA","CBI"],"LEU":["CA","CBL"],"LYS":["CA","CBK"],"MET":["CA","CBM"],"PHE":["CA","CBF"], - "PRO":["CA","CBP"],"SER":["CA","CBS"],"THR":["CA","CBT"],"TRP":["CA","CBW"],"TYR":["CA","CBY"], - "VAL":["CA","CBV"], + "ALA": ["CA", "CBA"], + "ARG": ["CA", "CBR"], + "ASN": ["CA", "CBN"], + "ASP": ["CA", "CBD"], + "CYS": ["CA", "CBC"], + "GLN": ["CA", "CBQ"], + "GLU": ["CA", "CBE"], + "GLY": ["CAG"], + "HIS": ["CA", "CBH"], + "HSD": ["CA", "CBH"], + "ILE": ["CA", "CBI"], + "LEU": ["CA", "CBL"], + "LYS": ["CA", "CBK"], + "MET": ["CA", "CBM"], + "PHE": ["CA", "CBF"], + "PRO": ["CA", "CBP"], + "SER": ["CA", "CBS"], + "THR": ["CA", "CBT"], + "TRP": ["CA", "CBW"], + "TYR": ["CA", "CBY"], + "VAL": ["CA", "CBV"], } - embedding_map = {k: i for i, k in enumerate(sorted(set.union(*[set(v) for v in self.bead_types.values()])))} - self.bead_embeddings = {k: [embedding_map[i] for i in v] for k, v in self.bead_types.items()} - self.bead_atom_names = {k: ["CA","CB"] for k in residues} + + embedding_map = {k:i for i,k in enumerate(sorted(set.union(*[set(i) for i in self.bead_types.values()])))} + self.bead_embeddings = {k:[embedding_map[i] for i in v] for k, v in self.bead_types.items()} + + # The "atom name" assigned to the beads + self.bead_atom_names = {k: ["CA", "CB"] for k in residues} self.bead_atom_names["GLY"] = ["CA"] - self.bead_masses = {k: [12.01] * len(v) for k, v in self.bead_types.items()} + self.bead_masses = {k: [12.01]*len(v) for k,v in self.bead_types.items()} self.bead_backbone_idx = {k: 0 for k in residues} - -# ----------------------------- -# PriorBuilder + PriorFactory -# ----------------------------- class PriorBuilder: - """ - Base prior builder. Concrete behavior comes from: - - mapping_def (CA vs CACB) - - a list of term factories (bonds/angles/dihedrals/lj etc) - - prior_params metadata (for downstream delta-forces logic) - """ - def __init__(self, *, mapping_def: Any, prior_params: Dict[str, Any], term_factories: Dict[str, Callable[[], Any]], runtime: RuntimeFlags): - self.mapping_def = mapping_def - self.prior_params: Dict[str, Any] = dict(prior_params) - self.term_factories = dict(term_factories) - self.runtime = runtime - - self.terms: Dict[str, Any] = {k: f() for k, f in self.term_factories.items()} - self.atom_types: set[str] = set() - self.priors: Optional[Dict[str, Any]] = None - - # mapping - def build_mapping(self, topology) -> CGMapping: - return CGMapping(topology, self.mapping_def) - - def make_mol(self, cg_map: CGMapping): - bonds = "bonds" in self.terms - angles = "angles" in self.terms - dihedrals = "dihedrals" in self.terms - return cg_map.to_mol(bonds=bonds, angles=angles, dihedrals=dihedrals) - - # cache aggregation - def add_molecule(self, mol, traj, cache_dir: str) -> None: + def __init__(self): + self.prior_params = dict() + self.priors = None + self.terms = dict() + self.atom_types = set() + self.fit_constraints = True + self.tag_beta_turns = False + self.min_cnt = 0 + + def select_atoms(self, topology): + """Returns tha atom index to be saved for this prior""" + raise NotImplementedError() + + def map_embeddings(self, selected_atoms, trajectory): + """Generates the embeddings array for the selected atoms""" + raise NotImplementedError() + + def write_psf(self, pdb_file, psf_file): + """Write the .psf file describing the course grain geometry""" + raise NotImplementedError() + + def add_molecule(self, mol, traj, cache_dir): fit_ok_path = os.path.join(cache_dir, "fit_ok.txt") - if os.path.exists(fit_ok_path): + + if cache_dir and os.path.exists(fit_ok_path): os.unlink(fit_ok_path) for term in self.terms.values(): term.add_molecule(mol, traj, cache_dir) - self.atom_types |= set(mol.atomtype) + self.atom_types = self.atom_types.union(mol.atomtype) - np.save(os.path.join(cache_dir, "atomtype.npy"), mol.atomtype) - with open(fit_ok_path, "wt", encoding="utf-8") as f: - f.write("ok") + if cache_dir: + np.save(os.path.join(cache_dir, "atomtype.npy"), mol.atomtype) + with open(fit_ok_path, "wt", encoding="utf-8") as f: + f.write("ok") - def load_molecule_cache(self, cache_dir: str) -> None: + def load_molecule_cache(self, cache_dir): assert os.path.exists(os.path.join(cache_dir, "fit_ok.txt")) atomtype = np.load(os.path.join(cache_dir, "atomtype.npy"), allow_pickle=True) - self.atom_types |= set(atomtype) + self.atom_types = self.atom_types.union(atomtype) + for term in self.terms.values(): term.load_molecule_cache(cache_dir) - # fitting - def _init_prior_dict(self) -> None: - priors: Dict[str, Any] = {} - priors["atomtypes"] = sorted(self.atom_types) - priors["bonds"] = {} - priors["angles"] = {} - priors["dihedrals"] = {} - priors["lj"] = {} - priors["electrostatics"] = {at: {"charge": 0.0} for at in priors["atomtypes"]} - priors["masses"] = {at: 12.01 for at in priors["atomtypes"]} - self.priors = priors + def enable_fit_constraints(self, use_constraints): + self.fit_constraints = use_constraints + self.prior_params["fit_constraints"] = self.fit_constraints - def fit(self, temperature: int, plot_dir: Optional[str]) -> None: - self._init_prior_dict() - assert self.priors is not None + def enable_bond_tags(self, use_tags): + self.tag_beta_turns = use_tags + self.prior_params["tag_beta_turns"] = self.tag_beta_turns + def set_min_cnt(self, min_cnt): + assert min_cnt >= 0 + self.min_cnt = min_cnt + self.prior_params["min_cnt"] = self.min_cnt + + def fit(self, temperature, plot_dir=None): + self.init_prior_dict() + assert self.priors is not None for key, term in self.terms.items(): - cached = plot_dir and os.path.exists(f"{plot_dir}/prior_{key}.pkl") and (key in self.runtime.use_cached_fits) - if cached: + if os.path.exists(f"{plot_dir}/prior_{key}.pkl") and (key in USE_CACHED_FITS): + print(f"Used cached fit for {key}...") with open(f"{plot_dir}/prior_{key}.pkl", "rb") as f: self.priors[key] = pickle.load(f) - continue - - self.priors[key] = term.get_param( - temperature, - plot_dir, - self.prior_params.get("fit_constraints", True), - self.prior_params.get("min_cnt", 0), - ) - if plot_dir: + else: + print(f"Fitting {key}...") + self.priors[key] = term.get_param(temperature, plot_dir, self.fit_constraints, self.min_cnt) + # pickle the prior for this term with open(f"{plot_dir}/prior_{key}.pkl", "wb") as f: pickle.dump(self.priors[key], f) - def save_prior(self, output_dir: str) -> None: - assert self.priors is not None - with open(os.path.join(output_dir, "priors.yaml"), "w") as f: - yaml.dump(self.priors, f) - with open(os.path.join(output_dir, "prior_params.json"), "w") as f: - json.dump(self.prior_params, f, indent=2, sort_keys=True) - - # flex special-case: override by composition instead of subclass if you want - def save_prior_flex(self, output_dir: str) -> None: - """ - Mirrors your flex behavior: yaml gets only classical priors, pickle saves nets. - """ - assert self.priors is not None + def init_prior_dict(self): + # Define the force field dict + priors = {} + priors['atomtypes'] = sorted(self.atom_types) + priors['bonds'] = {} + priors['angles'] = {} + priors['dihedrals'] = {} + priors['lj'] = {} + # For mass and charge assume everything is a carbon atom + priors['electrostatics'] = {at: {'charge': 0.0} for at in priors['atomtypes']} + # The mass of carbon used here is the from OpenMM/AMBER-14 value + priors['masses'] = {at: 12.01 for at in priors['atomtypes']} + self.priors = priors - # Write params - with open(os.path.join(output_dir, "prior_params.json"), "w") as f: - json.dump(self.prior_params, f, indent=2, sort_keys=True) + def save_prior(self, output_path, pdbid): + prefix = "" + if pdbid: + prefix = f"{pdbid}_" + with open(os.path.join(output_path, f"{prefix}priors.yaml"), "w") as f: + yaml.dump(self.priors, f) + with open(os.path.join(output_path, f"{prefix}prior_params.json"),"w") as f: + json.dump(self.prior_params, f) - # Truncate bonds/angles/dihedrals from YAML (classical only) - truncated = dict(self.priors) - truncated.pop("bonds", None) - truncated.pop("angles", None) - truncated.pop("dihedrals", None) + def make_mol(self, cg_map): + bonds = "bonds" in self.terms + angles = "angles" in self.terms + dihedrals = "dihedrals" in self.terms + return cg_map.to_mol(bonds = bonds, angles = angles, dihedrals = dihedrals) - with open(os.path.join(output_dir, "priors.yaml"), "w") as f: - yaml.dump(truncated, f) +class Prior_CA(PriorBuilder): + def __init__(self): + super().__init__() + self.prior_params.update({ + "prior_configuration_name": "CA", + "exclusions" : ['bonds'], + "forceterms" : ["bonds"], + }) + self.terms["bonds"] = prior.ParamBondedCalculator() + + def build_mapping(self, topology): + return CGMapping(topology, CGMappingDef_CA()) + + def select_atoms(self, topology): + #TODO: Remove this function (replaced by build_mapping) + return topology.select('name CA and protein') + + def map_embeddings(self, selected_atoms, topology): #pyright: ignore[reportIncompatibleMethodOverride] + #TODO: Remove this function (replaced by build_mapping) + standardResidues = {"ALA", "ARG", "ASN", "ASP", "ASX", "CYS", "GLU", "GLN", "GLX", "GLY", "HIS", "ILE", "LEU", "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL"} + amino_acid_mapping = {name: index + 1 for index, name in enumerate(sorted(standardResidues))} + + result = [] + for a_idx in selected_atoms: + r_name = topology.atom(a_idx).residue.name + result.append(amino_acid_mapping[r_name]) + return np.array(result, dtype=int) + + def write_psf(self, pdb_file, psf_file): + #TODO: Remove this function (replaced by build_mapping) + bonds = "bonds" in self.terms + angles = "angles" in self.terms + dihedrals = "dihedrals" in self.terms + return psfwriter.pdb2psf_CA(pdb_file, psf_file, bonds = bonds, angles = angles, dihedrals = dihedrals, + tag_beta_turns = self.tag_beta_turns) - payload = dict(self.priors) - payload["terms"] = self.terms - payload["prior_params"] = self.prior_params - with open(os.path.join(output_dir, "priors.pkl"), "wb") as f: - pickle.dump(payload, f) +class Prior_CACB(PriorBuilder): + """Implements the torchmd-cg CACB prior""" + def __init__(self): + super().__init__() + self.prior_params.update({ + "prior_configuration_name": "CACB", + "exclusions" : ['bonds'], + "forceterms" : ["bonds"], + }) + self.terms["bonds"] = prior.ParamBondedCalculator() + + def build_mapping(self, topology): + return CGMapping(topology, CGMappingDef_CACB()) + + def select_atoms(self, topology): + #TODO: Remove this function (replaced by build_mapping) + return topology.select('(name CA or name CB) and protein') + + def map_embeddings(self, selected_atoms, topology):#pyright: ignore[reportIncompatibleMethodOverride] + #TODO: Remove this function (replaced by build_mapping) + + # Make a map from embedding name to embedding name number + # e.g. {"CAA":0, "CAC":1, ...} + embedding_map = CACB_MAP + embedding_nums = dict([(k, i) for i, k in enumerate(sorted(set(embedding_map.values())))]) + + result = [] + for a_idx in selected_atoms: + r_name = topology.atom(a_idx).residue.name + a_name = topology.atom(a_idx).name + emb_name = embedding_map[(r_name, a_name)] + result.append(embedding_nums[emb_name]) + return np.array(result, dtype=int) + + def write_psf(self, pdb_file, psf_file): + #TODO: Remove this function (replaced by build_mapping) + bonds = "bonds" in self.terms + angles = "angles" in self.terms + dihedrals = "dihedrals" in self.terms + return psfwriter.pdb2psf_CACB(pdb_file, psf_file, bonds = bonds, angles = angles, dihedrals = dihedrals) +class Prior_CACB_lj(Prior_CACB): + """torchmd-cg CACB prior with Bonded & RepulsionCG terms""" + def __init__(self): + super().__init__() + self.prior_params.update({ + "prior_configuration_name": "CACB_lj", + "exclusions" : ['bonds'], + "forceterms" : ['bonds', 'repulsioncg'], + }) + self.terms["bonds"] = prior.ParamBondedCalculator() + self.terms["lj"] = prior.ParamNonbondedCalculator(fit_range=[3, 6]) + +class Prior_CACB_lj_angle_dihedral(Prior_CACB): + """torchmd-cg CACB prior with Bonded, Angle, Dihedral & RepulsionCG terms""" + def __init__(self): + super().__init__() + self.prior_params.update({ + "prior_configuration_name": "CACB_lj_angle_dihedral", + "exclusions" : ['bonds', 'angles', 'dihedrals'], + "forceterms" : ['bonds', 'angles', 'dihedrals', 'repulsioncg'], + }) + self.terms["bonds"] = prior.ParamBondedCalculator() + self.terms["lj"] = prior.ParamNonbondedCalculator(fit_range=[3, 6]) + self.terms["angles"] = prior.ParamAngleCalculator() + self.terms["dihedrals"] = prior.ParamDihedralCalculator() + +class Prior_CA_lj(Prior_CA): + """CA prior with Bonded & RepulsionCG terms""" + def __init__(self): + super().__init__() + self.prior_params.update({ + "prior_configuration_name": "CA_lj", + "exclusions" : ['bonds'], + "forceterms" : ['bonds', 'repulsioncg'], + }) + self.terms["bonds"] = prior.ParamBondedCalculator() + self.terms["lj"] = prior.ParamNonbondedCalculator(fit_range=[3, 6]) + +class Prior_CA_lj_angle(Prior_CA): + """CA prior with Bonded, Angle, and RepulsionCG terms""" + def __init__(self): + super().__init__() + self.prior_params.update({ + "prior_configuration_name": "CA_lj_angle", + "exclusions" : ['bonds', 'angles'], + "forceterms" : ['bonds', 'angles', 'repulsioncg'], + }) + self.terms["bonds"] = prior.ParamBondedCalculator() + self.terms["lj"] = prior.ParamNonbondedCalculator(fit_range=[3, 6]) + self.terms['angles'] = prior.ParamAngleCalculator() + +class Prior_CA_lj_angle_dihedral(Prior_CA): + """torchmd-cg CA prior with Bonded, Angle, Dihedral & RepulsionCG terms""" + def __init__(self): + super().__init__() + self.prior_params.update({ + "prior_configuration_name": "CA_lj_angle_dihedral", + "exclusions" : ['bonds', 'angles', 'dihedrals'], + "forceterms" : ['bonds', 'angles', 'dihedrals', 'repulsioncg'], + }) + self.terms["bonds"] = prior.ParamBondedCalculator() + self.terms["lj"] = prior.ParamNonbondedCalculator(fit_range=[3, 6]) + self.terms["angles"] = prior.ParamAngleCalculator() + self.terms["dihedrals"] = prior.ParamDihedralCalculator() + +class Prior_CA_lj_angle_dihedralX(Prior_CA): + """torchmd-cg CA prior with Bonded, Angle, DihedralX & RepulsionCG terms""" + def __init__(self): + super().__init__() + self.prior_params.update({ + "prior_configuration_name": "CA_lj_angle_dihedralX", + "exclusions" : ['bonds', 'angles', 'dihedrals'], + "forceterms" : ['bonds', 'angles', 'dihedrals', 'repulsioncg'], + }) + self.terms["bonds"] = prior.ParamBondedCalculator() + self.terms["lj"] = prior.ParamNonbondedCalculator(fit_range=[3, 6]) + self.terms["angles"] = prior.ParamAngleCalculator() + self.terms["dihedrals"] = prior.ParamDihedralCalculator(unified=True) + +class Prior_CA_lj_angleXCX_dihedralX(Prior_CA): + """torchmd-cg CA prior with Bonded, Angle, DihedralX & RepulsionCG terms""" + def __init__(self): + super().__init__() + self.prior_params.update({ + "prior_configuration_name": "CA_lj_angleXCX_dihedralX", + "exclusions" : ['bonds', 'angles', 'dihedrals'], + "forceterms" : ['bonds', 'angles', 'dihedrals', 'repulsioncg'], + }) + self.terms["bonds"] = prior.ParamBondedCalculator() + self.terms["lj"] = prior.ParamNonbondedCalculator(fit_range=[3, 6]) + self.terms["angles"] = prior.ParamAngleCalculator(center=True) + self.terms["dihedrals"] = prior.ParamDihedralCalculator(unified=True) + +class Prior_CA_lj_angleXCX_dihedralX_flex(Prior_CA): + """torchmd-cg CA prior with highly flexible Bonded, Angle, DihedralX & RepulsionCG terms that fit the data. -class PriorFactory: """ - Defines priors declaratively (no giant subclass tree). + def __init__(self): + super().__init__() + self.prior_params.update({ + "prior_configuration_name": "CA_lj_angleXCX_dihedralX_flex", + "exclusions" : ['bonds', 'angles', 'dihedrals'], + "forceterms_nn" : ['bonds', 'angles', 'dihedrals'], + "forceterms_classical": ['repulsioncg'], # changed from lj, would need to re-generated the dataset (Jan 10 2025). repulsioncg is using just the repulsion term from lj. it uses the same parameters as lj, so need to make sure the right function is evaluated. + "external" : True + }) + self.prior_params['forceterms'] = self.prior_params['forceterms_classical'] + self.prior_params['forceterms_nn'] + + self.terms["bonds"] = prior_flex.ParamBondedFlexCalculator() + self.terms["lj"] = prior.ParamNonbondedCalculator(fit_range=[3, 6]) + self.terms["angles"] = prior_flex.ParamAngleFlexCalculator(center=True) + self.terms["dihedrals"] = prior_flex.ParamDihedralFlexCalculator(unified=True) + + # have to override this method since we're saving neural nets as priors + def save_prior(self, output_path, pdbid): + prefix = "" + # if pdbid: + # prefix = f"{pdbid}_" + with open(os.path.join(output_path, f"{prefix}prior_params.json"),"w") as f: + json.dump(self.prior_params, f) + + # print('self.priors', self.priors.keys()) + # remove the dihedrals and bonds from the priors + priorsTruncated = self.priors.copy() + priorsTruncated.pop('dihedrals') + priorsTruncated.pop('bonds') + priorsTruncated.pop('angles') + # print('priorsTruncated', priorsTruncated.keys()) + + # save the classical priors using yaml. this is requires because the classical priors are built from the yaml files + with open(os.path.join(output_path, f"{prefix}priors.yaml"), "w") as f: + yaml.dump(priorsTruncated, f) + + self.priors['terms'] = self.terms + self.priors['prior_params'] = self.prior_params + + # also save with pickle + with open(os.path.join(output_path, f"{prefix}priors.pkl"), "wb") as f: + pickle.dump(self.priors, f) + + + def load_prior_nnets(self, output_path): + # load the prior with pickle + with open(os.path.join(output_path, "priors.pkl"), "rb") as f: + self.priors = pickle.load(f) + + # return self.priors + + # with open(os.path.join(output_path, f"{prefix}priors.pkl"), "wb") as f: + # pickle.dump(self.priors, f) + + + +class Prior_CA_lj_angleXCX_dihedralX_V1(Prior_CA): + """torchmd-cg CA prior with Bonded, Angle, DihedralX & RepulsionCG terms""" + def __init__(self): + super().__init__() + self.prior_params.update({ + "prior_configuration_name": "CA_lj_angleXCX_dihedralX_V1", + "exclusions" : ['bonds', 'angles', '1-4'], + "forceterms" : ['Bonds', 'angles', 'dihedrals', 'RepulsionCG'], + }) + self.terms["bonds"] = prior.ParamBondedCalculator() + self.terms["lj"] = prior.ParamNonbondedCalculator(fit_range=[3, 6]) + self.terms["angles"] = prior.ParamAngleCalculator(center=True) + self.terms["dihedrals"] = prior.ParamDihedralCalculator(unified=True) + +class Prior_CA_lj_bondNull_angleXCX_dihedralX(Prior_CA): + """torchmd-cg CA prior with Angle, DihedralX & RepulsionCG terms (+ bond exclusions)""" + def __init__(self): + super().__init__() + self.prior_params.update({ + "prior_configuration_name": "CA_lj_bondNull_angleXCX_dihedralX", + "exclusions" : ['bonds', 'angles', '1-4'], + "forceterms" : ['Bonds', 'angles', 'dihedrals', 'RepulsionCG'], + }) + self.terms["bonds"] = prior.NullParamBondedCalculator() + self.terms["lj"] = prior.ParamNonbondedCalculator(fit_range=[3, 6]) + self.terms["angles"] = prior.ParamAngleCalculator(center=True) + self.terms["dihedrals"] = prior.ParamDihedralCalculator(unified=True) + +class Prior_CA_lj_bondNull_angleNull_dihedralX(Prior_CA): + """torchmd-cg CA prior with DihedralX & RepulsionCG terms (+ bond & angle exclusions)""" + def __init__(self): + super().__init__() + self.prior_params.update({ + "prior_configuration_name": "CA_lj_bondNull_angleNull_dihedralX", + "exclusions" : ['bonds', 'angles', '1-4'], + "forceterms" : ['Bonds', 'angles', 'dihedrals', 'RepulsionCG'], + }) + self.terms["bonds"] = prior.NullParamBondedCalculator() + self.terms["lj"] = prior.ParamNonbondedCalculator(fit_range=[3, 6]) + self.terms["angles"] = prior.NullParamAngleCalculator() + self.terms["dihedrals"] = prior.ParamDihedralCalculator(unified=True) + +class Prior_CA_lj_bondNull_angleNull_dihedralNull(Prior_CA): + """torchmd-cg CA prior with RepulsionCG terms (+ bond, angle, & dihedral exclusions)""" + def __init__(self): + super().__init__() + self.prior_params.update({ + "prior_configuration_name": "CA_lj_bondNull_angleNull_dihedralNull", + "exclusions" : ['bonds', 'angles', '1-4'], + "forceterms" : ['Bonds', 'angles', 'dihedrals', 'RepulsionCG'], + }) + self.terms["bonds"] = prior.NullParamBondedCalculator() + self.terms["lj"] = prior.ParamNonbondedCalculator(fit_range=[3, 6]) + self.terms["angles"] = prior.NullParamAngleCalculator() + self.terms["dihedrals"] = prior.NullParamDihedralCalculator() + +class Prior_CA_lj_angleNull_dihedralX(Prior_CA): + """torchmd-cg CA prior with Bonded, DihedralX & RepulsionCG terms (+ angle exclusions)""" + def __init__(self): + super().__init__() + self.prior_params.update({ + "prior_configuration_name": "CA_lj_angleNull_dihedralX", + "exclusions" : ['bonds', 'angles', '1-4'], + "forceterms" : ['Bonds', 'angles', 'dihedrals', 'RepulsionCG'], + }) + self.terms["bonds"] = prior.ParamBondedCalculator() + self.terms["lj"] = prior.ParamNonbondedCalculator(fit_range=[3, 6]) + self.terms["angles"] = prior.NullParamAngleCalculator() + self.terms["dihedrals"] = prior.ParamDihedralCalculator(unified=True) + +class Prior_CA_lj_angleNull_dihedralNull(Prior_CA): + """torchmd-cg CA prior with Bonded & RepulsionCG terms (+ angle & dihedral exclusions)""" + def __init__(self): + super().__init__() + self.prior_params.update({ + "prior_configuration_name": "CA_lj_angleNull_dihedralNull", + "exclusions" : ['bonds', 'angles', '1-4'], + "forceterms" : ['Bonds', 'angles', 'dihedrals', 'RepulsionCG'], + }) + self.terms["bonds"] = prior.ParamBondedCalculator() + self.terms["lj"] = prior.ParamNonbondedCalculator(fit_range=[3, 6]) + self.terms["angles"] = prior.NullParamAngleCalculator() + self.terms["dihedrals"] = prior.NullParamDihedralCalculator() + +class Prior_CA_Majewski2022_v0(Prior_CA): + """torchmd-cg CA prior based on the parameters used in (Majewski 2022) + Note this version (v0) has different lj exclusions than the one used in the paper. """ - @staticmethod - def make(prior_name: str, cfg: PreprocessConfig, runtime: RuntimeFlags) -> PriorBuilder: - # Shared term constructors - def bonds(): return prior.ParamBondedCalculator() - def bonds_null(): return prior.NullParamBondedCalculator() - def angles(center=False): return prior.ParamAngleCalculator(center=center) - def angles_null(): return prior.NullParamAngleCalculator() - def dihedrals(unified=False, scale=1.0): return prior.ParamDihedralCalculator(unified=unified, scale=scale) - def dihedrals_null(): return prior.NullParamDihedralCalculator() - def lj(exclusions=None): return prior.ParamNonbondedCalculator(fit_range=[3, 6], exclusion_terms=exclusions or set()) - - def flex_bonds(): return prior_flex.ParamBondedFlexCalculator() - def flex_angles(center=False): return prior_flex.ParamAngleFlexCalculator(center=center) - def flex_dihedrals(unified=False): return prior_flex.ParamDihedralFlexCalculator(unified=unified) - - # Declarative catalog - CATALOG: Dict[str, Dict[str, Any]] = { - # CA - "CA": { - "mapping_def": CGMappingDef_CA(), - "prior_params": {"prior_configuration_name":"CA","exclusions":["bonds"],"forceterms":["bonds"]}, - "terms": {"bonds": bonds}, - }, - "CA_lj": { - "mapping_def": CGMappingDef_CA(), - "prior_params": {"prior_configuration_name":"CA_lj","exclusions":["bonds"],"forceterms":["bonds","repulsioncg"]}, - "terms": {"bonds": bonds, "lj": lj}, - }, - "CA_lj_angleXCX_dihedralX": { - "mapping_def": CGMappingDef_CA(), - "prior_params": {"prior_configuration_name":"CA_lj_angleXCX_dihedralX","exclusions":["bonds","angles","dihedrals"],"forceterms":["bonds","angles","dihedrals","repulsioncg"]}, - "terms": {"bonds": bonds, "lj": lj, "angles": (lambda: angles(center=True)), "dihedrals": (lambda: dihedrals(unified=True))}, - }, - "CA_lj_bondNull_angleNull_dihedralNull": { - "mapping_def": CGMappingDef_CA(), - "prior_params": {"prior_configuration_name":"CA_lj_bondNull_angleNull_dihedralNull","exclusions":["bonds","angles","1-4"],"forceterms":["Bonds","angles","dihedrals","RepulsionCG"]}, - "terms": {"bonds": bonds_null, "lj": lj, "angles": angles_null, "dihedrals": dihedrals_null}, - }, - "CA_Majewski2022_v1": { - "mapping_def": CGMappingDef_CA(), - "prior_params": {"prior_configuration_name":"CA_Majewski2022_v1","exclusions":["bonds"],"forceterms":["bonds","dihedrals","repulsioncg"]}, - "terms": {"bonds": bonds, "lj": (lambda: lj(exclusions={"bonds"})), "dihedrals": (lambda: dihedrals(unified=True, scale=0.5))}, - }, - "CA_null": { - "mapping_def": CGMappingDef_CA(), - "prior_params": {"prior_configuration_name":"CA_null","exclusions":[],"forceterms":[]}, - "terms": {}, - }, - - # CACB - "CACB": { - "mapping_def": CGMappingDef_CACB(), - "prior_params": {"prior_configuration_name":"CACB","exclusions":["bonds"],"forceterms":["bonds"]}, - "terms": {"bonds": bonds}, - }, - "CACB_lj_angle_dihedral": { - "mapping_def": CGMappingDef_CACB(), - "prior_params": {"prior_configuration_name":"CACB_lj_angle_dihedral","exclusions":["bonds","angles","dihedrals"],"forceterms":["bonds","angles","dihedrals","repulsioncg"]}, - "terms": {"bonds": bonds, "lj": lj, "angles": angles, "dihedrals": dihedrals}, - }, - - # FLEX - "CA_lj_angleXCX_dihedralX_flex": { - "mapping_def": CGMappingDef_CA(), - "prior_params": { - "prior_configuration_name":"CA_lj_angleXCX_dihedralX_flex", - "exclusions":["bonds","angles","dihedrals"], - "forceterms_nn":["bonds","angles","dihedrals"], - "forceterms_classical":["repulsioncg"], - "external": True, - }, - "terms": { - "bonds": flex_bonds, - "angles": (lambda: flex_angles(center=True)), - "dihedrals": (lambda: flex_dihedrals(unified=True)), - "lj": lj, - }, - }, - } - - if prior_name not in CATALOG: - raise RuntimeError(f"Unknown prior configuration: {prior_name}") - - spec = CATALOG[prior_name] - - # inject fit settings - prior_params = dict(spec["prior_params"]) - prior_params["fit_constraints"] = cfg.fit_constraints - prior_params["min_cnt"] = cfg.fit_min_cnt - - # flex convenience: build forceterms union - if prior_params.get("external"): - prior_params["forceterms"] = prior_params["forceterms_classical"] + prior_params["forceterms_nn"] - - return PriorBuilder( - mapping_def=spec["mapping_def"], - prior_params=prior_params, - term_factories=spec["terms"], - runtime=runtime, - ) - + def __init__(self): + super().__init__() + self.prior_params.update({ + "prior_configuration_name": "CA_Majewski2022_v0", + "exclusions" : ['bonds', 'dihedrals'], + "forceterms" : ['bonds', 'dihedrals', 'repulsioncg'], + }) + self.terms["bonds"] = prior.ParamBondedCalculator() + self.terms["lj"] = prior.ParamNonbondedCalculator(fit_range=[3, 6]) + self.terms["dihedrals"] = prior.ParamDihedralCalculator(unified=True, scale=0.5) + +class Prior_CA_Majewski2022_v1(Prior_CA): + """torchmd-cg CA prior based on the parameters used in (Majewski 2022)""" + def __init__(self): + super().__init__() + self.prior_params.update({ + "prior_configuration_name": "CA_Majewski2022_v1", + "exclusions" : ['bonds'], + "forceterms" : ['bonds', 'dihedrals', 'repulsioncg'], + }) + self.terms["bonds"] = prior.ParamBondedCalculator() + self.terms["lj"] = prior.ParamNonbondedCalculator(fit_range=[3, 6], exclusion_terms={"bonds"}) + self.terms["dihedrals"] = prior.ParamDihedralCalculator(unified=True, scale=0.5) + +class Prior_CA_null(Prior_CA): + """CA prior with no terms""" + def __init__(self): + super().__init__() + self.prior_params.update({ + "prior_configuration_name": "CA_null", + "exclusions" : [], + "forceterms" : [], + }) + self.terms = {} + +class Prior_CA_lj_only(Prior_CA): + """CA prior with just a RepulsionCG term""" + def __init__(self): + super().__init__() + self.prior_params.update({ + "prior_configuration_name": "CA_lj_only", + "exclusions" : [], + "forceterms" : ['RepulsionCG'], + }) + self.terms = {} + self.terms["lj"] = prior.ParamNonbondedCalculator(fit_range=[3, 6]) + +def slice_to_str(s): + result = [s.start, s.stop, s.step] + result = [str(i) if i is not None else '' for i in result] + return ":".join(result) + +def get_prior_params_path(prior_path): + dir_path, file_name = os.path.split(prior_path) + file_name = file_name.replace("priors.yaml", "prior_params.json") + return os.path.join(dir_path, file_name) + +def load_h5_traj_slice(path, slice): + """Load a slice from a h5 trajectory without reading the entire file into memory""" + base_traj = mdtraj.load_frame(path, 0) + with h5py.File(path) as f: + t_xyz = f["coordinates"][slice][:] #pyright: ignore[reportIndexIssue] + t_time = f["time"][slice][:] #pyright: ignore[reportIndexIssue] -# ----------------------------- -# Execution backends -# ----------------------------- -class ExecContext: - def __init__(self, dist: DistributedConfig): - self.dist = dist - self.rank = 0 - self.world = 1 - self.comm = None + t_unitcell_lengths = None + t_unitcell_angles = None + if "cell_lengths" in f.keys(): + t_unitcell_lengths = f["cell_lengths"][slice][:] #pyright: ignore[reportIndexIssue] + t_unitcell_angles = f["cell_angles"][slice][:] #pyright: ignore[reportIndexIssue] - if dist.use_mpi: - if not MPI_AVAILABLE: - raise RuntimeError("MPI requested but mpi4py is not installed.") - self.comm = MPI.COMM_WORLD - self.rank = self.comm.Get_rank() - self.world = self.comm.Get_size() + result = mdtraj.Trajectory(t_xyz, base_traj.topology, time=t_time, unitcell_lengths=t_unitcell_lengths, unitcell_angles=t_unitcell_angles) + return result - def barrier(self) -> None: - if self.comm: - self.comm.Barrier() +class Preprocessor: + def __init__(self, dataset_conf, input_path_map, save_path, prior_builder, prior_file, prior_name, frame_slice, temp, optimize_forces, box, prior_plots, resume_preprocess, num_cores, jobid=None, totalNrJobs=None): + self.dataset_conf = dataset_conf + self.save_path = save_path + self.prior_builder = prior_builder + self.prior_file = prior_file + self.frame_slice = frame_slice + self.temp = temp + self.jobid = jobid + self.totalNrJobs = totalNrJobs + + self.pdbid_list = input_path_map + + if FILTER_NOT_PROCESSED_STEP_ONE: + pdbs_processed_step1 = [f.split('/')[-3] for f in glob.glob(os.path.join(save_path, "*/fit/fit_ok.txt"))] + # remove keys that are not in pdbs_processed_step1 + self.pdbid_list = {k: v for k, v in self.pdbid_list.items() if k in pdbs_processed_step1} + print('%d pdbs left after removing pdbs not processed in step 1' % len(self.pdbid_list)) + + # if os.path.exists(os.path.join(save_path, 'pdb_list.pkl')): + # with open(os.path.join(save_path, 'pdb_list.pkl'), 'rb') as f: + # self.pdbid_list = pickle.load(f) + # else: + # self.pdbid_list = self.get_pdbid_list() + + # if FILTER_NOT_PROCESSED_STEP_ONE: + # pdbs_processed_step1 = [f.split('/')[-3] for f in glob.glob(os.path.join(save_path, "*/fit/fit_ok.txt"))] + # # remove keys that are not in pdbs_processed_step1 + # self.pdbid_list = {k: v for k, v in self.pdbid_list.items() if k in pdbs_processed_step1} + # print('%d pdbs left after removing pdbs not processed in step 1' % len(self.pdbid_list)) + + # # pickle pdb_list + # os.makedirs(save_path, exist_ok=True) + # with open(os.path.join(save_path, 'pdb_list.pkl'), 'wb') as f: + # pickle.dump(self.pdbid_list, f) + + + self.optimize_forces = optimize_forces + self.box = box + self.prior_plots = prior_plots + self.resume_preprocess = resume_preprocess + self.num_cores = num_cores + + print("Input directory paths:", [i["path"] for i in self.dataset_conf]) + print("Save directory path:", self.save_path) + print(f"Temperature: {self.temp}") + print("Frame slice:", slice_to_str(self.frame_slice)) + print("Number of cores used for parallelization:", self.num_cores) + # print("PDB ID list:", self.pdbid_list) + + def step1_threading(self, pdbid): + try: + cache_dir = os.path.join(self.save_path, pdbid, "fit") + if not(self.resume_preprocess and os.path.exists(os.path.join(cache_dir, "fit_ok.txt"))): + # This assumes we've named the processes during initialization + bar_pos = int(mp.current_process().name.split("-")[1]) + 1 + self.process_step1(pdbid, bar_pos) + return [] + except Exception as e: + traceback.print_tb(e.__traceback__) + print(f"{pdbid}:", e) + raise + + def step3_threading(self, pdbid): + try: + # This assumes we've named the processes during initialization + bar_pos = int(mp.current_process().name.split("-")[1]) + 1 + self.process_step3(pdbid, bar_pos) + except Exception as e: + traceback.print_tb(e.__traceback__) + print(f"{pdbid}:", e) + raise + + def preprocess(self): + os.makedirs(os.path.join(self.save_path, "result"), exist_ok=True) + + info_dict = { + "input_paths": [i["path"] for i in self.dataset_conf], + "frame_slice": slice_to_str(self.frame_slice), + "pdbids": list(self.pdbid_list.keys()), + "optimize_forces": self.optimize_forces, + "box": self.box, + "prior_name": prior_name + } - def is_root(self) -> bool: - return self.rank == 0 + # If resuming, validate that no paramters that would invalidate the fit cache object have changed + # FIXME: This should also check "--tag-beta-turns" + if self.resume_preprocess: + if os.path.exists(os.path.join(self.save_path, "result/info.json")): + with open(os.path.join(self.save_path, "result/info.json"), "rt", encoding="utf-8") as f: + prevous_info = json.load(f) + for k in ["box", "frame_slice", "optimize_forces", "prior_name"]: + assert info_dict[k] == prevous_info[k], \ + f"Can't resume with different parameters: {k}: {info_dict[k]} != {prevous_info[k]}" - def split_items(self, items: List[str]) -> List[str]: - # deterministic split by rank (round-robin keeps load more even with variable protein sizes) - if self.world == 1: - return items - return [x for i, x in enumerate(items) if (i % self.world) == self.rank] + with open(os.path.join(self.save_path, "result/info.json"), "wt", encoding="utf-8") as f: + json.dump(info_dict, f) + pdbids = self.pdbid_list -class ParallelRunner: - """ - One DRY runner for: - - serial - - multiprocessing per rank - """ - def __init__(self, workers: int): - self.workers = max(1, int(workers)) - - @staticmethod - def _proc_init(counter: mp.Value) -> None: - with counter.get_lock(): - idx = int(counter.value) - counter.value += 1 - mp.current_process().name = f"PreprocessWorker-{idx}" - - @staticmethod - def worker_bar_position(default: int = 1) -> int: - try: - return int(mp.current_process().name.split("-")[1]) + 1 - except Exception: - return default - - def run( - self, - items: List[str], - fn: Callable[[str], None], - *, - desc: str, - show_progress: bool = True, - ) -> Dict[str, str]: - if self.workers == 1: - errors: Dict[str, str] = {} - it = tqdm(items, desc=desc, dynamic_ncols=True) if show_progress else items - for x in it: - try: - fn(x) - except Exception as e: - traceback.print_tb(e.__traceback__) - errors[x] = str(e) - return errors - - # multiprocessing + # Ensure all jobs have the same tqdm lock : https://github.com/tqdm/tqdm/issues/982 tqdm.get_lock() - errors: Dict[str, str] = {} - counter = mp.Value("i", 0, lock=True) - - with tqdm(total=len(items), desc=desc, dynamic_ncols=True, disable=not show_progress) as pbar: - with mp.Pool(self.workers, initializer=self._proc_init, initargs=(counter,)) as pool: - pending: Dict[mp.pool.ApplyResult, str] = { - pool.apply_async(fn, args=(x,)): x for x in items - } - while pending: - for res in list(pending.keys()): - if not res.ready(): - continue - x = pending.pop(res) - try: - res.get() - except Exception as e: - traceback.print_tb(e.__traceback__) - errors[x] = str(e) - pbar.n = len(items) - len(pending) - pbar.refresh() - if pending: - next(iter(pending)).wait(1) - return errors + # TODO: Print exceptions in the main thread for legibility + # TODO: Abstract the loop logic instead of repeating it twice -# ----------------------------- -# Dataset discovery -# ----------------------------- -class DatasetIndex: - def __init__(self, dataset_conf: List[Dict[str, Any]]): - self.dataset_conf = dataset_conf + # Truncate any existing ok_list.txt + with open(os.path.join(self.save_path, "result/ok_list.txt"), "wt", encoding="utf-8") as ok_list: + pass - def build_input_map(self) -> Dict[str, str]: - pdbid_mapping: Dict[str, str] = {} - for entry in self.dataset_conf: - input_path = entry["path"] - prefix = entry.get("prefix", "") - suffix = entry.get("suffix", "") - assert os.path.isdir(input_path), f"Input path does not exist: {input_path}" - - if "pdbids" in entry: - for dir_name in entry["pdbids"]: - input_h5 = os.path.join(input_path, dir_name, "result", f"output_{dir_name}.h5") - assert os.path.exists(input_h5), f"Requested path missing: {input_h5}" - pdbid_mapping[prefix + dir_name + suffix] = input_h5 - continue - for dir_name in sorted(os.listdir(input_path)): - input_h5 = os.path.join(input_path, dir_name, "result", f"output_{dir_name}.h5") - if os.path.exists(input_h5): - pdbid_mapping[prefix + dir_name + suffix] = input_h5 - return pdbid_mapping + # Run step 1 in parallel, saving the results to the cache + + if DO_STEP_1: + errorList = {} + thread_counter = mp.Value('i', 0, lock=True) + with tqdm(total=len(pdbids), desc="Processing Step 1", dynamic_ncols=True) as pbar: + with mp.Pool(self.num_cores, initializer=process_init, initargs=(thread_counter,)) as pool: + pending_results = {} + + # Submit tasks and map results to pdbids + for pdbid in pdbids: + result = pool.apply_async(self.step1_threading, args=(pdbid,)) + pending_results[result] = pdbid + + while pending_results: + # Check completed tasks + for result in list(pending_results.keys()): # Iterate over a copy to allow removal + if result.ready(): + try: + result.get() # Retrieve result or raise exception + except Exception as e: + pdbid = pending_results[result] # Get the corresponding pdbid + errorList[pdbid] = str(e) + finally: + del pending_results[result] # Remove completed task + + # Update the progress bar + pbar.n = len(pdbids) - len(pending_results) + pbar.refresh() + + # Wait for 1 second or for the last job to finish + if pending_results: + next(iter(pending_results)).wait(1) + + if len(errorList): + print('errorList', errorList) + print('errorList keys', errorList.keys()) + + if not self.prior_file: + # pickle prior builder + if not os.path.exists(self.save_path + '/prior_builder.pkl') or REGEN_CACHE_FILES: + # Merge cache files back into prior builder + for pdbid in tqdm(pdbids, desc="Merging cache files together"): + cache_dir = os.path.join(self.save_path, pdbid, "fit") + self.prior_builder.load_molecule_cache(cache_dir) + + with open(self.save_path + '/prior_builder.pkl', 'wb') as f: + pickle.dump(self.prior_builder, f) + else: + print("Using cached prior_builder object... ") -# ----------------------------- -# Core pipeline (OO, modular) -# ----------------------------- -class Preprocessor: - def __init__( - self, - cfg: PreprocessConfig, - runtime: RuntimeFlags, - dist: DistributedConfig, - dataset_conf: List[Dict[str, Any]], - input_map: Dict[str, str], - prior_builder: PriorBuilder, - ): - self.cfg = cfg - self.runtime = runtime - self.exec = ExecContext(dist) - self.runner = ParallelRunner(workers=dist.per_rank_workers) + with open(self.save_path + '/prior_builder.pkl', 'rb') as f: + self.prior_builder = pickle.load(f) - self.dataset_conf = dataset_conf - self.input_map = dict(input_map) # pdbid -> h5 path - self.prior_builder = prior_builder + self.process_step2() - # optional filter: only keep proteins that completed step 1 - if runtime.filter_not_processed_step_one: - processed = {p.split("/")[-3] for p in glob.glob(os.path.join(cfg.output_dir, "*/fit/fit_ok.txt"))} - self.input_map = {k: v for k, v in self.input_map.items() if k in processed} - - # ---------- high level ---------- - def run(self) -> None: - ensure_dir(os.path.join(self.cfg.output_dir, "result")) - all_pdbids = sorted(self.input_map.keys()) - - # MPI split (each rank gets a subset) - pdbids = self.exec.split_items(all_pdbids) - - # only root writes global info.json (avoids collisions) - if self.exec.is_root(): - self._write_run_info(all_pdbids) - - self.exec.barrier() - - # Step 1 (per-rank) - if self.runtime.do_step_1: - errs = self.runner.run( - pdbids, - self._step1_one, - desc=f"[rank {self.exec.rank}] Step 1", - show_progress=True, - ) - if errs and self.exec.is_root(): - print("Step 1 errors (sample):", dict(list(errs.items())[:10])) - - self.exec.barrier() - - # Step 2 (ONLY root fits priors once, then broadcast/filesystem share) - if not self.cfg.prior_file: - if self.exec.is_root(): - self._fit_prior_once(all_pdbids) - self.exec.barrier() + self.prior_builder.save_prior(self.save_path, None) else: - if self.exec.is_root(): - self._copy_prior_files() - self.exec.barrier() - - # Step 3 (per-rank) - errs3 = self.runner.run( - pdbids, - self._step3_one, - desc=f"[rank {self.exec.rank}] Step 3", - show_progress=True, - ) - if errs3 and self.exec.is_root(): - print("Step 3 errors (sample):", dict(list(errs3.items())[:10])) - - self.exec.barrier() - - # root writes ok_list - if self.exec.is_root(): - ok_list_path = os.path.join(self.cfg.output_dir, "result/ok_list.txt") - with open(ok_list_path, "wt", encoding="utf-8") as f: - f.write("\n".join(all_pdbids)) - print("Done!") - - # ---------- info / resume ---------- - def _write_run_info(self, all_pdbids: List[str]) -> None: - info_path = os.path.join(self.cfg.output_dir, "result/info.json") - payload = { - "input_paths": [i["path"] for i in self.dataset_conf], - "frame_slice": slice_to_str(self.cfg.frame_slice), - "pdbids": all_pdbids, - "optimize_forces": self.cfg.optimize_forces, - "box": self.cfg.use_box, - "prior_name": self.cfg.prior_name, - } - - if self.cfg.resume and os.path.exists(info_path): - with open(info_path, "rt", encoding="utf-8") as f: - prev = json.load(f) - for k in ["box", "frame_slice", "optimize_forces", "prior_name"]: - if payload[k] != prev.get(k): - raise AssertionError(f"Can't resume with different parameters: {k}: {payload[k]} != {prev.get(k)}") - - with open(info_path, "wt", encoding="utf-8") as f: - json.dump(payload, f, indent=2, sort_keys=True) - - # truncate ok_list - with open(os.path.join(self.cfg.output_dir, "result/ok_list.txt"), "wt", encoding="utf-8") as _: - pass - - # ---------- step 1 ---------- - def _step1_one(self, pdbid: str) -> None: - cache_dir = os.path.join(self.cfg.output_dir, pdbid, "fit") - if self.cfg.resume and os.path.exists(os.path.join(cache_dir, "fit_ok.txt")): - return - - bar_pos = self.runner.worker_bar_position() - self._process_step1(pdbid, bar_pos) - - def _save_raw( - self, - out_dir: str, - *, - embeddings: np.ndarray, - forces: np.ndarray, - coords: np.ndarray, - box_vectors: Optional[np.ndarray], - weights: Optional[np.ndarray], - ) -> None: - raw = os.path.join(out_dir, "raw") - ensure_dir(raw) - np.save(os.path.join(raw, "embeddings.npy"), embeddings) - np.save(os.path.join(raw, "forces.npy"), forces) - np.save(os.path.join(raw, "coordinates.npy"), coords) - - box_path = os.path.join(raw, "box.npy") - if box_vectors is not None: - np.save(box_path, box_vectors) + prior_params_path = get_prior_params_path(prior_file) + shutil.copy(self.prior_file, os.path.join(self.save_path, "priors.yaml")) + shutil.copy(prior_params_path, os.path.join(self.save_path, "prior_params.json")) + + if self.totalNrJobs: + pdblist = list(pdbids.keys()) + pdbidsPerJob = len(pdblist) // self.totalNrJobs + 1 + jobid = self.jobid + assert jobid is not None + if jobid < self.totalNrJobs - 1: + pdbids_c = [pdblist[i] for i in range(jobid * pdbidsPerJob, (jobid + 1) * pdbidsPerJob)] + else: + pdbids_c = [pdblist[i] for i in range(jobid * pdbidsPerJob, len(pdblist))] + + # filter pdbids for this job only + pdbids = {k: v for k, v in pdbids.items() if k in pdbids_c} + + print(f"Step 3: Processing {len(pdbids)} pdbids") + + # Run step 3 in parallel + thread_counter = mp.Value('i', 0, lock=True) + with tqdm(total=len(pdbids), desc="Processing Step 3", dynamic_ncols=True) as pbar: + with mp.Pool(self.num_cores, initializer=process_init, initargs=(thread_counter,)) as pool: + pending_results = [] + + for pdbid in pdbids: + pending_results += [pool.apply_async(self.step3_threading, args=(pdbid,))] + + while pending_results: + # Check for exceptions + [i.get() for i in pending_results if i.ready()] + # Remove finished jobs from the list + pending_results = [i for i in pending_results if not i.ready()] + pbar.n = len(pdbids) - len(pending_results) + pbar.refresh() + # Wait for 1 second or for the last job to finish + if pending_results: + pending_results[0].wait(1) + + # alternatively, cd to the preprocessed_data directory and run this cmd: + # ls */raw/deltaforces.npy | awk '{print substr($1, 1, 4)}' > result/ok_list.txt + with open(os.path.join(self.save_path, "result/ok_list.txt"), "wt", encoding="utf-8") as ok_list: + ok_list.write("\n".join(pdbids)) + + print("Done!") + + def save_data(self, output_path, trajectory, embeddings, forces, pdbid): + # print(f" {pdbid} (coordinates, forces): {trajectory.xyz.shape}, {forces.shape}") + np.save(f"{output_path}/raw/embeddings.npy", embeddings) + np.save(f"{output_path}/raw/forces.npy", forces) + np.save(f"{output_path}/raw/coordinates.npy", trajectory.xyz) + box_path = f"{output_path}/raw/box.npy" + if self.box: + np.save(box_path, trajectory.unitcell_vectors) elif os.path.exists(box_path): os.unlink(box_path) - w_path = os.path.join(raw, "forces_weights.npy") - if weights is not None: - np.save(w_path, weights) - elif os.path.exists(w_path): - os.unlink(w_path) + def process_step1(self, pdbid, bar_position=0): + """Generate the course grained data and topology for the protein, then add it to the prior builder""" - def _process_step1(self, pdbid: str, bar_position: int) -> None: - input_h5 = self.input_map[pdbid] - out_dir = os.path.join(self.cfg.output_dir, pdbid) - ensure_dir(os.path.join(out_dir, "raw")) - ensure_dir(os.path.join(out_dir, "processed")) - - with tqdm(total=7, position=bar_position, desc=f"{pdbid}: step1", dynamic_ncols=True, leave=False) as pbar: - def step(msg: str) -> None: + with tqdm(total=7, position=bar_position, desc=f"{pdbid}: File path setup", dynamic_ncols=True, leave=False) as pbar: + def progress_bar_step(msg): pbar.update(1) pbar.set_description_str(f"{pdbid}: {msg}") - step("Loading trajectory") - aa_traj, aa_weights = load_h5_traj_slice(input_h5, self.cfg.frame_slice) - assert aa_traj.xyz is not None - aa_traj.xyz *= 10.0 # nm -> Å + # Set up paths and create directories + output_path = os.path.join(self.save_path, pdbid) - step("Building CG mapping") - cg_map = self.prior_builder.build_mapping(aa_traj.topology) - mol = self.prior_builder.make_mol(cg_map) + # TODO: Get rit of the of subdirectories? + os.makedirs(f"{output_path}/raw", exist_ok=True) + os.makedirs(f"{output_path}/processed", exist_ok=True) - # optional debug topology save - mol.write(os.path.join(out_dir, "processed", f"{pdbid}_processed.psf")) - cg_topology_mol = cg_map.to_mol(bonds=True, angles=True, dihedrals=True) - cg_topology_mol.write(os.path.join(out_dir, "processed", "topology.psf")) + # Find which path this ID belongs to + input_file_path = self.pdbid_list[pdbid] - step("Mapping CG forces") - with h5py.File(input_h5, "r") as f: - forces = f["forces"][self.cfg.frame_slice, :, :] - forces = cg_map.cg_optimal_forces(aa_traj, forces) if self.cfg.optimize_forces else cg_map.cg_forces(forces) - forces = forces * 0.02390057361376673 # kJ/mol/nm -> kcal/mol/Å + progress_bar_step("Loading trajectory") + AAtraj = load_h5_traj_slice(input_file_path, self.frame_slice) + assert AAtraj.xyz is not None # for pyright + AAtraj.xyz *= 10 # convert to angstroms - step("Mapping CG coordinates") - xyz = cg_map.cg_positions(aa_traj.xyz) + progress_bar_step("Building CG mapping") + cg_map = self.prior_builder.build_mapping(AAtraj.topology) + mol = self.prior_builder.make_mol(cg_map) + topology = cg_map.to_mol(bonds=True, angles=True, dihedrals=True) + mol.write(f'{output_path}/processed/{pdbid}_processed.psf') + topology.write(f'{output_path}/processed/topology.psf') # Save the topology for the CG mapping, this is optional but useful for debugging + + # Get the forces + progress_bar_step("Mapping CG forces") + with h5py.File(input_file_path, "r") as f: + forces = f["forces"][self.frame_slice, :, :] #pyright: ignore[reportIndexIssue] + + if self.optimize_forces: + forces = cg_map.cg_optimal_forces(AAtraj, forces) + else: + forces = cg_map.cg_forces(forces) + + assert len(forces) == len(AAtraj) + # Convert from kilojoules/mole/nanometer to kilocalories/mole/angstrom + forces = forces*0.02390057361376673 + + progress_bar_step("Mapping CG coordinates") + xyz = cg_map.cg_positions(AAtraj.xyz) cg_traj = mdtraj.Trajectory(xyz, topology=cg_map.to_mdtraj()) - - if self.cfg.use_box and aa_traj.unitcell_lengths is not None: - cg_traj.unitcell_lengths = aa_traj.unitcell_lengths * 10.0 - cg_traj.unitcell_angles = aa_traj.unitcell_angles + if self.box and AAtraj.unitcell_lengths is not None: + cg_traj.unitcell_lengths = AAtraj.unitcell_lengths * 10 + cg_traj.unitcell_angles = AAtraj.unitcell_angles else: cg_traj.unitcell_lengths = None - cg_traj.unitcell_angles = None - - cg_weights = None - if aa_weights is not None: - if aa_weights.ndim == 2: - cg_weights = cg_map.cg_weights(aa_weights) - elif aa_weights.ndim == 1: - n_beads = len(cg_map.embeddings) - cg_weights = np.repeat(aa_weights[:, None], n_beads, axis=1) - - step("Saving raw arrays") - self._save_raw( - out_dir, - embeddings=cg_map.embeddings, - forces=forces, - coords=cg_traj.xyz, - box_vectors=(cg_traj.unitcell_vectors if self.cfg.use_box else None), - weights=cg_weights, - ) - - step("Generating prior fit data") - if not self.cfg.prior_file: - cache_dir = os.path.join(out_dir, "fit") - ensure_dir(cache_dir) - mol.coords = np.moveaxis(cg_traj.xyz, 0, -1) - self.prior_builder.add_molecule(mol, cg_traj, cache_dir) + cg_traj.unitcell_angles = None - # ---------- step 2 ---------- - def _fit_prior_once(self, all_pdbids: List[str]) -> None: - # cache prior_builder aggregation - pb_cache = os.path.join(self.cfg.output_dir, "prior_builder.pkl") + # Save the data + progress_bar_step("Saving Data") + self.save_data(output_path, cg_traj, cg_map.embeddings, forces, pdbid) - if os.path.exists(pb_cache) and not self.runtime.regen_cache_files: - with open(pb_cache, "rb") as f: - self.prior_builder = pickle.load(f) - else: - for pdbid in tqdm(all_pdbids, desc="Merging fit caches", dynamic_ncols=True): - cache_dir = os.path.join(self.cfg.output_dir, pdbid, "fit") - self.prior_builder.load_molecule_cache(cache_dir) - with open(pb_cache, "wb") as f: - pickle.dump(self.prior_builder, f) - - plot_dir = None - if self.cfg.prior_plots: - plot_dir = os.path.join(self.cfg.output_dir, "prior_fit_plots") - ensure_dir(plot_dir) - - self.prior_builder.fit(self.cfg.temp, plot_dir=plot_dir) - - # Save priors - if self.prior_builder.prior_params.get("external"): - self.prior_builder.save_prior_flex(self.cfg.output_dir) - else: - self.prior_builder.save_prior(self.cfg.output_dir) - - def _copy_prior_files(self) -> None: - assert self.cfg.prior_file is not None - prior_params_path = get_prior_params_path(self.cfg.prior_file) - shutil.copy(self.cfg.prior_file, os.path.join(self.cfg.output_dir, "priors.yaml")) - shutil.copy(prior_params_path, os.path.join(self.cfg.output_dir, "prior_params.json")) - - # ---------- step 3 ---------- - def _step3_one(self, pdbid: str) -> None: - bar_pos = self.runner.worker_bar_position() - self._process_step3(pdbid, bar_pos) - - def _process_step3(self, pdbid: str, bar_position: int) -> None: - out_dir = os.path.join(self.cfg.output_dir, pdbid) - raw_dir = os.path.join(out_dir, "raw") - proc_dir = os.path.join(out_dir, "processed") - - # legacy cleanup - for legacy in [f"{pdbid}_priors.yaml", f"{pdbid}_prior_params.json"]: - p = os.path.join(raw_dir, legacy) - if os.path.exists(p): - os.unlink(p) - - coords_npz = os.path.join(raw_dir, "coordinates.npy") - forces_npz = os.path.join(raw_dir, "forces.npy") - delta_forces_npz = os.path.join(raw_dir, "deltaforces.npy") - prior_energy_npz = os.path.join(raw_dir, "prior_energy.npy") - box_npz = os.path.join(raw_dir, "box.npy") if self.cfg.use_box else None - - forcefield = os.path.join(self.cfg.output_dir, "priors.yaml") - psf_file = os.path.join(proc_dir, f"{pdbid}_processed.psf") + # Note: moveaxis creates a view, the original trajectory.xyz is unmodified + assert cg_traj.xyz is not None + mol.coords = np.moveaxis(cg_traj.xyz, 0, -1) + progress_bar_step("Generating prior fit data") + if not self.prior_file: + cache_dir = os.path.join(output_path, "fit") + os.makedirs(cache_dir, exist_ok=True) + self.prior_builder.add_molecule(mol, cg_traj, cache_dir) + + def process_step2(self): + """Fit the prior forcefield based on accumulated data""" + if self.prior_plots: + plot_dir = os.path.join(self.save_path, "prior_fit_plots") + os.makedirs(plot_dir, exist_ok=True) + else: + plot_dir = None + + self.prior_builder.fit(self.temp, plot_dir=plot_dir) + + def process_step3(self, pdbid, bar_position=0): + """Save prior focefield and generate delta forces data for each protein""" + output_path = os.path.join(self.save_path, pdbid) + + # Remove legacy prior files if they exist + if os.path.exists(f"{output_path}/raw/{pdbid}_priors.yaml"): + os.unlink(f"{output_path}/raw/{pdbid}_priors.yaml") + if os.path.exists(f"{output_path}/raw/{pdbid}_prior_params.json"): + os.unlink(f"{output_path}/raw/{pdbid}_prior_params.json") + + # Generate delta forces for all atom simulation vs. prior FF + coords_npz = f'{output_path}/raw/coordinates.npy' + forces_npz = f'{output_path}/raw/forces.npy' + delta_forces_npz = f'{output_path}/raw/deltaforces.npy' + prior_energy_npz = f'{output_path}/raw/prior_energy.npy' + box_npz = None + if self.box: + box_npz = f"{output_path}/raw/box.npy" + forcefield = os.path.join(self.save_path, "priors.yaml") + psf_file = f'{output_path}/processed/{pdbid}_processed.psf' prior_params = self.prior_builder.prior_params - df = DeltaForces(self.runtime.device_step_3, psf_file, coords_npz, box_npz) - - if prior_params.get("external"): - df.addExternalForces( - forcefield, - self.prior_builder.priors["bonds"], - self.prior_builder.priors["angles"], - self.prior_builder.priors["dihedrals"], - forceterms=prior_params["forceterms_nn"], - bar_position=bar_position, - ) - df.computePriorForces( - forcefield, - exclusions=prior_params["exclusions"], - forceterms=prior_params["forceterms_classical"], - bar_position=bar_position, - ) + + deltaForcesObj = DeltaForces(DEVICE_STEP_3, psf_file, coords_npz, box_npz) + if 'external' in self.prior_builder.prior_params.keys(): + # forceterms = ['bonds', 'angles', 'dihedrals'] + deltaForcesObj.addExternalForces(forcefield, self.prior_builder.priors['bonds'], self.prior_builder.priors['angles'], self.prior_builder.priors['dihedrals'], forceterms=prior_params["forceterms_nn"], bar_position=bar_position) + + # forceterms = ['repulsioncg'] # update them properly in preprocess.py in the _flex class + deltaForcesObj.computePriorForces(forcefield, exclusions=prior_params["exclusions"], + forceterms=prior_params["forceterms_classical"], bar_position=bar_position) + else: - df.computePriorForces( - forcefield, - exclusions=prior_params["exclusions"], - forceterms=prior_params["forceterms"], - bar_position=bar_position, - ) - - df.makeAndSaveDeltaForces(forces_npz, delta_forces_npz, prior_energy_npz) - - -# ----------------------------- -# CLI -# ----------------------------- -def load_dataset_conf(inputs: List[str]) -> List[Dict[str, Any]]: - conf: List[Dict[str, Any]] = [] - for i in inputs: - if os.path.isfile(i): - with open(i, "r") as f: - conf += yaml.safe_load(f) + deltaForcesObj.computePriorForces(forcefield, exclusions=prior_params["exclusions"], + forceterms=prior_params["forceterms"], bar_position=bar_position) + + # load MD forces from forces_npz, compute delta forces, and save them in delta_forces_npz + deltaForcesObj.makeAndSaveDeltaForces(forces_npz, delta_forces_npz, prior_energy_npz) + +def gen_input_mapping(conf): + """Find the list of input files for the passed dataset config""" + pdbid_mapping = dict() + for entry in conf: + input_path = entry["path"] + prefix = entry.get("prefix", "") + suffix = entry.get("suffix", "") + assert os.path.isdir(input_path), f"Input path does not exist: {input_path}" + if "pdbids" in entry: + for dir_name in entry["pdbids"]: + input_h5 = os.path.join(input_path, dir_name, "result", f"output_{dir_name}.h5") + assert os.path.exists(input_h5), "Requested path {input_path}/{dir_name} does not exist" + pdbid_mapping[prefix + dir_name + suffix] = input_h5 else: - conf += [{"path": i}] - return conf - -def main() -> None: - parser = argparse.ArgumentParser(description="Preprocess data (OO, DRY, MPI-enabled).") - parser.add_argument("input", nargs="+", help="Input directories or a YAML config file") + dir_names = os.listdir(input_path) + for dir_name in sorted(dir_names): + input_h5 = os.path.join(input_path, dir_name, "result", f"output_{dir_name}.h5") + if os.path.exists(input_h5): + pdbid_mapping[prefix + dir_name + suffix] = input_h5 + else: + print(f" Skipping \"{dir_name}\" (directory contains no output)") + return pdbid_mapping + +prior_types = { + "CA":Prior_CA, + "CACB":Prior_CACB, + "CACB_lj":Prior_CACB_lj, + "CACB_lj_angle_dihedral":Prior_CACB_lj_angle_dihedral, + "CA_lj":Prior_CA_lj, + "CA_lj_angle":Prior_CA_lj_angle, + "CA_lj_angle_dihedral":Prior_CA_lj_angle_dihedral, + "CA_lj_angle_dihedralX":Prior_CA_lj_angle_dihedralX, + "CA_lj_angleXCX_dihedralX":Prior_CA_lj_angleXCX_dihedralX, + "CA_lj_angleXCX_dihedralX_flex":Prior_CA_lj_angleXCX_dihedralX_flex, + "CA_lj_angleXCX_dihedralX_V1":Prior_CA_lj_angleXCX_dihedralX_V1, + "CA_Majewski2022_v0":Prior_CA_Majewski2022_v0, + "CA_Majewski2022_v1":Prior_CA_Majewski2022_v1, + "CA_lj_bondNull_angleXCX_dihedralX":Prior_CA_lj_bondNull_angleXCX_dihedralX, + "CA_lj_bondNull_angleNull_dihedralX":Prior_CA_lj_bondNull_angleNull_dihedralX, + "CA_lj_bondNull_angleNull_dihedralNull":Prior_CA_lj_bondNull_angleNull_dihedralNull, + "CA_lj_angleNull_dihedralX":Prior_CA_lj_angleNull_dihedralX, + "CA_lj_angleNull_dihedralNull":Prior_CA_lj_angleNull_dihedralNull, + "CA_null":Prior_CA_null, + "CA_lj_only":Prior_CA_lj_only, +} + +# --------------------------------------------------------------------------- +# Entry-point helpers: collection → validation → processing +# --------------------------------------------------------------------------- + +def build_parser(): + parser = argparse.ArgumentParser(description="Preprocess data.") + parser.add_argument("input", nargs="+", + help="Input directory paths or YAML dataset config files") parser.add_argument("-o", "--output", required=True, help="Output directory path") - parser.add_argument("--pdbids", nargs="*", help="Specific PDB IDs to process") - parser.add_argument("--num-frames", type=int, default=None, help="Number of frames to process") - parser.add_argument("--frame-slice", type=str, default=None, help="Frame slice: start:end:stride") - parser.add_argument("--temp", type=int, default=300) - parser.add_argument("--prior", type=str, default=None, help="Prior configuration name") - parser.add_argument("--prior-file", default=None, help="Use a pre-fit priors.yaml (and matching prior_params.json)") - parser.add_argument("--optimize-forces", action="store_true") - parser.add_argument("--no-box", action="store_true") - parser.add_argument("--prior-plots", dest="prior_plots", action="store_true", default=True) - parser.add_argument("--no-prior-plots", dest="prior_plots", action="store_false") - parser.add_argument("--no-fit-constraints", action="store_true") - parser.add_argument("--fit-min-cnt", type=int, default=0) - parser.add_argument("--resume", action="store_true") - parser.add_argument("--use-cached-fits", nargs="*", default=[], help="e.g. bonds angles dihedrals lj") - - # distributed - parser.add_argument("--mpi", action="store_true", help="Enable MPI splitting (requires mpi4py, launch with mpirun)") - parser.add_argument("--workers", type=int, default=32, help="Workers per rank (multiprocessing). Use 1 to disable.") - - # runtime flags (formerly globals) - parser.add_argument("--filter-not-processed-step-one", action="store_true") - parser.add_argument("--skip-step-1", action="store_true") - parser.add_argument("--no-regen-cache-files", action="store_true") - parser.add_argument("--device-step-3", type=str, default="cpu") - - args = parser.parse_args() - print(args) - - assert not (args.num_frames and args.frame_slice) - if args.num_frames is not None: - frame_slice = slice(0, args.num_frames) - elif args.frame_slice is not None: - frame_slice = parse_slice(args.frame_slice) - else: - frame_slice = slice(None) - - cfg = PreprocessConfig( - output_dir=args.output, - temp=args.temp, - frame_slice=frame_slice, - optimize_forces=args.optimize_forces, - use_box=not args.no_box, - prior_plots=args.prior_plots, - resume=args.resume, - fit_constraints=not args.no_fit_constraints, - fit_min_cnt=args.fit_min_cnt, - prior_name=args.prior or "", - prior_file=args.prior_file, - ) - - runtime = RuntimeFlags( - filter_not_processed_step_one=args.filter_not_processed_step_one, - do_step_1=not args.skip_step_1, - regen_cache_files=not args.no_regen_cache_files, - device_step_3=args.device_step_3, - use_cached_fits=tuple(args.use_cached_fits), - ) - - dist = DistributedConfig( - use_mpi=args.mpi, - per_rank_workers=max(1, args.workers), - ) - - # If flex prior needs spawn, do it early - # We'll detect it after we build the prior - # (safe to call set_start_method once, guarded) - torch.set_num_threads(1) - - # thread-safe matplotlib backend for plots - import matplotlib - matplotlib.use("Agg") - - # dataset - dataset_conf = load_dataset_conf(args.input) - index = DatasetIndex(dataset_conf) - input_map = index.build_input_map() - if args.pdbids: - input_map = {p: input_map[p] for p in args.pdbids} - - # prior file sanity (if used) + parser.add_argument("--pdbids", nargs="*", help="Restrict processing to these PDB IDs") + parser.add_argument("--num-frames", "--num_frames", type=int, default=None, + help="Number of frames to process (mutually exclusive with --frame-slice)") + parser.add_argument("--frame-slice", type=str, default=None, + help="Select frames via a Python slice string: start:end:stride " + "(mutually exclusive with --num-frames)") + parser.add_argument("--temp", type=int, default=300, help="Temperature (K)") + parser.add_argument("--prior", type=str, default=None, + help="Prior forcefield to use, one of: " + ", ".join(sorted(prior_types.keys()))) + parser.add_argument("--optimize-forces", action="store_true", + help="Use statistically optimal force aggregation (Kramer 2023)") + parser.add_argument("--prior-file", default=None, + help="Load a pre-built prior from PRIOR_FILE instead of fitting one") + parser.add_argument('--no-box', default=False, action='store_true', + help="Don't use periodic box information") + parser.add_argument('--prior-plots', default=True, action='store_true', + help="Save plots of the prior fit functions") + parser.add_argument('--no-prior-plots', dest='prior_plots', action='store_false', + help="Don't save plots of the prior fit functions") + parser.add_argument('--no-fit-constraints', default=False, action='store_true', + help="Disable range constraints when fitting prior functions") + parser.add_argument('--fit-min-cnt', type=int, default=0, + help="Minimum bin count to include when fitting the prior (default 0)") + # parser.add_argument('--tag-beta-turns', default=False, action='store_true', + # help="Give beta turns a different bond type in the prior") + parser.add_argument('--resume', default=False, action='store_true', + help="Resume an interrupted run — all settings must be identical") + parser.add_argument('--num-cores', type=int, default=32, + help="CPU cores for parallel preprocessing (default 32)") + parser.add_argument('--jobid', type=int, default=None, + help="Job index for array jobs; limits processing to a subset of PDBs") + parser.add_argument('--totalNrJobs', type=int, default=None, + help="Total number of array jobs (used with --jobid)") + return parser + + +def validate_config(cfg) -> None: + assert not (cfg.num_frames and cfg.frame_slice), \ + "--num-frames and --frame-slice are mutually exclusive" if cfg.prior_file: assert os.path.exists(cfg.prior_file), f"Prior file does not exist: {cfg.prior_file}" - params_path = get_prior_params_path(cfg.prior_file) - with open(params_path, "r", encoding="utf-8") as f: - params = json.load(f) - if not cfg.prior_name: - cfg.prior_name = params["prior_configuration_name"] - elif cfg.prior_name != params["prior_configuration_name"]: - print(f'WARNING: --prior "{cfg.prior_name}" differs from prior file "{params["prior_configuration_name"]}"') + assert cfg.prior or cfg.prior_file, \ + "Specify a prior with --prior or --prior-file" - assert cfg.prior_name, "You must specify the prior via --prior or --prior-file" - prior_builder = PriorFactory.make(cfg.prior_name, cfg, runtime) +def process_config(cfg): + # Frame selection + if cfg.num_frames: + frame_slice = slice(0, cfg.num_frames) + elif cfg.frame_slice: + frame_slice = slice(*[int(i) if i != '' else None for i in cfg.frame_slice.split(":")]) + else: + frame_slice = slice(None) - if prior_builder.prior_params.get("external"): - # flex uses torch nets; spawn is safer in multiprocess contexts - try: - mp.set_start_method("spawn", force=True) - except Exception: - pass + # Resolve prior name from --prior-file if not explicitly set + prior_name = cfg.prior + prior_file = cfg.prior_file + if prior_file: + prior_params_path = get_prior_params_path(prior_file) + with open(prior_params_path, "r", encoding="utf-8") as f: + prior_params = json.load(f) + prior_configuration_name = prior_params["prior_configuration_name"] + if prior_name is None: + prior_name = prior_configuration_name + elif prior_name != prior_configuration_name: + print() + print(f"WARNING: Prior \"{prior_name}\" differs from the one in the prior file \"{prior_configuration_name}\"") + print() + + if prior_name not in prior_types: + raise RuntimeError(f"Unknown prior configuration: {prior_name}") + print(f"Using prior: {prior_name}") + + # Build and configure the prior + prior_builder = prior_types[prior_name]() + prior_builder.enable_fit_constraints(not cfg.no_fit_constraints) + # prior_builder.enable_bond_tags(cfg.tag_beta_turns) + prior_builder.enable_bond_tags(False) + prior_builder.set_min_cnt(cfg.fit_min_cnt) + + if 'external' in prior_builder.prior_params.keys(): + mp.set_start_method('spawn') + + # Thread-safe matplotlib backend required for prior fit plots + import matplotlib + matplotlib.use('Agg') - pre = Preprocessor(cfg, runtime, dist, dataset_conf, input_map, prior_builder) - pre.run() + # Build dataset config from file paths or inline dicts + dataset_conf = [] + for i in cfg.input: + if os.path.isfile(i): + with open(i, "r") as f: + dataset_conf += yaml.safe_load(f) + else: + dataset_conf += [{"path": i}] + + input_path_map = gen_input_mapping(dataset_conf) + if cfg.pdbids: + input_path_map = {i: input_path_map[i] for i in cfg.pdbids} + + return dict( + dataset_conf=dataset_conf, + input_path_map=input_path_map, + output_dir=cfg.output, + prior_builder=prior_builder, + prior_file=prior_file, + prior_name=prior_name, + frame_slice=frame_slice, + temp=cfg.temp, + optimize_forces=cfg.optimize_forces, + box=not cfg.no_box, + prior_plots=cfg.prior_plots, + resume_preprocess=cfg.resume, + num_cores=cfg.num_cores, + jobid=cfg.jobid, + totalNrJobs=cfg.totalNrJobs, + ) if __name__ == "__main__": - main() + from module.base_config import BaseConfig + cfg = BaseConfig(build_parser()) + print(cfg.as_namespace()) + validate_config(cfg) + params = process_config(cfg) + preprocessor = Preprocessor(**params) + preprocessor.preprocess() diff --git a/simulate.py b/simulate.py index 02426b3..3eb8fab 100755 --- a/simulate.py +++ b/simulate.py @@ -531,33 +531,75 @@ def build_calc(model, mol, embeddings, use_box=False, replicas=1, temperature=30 class ArgsMock(): pass -def main(): +# --------------------------------------------------------------------------- +# Entry-point helpers: collection → validation → processing +# --------------------------------------------------------------------------- + +def build_parser(): import argparse arg_parser = argparse.ArgumentParser() arg_parser.add_argument("checkpoint_path", help="The model checkpoint to use") - arg_parser.add_argument("processed_path", help="One or more input files from which the model simulations will start. Each different input file will be processed as a different simulation replica.", nargs="+") - arg_parser.add_argument("--conf", default=None, help="The hyperparameters to use instead of those contained in the checkpoint") - arg_parser.add_argument("--max-num-neighbors", default=None, type=int, help="Override the 'max_num_neighbors' parameter of the model") + arg_parser.add_argument("processed_path", nargs="+", + help="One or more input files to start simulations from. " + "Each file is treated as a separate replica.") + arg_parser.add_argument("--conf", default=None, + help="Hyperparameter file to use instead of those stored in the checkpoint") + arg_parser.add_argument("--max-num-neighbors", default=None, type=int, + help="Override the 'max_num_neighbors' parameter of the model") arg_parser.add_argument("--temperature", default=300, type=int, help="Simulation temperature (Kelvin)") arg_parser.add_argument("--timestep", default=1, type=int, help="Simulation timestep (femtoseconds)") arg_parser.add_argument("--steps", default=10000, type=int, help="The number of frames to simulate") - arg_parser.add_argument("-o", "--output", default="sim.pdb", help="The output file name (may be a .pdb or .h5 file)") + arg_parser.add_argument("-o", "--output", default="sim.pdb", + help="Output file name (.pdb or .h5)") arg_parser.add_argument("--save-steps", default=100, type=int, help="Save frames every n steps") - arg_parser.add_argument("--prior-only", default=False, action='store_true', help="Disable the model and use only the prior forcefield") + arg_parser.add_argument("--prior-only", default=False, action='store_true', + help="Disable the model and run with the prior forcefield only") arg_parser.add_argument('--no-box', action='store_true', help='Do not use box information') - arg_parser.add_argument("--replicas", default=1, type=int, help="The number of simulations running in parallel") - arg_parser.add_argument("--torchmd", default=False, action='store_true', help="Use TorchMD for the prior instead of TorchForceField") - arg_parser.add_argument("--verbose", default=True, action='store_true', help="Prints detailed logs") - arg_parser.add_argument("--prior-nn", default=None, type=str, help="Path to the folder of a neural network prior.") - + arg_parser.add_argument("--replicas", default=1, type=int, + help="Number of parallel simulations to run") + arg_parser.add_argument("--torchmd", default=False, action='store_true', + help="Use TorchMD for the prior instead of TorchForceField") + arg_parser.add_argument("--verbose", default=True, action='store_true', help="Print detailed logs") + arg_parser.add_argument("--prior-nn", default=None, type=str, + help="Path to a neural network prior folder") + return arg_parser + + +def validate_config(cfg) -> None: + assert os.path.exists(cfg.checkpoint_path) or os.path.isdir(cfg.checkpoint_path), \ + f"Checkpoint not found: {cfg.checkpoint_path}" + output_dir = os.path.dirname(cfg.output) + if output_dir: + assert os.path.exists(output_dir), f"Output directory does not exist: {output_dir}" - args = arg_parser.parse_args() - print(args) - run_simulation(args.checkpoint_path, args.processed_path, conf=args.conf, max_num_neighbors=args.max_num_neighbors, - temperature=args.temperature, timestep=args.timestep, steps=args.steps, output=args.output, save_steps=args.save_steps, - prior_only=args.prior_only, no_box=args.no_box, replicas=args.replicas, torchmd=args.torchmd, verbose=args.verbose, - prior_nn=args.prior_nn, gpu=0) +def process_config(cfg) -> dict: + return dict( + checkpoint_path=cfg.checkpoint_path, + processed_path=cfg.processed_path, + conf=cfg.conf, + max_num_neighbors=cfg.max_num_neighbors, + temperature=cfg.temperature, + timestep=cfg.timestep, + steps=cfg.steps, + output=cfg.output, + save_steps=cfg.save_steps, + prior_only=cfg.prior_only, + no_box=cfg.no_box, + replicas=cfg.replicas, + torchmd=cfg.torchmd, + verbose=cfg.verbose, + prior_nn=cfg.prior_nn, + gpu=0, + ) + + +def main(): + from module.base_config import BaseConfig + cfg = BaseConfig(build_parser()) + validate_config(cfg) + params = process_config(cfg) + run_simulation(**params) def run_simulation(checkpoint_path, processed_path, conf=None, max_num_neighbors=None, temperature=300, timestep=1, steps=10000, output='sim.pdb', save_steps=100, diff --git a/train.py b/train.py index 1f9218d..2495d05 100755 --- a/train.py +++ b/train.py @@ -1,1352 +1,1127 @@ #!/usr/bin/env python3 -from __future__ import annotations -import os -import sys -import json -import time -import yaml -import shutil -import itertools -import datetime -import traceback -import resource -from dataclasses import dataclass -from typing import Dict, Any, Optional, Tuple, Iterable, List - -import numpy as np import torch import torch.nn as nn import torch.optim as optim -from torch import Tensor -from torch.utils.data import DataLoader -from tqdm import tqdm - -# Distributed +import torch.utils.data import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler - -# Local modules (unchanged) +import yaml +import numpy as np from module.torchmdnet.model import create_model from module import dataset from module import model_util -from module.lr_scheduler_wrappers import ( - SchedulerWrapper_CosineAnnealingWarmRestarts, - SchedulerWrapper_CosineAnnealingLR, - SchedulerWrapper_ExponentialLR, - SchedulerWrapper_ReduceLROnPlateau, -) - -# ----------------------------- Small utilities ----------------------------- - -def flatten_first(t: Optional[Tensor]) -> Optional[Tensor]: - """Flatten first two dims, preserving remaining dims.""" - if t is None or getattr(t, "shape", None) is None: - return t - if len(t.shape) < 2: - return t - return t.reshape(t.shape[0] * t.shape[1], *t.shape[2:]) - -def deterministic_shuffle(items: List[str], seed: int) -> List[str]: - """Deterministically shuffle a list.""" - g = torch.Generator().manual_seed(seed) - idx = torch.randperm(n=len(items), generator=g, device="cpu") - return [items[i] for i in idx] - -def check_early_stopping(val_list: List[float], patience: int = 1) -> bool: - """True if val loss increased (patience+1) consecutive times. patience<0 disables.""" - if patience < 0: - return False - if len(val_list) < patience + 2: - return False - window = np.array(val_list[-(patience + 2):]) - if np.all((window[1:] - window[:-1]) > 0): - print(f"Validation loss increased {patience+1} times, stopping...") - return True - return False - -def make_term_offsets(lengths: List[int], term_lengths: Tensor) -> Tensor: - """Offset per-term indices across concatenated variable-length batches.""" - result = [] - count = 0 - repeats = len(term_lengths) // len(lengths) - lengths = np.tile(np.array(lengths), repeats).tolist() - assert len(lengths) == len(term_lengths) - for off, nterms in zip(lengths, term_lengths): - result.append(torch.full((int(nterms), 1), count, dtype=torch.long)) - count += int(off) - return torch.cat(result, dim=0) - -# ----------------------------- Domain config ------------------------------ +from module.lr_scheduler_wrappers import * -class TermDef: - def __init__(self, path: Optional[str] = None, conf: Optional[dict] = None): - self.scales: Dict[str, float] = {} - self.angle_wrap: Dict[str, bool] = {} - if path: - with open(path, "r") as f: - conf = yaml.safe_load(f) - if conf: - for k, v in conf.items(): - self.scales[k] = float(v["scale"]) if (v is not None and "scale" in v) else 1.0 - self.angle_wrap[k] = bool(v["angle_wrap"]) if (v is not None and "angle_wrap" in v) else False +import os +import json +import time +from tqdm import tqdm +import datetime +import shutil +import resource +import sys +import traceback +import itertools +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple - def names(self) -> List[str]: - return list(self.scales.keys()) +from torch import Tensor - def scale(self, name: str) -> float: - return self.scales[name] +# os.environ["CUDA_LAUNCH_BLOCKING"] = "1" - def wrap(self, name: str) -> bool: - return self.angle_wrap[name] -# ----------------------------- Distributed manager ------------------------- +# =========================================================================== +# Distributed context +# =========================================================================== @dataclass -class DistInfo: +class DistributedContext: rank: Optional[int] = None world_size: Optional[int] = None local_rank: Optional[int] = None - @property - def enabled(self) -> bool: - return self.rank is not None and self.world_size is not None and self.world_size > 1 - @property def is_main(self) -> bool: return self.rank is None or self.rank == 0 -class DistributedManager: - def __init__(self, enable: bool): - self.info = DistInfo() - self._requested = enable - - def setup(self) -> DistInfo: - """Initialize distributed if env vars indicate multi-proc or user requested.""" - if not self._requested and not any(k in os.environ for k in ("RANK", "WORLD_SIZE", "SLURM_PROCID")): - return self.info - - rank, world_size, local_rank = self._read_env_ranks() - if rank is None: - return self.info - - torch.cuda.set_device(local_rank) - dist.init_process_group( - backend="nccl", - init_method="env://", - world_size=world_size, - rank=rank, - ) - self.info = DistInfo(rank=rank, world_size=world_size, local_rank=local_rank) - print(f"Initialized process {rank}/{world_size} (local_rank={local_rank})") - return self.info - - def cleanup(self) -> None: - """Destroy process group if initialized.""" - if dist.is_initialized(): - dist.destroy_process_group() - - def barrier(self) -> None: - """Sync all processes if distributed.""" - if self.info.enabled: + @property + def is_distributed(self) -> bool: + return self.world_size is not None and self.world_size > 1 + + def barrier(self): + if self.is_distributed: dist.barrier() - def broadcast_bool(self, value: bool, src: int = 0, device: Optional[torch.device] = None) -> bool: - """Broadcast a boolean from src to all processes.""" - if not self.info.enabled: + def all_reduce_scalar(self, value: float, device) -> float: + if not self.is_distributed: + return value + t = torch.tensor(value, device=device) + dist.all_reduce(t, op=dist.ReduceOp.SUM) + return t.item() + + def broadcast_bool(self, value: bool, device) -> bool: + if not self.is_distributed: return value - assert device is not None t = torch.tensor(1 if value else 0, device=device) - dist.broadcast(t, src=src) + dist.broadcast(t, src=0) return bool(t.item()) - def all_reduce_sum(self, x: float, device: torch.device) -> float: - """All-reduce a scalar float via SUM.""" - if not self.info.enabled: - return x - t = torch.tensor(x, device=device) - dist.all_reduce(t, op=dist.ReduceOp.SUM) - return float(t.item()) - def _read_env_ranks(self) -> Tuple[Optional[int], Optional[int], Optional[int]]: - """Read rank/world_size/local_rank from torchrun or SLURM env vars.""" - if "RANK" in os.environ and "WORLD_SIZE" in os.environ: - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - local_rank = int(os.environ.get("LOCAL_RANK", 0)) - return rank, world_size, local_rank +def setup_distributed() -> DistributedContext: + """Auto-detect torchrun or SLURM environment and initialize NCCL process group.""" + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + rank = int(os.environ['RANK']) + world_size = int(os.environ['WORLD_SIZE']) + local_rank = int(os.environ.get('LOCAL_RANK', 0)) + elif 'SLURM_PROCID' in os.environ: + rank = int(os.environ['SLURM_PROCID']) + world_size = int(os.environ['SLURM_NTASKS']) + local_rank = int(os.environ.get('SLURM_LOCALID', 0)) + else: + return DistributedContext() - if "SLURM_PROCID" in os.environ: - rank = int(os.environ["SLURM_PROCID"]) - world_size = int(os.environ["SLURM_NTASKS"]) - local_rank = int(os.environ.get("SLURM_LOCALID", 0)) - return rank, world_size, local_rank + torch.cuda.set_device(local_rank) + dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) + return DistributedContext(rank=rank, world_size=world_size, local_rank=local_rank) - return None, None, None -# ----------------------------- Model wrappers ------------------------------ +def cleanup_distributed(): + if dist.is_initialized(): + dist.destroy_process_group() -class BatchWrapper(nn.Module): - def __init__(self, model: nn.Module): - super().__init__() - self.model = model - def forward(self, pos: Tensor, lengths: List[int], **kwargs) -> Tuple[Tensor, Tensor, Dict[str, Tensor]]: - """Prepare batch indices/terms and forward into model.""" - batch_nums = dataset.make_batch_nums(len(pos), lengths).to(pos.device) - kwargs["pos"] = pos +# =========================================================================== +# Configuration dataclasses +# =========================================================================== - for k, v in list(kwargs.items()): - kwargs[k] = flatten_first(v) +@dataclass +class LossConfig: + energy_matching: bool + energy_weight: float + force_weight: float + train_term_def: 'TermDef' + extra_term_keys: List[str] + use_force_weights: bool - if "bonds" in kwargs: - kwargs["bonds"] = kwargs["bonds"] + make_term_offsets(lengths, kwargs.pop("len_bonds").cpu()).to(pos.device) - if "angles" in kwargs: - kwargs["angles"] = kwargs["angles"] + make_term_offsets(lengths, kwargs.pop("len_angles").cpu()).to(pos.device) - if "dihedrals" in kwargs: - kwargs["dihedrals"] = kwargs["dihedrals"] + make_term_offsets(lengths, kwargs.pop("len_dihedrals").cpu()).to(pos.device) - kwargs["batch"] = batch_nums - result = self.model(**kwargs) +@dataclass +class EpochMetrics: + """Aggregated (summed) metrics for one epoch. Divide by num_cal to get means.""" + total_loss: float = 0.0 + energy_loss: float = 0.0 + force_loss: float = 0.0 + num_cal: int = 0 + term_losses: Dict[str, float] = field(default_factory=dict) + term_num_cal: Dict[str, int] = field(default_factory=dict) - if len(result) == 2: - out_e, out_f = result - return out_e, out_f, {} - out_e, out_f, extra = result - return out_e, out_f, extra - -class ModelManager: - def __init__(self, conf: Dict[str, Any], dist_info: DistInfo, gpu_ids, local_rank: Optional[int]): - self.conf = conf - self.dist = dist_info - self.gpu_ids = gpu_ids - self.local_rank = local_rank - - self.device = self._select_device() - self.model = create_model(args=conf).to(self.device) - self.wrapped = BatchWrapper(self.model) - self.parallel, self.device_output = self._wrap_parallel() - - def _select_device(self) -> torch.device: - """Pick a torch device for the current process.""" - if self.local_rank is not None: - return torch.device(f"cuda:{self.local_rank}") - if self.gpu_ids != "cpu": - return torch.device("cuda:0") - return torch.device("cpu") - - def _wrap_parallel(self) -> Tuple[nn.Module, Any]: - """Wrap model in DDP or DataParallel or no-wrap.""" - if self.dist.enabled: - parallel = DDP( - self.wrapped, - device_ids=[self.local_rank], - output_device=self.local_rank, - find_unused_parameters=True, - ) - if self.dist.is_main: - print(f"DDP: Training on {self.dist.world_size} GPUs across nodes (local_rank={self.local_rank})") - return parallel, self.device - - if self.gpu_ids == "cpu": - if self.dist.is_main: - print("Training on CPU") - return self.wrapped, "cpu" - - parallel = nn.DataParallel(self.wrapped, device_ids=self.gpu_ids) - if self.dist.is_main: - print(f"DataParallel: Training on {len(parallel.device_ids)} GPU(s)") - return parallel, parallel.output_device - - def train(self) -> None: - self.model.train() - - def eval(self) -> None: - self.parallel.eval() - - def state_target_for_loading(self) -> nn.Module: - """Return the object to pass to model_util.load_state_dict_with_rename.""" - if isinstance(self.parallel, DDP): - return self.parallel.module.model - if hasattr(self.parallel, "module"): - return self.parallel.module.model - return self.model - - def state_dict_for_saving(self) -> Dict[str, Tensor]: - """Extract the underlying model state dict in a wrapper-agnostic way.""" - if isinstance(self.parallel, DDP): - return self.parallel.module.model.state_dict() - if hasattr(self.parallel, "module"): - return self.parallel.module.model.state_dict() - return self.model.state_dict() - -# ----------------------------- Data module --------------------------------- + def mean_loss(self) -> float: + return self.total_loss / self.num_cal if self.num_cal > 0 else 0.0 -class RoundRobinDataWrapper: - def __init__(self, *iterables: Iterable): - self.iterables = iterables + def mean_energy_loss(self) -> float: + return self.energy_loss / self.num_cal if self.num_cal > 0 else 0.0 - def __len__(self) -> int: - return sum(map(len, self.iterables)) + def mean_force_loss(self) -> float: + return self.force_loss / self.num_cal if self.num_cal > 0 else 0.0 - def __iter__(self): - iters = map(iter, self.iterables) - for num_active in range(len(self.iterables), 0, -1): - iters = itertools.cycle(itertools.islice(iters, num_active)) - yield from map(next, iters) + def mean_term_loss(self, k: str) -> float: + n = self.term_num_cal.get(k, 0) + return self.term_losses.get(k, 0.0) / n if n > 0 else 0.0 -@dataclass -class DataBundle: - datasets: List[Any] - train: RoundRobinDataWrapper - val: RoundRobinDataWrapper - train_samplers: List[Optional[DistributedSampler]] - pdb_list: List[str] - -class DataModule: - def __init__( - self, - directory_path: str, - subsetpdbs: str, - val_ratio: float, - batch_size: int, - atoms_per_call: Optional[int], - enable_shuffle: bool, - dataset_chunk_size: Optional[int], - use_npfile: bool, - embedding_filename: Optional[str], - energy_matching: bool, - dist_info: DistInfo, - ): - self.dir = directory_path - self.subsetpdbs = subsetpdbs - self.val_ratio = val_ratio - self.batch_size = batch_size - self.atoms_per_call = atoms_per_call - self.enable_shuffle = enable_shuffle - self.chunk = dataset_chunk_size - self.use_npfile = use_npfile - self.embedding_filename = embedding_filename or "embeddings.npy" - self.energy_filename = "tica_delta_energies.npy" if energy_matching else None - self.dist = dist_info - - def build(self) -> DataBundle: - """Create datasets and round-robin dataloaders (optionally chunked).""" - pdb_list = self._load_pdb_list() - pdb_lists = [pdb_list[i:i + self.chunk] for i in range(0, len(pdb_list), self.chunk)] if self.chunk else [pdb_list] - - datasets, train_loaders, val_loaders, samplers = [], [], [], [] - for chunk_list in pdb_lists: - ds, tr, va, sampler = self._make_loaders(chunk_list) - datasets.append(ds) - train_loaders.append(tr) - val_loaders.append(va) - samplers.append(sampler) - - return DataBundle( - datasets=datasets, - train=RoundRobinDataWrapper(*train_loaders), - val=RoundRobinDataWrapper(*val_loaders), - train_samplers=samplers, - pdb_list=pdb_list, - ) - def _load_pdb_list(self) -> List[str]: - """Read, dedupe, sort, and deterministically shuffle PDB IDs.""" - with open(os.path.join(self.dir, "result", self.subsetpdbs), "r") as f: - pdb_list = [x for x in f.read().split("\n") if x] - pdb_list = sorted(set(pdb_list)) - return deterministic_shuffle(pdb_list, seed=47563537) - - def _make_loaders(self, pdb_list: List[str]): - """Build torch DataLoaders and (optionally) DistributedSamplers.""" - print("Dataset:", " ".join(pdb_list)) - all_data = dataset.ProteinDataset( - self.dir, - pdb_list, - energy_file=self.energy_filename, - embeddings_file=self.embedding_filename, - use_npfile=self.use_npfile, - ) +# =========================================================================== +# Pure utilities +# =========================================================================== - assert 0.0 < self.val_ratio < 1.0 - val_size = int(self.val_ratio * len(all_data)) - train_size = len(all_data) - val_size +def flatten_first(t): + """Flatten the first two dimensions of a tensor.""" + if t is None: + return t + if len(t.shape) < 2: + return t + return t.reshape(t.shape[0] * t.shape[1], *t.shape[2:]) - if self.enable_shuffle: - g = torch.Generator().manual_seed(12341234) - val_idx, train_idx = torch.utils.data.random_split( - torch.arange(len(all_data)), - [val_size, train_size], - generator=g, - ) - else: - train_idx = range(train_size) - val_idx = range(train_size, train_size + val_size) - - train = torch.utils.data.Subset(all_data, train_idx) - val = torch.utils.data.Subset(all_data, val_idx) - - collate_fn = dataset.ProteinBatchCollate(self.atoms_per_call) - - train_sampler = None - val_sampler = None - if self.dist.enabled: - train_sampler = DistributedSampler(train, num_replicas=self.dist.world_size, rank=self.dist.rank, shuffle=False) - val_sampler = DistributedSampler(val, num_replicas=self.dist.world_size, rank=self.dist.rank, shuffle=False) - - train_loader = DataLoader( - train, - batch_size=self.batch_size, - shuffle=False if train_sampler is None else False, - num_workers=4, - persistent_workers=True, - pin_memory=True, - collate_fn=collate_fn, - sampler=train_sampler, - ) - val_loader = DataLoader( - val, - batch_size=self.batch_size, - shuffle=False, - num_workers=4, - persistent_workers=True, - pin_memory=True, - collate_fn=collate_fn, - sampler=val_sampler, - ) - return all_data, train_loader, val_loader, train_sampler - -# ----------------------------- Checkpointing ------------------------------- - -class Checkpointer: - def __init__(self, result_dir: str, dist_info: DistInfo): - self.result_dir = result_dir - self.dist = dist_info - - def find_resume_checkpoint(self) -> Optional[str]: - """Pick mini checkpoint first, else normal checkpoint.""" - mini = os.path.join(self.result_dir, "checkpoint-mini.pth") - main = os.path.join(self.result_dir, "checkpoint.pth") - if os.path.exists(mini): - return mini - if os.path.exists(main): - return main - return None - - def save( - self, - path: str, - epoch: int, - model_state: Dict[str, Tensor], - optimizer: optim.Optimizer, - model_conf: Dict[str, Any], - scheduler: Optional[Any], - extra: Optional[dict] = None, - ) -> None: - """Save checkpoint only on main process.""" - if not self.dist.is_main: - return - ckpt = { - "epoch": epoch, - "optimizer": optimizer.state_dict(), - "state_dict": model_state, - "hyper_parameters": model_conf, - } - if scheduler: - ckpt["scheduler"] = scheduler.state_dict() - if extra: - ckpt["extra"] = extra - torch.save(ckpt, path) - - def write_training_info( - self, - epoch: int, - directory_path: str, - pdb_list: List[str], - params: Dict[str, Any], - ) -> None: - """Write/update training_info.json and copy priors on main process.""" - if not self.dist.is_main: - return - - training_info_path = os.path.join(self.result_dir, "training_info.json") - info = {} - - if os.path.exists(training_info_path): - with open(training_info_path, "r") as f: - info = json.load(f) - if "input_directory" in info: - info = {"0": info} - else: - print("Path", training_info_path, "does not exist") - - info[str(epoch)] = { - "weight_decay": params["weight_decay"], - "learning_rate": params["learning_rate"], - "epochs": params["epochs"], - "batch_size": params["batch_size"], - "input_directory": directory_path, - "pdbs": pdb_list, - "energy_weight": params["energy_weight"], - "force_weight": params["force_weight"], - "embedding_filename": params["embedding_filename"], - "world_size": params.get("world_size", 1), - } - if params.get("lr_scheduler_repr") is not None: - info[str(epoch)]["lr_scheduler"] = params["lr_scheduler_repr"] - - if not params["dry_run"]: - with open(training_info_path, "w") as f: - json.dump(info, f, indent=2) - prior_path = os.path.join(directory_path, "priors.yaml") - if os.path.exists(prior_path): - prior_params_path = os.path.join(directory_path, "prior_params.json") - shutil.copy(prior_path, self.result_dir) - shutil.copy(prior_params_path, self.result_dir) +def make_term_offsets(lengths, term_lengths): + result = [] + count = 0 + repeats = len(term_lengths) // len(lengths) + lengths = np.tile(lengths, repeats) + assert len(lengths) == len(term_lengths) + for off, nterms in zip(lengths, term_lengths): + result.append(torch.full((nterms, 1), count, dtype=torch.long)) + count += off + return torch.cat(result) -# ----------------------------- Training bookkeeping ------------------------ -@dataclass -class History: - train: List[float] - val: List[float] - energy: List[float] - force: List[float] - - @classmethod - def empty(cls) -> "History": - return cls(train=[], val=[], energy=[], force=[]) - - @classmethod - def load(cls, result_dir: str) -> "History": - path = os.path.join(result_dir, "history.npy") - if not os.path.exists(path): - return cls.empty() - data = np.load(path, allow_pickle=True).item() - return cls(train=data["train"], val=data["val"], energy=data["energy"], force=data["force"]) - - def save(self, result_dir: str) -> None: - np.save(os.path.join(result_dir, "history.npy"), {"train": self.train, "val": self.val, "energy": self.energy, "force": self.force}) - -class EpochHistory: - def __init__(self, result_dir: str): - self.result_dir = result_dir - self.path = os.path.join(result_dir, "epoch_history.json") - self.data: Dict[str, Any] = {} - if os.path.exists(self.path): - with open(self.path, "r") as f: - self.data = json.load(f) - - def update(self, key: str, value: Dict[str, Any]) -> None: - self.data[key] = value - with open(self.path, "w") as f: - json.dump(self.data, f, indent=2) - -# ----------------------------- Optim & scheduler --------------------------- +def deterministic_shuffle(target, seed): + generator = torch.Generator().manual_seed(seed) + indices = torch.randperm(n=len(target), generator=generator, device="cpu") + return [target[i] for i in indices] + + +def check_early_stopping(val_list, patience=1) -> bool: + """Return True if val loss increased patience+1 consecutive epochs. patience<0 disables.""" + if patience < 0: + return False + if len(val_list) < patience + 2: + return False + check_range = np.array(val_list[-(patience + 2):]) + if np.all((check_range[1:] - check_range[:-1]) > 0): + print(f"Validation loss increased {patience + 1} times, stopping...") + return True + return False + def should_decay(param_name: str) -> bool: - """Decide if a parameter should receive weight decay.""" - parts = param_name.split(".") - assert parts + parts = param_name.split('.') + assert len(parts) > 0 if parts[-1] == "bias": return False - if len(parts) >= 2 and parts[-2] == "embedding": + if parts[-2] == "embedding": return False - if len(parts) >= 2 and parts[-2] == "distance_expansion": + if parts[-2] == "distance_expansion": return False assert parts[-1] == "weight" return True -class OptimFactory: - @staticmethod - def adamw(model: nn.Module, lr: float, weight_decay: float) -> optim.Optimizer: - """Create AdamW with split decay groups.""" - do_decay, dont_decay = [], [] - for name, p in model.named_parameters(): - (do_decay if should_decay(name) else dont_decay).append(p) - return optim.AdamW( - [{"params": do_decay, "weight_decay": weight_decay}, {"params": dont_decay}], - lr=lr, - ) -class SchedulerFactory: - @staticmethod - def from_args(args) -> Optional[Any]: - """Build one of the scheduler wrappers from argparse args.""" - lr_scheduler = None - if args.cos_anneal: - T_0, T_mult = [int(i) for i in args.cos_anneal.split(",")] - lr_scheduler = SchedulerWrapper_CosineAnnealingWarmRestarts(T_0, T_mult) - if args.cos_lr: - assert lr_scheduler is None - T_max, eta_min = args.cos_lr.split(",") - lr_scheduler = SchedulerWrapper_CosineAnnealingLR(int(T_max), float(eta_min)) - if args.exp_lr: - assert lr_scheduler is None - lr_scheduler = SchedulerWrapper_ExponentialLR(float(args.exp_lr)) - if args.plateau_lr: - assert lr_scheduler is None - factor, patience, threshold, min_lr = args.plateau_lr.split(",") - lr_scheduler = SchedulerWrapper_ReduceLROnPlateau(float(factor), int(patience), float(threshold), float(min_lr)) - return lr_scheduler - -# ----------------------------- Core trainer -------------------------------- +# =========================================================================== +# Model classes +# =========================================================================== -@dataclass -class TrainArgs: - directory_path: str - result_directory: Optional[str] - conf_path: str - gpu_ids: Any - weight_decay: float - learning_rate: float - epochs: int - batch_size: int - val_ratio: float - atoms_per_call: Optional[int] - scheduler: Optional[Any] - dry_run: bool - reset_early_stopping: bool - enable_shuffle: bool - mini_epoch_size: Optional[int] - early_stopping: int - checkpoint_save: int - subsetpdbs: str - energy_weight: float - force_weight: float - energy_matching: bool - train_term_def: TermDef - embedding_filename: Optional[str] - dataset_chunk_size: Optional[int] - use_npfile: bool - use_force_weights: bool +class BatchWrapper(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model -class Trainer: - def __init__(self, args: TrainArgs, dist_mgr: DistributedManager): - self.args = args - self.dist_mgr = dist_mgr - self.dist = dist_mgr.info - - self.result_dir = self._ensure_result_dir(args.result_directory, args.dry_run) - self.checkpointer = Checkpointer(self.result_dir, self.dist) - - self.conf = self._load_conf(args.conf_path) - self._apply_conf_overrides() - - self.data_bundle = self._build_data() - self.model_mgr = ModelManager(self.conf, self.dist, args.gpu_ids, self.dist.local_rank) - - self.optimizer = OptimFactory.adamw(self.model_mgr.model, args.learning_rate, args.weight_decay) - if args.scheduler: - args.scheduler.initialize(self.optimizer) - self.scheduler = args.scheduler - - self.history = History.load(self.result_dir) - self.epoch_history = EpochHistory(self.result_dir) - - self.epoch = 0 - self.epoch_resume_extra: Optional[dict] = None - - self._maybe_resume() - self.dist_mgr.barrier() - self._write_training_info() - - def _load_conf(self, path: str) -> Dict[str, Any]: - """Load YAML config.""" - with open(path, "r") as f: - conf = yaml.safe_load(f) - if self.dist.is_main: - print("Config:\n", conf, "\n") - return conf - - def _apply_conf_overrides(self) -> None: - """Set conf flags needed for training terms.""" - if "harmonic_net" in self.conf and self.args.train_term_def.names(): - self.conf["harmonic_net_return_terms"] = True - - if self.conf.get("external_embedding_channels") is None and (self.args.embedding_filename and self.args.embedding_filename != "embeddings.npy"): - if self.dist.is_main: - print("WARNING: external embeddings usually should use graph-network-ext network") - - def _build_data(self) -> DataBundle: - """Construct datasets/loaders and add optional dataset features.""" - dm = DataModule( - directory_path=self.args.directory_path, - subsetpdbs=self.args.subsetpdbs, - val_ratio=self.args.val_ratio, - batch_size=self.args.batch_size, - atoms_per_call=self.args.atoms_per_call, - enable_shuffle=self.args.enable_shuffle, - dataset_chunk_size=self.args.dataset_chunk_size, - use_npfile=self.args.use_npfile, - embedding_filename=self.args.embedding_filename, - energy_matching=self.args.energy_matching, - dist_info=self.dist, - ) - bundle = dm.build() - - if "sequence_basis_radius" in self.conf: - if self.dist.is_main: - print(f"Adding sequences to dataset... (sequence_basis_radius={self.conf['sequence_basis_radius']})") - for d in bundle.datasets: - d.build_sequences() - - extra_terms = [] - if "harmonic_net" in self.conf: - if self.dist.is_main: - print(f"Adding classical terms to dataset... (harmonic_net={self.conf['harmonic_net']})") - for d in bundle.datasets: - d.build_classical_terms() - - if self.args.train_term_def.names(): - extra_terms = self.args.train_term_def.names() - if self.dist.is_main: - print(f"Loading additional trained terms: {extra_terms}") - print(f" Term Scales: {[self.args.train_term_def.scale(i) for i in extra_terms]}") - print(f" Term Angle Wrap: {[self.args.train_term_def.wrap(i) for i in extra_terms]}") - for d in bundle.datasets: - d.load_frame_terms(extra_terms) - - if self.args.use_force_weights: - if self.dist.is_main: - print("Loading forces weights...") - for d in bundle.datasets: - d.load_frame_terms(["forces_weights"]) - - self.extra_train_terms = extra_terms - if self.dist.is_main: - print() - return bundle - - def _ensure_result_dir(self, result_dir: Optional[str], dry_run: bool) -> str: - """Create or validate result directory (main process only).""" - if result_dir and os.path.exists(os.path.join(result_dir, "checkpoint.pth")): - return result_dir - if result_dir and os.path.exists(os.path.join(result_dir, "checkpoint-mini.pth")): - return result_dir - - if result_dir is None: - result_dir = "../data/result-" + datetime.datetime.now().strftime("%Y.%m.%d-%H.%M.%S") - - if os.path.exists(result_dir): - info_path = os.path.join(result_dir, "training_info.json") - if os.path.exists(info_path): - if self.dist.is_main: - print("Re-initializing:", result_dir) - return result_dir - raise RuntimeError("Model directory exists but doesn't contain a checkpoint.pth or training_info.json file") - - if self.dist.is_main and not dry_run: - os.makedirs(result_dir, exist_ok=False) - print("Created:", result_dir) - return result_dir - - def _maybe_resume(self) -> None: - """Resume from checkpoint if present.""" - ckpt_path = self.checkpointer.find_resume_checkpoint() - if self.dist.is_main: - print("checkpoint_path", ckpt_path) - - if not ckpt_path: - self.epoch = 0 - if self.dist.is_main: - print("Saving to:", self.result_dir) - return - - if self.dist.is_main: - print("Resuming:", self.result_dir) - - device = self.model_mgr.device - ckpt = torch.load(ckpt_path, weights_only=False, map_location=device) - - model_target = self.model_mgr.state_target_for_loading() - model_util.load_state_dict_with_rename(model_target, ckpt["state_dict"]) - - if ckpt.get("optimizer") is not None: - self.optimizer.load_state_dict(ckpt["optimizer"]) - else: - if self.dist.is_main: - print(" No optimizer in checkpoint, resetting...") + def forward(self, pos, lengths, **kwargs) -> Tuple[Tensor, Tensor]: + batch_nums = dataset.make_batch_nums(len(pos), lengths) + batch_nums = batch_nums.to(pos.device) + assert batch_nums.device == pos.device + kwargs["pos"] = pos + for k, v in kwargs.items(): + kwargs[k] = flatten_first(v) + if "bonds" in kwargs: + kwargs["bonds"] = kwargs["bonds"] + make_term_offsets(lengths, kwargs.pop("len_bonds").cpu()).to(pos.device) + if "angles" in kwargs: + kwargs["angles"] = kwargs["angles"] + make_term_offsets(lengths, kwargs.pop("len_angles").cpu()).to(pos.device) + if "dihedrals" in kwargs: + kwargs["dihedrals"] = kwargs["dihedrals"] + make_term_offsets(lengths, kwargs.pop("len_dihedrals").cpu()).to(pos.device) + kwargs["batch"] = batch_nums + result = self.model(**kwargs) + if len(result) == 2: + result = [*result, {}] + return result # pyright: ignore[reportReturnType] + + +class TermDef: + def __init__(self, path=None, conf=None): + self.scales: Dict[str, float] = {} + self.angle_wrap: Dict[str, bool] = {} + if path: + with open(path, 'r') as f: + conf = yaml.safe_load(f) + if conf: + for k, v in conf.items(): + self.scales[k] = float(v["scale"]) if v is not None and "scale" in v else 1.0 + self.angle_wrap[k] = bool(v["angle_wrap"]) if v is not None and "angle_wrap" in v else False - if self.scheduler and ckpt.get("scheduler") is not None: - self.scheduler.load_state_dict(ckpt["scheduler"]) + def get_names(self) -> List[str]: + return list(self.scales.keys()) - self.epoch_resume_extra = ckpt.get("extra") - self.epoch = int(ckpt.get("epoch", 0)) + def get_scale(self, name: str) -> float: + return self.scales[name] - if self.epoch > 0: - self.history = History.load(self.result_dir) - self.epoch_history = EpochHistory(self.result_dir) + def get_angle_wrap(self, name: str) -> bool: + return self.angle_wrap[name] - if self.dist.is_main: - print("Saving to:", self.result_dir) - def _write_training_info(self) -> None: - """Write training info and priors (main process only).""" - if self.scheduler: - lr_sched_repr = repr(self.scheduler) - else: - lr_sched_repr = None - for g in self.optimizer.param_groups: - g["lr"] = self.args.learning_rate - - self.checkpointer.write_training_info( - epoch=self.epoch, - directory_path=self.args.directory_path, - pdb_list=self.data_bundle.pdb_list, - params={ - "weight_decay": self.args.weight_decay, - "learning_rate": self.args.learning_rate, - "epochs": self.args.epochs, - "batch_size": self.args.batch_size, - "energy_weight": self.args.energy_weight, - "force_weight": self.args.force_weight, - "embedding_filename": self.args.embedding_filename or "embeddings.npy", - "world_size": self.dist.world_size if self.dist.world_size else 1, - "lr_scheduler_repr": lr_sched_repr, - "dry_run": self.args.dry_run, - }, - ) +class RoundRobinDataWrapper: + def __init__(self, *iterables): + self.iterables = iterables + + def __len__(self): + return sum(map(len, self.iterables)) - def run(self) -> None: - """Main training loop.""" - if self.scheduler and self.scheduler.is_annealing(): - self.args.early_stopping = -1 + def __iter__(self): + iterators = map(iter, self.iterables) + for num_active in range(len(self.iterables), 0, -1): + iterators = itertools.cycle(itertools.islice(iterators, num_active)) + yield from map(next, iterators) + + +# =========================================================================== +# Data loading +# =========================================================================== + +def gen_dataloaders(directory_path, pdb_list, energy_filename, embedding_filename, + use_npfile, enable_shuffle, val_ratio, batch_size, atoms_per_call, + ddp: DistributedContext): + """Build train/val DataLoaders for one PDB chunk, with optional DistributedSampler.""" + print("Dataset:", " ".join(pdb_list)) + all_data = dataset.ProteinDataset( + directory_path, pdb_list, + energy_file=energy_filename, embeddings_file=embedding_filename, use_npfile=use_npfile, + ) + assert 0.0 < val_ratio < 1.0 + val_size = int(val_ratio * len(all_data)) + train_size = len(all_data) - val_size + + if enable_shuffle: + gen = torch.Generator().manual_seed(12341234) + val_idx, train_idx = torch.utils.data.random_split( + torch.arange(len(all_data)), [val_size, train_size], generator=gen + ) # pyright: ignore[reportArgumentType] + else: + train_idx = range(train_size) + val_idx = range(train_size, train_size + val_size) + + train_set = torch.utils.data.Subset(all_data, train_idx) # pyright: ignore[reportArgumentType] + val_set = torch.utils.data.Subset(all_data, val_idx) # pyright: ignore[reportArgumentType] + collate_fn = dataset.ProteinBatchCollate(atoms_per_call) + + train_sampler = ( + DistributedSampler(train_set, num_replicas=ddp.world_size, rank=ddp.rank, shuffle=False) + if ddp.is_distributed else None + ) + val_sampler = ( + DistributedSampler(val_set, num_replicas=ddp.world_size, rank=ddp.rank, shuffle=False) + if ddp.is_distributed else None + ) - first_es_epoch = self.epoch if self.args.reset_early_stopping else 0 - verbose_loss_report = sys.stdout.isatty() + train_loader = DataLoader( + train_set, batch_size=batch_size, shuffle=False, num_workers=4, + persistent_workers=True, pin_memory=True, collate_fn=collate_fn, sampler=train_sampler, + ) + val_loader = DataLoader( + val_set, batch_size=batch_size, shuffle=False, num_workers=4, + persistent_workers=True, pin_memory=True, collate_fn=collate_fn, sampler=val_sampler, + ) + return all_data, train_loader, val_loader, train_sampler + + +def _load_pdb_list(directory_path, subsetpdbs, dataset_chunk_size): + """Read, deduplicate, shuffle, and optionally chunk the PDB ID list.""" + with open(os.path.join(directory_path, "result", subsetpdbs), 'r') as f: + pdb_list = f.read().split('\n') + pdb_list = sorted(set(p for p in pdb_list if p)) + pdb_list = deterministic_shuffle(pdb_list, seed=47563537) + if dataset_chunk_size is not None: + pdb_lists = [pdb_list[i:i + dataset_chunk_size] for i in range(0, len(pdb_list), dataset_chunk_size)] + else: + pdb_lists = [pdb_list] + return pdb_lists, pdb_list + + +def _build_all_dataloaders(pdb_lists, directory_path, energy_filename, embedding_filename, + use_npfile, enable_shuffle, val_ratio, batch_size, atoms_per_call, + ddp: DistributedContext): + """Build datasets and merged RoundRobin loaders for all PDB chunks.""" + datasets, train_loaders, val_loaders, train_samplers = [], [], [], [] + for pdb_chunk in pdb_lists: + ds, train_loader, val_loader, sampler = gen_dataloaders( + directory_path, pdb_chunk, energy_filename, embedding_filename, + use_npfile, enable_shuffle, val_ratio, batch_size, atoms_per_call, ddp, + ) + datasets.append(ds) + train_loaders.append(train_loader) + val_loaders.append(val_loader) + train_samplers.append(sampler) + train_data = RoundRobinDataWrapper(*train_loaders) + val_data = RoundRobinDataWrapper(*val_loaders) + return datasets, train_data, val_data, train_samplers + + +# =========================================================================== +# Model setup +# =========================================================================== + +def _load_model_conf(conf_path, ddp: DistributedContext) -> dict: + """Load and print the model architecture YAML config.""" + if conf_path is None: + conf_path = "../configs/config.yaml" + with open(conf_path, 'r') as f: + conf = yaml.safe_load(f) + if ddp.is_main: + print("Config:\n", conf, "\n") + return conf + + +def _augment_datasets(conf, datasets, train_term_def: TermDef, use_force_weights: bool, + embedding_filename: str, ddp: DistributedContext) -> List[str]: + """Add optional features (sequences, classical terms, force weights) to every dataset.""" + if conf.get("external_embedding_channels") is None and embedding_filename != "embeddings.npy": + if ddp.is_main: + print("WARNING: external embeddings usually should use graph-network-ext network") + + extra_train_terms: List[str] = [] + + if "sequence_basis_radius" in conf: + if ddp.is_main: + print(f"Adding sequences to dataset... (sequence_basis_radius={conf['sequence_basis_radius']})") + for d in datasets: + d.build_sequences() + + if "harmonic_net" in conf: + if ddp.is_main: + print(f"Adding classical terms to dataset... (harmonic_net={conf['harmonic_net']})") + for d in datasets: + d.build_classical_terms() + if train_term_def.get_names(): + harmonic_trained_terms = train_term_def.get_names() + if ddp.is_main: + print(f"Loading additional trained terms: {harmonic_trained_terms}") + print(f" Term Scales: {[train_term_def.get_scale(t) for t in harmonic_trained_terms]}") + print(f" Term Angle Wrap: {[train_term_def.get_angle_wrap(t) for t in harmonic_trained_terms]}") + for d in datasets: + d.load_frame_terms(harmonic_trained_terms) + extra_train_terms.extend(harmonic_trained_terms) + + if use_force_weights: + if ddp.is_main: + print("Loading force weights...") + for d in datasets: + d.load_frame_terms(["forces_weights"]) + + return extra_train_terms + + +def build_parallel_model(model, gpu_ids, ddp: DistributedContext): + """ + Wrap model for distributed/parallel execution. + Returns (parallel_model, device, device_output). + device — where weights live; use for checkpoints and all_reduce + device_output — where model outputs land; use for moving targets + """ + wrapped = BatchWrapper(model) + if ddp.is_distributed: + device = torch.device(f'cuda:{ddp.local_rank}') + model.to(device) + parallel = DDP(wrapped, device_ids=[ddp.local_rank], output_device=ddp.local_rank, + find_unused_parameters=True) + device_output = device + if ddp.is_main: + print(f"DDP: {ddp.world_size} GPUs (local_rank={ddp.local_rank})") + elif gpu_ids == "cpu": + device = torch.device("cpu") + model.to(device) + parallel = wrapped + device_output = device + print("Training on CPU") + else: + device = torch.device(f'cuda:{gpu_ids[0]}') + model.to(device) + parallel = nn.DataParallel(wrapped, device_ids=gpu_ids) + device_output = parallel.output_device + print(f"DataParallel: {len(parallel.device_ids)} GPU(s)") + return parallel, device, device_output + + +# =========================================================================== +# Optimizer +# =========================================================================== + +def _build_optimizer(model, weight_decay: float, learning_rate: float): + """AdamW with weight decay applied only to non-embedding, non-bias parameters.""" + do_decay, dont_decay = [], [] + for name, param in model.named_parameters(): + (do_decay if should_decay(name) else dont_decay).append(param) + return optim.AdamW( + [{"params": do_decay, "weight_decay": weight_decay}, {"params": dont_decay}], + lr=learning_rate, + ) - criterion = nn.MSELoss() - term_criterion = nn.MSELoss(reduction="none") - while self.epoch < self.args.epochs: - self._set_sampler_epoch(self.epoch) +# =========================================================================== +# Checkpoint I/O +# =========================================================================== + +def save_checkpoint(checkpoint_path, epoch, model, optimizer, model_conf, scheduler, + extra=None, ddp: Optional[DistributedContext] = None): + """Save checkpoint. In distributed mode only rank 0 writes.""" + if ddp is not None and not ddp.is_main: + return + d = { + "epoch": epoch, + "optimizer": optimizer.state_dict(), + "state_dict": model.state_dict(), + "hyper_parameters": model_conf, + } + if scheduler: + d["scheduler"] = scheduler.state_dict() + if extra: + d["extra"] = extra + torch.save(d, checkpoint_path) + + +def _load_checkpoint_state(result_directory, model, optimizer, scheduler, device, + ddp: DistributedContext): + """Find and load the best available checkpoint. Returns (epoch, epoch_resume).""" + checkpoint_path = None + if result_directory and os.path.exists(f'{result_directory}/checkpoint-mini.pth'): + checkpoint_path = f'{result_directory}/checkpoint-mini.pth' + elif result_directory and os.path.exists(f'{result_directory}/checkpoint.pth'): + checkpoint_path = f'{result_directory}/checkpoint.pth' + + if ddp.is_main: + print("checkpoint_path", checkpoint_path) + + if not checkpoint_path: + return 0, None + + if ddp.is_main: + print("Resuming:", result_directory) + + checkpoint = torch.load(checkpoint_path, weights_only=False, map_location=device) + model_util.load_state_dict_with_rename(model, checkpoint["state_dict"]) + + if "optimizer" in checkpoint and checkpoint["optimizer"] is not None: + optimizer.load_state_dict(checkpoint["optimizer"]) + elif ddp.is_main: + print(" No optimizer in checkpoint, resetting...") + + if scheduler and "scheduler" in checkpoint and checkpoint["scheduler"] is not None: + scheduler.load_state_dict(checkpoint["scheduler"]) + + return checkpoint.get("epoch", 0), checkpoint.get("extra") + + +def _commit_epoch_checkpoint(epoch, model, optimizer, conf, scheduler, result_directory, + val_loss_list, history, checkpoint_save, ddp: DistributedContext): + """Atomically promote the epoch checkpoint and optionally save a best/backup copy.""" + tmp_path = f'{result_directory}/checkpoint-{epoch}.pth' + save_checkpoint(tmp_path, epoch + 1, model, optimizer, conf, scheduler, ddp=ddp) + + if not ddp.is_main: + return + + if os.path.exists(f'{result_directory}/checkpoint-mini.pth'): + os.unlink(f'{result_directory}/checkpoint-mini.pth') + + if checkpoint_save and (epoch % checkpoint_save == 0): + shutil.copyfile(tmp_path, f'{result_directory}/checkpoint.pth') + else: + os.replace(tmp_path, f'{result_directory}/checkpoint.pth') + + if val_loss_list[-1] <= np.min(val_loss_list): + shutil.copyfile(f'{result_directory}/checkpoint.pth', f'{result_directory}/checkpoint-best.pth') + + np.save(f'{result_directory}/history.npy', history) # pyright: ignore[reportArgumentType] + print(" Checkpoint saved.") + + +def _save_mini_epoch_checkpoint(epoch, i, model, optimizer, conf, scheduler, + total_loss, num_cal, mini_train_loss, mini_num_cal, + train_data, result_directory, epoch_history, + ddp: DistributedContext): + """Save a mid-epoch checkpoint and record the mini-epoch in epoch_history.""" + tmp_path = f'{result_directory}/checkpoint-{epoch}-{i}.pth' + save_checkpoint(tmp_path, epoch, model, optimizer, conf, scheduler, + extra={"train_loss": total_loss, "num_cal": num_cal, "i": i}, ddp=ddp) + if not ddp.is_main: + return + os.replace(tmp_path, f'{result_directory}/checkpoint-mini.pth') + epoch_history[f"{epoch}-{i}"] = { + "train_loss": total_loss / num_cal, + "mini_train_loss": mini_train_loss / mini_num_cal, + "epoch_len": len(train_data), + "lr": [g['lr'] for g in optimizer.param_groups], + } + with open(os.path.join(result_directory, "epoch_history.json"), "w") as f: + json.dump(epoch_history, f, indent=2) + + +# =========================================================================== +# Result directory and history +# =========================================================================== + +def _setup_result_directory(result_directory, epoch, dry_run, ddp: DistributedContext) -> str: + """Create the result directory for a fresh run, or validate it for resumption.""" + if epoch > 0: + pass # Resuming — directory already exists + elif not result_directory or not os.path.exists(result_directory): + if not result_directory: + result_directory = "../data/result-" + datetime.datetime.now().strftime("%Y.%m.%d-%H.%M.%S") + if not dry_run and ddp.is_main: + os.makedirs(result_directory, exist_ok=False) + if ddp.is_main: + print("Created:", result_directory) + elif os.path.exists(f'{result_directory}/training_info.json'): + if ddp.is_main: + print("Re-initializing:", result_directory) + else: + raise RuntimeError("Model directory exists but doesn't contain a checkpoint.pth or training_info.json file") + + if ddp.is_main: + print("Saving to:", result_directory) + return result_directory + + +def _load_history(result_directory, epoch): + """Load loss history arrays and epoch_history JSON from a previous run.""" + train_loss_list, val_loss_list, energy_loss_list, force_loss_list = [], [], [], [] + epoch_history: dict = {} + + if epoch > 0: + history = np.load(f'{result_directory}/history.npy', allow_pickle=True).item() + train_loss_list = history['train'] + val_loss_list = history['val'] + energy_loss_list = history['energy'] + force_loss_list = history['force'] + + epoch_history_path = os.path.join(result_directory, "epoch_history.json") + if os.path.exists(epoch_history_path): + with open(epoch_history_path, "r") as f: + epoch_history = json.load(f) + + return train_loss_list, val_loss_list, energy_loss_list, force_loss_list, epoch_history + + +def _write_training_info(result_directory, epoch, weight_decay, learning_rate, epochs, batch_size, + directory_path, pdb_list, energy_weight, force_weight, embedding_filename, + ddp: DistributedContext, scheduler, optimizer, dry_run): + """Write training_info.json and copy prior files. Only runs on rank 0.""" + if not ddp.is_main: + return + + training_info_path = os.path.join(result_directory, "training_info.json") + training_info_dict: dict = {} + + if os.path.exists(training_info_path): + with open(training_info_path, "r") as f: + training_info_dict = json.load(f) + if "input_directory" in training_info_dict: + training_info_dict = {"0": training_info_dict} + else: + print("Path", training_info_path, "does not exist") + + training_info_dict[str(epoch)] = { + "weight_decay": weight_decay, + "learning_rate": learning_rate, + "epochs": epochs, + "batch_size": batch_size, + "input_directory": directory_path, + "pdbs": pdb_list, + "energy_weight": energy_weight, + "force_weight": force_weight, + "embedding_filename": embedding_filename, + "world_size": ddp.world_size if ddp.world_size else 1, + } + + if scheduler: + training_info_dict[str(epoch)]["lr_scheduler"] = repr(scheduler) + else: + for g in optimizer.param_groups: + g['lr'] = learning_rate + + if not dry_run: + with open(training_info_path, "w") as f: + json.dump(training_info_dict, f, indent=2) + prior_path = os.path.join(directory_path, "priors.yaml") + if os.path.exists(prior_path): + shutil.copy(prior_path, result_directory) + shutil.copy(os.path.join(directory_path, "prior_params.json"), result_directory) + + +# =========================================================================== +# Per-epoch helpers +# =========================================================================== + +def _set_sampler_epochs(train_samplers, epoch): + """Notify DistributedSamplers of the current epoch so shuffling stays deterministic.""" + for sampler in train_samplers: + if sampler is not None: + sampler.set_epoch(epoch) + + +def _update_loss_history(train_loss_list, val_loss_list, energy_loss_list, force_loss_list, + train_metrics: EpochMetrics, val_metrics: EpochMetrics): + """Append per-epoch means to the running history lists.""" + train_loss_list.append(train_metrics.mean_loss()) + energy_loss_list.append(train_metrics.mean_energy_loss()) + force_loss_list.append(train_metrics.mean_force_loss()) + val_loss_list.append(val_metrics.mean_loss()) + + +def _log_epoch_results(ddp: DistributedContext, epoch, train_loss_list, val_loss_list, + train_metrics: EpochMetrics, val_metrics: EpochMetrics, + extra_train_terms, epoch_history, optimizer, t0, train_data, + result_directory): + """Record epoch stats in epoch_history, write JSON, and print a summary.""" + if not ddp.is_main: + return + + epoch_history[f"{epoch}"] = { + "train_loss": train_loss_list[-1], + "val_loss": val_loss_list[-1], + "energy_loss": train_metrics.mean_energy_loss(), + "force_loss": train_metrics.mean_force_loss(), + "epoch_len": len(train_data), + "lr": [g['lr'] for g in optimizer.param_groups], + } + for k in extra_train_terms: + epoch_history[f"{epoch}"][f"train_loss_{k}"] = train_metrics.mean_term_loss(k) + epoch_history[f"{epoch}"][f"val_loss_{k}"] = val_metrics.mean_term_loss(k) + + with open(os.path.join(result_directory, "epoch_history.json"), "w") as f: + json.dump(epoch_history, f, indent=2) + + print(f"Epoch {epoch} - Train Loss: {train_loss_list[-1]} - Val Loss: {val_loss_list[-1]} - time: {round(time.time() - t0, 2)}s") + if epoch > 0: + print(f" ∆Train: {train_loss_list[-1] - train_loss_list[-2]} - ∆Val: {val_loss_list[-1] - val_loss_list[-2]}") + for k in extra_train_terms: + print(f" Train {k} loss={train_metrics.mean_term_loss(k):.4f}") + print(f" Val {k} loss={val_metrics.mean_term_loss(k):.4f}") + + +def _check_early_stop(ddp: DistributedContext, val_loss_list, first_early_stopping_epoch, + early_stopping, device) -> bool: + """Evaluate early stopping on rank 0 then broadcast the decision to all ranks.""" + should_stop = False + if ddp.is_main: + should_stop = check_early_stopping(val_loss_list[first_early_stopping_epoch:], patience=early_stopping) + should_stop = ddp.broadcast_bool(should_stop, device) + if should_stop and ddp.is_main: + print("Early stopping triggered.") + return should_stop + + +# =========================================================================== +# Forward pass (shared between train and val) +# =========================================================================== + +def _process_sub_batch(sub_batch, parallel_model, criterion, term_criterion, + device_output, loss_cfg: LossConfig, + total_batch_size: int, total_term_batch_size: Dict[str, int]): + """ + Forward one sub-batch through the model and compute weighted losses. + Returns (loss_tensor, loss_sum, energy_loss_sum, force_loss_sum, term_loss_sums). + *_sum values are pre-multiplied by total_batch_size for direct accumulation. + Call loss_tensor.backward() in the training loop. + """ + force = sub_batch.pop("force") + force = force.reshape(-1, force.shape[-1]).to(device_output) + + force_weights = None + if loss_cfg.use_force_weights: + force_weights = sub_batch.pop("forces_weights").reshape(-1).to(device_output) + + energy = None + if loss_cfg.energy_matching: + energy = sub_batch.pop("energy") + energy = energy.reshape(-1, energy.shape[-1]).to(device_output) + + term_targets = {k: sub_batch.pop(k).flatten().to(device_output) for k in loss_cfg.extra_term_keys} + + out_energy, out_force, extra = parallel_model(**sub_batch) + + sub_batch_size = force.numel() + weight = sub_batch_size / total_batch_size + + energy_loss = torch.tensor(0.0) + if loss_cfg.energy_matching: + energy_loss = criterion(out_energy, energy) * weight + + fl = term_criterion(out_force, force) + if force_weights is not None: + fl = fl * force_weights[:, None] + force_loss = fl.mean() * weight + + loss = loss_cfg.energy_weight * energy_loss + loss_cfg.force_weight * force_loss + + term_loss_sums: Dict[str, float] = {} + for k in loss_cfg.extra_term_keys: + td = loss_cfg.train_term_def + if td.get_angle_wrap(k): + tl = (extra[k] - term_targets[k] + torch.pi) % (2 * torch.pi) - torch.pi + tl = tl ** 2 + else: + tl = term_criterion(extra[k], term_targets[k]) + tl = tl / total_term_batch_size[k] + tl = tl * (term_targets[k] >= -10).float() + tl = torch.sum(tl) + loss = loss + tl * td.get_scale(k) + term_loss_sums[k] = tl.item() * total_term_batch_size[k] + + return loss, loss.item() * total_batch_size, energy_loss.item() * total_batch_size, force_loss.item() * total_batch_size, term_loss_sums + + +def _reduce_to_epoch_metrics(total_loss, energy_loss, force_loss, num_cal, + term_losses, term_num_cal, extra_term_keys, + ddp: DistributedContext, device) -> EpochMetrics: + """All-reduce raw accumulators across ranks and package into EpochMetrics.""" + return EpochMetrics( + total_loss=ddp.all_reduce_scalar(total_loss, device), + energy_loss=ddp.all_reduce_scalar(energy_loss, device), + force_loss=ddp.all_reduce_scalar(force_loss, device), + num_cal=int(ddp.all_reduce_scalar(num_cal, device)), + term_losses={k: ddp.all_reduce_scalar(term_losses[k], device) for k in extra_term_keys}, + term_num_cal={k: int(ddp.all_reduce_scalar(term_num_cal[k], device)) for k in extra_term_keys}, + ) - t0 = time.time() - self.model_mgr.train() - train_metrics = self._train_one_epoch( - criterion=criterion, - term_criterion=term_criterion, - verbose=verbose_loss_report, - ) +# =========================================================================== +# Epoch runners +# =========================================================================== + +def run_train_epoch( + epoch, epochs, train_data, parallel_model, model, + optimizer, scheduler, criterion, term_criterion, + device, device_output, loss_cfg: LossConfig, ddp: DistributedContext, + dry_run: bool, verbose_loss_report: bool, + mini_epoch_size, epoch_resume, result_directory, conf, epoch_history, +) -> EpochMetrics: + parallel_model.train() + + total_loss = 0.0 + energy_loss = 0.0 + force_loss = 0.0 + num_cal = 0 + mini_train_loss = 0.0 + mini_num_cal = 0 + epoch_offset = 0 + term_losses: Dict[str, float] = {k: 0.0 for k in loss_cfg.extra_term_keys} + term_num_cal: Dict[str, int] = {k: 0 for k in loss_cfg.extra_term_keys} + + if epoch_resume: + if ddp.is_main: + print("Resuming epoch...") + total_loss = float(epoch_resume["train_loss"]) + num_cal = int(epoch_resume["num_cal"]) + epoch_offset = int(epoch_resume["i"]) + + # Setting miniters is required to keep the bar from stalling when skipping ahead on resume + tqdm_iter = ( + tqdm(enumerate(train_data), desc=f"Training ({epoch}/{epochs})", + total=len(train_data), dynamic_ncols=True, miniters=1) + if ddp.is_main else enumerate(train_data) + ) - self._append_train_metrics(train_metrics) - self.model_mgr.eval() - - val_metrics = self._validate_one_epoch(criterion=criterion, term_criterion=term_criterion) - self.history.val.append(val_metrics["val_loss"]) - - if self.scheduler: - self.scheduler.step(val_metrics["val_loss"]) - - if self.dist.is_main: - self._log_epoch(train_metrics, val_metrics, t0) - - should_stop = False - if self.dist.is_main: - should_stop = check_early_stopping(self.history.val[first_es_epoch:], patience=self.args.early_stopping) - should_stop = self.dist_mgr.broadcast_bool(should_stop, src=0, device=self.model_mgr.device) - - self._save_epoch(train_metrics, val_metrics) - self.dist_mgr.barrier() - - if should_stop: - if self.dist.is_main: - print("Early stopping triggered.") - break - - self.epoch += 1 - - def _set_sampler_epoch(self, epoch: int) -> None: - """Set epoch on DistributedSamplers for determinism.""" - if not self.dist.enabled: - return - for s in self.data_bundle.train_samplers: - if s is not None: - s.set_epoch(epoch) - - def _train_one_epoch(self, criterion, term_criterion, verbose: bool) -> Dict[str, Any]: - """Train for one epoch, returning aggregated metrics.""" - args = self.args - model = self.model_mgr.parallel - device_out = self.model_mgr.device_output - - train_loss = 0.0 - train_energy_loss = 0.0 - train_force_loss = 0.0 - num_cal = 0.0 - - epoch_offset = 0 - mini_train_loss = 0.0 - mini_num_cal = 0.0 - - train_term_losses = {k: 0.0 for k in self.extra_train_terms} - train_term_num_cal = {k: 0 for k in self.extra_train_terms} - - if self.epoch_resume_extra: - if self.dist.is_main: - print("Resuming epoch...") - train_loss = float(self.epoch_resume_extra["train_loss"]) - num_cal = float(self.epoch_resume_extra["num_cal"]) - epoch_offset = int(self.epoch_resume_extra["i"]) - self.epoch_resume_extra = None - - if self.dist.is_main: - it = tqdm( - enumerate(self.data_bundle.train), - desc=f"Training ({self.epoch}/{args.epochs})", - total=len(self.data_bundle.train), - dynamic_ncols=True, - miniters=1, + for i, batch in tqdm_iter: + if i < epoch_offset: + continue + elif epoch_offset and i == epoch_offset: + if ddp.is_main and hasattr(tqdm_iter, 'write'): + tqdm_iter.write(f"Resumed epoch at batch {i}") + elif mini_epoch_size and i > 0 and i % mini_epoch_size == 0: + _save_mini_epoch_checkpoint( + epoch, i, model, optimizer, conf, scheduler, + total_loss, num_cal, mini_train_loss, mini_num_cal, + train_data, result_directory, epoch_history, ddp, ) - else: - it = enumerate(self.data_bundle.train) + if ddp.is_main and hasattr(tqdm_iter, 'write'): + tqdm_iter.write(f"Mini-epoch {epoch}-{i}: Train Loss: {total_loss / num_cal}") + mini_train_loss = 0.0 + mini_num_cal = 0 + + batch_size_total = sum(sb["force"].numel() for sb in batch) + batch_term_sizes = {k: sum(sb[k].numel() for sb in batch) for k in loss_cfg.extra_term_keys} + num_cal += batch_size_total + mini_num_cal += batch_size_total + for k in term_num_cal: + term_num_cal[k] += batch_term_sizes[k] + + optimizer.zero_grad() + for sub_batch in batch: + loss, loss_sum, e_sum, f_sum, term_sums = _process_sub_batch( + sub_batch, parallel_model, criterion, term_criterion, + device_output, loss_cfg, batch_size_total, batch_term_sizes, + ) + total_loss += loss_sum + mini_train_loss += loss_sum + energy_loss += e_sum + force_loss += f_sum + for k, v in term_sums.items(): + term_losses[k] += v + loss.backward() + optimizer.step() - for i, batch in it: - if i < epoch_offset: - continue - if epoch_offset and i == epoch_offset and self.dist.is_main and hasattr(it, "write"): - it.write(f"Resumed epoch at batch {i}") + if scheduler: + scheduler.step_batch(epoch + i / len(train_data)) + + if dry_run: + if ddp.is_main: + print("\nDry run OK!") + sys.exit(0) + + if verbose_loss_report and ddp.is_main and hasattr(tqdm_iter, 'set_description'): + desc = [f"Training ({epoch}/{epochs}) (T={total_loss / num_cal:.4f}"] + for t in loss_cfg.extra_term_keys: + if term_num_cal[t] > 0: + desc.append(f"{t}={term_losses[t] / term_num_cal[t]:.4f}") + tqdm_iter.set_description(", ".join(desc) + ")") + + return _reduce_to_epoch_metrics( + total_loss, energy_loss, force_loss, num_cal, term_losses, term_num_cal, + loss_cfg.extra_term_keys, ddp, device, + ) - if args.mini_epoch_size and i > 0 and (i % args.mini_epoch_size) == 0: - self._save_mini_checkpoint(i, train_loss, num_cal, mini_train_loss, mini_num_cal, len(self.data_bundle.train)) - mini_train_loss, mini_num_cal = 0.0, 0.0 - total_batch_size = sum([sb["force"].numel() for sb in batch]) - num_cal += total_batch_size - mini_num_cal += total_batch_size +def run_val_epoch( + epoch, epochs, val_data, parallel_model, + criterion, term_criterion, device, device_output, + loss_cfg: LossConfig, ddp: DistributedContext, +) -> EpochMetrics: + parallel_model.eval() - total_term_batch_size = {k: sum([sb[k].numel() for sb in batch]) for k in train_term_losses} - for k in train_term_num_cal: - train_term_num_cal[k] += total_term_batch_size[k] + total_loss = 0.0 + num_cal = 0 + term_losses: Dict[str, float] = {k: 0.0 for k in loss_cfg.extra_term_keys} + term_num_cal: Dict[str, int] = {k: 0 for k in loss_cfg.extra_term_keys} - self.optimizer.zero_grad() + val_iter = ( + tqdm(val_data, desc=f"Validation ({epoch}/{epochs})", total=len(val_data), dynamic_ncols=True) + if ddp.is_main else val_data + ) + + with torch.no_grad(): + for batch in val_iter: + batch_size_total = sum(sb["force"].numel() for sb in batch) + batch_term_sizes = {k: sum(sb[k].numel() for sb in batch) for k in loss_cfg.extra_term_keys} + num_cal += batch_size_total + for k in term_num_cal: + term_num_cal[k] += batch_term_sizes[k] for sub_batch in batch: - loss, energy_loss, force_loss, term_loss_dict = self._compute_loss_for_sub_batch( - sub_batch=sub_batch, - criterion=criterion, - term_criterion=term_criterion, - total_batch_size=total_batch_size, - total_term_batch_size=total_term_batch_size, - device_out=device_out, - model=model, + _, loss_sum, _, _, term_sums = _process_sub_batch( + sub_batch, parallel_model, criterion, term_criterion, + device_output, loss_cfg, batch_size_total, batch_term_sizes, ) + total_loss += loss_sum + for k, v in term_sums.items(): + term_losses[k] += v - train_force_loss += float(force_loss) * total_batch_size - if args.energy_matching: - train_energy_loss += float(energy_loss) * total_batch_size - - delta_loss = float(loss) * total_batch_size - train_loss += delta_loss - mini_train_loss += delta_loss - - for k, v in term_loss_dict.items(): - train_term_losses[k] += float(v) * total_term_batch_size[k] - - loss.backward() - - self.optimizer.step() - - if self.scheduler: - self.scheduler.step_batch(self.epoch + i / len(self.data_bundle.train)) - - if args.dry_run: - if self.dist.is_main: - print("\nDry run OK!") - sys.exit(0) - - if verbose and self.dist.is_main and hasattr(it, "set_description"): - desc = [f"Training ({self.epoch}/{args.epochs}) (T={train_loss/num_cal:.4f}"] - for tname in train_term_losses: - desc.append(f"{tname}={train_term_losses[tname]/train_term_num_cal[tname]:.4f}") - it.set_description(", ".join(desc) + ")") - - metrics = { - "train_loss_sum": train_loss, - "train_energy_loss_sum": train_energy_loss, - "train_force_loss_sum": train_force_loss, - "num_cal": num_cal, - "train_term_losses_sum": train_term_losses, - "train_term_num_cal": train_term_num_cal, - } - return self._aggregate_train_metrics(metrics) - - def _compute_loss_for_sub_batch( - self, - sub_batch: Dict[str, Any], - criterion, - term_criterion, - total_batch_size: int, - total_term_batch_size: Dict[str, int], - device_out, - model: nn.Module, - ): - """Compute loss (and side losses) for a single sub-batch.""" - args = self.args - - force = sub_batch.pop("force").reshape(-1, sub_batch["force"].shape[-1]).to(device_out) - - force_weights = None - if args.use_force_weights: - force_weights = sub_batch.pop("forces_weights").reshape(-1).to(device_out) - - energy = None - if args.energy_matching: - energy = sub_batch.pop("energy").reshape(-1, sub_batch["energy"].shape[-1]).to(device_out) - - term_targets = {} - for k in self.extra_train_terms: - term_targets[k] = sub_batch.pop(k).flatten().to(device_out) - - out_energy, out_force, extra = model(**sub_batch) - - sub_batch_size = force.numel() - energy_loss = torch.tensor(0.0, device=out_force.device) - if args.energy_matching: - energy_loss = criterion(out_energy, energy) * (sub_batch_size / total_batch_size) - - force_loss = term_criterion(out_force, force) - if force_weights is not None: - force_loss = force_loss * force_weights[:, None] - force_loss = force_loss.mean() * (sub_batch_size / total_batch_size) - - loss = args.energy_weight * energy_loss + args.force_weight * force_loss - - term_loss_dict: Dict[str, Tensor] = {} - for k in self.extra_train_terms: - if self.args.train_term_def.wrap(k): - tl = (extra[k] - term_targets[k] + torch.pi) % (2 * torch.pi) - torch.pi - tl = tl ** 2 - else: - tl = term_criterion(extra[k], term_targets[k]) - - tl = tl / total_term_batch_size[k] - tl = tl * (term_targets[k] >= -10).float() - tl = torch.sum(tl) - - loss = loss + tl * self.args.train_term_def.scale(k) - term_loss_dict[k] = tl - - return loss, energy_loss, force_loss, term_loss_dict - - def _aggregate_train_metrics(self, metrics: Dict[str, Any]) -> Dict[str, Any]: - """All-reduce training metrics across processes.""" - if not self.dist.enabled: - return metrics - - device = self.model_mgr.device - metrics["train_loss_sum"] = self.dist_mgr.all_reduce_sum(metrics["train_loss_sum"], device) - metrics["num_cal"] = self.dist_mgr.all_reduce_sum(metrics["num_cal"], device) - metrics["train_energy_loss_sum"] = self.dist_mgr.all_reduce_sum(metrics["train_energy_loss_sum"], device) - metrics["train_force_loss_sum"] = self.dist_mgr.all_reduce_sum(metrics["train_force_loss_sum"], device) - - for k in metrics["train_term_losses_sum"]: - metrics["train_term_losses_sum"][k] = self.dist_mgr.all_reduce_sum(metrics["train_term_losses_sum"][k], device) - metrics["train_term_num_cal"][k] = int(self.dist_mgr.all_reduce_sum(float(metrics["train_term_num_cal"][k]), device)) - - return metrics - - def _append_train_metrics(self, train_metrics: Dict[str, Any]) -> None: - """Append normalized train/energy/force losses to history.""" - num_cal = train_metrics["num_cal"] - self.history.train.append(train_metrics["train_loss_sum"] / num_cal) - self.history.energy.append(train_metrics["train_energy_loss_sum"] / num_cal) - self.history.force.append(train_metrics["train_force_loss_sum"] / num_cal) - - def _validate_one_epoch(self, criterion, term_criterion) -> Dict[str, Any]: - """Validate for one epoch, returning aggregated metrics.""" - args = self.args - model = self.model_mgr.parallel - device_out = self.model_mgr.device_output - - val_loss = 0.0 - num_cal = 0.0 - - val_term_losses = {k: 0.0 for k in self.extra_train_terms} - val_term_num_cal = {k: 0 for k in self.extra_train_terms} - - if self.dist.is_main: - it = tqdm(self.data_bundle.val, desc=f"Validation ({self.epoch}/{args.epochs})", total=len(self.data_bundle.val), dynamic_ncols=True) - else: - it = self.data_bundle.val + return _reduce_to_epoch_metrics( + total_loss, 0.0, 0.0, num_cal, term_losses, term_num_cal, + loss_cfg.extra_term_keys, ddp, device, + ) - for batch in it: - total_batch_size = sum([sb["force"].numel() for sb in batch]) - num_cal += total_batch_size - total_term_batch_size = {k: sum([sb[k].numel() for sb in batch]) for k in val_term_losses} - for k in val_term_num_cal: - val_term_num_cal[k] += total_term_batch_size[k] +# =========================================================================== +# Training orchestration +# =========================================================================== + +def train_model( + directory_path, conf_path, result_directory, dry_run, gpu_ids, + weight_decay, learning_rate, epochs, batch_size, val_ratio, atoms_per_call, + scheduler, reset_early_stopping, enable_shuffle, mini_epoch_size, early_stopping, + checkpoint_save, subsetpdbs, energy_weight, force_weight, energy_matching, + train_term_def, embedding_filename, dataset_chunk_size, use_npfile, use_force_weights, + ddp: Optional[DistributedContext] = None, +): + if ddp is None: + ddp = DistributedContext() + + energy_filename = "tica_delta_energies.npy" if energy_matching else None + if embedding_filename is None: + embedding_filename = "embeddings.npy" + + pdb_lists, pdb_list = _load_pdb_list(directory_path, subsetpdbs, dataset_chunk_size) + datasets, train_data, val_data, train_samplers = _build_all_dataloaders( + pdb_lists, directory_path, energy_filename, embedding_filename, + use_npfile, enable_shuffle, val_ratio, batch_size, atoms_per_call, ddp, + ) - for sub_batch in batch: - force = sub_batch.pop("force").reshape(-1, sub_batch["force"].shape[-1]).to(device_out) - - force_weights = None - if args.use_force_weights: - force_weights = sub_batch.pop("forces_weights").reshape(-1).to(device_out) - - energy = None - if args.energy_matching: - energy = sub_batch.pop("energy").reshape(-1, sub_batch["energy"].shape[-1]).to(device_out) - - term_targets = {} - for k in self.extra_train_terms: - term_targets[k] = sub_batch.pop(k).flatten().to(device_out) - - out_energy, out_force, extra = model(**sub_batch) - - sub_batch_size = force.numel() - energy_loss = torch.tensor(0.0, device=out_force.device) - if args.energy_matching: - energy_loss = criterion(out_energy, energy) * (sub_batch_size / total_batch_size) - - force_loss = term_criterion(out_force, force) - if force_weights is not None: - force_loss = force_loss * force_weights[:, None] - force_loss = force_loss.mean() * (sub_batch_size / total_batch_size) - - loss = args.energy_weight * energy_loss + args.force_weight * force_loss - val_loss += float(loss) * total_batch_size - - for k in self.extra_train_terms: - if self.args.train_term_def.wrap(k): - tl = (extra[k] - term_targets[k] + torch.pi) % (2 * torch.pi) - torch.pi - tl = tl ** 2 - else: - tl = term_criterion(extra[k], term_targets[k]) - - tl = tl / total_term_batch_size[k] - tl = tl * (term_targets[k] >= -10).float() - tl = torch.sum(tl) - val_term_losses[k] += float(tl) * total_term_batch_size[k] - - if self.dist.enabled: - device = self.model_mgr.device - val_loss = self.dist_mgr.all_reduce_sum(val_loss, device) - num_cal = self.dist_mgr.all_reduce_sum(num_cal, device) - for k in val_term_losses: - val_term_losses[k] = self.dist_mgr.all_reduce_sum(val_term_losses[k], device) - val_term_num_cal[k] = int(self.dist_mgr.all_reduce_sum(float(val_term_num_cal[k]), device)) - - return { - "val_loss_sum": val_loss, - "num_cal": num_cal, - "val_loss": val_loss / num_cal, - "val_term_losses_sum": val_term_losses, - "val_term_num_cal": val_term_num_cal, - } - - def _save_mini_checkpoint(self, batch_i: int, train_loss: float, num_cal: float, mini_train_loss: float, mini_num_cal: float, epoch_len: int) -> None: - """Save and rotate mini-checkpoint (main process only).""" - tmp = os.path.join(self.result_dir, f"checkpoint-{self.epoch}-{batch_i}.pth") - self.checkpointer.save( - path=tmp, - epoch=self.epoch, - model_state=self.model_mgr.state_dict_for_saving(), - optimizer=self.optimizer, - model_conf=self.conf, - scheduler=self.scheduler, - extra={"train_loss": train_loss, "num_cal": num_cal, "i": batch_i}, - ) + conf = _load_model_conf(conf_path, ddp) + if "harmonic_net" in conf and train_term_def.get_names(): + conf["harmonic_net_return_terms"] = True - if not self.dist.is_main: - return + model = create_model(args=conf) + parallel_model, device, device_output = build_parallel_model(model, gpu_ids, ddp) + if ddp.is_main: + print("Model:\n", model, "\n") - os.replace(tmp, os.path.join(self.result_dir, "checkpoint-mini.pth")) + extra_train_terms = _augment_datasets(conf, datasets, train_term_def, use_force_weights, embedding_filename, ddp) + loss_cfg = LossConfig( + energy_matching=energy_matching, energy_weight=energy_weight, force_weight=force_weight, + train_term_def=train_term_def, extra_term_keys=extra_train_terms, use_force_weights=use_force_weights, + ) + criterion = nn.MSELoss() + term_criterion = nn.MSELoss(reduction="none") + + optimizer = _build_optimizer(model, weight_decay, learning_rate) + if scheduler: + scheduler.initialize(optimizer) + + epoch, epoch_resume = _load_checkpoint_state(result_directory, model, optimizer, scheduler, device, ddp) + result_directory = _setup_result_directory(result_directory, epoch, dry_run, ddp) + ddp.barrier() + + train_loss_list, val_loss_list, energy_loss_list, force_loss_list, epoch_history = _load_history(result_directory, epoch) + _write_training_info( + result_directory, epoch, weight_decay, learning_rate, epochs, batch_size, + directory_path, pdb_list, energy_weight, force_weight, embedding_filename, + ddp, scheduler, optimizer, dry_run, + ) - self.epoch_history.update( - f"{self.epoch}-{batch_i}", - { - "train_loss": train_loss / num_cal, - "mini_train_loss": (mini_train_loss / mini_num_cal) if mini_num_cal else 0.0, - "epoch_len": epoch_len, - "lr": [g["lr"] for g in self.optimizer.param_groups], - }, + if scheduler and scheduler.is_annealing(): + early_stopping = -1 + first_early_stopping_epoch = epoch if reset_early_stopping else 0 + verbose_loss_report = sys.stdout.isatty() + + while epoch < epochs: + _set_sampler_epochs(train_samplers, epoch) + t0 = time.time() + + train_metrics = run_train_epoch( + epoch, epochs, train_data, parallel_model, model, + optimizer, scheduler, criterion, term_criterion, + device, device_output, loss_cfg, ddp, + dry_run, verbose_loss_report, + mini_epoch_size, epoch_resume, result_directory, conf, epoch_history, ) + epoch_resume = None - def _log_epoch(self, train_metrics: Dict[str, Any], val_metrics: Dict[str, Any], t0: float) -> None: - """Print epoch summary and write epoch_history.json (main process only).""" - entry = { - "train_loss": self.history.train[-1], - "val_loss": self.history.val[-1], - "energy_loss": self.history.energy[-1], - "force_loss": self.history.force[-1], - "epoch_len": len(self.data_bundle.train), - "lr": [g["lr"] for g in self.optimizer.param_groups], - } - - for k in self.extra_train_terms: - entry[f"train_loss_{k}"] = train_metrics["train_term_losses_sum"][k] / max(1, train_metrics["train_term_num_cal"][k]) - entry[f"val_loss_{k}"] = val_metrics["val_term_losses_sum"][k] / max(1, val_metrics["val_term_num_cal"][k]) - - self.epoch_history.update(str(self.epoch), entry) - - print( - f"Epoch {self.epoch} - Train Loss: {self.history.train[-1]} - Val Loss: {self.history.val[-1]} - time: {round(time.time() - t0, 2)}s" - ) - if self.epoch > 0: - print(f" ∆Train: {self.history.train[-1]-self.history.train[-2]} - ∆Val: {self.history.val[-1]-self.history.val[-2]}") - - for k in self.extra_train_terms: - print(f" Train {k} loss={entry[f'train_loss_{k}']:.4f}") - print(f" Val {k} loss={entry[f'val_loss_{k}']:.4f}") - - def _save_epoch(self, train_metrics: Dict[str, Any], val_metrics: Dict[str, Any]) -> None: - """Save checkpoint/history (main process only), with barriers handled by caller.""" - tmp = os.path.join(self.result_dir, f"checkpoint-{self.epoch}.pth") - - self.checkpointer.save( - path=tmp, - epoch=self.epoch + 1, - model_state=self.model_mgr.state_dict_for_saving(), - optimizer=self.optimizer, - model_conf=self.conf, - scheduler=self.scheduler, + val_metrics = run_val_epoch( + epoch, epochs, val_data, parallel_model, + criterion, term_criterion, device, device_output, loss_cfg, ddp, ) - if not self.dist.is_main: - return + _update_loss_history(train_loss_list, val_loss_list, energy_loss_list, force_loss_list, + train_metrics, val_metrics) - mini = os.path.join(self.result_dir, "checkpoint-mini.pth") - if os.path.exists(mini): - os.unlink(mini) + if scheduler: + scheduler.step(val_loss_list[-1]) - main = os.path.join(self.result_dir, "checkpoint.pth") - if self.args.checkpoint_save and (self.epoch % self.args.checkpoint_save == 0): - shutil.copyfile(tmp, main) - else: - os.replace(tmp, main) + _log_epoch_results(ddp, epoch, train_loss_list, val_loss_list, train_metrics, val_metrics, + extra_train_terms, epoch_history, optimizer, t0, train_data, result_directory) - best = os.path.join(self.result_dir, "checkpoint-best.pth") - if self.history.val[-1] <= float(np.min(self.history.val)): - shutil.copyfile(main, best) + if _check_early_stop(ddp, val_loss_list, first_early_stopping_epoch, early_stopping, device): + break - self.history.save(self.result_dir) - print(" Checkpoint saved.") + history = {"train": train_loss_list, "val": val_loss_list, "energy": energy_loss_list, "force": force_loss_list} + _commit_epoch_checkpoint(epoch, model, optimizer, conf, scheduler, result_directory, + val_loss_list, history, checkpoint_save, ddp) + ddp.barrier() + epoch += 1 -# ----------------------------- CLI glue ------------------------------------ -def parse_args(): - import argparse +# =========================================================================== +# Entry-point helpers: collection → validation → processing +# =========================================================================== - p = argparse.ArgumentParser(description="Train a CGSchNet network with distributed support (refactored)") - p.add_argument("input", help="Processed data to train on ") - p.add_argument("result", default=None, nargs="?", help="Checkpoint directory to continue") - p.add_argument("-c", "--config", default="../configs/config.yaml", type=str) - - p.add_argument("--gpus", default=None, type=str, help='List of GPUs (e.g. "0,1,2") or "cpu"') - p.add_argument("--batch", type=int, default=50) - p.add_argument("--epochs", type=int, default=25) - p.add_argument("--lr", type=float, default=1e-4) - p.add_argument("--wd", type=float, default=0.0) - p.add_argument("--val-ratio", type=float, default=0.1) - p.add_argument("--apc", "--atoms-per-call", dest="apc", type=int, default=None) - - p.add_argument("--cos-anneal", default=None, help='Cosine anneal: "T_0,T_mult"') - p.add_argument("--cos-lr", default=None, help='Cosine LR: "T_max,eta_min"') - p.add_argument("--exp-lr", default=None, help='Exponential LR: "gamma"') - p.add_argument("--plateau-lr", default=None, help='Plateau LR: "factor,patience,threshold,min_lr"') - - p.add_argument("--dry-run", action="store_true") - p.add_argument("--reset-early-stopping", action="store_true") - p.add_argument("--no-shuffle", action="store_true") - p.add_argument("--mini-epoch", type=int, default=None) - p.add_argument("--early-stopping", type=int, default=1) - p.add_argument("--checkpoint-save", type=int, default=10) - - p.add_argument("--subsetpdbs", default="ok_list.txt", type=str) - p.add_argument("--energy-weight", default=0.0, type=float) - p.add_argument("--force-weight", default=1.0, type=float) - - p.add_argument("--term-def", default=None, type=str) - p.add_argument("--embedding", type=str, default=None) - p.add_argument("--chunk-dataset", type=int, default=None) - p.add_argument("--npfile", action="store_true") - p.add_argument("--use-force-weights", action="store_true") - - p.add_argument("--distributed", action="store_true") - return p.parse_args() - -def relax_open_file_limit() -> None: - """Raise RLIMIT_NOFILE as much as possible.""" +def build_parser(): + import argparse + parser = argparse.ArgumentParser(description="Train a CGSchNet network") + parser.add_argument("input", help="Processed data to train on") + parser.add_argument("result", default=None, nargs="?", help="Checkpoint directory to continue") + parser.add_argument("-c", "--config", default="../configs/config.yaml", type=str, help="Path to model architecture config YAML") + parser.add_argument("--gpus", default=None, type=str, help="List of GPUs to train on (e.g. \"0,1,2\") or \"cpu\"") + parser.add_argument("--batch", type=int, default=50, help="The batch size to use") + parser.add_argument("--epochs", type=int, default=25, help="The total number of epochs to train for") + parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate") + parser.add_argument("--wd", type=float, default=0, help="Weight decay") + parser.add_argument("--val-ratio", type=float, default=0.1, help="Validation set ratio, should be between 0.0 and 1.0") + parser.add_argument("--apc", "--atoms-per-call", type=int, default=None, help="Number of atoms to include in each sub-batch") + parser.add_argument("--cos-anneal", default=None, help="Train using cosine annealing, parameters are \"T_0,T_mult\"") + parser.add_argument("--cos-lr", default=None, help="Train using a cosine learning rate, parameters are \"T_max,eta_min\"") + parser.add_argument("--exp-lr", default=None, help="Train using an exponential learning rate, parameter is \"gamma\"") + parser.add_argument("--plateau-lr", default=None, help="Train using a plateau learning rate, parameters are \"factor,patience,threshold,min_lr\"") + parser.add_argument("--dry-run", action="store_true", help="Do a dry run of the training loop but produce no output") + parser.add_argument("--reset-early-stopping", action="store_true", help="Reset the early stopping check to start from the current epoch") + parser.add_argument("--no-shuffle", action="store_true", help="Do not shuffle the training dataset") + parser.add_argument("--mini-epoch", type=int, default=None, help="Save a mini epoch after every n batches") + parser.add_argument("--early-stopping", type=int, default=1, help="Epochs validation loss can increase before early stopping, or -1 to disable (default=1)") + parser.add_argument("--checkpoint-save", type=int, default=10, help="Save a backup checkpoint every n epochs, 0 to disable (default=10)") + parser.add_argument("--subsetpdbs", default='ok_list.txt', type=str, help="Change the pdbid list used when reading in the dataset (default=ok_list.txt)") + parser.add_argument("--energy-weight", default=0.0, type=float, help="Energy weighting for loss function") + parser.add_argument("--force-weight", default=1.0, type=float, help="Force weighting for loss function") + parser.add_argument("--term-def", default=None, type=str, help="Path to a term definition yaml file for additional loss terms during training") + parser.add_argument("--embedding", type=str, default=None, help="Alternate file to load embeddings from (default: embeddings.npy)") + parser.add_argument("--chunk-dataset", type=int, default=None, help="Break the dataset into chunks of n proteins per batch") + parser.add_argument("--npfile", action="store_true", help="Use file loader instead of mmap to load dataset") + parser.add_argument("--use-force-weights", action="store_true", default=False, help="Use per-bead force weights in training") + parser.add_argument("--distributed", action="store_true", help="Enable distributed training (also auto-detected from RANK/SLURM_PROCID env vars)") + return parser + + +def validate_config(cfg) -> None: + assert torch.cuda.is_available(), "CUDA is not available, please run on a machine with CUDA or use --gpus cpu" + assert os.path.isdir(cfg.input), f"Input directory does not exist: {cfg.input}" + assert os.path.isfile(cfg.config), f"Config file does not exist: {cfg.config}" + assert cfg.checkpoint_save >= 0, "--checkpoint-save must be >= 0" + n_schedulers = sum(s is not None for s in [cfg.cos_anneal, cfg.cos_lr, cfg.exp_lr, cfg.plateau_lr]) + assert n_schedulers <= 1, "At most one LR scheduler may be specified (--cos-anneal, --cos-lr, --exp-lr, --plateau-lr)" + + +def process_config(cfg) -> dict: + if cfg.gpus == "cpu": + gpu_ids = "cpu" + elif cfg.gpus: + gpu_ids = [int(i) for i in cfg.gpus.strip().split(",")] + else: + gpu_ids = "cpu" + + lr_scheduler = None + if cfg.cos_anneal: + T_0, T_mult = [int(i) for i in cfg.cos_anneal.split(",")] + lr_scheduler = SchedulerWrapper_CosineAnnealingWarmRestarts(T_0, T_mult) + elif cfg.cos_lr: + T_max, eta_min = cfg.cos_lr.split(",") + lr_scheduler = SchedulerWrapper_CosineAnnealingLR(int(T_max), float(eta_min)) + elif cfg.exp_lr: + lr_scheduler = SchedulerWrapper_ExponentialLR(float(cfg.exp_lr)) + elif cfg.plateau_lr: + factor, patience, threshold, min_lr = cfg.plateau_lr.split(",") + lr_scheduler = SchedulerWrapper_ReduceLROnPlateau(float(factor), int(patience), float(threshold), float(min_lr)) + + train_term_def = TermDef(path=cfg.term_def) if cfg.term_def is not None else TermDef() + + # Raise the OS file-descriptor limit — large datasets open ~4 files per molecule soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard)) -def parse_gpu_ids(gpus_arg: Optional[str]): - """Parse --gpus argument into "cpu" or list[int].""" - if not gpus_arg: - return "cpu" - if gpus_arg == "cpu": - return "cpu" - return [int(i) for i in gpus_arg.strip().split(",")] - -def build_train_args(args) -> TrainArgs: - """Translate argparse args into TrainArgs dataclass.""" - lr_scheduler = SchedulerFactory.from_args(args) - train_term_def = TermDef(path=args.term_def) if args.term_def else TermDef() - - gpu_ids = parse_gpu_ids(args.gpus) - energy_matching = args.energy_weight != 0.0 - - return TrainArgs( - directory_path=args.input, - result_directory=args.result, - conf_path=args.config, + return dict( + directory_path=cfg.input, + result_directory=cfg.result, + conf_path=cfg.config, + dry_run=cfg.dry_run, + weight_decay=cfg.wd, + learning_rate=cfg.lr, gpu_ids=gpu_ids, - weight_decay=args.wd, - learning_rate=args.lr, - epochs=args.epochs, - batch_size=args.batch, - val_ratio=args.val_ratio, - atoms_per_call=args.apc, + epochs=cfg.epochs, + batch_size=cfg.batch, + val_ratio=cfg.val_ratio, + atoms_per_call=cfg.apc, scheduler=lr_scheduler, - dry_run=args.dry_run, - reset_early_stopping=args.reset_early_stopping, - enable_shuffle=not args.no_shuffle, - mini_epoch_size=args.mini_epoch, - early_stopping=args.early_stopping, - checkpoint_save=args.checkpoint_save, - subsetpdbs=args.subsetpdbs, - energy_weight=args.energy_weight, - force_weight=args.force_weight, - energy_matching=energy_matching, + reset_early_stopping=cfg.reset_early_stopping, + enable_shuffle=not cfg.no_shuffle, + mini_epoch_size=cfg.mini_epoch, + early_stopping=cfg.early_stopping, + checkpoint_save=cfg.checkpoint_save, + subsetpdbs=cfg.subsetpdbs, + energy_weight=cfg.energy_weight, + force_weight=cfg.force_weight, + energy_matching=cfg.energy_weight != 0.0, train_term_def=train_term_def, - embedding_filename=args.embedding, - dataset_chunk_size=args.chunk_dataset, - use_npfile=args.npfile, - use_force_weights=args.use_force_weights, + embedding_filename=cfg.embedding, + dataset_chunk_size=cfg.chunk_dataset, + use_npfile=cfg.npfile, + use_force_weights=cfg.use_force_weights, ) -def main(): - args = parse_args() - assert torch.cuda.is_available(), "CUDA is not available, please run on a machine with CUDA or use --gpus cpu" - assert os.path.isdir(args.input), f"Input directory does not exist: {args.input}" - assert os.path.isfile(args.config), f"Config file does not exist: {args.config}" - assert args.checkpoint_save >= 0 - - relax_open_file_limit() +if __name__ == "__main__": + from module.base_config import BaseConfig - dist_mgr = DistributedManager(enable=args.distributed) - dist_mgr.setup() + cfg = BaseConfig(build_parser()) + validate_config(cfg) + params = process_config(cfg) - train_args = build_train_args(args) + ddp = DistributedContext() + if cfg.distributed or 'RANK' in os.environ or 'SLURM_PROCID' in os.environ: + ddp = setup_distributed() + if ddp.is_distributed and ddp.is_main: + print(f"Initialized process {ddp.rank}/{ddp.world_size} (local_rank={ddp.local_rank})") try: - trainer = Trainer(train_args, dist_mgr) - trainer.run() + train_model(**params, ddp=ddp) except Exception as e: traceback.print_tb(e.__traceback__) print(e) sys.exit(1) finally: - dist_mgr.cleanup() - -if __name__ == "__main__": - main() - + cleanup_distributed() From 4ca608496ca35c65c2b0618cb50ec473f21d9802 Mon Sep 17 00:00:00 2001 From: bamboozle-jpg Date: Wed, 6 May 2026 17:19:01 -0700 Subject: [PATCH 4/4] Kevin changes pre-PR --- modules | 2 +- train.py | 29 +++++++++++++++++------------ 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/modules b/modules index 78e5ec5..7dd1aec 160000 --- a/modules +++ b/modules @@ -1 +1 @@ -Subproject commit 78e5ec5d0743420b0eb8e526d80c89a4e734b704 +Subproject commit 7dd1aec113b0fe14d318db3e57e2b714ab0b269d diff --git a/train.py b/train.py index 2495d05..39b4375 100755 --- a/train.py +++ b/train.py @@ -259,7 +259,7 @@ def __iter__(self): def gen_dataloaders(directory_path, pdb_list, energy_filename, embedding_filename, use_npfile, enable_shuffle, val_ratio, batch_size, atoms_per_call, - ddp: DistributedContext): + num_workers, ddp: DistributedContext): """Build train/val DataLoaders for one PDB chunk, with optional DistributedSampler.""" print("Dataset:", " ".join(pdb_list)) all_data = dataset.ProteinDataset( @@ -293,11 +293,11 @@ def gen_dataloaders(directory_path, pdb_list, energy_filename, embedding_filenam ) train_loader = DataLoader( - train_set, batch_size=batch_size, shuffle=False, num_workers=4, + train_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, persistent_workers=True, pin_memory=True, collate_fn=collate_fn, sampler=train_sampler, ) val_loader = DataLoader( - val_set, batch_size=batch_size, shuffle=False, num_workers=4, + val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, persistent_workers=True, pin_memory=True, collate_fn=collate_fn, sampler=val_sampler, ) return all_data, train_loader, val_loader, train_sampler @@ -305,7 +305,9 @@ def gen_dataloaders(directory_path, pdb_list, energy_filename, embedding_filenam def _load_pdb_list(directory_path, subsetpdbs, dataset_chunk_size): """Read, deduplicate, shuffle, and optionally chunk the PDB ID list.""" - with open(os.path.join(directory_path, "result", subsetpdbs), 'r') as f: + if subsetpdbs is None: + subsetpdbs = os.path.join(directory_path, "result", "ok_list.txt") + with open(subsetpdbs, 'r') as f: pdb_list = f.read().split('\n') pdb_list = sorted(set(p for p in pdb_list if p)) pdb_list = deterministic_shuffle(pdb_list, seed=47563537) @@ -318,13 +320,14 @@ def _load_pdb_list(directory_path, subsetpdbs, dataset_chunk_size): def _build_all_dataloaders(pdb_lists, directory_path, energy_filename, embedding_filename, use_npfile, enable_shuffle, val_ratio, batch_size, atoms_per_call, + num_workers, ddp: DistributedContext): """Build datasets and merged RoundRobin loaders for all PDB chunks.""" datasets, train_loaders, val_loaders, train_samplers = [], [], [], [] for pdb_chunk in pdb_lists: ds, train_loader, val_loader, sampler = gen_dataloaders( directory_path, pdb_chunk, energy_filename, embedding_filename, - use_npfile, enable_shuffle, val_ratio, batch_size, atoms_per_call, ddp, + use_npfile, enable_shuffle, val_ratio, batch_size, atoms_per_call, num_workers, ddp, ) datasets.append(ds) train_loaders.append(train_loader) @@ -542,9 +545,7 @@ def _setup_result_directory(result_directory, epoch, dry_run, ddp: DistributedCo """Create the result directory for a fresh run, or validate it for resumption.""" if epoch > 0: pass # Resuming — directory already exists - elif not result_directory or not os.path.exists(result_directory): - if not result_directory: - result_directory = "../data/result-" + datetime.datetime.now().strftime("%Y.%m.%d-%H.%M.%S") + elif not os.path.exists(result_directory): if not dry_run and ddp.is_main: os.makedirs(result_directory, exist_ok=False) if ddp.is_main: @@ -910,6 +911,7 @@ def train_model( scheduler, reset_early_stopping, enable_shuffle, mini_epoch_size, early_stopping, checkpoint_save, subsetpdbs, energy_weight, force_weight, energy_matching, train_term_def, embedding_filename, dataset_chunk_size, use_npfile, use_force_weights, + num_workers, ddp: Optional[DistributedContext] = None, ): if ddp is None: @@ -922,7 +924,7 @@ def train_model( pdb_lists, pdb_list = _load_pdb_list(directory_path, subsetpdbs, dataset_chunk_size) datasets, train_data, val_data, train_samplers = _build_all_dataloaders( pdb_lists, directory_path, energy_filename, embedding_filename, - use_npfile, enable_shuffle, val_ratio, batch_size, atoms_per_call, ddp, + use_npfile, enable_shuffle, val_ratio, batch_size, atoms_per_call, num_workers, ddp, ) conf = _load_model_conf(conf_path, ddp) @@ -1007,7 +1009,7 @@ def build_parser(): import argparse parser = argparse.ArgumentParser(description="Train a CGSchNet network") parser.add_argument("input", help="Processed data to train on") - parser.add_argument("result", default=None, nargs="?", help="Checkpoint directory to continue") + parser.add_argument("result", help="Checkpoint directory to continue") parser.add_argument("-c", "--config", default="../configs/config.yaml", type=str, help="Path to model architecture config YAML") parser.add_argument("--gpus", default=None, type=str, help="List of GPUs to train on (e.g. \"0,1,2\") or \"cpu\"") parser.add_argument("--batch", type=int, default=50, help="The batch size to use") @@ -1023,10 +1025,11 @@ def build_parser(): parser.add_argument("--dry-run", action="store_true", help="Do a dry run of the training loop but produce no output") parser.add_argument("--reset-early-stopping", action="store_true", help="Reset the early stopping check to start from the current epoch") parser.add_argument("--no-shuffle", action="store_true", help="Do not shuffle the training dataset") + parser.add_argument("--num-workers", type=int, default=0, help="Number of workers to use for data loading") parser.add_argument("--mini-epoch", type=int, default=None, help="Save a mini epoch after every n batches") parser.add_argument("--early-stopping", type=int, default=1, help="Epochs validation loss can increase before early stopping, or -1 to disable (default=1)") parser.add_argument("--checkpoint-save", type=int, default=10, help="Save a backup checkpoint every n epochs, 0 to disable (default=10)") - parser.add_argument("--subsetpdbs", default='ok_list.txt', type=str, help="Change the pdbid list used when reading in the dataset (default=ok_list.txt)") + parser.add_argument("--subsetpdbs", default=None, type=str, help="Change the pdbid list used when reading in the dataset (default=input_dir/result/ok_list.txt)") parser.add_argument("--energy-weight", default=0.0, type=float, help="Energy weighting for loss function") parser.add_argument("--force-weight", default=1.0, type=float, help="Force weighting for loss function") parser.add_argument("--term-def", default=None, type=str, help="Path to a term definition yaml file for additional loss terms during training") @@ -1041,6 +1044,7 @@ def build_parser(): def validate_config(cfg) -> None: assert torch.cuda.is_available(), "CUDA is not available, please run on a machine with CUDA or use --gpus cpu" assert os.path.isdir(cfg.input), f"Input directory does not exist: {cfg.input}" + assert os.path.isdir(cfg.result), f"Result directory does not exist: {cfg.result}" assert os.path.isfile(cfg.config), f"Config file does not exist: {cfg.config}" assert cfg.checkpoint_save >= 0, "--checkpoint-save must be >= 0" n_schedulers = sum(s is not None for s in [cfg.cos_anneal, cfg.cos_lr, cfg.exp_lr, cfg.plateau_lr]) @@ -1072,7 +1076,7 @@ def process_config(cfg) -> dict: # Raise the OS file-descriptor limit — large datasets open ~4 files per molecule soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) - resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard)) + resource.setrlimit(resource.RLIMIT_NOFILE, (soft, hard)) return dict( directory_path=cfg.input, @@ -1101,6 +1105,7 @@ def process_config(cfg) -> dict: dataset_chunk_size=cfg.chunk_dataset, use_npfile=cfg.npfile, use_force_weights=cfg.use_force_weights, + num_workers=cfg.num_workers, )