A multilabel protein function prediction model for Viridiplantae (green plants) using Graph Neural Networks with ProtBERT embeddings.
conda env create -f environment.yml
conda activate deepgreengoNote on PyTorch Geometric extras: After activating the env, install the C++ extension wheels matching your exact PyTorch + CUDA version from https://data.pyg.org/whl/:
# Example for torch 2.1.0 + CUDA 12.1: pip install torch-scatter torch-sparse torch-cluster torch-spline-conv \ -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
# 1. Install PyTorch first (choose CUDA version at https://pytorch.org):
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# 2. Install PyTorch Geometric:
pip install torch-geometric
pip install torch-scatter torch-sparse torch-cluster torch-spline-conv \
-f https://data.pyg.org/whl/torch-2.1.0+cu121.html
# 3. Install remaining dependencies:
pip install -r requirements.txtconda install -c conda-forge -c bioconda mmseqs2 # Homology clustering
conda install -c bioconda blast # BLAST baseline (optional)
conda install -c bioconda diamond # DIAMOND baseline (optional)Place your downloaded Viridiplantae PDB structures (.cif.gz) in:
preprocessing/data/structure_files/
You also need the SIFTS annotation file and GO OBO file in preprocessing/data/.
Before running, set your HuggingFace token as an environment variable to prevent rate limits and unauthenticated download errors for ProtBERT:
export HF_TOKEN="your_hf_token_here"bash run_all.shThe script will:
- Extract sequences and build GO annotations from CIF files
- Cluster sequences at 30% identity (MMseqs2) and split into Train/Valid/Test
- Compute pLDDT-filtered contact maps and build PyG graph datasets
- Run BLAST / DIAMOND / Naive baselines
- Train all model ablations (MLP / GCN / GAT / Hybrid × BCE / Focal, 3 seeds, 3 ontologies)
- Run per-cluster generalisation evaluation
- Aggregate results and generate figures
bash run_all.sh --skip-preprocess # Preprocessing already done
bash run_all.sh --skip-ablations # Only run preprocessing + baselines
bash run_all.sh --skip-plots # Skip figure generationEPOCHS=50 BATCH_SIZE=16 MAIN_MODEL=GAT MAIN_LOSS=BCE bash run_all.shTo run the automated hyperparameter sensitivity sweeps for the Hybrid model:
bash run_hyperparam_ablations.shpython3 train.py \
--model Hybrid \
--loss Focal \
--seed 42 \
--ontology biological_process \
--epochs 200python3 predictions.py \
-struc_dir examples/structure_files \
-model_path runs/bp_Hybrid_Focal_s42/best_model.pth \
-output examples/my_predictions.csvdeep-green-GO/
├── preprocessing/
│ ├── extract_seqs_from_cif.py # Sequence extraction + GO annotation
│ ├── cluster_and_split.py # MMseqs2 clustering + cluster-aware split
│ ├── create_cmaps.py # pLDDT-filtered contact maps
│ └── create_batch_dataset.py # PyG graph dataset builder (ProtBERT)
├── baselines/
│ ├── blast/ # BLASTp nearest-neighbour baseline
│ ├── diamond/ # DIAMOND nearest-neighbour baseline
│ ├── naive_frequency/ # GO term frequency prior baseline
│ └── deepfri_comparison/ # Comparison notes vs DeepFRI
├── model.py # GCN / GAT / Hybrid / MLP architectures
├── train.py # Training script with early stopping
├── evals.py # Micro/Macro Fmax, Smin, AUROC, AUPRC
├── focal_loss.py # Focal loss implementation
├── per_cluster_eval.py # Per homology-cluster generalisation eval
├── aggregate_results.py # Aggregate runs into mean±std tables
├── plot_results.py # Publication-quality figure generation
├── predictions.py # Inference on new structures
├── run_all.sh # ONE-CLICK full pipeline
├── run_ablations.sh # Ablation sweep helper
├── run_hyperparam_ablations.sh # Hyperparameter sensitivity helper
├── generate_supp_tables.py # LaTeX config table generator
├── environment.yml # Conda environment
└── requirements.txt # pip requirements