Skip to content

thirtysix/biomedical-RAFT-toolkit

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

RAFT Fine-tuning for Biomedical RAG

RAFT Fine-tuning Pipeline Overview

A complete pipeline for RAFT (Retrieval-Augmented Fine-Tuning) of language models on biomedical question-answering datasets (PubMedQA, BioASQ).

RAFT teaches models to:

  • Cite sources with document IDs
  • Quote evidence verbatim from documents
  • Identify relevant documents among distractors
  • Acknowledge limitations when information is insufficient

Table of Contents

Overview

What is RAFT?

RAFT (Retrieval-Augmented Fine-Tuning) fine-tunes language models on examples that mix oracle documents (containing the answer) with distractor documents (topically related but not answering the question). This teaches models to:

  1. Discriminate between relevant and irrelevant documents
  2. Extract and cite evidence properly
  3. Say "I don't know" when documents don't contain the answer

Paper: Zhang et al., "RAFT: Adapting Language Model to Domain Specific RAG" (arXiv:2403.10131)

Key Features

  • Unsloth optimization: 2-5x training speedup (GitHub)
  • Completions-only training: Loss computed only on model responses
  • 80/20 oracle split: Teaches appropriate epistemic humility
  • HPC-ready: SLURM templates for cluster deployment
  • GGUF export: Convert to Ollama/llama.cpp format

Example: What RAFT Produces

Input prompt (with retrieved documents):

Question: Does coffee consumption increase the risk of heart disease?

Document [1]: A meta-analysis of 36 studies found that moderate coffee consumption
(3-4 cups/day) was associated with a reduced risk of cardiovascular disease...

Document [2]: Caffeine can temporarily raise blood pressure in some individuals...

Document [3]: The antioxidants in coffee, including chlorogenic acid, may provide
cardioprotective effects...

RAFT-trained model output (with citations):

Based on the available evidence, moderate coffee consumption does not appear to
increase heart disease risk and may actually be protective. ##begin_quote## A
meta-analysis of 36 studies found that moderate coffee consumption (3-4 cups/day)
was associated with a reduced risk of cardiovascular disease ##end_quote## [1].
This may be due to ##begin_quote## antioxidants in coffee, including chlorogenic
acid ##end_quote## [3] which provide cardioprotective effects.

When oracle document is absent (20% of training):

I don't know. The provided documents discuss general nutrition but do not contain
specific information about coffee and heart disease risk.

Quick Start

Choose your setup path:

  • Local Setup - For workstations with GPU (uses venv)
  • HPC Setup - For SLURM clusters (uses system modules)

Local Quick Start

# 1. Clone and setup
git clone https://github.com/thirtysix/biomedical-RAFT-toolkit.git
cd pubmed_finetune_RAFT

# 2. Run setup script (creates venv, installs dependencies)
./setup.sh              # Basic setup
# OR
./setup.sh --unsloth    # With Unsloth for 2-5x faster training

# 3. Activate the environment
source venv/bin/activate

# 4. Download and convert PubMedQA
python scripts/data_prep/01_download_pubmedqa.py
python scripts/data_prep/02_convert_pubmedqa_to_raft.py

# 5. (Optional) Add BioASQ data - requires free registration at http://bioasq.org/
#    After downloading BioASQ JSON files to ./bioasq_data/:
python scripts/data_prep/03_process_bioasq.py \
    --input_dir ./bioasq_data \
    --output_dir ./data/raw/bioasq
python scripts/data_prep/04_convert_bioasq_to_raft.py \
    --input_file ./data/raw/bioasq/bioasq_processed.json \
    --output_dir ./data/processed/raft_bioasq

# 6. Combine datasets (add raft_bioasq path if using BioASQ)
python scripts/data_prep/05_combine_and_split.py \
    --input_dirs ./data/processed/raft_pubmedqa/pqa_artificial \
                 ./data/processed/raft_pubmedqa/pqa_labeled
                 # ./data/processed/raft_bioasq  # uncomment if using BioASQ

# 7. Train (single GPU)
python scripts/training/train_raft.py \
    --model_name Qwen/Qwen3-4B-Instruct-2507 \
    --dataset_path ./data/processed/raft_final \
    --output_dir ./checkpoints/raft_pubmed \
    --max_samples 10000  # Small subset for testing

HPC Quick Start

# 1. Clone repository
git clone https://github.com/thirtysix/biomedical-RAFT-toolkit.git
cd pubmed_finetune_RAFT

# 2. Load system PyTorch module (names vary by cluster)
module load pytorch        # Generic
# module load pytorch/2.9  # Or version-specific

# 3. Set up user package directory (persistent across jobs)
export PYTHONUSERBASE=$SCRATCH/.local
export PATH=$PYTHONUSERBASE/bin:$PATH

# 4. Install dependencies to user directory
pip install --user -r requirements.txt

# 5. (Optional) Install Unsloth for faster training
pip install --user "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"

# 6. Configure SLURM scripts with your project account
nano scripts/slurm/train.sh
# Update: --account=your_project

# 7. Download and convert PubMedQA
python scripts/data_prep/01_download_pubmedqa.py
python scripts/data_prep/02_convert_pubmedqa_to_raft.py

# 8. (Optional) Add BioASQ data - requires free registration at http://bioasq.org/
#    After downloading BioASQ JSON files to ./bioasq_data/:
python scripts/data_prep/03_process_bioasq.py \
    --input_dir ./bioasq_data \
    --output_dir ./data/raw/bioasq
python scripts/data_prep/04_convert_bioasq_to_raft.py \
    --input_file ./data/raw/bioasq/bioasq_processed.json \
    --output_dir ./data/processed/raft_bioasq

# 9. Combine datasets (add raft_bioasq path if using BioASQ)
python scripts/data_prep/05_combine_and_split.py \
    --input_dirs ./data/processed/raft_pubmedqa/pqa_artificial \
                 ./data/processed/raft_pubmedqa/pqa_labeled
                 # ./data/processed/raft_bioasq  # uncomment if using BioASQ

# 10. Submit training job
sbatch scripts/slurm/train.sh

See docs/HPC_SETUP.md for detailed HPC instructions.

Requirements

Hardware

Configuration GPU Memory Notes
Minimum 16GB 4-bit quantization required
Recommended 24-32GB V100 or RTX 3090/4090
Optimal 40-80GB A100

Disk Space

Component Size Notes
Model cache 8-16GB Per model downloaded
PubMedQA data ~500MB Raw + processed
BioASQ data ~200MB Optional
Checkpoints ~2GB each Default keeps 3
Total recommended 50GB+ For comfortable experimentation

Software

  • Python 3.10+
  • PyTorch 2.0+
  • CUDA 11.8+ (12.1 recommended)

HuggingFace Setup

Models are downloaded from HuggingFace Hub. Some models require authentication:

# Install HuggingFace CLI (included in requirements.txt)
pip install huggingface_hub

# Login to HuggingFace (required for gated models like Llama, Gemma)
huggingface-cli login
# Enter your token from https://huggingface.co/settings/tokens

Gated models require accepting license terms on their HuggingFace page before download:

For HPC/non-interactive environments, set the token as an environment variable:

export HF_TOKEN=hf_your_token_here
export HF_HOME=$SCRATCH/.hf_cache  # Optional: custom cache location

Installation

Choose your environment: Local (workstations) or HPC (SLURM clusters).

Local Installation

For workstations with GPU, use a virtual environment:

Automated Setup (Recommended):

# Run the setup script
./setup.sh                  # Basic installation
./setup.sh --unsloth        # Include Unsloth for 2-5x faster training
./setup.sh --dev            # Include development tools (pytest, black, etc.)
./setup.sh --unsloth --dev  # Both Unsloth and dev tools

# Activate the environment
source venv/bin/activate

The setup script will:

  • Check Python version (3.10+ required)
  • Detect CUDA and install appropriate PyTorch
  • Create a virtual environment in ./venv
  • Install all dependencies
  • Verify the installation

Manual Setup:

# Create virtual environment
python3 -m venv venv
source venv/bin/activate

# Upgrade pip
pip install --upgrade pip wheel setuptools

# Install PyTorch (adjust for your CUDA version)
# For CUDA 12.x:
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# For CUDA 11.x:
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# For CPU only:
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu

# Install dependencies
pip install -r requirements.txt

# Optional: Install Unsloth for faster training
pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"

HPC Installation

For SLURM clusters, use system modules with user-level pip packages:

# 1. Load system PyTorch module (optimized for cluster hardware)
module load pytorch        # Generic name
# module load pytorch/2.9  # Or version-specific (check: module avail pytorch)

# 2. Set up user package directory (persistent across jobs)
export PYTHONUSERBASE=$SCRATCH/.local
export PATH=$PYTHONUSERBASE/bin:$PATH

# 3. Install dependencies to user directory
pip install --user -r requirements.txt

# 4. (Optional) Install Unsloth for 2-5x faster training
pip install --user "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"

Why use system modules on HPC?

  • System PyTorch is optimized for cluster GPU hardware (CUDA, cuDNN, NCCL)
  • Avoids conda/venv overhead on shared filesystems
  • pip install --user packages persist across sessions in $PYTHONUSERBASE

Add these exports to your ~/.bashrc or SLURM scripts for persistence:

export PYTHONUSERBASE=$SCRATCH/.local
export PATH=$PYTHONUSERBASE/bin:$PATH

See docs/HPC_SETUP.md for detailed cluster-specific instructions.

Data Preparation

Dataset Overview

Dataset Examples Description Access
PubMedQA ~211k artificial, ~1k labeled Yes/no/maybe questions from PubMed abstracts Free download
BioASQ ~4k questions Biomedical QA challenge data Free registration required

RAFT Data Format: Each training example contains:

{
  "messages": [
    {"role": "system", "content": "You are a biomedical assistant..."},
    {"role": "user", "content": "Question: ... Document [1]: ... Document [2]: ..."},
    {"role": "assistant", "content": "Answer with ##begin_quote##...##end_quote## [1]..."}
  ]
}

The conversion scripts handle:

  • Oracle ratio (80/20): 80% of examples include the answer document, 20% contain only distractors
  • Distractor selection: Semantically similar but non-answering documents
  • Citation formatting: [1], [2], etc. with quote markers

1. Download and Convert PubMedQA

# Download PubMedQA dataset
python scripts/data_prep/01_download_pubmedqa.py \
    --output_dir ./data/raw/pubmedqa \
    --explore  # Show sample data

# Convert to RAFT format
python scripts/data_prep/02_convert_pubmedqa_to_raft.py \
    --input_dir ./data/raw/pubmedqa \
    --output_dir ./data/processed/raft_pubmedqa \
    --oracle_ratio 0.8 \
    --num_workers 4

2. (Optional) Prepare and Convert BioASQ

BioASQ requires free registration. Download the training data JSON files after registering:

# Process BioASQ JSON files
python scripts/data_prep/03_process_bioasq.py \
    --input_dir /path/to/bioasq/json \
    --output_dir ./data/raw/bioasq

# Convert to RAFT format
python scripts/data_prep/04_convert_bioasq_to_raft.py \
    --input_file ./data/raw/bioasq/bioasq_processed.json \
    --output_dir ./data/processed/raft_bioasq

3. Combine Datasets

python scripts/data_prep/05_combine_and_split.py \
    --input_dirs ./data/processed/raft_pubmedqa/pqa_artificial \
                 ./data/processed/raft_pubmedqa/pqa_labeled \
    --output_dir ./data/processed/raft_final \
    --train_ratio 0.9 \
    --val_ratio 0.05 \
    --test_ratio 0.05

Add ./data/processed/raft_bioasq to --input_dirs if using BioASQ.

Expected output: After running all steps, you should have:

data/processed/raft_final/
├── train/          # ~190k examples (90%)
├── validation/     # ~10k examples (5%)
└── test/           # ~10k examples (5%)

Training

Supported Models

Any causal language model from Hugging Face Models can be used. Recommended models for biomedical QA:

Model Size GPU Memory License Notes
Qwen/Qwen3-4B-Instruct-2507 4B ~16GB Apache 2.0 Default - Good balance of quality/speed
Qwen/Qwen3-8B-Instruct-2507 8B ~24GB Apache 2.0 Better reasoning capability
meta-llama/Llama-3.1-8B-Instruct 8B ~24GB Llama 3.1 ⚠️ Strong general performance
mistralai/Mistral-7B-Instruct-v0.3 7B ~20GB Apache 2.0 Fast inference
google/gemma-2-9b-it 9B ~24GB Gemma ⚠️ Good instruction following
BioMistral/BioMistral-7B 7B ~20GB Apache 2.0 Pre-trained on biomedical text

⚠️ Gated models: Llama and Gemma require accepting license terms on HuggingFace before download. See HuggingFace Setup.

To use a different model:

python scripts/training/train_raft.py --model_name meta-llama/Llama-3.1-8B-Instruct ...

Local Training (Single GPU)

python scripts/training/train_raft.py \
    --model_name Qwen/Qwen3-4B-Instruct-2507 \
    --dataset_path ./data/processed/raft_final \
    --output_dir ./checkpoints/raft_pubmed \
    --max_seq_length 1024 \
    --num_train_epochs 2 \
    --per_device_train_batch_size 4 \
    --gradient_accumulation_steps 8 \
    --learning_rate 2e-4 \
    --lora_r 16 \
    --lora_alpha 32 \
    --packing

HPC Training (SLURM)

  1. Edit SLURM scripts with your project account:

    nano scripts/slurm/train.sh
    # Update: --account=your_project
  2. Run test job:

    sbatch scripts/slurm/train_test.sh
  3. Run full training:

    sbatch scripts/slurm/train.sh
  4. Resume if interrupted:

    sbatch scripts/slurm/resume.sh

See docs/HPC_SETUP.md for detailed HPC instructions.

Multi-GPU Training (DDP)

For multi-GPU distributed training, disable Unsloth (which is single-GPU optimized):

torchrun --nproc_per_node=4 scripts/training/train_raft.py \
    --model_name Qwen/Qwen3-4B-Instruct-2507 \
    --dataset_path ./data/processed/raft_final \
    --output_dir ./checkpoints/raft_pubmed \
    --no_unsloth  # Required for multi-GPU DDP

Monitoring Training

TensorBoard logs are saved automatically:

# Start TensorBoard
tensorboard --logdir ./checkpoints/raft_pubmed/logs --port 6006

# View at http://localhost:6006

What to look for:

  • Training loss: Should decrease steadily; expect ~1.5-2.5 initially, dropping to ~0.8-1.2
  • Learning rate: Should warm up then decay (cosine schedule)
  • Gradient norm: Should stay below max_grad_norm (0.3 default)

Quick sanity check during training:

# Check latest checkpoint
ls -la ./checkpoints/raft_pubmed/checkpoint-*/

# View training logs
tail -f ./checkpoints/raft_pubmed/logs/*/events.out.*

Training Modes: Completions-Only vs Packing

Mode Best For Trade-off
Completions-only (default) RAFT training Trains only on model responses, not context. Better for learning citation format.
Packing Maximum throughput Packs multiple examples per sequence. Faster but incompatible with completions-only.

These modes are mutually exclusive. The default (--train_on_completions=True) is recommended for RAFT as it focuses learning on the citation/response format rather than memorizing document content.

To use packing instead (for faster training on general tasks):

python scripts/training/train_raft.py \
    --train_on_completions=False \
    --packing

Using Config Files

Instead of command-line arguments, use configs/raft_config.yaml:

# Edit config file
nano configs/raft_config.yaml

# The training script reads from this file by default
# Override specific values via command line as needed

Default Training Parameters

All parameters with their defaults (override via command line):

Model & LoRA:

Parameter Default Description
--model_name Qwen/Qwen3-4B-Instruct-2507 Base model from HuggingFace
--max_seq_length 2048 Maximum context length
--load_in_4bit False 4-bit quantization (single GPU only)
--lora_r 32 LoRA rank (16-32 recommended)
--lora_alpha 32 LoRA alpha scaling
--lora_dropout 0.0 LoRA dropout (0 recommended)
--use_rslora True Rank-Stabilized LoRA
--no_unsloth False Disable Unsloth (required for multi-GPU DDP)

Training:

Parameter Default Description
--num_train_epochs 2 Number of training epochs
--per_device_train_batch_size 8 Batch size per GPU
--gradient_accumulation_steps 4 Gradient accumulation steps
--learning_rate 2e-4 Peak learning rate
--warmup_ratio 0.03 Warmup proportion of training
--weight_decay 0.01 Weight decay for regularization
--max_grad_norm 0.3 Gradient clipping threshold
--train_on_completions True Train only on assistant responses
--packing False Sequence packing (incompatible with completions-only)
--max_samples None Limit training examples (useful for testing)

Checkpoints & Logging:

Parameter Default Description
--output_dir ./checkpoints/raft_pubmed Output directory
--logging_steps 10 Log every N steps
--save_steps 500 Save checkpoint every N steps
--eval_steps 500 Evaluate every N steps
--save_total_limit 3 Max checkpoints to keep
--seed 42 Random seed
--run_name raft-pubmed Name for experiment tracking
--resume_from_checkpoint None Path to checkpoint or auto to find latest

Evaluation

Quick Sanity Check

Before full evaluation, verify the model works:

# Interactive test with a few examples
python scripts/deployment/inference.py \
    --adapter_path ./checkpoints/raft_pubmed \
    --interactive

Try asking a question with documents - the model should cite sources with [1], [2] notation and use ##begin_quote##...##end_quote## markers.

Full Evaluation

python scripts/evaluation/evaluate_raft.py \
    --model_path ./checkpoints/raft_pubmed \
    --test_data ./data/processed/raft_final \
    --output_dir ./evaluation_results \
    --max_examples 500  # For quick evaluation; remove for full test set

Metrics and Interpretation

Metric Description Target Notes
Citation Validity Are cited document IDs real? >95% Should be near-perfect after training
Quote Accuracy Are quotes found in cited documents? >85% Measures faithful extraction
Oracle Identification Does model cite oracle when present? >80% Key RAFT metric
IDK Precision Does model abstain when oracle absent? >70% Tests epistemic humility
ROUGE-1 Unigram overlap with reference >0.4 General answer quality
ROUGE-L Longest common subsequence >0.35 Fluency and structure

Interpreting results:

  • Low Citation Validity (<90%): Model may be hallucinating document IDs. Check training data format.
  • Low Oracle Identification (<70%): Model isn't learning to identify relevant documents. Try longer training or more data.
  • Low IDK Precision (<50%): Model answers when it shouldn't. Verify 80/20 oracle split in training data.
  • Low ROUGE scores: May indicate answer quality issues, but RAFT models often have lower ROUGE due to citation overhead.

Baseline comparison (approximate, varies by model/data):

Model Citation Valid Oracle ID IDK Precision
Base model (no RAFT) ~20% ~40% ~10%
After RAFT training >95% >80% >70%

Deployment

Merge LoRA Adapters

python scripts/deployment/merge_lora.py \
    --base_model Qwen/Qwen3-4B-Instruct-2507 \
    --adapter_path ./checkpoints/raft_pubmed \
    --output_path ./checkpoints/raft_merged

Run Inference

# Interactive mode with LoRA adapter
python scripts/deployment/inference.py \
    --adapter_path ./checkpoints/raft_pubmed \
    --interactive

# With merged model (faster loading)
python scripts/deployment/inference.py \
    --model_path ./checkpoints/raft_merged \
    --interactive

Example Inference Usage

When using the model, format your input with documents:

Question: What is the role of BRCA1 in cancer?

Document [1]: BRCA1 is a tumor suppressor gene that plays a critical role in
DNA repair through homologous recombination. Mutations in BRCA1 significantly
increase the risk of breast and ovarian cancer.

Document [2]: The cell cycle is regulated by cyclins and cyclin-dependent
kinases, which control progression through different phases.

Document [3]: BRCA1 also functions in cell cycle checkpoint control,
activating the G1/S and G2/M checkpoints in response to DNA damage.

Expected model output:

BRCA1 is a tumor suppressor gene with multiple roles in cancer prevention.
##begin_quote## BRCA1 is a tumor suppressor gene that plays a critical role
in DNA repair through homologous recombination ##end_quote## [1]. Additionally,
##begin_quote## BRCA1 also functions in cell cycle checkpoint control,
activating the G1/S and G2/M checkpoints in response to DNA damage ##end_quote##
[3]. Mutations in this gene ##begin_quote## significantly increase the risk of
breast and ovarian cancer ##end_quote## [1].

Note how the model:

  • Cites relevant documents [1] and [3]
  • Ignores irrelevant document [2] (about cell cycle, not BRCA1)
  • Uses ##begin_quote##...##end_quote## for verbatim text

Convert to GGUF (for Ollama)

# Requires llama.cpp
export LLAMA_CPP_PATH=/path/to/llama.cpp

./scripts/deployment/convert_to_gguf.sh \
    ./checkpoints/raft_merged \
    raft-pubmed \
    Q4_K_M

Then use with Ollama:

ollama create raft-pubmed -f Modelfile
ollama run raft-pubmed

Project Structure

pubmed_finetune_RAFT/
├── README.md
├── setup.sh                      # Automated setup script
├── requirements.txt
├── .gitignore
├── .env.example
├── configs/
│   └── raft_config.yaml
├── docs/
│   ├── DATA_SOURCES.md
│   ├── RAFT_EXPLANATION.md
│   └── HPC_SETUP.md
├── src/
│   ├── __init__.py
│   └── prompts.py              # Consolidated prompts
├── scripts/
│   ├── data_prep/
│   │   ├── 01_download_pubmedqa.py
│   │   ├── 02_convert_pubmedqa_to_raft.py
│   │   ├── 03_process_bioasq.py
│   │   ├── 04_convert_bioasq_to_raft.py
│   │   └── 05_combine_and_split.py
│   ├── training/
│   │   └── train_raft.py
│   ├── slurm/
│   │   ├── train.sh
│   │   ├── train_test.sh
│   │   └── resume.sh
│   ├── evaluation/
│   │   └── evaluate_raft.py
│   └── deployment/
│       ├── merge_lora.py
│       ├── inference.py
│       └── convert_to_gguf.sh
├── venv/                       # Git-ignored (created by setup.sh)
├── data/                       # Git-ignored
├── checkpoints/                # Git-ignored
└── logs/                       # Git-ignored

Troubleshooting

Out of Memory

# Reduce batch size
--per_device_train_batch_size 2

# Increase gradient accumulation
--gradient_accumulation_steps 16

# Use 4-bit quantization
--load_in_4bit

# Reduce sequence length
--max_seq_length 1024

Slow Training

# Install Unsloth
pip install unsloth

# Enable packing (if not using completions-only)
--packing

Import Errors

# Ensure correct Python environment
which python
pip list | grep -E "torch|transformers|peft"

Citation

If you use this code, please cite:

@software{raft_pubmed_finetune,
  title={RAFT Fine-tuning for Biomedical RAG},
  author={Harlan Barker},
  year={2024},
  url={https://github.com/thirtysix/biomedical-RAFT-toolkit}
}

Related Work

@article{zhang2024raft,
  title={RAFT: Adapting Language Model to Domain Specific RAG},
  author={Zhang, Tianjun and others},
  journal={arXiv preprint arXiv:2403.10131},
  year={2024}
}

@inproceedings{jin2019pubmedqa,
  title={PubMedQA: A Dataset for Biomedical Research Question Answering},
  author={Jin, Qiao and others},
  booktitle={EMNLP},
  year={2019}
}

Acknowledgments

Documentation developed with assistance from Claude.

License

This project is licensed under the MIT License - see the LICENSE file for details.

Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

About

Fine-tune language models to cite sources and quote evidence using RAFT (Retrieval-Augmented Fine-Tuning) on biomedical QA datasets (PubMedQA, BioASQ).

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors