A complete pipeline for RAFT (Retrieval-Augmented Fine-Tuning) of language models on biomedical question-answering datasets (PubMedQA, BioASQ).
RAFT teaches models to:
- Cite sources with document IDs
- Quote evidence verbatim from documents
- Identify relevant documents among distractors
- Acknowledge limitations when information is insufficient
- Overview
- Quick Start
- Requirements
- Data Preparation
- Training
- Evaluation
- Deployment
- Project Structure
- Troubleshooting
- Citation
RAFT (Retrieval-Augmented Fine-Tuning) fine-tunes language models on examples that mix oracle documents (containing the answer) with distractor documents (topically related but not answering the question). This teaches models to:
- Discriminate between relevant and irrelevant documents
- Extract and cite evidence properly
- Say "I don't know" when documents don't contain the answer
Paper: Zhang et al., "RAFT: Adapting Language Model to Domain Specific RAG" (arXiv:2403.10131)
- Unsloth optimization: 2-5x training speedup (GitHub)
- Completions-only training: Loss computed only on model responses
- 80/20 oracle split: Teaches appropriate epistemic humility
- HPC-ready: SLURM templates for cluster deployment
- GGUF export: Convert to Ollama/llama.cpp format
Input prompt (with retrieved documents):
Question: Does coffee consumption increase the risk of heart disease?
Document [1]: A meta-analysis of 36 studies found that moderate coffee consumption
(3-4 cups/day) was associated with a reduced risk of cardiovascular disease...
Document [2]: Caffeine can temporarily raise blood pressure in some individuals...
Document [3]: The antioxidants in coffee, including chlorogenic acid, may provide
cardioprotective effects...
RAFT-trained model output (with citations):
Based on the available evidence, moderate coffee consumption does not appear to
increase heart disease risk and may actually be protective. ##begin_quote## A
meta-analysis of 36 studies found that moderate coffee consumption (3-4 cups/day)
was associated with a reduced risk of cardiovascular disease ##end_quote## [1].
This may be due to ##begin_quote## antioxidants in coffee, including chlorogenic
acid ##end_quote## [3] which provide cardioprotective effects.
When oracle document is absent (20% of training):
I don't know. The provided documents discuss general nutrition but do not contain
specific information about coffee and heart disease risk.
Choose your setup path:
- Local Setup - For workstations with GPU (uses venv)
- HPC Setup - For SLURM clusters (uses system modules)
# 1. Clone and setup
git clone https://github.com/thirtysix/biomedical-RAFT-toolkit.git
cd pubmed_finetune_RAFT
# 2. Run setup script (creates venv, installs dependencies)
./setup.sh # Basic setup
# OR
./setup.sh --unsloth # With Unsloth for 2-5x faster training
# 3. Activate the environment
source venv/bin/activate
# 4. Download and convert PubMedQA
python scripts/data_prep/01_download_pubmedqa.py
python scripts/data_prep/02_convert_pubmedqa_to_raft.py
# 5. (Optional) Add BioASQ data - requires free registration at http://bioasq.org/
# After downloading BioASQ JSON files to ./bioasq_data/:
python scripts/data_prep/03_process_bioasq.py \
--input_dir ./bioasq_data \
--output_dir ./data/raw/bioasq
python scripts/data_prep/04_convert_bioasq_to_raft.py \
--input_file ./data/raw/bioasq/bioasq_processed.json \
--output_dir ./data/processed/raft_bioasq
# 6. Combine datasets (add raft_bioasq path if using BioASQ)
python scripts/data_prep/05_combine_and_split.py \
--input_dirs ./data/processed/raft_pubmedqa/pqa_artificial \
./data/processed/raft_pubmedqa/pqa_labeled
# ./data/processed/raft_bioasq # uncomment if using BioASQ
# 7. Train (single GPU)
python scripts/training/train_raft.py \
--model_name Qwen/Qwen3-4B-Instruct-2507 \
--dataset_path ./data/processed/raft_final \
--output_dir ./checkpoints/raft_pubmed \
--max_samples 10000 # Small subset for testing# 1. Clone repository
git clone https://github.com/thirtysix/biomedical-RAFT-toolkit.git
cd pubmed_finetune_RAFT
# 2. Load system PyTorch module (names vary by cluster)
module load pytorch # Generic
# module load pytorch/2.9 # Or version-specific
# 3. Set up user package directory (persistent across jobs)
export PYTHONUSERBASE=$SCRATCH/.local
export PATH=$PYTHONUSERBASE/bin:$PATH
# 4. Install dependencies to user directory
pip install --user -r requirements.txt
# 5. (Optional) Install Unsloth for faster training
pip install --user "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
# 6. Configure SLURM scripts with your project account
nano scripts/slurm/train.sh
# Update: --account=your_project
# 7. Download and convert PubMedQA
python scripts/data_prep/01_download_pubmedqa.py
python scripts/data_prep/02_convert_pubmedqa_to_raft.py
# 8. (Optional) Add BioASQ data - requires free registration at http://bioasq.org/
# After downloading BioASQ JSON files to ./bioasq_data/:
python scripts/data_prep/03_process_bioasq.py \
--input_dir ./bioasq_data \
--output_dir ./data/raw/bioasq
python scripts/data_prep/04_convert_bioasq_to_raft.py \
--input_file ./data/raw/bioasq/bioasq_processed.json \
--output_dir ./data/processed/raft_bioasq
# 9. Combine datasets (add raft_bioasq path if using BioASQ)
python scripts/data_prep/05_combine_and_split.py \
--input_dirs ./data/processed/raft_pubmedqa/pqa_artificial \
./data/processed/raft_pubmedqa/pqa_labeled
# ./data/processed/raft_bioasq # uncomment if using BioASQ
# 10. Submit training job
sbatch scripts/slurm/train.shSee docs/HPC_SETUP.md for detailed HPC instructions.
| Configuration | GPU Memory | Notes |
|---|---|---|
| Minimum | 16GB | 4-bit quantization required |
| Recommended | 24-32GB | V100 or RTX 3090/4090 |
| Optimal | 40-80GB | A100 |
| Component | Size | Notes |
|---|---|---|
| Model cache | 8-16GB | Per model downloaded |
| PubMedQA data | ~500MB | Raw + processed |
| BioASQ data | ~200MB | Optional |
| Checkpoints | ~2GB each | Default keeps 3 |
| Total recommended | 50GB+ | For comfortable experimentation |
- Python 3.10+
- PyTorch 2.0+
- CUDA 11.8+ (12.1 recommended)
Models are downloaded from HuggingFace Hub. Some models require authentication:
# Install HuggingFace CLI (included in requirements.txt)
pip install huggingface_hub
# Login to HuggingFace (required for gated models like Llama, Gemma)
huggingface-cli login
# Enter your token from https://huggingface.co/settings/tokensGated models require accepting license terms on their HuggingFace page before download:
- Llama models - Requires Meta license acceptance
- Gemma models - Requires Google license acceptance
For HPC/non-interactive environments, set the token as an environment variable:
export HF_TOKEN=hf_your_token_here
export HF_HOME=$SCRATCH/.hf_cache # Optional: custom cache locationChoose your environment: Local (workstations) or HPC (SLURM clusters).
For workstations with GPU, use a virtual environment:
Automated Setup (Recommended):
# Run the setup script
./setup.sh # Basic installation
./setup.sh --unsloth # Include Unsloth for 2-5x faster training
./setup.sh --dev # Include development tools (pytest, black, etc.)
./setup.sh --unsloth --dev # Both Unsloth and dev tools
# Activate the environment
source venv/bin/activateThe setup script will:
- Check Python version (3.10+ required)
- Detect CUDA and install appropriate PyTorch
- Create a virtual environment in
./venv - Install all dependencies
- Verify the installation
Manual Setup:
# Create virtual environment
python3 -m venv venv
source venv/bin/activate
# Upgrade pip
pip install --upgrade pip wheel setuptools
# Install PyTorch (adjust for your CUDA version)
# For CUDA 12.x:
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# For CUDA 11.x:
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# For CPU only:
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
# Install dependencies
pip install -r requirements.txt
# Optional: Install Unsloth for faster training
pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"For SLURM clusters, use system modules with user-level pip packages:
# 1. Load system PyTorch module (optimized for cluster hardware)
module load pytorch # Generic name
# module load pytorch/2.9 # Or version-specific (check: module avail pytorch)
# 2. Set up user package directory (persistent across jobs)
export PYTHONUSERBASE=$SCRATCH/.local
export PATH=$PYTHONUSERBASE/bin:$PATH
# 3. Install dependencies to user directory
pip install --user -r requirements.txt
# 4. (Optional) Install Unsloth for 2-5x faster training
pip install --user "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"Why use system modules on HPC?
- System PyTorch is optimized for cluster GPU hardware (CUDA, cuDNN, NCCL)
- Avoids conda/venv overhead on shared filesystems
pip install --userpackages persist across sessions in$PYTHONUSERBASE
Add these exports to your ~/.bashrc or SLURM scripts for persistence:
export PYTHONUSERBASE=$SCRATCH/.local
export PATH=$PYTHONUSERBASE/bin:$PATHSee docs/HPC_SETUP.md for detailed cluster-specific instructions.
| Dataset | Examples | Description | Access |
|---|---|---|---|
| PubMedQA | ~211k artificial, ~1k labeled | Yes/no/maybe questions from PubMed abstracts | Free download |
| BioASQ | ~4k questions | Biomedical QA challenge data | Free registration required |
RAFT Data Format: Each training example contains:
{
"messages": [
{"role": "system", "content": "You are a biomedical assistant..."},
{"role": "user", "content": "Question: ... Document [1]: ... Document [2]: ..."},
{"role": "assistant", "content": "Answer with ##begin_quote##...##end_quote## [1]..."}
]
}The conversion scripts handle:
- Oracle ratio (80/20): 80% of examples include the answer document, 20% contain only distractors
- Distractor selection: Semantically similar but non-answering documents
- Citation formatting:
[1],[2], etc. with quote markers
# Download PubMedQA dataset
python scripts/data_prep/01_download_pubmedqa.py \
--output_dir ./data/raw/pubmedqa \
--explore # Show sample data
# Convert to RAFT format
python scripts/data_prep/02_convert_pubmedqa_to_raft.py \
--input_dir ./data/raw/pubmedqa \
--output_dir ./data/processed/raft_pubmedqa \
--oracle_ratio 0.8 \
--num_workers 4BioASQ requires free registration. Download the training data JSON files after registering:
# Process BioASQ JSON files
python scripts/data_prep/03_process_bioasq.py \
--input_dir /path/to/bioasq/json \
--output_dir ./data/raw/bioasq
# Convert to RAFT format
python scripts/data_prep/04_convert_bioasq_to_raft.py \
--input_file ./data/raw/bioasq/bioasq_processed.json \
--output_dir ./data/processed/raft_bioasqpython scripts/data_prep/05_combine_and_split.py \
--input_dirs ./data/processed/raft_pubmedqa/pqa_artificial \
./data/processed/raft_pubmedqa/pqa_labeled \
--output_dir ./data/processed/raft_final \
--train_ratio 0.9 \
--val_ratio 0.05 \
--test_ratio 0.05Add ./data/processed/raft_bioasq to --input_dirs if using BioASQ.
Expected output: After running all steps, you should have:
data/processed/raft_final/
├── train/ # ~190k examples (90%)
├── validation/ # ~10k examples (5%)
└── test/ # ~10k examples (5%)
Any causal language model from Hugging Face Models can be used. Recommended models for biomedical QA:
| Model | Size | GPU Memory | License | Notes |
|---|---|---|---|---|
| Qwen/Qwen3-4B-Instruct-2507 | 4B | ~16GB | Apache 2.0 | Default - Good balance of quality/speed |
| Qwen/Qwen3-8B-Instruct-2507 | 8B | ~24GB | Apache 2.0 | Better reasoning capability |
| meta-llama/Llama-3.1-8B-Instruct | 8B | ~24GB | Llama 3.1 |
Strong general performance |
| mistralai/Mistral-7B-Instruct-v0.3 | 7B | ~20GB | Apache 2.0 | Fast inference |
| google/gemma-2-9b-it | 9B | ~24GB | Gemma |
Good instruction following |
| BioMistral/BioMistral-7B | 7B | ~20GB | Apache 2.0 | Pre-trained on biomedical text |
To use a different model:
python scripts/training/train_raft.py --model_name meta-llama/Llama-3.1-8B-Instruct ...python scripts/training/train_raft.py \
--model_name Qwen/Qwen3-4B-Instruct-2507 \
--dataset_path ./data/processed/raft_final \
--output_dir ./checkpoints/raft_pubmed \
--max_seq_length 1024 \
--num_train_epochs 2 \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 8 \
--learning_rate 2e-4 \
--lora_r 16 \
--lora_alpha 32 \
--packing-
Edit SLURM scripts with your project account:
nano scripts/slurm/train.sh # Update: --account=your_project -
Run test job:
sbatch scripts/slurm/train_test.sh
-
Run full training:
sbatch scripts/slurm/train.sh
-
Resume if interrupted:
sbatch scripts/slurm/resume.sh
See docs/HPC_SETUP.md for detailed HPC instructions.
For multi-GPU distributed training, disable Unsloth (which is single-GPU optimized):
torchrun --nproc_per_node=4 scripts/training/train_raft.py \
--model_name Qwen/Qwen3-4B-Instruct-2507 \
--dataset_path ./data/processed/raft_final \
--output_dir ./checkpoints/raft_pubmed \
--no_unsloth # Required for multi-GPU DDPTensorBoard logs are saved automatically:
# Start TensorBoard
tensorboard --logdir ./checkpoints/raft_pubmed/logs --port 6006
# View at http://localhost:6006What to look for:
- Training loss: Should decrease steadily; expect ~1.5-2.5 initially, dropping to ~0.8-1.2
- Learning rate: Should warm up then decay (cosine schedule)
- Gradient norm: Should stay below
max_grad_norm(0.3 default)
Quick sanity check during training:
# Check latest checkpoint
ls -la ./checkpoints/raft_pubmed/checkpoint-*/
# View training logs
tail -f ./checkpoints/raft_pubmed/logs/*/events.out.*| Mode | Best For | Trade-off |
|---|---|---|
| Completions-only (default) | RAFT training | Trains only on model responses, not context. Better for learning citation format. |
| Packing | Maximum throughput | Packs multiple examples per sequence. Faster but incompatible with completions-only. |
These modes are mutually exclusive. The default (--train_on_completions=True) is recommended for RAFT as it focuses learning on the citation/response format rather than memorizing document content.
To use packing instead (for faster training on general tasks):
python scripts/training/train_raft.py \
--train_on_completions=False \
--packingInstead of command-line arguments, use configs/raft_config.yaml:
# Edit config file
nano configs/raft_config.yaml
# The training script reads from this file by default
# Override specific values via command line as neededAll parameters with their defaults (override via command line):
Model & LoRA:
| Parameter | Default | Description |
|---|---|---|
--model_name |
Qwen/Qwen3-4B-Instruct-2507 |
Base model from HuggingFace |
--max_seq_length |
2048 |
Maximum context length |
--load_in_4bit |
False |
4-bit quantization (single GPU only) |
--lora_r |
32 |
LoRA rank (16-32 recommended) |
--lora_alpha |
32 |
LoRA alpha scaling |
--lora_dropout |
0.0 |
LoRA dropout (0 recommended) |
--use_rslora |
True |
Rank-Stabilized LoRA |
--no_unsloth |
False |
Disable Unsloth (required for multi-GPU DDP) |
Training:
| Parameter | Default | Description |
|---|---|---|
--num_train_epochs |
2 |
Number of training epochs |
--per_device_train_batch_size |
8 |
Batch size per GPU |
--gradient_accumulation_steps |
4 |
Gradient accumulation steps |
--learning_rate |
2e-4 |
Peak learning rate |
--warmup_ratio |
0.03 |
Warmup proportion of training |
--weight_decay |
0.01 |
Weight decay for regularization |
--max_grad_norm |
0.3 |
Gradient clipping threshold |
--train_on_completions |
True |
Train only on assistant responses |
--packing |
False |
Sequence packing (incompatible with completions-only) |
--max_samples |
None |
Limit training examples (useful for testing) |
Checkpoints & Logging:
| Parameter | Default | Description |
|---|---|---|
--output_dir |
./checkpoints/raft_pubmed |
Output directory |
--logging_steps |
10 |
Log every N steps |
--save_steps |
500 |
Save checkpoint every N steps |
--eval_steps |
500 |
Evaluate every N steps |
--save_total_limit |
3 |
Max checkpoints to keep |
--seed |
42 |
Random seed |
--run_name |
raft-pubmed |
Name for experiment tracking |
--resume_from_checkpoint |
None |
Path to checkpoint or auto to find latest |
Before full evaluation, verify the model works:
# Interactive test with a few examples
python scripts/deployment/inference.py \
--adapter_path ./checkpoints/raft_pubmed \
--interactiveTry asking a question with documents - the model should cite sources with [1], [2] notation and use ##begin_quote##...##end_quote## markers.
python scripts/evaluation/evaluate_raft.py \
--model_path ./checkpoints/raft_pubmed \
--test_data ./data/processed/raft_final \
--output_dir ./evaluation_results \
--max_examples 500 # For quick evaluation; remove for full test set| Metric | Description | Target | Notes |
|---|---|---|---|
| Citation Validity | Are cited document IDs real? | >95% | Should be near-perfect after training |
| Quote Accuracy | Are quotes found in cited documents? | >85% | Measures faithful extraction |
| Oracle Identification | Does model cite oracle when present? | >80% | Key RAFT metric |
| IDK Precision | Does model abstain when oracle absent? | >70% | Tests epistemic humility |
| ROUGE-1 | Unigram overlap with reference | >0.4 | General answer quality |
| ROUGE-L | Longest common subsequence | >0.35 | Fluency and structure |
Interpreting results:
- Low Citation Validity (<90%): Model may be hallucinating document IDs. Check training data format.
- Low Oracle Identification (<70%): Model isn't learning to identify relevant documents. Try longer training or more data.
- Low IDK Precision (<50%): Model answers when it shouldn't. Verify 80/20 oracle split in training data.
- Low ROUGE scores: May indicate answer quality issues, but RAFT models often have lower ROUGE due to citation overhead.
Baseline comparison (approximate, varies by model/data):
| Model | Citation Valid | Oracle ID | IDK Precision |
|---|---|---|---|
| Base model (no RAFT) | ~20% | ~40% | ~10% |
| After RAFT training | >95% | >80% | >70% |
python scripts/deployment/merge_lora.py \
--base_model Qwen/Qwen3-4B-Instruct-2507 \
--adapter_path ./checkpoints/raft_pubmed \
--output_path ./checkpoints/raft_merged# Interactive mode with LoRA adapter
python scripts/deployment/inference.py \
--adapter_path ./checkpoints/raft_pubmed \
--interactive
# With merged model (faster loading)
python scripts/deployment/inference.py \
--model_path ./checkpoints/raft_merged \
--interactiveWhen using the model, format your input with documents:
Question: What is the role of BRCA1 in cancer?
Document [1]: BRCA1 is a tumor suppressor gene that plays a critical role in
DNA repair through homologous recombination. Mutations in BRCA1 significantly
increase the risk of breast and ovarian cancer.
Document [2]: The cell cycle is regulated by cyclins and cyclin-dependent
kinases, which control progression through different phases.
Document [3]: BRCA1 also functions in cell cycle checkpoint control,
activating the G1/S and G2/M checkpoints in response to DNA damage.
Expected model output:
BRCA1 is a tumor suppressor gene with multiple roles in cancer prevention.
##begin_quote## BRCA1 is a tumor suppressor gene that plays a critical role
in DNA repair through homologous recombination ##end_quote## [1]. Additionally,
##begin_quote## BRCA1 also functions in cell cycle checkpoint control,
activating the G1/S and G2/M checkpoints in response to DNA damage ##end_quote##
[3]. Mutations in this gene ##begin_quote## significantly increase the risk of
breast and ovarian cancer ##end_quote## [1].
Note how the model:
- Cites relevant documents
[1]and[3] - Ignores irrelevant document
[2](about cell cycle, not BRCA1) - Uses
##begin_quote##...##end_quote##for verbatim text
# Requires llama.cpp
export LLAMA_CPP_PATH=/path/to/llama.cpp
./scripts/deployment/convert_to_gguf.sh \
./checkpoints/raft_merged \
raft-pubmed \
Q4_K_MThen use with Ollama:
ollama create raft-pubmed -f Modelfile
ollama run raft-pubmedpubmed_finetune_RAFT/
├── README.md
├── setup.sh # Automated setup script
├── requirements.txt
├── .gitignore
├── .env.example
├── configs/
│ └── raft_config.yaml
├── docs/
│ ├── DATA_SOURCES.md
│ ├── RAFT_EXPLANATION.md
│ └── HPC_SETUP.md
├── src/
│ ├── __init__.py
│ └── prompts.py # Consolidated prompts
├── scripts/
│ ├── data_prep/
│ │ ├── 01_download_pubmedqa.py
│ │ ├── 02_convert_pubmedqa_to_raft.py
│ │ ├── 03_process_bioasq.py
│ │ ├── 04_convert_bioasq_to_raft.py
│ │ └── 05_combine_and_split.py
│ ├── training/
│ │ └── train_raft.py
│ ├── slurm/
│ │ ├── train.sh
│ │ ├── train_test.sh
│ │ └── resume.sh
│ ├── evaluation/
│ │ └── evaluate_raft.py
│ └── deployment/
│ ├── merge_lora.py
│ ├── inference.py
│ └── convert_to_gguf.sh
├── venv/ # Git-ignored (created by setup.sh)
├── data/ # Git-ignored
├── checkpoints/ # Git-ignored
└── logs/ # Git-ignored
# Reduce batch size
--per_device_train_batch_size 2
# Increase gradient accumulation
--gradient_accumulation_steps 16
# Use 4-bit quantization
--load_in_4bit
# Reduce sequence length
--max_seq_length 1024# Install Unsloth
pip install unsloth
# Enable packing (if not using completions-only)
--packing# Ensure correct Python environment
which python
pip list | grep -E "torch|transformers|peft"If you use this code, please cite:
@software{raft_pubmed_finetune,
title={RAFT Fine-tuning for Biomedical RAG},
author={Harlan Barker},
year={2024},
url={https://github.com/thirtysix/biomedical-RAFT-toolkit}
}@article{zhang2024raft,
title={RAFT: Adapting Language Model to Domain Specific RAG},
author={Zhang, Tianjun and others},
journal={arXiv preprint arXiv:2403.10131},
year={2024}
}
@inproceedings{jin2019pubmedqa,
title={PubMedQA: A Dataset for Biomedical Research Question Answering},
author={Jin, Qiao and others},
booktitle={EMNLP},
year={2019}
}Documentation developed with assistance from Claude.
This project is licensed under the MIT License - see the LICENSE file for details.
Contributions are welcome! Please feel free to submit a Pull Request.
