diff --git a/.github/workflows/flake8.yml b/.github/workflows/flake8.yml index 2ef44da..aad7180 100755 --- a/.github/workflows/flake8.yml +++ b/.github/workflows/flake8.yml @@ -16,4 +16,4 @@ jobs: - name: flake8 Lint uses: py-actions/flake8@v2 with: - ignore: "W504" \ No newline at end of file + ignore: "W504,W503" \ No newline at end of file diff --git a/frame/explain.py b/frame/explain.py index fabb25f..9b20d8e 100644 --- a/frame/explain.py +++ b/frame/explain.py @@ -7,98 +7,263 @@ import torch import joblib from tqdm import tqdm -from torch_geometric.explain import Explainer, CaptumExplainer +from torch_geometric.explain import (Explainer, + CaptumExplainer, + GNNExplainer) +from torch_geometric.loader import DataLoader from frame.source import explain, models -from torch_geometric.loader import DataLoader +from frame.source.explain import aggregate +from frame.source.explain import pharmacophores + device = "cuda" if torch.cuda.is_available() else "cpu" +ALGORITHMS = ("ig", "gnnex") +BASELINES = ("native", "aggregated") -def main(): - args_parser = argparse.ArgumentParser() - args_parser.add_argument("-c", "--config", dest="config", required=True) - args = args_parser.parse_args() - with open(args.config) as stream: - params = yaml.safe_load(stream) - config = params["Data"] - tune = models.tune_fixed(params) +def _parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("-c", "--config", dest="config", required=True) + parser.add_argument("--algorithm", choices=ALGORITHMS, default="ig", + help="Attribution algorithm to use.") + parser.add_argument("--baseline", choices=BASELINES, default="native", + help="`native` explains the loader in metadata; " + "`aggregated` runs the atom-level model and " + "aggregates attributions per BRICS fragment.") + parser.add_argument("--pharmacophore", default=None, + choices=list(pharmacophores.CLASSIFIERS), + help="Case study for fragment hit rate.") + parser.add_argument("--gnnex-epochs", type=int, default=100, + help="Epochs for GNNExplainer mask optimisation.") + return parser.parse_args() - path_checkpoint = config["path_checkpoint"] - model_name = config.get("model", "gat").lower() - batch_size = config.get("batch_size", 64) - task = config.get("task", "classification").lower() - # * Initialize - name = config["name"] +def _resolve_name(config: dict): + name = config.get("name", "none") if name.lower() == "none": name = str(uuid.uuid4()).split("-")[0] config["name"] = name + return name + + +def _build_algorithm(algorithm: str, epochs: int): + if algorithm == "ig": + return CaptumExplainer("IntegratedGradients") + return GNNExplainer(epochs=epochs) + + +def _build_explainer(model, algorithm: str, task: str, epochs: int): + mode = ("multiclass_classification" + if task == "classification" else "regression") + return Explainer(model=model, + algorithm=_build_algorithm(algorithm, epochs), + explanation_type="model", + edge_mask_type="object", + node_mask_type="attributes", + model_config=dict(mode=mode, + task_level="graph", + return_type="raw")) + + +def _load_artefacts(config: dict, joblib_key: str, checkpoint_key: str): + """Load a dataset joblib + trained checkpoint by config key. + + Args: + config: params["Data"] block. + joblib_key: Key into config pointing to the dataset joblib. + checkpoint_key: Key into config pointing to the model + checkpoint. + + Returns: + Dict with keys dataset, metadata, state_dict. + + Raises: + KeyError: If either key is missing. + FileNotFoundError: If either path does not exist. + """ + joblib_path = Path(config[joblib_key]) + ckpt_path = Path(config[checkpoint_key]) + data = joblib.load(joblib_path) + state = torch.load(ckpt_path, map_location=device) + return {"dataset": data["dataset"], + "metadata": data["metadata"], + "state_dict": state} + + +def _model_config(tune: dict, metadata: dict, task: str, params: dict): + """Merge the fixed tune config with dataset metadata for model init.""" + cfg = dict(tune) + cfg["feat_size"] = metadata["feat_size"] + cfg["edge_dim"] = metadata["edge_dim"] + cfg["bce_weight"] = metadata["bce_weight"] + cfg["task"] = task + cfg["regression_loss"] = params["Data"].get("regression_loss", "mse") + return cfg + + +def _read_predictions(task: str, model_out): + """Return (logit_list, pred, pred_lbl) for a forward pass.""" + if task == "classification": + logit = model_out.cpu().detach() + logit_list = list(torch.ravel(logit).numpy()) + detach = torch.sigmoid(logit) + pred = list(torch.ravel(detach).numpy()) + pred_lbl = (detach >= 0.5).int() + else: + detach = model_out.cpu().detach() + pred = list(torch.ravel(detach).numpy()) + logit_list = pred + pred_lbl = [None] * detach.shape[0] + return logit_list, pred, pred_lbl + +def _run_native(model, dataloader, explainer, task: str, + mol_exp: explain.MolExplain): + """Iterate batches, run explainer, hand off to MolExplain.""" + for batch in tqdm(dataloader, ncols=120, desc="Explaining"): + batch.to(device) + model_out = model(x=batch.x.float(), + edge_index=batch.edge_index, + edge_attr=batch.edge_attr.float(), + batch=batch.batch) + logit_list, pred, pred_lbl = _read_predictions(task, model_out) + explanation = explainer(batch.x.float(), batch.edge_index, + edge_attr=batch.edge_attr.float(), + batch=batch.batch) + mol_exp.process_batch(explanation, logit_list, pred, pred_lbl, + batch.to_data_list()) + + +def _run_aggregated(model, dataloader, explainer, task: str, + smiles_index: dict, mol_exp: explain.MolExplain): + """Aggregate atom-level attributions into fragment-level scores.""" + for batch in tqdm(dataloader, ncols=120, desc="Aggregating"): + batch.to(device) + model_out = model(x=batch.x.float(), + edge_index=batch.edge_index, + edge_attr=batch.edge_attr.float(), + batch=batch.batch) + logit_list, pred, pred_lbl = _read_predictions(task, model_out) + explanation = explainer(batch.x.float(), batch.edge_index, + edge_attr=batch.edge_attr.float(), + batch=batch.batch) + graphs = batch.to_data_list() + for graph in graphs: + record = smiles_index.get(getattr(graph, "smiles", None)) + if record is not None: + graph.agg_atom_map = dict(enumerate( + record["atom_to_fragment"])) + agg = aggregate.aggregated_batch_masks(explanation.node_mask, + explanation.batch, + graphs, smiles_index) + mol_exp.process_aggregated_batch(agg, logit_list, pred, + pred_lbl, graphs) + + +def main(): + args = _parse_args() + with open(args.config) as stream: + params = yaml.safe_load(stream) + + config = params["Data"] + name = _resolve_name(config) cwd = Path(os.getcwd()) - project_dir = cwd / "output" / name - out = project_dir / "explain" - os.makedirs(out, exist_ok=True) - - # * Load dataset - path_joblib = Path(config["path_joblib"]) - data = joblib.load(path_joblib) - dataset = data["dataset"] - tune["feat_size"] = data["metadata"]["feat_size"] - tune["edge_dim"] = data["metadata"]["edge_dim"] - tune["bce_weight"] = data["metadata"]["bce_weight"] - loader = data["metadata"]["loader"] + out_root = cwd / "output" / name / "explain" / args.algorithm + out_dir = out_root / args.baseline + os.makedirs(out_dir, exist_ok=True) + + tune = models.tune_fixed(params) + task = config.get("task", "classification").lower() + batch_size = int(config.get("batch_size", 64)) + model_name = config.get("model", "gat").lower() + primary = _load_artefacts(config, "path_joblib", "path_checkpoint") + cfg = _model_config(tune, primary["metadata"], task, params) + model = models.select_model(model_name, cfg) + model.load_state_dict(primary["state_dict"]) + model.eval() + + explainer = _build_explainer(model, args.algorithm, task, + args.gnnex_epochs) + dataset = primary["dataset"] + loader = primary["metadata"]["loader"] dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4, persistent_workers=True) - # * Get checkpoint and prepare Explainer - model = models.select_model(model_name, tune) - model.load_state_dict(torch.load(path_checkpoint)) - model.eval() - - if task == "classification": - mode = "multiclass_classification" + if args.baseline == "aggregated": + if loader != "default": + raise ValueError("aggregated baseline requires path_joblib to " + "point at the atom-level dataset (loader=" + "default); got loader=" + loader) + frag_path = config.get("path_joblib_frag") + if frag_path is None: + raise ValueError("aggregated baseline requires " + "Data.path_joblib_frag in the config.") + frag_data = joblib.load(Path(frag_path)) + smiles_index = aggregate.build_smiles_index(frag_data["dataset"]) + mol_exp = explain.MolExplain("decompose", out_dir, + algorithm=args.algorithm) + _run_aggregated(model, dataloader, explainer, task, + smiles_index, mol_exp) else: - mode = "regression" - explainer = Explainer(model=model, - algorithm=CaptumExplainer("IntegratedGradients"), - explanation_type="model", - edge_mask_type="object", - node_mask_type="attributes", - model_config=dict(mode=mode, - task_level="graph", - return_type="raw")) - - for data in tqdm(dataloader, ncols=120, desc="Explaining"): - data.to(device) - - # * Make predictions - model_out = model(x=data.x.float(), - edge_index=data.edge_index, - edge_attr=data.edge_attr.float(), - batch=data.batch) - - # * Read prediction values - if task == "classification": - logit = model_out.cpu().detach() - logit_list = list(torch.ravel(logit).numpy()) - detach = torch.sigmoid(logit) - pred = list(torch.ravel(detach).numpy()) - pred_lbl = (detach >= 0.5).int() - else: - detach = model_out.cpu().detach() - pred = list(torch.ravel(detach).numpy()) - logit_list = pred - pred_lbl = [None] * detach.shape[0] - - # * Explain - explanation = explainer(data.x.float(), data.edge_index, - edge_attr=data.edge_attr.float(), - batch=data.batch) - - mol_exp = explain.MolExplain(explanation, logit_list, pred, pred_lbl, - loader, out) - mol_exp.retrieve_info(data) - mol_exp.plot_explanations(data) + mol_exp = explain.MolExplain(loader, out_dir, args.algorithm) + _run_native(model, dataloader, explainer, task, mol_exp) + + classifier = None + class_names = None + if args.pharmacophore is not None: + classifier = pharmacophores.get_classifier(args.pharmacophore) + class_names = pharmacophores.get_class_names(args.pharmacophore) + mol_exp.finalize(classifier=classifier, class_names=class_names) + + # Cross-explainer Spearman: if the *other* algorithm has already + # produced records under the same baseline directory, write a + # summary JSON next to this run. + other = "gnnex" if args.algorithm == "ig" else "ig" + other_dir = cwd / "output" / name / "explain" / other / args.baseline + other_records = other_dir / "records.npz" + if other_records.exists(): + peer = _load_records_npz(other_records) + spearman_path = out_dir / "cross_explainer_spearman.json" + other_path = other_dir / "cross_explainer_spearman.json" + explain.spearman_between_runs(mol_exp.records, peer, spearman_path) + explain.spearman_between_runs(peer, mol_exp.records, other_path) + _dump_records_npz(mol_exp.records, out_dir / "records.npz") + + +def _dump_records_npz(records, path): + """Persist accumulated per-molecule scores for later Spearman pairing. + + Stores SMILES, fragments, and the 1-D scores vector per molecule in a + single .npz archive (one entry per molecule, plus a manifest list of + SMILES). This is the minimum needed to recompute cross-explainer + Spearman after the fact, without redoing the attribution run. + """ + import numpy as np + smiles = [r["smiles"] for r in records] + payload = {"smiles": np.array(smiles, dtype=object)} + for i, r in enumerate(records): + payload[f"scores_{i}"] = np.asarray(r["scores"], dtype=float) + payload[f"frags_{i}"] = np.array(r["fragments"], dtype=object) + payload[f"top_{i}"] = np.array( + "" if r["top_fragment"] is None else r["top_fragment"], + dtype=object) + np.savez(path, **payload) + + +def _load_records_npz(path): + """Load records saved by `_dump_records_npz`.""" + import numpy as np + arch = np.load(path, allow_pickle=True) + smiles = list(arch["smiles"]) + records = [] + for i, smi in enumerate(smiles): + records.append({"smiles": str(smi), + "scores": arch[f"scores_{i}"], + "fragments": list(arch[f"frags_{i}"]), + "top_fragment": (str(arch[f"top_{i}"]) + if str(arch[f"top_{i}"]) != "" + else None)}) + return records diff --git a/frame/generate.py b/frame/generate.py index 5a3ad50..c61a026 100644 --- a/frame/generate.py +++ b/frame/generate.py @@ -1,4 +1,5 @@ import os +import json import uuid import shutil import argparse @@ -7,12 +8,47 @@ import yaml import torch import joblib +import numpy as np from frame.source import datasets device = "cuda" if torch.cuda.is_available() else "cpu" +def _graph_size_stats(dataset): + """Return mean/std of node and edge counts across a dataset.""" + n_nodes = [] + n_edges = [] + for data in dataset: + n_nodes.append(int(data.x.shape[0])) + n_edges.append(int(data.edge_index.shape[1])) + if not n_nodes: + return {"n_graphs": 0, + "nodes": {"mean": 0.0, "std": 0.0}, + "edges": {"mean": 0.0, "std": 0.0}} + return {"n_graphs": len(n_nodes), + "nodes": {"mean": float(np.mean(n_nodes)), + "std": float(np.std(n_nodes)), + "min": int(np.min(n_nodes)), + "max": int(np.max(n_nodes))}, + "edges": {"mean": float(np.mean(n_edges)), + "std": float(np.std(n_edges)), + "min": int(np.min(n_edges)), + "max": int(np.max(n_edges))}} + + +def _write_dataset_stats(dataset, loader: str, path_csv: str, + project_dir: Path): + """Persist graph-size and BRICS-exclusion stats to dataset_stats.json.""" + stats = {"loader": loader, + "source_csv": str(path_csv), + "graph_size": _graph_size_stats(dataset)} + if loader == "decompose" and hasattr(dataset, "exclusion_summary"): + stats["brics_exclusion"] = dataset.exclusion_summary() + with open(project_dir / "dataset_stats.json", "w") as fh: + json.dump(stats, fh, indent=2) + + def main(): args_parser = argparse.ArgumentParser() args_parser.add_argument("-c", "--config", dest="config", required=True) @@ -56,5 +92,7 @@ def main(): dump_data = {"dataset": dataset, "metadata": metadata} joblib.dump(dump_data, project_dir / "data.joblib") + _write_dataset_stats(dataset, loader, path_csv, project_dir) + if os.path.isdir(cwd / "???"): shutil.rmtree(cwd / "???") diff --git a/frame/scaffold_split.py b/frame/scaffold_split.py index 6b7cf04..8bd4462 100644 --- a/frame/scaffold_split.py +++ b/frame/scaffold_split.py @@ -1,10 +1,44 @@ +import json import argparse +from collections import Counter +from pathlib import Path import pandas as pd +from rdkit import Chem +from rdkit import RDLogger +from rdkit.Chem.Scaffolds import MurckoScaffold from frame.source.datasets import scaffold_split +lg = RDLogger.logger() +lg.setLevel(RDLogger.CRITICAL) + + +def _murcko(smiles: str, include_chirality: bool): + mol = Chem.MolFromSmiles(smiles) + if mol is None: + return "" + return MurckoScaffold.MurckoScaffoldSmiles( + mol=mol, includeChirality=include_chirality) + + +def _scaffold_stats(smiles_list: list, sets: list, include_chirality: bool): + """Summarise scaffold counts and split sizes.""" + scaffolds = [_murcko(s, include_chirality) for s in smiles_list] + group_sizes = Counter(scaffolds) + set_counts = Counter(sets) + return {"n_molecules": len(smiles_list), + "n_scaffolds": len(group_sizes), + "n_singleton_scaffolds": sum(1 for v in group_sizes.values() + if v == 1), + "largest_scaffold_group": max(group_sizes.values(), + default=0), + "split_sizes": {"train": set_counts.get("train", 0), + "valid": set_counts.get("valid", 0), + "test": set_counts.get("test", 0)}} + + def main(): parser = argparse.ArgumentParser( description=("Rewrite the `set` column of a CSV using a Murcko " @@ -25,7 +59,8 @@ def main(): if "smiles" not in df.columns: raise ValueError("Input CSV must have a `smiles` column.") - sets = scaffold_split(df["smiles"].tolist(), + smiles_list = df["smiles"].tolist() + sets = scaffold_split(smiles_list, fracs=tuple(args.fracs), include_chirality=args.chirality) df["set"] = sets @@ -34,3 +69,8 @@ def main(): print(f"Scaffold split: {counts}") df.to_csv(args.output, index=False) + + stats = _scaffold_stats(smiles_list, sets, args.chirality) + stats_path = Path(args.output).with_name("scaffold_stats.json") + with open(stats_path, "w") as fh: + json.dump(stats, fh, indent=2) diff --git a/frame/source/datasets/decompose.py b/frame/source/datasets/decompose.py index 2ba9884..f8347dd 100644 --- a/frame/source/datasets/decompose.py +++ b/frame/source/datasets/decompose.py @@ -35,8 +35,25 @@ class DecomposeDataset(InMemoryDataset): + """BRICS fragment-level molecular dataset. + + Molecules with no BRICS-cleavable bonds cannot be decomposed and + are skipped. The SMILES and ids of skipped molecules are kept on + the dataset so frame.generate can report the exclusion rate + required by the manuscript. + + Attributes: + excluded_smiles: SMILES strings of molecules skipped because + BRICS could not decompose them. + excluded_ids: Aligned list of dataset ids. + n_total: Total number of input rows (incl. excluded). + """ + def __init__(self, path: str, transform=None, pre_transform=None): self.path = path + self.excluded_smiles = [] + self.excluded_ids = [] + self.n_total = 0 super().__init__(None, transform, pre_transform, log=False) data_list = self.process_data() @@ -55,6 +72,7 @@ def process_data(self): col_id = cols.index("id") dataset = dataset[1:-1] + self.n_total = len(dataset) # * Iterate data_list = [] @@ -74,32 +92,46 @@ def process_data(self): # Create graph object frags, frag_map, atom_map = _get_map(mol_smiles) - if frags is not None: - xs = [] - for frag in frags: - xs.append(_gen_features(frag)) - x = torch.stack(xs, dim=0) - - mapping = [list(atom_map.keys()), list(atom_map.values())] - - edges = [] - for u, v in frag_map: - edges.append((u, v)) - edges.append((v, u)) - edge_index = torch.tensor(edges, dtype=torch.long) - edge_index = edge_index.t().contiguous() - edge_attr = torch.ones(edge_index.size(1), 1) - - data = Data(x=x, edge_index=edge_index, - edge_attr=edge_attr, y=y, - idx=mol_idx, set=mol_set, - frag=frags, atom_map=mapping, - smiles=mol_smiles) - - data_list.append(data) + if frags is None: + self.excluded_smiles.append(mol_smiles) + self.excluded_ids.append(mol_idx) + continue + + xs = [] + for frag in frags: + xs.append(_gen_features(frag)) + x = torch.stack(xs, dim=0) + + mapping = [list(atom_map.keys()), list(atom_map.values())] + + edges = [] + for u, v in frag_map: + edges.append((u, v)) + edges.append((v, u)) + edge_index = torch.tensor(edges, dtype=torch.long) + edge_index = edge_index.t().contiguous() + edge_attr = torch.ones(edge_index.size(1), 1) + + data = Data(x=x, edge_index=edge_index, + edge_attr=edge_attr, y=y, + idx=mol_idx, set=mol_set, + frag=frags, atom_map=mapping, + smiles=mol_smiles) + + data_list.append(data) return data_list + def exclusion_summary(self): + """Return {n_total, n_excluded, fraction, excluded} dict.""" + n_excl = len(self.excluded_smiles) + frac = (n_excl / self.n_total) if self.n_total else 0.0 + return {"n_total": self.n_total, + "n_excluded": n_excl, + "fraction_excluded": frac, + "excluded_smiles": list(self.excluded_smiles), + "excluded_ids": list(self.excluded_ids)} + @property def raw_file_names(self): return [] diff --git a/frame/source/explain/__init__.py b/frame/source/explain/__init__.py index b555819..f666da2 100644 --- a/frame/source/explain/__init__.py +++ b/frame/source/explain/__init__.py @@ -1,14 +1,22 @@ +import json +from typing import Callable, Sequence + import torch import numpy as np -from lxml import etree from rdkit import Chem import matplotlib as mpl import matplotlib.pyplot as plt -import svgutils.transform as sg from rdkit.Chem.Draw import rdMolDraw2D from matplotlib.colors import LinearSegmentedColormap from matplotlib.colors import Normalize, TwoSlopeNorm +from frame.source.explain import metrics_explain +from frame.source.explain.metrics_explain import (fragment_scores, + fragment_hit_rate, + mean_gini, + spearman_cross_explainer, + top_fragment) + mpl.use("Agg") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -24,151 +32,213 @@ class MolExplain: - def __init__(self, explanation, logit, pred, pred_lbl, loader, - out_dir, k=10): - self.mask = explanation.node_mask.detach().cpu() - self.batch = explanation.batch.detach().cpu() - self.pred = pred - self.pred_lbl = pred_lbl - self.logit = logit + """Driver for per-molecule explanation artefacts and metrics. + + The constructor wires up output paths and metric accumulators. + `process_batch` is called once per minibatch by the CLI driver + and handles native attributions (atom-level or fragment-level). + `process_aggregated_batch` consumes pre-aggregated fragment + masks from the aggregated atom-level baseline. `finalize` + writes explain_metrics.json. + + Args: + loader: "default" for atom-level data or "decompose" for + fragment-level data. Controls CSV header and visualisation + branch. + out_dir: Directory where predictions.csv, SVGs, and + explain_metrics.json are written. + algorithm: Label included in the metrics JSON for downstream + cross-explainer comparison. Typically "ig" or + "gnnex". + k: Number of top features per fragment for the optional bar + plot. Currently unused in CSV output. + """ + + def __init__(self, loader: str, out_dir, algorithm: str = "ig", + k: int = 10): self.loader = loader self.out = out_dir + self.algorithm = algorithm self.k = k self.cut = k // 2 - self.labels = np.array(V1) + self.records = [] + self._init_predictions_file() - def retrieve_info(self, graphs): + def _init_predictions_file(self): + """Write the predictions CSV header (overwriting any prior file).""" if self.loader == "default": - header = "id,smiles,real,pred_label,pred,fragment" - with open(self.out / "predictions.csv", "w") as f: - f.write(f"{header}\n") - - self._info_atom(graphs) - + header = "id,smiles,real,pred_label,pred,fragment\n" else: labels = ",".join(V1) - header = "id,smiles,real,pred_label,pred" - with open(self.out / "predictions.csv", "w") as f: - f.write(f"{header+labels}\n") - - self._info_fragment(graphs) - - def _info_atom(self, graphs): - batch_num = self.batch.unique() - masks = [self.mask[self.batch == b] for b in batch_num] - pred_label = "" - + header = "id,smiles,real,pred_label,pred," + labels + "\n" + with open(self.out / "predictions.csv", "w") as f: + f.write(header) + + def process_batch(self, explanation, logit, pred, pred_lbl, graphs): + """Consume one minibatch's native explanation. + + Splits the batched attribution by graph, writes one row per + atom-level molecule or one row per fragment per fragment-level + molecule to predictions.csv, and appends one record per + fragment-level molecule to :attr:`records` for later metric + aggregation. Visualisation is delegated to + `plot_explanations`. + + Args: + explanation: A PyG Explanation with attributes + node_mask and batch. + logit: Per-graph raw model output (list). + pred: Per-graph post-sigmoid probability (classification) + or raw output (regression). + pred_lbl: Per-graph predicted hard label + (classification) or list of None (regression). + graphs: Iterable of Data objects in batch order. + """ + mask = explanation.node_mask.detach().cpu() + batch = explanation.batch.detach().cpu() + if self.loader == "default": + self._info_atom(mask, batch, logit, pred, pred_lbl, graphs) + else: + self._info_fragment(mask, batch, logit, pred, pred_lbl, graphs) + self.plot_explanations(mask, batch, logit, pred, pred_lbl, graphs) + + def process_aggregated_batch(self, agg_results, logit, pred, + pred_lbl, graphs): + """Consume one minibatch's aggregated atom-level baseline. + + agg_results is the output of + `frame.source.explain.aggregate.aggregated_batch_masks`, + one dict (or None for skipped graphs) per atom-level graph. + For each non-skipped graph this writes the same predictions row + layout as `_info_fragment`, appends an aggregated record, + and renders the SVG via the fragment-level visualiser. + + Args: + agg_results: List from + `aggregate.aggregated_batch_masks`. + logit: Per-graph raw model output. + pred: Per-graph probability (cls) or raw output (reg). + pred_lbl: Per-graph predicted label or list of None. + graphs: Atom-level Data objects in batch order. + """ + for idx, record in enumerate(agg_results): + if record is None: + continue + data = graphs[idx] + mask = record["mask"] + fragments = record["fragments"] + label_pred = (pred_lbl[idx].numpy()[0] + if pred_lbl[idx] is not None else "") + self._write_fragment_row(data, mask, fragments, + pred[idx], label_pred) + self._append_record(data, mask, fragments) + self._plot_aggregated(data, mask, pred[idx], logit[idx], + pred_lbl[idx]) + + def _info_atom(self, mask, batch, logit, pred, pred_lbl, graphs): + """Write one CSV row per atom-level molecule.""" + batch_num = batch.unique() + masks = [mask[batch == b] for b in batch_num] for idx in range(len(masks)): data = graphs[idx] real_label = int(data.y.cpu().numpy()[0]) - pred = self.pred[idx] - - if self.pred_lbl[idx] is not None: - pred_label = self.pred_lbl[idx].numpy()[0] - + label_pred = (pred_lbl[idx].numpy()[0] + if pred_lbl[idx] is not None else "") text = (f"{data.idx},{data.smiles},{real_label}," - f"{pred_label},{pred:.3f}\n") - - # * Export prediction + f"{label_pred},{pred[idx]:.3f}\n") with open(self.out / "predictions.csv", "a") as f: f.writelines(text) - def _info_fragment(self, graphs): - batch_num = self.batch.unique() - masks = [self.mask[self.batch == b] for b in batch_num] - pred_label = "" - + def _info_fragment(self, mask, batch, logit, pred, pred_lbl, graphs): + """Write one CSV row per fragment and accumulate scores.""" + batch_num = batch.unique() + masks = [mask[batch == b] for b in batch_num] for idx, node_mask in enumerate(masks): data = graphs[idx] - real_label = int(data.y.cpu().numpy()[0]) - pred = self.pred[idx] - fragments = np.array(data.frag) - - if self.pred_lbl[idx] is not None: - pred_label = self.pred_lbl[idx].numpy()[0] - - mask_list = node_mask.cpu().numpy().tolist() - mask_list = [[f"{m:.3f}" for m in mask] for mask in mask_list] - - text = [] - for mask, frag in zip(mask_list, fragments): - contribs = ",".join(mask) - txt = (f"{data.idx},{data.smiles},{real_label},{pred_label}," - f"{pred:.3f},{frag},{contribs}\n") - text.append(txt) - - # * Export prediction - with open(self.out / "predictions.csv", "a") as f: - f.writelines(text) - - def plot_explanations(self, graphs): - batch_num = self.batch.unique() - masks = [self.mask[self.batch == b] for b in batch_num] - + fragments = list(np.array(data.frag)) + label_pred = (pred_lbl[idx].numpy()[0] + if pred_lbl[idx] is not None else "") + self._write_fragment_row(data, node_mask.cpu().numpy(), + fragments, pred[idx], label_pred) + self._append_record(data, node_mask.cpu().numpy(), fragments) + + def _write_fragment_row(self, data, mask_2d, fragments, pred_val, + label_pred): + """Append one row per fragment to predictions.csv.""" + real_label = int(data.y.cpu().numpy()[0]) + mask_list = [[f"{m:.3f}" for m in row] for row in mask_2d] + text = [] + for row, frag in zip(mask_list, fragments): + contribs = ",".join(row) + text.append(f"{data.idx},{data.smiles},{real_label}," + f"{label_pred},{pred_val:.3f},{frag}," + f"{contribs}\n") + with open(self.out / "predictions.csv", "a") as f: + f.writelines(text) + + def _append_record(self, data, mask_2d, fragments): + """Accumulate one molecule's per-fragment scores for metrics.""" + scores = fragment_scores(np.asarray(mask_2d)) + self.records.append({"idx": str(data.idx), + "smiles": data.smiles, + "fragments": list(fragments), + "scores": scores, + "top_fragment": top_fragment(scores, + fragments)}) + + def plot_explanations(self, mask, batch, logit, pred, pred_lbl, + graphs): + """Render per-molecule SVGs for the current minibatch.""" + batch_num = batch.unique() + masks = [mask[batch == b] for b in batch_num] for idx, node_mask in enumerate(masks): data = graphs[idx] name = data.idx - pred = self.pred[idx] - logit = self.logit[idx] - pred_label = self.pred_lbl[idx] - if pred_label is not None: - pred_label = pred_label.numpy()[0] - + label_pred = (pred_lbl[idx].numpy()[0] + if pred_lbl[idx] is not None else None) if self.loader == "default": - self._explain_atom(data, node_mask, pred, logit, - pred_label, name) - + self._explain_atom(data, node_mask, pred[idx], + logit[idx], label_pred, name) else: - # * Feature-level bar plot - # self._bar_plot(node_mask, name) - - # * Fragment-level visualization - # fragments = data.frag - # self._frag_visualization(node_mask, fragments, name) - - # * Molecule-level visualization - self._explain_frag(data, node_mask, pred, logit, - pred_label, name) - - def _rescale_mask(sel, mask_atom, logit): + self._explain_frag(data, node_mask, pred[idx], + logit[idx], label_pred, name) + + def _plot_aggregated(self, data, mask_2d, pred_val, logit_val, + pred_lbl_t): + """Render an SVG for the aggregated baseline on an atom-level graph. + + Uses the same fragment-level visualisation as the native model, + but takes the atom-to-fragment map from the matching fragment + record instead of data.atom_map. + """ + atom_map = getattr(data, "agg_atom_map", None) + if atom_map is None: + return + scores = fragment_scores(np.asarray(mask_2d)) + scores = self._rescale_mask(scores, logit_val) + scores = np.round(scores, 3) + label_pred = (pred_lbl_t.numpy()[0] + if pred_lbl_t is not None else None) + self._draw_fragment_svg(data, scores, atom_map, pred_val, + logit_val, label_pred, data.idx) + + @staticmethod + def _rescale_mask(mask_atom: np.ndarray, logit): """Rescale per-atom attributions so they sum to the raw logit.""" current_sum = mask_atom.sum() if current_sum == 0: return mask_atom - return mask_atom * (logit / current_sum) + return mask_atom * (float(logit) / current_sum) - def _explain_atom(self, data, node_mask, pred, logit, pred_label, name): + def _explain_atom(self, data, node_mask, pred, logit, + pred_label, name): smiles = data.smiles mol = Chem.MolFromSmiles(smiles) - mask_atom = torch.sum(node_mask, dim=1).numpy() mask_atom = self._rescale_mask(mask_atom, logit) mask_atom = np.round(mask_atom, 3) - - min_val = mask_atom.min() - max_val = mask_atom.max() - if min_val > 0: - max_val *= 1.3 - cmap = mpl.cm.ScalarMappable(norm=Normalize(vmin=0, - vmax=max_val), - cmap=mpl.cm.Blues) - elif max_val < 0: - min_val *= 1.3 - cmap = mpl.cm.ScalarMappable(norm=Normalize(vmin=min_val, - vmax=0), - cmap=mpl.cm.Oranges_r) - else: - min_val *= 1.3 - max_val *= 1.3 - pos_colors = plt.cm.Blues(np.linspace(0, 1, 128)) - neg_colors = plt.cm.Oranges_r(np.linspace(0, 1, 128)) - combined = np.vstack((neg_colors, pos_colors)) - color = LinearSegmentedColormap.from_list("OrBu", combined) - cmap = mpl.cm.ScalarMappable(norm=TwoSlopeNorm(vmin=min_val, - vcenter=0, - vmax=max_val), - cmap=color) + cmap = self._make_cmap(mask_atom) highlight_node = {} for atom in mol.GetAtoms(): @@ -177,75 +247,74 @@ def _explain_atom(self, data, node_mask, pred, logit, pred_label, name): highlight_node[idx] = [rgb] atom.SetProp("atomNote", str(mask_atom[idx])) - legend = (f"Graph ID: {name}\n{smiles}\n" - f"Prediction: {pred:.3f}\tLogits: {logit:.3f}\t|\t" - f"Class: {pred_label}\tTrue: {int(data.y)}") + legend = self._format_legend(data, smiles, pred, logit, + pred_label, name) + self._write_svg(mol, legend, highlight_node, name) - if pred_label is None: # Regression - legend = (f"Graph ID: {name}\n{smiles}\n" - f"Prediction: {pred:.3f}\tTrue: {float(data.y):.3f}") - - drawer = rdMolDraw2D.MolDraw2DSVG(1200, 800) - opts = drawer.drawOptions() - opts.fillHighlights = True - opts.annotationFontScale = 0.5 - opts.legendFontSize = 25 - drawer.DrawMoleculeWithHighlights(mol, legend, highlight_node, - {}, {}, {}) - drawer.FinishDrawing() - - with open(self.out / f"{data.idx}.svg", "w") as f: - f.write(drawer.GetDrawingText()) - - def _explain_frag(self, data, node_mask, pred, logit, pred_label, name): + def _explain_frag(self, data, node_mask, pred, logit, pred_label, + name): + atom_map = dict(zip(data.atom_map[0], data.atom_map[1])) + scores = fragment_scores(node_mask.numpy()) + scores = self._rescale_mask(scores, logit) + scores = np.round(scores, 3) + self._draw_fragment_svg(data, scores, atom_map, pred, logit, + pred_label, name) + + def _draw_fragment_svg(self, data, frag_scores, atom_map, pred, + logit, pred_label, name): + """Draw a 2-D molecule with atoms coloured by parent fragment.""" smiles = data.smiles mol = Chem.MolFromSmiles(smiles) - - atom_map = dict(zip(data.atom_map[0], data.atom_map[1])) - mask_atom = torch.sum(node_mask, dim=1).numpy() - mask_atom = self._rescale_mask(mask_atom, logit) - mask_atom = np.round(mask_atom, 3) - - min_val = mask_atom.min() - max_val = mask_atom.max() - if min_val >= 0: - max_val *= 1.3 - cmap = mpl.cm.ScalarMappable(norm=Normalize(vmin=0, - vmax=max_val), - cmap=mpl.cm.Blues) - elif max_val <= 0: - min_val *= 1.3 - cmap = mpl.cm.ScalarMappable(norm=Normalize(vmin=min_val, - vmax=0), - cmap=mpl.cm.Oranges_r) - else: - min_val *= 1.3 - max_val *= 1.3 - pos_colors = plt.cm.Blues(np.linspace(0, 1, 128)) - neg_colors = plt.cm.Oranges_r(np.linspace(0, 1, 128)) - combined = np.vstack((neg_colors, pos_colors)) - color = LinearSegmentedColormap.from_list("OrBu", combined) - cmap = mpl.cm.ScalarMappable(norm=TwoSlopeNorm(vmin=min_val, - vcenter=0, - vmax=max_val), - cmap=color) + scores = np.round(np.asarray(frag_scores, dtype=float), 3) + cmap = self._make_cmap(scores) highlight_node = {} for atom in mol.GetAtoms(): idx = atom.GetIdx() - frag_val = mask_atom[atom_map[idx]] + frag_val = scores[atom_map[idx]] rgb = cmap.to_rgba(frag_val)[:-1] highlight_node[idx] = [rgb] atom.SetProp("atomNote", str(frag_val)) - legend = (f"Graph ID: {name}\n{smiles}\n" - f"Prediction: {pred:.3f}\tLogits: {logit:.3f}\t|\t" - f"Class: {pred_label}\tTrue: {int(data.y)}") - - if pred_label is None: # Regression - legend = (f"Graph ID: {name}\n{smiles}\n" - f"Prediction: {pred:.3f}\tTrue: {float(data.y):.3f}") + legend = self._format_legend(data, smiles, pred, logit, + pred_label, name) + self._write_svg(mol, legend, highlight_node, name) + @staticmethod + def _make_cmap(values: np.ndarray): + """Build a matplotlib colormap suited to the sign of values.""" + arr = np.asarray(values).ravel() + if arr.size == 0: + return mpl.cm.ScalarMappable(norm=Normalize(0, 1), + cmap=mpl.cm.Blues) + vmin = float(arr.min()) + vmax = float(arr.max()) + if vmin >= 0: + return mpl.cm.ScalarMappable(norm=Normalize(vmin=0, + vmax=vmax * 1.3 + or 1.0), + cmap=mpl.cm.Blues) + if vmax <= 0: + return mpl.cm.ScalarMappable(norm=Normalize(vmin=vmin * 1.3, + vmax=0), + cmap=mpl.cm.Oranges_r) + pos_colors = plt.cm.Blues(np.linspace(0, 1, 128)) + neg_colors = plt.cm.Oranges_r(np.linspace(0, 1, 128)) + combined = np.vstack((neg_colors, pos_colors)) + color = LinearSegmentedColormap.from_list("OrBu", combined) + norm = TwoSlopeNorm(vmin=vmin * 1.3, vcenter=0, vmax=vmax * 1.3) + return mpl.cm.ScalarMappable(norm=norm, cmap=color) + + @staticmethod + def _format_legend(data, smiles, pred, logit, pred_label, name): + if pred_label is None: + return (f"Graph ID: {name}\n{smiles}\n" + f"Prediction: {pred:.3f}\tTrue: {float(data.y):.3f}") + return (f"Graph ID: {name}\n{smiles}\n" + f"Prediction: {pred:.3f}\tLogits: {logit:.3f}\t|\t" + f"Class: {pred_label}\tTrue: {int(data.y)}") + + def _write_svg(self, mol, legend, highlight_node, name): drawer = rdMolDraw2D.MolDraw2DSVG(1200, 800) opts = drawer.drawOptions() opts.fillHighlights = True @@ -254,135 +323,78 @@ def _explain_frag(self, data, node_mask, pred, logit, pred_label, name): drawer.DrawMoleculeWithHighlights(mol, legend, highlight_node, {}, {}, {}) drawer.FinishDrawing() - with open(self.out / f"{name}.svg", "w") as f: f.write(drawer.GetDrawingText()) - def _bar_plot(self, node_mask, name): - # * Feature-level bar plot - mask_feat = torch.sum(node_mask, dim=0).numpy() - feats = self._get_top(mask_feat) - - fig, ax = plt.subplots(figsize=(10, 6)) - all_lbl = np.append(feats[1]["labels"], feats[0]["labels"]) - all_val = np.append(feats[1]["contrib"], feats[0]["contrib"]) - colors = (["SteelBlue"] * len(feats[1]["labels"]) + - ["DarkOrange"] * len(feats[0]["labels"])) - - ax.barh(all_lbl, all_val, color=colors) - ax.set_title(f"Top {self.k} Features - {name}") - ax.set_xlabel("Contribution") - ax.invert_yaxis() - - plt.xlim(mask_feat.min() * 1.15, mask_feat.max() * 1.15) - for i, v in enumerate(all_val): - x_off = 0.02 if v > 0 else -0.3 - ax.text(v + x_off, i, str(v), va="center", fontsize=8) - - plt.tight_layout() - out_feat = self.out / f"{name}_feat.svg" - fig.savefig(out_feat, format="svg") - plt.close(fig) - - def _get_top(self, mask, fragments=None): - mask = np.round(mask, 3) - labels = np.array(self.labels) - - pos = {"contrib": np.array([]), - "labels": np.array([]), - "fragment": np.array([])} - pos_mask = mask > 0 - if np.any(pos_mask): - idx_pos = np.argsort(mask[pos_mask])[-self.cut:][::-1] - idx_pos = np.where(pos_mask)[0][idx_pos] - pos = {"contrib": mask[idx_pos]} - - if fragments is None: - pos["labels"] = labels[idx_pos] - else: - pos["fragment"] = fragments[idx_pos] - - neg = {"contrib": np.array([]), - "labels": np.array([]), - "fragment": np.array([])} - neg_mask = mask < 0 - if np.any(neg_mask): - idx_neg = np.argsort(mask[neg_mask])[:self.cut] - idx_neg = np.where(neg_mask)[0][idx_neg] - neg = {"contrib": mask[idx_neg]} - - if fragments is None: - neg["labels"] = labels[idx_neg] - else: - neg["fragment"] = fragments[idx_neg] - - cuts = {0: neg, 1: pos} - return cuts - - def _frag_visualization(self, node_mask, fragments, name): - mask_frag = node_mask.numpy().tolist() - frag_imgs = [] - - for i, frag in enumerate(fragments): - top_val = self._get_top(mask_frag[i]) - label = np.append(top_val[1]["labels"], top_val[0]["labels"]) - cntrb = np.append(top_val[1]["contrib"], top_val[0]["contrib"]) - - contrib = np.sum(mask_frag[i]).round(3) - entries = [f"{lbl}: {val}" for lbl, val in zip(label, cntrb)] - frag_imgs.append(self._subplot(frag, entries, contrib)) - - # Create image - fig = self._create_frag_image(frag_imgs, 1600, 300) - fig.save(self.out / f"{name}_frag.svg") - - def _subplot(self, frag, entries, contrib, size=(500, 250)): - mol = Chem.MolFromSmiles(frag) - - pos_1 = ", ".join(entries[:3]) - pos_2 = ", ".join(entries[3: 5]) - neg_1 = ", ".join(entries[5: 8]) - neg_2 = ", ".join(entries[8:]) - legend = (f"{frag}\nTotal: {contrib}" - f"\n{pos_1}\n{pos_2}\n" - f"\n{neg_1}\n{neg_2}") - - drawer = rdMolDraw2D.MolDraw2DSVG(size[0], size[1]) - opts = drawer.drawOptions() - opts.addAtomIndices = False - opts.addStereoAnnotation = True - opts.dummiesAreAttachments = False - opts.padding = 0.0 - opts.legendFraction = 0.5 - opts.legendFontSize = 16 - drawer.DrawMolecule(mol, legend=legend) - drawer.FinishDrawing() - - mol_svg = drawer.GetDrawingText() - mol_fig = sg.fromstring(mol_svg) - mol_plot = mol_fig.getroot() - - return mol_plot - - def _create_frag_image(self, images, width=1600, height=300): - n_rows = int(np.ceil(len(images) / 3)) - h = height * n_rows - bg = f'' - background = etree.fromstring(bg) - - x = 50 - y = 20 - count = 0 - for img in images: - img.moveto(x, y) - x += width / 3.14 - count += 1 - - if count == 3: - count = 0 - x = 50 - y += height - - fig = sg.SVGFigure(str(width), str(h)) - fig.append([background] + images) - return fig + def finalize(self, classifier: Callable = None, + class_names: Sequence[str] = None): + """Write explain_metrics.json summarising accumulated records. + + Computes the mean Gini coefficient of fragment-importance + distributions and, if a pharmacophore classifier and + class_names are supplied, the fragment hit rate with a + bootstrap 95% CI and a per-class breakdown. + + Args: + classifier: Optional callable mapping fragment SMILES to a + class name string (or None). + class_names: Optional iterable of all valid class names for + the per-class breakdown. + + Returns: + The metrics dict that was written to disk. + """ + score_vecs = [r["scores"] for r in self.records] + tops = [r["top_fragment"] for r in self.records + if r["top_fragment"] is not None] + metrics = {"algorithm": self.algorithm, + "loader": self.loader, + "n_molecules": len(self.records), + "mean_gini": mean_gini(score_vecs)} + if classifier is not None and class_names is not None: + metrics["fragment_hit_rate"] = fragment_hit_rate( + tops, classifier, class_names) + with open(self.out / "explain_metrics.json", "w") as fh: + json.dump(metrics, fh, indent=2) + return metrics + + +def spearman_between_runs(records_a: Sequence[dict], + records_b: Sequence[dict], + out_path): + """Compute per-molecule Spearman rho between two explainer runs. + + The two record lists must come from the same dataset (same SMILES + set) so that fragment vectors can be aligned. Molecules present in + only one run are skipped. + + Args: + records_a: MolExplain.records from explainer A. + records_b: MolExplain.records from explainer B. + out_path: Destination JSON path. Parent directory must exist. + + Returns: + The metrics dict that was written to disk. + """ + by_smiles_b = {r["smiles"]: r for r in records_b} + aligned_a, aligned_b = [], [] + for r in records_a: + match = by_smiles_b.get(r["smiles"]) + if match is None: + continue + if len(match["scores"]) != len(r["scores"]): + continue + aligned_a.append(r["scores"]) + aligned_b.append(match["scores"]) + + summary = spearman_cross_explainer(aligned_a, aligned_b) + summary["n_overlap"] = len(aligned_a) + with open(out_path, "w") as fh: + json.dump(summary, fh, indent=2) + return summary + + +__all__ = ["MolExplain", + "spearman_between_runs", + "metrics_explain", + "V1"] diff --git a/frame/source/explain/aggregate.py b/frame/source/explain/aggregate.py new file mode 100644 index 0000000..3aad538 --- /dev/null +++ b/frame/source/explain/aggregate.py @@ -0,0 +1,123 @@ +from typing import Sequence + +import numpy as np +import torch +from torch_geometric.data.data import Data + + +def aggregate_atom_mask(atom_mask: np.ndarray, + atom_to_fragment: Sequence[int], + n_fragments: int): + """Sum a per-atom attribution vector into per-fragment scores. + + Args: + atom_mask: 2-D array of shape (n_atoms, n_features) or 1-D + array of length n_atoms. The feature axis is preserved + so the result is shape-compatible with native fragment-level + attributions. + atom_to_fragment: Sequence of length n_atoms where + atom_to_fragment[i] is the fragment index of atom i. + n_fragments: Total number of BRICS fragments. + + Returns: + Numpy array of shape (n_fragments, n_features) if + atom_mask was 2-D, otherwise shape (n_fragments,). + + Raises: + ValueError: If the lengths of atom_mask and + atom_to_fragment disagree. + """ + arr = np.asarray(atom_mask, dtype=float) + n_atoms = arr.shape[0] + if n_atoms != len(atom_to_fragment): + raise ValueError(f"atom_mask has {n_atoms} rows but " + f"atom_to_fragment has {len(atom_to_fragment)} " + "entries") + + if arr.ndim == 1: + out = np.zeros(n_fragments, dtype=float) + for atom_idx, frag_idx in enumerate(atom_to_fragment): + out[frag_idx] += arr[atom_idx] + return out + + out = np.zeros((n_fragments, arr.shape[1]), dtype=float) + for atom_idx, frag_idx in enumerate(atom_to_fragment): + out[frag_idx] += arr[atom_idx] + return out + + +def build_smiles_index(decompose_dataset): + """Build a SMILES → (atom_map, fragments) lookup over a dataset. + + Used by the aggregated-baseline pipeline to translate an atom-level + Data object into the matching fragment-level structure. Matching is + done on the canonical SMILES string stored on every Data object. + + Args: + decompose_dataset: A :class:`DecomposeDataset` or any iterable + of Data objects exposing smiles, atom_map, and + frag attributes. + + Returns: + Dict mapping SMILES string to + {"atom_to_fragment": list[int], "fragments": list[str], + "n_fragments": int}. + """ + index = {} + for data in decompose_dataset: + smiles = getattr(data, "smiles", None) + if smiles is None: + continue + atom_keys, frag_values = data.atom_map[0], data.atom_map[1] + mapping = [0] * (max(atom_keys) + 1) + for atom_idx, frag_idx in zip(atom_keys, frag_values): + mapping[atom_idx] = int(frag_idx) + index[smiles] = {"atom_to_fragment": mapping, + "fragments": list(data.frag), + "n_fragments": len(data.frag)} + return index + + +def aggregated_batch_masks(node_mask: torch.Tensor, + batch: torch.Tensor, + atom_graphs: Sequence[Data], + smiles_index: dict): + """Aggregate a batched atom-level mask into per-fragment masks. + + Splits node_mask by the PyG batch vector, looks each graph up by + SMILES in smiles_index, and aggregates with `aggregate_atom_mask`. + Graphs whose SMILES is missing from the index (e.g. molecules with + no BRICS bonds, which are filtered out of the fragment dataset) + are returned as None placeholders so the caller can skip them. + + Args: + node_mask: 2-D tensor of shape (total_atoms, n_features). + batch: 1-D long tensor of length total_atoms mapping each + atom to its graph index inside the batch. + atom_graphs: Atom-level Data objects in batch order; used to + look up SMILES and atom counts. + smiles_index: Output of `build_smiles_index`. + + Returns: + List with one entry per graph in the batch. Each entry is a dict + with keys mask (numpy array, shape (n_fragments, + n_features)), fragments (list of SMILES), or None for + skipped graphs. + """ + mask_np = node_mask.detach().cpu().numpy() + batch_np = batch.detach().cpu().numpy() + out = [] + for graph_idx, graph in enumerate(atom_graphs): + smiles = getattr(graph, "smiles", None) + record = smiles_index.get(smiles) + if record is None: + out.append(None) + continue + + atom_rows = mask_np[batch_np == graph_idx] + agg = aggregate_atom_mask(atom_rows, + record["atom_to_fragment"], + record["n_fragments"]) + out.append({"mask": agg, + "fragments": record["fragments"]}) + return out diff --git a/frame/source/explain/metrics_explain.py b/frame/source/explain/metrics_explain.py new file mode 100644 index 0000000..0bc714c --- /dev/null +++ b/frame/source/explain/metrics_explain.py @@ -0,0 +1,182 @@ +from typing import Callable, Sequence + +import numpy as np +from scipy.stats import spearmanr + + +def fragment_scores(node_mask: np.ndarray): + """Reduce a (n_fragments, n_features) mask to a scalar per fragment. + + Sums across the feature axis. Sum aggregation preserves the sign of + contributions, which is important when attributions can be negative + (e.g. Integrated Gradients). + + Args: + node_mask: 2-D numpy array of shape (n_fragments, n_features). + + Returns: + 1-D numpy array of length n_fragments. + """ + arr = np.asarray(node_mask, dtype=float) + if arr.ndim == 1: + return arr + return arr.sum(axis=1) + + +def gini(values: np.ndarray): + """Gini coefficient on absolute importance values. + + A value of 0 means perfectly uniform attribution; 1 means all the + importance is concentrated on a single fragment. The metric uses + absolute values so it is well-defined when attributions can be + negative. + + Args: + values: 1-D numpy array of fragment importance scores. + + Returns: + Float in [0, 1]. Returns 0.0 for arrays with fewer than + two entries or with zero total absolute mass. + """ + arr = np.abs(np.asarray(values, dtype=float)).ravel() + n = arr.size + if n < 2: + return 0.0 + + total = arr.sum() + if total <= 0: + return 0.0 + + sorted_arr = np.sort(arr) + ranks = np.arange(1, n + 1) + numerator = ((2 * ranks - n - 1) * sorted_arr).sum() + return float(numerator / (n * total)) + + +def mean_gini(per_mol_scores: Sequence[np.ndarray]): + """Mean Gini coefficient across a population of molecules. + + Args: + per_mol_scores: Sequence of 1-D numpy arrays, one per molecule. + + Returns: + Float; 0.0 if the input sequence is empty. + """ + if len(per_mol_scores) == 0: + return 0.0 + return float(np.mean([gini(s) for s in per_mol_scores])) + + +def _bootstrap_ci(values: np.ndarray, n_boot: int = 1000, + seed: int = 13): + """95% percentile bootstrap CI of the mean of a 0/1 array.""" + rng = np.random.default_rng(seed) + n = values.size + if n == 0: + return (0.0, 0.0) + boot = rng.choice(values, size=(n_boot, n), replace=True) + means = boot.mean(axis=1) + lo = float(np.percentile(means, 2.5)) + hi = float(np.percentile(means, 97.5)) + return (lo, hi) + + +def fragment_hit_rate(top_fragments: Sequence[str], + classifier: Callable[[str], str], + class_names: Sequence[str], + n_boot: int = 1000): + """Fraction of molecules whose top fragment matches a known class. + + Args: + top_fragments: One fragment SMILES per molecule (the argmax of + its fragment-importance vector). + classifier: Callable that maps a fragment SMILES to a class + name string or None. + class_names: All valid class names. Used to populate a + per-class breakdown even when some classes never matched. + n_boot: Bootstrap resamples for the overall-rate CI. + + Returns: + Dict with keys overall, ci_low, ci_high, + per_class (mapping class name -> fraction of molecules), + and n (population size). + """ + n = len(top_fragments) + if n == 0: + return {"overall": 0.0, "ci_low": 0.0, "ci_high": 0.0, + "per_class": {c: 0.0 for c in class_names}, + "n": 0} + + matches = np.zeros(n, dtype=int) + per_class = {c: 0 for c in class_names} + for i, frag in enumerate(top_fragments): + label = classifier(frag) + if label is None: + continue + matches[i] = 1 + if label in per_class: + per_class[label] += 1 + + overall = float(matches.mean()) + lo, hi = _bootstrap_ci(matches.astype(float), n_boot=n_boot) + per_class_rate = {c: per_class[c] / n for c in class_names} + + return {"overall": overall, + "ci_low": lo, + "ci_high": hi, + "per_class": per_class_rate, + "n": n} + + +def spearman_cross_explainer(scores_a: Sequence[np.ndarray], + scores_b: Sequence[np.ndarray]): + """Mean per-molecule Spearman correlation between two explainers. + + Molecules with fewer than two fragments contribute nothing (Spearman + is undefined). Molecules where one of the score vectors is constant + likewise return NaN from scipy and are skipped. + + Args: + scores_a: One importance vector per molecule from explainer A. + scores_b: Same population, same order, from explainer B. + + Returns: + Dict with mean (float), std (float), and n_used + (count of molecules contributing to the mean). + """ + if len(scores_a) != len(scores_b): + raise ValueError("scores_a and scores_b must have the same length") + + rhos = [] + for a, b in zip(scores_a, scores_b): + a = np.asarray(a, dtype=float).ravel() + b = np.asarray(b, dtype=float).ravel() + if a.size < 2 or a.size != b.size: + continue + rho, _ = spearmanr(a, b) + if rho is None or np.isnan(rho): + continue + rhos.append(float(rho)) + + if not rhos: + return {"mean": 0.0, "std": 0.0, "n_used": 0} + return {"mean": float(np.mean(rhos)), + "std": float(np.std(rhos)), + "n_used": len(rhos)} + + +def top_fragment(score_vec: np.ndarray, fragments: Sequence[str]): + """Return the fragment SMILES with the largest absolute score. + + Args: + score_vec: 1-D fragment-importance vector. + fragments: Aligned list of fragment SMILES. + + Returns: + SMILES of the argmax-|score| fragment, or None if empty. + """ + arr = np.asarray(score_vec, dtype=float).ravel() + if arr.size == 0 or len(fragments) == 0: + return None + idx = int(np.argmax(np.abs(arr))) + return fragments[idx] diff --git a/frame/source/explain/pharmacophores.py b/frame/source/explain/pharmacophores.py new file mode 100644 index 0000000..3a2d9fb --- /dev/null +++ b/frame/source/explain/pharmacophores.py @@ -0,0 +1,173 @@ +from rdkit import Chem +from rdkit.Chem import Descriptors + + +BACE_PATTERNS = (("transition_state_mimic", + "[CX4]([OX2H])[CX4][NX3]"), + ("transition_state_mimic", + "[CX4]([OX2H])[CX4]=[CX3]"), + ("basic_amine", + "[NX3;H2,H1;!$(N-C=[!#6])]"), + ("s2_aromatic_hydrophobic", + "c1ccc(cc1)[CX4,CX3]"), + ("s1_hydrophobic", + "[CX4;H2,H3][CX4;H2,H3][CX4;H2,H3][CX4;H2,H3]")) + +MPRO_PATTERNS = (("warhead", + "C#N"), + ("warhead", + "C=CC(=O)[!O]"), + ("warhead", + "C(=O)C(=O)N"), + ("s1_lactam_pyridone", + "O=C1NCCC1"), + ("s1_lactam_pyridone", + "O=C1NCCCC1"), + ("s1_lactam_pyridone", + "O=c1cccc[nH]1"), + ("s2_hydrophobic", + "C1CC1"), + ("s2_hydrophobic", + "[CX4;H1,H0]([CX4;H3])([CX4;H3])[CX4;H3]")) + +BBBP_TPSA_THRESHOLD = 30.0 + + +def _classify_by_smarts(fragment_smiles: str, patterns: tuple): + """Return the first SMARTS class name matching the fragment. + + Args: + fragment_smiles: Canonical SMILES of a single BRICS fragment. + patterns: Iterable of (class_name, smarts) pairs evaluated + in order; the first match wins. + + Returns: + Class name string, or None if no pattern matches or the + SMILES is invalid. + """ + mol = Chem.MolFromSmiles(fragment_smiles) + if mol is None: + return None + + for class_name, smarts in patterns: + query = Chem.MolFromSmarts(smarts) + if query is None: + continue + if mol.HasSubstructMatch(query): + return class_name + return None + + +def classify_bace(fragment_smiles: str): + """Classify a fragment against BACE-1 inhibitor pharmacophores. + + Classes (in priority order): transition_state_mimic (catalytic + Asp32/Asp228 contact), basic_amine (S3 recognition), + s2_aromatic_hydrophobic (S2 pocket), s1_hydrophobic (S1 + pocket). + + Args: + fragment_smiles: Canonical SMILES of one BRICS fragment. + + Returns: + Class name string or None. + """ + return _classify_by_smarts(fragment_smiles, BACE_PATTERNS) + + +def classify_mpro(fragment_smiles: str): + """Classify a fragment against SARS-CoV-2 MPro inhibitor pharmacophores. + + Classes (in priority order): warhead (nitrile, Michael + acceptor, alpha-ketoamide), s1_lactam_pyridone (gamma-lactam, + delta-lactam, 2-pyridone), s2_hydrophobic (cyclopropyl or + branched leucine-mimetic). + + Args: + fragment_smiles: Canonical SMILES of one BRICS fragment. + + Returns: + Class name string or None. + """ + return _classify_by_smarts(fragment_smiles, MPRO_PATTERNS) + + +def classify_bbbp(fragment_smiles: str, + threshold: float = BBBP_TPSA_THRESHOLD): + """Classify a fragment by topological polar surface area. + + BBB permeation is governed by global physicochemistry rather than + discrete binding-site motifs, so the BBBP registry partitions + fragments by RDKit TPSA. Fragments with TPSA < threshold are + expected to favour BBB+ predictions; fragments above are expected + to favour BBB- predictions. + + Args: + fragment_smiles: Canonical SMILES of one BRICS fragment. + threshold: TPSA cutoff in Angstrom^2. Defaults to 30.0. + + Returns: + "low_tpsa" or "high_tpsa"; None if the SMILES is + invalid. + """ + mol = Chem.MolFromSmiles(fragment_smiles) + if mol is None: + return None + + tpsa = Descriptors.TPSA(mol) + if tpsa < threshold: + return "low_tpsa" + return "high_tpsa" + + +CLASSIFIERS = {"bace": classify_bace, + "mpro": classify_mpro, + "bbbp": classify_bbbp} + +CLASS_NAMES = {"bace": ("transition_state_mimic", + "basic_amine", + "s2_aromatic_hydrophobic", + "s1_hydrophobic"), + "mpro": ("warhead", + "s1_lactam_pyridone", + "s2_hydrophobic"), + "bbbp": ("low_tpsa", + "high_tpsa")} + + +def get_classifier(name: str): + """Return the classify function for a case study by name. + + Args: + name: One of "bace", "mpro", "bbbp" (case-insensitive). + + Returns: + Callable fragment_smiles -> Optional[str]. + + Raises: + ValueError: If name is not a registered case study. + """ + key = name.lower() + if key not in CLASSIFIERS: + raise ValueError(f"Unknown pharmacophore registry: {name}. " + f"Choose from {list(CLASSIFIERS)}.") + return CLASSIFIERS[key] + + +def get_class_names(name: str): + """Return the tuple of class names for a case study. + + Args: + name: One of "bace", "mpro", "bbbp" (case-insensitive). + + Returns: + Tuple of class-name strings. + + Raises: + ValueError: If name is not a registered case study. + """ + key = name.lower() + if key not in CLASS_NAMES: + raise ValueError(f"Unknown pharmacophore registry: {name}. " + f"Choose from {list(CLASS_NAMES)}.") + return CLASS_NAMES[key] diff --git a/frame/source/train/runner.py b/frame/source/train/runner.py index 6b425fd..3f3c8b9 100644 --- a/frame/source/train/runner.py +++ b/frame/source/train/runner.py @@ -17,33 +17,36 @@ def train_one_seed(seed: int, train_data: list, valid_loader: DataLoader, batch_size: int, workers: int): """Train a fresh model under a fixed seed and return its best state. - Reseeds the global RNGs and pins the train ``DataLoader``'s shuffle - generator so that two calls with the same ``seed`` produce the same + Reseeds the global RNGs and pins the train DataLoader's shuffle + generator so that two calls with the same seed produce the same trajectory. The valid loader is reused across seeds since it does not shuffle. Args: seed: Integer used to reseed Python/NumPy/Torch RNGs and the train-loader generator. - train_data: List of ``Data`` objects belonging to the train split. + train_data: List of Data objects belonging to the train split. valid_loader: Pre-built validation loader (no shuffle, seed-free). - model_name: Backbone selector consumed by ``models.select_model``. - config: Hyperparameter dict consumed by ``models.model_setup``. + model_name: Backbone selector consumed by models.select_model. + config: Hyperparameter dict consumed by models.model_setup. epochs: Maximum number of epochs to train. - patience: Early-stopping patience on the ``optim`` metric. - task: ``"classification"`` or ``"regression"``. - grad_clip: Max-norm gradient clipping value (``None``/0 disables). + patience: Early-stopping patience on the optim metric. + task: "classification" or "regression". + grad_clip: Max-norm gradient clipping value (None/0 disables). drop_edge_p: Train-time edge drop probability. mask_feat_p: Train-time node-feature mask probability. batch_size: Train loader batch size. workers: Number of DataLoader workers. Returns: - Tuple ``(best_state, results, fit_time, n_params)`` where - ``best_state`` is the ``state_dict`` of the best epoch on the valid - ``optim`` metric, ``results`` is the validation metric dict at that - state, ``fit_time`` is wall-clock seconds spent training, and - ``n_params`` is the trainable parameter count. + Tuple (best_state, results, fit_time, n_params, epoch_times) + where best_state is the state_dict of the best epoch on + the valid optim metric, results is the validation + metric dict at that state, fit_time is wall-clock seconds + spent training, n_params is the trainable parameter count, + and epoch_times is a list of per-epoch wall-clock seconds + (length is the number of epochs actually executed before + early stopping). """ train_pkg.set_seed(seed) generator = torch.Generator() @@ -58,15 +61,18 @@ def train_one_seed(seed: int, train_data: list, valid_loader: DataLoader, best_metric = -float("inf") patience_counter = 0 best_state = None + epoch_times = [] start = time.time() for _ in tqdm(range(epochs), ncols=120, desc=f"Seed {seed}"): + epoch_start = time.perf_counter() _ = train_pkg.train_epoch(model, optim, lossfn, train_loader, grad_clip_norm=grad_clip, drop_edge_p=drop_edge_p, mask_feat_p=mask_feat_p) val_metrics = train_pkg.valid_epoch(model, task, valid_loader) schdlr.step() + epoch_times.append(time.perf_counter() - epoch_start) if val_metrics["optim"] > best_metric: patience_counter = 0 @@ -85,4 +91,4 @@ def train_one_seed(seed: int, train_data: list, valid_loader: DataLoader, n_params = sum(int(np.prod(p.size())) for p in model.parameters() if p.requires_grad) - return best_state, results, fit_time, n_params + return best_state, results, fit_time, n_params, epoch_times diff --git a/frame/train.py b/frame/train.py index 079ef1b..1fe18f2 100644 --- a/frame/train.py +++ b/frame/train.py @@ -36,6 +36,21 @@ def _report_seed_stats(task, per_seed_results, project_dir): json.dump(summary, fh, indent=2) +def _write_timing(per_seed_timing, project_dir): + all_epochs = [t for entry in per_seed_timing + for t in entry["epoch_times"]] + summary = {"per_seed": per_seed_timing, + "total_fit_time": float(sum(e["fit_time"] + for e in per_seed_timing)), + "mean_seconds_per_epoch": (float(np.mean(all_epochs)) + if all_epochs else 0.0), + "std_seconds_per_epoch": (float(np.std(all_epochs)) + if all_epochs else 0.0), + "n_epochs_total": len(all_epochs)} + with open(project_dir / "timing.json", "w") as fh: + json.dump(summary, fh, indent=2) + + def run(params, dataset): epochs = params["Data"].get("epochs", 10) workers = params["Data"].get("workers", 4) @@ -74,17 +89,20 @@ def run(params, dataset): # * Train one model per seed, keep the best-seed checkpoint per_seed_results = [] + per_seed_timing = [] best_state = None best_optim = -float("inf") for seed in seeds: - state, results, _, _ = runner.train_one_seed(int(seed), train_data, - valid_loader, - model_name, config, - epochs, patience, task, - grad_clip, drop_edge_p, - mask_feat_p, size, - workers) + state, results, fit_time, _, epoch_times = runner.train_one_seed( + int(seed), train_data, valid_loader, model_name, config, + epochs, patience, task, grad_clip, drop_edge_p, + mask_feat_p, size, workers) per_seed_results.append(results) + per_seed_timing.append({"seed": int(seed), + "fit_time": float(fit_time), + "n_epochs": len(epoch_times), + "epoch_times": [float(t) + for t in epoch_times]}) if results["optim"] > best_optim: best_optim = float(results["optim"]) best_state = state @@ -93,6 +111,7 @@ def run(params, dataset): torch.save(best_state, str(project_dir / "best_model.pt")) _report_seed_stats(task, per_seed_results, project_dir) + _write_timing(per_seed_timing, project_dir) def main(): diff --git a/frame/tune.py b/frame/tune.py index d1f6a14..2771e2e 100644 --- a/frame/tune.py +++ b/frame/tune.py @@ -97,10 +97,12 @@ def objective(trial, params, dataset): total_time = 0.0 n_params = 0 for seed in seeds: - state, results, fit_time, n_params = runner.train_one_seed( - int(seed), train_data, valid_loader, model_name, - config, epochs, patience, task, grad_clip, - drop_edge_p, mask_feat_p, size, workers) + state, results, fit_time, n_params, _ = ( + runner.train_one_seed(int(seed), train_data, + valid_loader, model_name, + config, epochs, patience, + task, grad_clip, drop_edge_p, + mask_feat_p, size, workers)) per_seed_optim.append(float(results["optim"])) per_seed_results.append(results) total_time += fit_time diff --git a/parameters.yaml b/parameters.yaml index 7e62f2d..b9972a6 100755 --- a/parameters.yaml +++ b/parameters.yaml @@ -5,13 +5,13 @@ Data: path_joblib: "path/to/joblib_file.joblib" path_checkpoint: "path/to/checkpoint.pt" - trials: 50 - epochs: 50 + trials: 50 # number of Optuna trials + epochs: 50 # max epochs per run model: gat # gcn | gat | gin | sage | attentive - batch_size: 128 - patience: 15 + batch_size: 128 # graphs per mini-batch + patience: 15 # early-stopping patience (epochs) loader: default # default | decompose - grad_clip_norm: 1.0 + grad_clip_norm: 1.0 # max grad-norm (0/None disables) drop_edge_p: 0.0 # train-time edge drop probability mask_feat_p: 0.0 # train-time node feature mask probability regression_loss: mse # mse | huber | smooth_l1 @@ -21,29 +21,33 @@ Data: train_seeds: [13, 42, 73, 101, 202] # seeds for final retraining Tune: - hidden_channels: + hidden_channels: # node embedding width min: 32 max: 512 # value: 16 - num_layers: + num_layers: # message-passing layers min: 1 max: 4 # value: 2 - dropout_rate: + dropout_rate: # dropout rate min: 0.1 max: 0.6 # value: 0.3 - timesteps: + pool: # readout (ignored by attentive) + choices: [mean, add, max] + # value: mean + gcn_improved: # GCN self-loop weighting (A + 2I) + value: true + heads: # GAT attention heads min: 1 max: 4 - # value: 2 - heads: + # value: 3 + att_v2: # use GATv2 conv + value: true + num_timesteps: # AttentiveFP readout refinement steps min: 1 max: 4 - # value: 3 - pool: - choices: [mean, add, max] - # value: mean + # value: 2 learning_rate: # min: 1.0e-4 diff --git a/pyproject.toml b/pyproject.toml index 7cb80e4..e36a4a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ dependencies = ["torch==2.11.0", "optuna==4.8.0", "pandas==3.0.2", "scikit-learn==1.8.0", + "scipy==1.17.1", "svgutils==0.3.4", "torch-geometric==2.7.0"]