本仓库实现了一套完整的多轮 Text-to-SQL 训练流程,包括:数据合成(Cold-Start SFT 数据)→ 冷启动监督微调(SFT)→ Agentic 强化学习(RL)→ 课程式二阶段 RL,以及完整的多数据集评测链路。
目标:训练一个能够通过多轮 SQL 执行与自我修正,最终给出正确 SQL 答案的 Text-to-SQL Agent。
核心能力边界:
- 输入:自然语言问题 + 数据库 Schema(含字段注释、外部知识)
- 输出:可执行的、结果等价于 gold SQL 的 SQL 语句
- 交互模式:多轮对话(最多 6 轮),每轮可执行
<sql>查询并观察结果,最终输出<solution> - 支持数据集:BIRD、Spider 系列(Spider-DK/Realistic/Syn)、EHRSQL、ScienceBenchmark、Spider2.0
三阶段训练:
Stage 0: Cold-Start SFT(Teacher 采样轨迹 → LLaMA-Factory SFT)
↓
Stage 1: Agentic RL(GRPO + DAPO 动态采样,SkyRL 框架)
↓
Stage 2: Curriculum RL(Stage 2 数据 + KL 正则化,SkyRL 框架)
ypxu_scripts_and_code/
├── cold_sft/ # Stage 0:冷启动 SFT 数据合成与训练
│ ├── cold_sft_synthesis/ # 数据合成引擎
│ │ ├── synthesize.py # ★ 主合成脚本:原始数据 → trajectories.jsonl
│ │ ├── rl2sft.py # ★ 格式转换:bird_train.json → bird_sft_train.json
│ │ ├── episode.py # 多轮 episode 执行(SQL 执行 + EX 校验)
│ │ ├── model.py # TeacherClient(异步 OpenAI-compatible 封装)
│ │ ├── execution.py # SQL 执行工具
│ │ └── sql/
│ │ ├── env.py # SkyRL SQL 环境(RL 训练中使用)
│ │ └── utils.py # 奖励计算
│ └── LLaMA-Factory/ # SFT 训练框架(submodule)
│ └── train.sh # ★ SFT 训练启动脚本
│
├── SkyRL/ # Stage 1 & 2:Agentic RL 训练框架
│ ├── skyrl-train/
│ │ └── examples/text_to_sql/
│ │ ├── dapo_train.sh # ★ Stage 1:DAPO RL 训练
│ │ └── curriculum-learning/
│ │ └── agentic-run-stage2.sh # ★ Stage 2:课程式 RL
│ ├── skyrl-gym/ # 环境定义(SQL 环境)
│ └── skyrl-agent/ # Agent 层
│
├── eval/ # 评测流水线
│ ├── process_dataset.py # Step 1:生成 multi-turn prompt
│ ├── infer.py # Step 2:vLLM 多轮推理
│ ├── evaluate_bird.py # Step 3:BIRD/EHRSQL/ScienceBenchmark 评测
│ ├── evaluate_spider.py # Step 3:Spider 系列评测
│ └── evaluate_spider2.py # Step 3:Spider2.0 评测
│
└── outputs/ # 训练输出(日志、TensorBoard 等)
组件依赖关系:
bird/bird_train.json
│ rl2sft.py
▼
bird_sft_train.json
│ synthesize.py(调用 episode.py + model.py)
▼
trajectories.jsonl
│ LLaMA-Factory SFT
▼
cold_sft/ckpts_4584/checkpoint-864/merged_model ← Stage 0 输出
│ dapo_train.sh(SkyRL GRPO)
▼
SkyRL/ckpts/agentic-sql-8B-ckpt ← Stage 1 输出
│ agentic-run-stage2.sh(SkyRL GRPO + KL)
▼
SkyRL/ckpts/curriculum-ckpts/stage2-ckpt ← Stage 2 输出
原始 BIRD 训练集 (bird/bird_train.json)
│
▼ [Step 1] rl2sft.py
原始数据转 SFT 格式 (bird_sft_train.json)
│
▼ [Step 2] synthesize.py × N samples
Teacher LLM 多轮采样 + SQL 执行 + EX 校验
│
▼ 只保留 EX 通过的轨迹
trajectories.jsonl (最终 SFT 训练数据)
脚本:cold_sft/cold_sft_synthesis/rl2sft.py
输入:bird/bird_train.json
- 格式:RL 训练使用的 BIRD 数据集格式,每条记录包含:
{ "db_id": "concert_singer", "prompt": [ {"role": "system", "content": "... {db_details} ... {external_knowledge} ... {question} ..."}, {"role": "user", "content": "{db_details}: ...\n{external_knowledge}: ...\n{question}: ..."} ], "sql": "SELECT ..." } prompt[0](system):包含完整的 prompt 模板,含{db_details}、{external_knowledge}、{question}占位符prompt[1](user):包含具体的数据库结构、外部知识、问题内容
处理逻辑:
- 从
prompt[1](user content)中用正则提取三个字段:{db_details}:数据库 Schema 详情{external_knowledge}:外部知识(如领域知识、枚举值说明){question}:自然语言问题
- 将三个字段填入
prompt[0](system template),生成完整的单轮 prompt - 默认移除 prompt 中的示例块(
START OF EXAMPLE ... END OF EXAMPLE),可用--keep-example保留
输出:bird_sft_train.json
[
{
"db_id": "concert_singer",
"prompt": "<完整的单轮合成 prompt,含 db_details、external_knowledge、question>",
"sql": "SELECT ..."
}
]运行命令:
cd cold_sft/cold_sft_synthesis
python rl2sft.py # 默认移除示例块
# 或
python rl2sft.py --keep-example # 保留示例块注意:输入路径
bird/bird_train.json和输出路径bird_sft_train.json均为脚本中硬编码的相对路径,需在cold_sft/cold_sft_synthesis/目录下运行,或修改脚本顶部的INPUT_PATH/OUTPUT_PATH。
脚本:cold_sft/cold_sft_synthesis/synthesize.py
输入:bird_sft_train.json(来自 Step 1)
核心参数:
| 参数 | 默认值 | 说明 |
|---|---|---|
--input |
bird_sft_train.json |
输入 JSON 文件 |
--output |
trajectories.jsonl |
输出 JSONL 文件 |
--db-path |
bird/train_databases |
SQLite 数据库根目录 |
--api-base |
必填 | OpenAI-compatible API 地址 |
--api-key |
EMPTY |
API 密钥 |
--model |
必填 | Teacher 模型名称 |
--samples-per-question |
20 |
每道题采样几条轨迹 |
--num-workers |
32 |
并发 worker 数 |
--max-turns |
6 |
每条轨迹最大交互轮数 |
--num-examples |
全量 | 随机抽取 N 道题用于 SFT(剩余用于 RL) |
--random-seed |
42 |
随机种子 |
--no-resume |
关闭 | 默认支持断点续跑 |
运行命令:
cd cold_sft/cold_sft_synthesis
# 全量数据,每题 20 个样本
python synthesize.py \
--input bird_sft_train.json \
--output trajectories.jsonl \
--db-path bird/train_databases \
--api-base http://your-llm-api:8000/v1 \
--api-key EMPTY \
--model Qwen2.5-72B-Instruct \
--samples-per-question 20 \
--num-workers 32 \
--max-turns 6
# 随机抽取 500 道题(同时保存 example ID,用于 RL 训练排除)
python synthesize.py \
--input bird_sft_train.json \
--output trajectories.jsonl \
--db-path bird/train_databases \
--api-base http://your-llm-api:8000/v1 \
--model Qwen2.5-72B-Instruct \
--num-examples 500 \
--random-seed 42合成机制(episode.py):
每条轨迹的生成过程(run_episode):
- 向 Teacher LLM 发送当前 prompt(含历史),等待回复
- 解析回复中的
<think>...</think>、<sql>...</sql>、<solution>...</solution>标签- 若格式错误,最多重试 3 次,并附加格式修正提示(
RETRY_PROMPT)
- 若格式错误,最多重试 3 次,并附加格式修正提示(
- 若当前步是中间步(含
<sql>):- 在本地执行 SQL,超时时间 10 秒
- 结果超过 50 行时截断
- 将执行结果作为
<observation>追加到历史
- 若当前步是终止步(含
<solution>):- 执行等价性校验(EX check):对比 prediction SQL 和 gold SQL 的执行结果集合
- 只保留 EX 通过的轨迹
输出格式(trajectories.jsonl):
每行一条 JSON 记录:
{
"db_id": "concert_singer",
"example_id": 42,
"sample_id": 3,
"base_prompt": "<合成 prompt 文本>",
"gold_sql": "SELECT ...",
"turns": [
{
"turn": 1,
"think": "先看一下表结构...",
"sql": "SELECT * FROM singer LIMIT 5;",
"observation": "\n\n<observation>\nSTATUS: OK\nROW_COUNT: 5\n...\n</observation>\n\n",
"solution": null,
"is_final": false
},
{
"turn": 2,
"think": "从结果来看,答案应该是...",
"sql": null,
"observation": null,
"solution": "SELECT singer_name FROM singer WHERE age > 25;",
"is_final": true
}
]
}字段说明:
example_id:对应bird_sft_train.json中的题目索引sample_id:同一道题的第几次采样(0 到samples_per_question-1)turns:每个元素代表一个交互轮次- 中间轮次:
think+sql+observation(solution=null) - 最终轮次:
think+solution(sql=null,observation=null,is_final=true)
- 中间轮次:
断点续跑:脚本默认开启 resume,重启后会读取已有 trajectories.jsonl,跳过已完成的 (example_id, sample_id) 对。如需从头重跑,加 --no-resume。
SFT 数据拆分:使用 --num-examples N 时,被采样的题目 ID 会保存到 trajectories_example_ids.json,可用于在 RL 训练时排除这部分题目,避免数据泄漏。
Stage 0: Cold-Start SFT
输入:trajectories.jsonl(Teacher 采样的成功轨迹)
工具:LLaMA-Factory(train.sh)
输出:cold_sft/ckpts_4584/checkpoint-864/merged_model
目的:让模型学会多轮 SQL Agent 的基本格式与策略
Stage 1: Agentic RL (DAPO)
输入:Stage 0 的 merged_model + SkyRL/data/bird_train_rl.parquet
工具:SkyRL(dapo_train.sh)
算法:GRPO + DAPO 动态采样
输出:SkyRL/ckpts/agentic-sql-8B-ckpt
目的:通过在线 RL 进一步提升 Agent 的探索与修正能力
Stage 2: Curriculum RL(可选)
输入:Stage 1 的 merged_model + SkyRL/data/curriculum_learning/stage2_train.parquet
工具:SkyRL(agentic-run-stage2.sh)
算法:GRPO + KL 正则化
输出:SkyRL/ckpts/curriculum-ckpts/stage2-ckpt
目的:在更难的课程数据上继续优化,加 KL 约束防止策略偏移
阶段间依赖:
Stage 0 输出 (merged_model)
└──→ Stage 1 的 MODEL_PATH(dapo_train.sh 第 9 行)
Stage 1 输出 (stage1-ckpt/global_step_60/merged_model)
└──→ Stage 2 的 MODEL_PATH(agentic-run-stage2.sh 第 8 行)
LLaMA-Factory 使用 UV 包管理工具管理依赖:
cd cold_sft/LLaMA-Factory
# 配置 UV 路径(根据实际环境修改 train.sh 中的路径)
export UV_PYTHON_INSTALL_DIR=/your/path/.uv/python
export UV_CACHE_DIR=/your/path/.cache/uv
export UV_PROJECT_ENVIRONMENT=/your/path/venvs/llamafactory
# 安装依赖(离线模式)
uv --offline sync --extra torch --extra metrics --prerelease=allow确保已完成 第 3 节 中的数据合成,得到 trajectories.jsonl。
需将 trajectories.jsonl 转换为 LLaMA-Factory 所需的 SFT 训练格式(多轮对话格式),具体格式转换逻辑参见 LLaMA-Factory 文档。
cd cold_sft/LLaMA-Factory
bash train.sh关键超参(在 LLaMA-Factory 的 YAML 配置中设置,train.sh 负责环境初始化):
- 模型:Qwen/Qwen2.5 系列(8B)
- 训练数据:转换后的多轮 SFT 格式
- 方法:全量 SFT(full fine-tuning)或 LoRA
训练完成后,合并模型权重:
# LLaMA-Factory 内置合并命令(LoRA 场景)
llamafactory-cli export \
--model_name_or_path /path/to/base_model \
--adapter_name_or_path /path/to/lora_checkpoint \
--export_dir cold_sft/ckpts_4584/checkpoint-864/merged_model \
--export_size 4输出路径:cold_sft/ckpts_4584/checkpoint-864/merged_model(供 Stage 1 使用)
脚本:SkyRL/skyrl-train/examples/text_to_sql/dapo_train.sh
路径配置(脚本顶部,需根据实际路径修改):
ROOT_DIR="/volume/pt-coder/users/hxyan/agentic-sql"
DATA_DIR="$ROOT_DIR/SkyRL/data"
DB_PATH="$ROOT_DIR/SkyRL/data"
CKPT_PATH="$ROOT_DIR/SkyRL/ckpts/agentic-sql-8B-ckpt" # Stage 1 检查点保存路径
MODEL_PATH="$ROOT_DIR/cold_sft/ckpts_4584/checkpoint-864/merged_model" # ← Stage 0 输出关键超参:
| 参数 | 值 | 说明 |
|---|---|---|
trainer.algorithm.advantage_estimator |
grpo |
使用 GRPO 算法 |
trainer.epochs |
30 |
训练轮数 |
trainer.policy.optimizer_config.lr |
5.0e-6 |
学习率 |
trainer.train_batch_size |
320 |
训练批次大小 |
trainer.micro_train_batch_size_per_gpu |
4 |
每 GPU 微批次 |
trainer.micro_forward_batch_size_per_gpu |
8 |
前向微批次 |
generator.n_samples_per_prompt |
10 |
每个 prompt 采样数 |
generator.max_turns |
6 |
最大交互轮数 |
generator.sampling_params.temperature |
0.7 |
采样温度 |
generator.sampling_params.top_p |
0.95 |
Top-p |
generator.max_input_length |
29000 |
最大输入长度(token) |
generator.sampling_params.max_generate_length |
3000 |
最大生成长度 |
trainer.max_prompt_length |
9000 |
训练时 prompt 最大长度 |
trainer.algorithm.use_kl_loss |
false |
不使用 KL loss |
trainer.algorithm.eps_clip_low/high |
0.2 / 0.28 |
DAPO 裁剪范围 |
trainer.algorithm.dynamic_sampling.type |
filter |
DAPO 动态采样类型 |
trainer.algorithm.dynamic_sampling.max_sample_batches |
30 |
动态采样最大批次 |
trainer.ckpt_interval |
60 |
每 60 step 保存一次 ckpt |
trainer.hf_save_interval |
30 |
每 30 step 保存一次 HF 格式模型 |
trainer.eval_interval |
5 |
每 5 step 评估一次 |
trainer.strategy |
fsdp2 |
分布式训练策略 |
NUM_GPUS |
8 |
训练 GPU 数 |
NUM_INFERENCE_ENGINES |
2 |
推理引擎数(每个 TP=4) |
数据:
- 训练集:
SkyRL/data/bird_train_rl.parquet - 验证集:
SkyRL/data/bird_dev_rl.parquet - 数据库:
SkyRL/data/bird/train/train_databases/和SkyRL/data/bird/dev/dev_databases/
奖励信号:由 SQLEnv._get_reward() 计算,在终止步(含 <solution>)时执行等价性校验,通过得 1.0,失败得 0.0,中间步奖励为 0。
停止词:</sql>、</solution>
启动命令:
cd SkyRL/skyrl-train
bash examples/text_to_sql/dapo_train.sh脚本:SkyRL/skyrl-train/examples/text_to_sql/curriculum-learning/agentic-run-stage2.sh
路径配置:
ROOT_DIR="/volume/pt-coder/users/hxyan/agentic-sql"
DATA_DIR="$ROOT_DIR/SkyRL/data"
CKPT_PATH="$ROOT_DIR/SkyRL/ckpts/curriculum-ckpts/stage2-ckpt"
MODEL_PATH="$ROOT_DIR/SkyRL/ckpts/curriculum-ckpts/stage1-ckpt/global_step_60/merged_model" # ← Stage 1 输出
stage1-ckpt需要手动从 Stage 1 的 DAPO 检查点中选取合适的 step,并确保已保存为 merged HF 格式。
与 Stage 1 的关键差异:
| 参数 | Stage 1 | Stage 2 | 变化原因 |
|---|---|---|---|
| 训练数据 | bird_train_rl.parquet |
curriculum_learning/stage2_train.parquet |
更难的课程数据 |
epochs |
30 |
10 |
防止在较小课程集上过拟合 |
lr |
5.0e-6 |
2.0e-6 |
精调阶段降低学习率 |
train_batch_size |
320 |
256 |
略减批次 |
use_kl_loss |
false |
true |
防止策略偏离 Stage 1 |
kl_loss_coef |
— | 0.001 |
KL 正则化系数 |
temperature |
0.7 |
0.8 |
增加探索多样性 |
ckpt_interval |
60 |
15 |
更频繁保存 |
max_prompt_length |
9000 |
10000 |
课程数据可能更长 |
gpu_memory_utilization |
0.85 |
0.7 |
适配显存占用 |
启动命令:
cd SkyRL/skyrl-train
bash examples/text_to_sql/curriculum-learning/agentic-run-stage2.shStep 1:生成 Multi-Turn Prompt
cd eval
# BIRD(示例)
python process_dataset.py \
--input_data_file /data/bird/dev.json \
--output_data_file ./data/bird_dev_processed.json \
--db_path /data/bird/databases \
--tables /data/bird/tables.json \
--source birdStep 2:多轮推理
# Greedy Search(单 GPU TP=2)
CUDA_VISIBLE_DEVICES=0,1 python infer.py \
--model_path /path/to/model \
--input_path ./data/bird_dev_processed.json \
--output_path ./results/bird_dev_greedy.json \
--db_path /data/bird/databases \
--max_turns 6 \
--max_tokens 3000 \
--n 1 \
--temperature 0.0 \
--tensor_parallel_size 2
# 多样本采样(用于 Major Voting / Pass@k)
CUDA_VISIBLE_DEVICES=0,1 python infer.py \
--model_path /path/to/model \
--input_path ./data/bird_dev_processed.json \
--output_path ./results/bird_dev_sampling.json \
--db_path /data/bird/databases \
--max_turns 6 \
--n 16 \
--temperature 0.6
# 多 GPU 数据并行
CUDA_VISIBLE_DEVICES=0,1,2,3 python infer.py \
--model_path /path/to/model \
--input_path ./data/bird_dev_processed.json \
--output_path ./results/bird_dev_greedy.json \
--db_path /data/bird/databases \
--num_gpus 4 \
--tp_per_instance 1 \
--n 1 \
--temperature 0.0Step 3:评估
# BIRD Greedy Search
python evaluate_bird.py \
--input_path ./results/bird_dev_greedy.json \
--db_path /data/bird/databases \
--mode greedy_search
# BIRD Major Voting
python evaluate_bird.py \
--input_path ./results/bird_dev_sampling.json \
--db_path /data/bird/databases \
--mode major_voting
# BIRD Pass@k
python evaluate_bird.py \
--input_path ./results/bird_dev_sampling.json \
--db_path /data/bird/databases \
--mode pass@k
# Spider 系列
python evaluate_spider.py \
--pred ./results/spider_dev_greedy.json \
--gold /data/spider/dev_gold.sql \
--db_path /data/spider/database \
--mode greedy_search| 数据集 | --source 参数 |
评估脚本 |
|---|---|---|
| BIRD | bird |
evaluate_bird.py |
| EHRSQL | ehrsql |
evaluate_bird.py |
| ScienceBenchmark | sciencebenchmark |
evaluate_bird.py |
| Spider | spider |
evaluate_spider.py |
| Spider-DK | spider_dk |
evaluate_spider.py |
| Spider-Realistic | spider_realistic |
evaluate_spider.py |
| Spider-Syn | spider_syn |
evaluate_spider.py |
| Spider2.0 | spider2.0 |
evaluate_spider2.py |
cd eval
# 评估单个模型
python auto_evaluation.py \
--output_ckpt_dir /path/to/model \
--source bird \
--input_file ./data/bird_dev_processed.json \
--db_path /data/bird/databases \
--eval_name bird_dev \
--visible_devices 0,1 \
--tensor_parallel_size 2 \
--n 16
# 批量评估多个 checkpoint
python auto_evaluation.py \
--output_ckpt_dir /path/to/checkpoints \
--multiple_models \
--source bird \
--input_file ./data/bird_dev_processed.json \
--db_path /data/bird/databases \
--eval_name bird_dev_multi_ckpt \
--visible_devices 0,1,2,3 \
--tensor_parallel_size 2 \
--n 16输出:
eval/results/{eval_name}/:推理结果 JSONeval/evaluation_results/{eval_name}/:评估指标 JSON 和可视化图表
cold_sft/
├── ckpts_4584/checkpoint-864/
│ └── merged_model/ # 合并后的 HF 格式模型(供 Stage 1 使用)
└── tensorboard/ # TensorBoard 日志
SkyRL/ckpts/agentic-sql-8B-ckpt/
├── global_step_XX/ # 分布式 ckpt(每 60 step 保存一次)
│ └── merged_model/ # HF 格式模型(每 30 step 保存一次)
└── ...
- 恢复训练:
trainer.resume_mode=latest自动加载最新检查点 - 评测使用:取
global_step_XX/merged_model/目录
SkyRL/ckpts/curriculum-ckpts/
├── stage1-ckpt/ # Stage 1 选定的起始检查点
│ └── global_step_60/merged_model/
└── stage2-ckpt/ # Stage 2 输出(每 15 step 保存一次)
- TensorBoard 日志:
outputs/目录 - 项目名称:
skyrlsql,运行名称:skyrlsql_repro - 查看命令:
tensorboard --logdir outputs/
按执行顺序排列:
# ============================================================
# 【Stage 0 - 数据合成】
# ============================================================
# Step 0a: 格式转换(需在 cold_sft/cold_sft_synthesis/ 目录下运行)
cd cold_sft/cold_sft_synthesis
python rl2sft.py
# 输出:bird_sft_train.json
# Step 0b: Teacher 采样(替换 api-base 和 model 为实际值)
python synthesize.py \
--input bird_sft_train.json \
--output trajectories.jsonl \
--db-path bird/train_databases \
--api-base http://your-llm-api:8000/v1 \
--model Qwen2.5-72B-Instruct \
--samples-per-question 20 \
--num-workers 32 \
--max-turns 6
# 输出:trajectories.jsonl
# ============================================================
# 【Stage 0 - SFT 训练】
# ============================================================
# Step 0c: 启动 SFT 训练(需先将 trajectories.jsonl 转换为 LLaMA-Factory 格式)
cd cold_sft/LLaMA-Factory
bash train.sh
# 输出:cold_sft/ckpts_4584/checkpoint-864/merged_model
# ============================================================
# 【Stage 1 - Agentic RL (DAPO)】
# ============================================================
# Step 1: 修改 dapo_train.sh 中的 ROOT_DIR 后启动
cd SkyRL/skyrl-train
bash examples/text_to_sql/dapo_train.sh
# 输出:SkyRL/ckpts/agentic-sql-8B-ckpt/
# ============================================================
# 【Stage 2 - Curriculum RL(可选)】
# ============================================================
# Step 2: 修改 agentic-run-stage2.sh 中的 ROOT_DIR 和 MODEL_PATH 后启动
bash examples/text_to_sql/curriculum-learning/agentic-run-stage2.sh
# 输出:SkyRL/ckpts/curriculum-ckpts/stage2-ckpt/
# ============================================================
# 【评测】
# ============================================================
cd eval
# Step E1: 生成评测 prompt
python process_dataset.py \
--input_data_file /data/bird/dev.json \
--output_data_file ./data/bird_dev_processed.json \
--db_path /data/bird/databases \
--tables /data/bird/tables.json \
--source bird
# Step E2: 推理
CUDA_VISIBLE_DEVICES=0,1 python infer.py \
--model_path /path/to/final_model \
--input_path ./data/bird_dev_processed.json \
--output_path ./results/bird_dev_greedy.json \
--db_path /data/bird/databases \
--max_turns 6 --n 1 --temperature 0.0 --tensor_parallel_size 2
# Step E3: 评估
python evaluate_bird.py \
--input_path ./results/bird_dev_greedy.json \
--db_path /data/bird/databases \
--mode greedy_searchrl2sft.py的输入路径bird/bird_train.json是相对于cold_sft/cold_sft_synthesis/的,需在该目录下运行dapo_train.sh和agentic-run-stage2.sh中的ROOT_DIR需修改为实际绝对路径- 数据库路径遵循统一结构:
{db_path}/{db_id}/{db_id}.sqlite SQLEnv中task字段决定数据库子路径:bird_train→bird/train/train_databases/,bird_dev→bird/dev/dev_databases/
- Stage 1:减小
trainer.micro_train_batch_size_per_gpu(当前为 4)或trainer.micro_forward_batch_size_per_gpu(当前为 8) - Stage 2:已将
gpu_memory_utilization从 0.85 降至 0.7,若仍 OOM 可继续降低 - 评测推理:减小
--tensor_parallel_size或减少--n
RL 脚本已配置 trainer.resume_mode=latest,重启后自动加载 CKPT_PATH 下的最新检查点,无需修改命令。
脚本默认支持断点续跑(append 模式写文件),重启后会自动读取已完成的 (example_id, sample_id) 对并跳过,无需任何额外操作。若需强制从头开始,加 --no-resume。
episode.py 内置格式修复逻辑,每轮最多重试 3 次。若仍失败,该条轨迹直接丢弃(返回 None)。Teacher 模型能力越强,成功率越高;建议使用 70B 级别以上的模型。
使用 --num-examples N 运行 synthesize.py 时,被选中的题目 ID 会保存到 trajectories_example_ids.json。在准备 RL 训练数据(bird_train_rl.parquet)时,过滤掉这些 ID 对应的题目即可。
Stage 2 脚本中使用的是 stage1-ckpt/global_step_60/merged_model,即 Stage 1 第 60 步的 merged 模型。实际使用时,建议通过评测分数选择 Stage 1 最佳检查点,而非固定使用 step 60。