Skip to content

OLIVER-XYP/agro_sql

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Agentic Text-to-SQL:冷启动 SFT → Agentic RL 训练全流程

本仓库实现了一套完整的多轮 Text-to-SQL 训练流程,包括:数据合成(Cold-Start SFT 数据)→ 冷启动监督微调(SFT)→ Agentic 强化学习(RL)→ 课程式二阶段 RL,以及完整的多数据集评测链路。


目录

  1. 项目概述
  2. 项目目录与组件关系
  3. 数据合成流程
  4. 训练流程总览
  5. 冷启动 SFT 详细步骤
  6. Agentic RL 详细步骤
  7. 评测方式
  8. 产物与检查点说明
  9. 一键最短路径命令清单
  10. 常见问题

1. 项目概述

目标:训练一个能够通过多轮 SQL 执行与自我修正,最终给出正确 SQL 答案的 Text-to-SQL Agent。

核心能力边界

  • 输入:自然语言问题 + 数据库 Schema(含字段注释、外部知识)
  • 输出:可执行的、结果等价于 gold SQL 的 SQL 语句
  • 交互模式:多轮对话(最多 6 轮),每轮可执行 <sql> 查询并观察结果,最终输出 <solution>
  • 支持数据集:BIRD、Spider 系列(Spider-DK/Realistic/Syn)、EHRSQL、ScienceBenchmark、Spider2.0

三阶段训练

Stage 0: Cold-Start SFT(Teacher 采样轨迹 → LLaMA-Factory SFT)
    ↓
Stage 1: Agentic RL(GRPO + DAPO 动态采样,SkyRL 框架)
    ↓
Stage 2: Curriculum RL(Stage 2 数据 + KL 正则化,SkyRL 框架)

2. 项目目录与组件关系

ypxu_scripts_and_code/
├── cold_sft/                        # Stage 0:冷启动 SFT 数据合成与训练
│   ├── cold_sft_synthesis/          # 数据合成引擎
│   │   ├── synthesize.py            # ★ 主合成脚本:原始数据 → trajectories.jsonl
│   │   ├── rl2sft.py                # ★ 格式转换:bird_train.json → bird_sft_train.json
│   │   ├── episode.py               # 多轮 episode 执行(SQL 执行 + EX 校验)
│   │   ├── model.py                 # TeacherClient(异步 OpenAI-compatible 封装)
│   │   ├── execution.py             # SQL 执行工具
│   │   └── sql/
│   │       ├── env.py               # SkyRL SQL 环境(RL 训练中使用)
│   │       └── utils.py             # 奖励计算
│   └── LLaMA-Factory/               # SFT 训练框架(submodule)
│       └── train.sh                 # ★ SFT 训练启动脚本
│
├── SkyRL/                           # Stage 1 & 2:Agentic RL 训练框架
│   ├── skyrl-train/
│   │   └── examples/text_to_sql/
│   │       ├── dapo_train.sh        # ★ Stage 1:DAPO RL 训练
│   │       └── curriculum-learning/
│   │           └── agentic-run-stage2.sh  # ★ Stage 2:课程式 RL
│   ├── skyrl-gym/                   # 环境定义(SQL 环境)
│   └── skyrl-agent/                 # Agent 层
│
├── eval/                            # 评测流水线
│   ├── process_dataset.py           # Step 1:生成 multi-turn prompt
│   ├── infer.py                     # Step 2:vLLM 多轮推理
│   ├── evaluate_bird.py             # Step 3:BIRD/EHRSQL/ScienceBenchmark 评测
│   ├── evaluate_spider.py           # Step 3:Spider 系列评测
│   └── evaluate_spider2.py          # Step 3:Spider2.0 评测
│
└── outputs/                         # 训练输出(日志、TensorBoard 等)

组件依赖关系

bird/bird_train.json
    │ rl2sft.py
    ▼
bird_sft_train.json
    │ synthesize.py(调用 episode.py + model.py)
    ▼
trajectories.jsonl
    │ LLaMA-Factory SFT
    ▼
cold_sft/ckpts_4584/checkpoint-864/merged_model  ← Stage 0 输出
    │ dapo_train.sh(SkyRL GRPO)
    ▼
SkyRL/ckpts/agentic-sql-8B-ckpt  ← Stage 1 输出
    │ agentic-run-stage2.sh(SkyRL GRPO + KL)
    ▼
SkyRL/ckpts/curriculum-ckpts/stage2-ckpt  ← Stage 2 输出

3. 数据合成流程

3.1 整体流程

原始 BIRD 训练集 (bird/bird_train.json)
    │
    ▼  [Step 1] rl2sft.py
原始数据转 SFT 格式 (bird_sft_train.json)
    │
    ▼  [Step 2] synthesize.py × N samples
Teacher LLM 多轮采样 + SQL 执行 + EX 校验
    │
    ▼  只保留 EX 通过的轨迹
trajectories.jsonl  (最终 SFT 训练数据)

3.2 Step 1:原始数据格式转换(rl2sft.py)

脚本cold_sft/cold_sft_synthesis/rl2sft.py

输入bird/bird_train.json

  • 格式:RL 训练使用的 BIRD 数据集格式,每条记录包含:
    {
      "db_id": "concert_singer",
      "prompt": [
        {"role": "system", "content": "... {db_details} ... {external_knowledge} ... {question} ..."},
        {"role": "user",   "content": "{db_details}: ...\n{external_knowledge}: ...\n{question}: ..."}
      ],
      "sql": "SELECT ..."
    }
  • prompt[0](system):包含完整的 prompt 模板,含 {db_details}{external_knowledge}{question} 占位符
  • prompt[1](user):包含具体的数据库结构、外部知识、问题内容

处理逻辑

  1. prompt[1](user content)中用正则提取三个字段:
    • {db_details}:数据库 Schema 详情
    • {external_knowledge}:外部知识(如领域知识、枚举值说明)
    • {question}:自然语言问题
  2. 将三个字段填入 prompt[0](system template),生成完整的单轮 prompt
  3. 默认移除 prompt 中的示例块(START OF EXAMPLE ... END OF EXAMPLE),可用 --keep-example 保留

输出bird_sft_train.json

[
  {
    "db_id": "concert_singer",
    "prompt": "<完整的单轮合成 prompt,含 db_details、external_knowledge、question>",
    "sql": "SELECT ..."
  }
]

运行命令

cd cold_sft/cold_sft_synthesis
python rl2sft.py                # 默认移除示例块
#
python rl2sft.py --keep-example # 保留示例块

注意:输入路径 bird/bird_train.json 和输出路径 bird_sft_train.json 均为脚本中硬编码的相对路径,需在 cold_sft/cold_sft_synthesis/ 目录下运行,或修改脚本顶部的 INPUT_PATH / OUTPUT_PATH


3.3 Step 2:多轮轨迹合成(synthesize.py)

脚本cold_sft/cold_sft_synthesis/synthesize.py

输入bird_sft_train.json(来自 Step 1)

核心参数

参数 默认值 说明
--input bird_sft_train.json 输入 JSON 文件
--output trajectories.jsonl 输出 JSONL 文件
--db-path bird/train_databases SQLite 数据库根目录
--api-base 必填 OpenAI-compatible API 地址
--api-key EMPTY API 密钥
--model 必填 Teacher 模型名称
--samples-per-question 20 每道题采样几条轨迹
--num-workers 32 并发 worker 数
--max-turns 6 每条轨迹最大交互轮数
--num-examples 全量 随机抽取 N 道题用于 SFT(剩余用于 RL)
--random-seed 42 随机种子
--no-resume 关闭 默认支持断点续跑

运行命令

cd cold_sft/cold_sft_synthesis

# 全量数据,每题 20 个样本
python synthesize.py \
    --input bird_sft_train.json \
    --output trajectories.jsonl \
    --db-path bird/train_databases \
    --api-base http://your-llm-api:8000/v1 \
    --api-key EMPTY \
    --model Qwen2.5-72B-Instruct \
    --samples-per-question 20 \
    --num-workers 32 \
    --max-turns 6

# 随机抽取 500 道题(同时保存 example ID,用于 RL 训练排除)
python synthesize.py \
    --input bird_sft_train.json \
    --output trajectories.jsonl \
    --db-path bird/train_databases \
    --api-base http://your-llm-api:8000/v1 \
    --model Qwen2.5-72B-Instruct \
    --num-examples 500 \
    --random-seed 42

合成机制(episode.py)

每条轨迹的生成过程(run_episode):

  1. 向 Teacher LLM 发送当前 prompt(含历史),等待回复
  2. 解析回复中的 <think>...</think><sql>...</sql><solution>...</solution> 标签
    • 若格式错误,最多重试 3 次,并附加格式修正提示(RETRY_PROMPT
  3. 若当前步是中间步(含 <sql>):
    • 在本地执行 SQL,超时时间 10 秒
    • 结果超过 50 行时截断
    • 将执行结果作为 <observation> 追加到历史
  4. 若当前步是终止步(含 <solution>):
    • 执行等价性校验(EX check):对比 prediction SQL 和 gold SQL 的执行结果集合
    • 只保留 EX 通过的轨迹

输出格式(trajectories.jsonl)

每行一条 JSON 记录:

{
  "db_id": "concert_singer",
  "example_id": 42,
  "sample_id": 3,
  "base_prompt": "<合成 prompt 文本>",
  "gold_sql": "SELECT ...",
  "turns": [
    {
      "turn": 1,
      "think": "先看一下表结构...",
      "sql": "SELECT * FROM singer LIMIT 5;",
      "observation": "\n\n<observation>\nSTATUS: OK\nROW_COUNT: 5\n...\n</observation>\n\n",
      "solution": null,
      "is_final": false
    },
    {
      "turn": 2,
      "think": "从结果来看,答案应该是...",
      "sql": null,
      "observation": null,
      "solution": "SELECT singer_name FROM singer WHERE age > 25;",
      "is_final": true
    }
  ]
}

字段说明:

  • example_id:对应 bird_sft_train.json 中的题目索引
  • sample_id:同一道题的第几次采样(0 到 samples_per_question-1
  • turns:每个元素代表一个交互轮次
    • 中间轮次:think + sql + observationsolution=null
    • 最终轮次:think + solutionsql=nullobservation=nullis_final=true

断点续跑:脚本默认开启 resume,重启后会读取已有 trajectories.jsonl,跳过已完成的 (example_id, sample_id) 对。如需从头重跑,加 --no-resume

SFT 数据拆分:使用 --num-examples N 时,被采样的题目 ID 会保存到 trajectories_example_ids.json,可用于在 RL 训练时排除这部分题目,避免数据泄漏。


4. 训练流程总览

Stage 0: Cold-Start SFT
  输入:trajectories.jsonl(Teacher 采样的成功轨迹)
  工具:LLaMA-Factory(train.sh)
  输出:cold_sft/ckpts_4584/checkpoint-864/merged_model
  目的:让模型学会多轮 SQL Agent 的基本格式与策略

Stage 1: Agentic RL (DAPO)
  输入:Stage 0 的 merged_model + SkyRL/data/bird_train_rl.parquet
  工具:SkyRL(dapo_train.sh)
  算法:GRPO + DAPO 动态采样
  输出:SkyRL/ckpts/agentic-sql-8B-ckpt
  目的:通过在线 RL 进一步提升 Agent 的探索与修正能力

Stage 2: Curriculum RL(可选)
  输入:Stage 1 的 merged_model + SkyRL/data/curriculum_learning/stage2_train.parquet
  工具:SkyRL(agentic-run-stage2.sh)
  算法:GRPO + KL 正则化
  输出:SkyRL/ckpts/curriculum-ckpts/stage2-ckpt
  目的:在更难的课程数据上继续优化,加 KL 约束防止策略偏移

阶段间依赖

Stage 0 输出 (merged_model)
    └──→ Stage 1 的 MODEL_PATH(dapo_train.sh 第 9 行)

Stage 1 输出 (stage1-ckpt/global_step_60/merged_model)
    └──→ Stage 2 的 MODEL_PATH(agentic-run-stage2.sh 第 8 行)

5. 冷启动 SFT 详细步骤

5.1 环境准备

LLaMA-Factory 使用 UV 包管理工具管理依赖:

cd cold_sft/LLaMA-Factory

# 配置 UV 路径(根据实际环境修改 train.sh 中的路径)
export UV_PYTHON_INSTALL_DIR=/your/path/.uv/python
export UV_CACHE_DIR=/your/path/.cache/uv
export UV_PROJECT_ENVIRONMENT=/your/path/venvs/llamafactory

# 安装依赖(离线模式)
uv --offline sync --extra torch --extra metrics --prerelease=allow

5.2 数据准备

确保已完成 第 3 节 中的数据合成,得到 trajectories.jsonl

需将 trajectories.jsonl 转换为 LLaMA-Factory 所需的 SFT 训练格式(多轮对话格式),具体格式转换逻辑参见 LLaMA-Factory 文档。

5.3 启动训练

cd cold_sft/LLaMA-Factory
bash train.sh

关键超参(在 LLaMA-Factory 的 YAML 配置中设置,train.sh 负责环境初始化):

  • 模型:Qwen/Qwen2.5 系列(8B)
  • 训练数据:转换后的多轮 SFT 格式
  • 方法:全量 SFT(full fine-tuning)或 LoRA

5.4 输出

训练完成后,合并模型权重:

# LLaMA-Factory 内置合并命令(LoRA 场景)
llamafactory-cli export \
    --model_name_or_path /path/to/base_model \
    --adapter_name_or_path /path/to/lora_checkpoint \
    --export_dir cold_sft/ckpts_4584/checkpoint-864/merged_model \
    --export_size 4

输出路径:cold_sft/ckpts_4584/checkpoint-864/merged_model(供 Stage 1 使用)


6. Agentic RL 详细步骤

6.1 Stage 1:DAPO 训练

脚本SkyRL/skyrl-train/examples/text_to_sql/dapo_train.sh

路径配置(脚本顶部,需根据实际路径修改):

ROOT_DIR="/volume/pt-coder/users/hxyan/agentic-sql"
DATA_DIR="$ROOT_DIR/SkyRL/data"
DB_PATH="$ROOT_DIR/SkyRL/data"
CKPT_PATH="$ROOT_DIR/SkyRL/ckpts/agentic-sql-8B-ckpt"      # Stage 1 检查点保存路径
MODEL_PATH="$ROOT_DIR/cold_sft/ckpts_4584/checkpoint-864/merged_model"  # ← Stage 0 输出

关键超参

参数 说明
trainer.algorithm.advantage_estimator grpo 使用 GRPO 算法
trainer.epochs 30 训练轮数
trainer.policy.optimizer_config.lr 5.0e-6 学习率
trainer.train_batch_size 320 训练批次大小
trainer.micro_train_batch_size_per_gpu 4 每 GPU 微批次
trainer.micro_forward_batch_size_per_gpu 8 前向微批次
generator.n_samples_per_prompt 10 每个 prompt 采样数
generator.max_turns 6 最大交互轮数
generator.sampling_params.temperature 0.7 采样温度
generator.sampling_params.top_p 0.95 Top-p
generator.max_input_length 29000 最大输入长度(token)
generator.sampling_params.max_generate_length 3000 最大生成长度
trainer.max_prompt_length 9000 训练时 prompt 最大长度
trainer.algorithm.use_kl_loss false 不使用 KL loss
trainer.algorithm.eps_clip_low/high 0.2 / 0.28 DAPO 裁剪范围
trainer.algorithm.dynamic_sampling.type filter DAPO 动态采样类型
trainer.algorithm.dynamic_sampling.max_sample_batches 30 动态采样最大批次
trainer.ckpt_interval 60 每 60 step 保存一次 ckpt
trainer.hf_save_interval 30 每 30 step 保存一次 HF 格式模型
trainer.eval_interval 5 每 5 step 评估一次
trainer.strategy fsdp2 分布式训练策略
NUM_GPUS 8 训练 GPU 数
NUM_INFERENCE_ENGINES 2 推理引擎数(每个 TP=4)

数据

  • 训练集:SkyRL/data/bird_train_rl.parquet
  • 验证集:SkyRL/data/bird_dev_rl.parquet
  • 数据库:SkyRL/data/bird/train/train_databases/SkyRL/data/bird/dev/dev_databases/

奖励信号:由 SQLEnv._get_reward() 计算,在终止步(含 <solution>)时执行等价性校验,通过得 1.0,失败得 0.0,中间步奖励为 0。

停止词</sql></solution>

启动命令

cd SkyRL/skyrl-train
bash examples/text_to_sql/dapo_train.sh

6.2 Stage 2:课程式 RL(可选)

脚本SkyRL/skyrl-train/examples/text_to_sql/curriculum-learning/agentic-run-stage2.sh

路径配置

ROOT_DIR="/volume/pt-coder/users/hxyan/agentic-sql"
DATA_DIR="$ROOT_DIR/SkyRL/data"
CKPT_PATH="$ROOT_DIR/SkyRL/ckpts/curriculum-ckpts/stage2-ckpt"
MODEL_PATH="$ROOT_DIR/SkyRL/ckpts/curriculum-ckpts/stage1-ckpt/global_step_60/merged_model"  # ← Stage 1 输出

stage1-ckpt 需要手动从 Stage 1 的 DAPO 检查点中选取合适的 step,并确保已保存为 merged HF 格式。

与 Stage 1 的关键差异

参数 Stage 1 Stage 2 变化原因
训练数据 bird_train_rl.parquet curriculum_learning/stage2_train.parquet 更难的课程数据
epochs 30 10 防止在较小课程集上过拟合
lr 5.0e-6 2.0e-6 精调阶段降低学习率
train_batch_size 320 256 略减批次
use_kl_loss false true 防止策略偏离 Stage 1
kl_loss_coef 0.001 KL 正则化系数
temperature 0.7 0.8 增加探索多样性
ckpt_interval 60 15 更频繁保存
max_prompt_length 9000 10000 课程数据可能更长
gpu_memory_utilization 0.85 0.7 适配显存占用

启动命令

cd SkyRL/skyrl-train
bash examples/text_to_sql/curriculum-learning/agentic-run-stage2.sh

7. 评测方式

7.1 三步评测流程

Step 1:生成 Multi-Turn Prompt

cd eval

# BIRD(示例)
python process_dataset.py \
    --input_data_file /data/bird/dev.json \
    --output_data_file ./data/bird_dev_processed.json \
    --db_path /data/bird/databases \
    --tables /data/bird/tables.json \
    --source bird

Step 2:多轮推理

# Greedy Search(单 GPU TP=2)
CUDA_VISIBLE_DEVICES=0,1 python infer.py \
    --model_path /path/to/model \
    --input_path ./data/bird_dev_processed.json \
    --output_path ./results/bird_dev_greedy.json \
    --db_path /data/bird/databases \
    --max_turns 6 \
    --max_tokens 3000 \
    --n 1 \
    --temperature 0.0 \
    --tensor_parallel_size 2

# 多样本采样(用于 Major Voting / Pass@k)
CUDA_VISIBLE_DEVICES=0,1 python infer.py \
    --model_path /path/to/model \
    --input_path ./data/bird_dev_processed.json \
    --output_path ./results/bird_dev_sampling.json \
    --db_path /data/bird/databases \
    --max_turns 6 \
    --n 16 \
    --temperature 0.6

# 多 GPU 数据并行
CUDA_VISIBLE_DEVICES=0,1,2,3 python infer.py \
    --model_path /path/to/model \
    --input_path ./data/bird_dev_processed.json \
    --output_path ./results/bird_dev_greedy.json \
    --db_path /data/bird/databases \
    --num_gpus 4 \
    --tp_per_instance 1 \
    --n 1 \
    --temperature 0.0

Step 3:评估

# BIRD Greedy Search
python evaluate_bird.py \
    --input_path ./results/bird_dev_greedy.json \
    --db_path /data/bird/databases \
    --mode greedy_search

# BIRD Major Voting
python evaluate_bird.py \
    --input_path ./results/bird_dev_sampling.json \
    --db_path /data/bird/databases \
    --mode major_voting

# BIRD Pass@k
python evaluate_bird.py \
    --input_path ./results/bird_dev_sampling.json \
    --db_path /data/bird/databases \
    --mode pass@k

# Spider 系列
python evaluate_spider.py \
    --pred ./results/spider_dev_greedy.json \
    --gold /data/spider/dev_gold.sql \
    --db_path /data/spider/database \
    --mode greedy_search

7.2 支持的数据集

数据集 --source 参数 评估脚本
BIRD bird evaluate_bird.py
EHRSQL ehrsql evaluate_bird.py
ScienceBenchmark sciencebenchmark evaluate_bird.py
Spider spider evaluate_spider.py
Spider-DK spider_dk evaluate_spider.py
Spider-Realistic spider_realistic evaluate_spider.py
Spider-Syn spider_syn evaluate_spider.py
Spider2.0 spider2.0 evaluate_spider2.py

7.3 自动化多 Checkpoint 评估

cd eval

# 评估单个模型
python auto_evaluation.py \
    --output_ckpt_dir /path/to/model \
    --source bird \
    --input_file ./data/bird_dev_processed.json \
    --db_path /data/bird/databases \
    --eval_name bird_dev \
    --visible_devices 0,1 \
    --tensor_parallel_size 2 \
    --n 16

# 批量评估多个 checkpoint
python auto_evaluation.py \
    --output_ckpt_dir /path/to/checkpoints \
    --multiple_models \
    --source bird \
    --input_file ./data/bird_dev_processed.json \
    --db_path /data/bird/databases \
    --eval_name bird_dev_multi_ckpt \
    --visible_devices 0,1,2,3 \
    --tensor_parallel_size 2 \
    --n 16

输出:

  • eval/results/{eval_name}/:推理结果 JSON
  • eval/evaluation_results/{eval_name}/:评估指标 JSON 和可视化图表

8. 产物与检查点说明

SFT 检查点(Stage 0)

cold_sft/
├── ckpts_4584/checkpoint-864/
│   └── merged_model/          # 合并后的 HF 格式模型(供 Stage 1 使用)
└── tensorboard/               # TensorBoard 日志

RL 检查点(Stage 1)

SkyRL/ckpts/agentic-sql-8B-ckpt/
├── global_step_XX/            # 分布式 ckpt(每 60 step 保存一次)
│   └── merged_model/          # HF 格式模型(每 30 step 保存一次)
└── ...
  • 恢复训练:trainer.resume_mode=latest 自动加载最新检查点
  • 评测使用:取 global_step_XX/merged_model/ 目录

RL 检查点(Stage 2)

SkyRL/ckpts/curriculum-ckpts/
├── stage1-ckpt/               # Stage 1 选定的起始检查点
│   └── global_step_60/merged_model/
└── stage2-ckpt/               # Stage 2 输出(每 15 step 保存一次)

训练监控

  • TensorBoard 日志:outputs/ 目录
  • 项目名称:skyrlsql,运行名称:skyrlsql_repro
  • 查看命令:
    tensorboard --logdir outputs/

9. 一键最短路径命令清单

按执行顺序排列:

# ============================================================
# 【Stage 0 - 数据合成】
# ============================================================

# Step 0a: 格式转换(需在 cold_sft/cold_sft_synthesis/ 目录下运行)
cd cold_sft/cold_sft_synthesis
python rl2sft.py
# 输出:bird_sft_train.json

# Step 0b: Teacher 采样(替换 api-base 和 model 为实际值)
python synthesize.py \
    --input bird_sft_train.json \
    --output trajectories.jsonl \
    --db-path bird/train_databases \
    --api-base http://your-llm-api:8000/v1 \
    --model Qwen2.5-72B-Instruct \
    --samples-per-question 20 \
    --num-workers 32 \
    --max-turns 6
# 输出:trajectories.jsonl

# ============================================================
# 【Stage 0 - SFT 训练】
# ============================================================

# Step 0c: 启动 SFT 训练(需先将 trajectories.jsonl 转换为 LLaMA-Factory 格式)
cd cold_sft/LLaMA-Factory
bash train.sh
# 输出:cold_sft/ckpts_4584/checkpoint-864/merged_model

# ============================================================
# 【Stage 1 - Agentic RL (DAPO)】
# ============================================================

# Step 1: 修改 dapo_train.sh 中的 ROOT_DIR 后启动
cd SkyRL/skyrl-train
bash examples/text_to_sql/dapo_train.sh
# 输出:SkyRL/ckpts/agentic-sql-8B-ckpt/

# ============================================================
# 【Stage 2 - Curriculum RL(可选)】
# ============================================================

# Step 2: 修改 agentic-run-stage2.sh 中的 ROOT_DIR 和 MODEL_PATH 后启动
bash examples/text_to_sql/curriculum-learning/agentic-run-stage2.sh
# 输出:SkyRL/ckpts/curriculum-ckpts/stage2-ckpt/

# ============================================================
# 【评测】
# ============================================================

cd eval

# Step E1: 生成评测 prompt
python process_dataset.py \
    --input_data_file /data/bird/dev.json \
    --output_data_file ./data/bird_dev_processed.json \
    --db_path /data/bird/databases \
    --tables /data/bird/tables.json \
    --source bird

# Step E2: 推理
CUDA_VISIBLE_DEVICES=0,1 python infer.py \
    --model_path /path/to/final_model \
    --input_path ./data/bird_dev_processed.json \
    --output_path ./results/bird_dev_greedy.json \
    --db_path /data/bird/databases \
    --max_turns 6 --n 1 --temperature 0.0 --tensor_parallel_size 2

# Step E3: 评估
python evaluate_bird.py \
    --input_path ./results/bird_dev_greedy.json \
    --db_path /data/bird/databases \
    --mode greedy_search

10. 常见问题

Q1:路径错误 / FileNotFoundError

  • rl2sft.py 的输入路径 bird/bird_train.json 是相对于 cold_sft/cold_sft_synthesis/ 的,需在该目录下运行
  • dapo_train.shagentic-run-stage2.sh 中的 ROOT_DIR 需修改为实际绝对路径
  • 数据库路径遵循统一结构:{db_path}/{db_id}/{db_id}.sqlite
  • SQLEnvtask 字段决定数据库子路径:bird_trainbird/train/train_databases/bird_devbird/dev/dev_databases/

Q2:显存不足(OOM)

  • Stage 1:减小 trainer.micro_train_batch_size_per_gpu(当前为 4)或 trainer.micro_forward_batch_size_per_gpu(当前为 8)
  • Stage 2:已将 gpu_memory_utilization 从 0.85 降至 0.7,若仍 OOM 可继续降低
  • 评测推理:减小 --tensor_parallel_size 或减少 --n

Q3:恢复 RL 训练

RL 脚本已配置 trainer.resume_mode=latest,重启后自动加载 CKPT_PATH 下的最新检查点,无需修改命令。

Q4:synthesize.py 中断后如何续跑

脚本默认支持断点续跑(append 模式写文件),重启后会自动读取已完成的 (example_id, sample_id) 对并跳过,无需任何额外操作。若需强制从头开始,加 --no-resume

Q5:数据格式错误(Teacher 输出格式不合规)

episode.py 内置格式修复逻辑,每轮最多重试 3 次。若仍失败,该条轨迹直接丢弃(返回 None)。Teacher 模型能力越强,成功率越高;建议使用 70B 级别以上的模型。

Q6:如何排除 SFT 数据在 RL 训练中使用

使用 --num-examples N 运行 synthesize.py 时,被选中的题目 ID 会保存到 trajectories_example_ids.json。在准备 RL 训练数据(bird_train_rl.parquet)时,过滤掉这些 ID 对应的题目即可。

Q7:Stage 1 到 Stage 2 如何选择起始检查点

Stage 2 脚本中使用的是 stage1-ckpt/global_step_60/merged_model,即 Stage 1 第 60 步的 merged 模型。实际使用时,建议通过评测分数选择 Stage 1 最佳检查点,而非固定使用 step 60。

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors