diff --git a/.gitignore b/.gitignore index 901547b..ed6c50a 100644 --- a/.gitignore +++ b/.gitignore @@ -56,6 +56,9 @@ venv.bak/ *.zip *.png *.pdf +*.md +*.json +*.lock # Dirs *temp* @@ -68,3 +71,4 @@ venv.bak/ !*logo.txt !*parameters*.yaml !*requirements*.txt +!README.md diff --git a/README.md b/README.md index 496b206..0065c9d 100644 --- a/README.md +++ b/README.md @@ -4,28 +4,30 @@ This repository introduces **FRAME**, a framework for learning fragment-based molecular representations to enhance the interpretability of graph neural networks in drug discovery. FRAME represents chemically meaningful fragments as graph nodes and is compatible with several GNN architectures, including GCN, GAT, and AttentiveFP. It also integrates Integrated Gradients to generate more transparent and chemically grounded model explanations. -## ⚙️ **Installation** -1. Clone the repo: +## Installation -2. Create and activate your `virtualenv` with Python 3.12, for example as described [here](https://docs.python.org/3/library/venv.html). +FRAME is installed with [`uv`](https://docs.astral.sh/uv/), which picks the right `torch` wheels (CUDA 12.8) for you. -3. Install [PyTorch **2.8.0**](https://pytorch.org/get-started/locally/) using: +1. Install [`uv`](https://docs.astral.sh/uv/getting-started/installation/) if you don't already have it. +2. Clone the repo. +3. From the project root, run: ```console - pip install torch==2.8.0 -f https://download.pytorch.org/whl/cu129 + uv sync ``` -4. Install FRAME using: + That creates a `.venv/` with Python 3.11+ and installs PocketGraph along with everything it depends on. You can use prefix commands with `uv run` (e.g. `uv run frame_tune -c parameters.yaml`). + + + To install the `frame_*` commands globally (isolated in their own environment, available on your `PATH` without having to activate a venv), use `uv tool install`: ```console - python -m pip install . - ``` - or for development: - ```console - python -m pip install -e . + uv tool install . ``` -## 📂 Dataset Requirements +If you'd rather not use `uv`, you can install the dependencies declared in [pyproject.toml](pyproject.toml) directly with `pip` in a Python 3.11+ environment. + +## Dataset Requirements The CSV file used in FRAME **must** include the following columns: - **`id`** – A unique identifier for each entry. @@ -39,7 +41,7 @@ The CSV file used in FRAME **must** include the following columns: Please ensure that all entries follow this structure so the dataset can be correctly loaded and processed by the pipeline. -## 📄 Configuration +## Configuration All model parameters and runtime settings are defined in a YAML configuration file. An example file, [`parameters.yaml`](./parameters.yaml), is provided. @@ -62,7 +64,7 @@ Tune: value: 64 ``` -## 🔎 **Usage** +## **Usage** All entry points accept a `-c/--config` parameter pointing to the YAML config file. - Generate a processed dataset: diff --git a/frame/evaluate.py b/frame/evaluate.py index a680889..f626e79 100644 --- a/frame/evaluate.py +++ b/frame/evaluate.py @@ -21,12 +21,7 @@ def main(): params = yaml.safe_load(stream) config = params["Data"] - tune = {} - for name, bounds in params["Tune"].items(): - if isinstance(bounds["value"], int): - tune[name] = int(bounds["value"]) - else: - tune[name] = float(bounds["value"]) + tune = models.tune_fixed(params) path_checkpoint = config["path_checkpoint"] model_name = config.get("model", "gat").lower() diff --git a/frame/explain.py b/frame/explain.py index f379072..fabb25f 100644 --- a/frame/explain.py +++ b/frame/explain.py @@ -22,12 +22,7 @@ def main(): params = yaml.safe_load(stream) config = params["Data"] - tune = {} - for name, bounds in params["Tune"].items(): - if isinstance(bounds["value"], int): - tune[name] = int(bounds["value"]) - else: - tune[name] = float(bounds["value"]) + tune = models.tune_fixed(params) path_checkpoint = config["path_checkpoint"] model_name = config.get("model", "gat").lower() diff --git a/frame/scaffold_split.py b/frame/scaffold_split.py new file mode 100644 index 0000000..6b7cf04 --- /dev/null +++ b/frame/scaffold_split.py @@ -0,0 +1,36 @@ +import argparse + +import pandas as pd + +from frame.source.datasets import scaffold_split + + +def main(): + parser = argparse.ArgumentParser( + description=("Rewrite the `set` column of a CSV using a Murcko " + "scaffold split. Run before frame_gen.")) + parser.add_argument("-i", "--input", required=True, + help="Path to input CSV with id/smiles/label/set.") + parser.add_argument("-o", "--output", required=True, + help="Path to output CSV.") + parser.add_argument("--fracs", nargs=3, type=float, + default=[0.8, 0.1, 0.1], + metavar=("TRAIN", "VALID", "TEST"), + help="Split fractions (default: 0.8 0.1 0.1).") + parser.add_argument("--chirality", action="store_true", + help="Include chirality in scaffold definition.") + args = parser.parse_args() + + df = pd.read_csv(args.input) + if "smiles" not in df.columns: + raise ValueError("Input CSV must have a `smiles` column.") + + sets = scaffold_split(df["smiles"].tolist(), + fracs=tuple(args.fracs), + include_chirality=args.chirality) + df["set"] = sets + + counts = df["set"].value_counts().to_dict() + print(f"Scaffold split: {counts}") + + df.to_csv(args.output, index=False) diff --git a/frame/source/datasets/__init__.py b/frame/source/datasets/__init__.py index 86c7dde..8d91321 100644 --- a/frame/source/datasets/__init__.py +++ b/frame/source/datasets/__init__.py @@ -1,5 +1,7 @@ from frame.source.datasets.default import MolecularDataset from frame.source.datasets.decompose import DecomposeDataset +from frame.source.datasets.scaffold import scaffold_split __all__ = ["MolecularDataset", - "DecomposeDataset"] + "DecomposeDataset", + "scaffold_split"] diff --git a/frame/source/datasets/scaffold.py b/frame/source/datasets/scaffold.py new file mode 100644 index 0000000..f171baa --- /dev/null +++ b/frame/source/datasets/scaffold.py @@ -0,0 +1,65 @@ +from collections import defaultdict + +from rdkit import Chem +from rdkit import RDLogger +from rdkit.Chem.Scaffolds import MurckoScaffold + + +lg = RDLogger.logger() +lg.setLevel(RDLogger.CRITICAL) + + +def _scaffold(smiles, include_chirality=False): + mol = Chem.MolFromSmiles(smiles) + if mol is None: + return "" + return MurckoScaffold.MurckoScaffoldSmiles( + mol=mol, includeChirality=include_chirality) + + +def scaffold_split(smiles_list, fracs=(0.8, 0.1, 0.1), + include_chirality=False): + """Murcko-scaffold split. + + Largest scaffold groups go to train; smaller ones fill valid then test. + Molecules with the same scaffold never cross splits, which gives a more + realistic generalization signal than a random split for drug discovery. + + Args: + smiles_list: list of SMILES strings. + fracs: (train, valid, test) fractions; must sum to ~1.0. + include_chirality: pass-through to MurckoScaffoldSmiles. + + Returns: + list of "train" / "valid" / "test", aligned with smiles_list. + """ + if abs(sum(fracs) - 1.0) > 1e-6: + raise ValueError(f"fracs must sum to 1.0, got {fracs}") + + n = len(smiles_list) + train_target = int(round(fracs[0] * n)) + valid_target = int(round(fracs[1] * n)) + + groups = defaultdict(list) + for i, smi in enumerate(smiles_list): + groups[_scaffold(smi, include_chirality)].append(i) + + # Largest groups first; tiebreak on the scaffold key for determinism. + sorted_groups = sorted(groups.items(), + key=lambda kv: (-len(kv[1]), kv[0])) + + sets = ["test"] * n + train_count = 0 + valid_count = 0 + for _, indices in sorted_groups: + if train_count + len(indices) <= train_target: + for i in indices: + sets[i] = "train" + train_count += len(indices) + + elif valid_count + len(indices) <= valid_target: + for i in indices: + sets[i] = "valid" + valid_count += len(indices) + + return sets diff --git a/frame/source/models/__init__.py b/frame/source/models/__init__.py index 6b37e97..4c9c073 100644 --- a/frame/source/models/__init__.py +++ b/frame/source/models/__init__.py @@ -1,11 +1,15 @@ import torch +from torch.optim.lr_scheduler import (LinearLR, + SequentialLR, + CosineAnnealingLR) + from frame.source import train from frame.source.models import pyg_models device = "cuda" if torch.cuda.is_available() else "cpu" -def model_setup(model_name, config): +def model_setup(model_name, config, epochs=100): task = config["task"] model = select_model(model_name, config) @@ -16,16 +20,42 @@ def model_setup(model_name, config): eps=config["eps"], weight_decay=config["weight_decay"]) optimizer = train.Lookahead(base_optimizer, k=5, alpha=0.5) - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, - T_max=100, - eta_min=1e-6) + + warmup_epochs = int(config.get("warmup_epochs", 0)) + eta_min = float(config.get("lr_min", 1e-6)) + if warmup_epochs > 0 and warmup_epochs < epochs: + warmup = LinearLR(optimizer, start_factor=0.1, + total_iters=warmup_epochs) + cosine = CosineAnnealingLR(optimizer, + T_max=max(1, epochs - warmup_epochs), + eta_min=eta_min) + scheduler = SequentialLR(optimizer, + schedulers=[warmup, cosine], + milestones=[warmup_epochs]) + else: + scheduler = CosineAnnealingLR(optimizer, T_max=max(1, epochs), + eta_min=eta_min) if task == "classification": bce_weight = config["bce_weight"] lossfn = torch.nn.BCEWithLogitsLoss(pos_weight=bce_weight).to(device) else: - lossfn = torch.nn.MSELoss() + reg_loss = str(config.get("regression_loss", "mse")).lower() + delta = float(config.get("huber_delta", 1.0)) + + if reg_loss == "mse": + lossfn = torch.nn.MSELoss() + + elif reg_loss == "huber": + lossfn = torch.nn.HuberLoss(delta=delta) + + elif reg_loss == "smooth_l1": + lossfn = torch.nn.SmoothL1Loss(beta=delta) + + else: + raise ValueError(f"Unknown regression_loss: {reg_loss}. " + "Choose from mse, huber, smooth_l1.") return model, optimizer, scheduler, lossfn @@ -47,21 +77,54 @@ def select_model(model_name, config): return model +def _cast_value(val): + if isinstance(val, bool): + return val + if isinstance(val, int): + return int(val) + if isinstance(val, float): + return float(val) + return str(val) + + +def tune_fixed(params): + out = {} + for name, bounds in params["Tune"].items(): + if "value" not in bounds: + continue + out[name] = _cast_value(bounds["value"]) + return out + + def optuna_suggest(params, trial): configs = {} for name, bounds in params["Tune"].items(): - if "min" in bounds: - if isinstance(bounds["max"], int): + if "choices" in bounds: + configs[name] = trial.suggest_categorical(name, bounds["choices"]) + elif "min" in bounds: + log = bool(bounds.get("log", False)) + if isinstance(bounds["max"], int) and not log: configs[name] = trial.suggest_int(name, bounds["min"], bounds["max"]) else: - configs[name] = trial.suggest_float(name, float(bounds["min"]), - float(bounds["max"])) + configs[name] = trial.suggest_float(name, + float(bounds["min"]), + float(bounds["max"]), + log=log) else: - if isinstance(bounds["value"], int): - configs[name] = int(bounds["value"]) - else: - configs[name] = float(bounds["value"]) + configs[name] = _cast_value(bounds["value"]) + + # Round hidden_channels to match n_heads, and log + model_name = str(params["Data"].get("model", "")).lower() + if (model_name == "gat"): + heads = int(configs["heads"]) + original = int(configs["hidden_channels"]) + rounded = (original // heads) * heads + + if rounded != original: + configs["hidden_channels"] = rounded + trial.set_user_attr("hidden_channels_suggested", original) + trial.set_user_attr("hidden_channels_used", rounded) return configs diff --git a/frame/source/models/pyg_models.py b/frame/source/models/pyg_models.py index f1f7364..aa5e9ae 100644 --- a/frame/source/models/pyg_models.py +++ b/frame/source/models/pyg_models.py @@ -1,8 +1,23 @@ import torch from torch_geometric.nn import (GCN, GraphSAGE, GIN, GAT, - AttentiveFP, global_mean_pool) + AttentiveFP, + global_mean_pool, + global_add_pool, + global_max_pool) -torch.manual_seed(8) + +_POOLS = {"mean": global_mean_pool, + "add": global_add_pool, + "sum": global_add_pool, + "max": global_max_pool} + + +def _resolve_pool(name): + key = (name or "mean").lower() + if key not in _POOLS: + raise ValueError(f"Unknown pool '{name}'. " + f"Choose from {sorted(set(_POOLS))}.") + return _POOLS[key] class GNN_GCN(torch.nn.Module): @@ -14,6 +29,7 @@ def __init__(self, config): dropout = config.get("dropout_rate", 0.4) improved = config.get("gcn_improved", True) + self.pool = _resolve_pool(config.get("pool", "mean")) self.model = GCN(in_channels=in_channels, hidden_channels=hidden_channels, num_layers=num_layers, @@ -23,7 +39,7 @@ def __init__(self, config): def forward(self, x, edge_index, edge_attr, batch): x = self.model(x, edge_index, edge_attr=edge_attr) - x_pool = global_mean_pool(x, batch) + x_pool = self.pool(x, batch) return x_pool @@ -36,6 +52,7 @@ def __init__(self, config): num_layers = config.get("num_layers", 2) dropout = config.get("dropout_rate", 0.4) + self.pool = _resolve_pool(config.get("pool", "mean")) self.model = GraphSAGE(in_channels=in_channels, hidden_channels=hidden_channels, num_layers=num_layers, @@ -44,7 +61,7 @@ def __init__(self, config): def forward(self, x, edge_index, edge_attr, batch): x = self.model(x, edge_index, edge_attr=edge_attr) - x_pool = global_mean_pool(x, batch) + x_pool = self.pool(x, batch) return x_pool @@ -57,6 +74,7 @@ def __init__(self, config): num_layers = config.get("num_layers", 2) dropout = config.get("dropout_rate", 0.4) + self.pool = _resolve_pool(config.get("pool", "mean")) self.model = GIN(in_channels=in_channels, hidden_channels=hidden_channels, num_layers=num_layers, @@ -65,7 +83,7 @@ def __init__(self, config): def forward(self, x, edge_index, edge_attr, batch): x = self.model(x, edge_index, edge_attr=edge_attr) - x_pool = global_mean_pool(x, batch) + x_pool = self.pool(x, batch) return x_pool @@ -79,8 +97,15 @@ def __init__(self, config): dropout = config.get("dropout_rate", 0.4) edge_dim = config.get("edge_dim", None) v2 = config.get("att_v2", True) - heads = config.get("num_heads", 1) + heads = config.get("heads", 1) + + rounded = (hidden_channels // heads) * heads + if rounded == 0: + raise ValueError(f"hidden_channels ({hidden_channels}) must be " + f">= heads ({heads}).") + hidden_channels = rounded + self.pool = _resolve_pool(config.get("pool", "mean")) self.model = GAT(in_channels=in_channels, hidden_channels=hidden_channels, num_layers=num_layers, @@ -92,7 +117,7 @@ def __init__(self, config): def forward(self, x, edge_index, edge_attr, batch): x = self.model(x, edge_index, edge_attr=edge_attr) - x_pool = global_mean_pool(x, batch) + x_pool = self.pool(x, batch) return x_pool diff --git a/frame/source/train/__init__.py b/frame/source/train/__init__.py index 43ed1b5..6908100 100644 --- a/frame/source/train/__init__.py +++ b/frame/source/train/__init__.py @@ -1,5 +1,5 @@ from frame.source.train.optimizer import Lookahead -from frame.source.train.epoch import train_epoch, valid_epoch +from frame.source.train.epoch import train_epoch, valid_epoch, set_seed from frame.source.train.metrics import (reg_through_origin, concordance_correlation, roy_criteria, @@ -8,6 +8,7 @@ __all__ = ["train_epoch", "valid_epoch", + "set_seed", "Lookahead", diff --git a/frame/source/train/epoch.py b/frame/source/train/epoch.py index c90be48..519b207 100644 --- a/frame/source/train/epoch.py +++ b/frame/source/train/epoch.py @@ -4,22 +4,27 @@ import numpy as np from sklearn import metrics import torch.backends.cudnn as cudnn +from torch_geometric.utils import dropout_edge from frame.source.train import metrics as reg_metrics -random.seed(8) -np.random.seed(8) -torch.manual_seed(8) -if torch.cuda.is_available(): - torch.cuda.manual_seed(8) - torch.cuda.manual_seed_all(8) cudnn.deterministic = True cudnn.benchmark = False device = "cuda" if torch.cuda.is_available() else "cpu" -def train_epoch(model, optim, scheduler, lossfn, loader): +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def train_epoch(model, optim, lossfn, loader, grad_clip_norm=None, + drop_edge_p=0.0, mask_feat_p=0.0): step = 1 running_loss = 0.0 @@ -28,10 +33,23 @@ def train_epoch(model, optim, scheduler, lossfn, loader): batch = batch.to(device) optim.zero_grad() + x = batch.x.float() + edge_index = batch.edge_index + edge_attr = batch.edge_attr.float() + + if drop_edge_p and drop_edge_p > 0 and edge_index.numel() > 0: + edge_index, edge_mask = dropout_edge(edge_index, p=drop_edge_p, + force_undirected=True) + edge_attr = edge_attr[edge_mask] + + if mask_feat_p and mask_feat_p > 0: + keep = (torch.rand_like(x) >= mask_feat_p).float() + x = x * keep + # * Make predictions - out = model(x=batch.x.float(), - edge_index=batch.edge_index, - edge_attr=batch.edge_attr.float(), + out = model(x=x, + edge_index=edge_index, + edge_attr=edge_attr, batch=batch.batch) # * Compute loss @@ -39,6 +57,10 @@ def train_epoch(model, optim, scheduler, lossfn, loader): loss = lossfn(torch.squeeze(out), torch.squeeze(true)) loss.backward() + if grad_clip_norm is not None and grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), + max_norm=grad_clip_norm) + # * Update gradients optim.step() @@ -46,7 +68,6 @@ def train_epoch(model, optim, scheduler, lossfn, loader): running_loss += loss.detach().item() step += 1 - scheduler.step() return running_loss / step diff --git a/frame/source/train/metrics.py b/frame/source/train/metrics.py index 010b5b6..4c9dbc0 100644 --- a/frame/source/train/metrics.py +++ b/frame/source/train/metrics.py @@ -25,7 +25,7 @@ def reg_through_origin(y_true, y_pred): rto_r2 = rto.score(pred, true) slope = regression.coef_ - return rto_r2, float(slope) + return rto_r2, float(np.asarray(slope).item()) def concordance_correlation(y_true, y_pred): diff --git a/frame/source/train/runner.py b/frame/source/train/runner.py new file mode 100644 index 0000000..6b425fd --- /dev/null +++ b/frame/source/train/runner.py @@ -0,0 +1,88 @@ +import copy +import time + +import numpy as np +import torch +from tqdm import tqdm +from torch_geometric.loader import DataLoader + +from frame.source import models +from frame.source import train as train_pkg + + +def train_one_seed(seed: int, train_data: list, valid_loader: DataLoader, + model_name: str, config: dict, epochs: int, + patience: int, task: str, grad_clip: float, + drop_edge_p: float, mask_feat_p: float, + batch_size: int, workers: int): + """Train a fresh model under a fixed seed and return its best state. + + Reseeds the global RNGs and pins the train ``DataLoader``'s shuffle + generator so that two calls with the same ``seed`` produce the same + trajectory. The valid loader is reused across seeds since it does not + shuffle. + + Args: + seed: Integer used to reseed Python/NumPy/Torch RNGs and the + train-loader generator. + train_data: List of ``Data`` objects belonging to the train split. + valid_loader: Pre-built validation loader (no shuffle, seed-free). + model_name: Backbone selector consumed by ``models.select_model``. + config: Hyperparameter dict consumed by ``models.model_setup``. + epochs: Maximum number of epochs to train. + patience: Early-stopping patience on the ``optim`` metric. + task: ``"classification"`` or ``"regression"``. + grad_clip: Max-norm gradient clipping value (``None``/0 disables). + drop_edge_p: Train-time edge drop probability. + mask_feat_p: Train-time node-feature mask probability. + batch_size: Train loader batch size. + workers: Number of DataLoader workers. + + Returns: + Tuple ``(best_state, results, fit_time, n_params)`` where + ``best_state`` is the ``state_dict`` of the best epoch on the valid + ``optim`` metric, ``results`` is the validation metric dict at that + state, ``fit_time`` is wall-clock seconds spent training, and + ``n_params`` is the trainable parameter count. + """ + train_pkg.set_seed(seed) + generator = torch.Generator() + generator.manual_seed(seed) + train_loader = DataLoader(train_data, batch_size=batch_size, + shuffle=True, num_workers=workers, + generator=generator) + + model, optim, schdlr, lossfn = models.model_setup(model_name, config, + epochs=epochs) + + best_metric = -float("inf") + patience_counter = 0 + best_state = None + + start = time.time() + for _ in tqdm(range(epochs), ncols=120, desc=f"Seed {seed}"): + _ = train_pkg.train_epoch(model, optim, lossfn, train_loader, + grad_clip_norm=grad_clip, + drop_edge_p=drop_edge_p, + mask_feat_p=mask_feat_p) + val_metrics = train_pkg.valid_epoch(model, task, valid_loader) + schdlr.step() + + if val_metrics["optim"] > best_metric: + patience_counter = 0 + best_metric = val_metrics["optim"] + best_state = copy.deepcopy(model.state_dict()) + else: + patience_counter += 1 + + if patience_counter >= patience: + break + + fit_time = time.time() - start + + model.load_state_dict(best_state) + results = train_pkg.valid_epoch(model, task, valid_loader) + + n_params = sum(int(np.prod(p.size())) for p in model.parameters() + if p.requires_grad) + return best_state, results, fit_time, n_params diff --git a/frame/train.py b/frame/train.py index 4a3295a..079ef1b 100644 --- a/frame/train.py +++ b/frame/train.py @@ -1,5 +1,5 @@ import os -import copy +import json import uuid import shutil import argparse @@ -8,82 +8,91 @@ import yaml import torch import joblib -from tqdm import tqdm +import numpy as np from torch_geometric.loader import DataLoader -from frame.source import models, train +from frame.source import models +from frame.source.train import runner device = "cuda" if torch.cuda.is_available() else "cpu" +def _report_seed_stats(task, per_seed_results, project_dir): + """Print mean ± std across seeds and dump per-seed metrics to JSON.""" + headline = "mcc" if task == "classification" else "ccc" + values = [float(r[headline]) for r in per_seed_results] + mean = float(np.mean(values)) + std = float(np.std(values)) + + print(f"{headline.upper()}: {mean:.3f} ± {std:.3f} " + f"(per-seed: {[round(v, 3) for v in values]})") + + summary = {"headline_metric": headline, + "mean": round(mean, 4), + "std": round(std, 4), + "per_seed": values, + "per_seed_full": per_seed_results} + with open(project_dir / "seed_metrics.json", "w") as fh: + json.dump(summary, fh, indent=2) + + def run(params, dataset): - # Get static data from parameters epochs = params["Data"].get("epochs", 10) workers = params["Data"].get("workers", 4) size = params["Data"].get("batch_size", 32) patience = params["Data"].get("patience", 5) model_name = params["Data"].get("model", "gat").lower() task = params["Data"].get("task", "classification").lower() + grad_clip = params["Data"].get("grad_clip_norm", 1.0) + drop_edge_p = float(params["Data"].get("drop_edge_p", 0.0)) + mask_feat_p = float(params["Data"].get("mask_feat_p", 0.0)) + seeds = params["Data"].get("train_seeds", [8]) + if not seeds: + raise ValueError("train_seeds must contain at least one seed") project_dir = params["Data"]["project_dir"] - # Get values - config = {} - for name, bounds in params["Tune"].items(): - if isinstance(bounds["value"], int): - config[name] = int(bounds["value"]) - else: - config[name] = float(bounds["value"]) + config = models.tune_fixed(params) config["feat_size"] = params["Data"]["feat_size"] config["edge_dim"] = params["Data"]["edge_dim"] config["bce_weight"] = params["Data"]["bce_weight"] config["task"] = task + config["regression_loss"] = params["Data"].get("regression_loss", "mse") + config["huber_delta"] = params["Data"].get("huber_delta", 1.0) + config["warmup_epochs"] = int(params["Data"].get("warmup_epochs", 0)) params["Data"]["trial"] = None - # * Prepare dataloader - train_data = [data for data in dataset if data.set == "train"] - train_loader = DataLoader(train_data, batch_size=size, - shuffle=True, num_workers=workers, - persistent_workers=True) + size = int(config.get("batch_size", size)) + # * Prepare valid loader (shared across seeds; train loader is built + # inside the runner so its shuffle is pinned to each seed) + train_data = [data for data in dataset if data.set == "train"] valid_data = [data for data in dataset if data.set == "valid"] valid_loader = DataLoader(valid_data, batch_size=size, num_workers=workers, persistent_workers=True) - # * Get model - model, optim, schdlr, lossfn = models.model_setup(model_name, config) - - # * Train - best_metric = -1.0 - patience_counter = 0 - best_model_state = None - for epoch in tqdm(range(epochs), ncols=120, desc="Training"): - _ = train.train_epoch(model, optim, schdlr, lossfn, train_loader) - val_metrics = train.valid_epoch(model, task, valid_loader) - - # Early stopping check - if val_metrics["optim"] > best_metric: - patience_counter = 0 - best_metric = val_metrics["optim"] - best_model_state = copy.deepcopy(model.state_dict()) - else: - patience_counter += 1 - - # Check if we should stop early - if patience_counter >= patience: - break - - # Prepare best model - model.load_state_dict(best_model_state) + # * Train one model per seed, keep the best-seed checkpoint + per_seed_results = [] + best_state = None + best_optim = -float("inf") + for seed in seeds: + state, results, _, _ = runner.train_one_seed(int(seed), train_data, + valid_loader, + model_name, config, + epochs, patience, task, + grad_clip, drop_edge_p, + mask_feat_p, size, + workers) + per_seed_results.append(results) + if results["optim"] > best_optim: + best_optim = float(results["optim"]) + best_state = state + os.makedirs(project_dir, exist_ok=True) - torch.save(best_model_state, str(project_dir / "best_model.pt")) - results = train.valid_epoch(model, task, valid_loader) + torch.save(best_state, str(project_dir / "best_model.pt")) - if task == "classification": - print(f"MCC: {results['mcc']}") - else: - print(f"CCC: {results['ccc']}") + _report_seed_stats(task, per_seed_results, project_dir) def main(): diff --git a/frame/tune.py b/frame/tune.py index a8e8afb..d1f6a14 100644 --- a/frame/tune.py +++ b/frame/tune.py @@ -1,6 +1,5 @@ import os import time -import copy import uuid import shutil import argparse @@ -12,16 +11,42 @@ import optuna import numpy as np import pandas as pd -from tqdm import tqdm import plotly.graph_objects as go from torch_geometric.loader import DataLoader -from frame.source import models, train, utils +from frame.source import models, utils +from frame.source.train import runner device = "cuda" if torch.cuda.is_available() else "cpu" logger = utils.get_logger("TUNE", log_level="INFO") +def _aggregate_seed_metrics(per_seed_results): + """Mean each per-seed validation-metric dict, rounded to 3 dp.""" + keys = per_seed_results[0].keys() + return {key: round(float(np.mean([r[key] for r in per_seed_results])), 3) + for key in keys} + + +def _record_trial(trial, project_dir, best_state, per_seed_optim, + per_seed_results, total_time, n_params, best_seed): + """Persist best-seed checkpoint and write seed-aware trial attrs.""" + trial_dir = project_dir / f"trial_{trial.number}" + os.makedirs(trial_dir, exist_ok=True) + torch.save(best_state, str(trial_dir / "best_model.pt")) + + trial.set_user_attr("n_params", int(n_params)) + trial.set_user_attr("fit_time", float(round(total_time, 3))) + trial.set_user_attr("metrics", _aggregate_seed_metrics(per_seed_results)) + trial.set_user_attr("optim_mean", + round(float(np.mean(per_seed_optim)), 4)) + trial.set_user_attr("optim_std", + round(float(np.std(per_seed_optim)), 4)) + trial.set_user_attr("optim_per_seed", + [round(float(v), 4) for v in per_seed_optim]) + trial.set_user_attr("best_seed", int(best_seed)) + + def objective(trial, params, dataset): # Get static data from parameters epochs = params["Data"].get("epochs", 10) @@ -31,6 +56,12 @@ def objective(trial, params, dataset): model_name = params["Data"].get("model", "gat").lower() max_retries = params["Data"].get("max_retries", 5) task = params["Data"].get("task", "classification").lower() + grad_clip = params["Data"].get("grad_clip_norm", 1.0) + drop_edge_p = float(params["Data"].get("drop_edge_p", 0.0)) + mask_feat_p = float(params["Data"].get("mask_feat_p", 0.0)) + seeds = params["Data"].get("tune_seeds", [8]) + if not seeds: + raise ValueError("tune_seeds must contain at least one seed") project_dir = params["Data"]["project_dir"] @@ -40,67 +71,48 @@ def objective(trial, params, dataset): config["edge_dim"] = params["Data"]["edge_dim"] config["bce_weight"] = params["Data"]["bce_weight"] config["task"] = task + config["regression_loss"] = params["Data"].get("regression_loss", "mse") + config["huber_delta"] = params["Data"].get("huber_delta", 1.0) + config["warmup_epochs"] = int(params["Data"].get("warmup_epochs", 0)) params["Data"]["trial"] = trial + size = int(config.get("batch_size", size)) - # * Prepare dataloader + # * Prepare dataloader (valid is shared across seeds; train is rebuilt + # inside the runner so its shuffle is pinned to the seed) train_data = [data for data in dataset if data.set == "train"] - train_loader = DataLoader(train_data, batch_size=size, - shuffle=True, num_workers=workers, - persistent_workers=True) - valid_data = [data for data in dataset if data.set == "valid"] valid_loader = DataLoader(valid_data, batch_size=size, num_workers=workers, persistent_workers=True) - # * Get model - model, optim, schdlr, lossfn = models.model_setup(model_name, config) - - # * Train + # * Train one model per seed and aggregate retries = 0 while retries < max_retries: try: - best_metric = -1000 - patience_counter = 0 - best_model_state = None - - start = time.time() - for epoch in tqdm(range(epochs), ncols=120, desc="Training"): - _ = train.train_epoch(model, optim, schdlr, - lossfn, train_loader) - val_metrics = train.valid_epoch(model, task, valid_loader) - - # Early stopping check - if val_metrics["optim"] > best_metric: - patience_counter = 0 - best_metric = val_metrics["optim"] - best_model_state = copy.deepcopy(model.state_dict()) - else: - patience_counter += 1 - - # Check if we should stop early - if patience_counter >= patience: - break - - fit_time = time.time() - start - - # Prepare best model - model.load_state_dict(best_model_state) - trial_dir = project_dir / f"trial_{trial.number}" - os.makedirs(trial_dir, exist_ok=True) - torch.save(best_model_state, str(trial_dir / "best_model.pt")) - results = train.valid_epoch(model, task, valid_loader) - - # Get model complexity - n_params = filter(lambda p: p.requires_grad, model.parameters()) - sum_params = sum([np.prod(p.size()) for p in n_params]) - trial.set_user_attr("n_params", int(sum_params)) - - # Report time and metrics - trial.set_user_attr("fit_time", float(round(fit_time, 3))) - trial.set_user_attr("metrics", results) - - return results["optim"] + per_seed_optim = [] + per_seed_results = [] + best_state = None + best_optim = -float("inf") + best_seed = int(seeds[0]) + total_time = 0.0 + n_params = 0 + for seed in seeds: + state, results, fit_time, n_params = runner.train_one_seed( + int(seed), train_data, valid_loader, model_name, + config, epochs, patience, task, grad_clip, + drop_edge_p, mask_feat_p, size, workers) + per_seed_optim.append(float(results["optim"])) + per_seed_results.append(results) + total_time += fit_time + if results["optim"] > best_optim: + best_optim = float(results["optim"]) + best_state = state + best_seed = int(seed) + + _record_trial(trial, project_dir, best_state, per_seed_optim, + per_seed_results, total_time, n_params, + best_seed) + return float(np.mean(per_seed_optim)) except torch.cuda.OutOfMemoryError: retries += 1 @@ -132,12 +144,16 @@ def get_dataframe(study, task): val_metrics = trial.user_attrs.get("metrics", dummy) n_params = trial.user_attrs.get("n_params", np.nan) fit_time = trial.user_attrs.get("fit_time", np.nan) + optim_std = trial.user_attrs.get("optim_std", np.nan) + best_seed = trial.user_attrs.get("best_seed", np.nan) # Update dict record.update(val_metrics) record.update(trial.params) record.update({"n_params": n_params}) record.update({"fit_time": fit_time}) + record.update({"optim_std": optim_std}) + record.update({"best_seed": best_seed}) records.append(record) @@ -195,7 +211,8 @@ def main(): header = ["optim", "acc", "acc_bal", "f1", "prec", "rec", "mcc", "avg_prec", "roc_auc", "r2", "rmse", "mae", "rto_r2", "ccc", "roy_c", "roy_c_inv", - "delta", "n_params", "fit_time"] + "delta", "n_params", "fit_time", "optim_std", + "best_seed"] feats = [col for col in list(df.columns) if col not in header] feats = feats + ["optim"] diff --git a/parameters.yaml b/parameters.yaml index b3dc327..7e62f2d 100755 --- a/parameters.yaml +++ b/parameters.yaml @@ -4,33 +4,67 @@ Data: path_csv: "path/to/csv_file.csv" path_joblib: "path/to/joblib_file.joblib" path_checkpoint: "path/to/checkpoint.pt" - trials: 2 - epochs: 10 - model: gat # gcn, gat, attentive + + trials: 50 + epochs: 50 + model: gat # gcn | gat | gin | sage | attentive batch_size: 128 - patience: 5 - loader: default # "default" or "decompose" + patience: 15 + loader: default # default | decompose + grad_clip_norm: 1.0 + drop_edge_p: 0.0 # train-time edge drop probability + mask_feat_p: 0.0 # train-time node feature mask probability + regression_loss: mse # mse | huber | smooth_l1 + huber_delta: 1.0 # threshold for Huber / SmoothL1 + warmup_epochs: 0 # linear LR warmup epochs before cosine decay + tune_seeds: [13, 42, 73] # seeds averaged per Optuna trial + train_seeds: [13, 42, 73, 101, 202] # seeds for final retraining Tune: hidden_channels: min: 32 max: 512 + # value: 16 num_layers: min: 1 - max: 3 - heads: - min: 1 - max: 3 + max: 4 + # value: 2 dropout_rate: min: 0.1 max: 0.6 + # value: 0.3 + timesteps: + min: 1 + max: 4 + # value: 2 + heads: + min: 1 + max: 4 + # value: 3 + pool: + choices: [mean, add, max] + # value: mean + learning_rate: - value: 1e-3 + # min: 1.0e-4 + # max: 5.0e-3 + # log: true + value: 1.0e-3 weight_decay: - value: 1e-5 + # min: 1.0e-6 + # max: 1.0e-3 + # log: true + value: 1.0e-5 beta_1: + # min: 0.5 + # max: 0.9999999 value: 0.9 beta_2: + # min: 0.5 + # max: 0.9999999 value: 0.999 eps: - value: 1e-8 + # min: 1.0e-9 + # max: 1.0e-7 + # log: true + value: 1.0e-8 diff --git a/pyproject.toml b/pyproject.toml index bf60a3b..8963b1a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,28 +1,26 @@ [build-system] -requires = ["setuptools==80.9.0", "wheel==0.45.1"] -build-backend = "setuptools.build_meta" - -[tool.setuptools] -packages = ["frame"] +requires = ["hatchling"] +build-backend = "hatchling.build" [project] name = "frame" version = "0.0.0" -requires-python = ">=3.12" +requires-python = ">=3.11" authors = [{name = "Rafael Lopes", email = "rafael.lopes@uef.fi"}] readme = "README.md" -dynamic = ["dependencies"] - -[project.optional-dependencies] -dev = ["bump-my-version==1.1.2", - "gitchangelog==3.0.4", - "pytest==8.3.5", - "pytest-cov==6.1.1", - "pytest-depends==1.0.1", - "pytest-flake8==1.3.0"] - -[tool.setuptools.dynamic] -dependencies = {file = ["requirements.txt"]} +description = "" +dependencies = ["torch==2.11.0", + "captum==0.9.0", + "joblib==1.5.3", + "optuna==4.8.0", + "pandas==3.0.2", + "plotly==6.7.0", + "rdkit==2026.3.1", + "optuna==4.8.0", + "pandas==3.0.2", + "scikit-learn==1.8.0", + "svgutils==0.3.4", + "torch-geometric==2.7.0"] [project.scripts] frame_tune = "frame.tune:main" @@ -30,3 +28,33 @@ frame_train = "frame.train:main" frame_gen = "frame.generate:main" frame_eval = "frame.evaluate:main" frame_explain = "frame.explain:main" +frame_scaffold = "frame.scaffold_split:main" + +[dependency-groups] +dev = ["bump-my-version==1.3.0", + "generate-changelog==0.17.0"] + +[tool.hatch.build.targets.wheel] +packages = ["frame"] + +[tool.uv] +package = true +find-links = ["https://download.pytorch.org/whl/cu128"] + +[tool.uv.sources] +torch = { index = "pytorch-cu128" } + +[[tool.uv.index]] +name = "pytorch-cu128" +url = "https://download.pytorch.org/whl/cu128" +explicit = true + +[tool.bumpversion] +current_version = "0.0.0" +commit = true +tag = true + +[[tool.bumpversion.files]] +filename = "pyproject.toml" +search = 'version = "{current_version}"' +replace = 'version = "{new_version}"' diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 8a7c519..0000000 --- a/requirements.txt +++ /dev/null @@ -1,9 +0,0 @@ -captum==0.8.0 -joblib==1.5.2 -optuna==4.5.0 -pandas==2.3.2 -plotly==6.5.0 -rdkit==2025.3.6 -scikit-learn==1.7.2 -svgutils==0.3.4 -torch-geometric==2.6.1