Skip to content

ii-research/ARF-RAG

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ARF-RAG:Towards Adaptive and Retriever-friendly Retrieval-augmented Generation via Reinforcement Learning

This repo includes the original implementation of ARF-RAG, a novel retrieval-augmented generation framework that enables a single LLM to perform adaptive retrieval with retriever-friendly query generation through reinforcement learning.

Local Enviroment Installation

git clone https://github.com/bobbyfyb/ARF-RAG.git

conda create -n ARF-RAG python=3.8

conda activate ARF-RAG

pip install -r requirements.txt
           
cd rl

pip install -e .

Two-Stage Optimization of ARF-RAG

Warm-up Supervised Fine-tuning

Retriever Setting

In ARF-RAG, we adopt Contriever as the retriever for the training and inferencing. Please refer to Contriever for more detail about the indexing of the wiki corpus. Specifically, for the training and inferencing of ARF-RAG, please following the steps below to set up the Contriever retriever:

  1. clone the Contriever repo into ./retriever.
cd ./retriever
git clone https://github.com/facebookresearch/contriever.git
  1. follow the instruction in the Contriever repo to index the wiki corpus and save the index and passage corpus file.

  2. Modify the path in retriever_server.py and launch the Contriever retriever server with the following command:

python retriever_server.py

Warm-up Dataset Generating and Annotating

The SFT stage trains the LLM backbone to learn an initial policy for making decisions on when to retrieve and generating user-friendly queries. To train the SFT model, we first need to generate and annotate a warm-up dataset through 2 stages of annotating following the steps below:

  1. Collecting the self-generated optimized queries for each question of the dataset with the LLM backbone for annotating. For this step please refer to the LeReT repo and convert the generated preference-based dataset to reward score-based dataset and replace the BEST_QUERIES_FILE path in sft/data_class.py for the annotating step.

  2. cd ./sft and run the following commond to generate answers for each question:

python generate_answers.py \
    --base_model_path your_base_model_path \
    --data_path your_data_path \
    --output_path your_output_path \
    --dataset the name of the dataset \
    --split the split of the dataset \
    --generator the name of the generator model \
    --stage 1 \ # 1 for the first stage of annotating, 2 for the second stage of annotating
  1. evluate the F1 score of the gneerate answers for annotating
python evaluate_f1.py \
    --input_dir the dir containing the generated answers \
    --output_dir the dir to save the evaluation results
  1. First round of annoatating:
python annotate_answer.py \
    --data_dir the dir containing the generated answers with F1 score \
    --dataset the name of the dataset \
    --split the split of the dataset \
    --generator the name of the generator model \
    --stage 1 \ # 1 for the first stage of annotating, 2 for the second stage of annotating
  1. Second round of annotating: run the same command as the first round of annotating but change the stage to 2 and data path to the incorrectly answered questions in the first round of annotating.

  2. merge the annotated dataset from the two rounds of annotating to get the final warm-up dataset for SFT training.

python merge_datasets.py 

SFT Training

After generating and annotating the warm-up dataset, we can SFT the Qwen3 model with the following command:

python fine_tuning.py

Joint Training of Reinforcement Learning

For the PPO training of Qwen3, run the following command:

python rl/scripts/training/train_text_generation.py --config_path scripts/training/task_configs/{your_target_dataset_name}/qwen3_ppo_contriever_debug.yml --experiment_name qwen3_contriever_ppo

Inference and Evaluation

cd inference 

python evaluate_retrieval_qwen.py --base_model_path your_sft_model_path --checkpoint your_ppo_checkpoint_path 

Acknowledgement

This repository is built upon the following publicly available codebases:

SmartRAG: Jointly Learn RAG-Related Tasks From the Environment Feedback

LeReT: LeReT: Learning to Retrieve by Trying

About

TBA

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors