Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ venv.bak/
*.zip
*.png
*.pdf
*.md
*.json
*.lock

# Dirs
*temp*
Expand All @@ -68,3 +71,4 @@ venv.bak/
!*logo.txt
!*parameters*.yaml
!*requirements*.txt
!README.md
30 changes: 16 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,30 @@

This repository introduces **FRAME**, a framework for learning fragment-based molecular representations to enhance the interpretability of graph neural networks in drug discovery. FRAME represents chemically meaningful fragments as graph nodes and is compatible with several GNN architectures, including GCN, GAT, and AttentiveFP. It also integrates Integrated Gradients to generate more transparent and chemically grounded model explanations.

## ⚙️ **Installation**
1. Clone the repo:
## Installation

2. Create and activate your `virtualenv` with Python 3.12, for example as described [here](https://docs.python.org/3/library/venv.html).
FRAME is installed with [`uv`](https://docs.astral.sh/uv/), which picks the right `torch` wheels (CUDA 12.8) for you.

3. Install [PyTorch **2.8.0**](https://pytorch.org/get-started/locally/) using:
1. Install [`uv`](https://docs.astral.sh/uv/getting-started/installation/) if you don't already have it.
2. Clone the repo.
3. From the project root, run:

```console
pip install torch==2.8.0 -f https://download.pytorch.org/whl/cu129
uv sync
```

4. Install FRAME using:
That creates a `.venv/` with Python 3.11+ and installs PocketGraph along with everything it depends on. You can use prefix commands with `uv run` (e.g. `uv run frame_tune -c parameters.yaml`).


To install the `frame_*` commands globally (isolated in their own environment, available on your `PATH` without having to activate a venv), use `uv tool install`:

```console
python -m pip install .
```
or for development:
```console
python -m pip install -e .
uv tool install .
```

## 📂 Dataset Requirements
If you'd rather not use `uv`, you can install the dependencies declared in [pyproject.toml](pyproject.toml) directly with `pip` in a Python 3.11+ environment.

## Dataset Requirements
The CSV file used in FRAME **must** include the following columns:

- **`id`** – A unique identifier for each entry.
Expand All @@ -39,7 +41,7 @@ The CSV file used in FRAME **must** include the following columns:
Please ensure that all entries follow this structure so the dataset can be correctly loaded and processed by the pipeline.


## 📄 Configuration
## Configuration
All model parameters and runtime settings are defined in a YAML configuration file.
An example file, [`parameters.yaml`](./parameters.yaml), is provided.

Expand All @@ -62,7 +64,7 @@ Tune:
value: 64
```

## 🔎 **Usage**
## **Usage**
All entry points accept a `-c/--config` parameter pointing to the YAML config file.

- Generate a processed dataset:
Expand Down
7 changes: 1 addition & 6 deletions frame/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,7 @@ def main():
params = yaml.safe_load(stream)

config = params["Data"]
tune = {}
for name, bounds in params["Tune"].items():
if isinstance(bounds["value"], int):
tune[name] = int(bounds["value"])
else:
tune[name] = float(bounds["value"])
tune = models.tune_fixed(params)

path_checkpoint = config["path_checkpoint"]
model_name = config.get("model", "gat").lower()
Expand Down
7 changes: 1 addition & 6 deletions frame/explain.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,7 @@ def main():
params = yaml.safe_load(stream)

config = params["Data"]
tune = {}
for name, bounds in params["Tune"].items():
if isinstance(bounds["value"], int):
tune[name] = int(bounds["value"])
else:
tune[name] = float(bounds["value"])
tune = models.tune_fixed(params)

path_checkpoint = config["path_checkpoint"]
model_name = config.get("model", "gat").lower()
Expand Down
36 changes: 36 additions & 0 deletions frame/scaffold_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import argparse

import pandas as pd

from frame.source.datasets import scaffold_split


def main():
parser = argparse.ArgumentParser(
description=("Rewrite the `set` column of a CSV using a Murcko "
"scaffold split. Run before frame_gen."))
parser.add_argument("-i", "--input", required=True,
help="Path to input CSV with id/smiles/label/set.")
parser.add_argument("-o", "--output", required=True,
help="Path to output CSV.")
parser.add_argument("--fracs", nargs=3, type=float,
default=[0.8, 0.1, 0.1],
metavar=("TRAIN", "VALID", "TEST"),
help="Split fractions (default: 0.8 0.1 0.1).")
parser.add_argument("--chirality", action="store_true",
help="Include chirality in scaffold definition.")
args = parser.parse_args()

df = pd.read_csv(args.input)
if "smiles" not in df.columns:
raise ValueError("Input CSV must have a `smiles` column.")

sets = scaffold_split(df["smiles"].tolist(),
fracs=tuple(args.fracs),
include_chirality=args.chirality)
df["set"] = sets

counts = df["set"].value_counts().to_dict()
print(f"Scaffold split: {counts}")

df.to_csv(args.output, index=False)
4 changes: 3 additions & 1 deletion frame/source/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from frame.source.datasets.default import MolecularDataset
from frame.source.datasets.decompose import DecomposeDataset
from frame.source.datasets.scaffold import scaffold_split

__all__ = ["MolecularDataset",
"DecomposeDataset"]
"DecomposeDataset",
"scaffold_split"]
65 changes: 65 additions & 0 deletions frame/source/datasets/scaffold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from collections import defaultdict

from rdkit import Chem
from rdkit import RDLogger
from rdkit.Chem.Scaffolds import MurckoScaffold


lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)


def _scaffold(smiles, include_chirality=False):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return ""
return MurckoScaffold.MurckoScaffoldSmiles(
mol=mol, includeChirality=include_chirality)


def scaffold_split(smiles_list, fracs=(0.8, 0.1, 0.1),
include_chirality=False):
"""Murcko-scaffold split.

Largest scaffold groups go to train; smaller ones fill valid then test.
Molecules with the same scaffold never cross splits, which gives a more
realistic generalization signal than a random split for drug discovery.

Args:
smiles_list: list of SMILES strings.
fracs: (train, valid, test) fractions; must sum to ~1.0.
include_chirality: pass-through to MurckoScaffoldSmiles.

Returns:
list of "train" / "valid" / "test", aligned with smiles_list.
"""
if abs(sum(fracs) - 1.0) > 1e-6:
raise ValueError(f"fracs must sum to 1.0, got {fracs}")

n = len(smiles_list)
train_target = int(round(fracs[0] * n))
valid_target = int(round(fracs[1] * n))

groups = defaultdict(list)
for i, smi in enumerate(smiles_list):
groups[_scaffold(smi, include_chirality)].append(i)

# Largest groups first; tiebreak on the scaffold key for determinism.
sorted_groups = sorted(groups.items(),
key=lambda kv: (-len(kv[1]), kv[0]))

sets = ["test"] * n
train_count = 0
valid_count = 0
for _, indices in sorted_groups:
if train_count + len(indices) <= train_target:
for i in indices:
sets[i] = "train"
train_count += len(indices)

elif valid_count + len(indices) <= valid_target:
for i in indices:
sets[i] = "valid"
valid_count += len(indices)

return sets
89 changes: 76 additions & 13 deletions frame/source/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import torch
from torch.optim.lr_scheduler import (LinearLR,
SequentialLR,
CosineAnnealingLR)

from frame.source import train
from frame.source.models import pyg_models

device = "cuda" if torch.cuda.is_available() else "cpu"


def model_setup(model_name, config):
def model_setup(model_name, config, epochs=100):
task = config["task"]
model = select_model(model_name, config)

Expand All @@ -16,16 +20,42 @@ def model_setup(model_name, config):
eps=config["eps"],
weight_decay=config["weight_decay"])
optimizer = train.Lookahead(base_optimizer, k=5, alpha=0.5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
T_max=100,
eta_min=1e-6)

warmup_epochs = int(config.get("warmup_epochs", 0))
eta_min = float(config.get("lr_min", 1e-6))
if warmup_epochs > 0 and warmup_epochs < epochs:
warmup = LinearLR(optimizer, start_factor=0.1,
total_iters=warmup_epochs)
cosine = CosineAnnealingLR(optimizer,
T_max=max(1, epochs - warmup_epochs),
eta_min=eta_min)
scheduler = SequentialLR(optimizer,
schedulers=[warmup, cosine],
milestones=[warmup_epochs])
else:
scheduler = CosineAnnealingLR(optimizer, T_max=max(1, epochs),
eta_min=eta_min)

if task == "classification":
bce_weight = config["bce_weight"]
lossfn = torch.nn.BCEWithLogitsLoss(pos_weight=bce_weight).to(device)

else:
lossfn = torch.nn.MSELoss()
reg_loss = str(config.get("regression_loss", "mse")).lower()
delta = float(config.get("huber_delta", 1.0))

if reg_loss == "mse":
lossfn = torch.nn.MSELoss()

elif reg_loss == "huber":
lossfn = torch.nn.HuberLoss(delta=delta)

elif reg_loss == "smooth_l1":
lossfn = torch.nn.SmoothL1Loss(beta=delta)

else:
raise ValueError(f"Unknown regression_loss: {reg_loss}. "
"Choose from mse, huber, smooth_l1.")

return model, optimizer, scheduler, lossfn

Expand All @@ -47,21 +77,54 @@ def select_model(model_name, config):
return model


def _cast_value(val):
if isinstance(val, bool):
return val
if isinstance(val, int):
return int(val)
if isinstance(val, float):
return float(val)
return str(val)


def tune_fixed(params):
out = {}
for name, bounds in params["Tune"].items():
if "value" not in bounds:
continue
out[name] = _cast_value(bounds["value"])
return out


def optuna_suggest(params, trial):
configs = {}

for name, bounds in params["Tune"].items():
if "min" in bounds:
if isinstance(bounds["max"], int):
if "choices" in bounds:
configs[name] = trial.suggest_categorical(name, bounds["choices"])
elif "min" in bounds:
log = bool(bounds.get("log", False))
if isinstance(bounds["max"], int) and not log:
configs[name] = trial.suggest_int(name, bounds["min"],
bounds["max"])
else:
configs[name] = trial.suggest_float(name, float(bounds["min"]),
float(bounds["max"]))
configs[name] = trial.suggest_float(name,
float(bounds["min"]),
float(bounds["max"]),
log=log)
else:
if isinstance(bounds["value"], int):
configs[name] = int(bounds["value"])
else:
configs[name] = float(bounds["value"])
configs[name] = _cast_value(bounds["value"])

# Round hidden_channels to match n_heads, and log
model_name = str(params["Data"].get("model", "")).lower()
if (model_name == "gat"):
heads = int(configs["heads"])
original = int(configs["hidden_channels"])
rounded = (original // heads) * heads

if rounded != original:
configs["hidden_channels"] = rounded
trial.set_user_attr("hidden_channels_suggested", original)
trial.set_user_attr("hidden_channels_used", rounded)

return configs
Loading
Loading