From 3d3af7d39d89a4bedb816f42e44767c8b1592721 Mon Sep 17 00:00:00 2001 From: Surya Duraivenkatesh Date: Fri, 26 Jun 2026 16:18:12 -0700 Subject: [PATCH 01/13] Add Tillicum Qwen3 8B OPD example --- .../00_pull_or_load_container.sh | 36 +++ .../qwen3_8b_opd_tillicum/01_prepare_env.sh | 68 ++++ .../02_prepare_openthoughts3_math_sample.py | 292 ++++++++++++++++++ .../03_convert_models_if_needed.sbatch | 60 ++++ .../04_run_sft_100k_8xh200.sbatch | 143 +++++++++ .../05_run_opd_50k_8xh200.sbatch | 205 ++++++++++++ .../06_eval_math500_greedy_1x.sbatch | 230 ++++++++++++++ examples/qwen3_8b_opd_tillicum/README.md | 112 +++++++ .../qwen3_8b_opd_tillicum/container_exec.sh | 114 +++++++ examples/qwen3_8b_opd_tillicum/env.sh | 99 ++++++ .../run_all_dry_check.sh | 72 +++++ .../qwen3_8b_opd_tillicum/summarize_eval.py | 116 +++++++ 12 files changed, 1547 insertions(+) create mode 100755 examples/qwen3_8b_opd_tillicum/00_pull_or_load_container.sh create mode 100755 examples/qwen3_8b_opd_tillicum/01_prepare_env.sh create mode 100755 examples/qwen3_8b_opd_tillicum/02_prepare_openthoughts3_math_sample.py create mode 100755 examples/qwen3_8b_opd_tillicum/03_convert_models_if_needed.sbatch create mode 100755 examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch create mode 100755 examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch create mode 100755 examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch create mode 100644 examples/qwen3_8b_opd_tillicum/README.md create mode 100755 examples/qwen3_8b_opd_tillicum/container_exec.sh create mode 100755 examples/qwen3_8b_opd_tillicum/env.sh create mode 100755 examples/qwen3_8b_opd_tillicum/run_all_dry_check.sh create mode 100755 examples/qwen3_8b_opd_tillicum/summarize_eval.py diff --git a/examples/qwen3_8b_opd_tillicum/00_pull_or_load_container.sh b/examples/qwen3_8b_opd_tillicum/00_pull_or_load_container.sh new file mode 100755 index 000000000..a13e1e266 --- /dev/null +++ b/examples/qwen3_8b_opd_tillicum/00_pull_or_load_container.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" +source "${SCRIPT_DIR}/env.sh" + +APPTAINER_BIN="${APPTAINER_BIN:-}" +if [[ -z "${APPTAINER_BIN}" ]]; then + if command -v apptainer >/dev/null 2>&1; then + APPTAINER_BIN=apptainer + elif command -v singularity >/dev/null 2>&1; then + APPTAINER_BIN=singularity + else + echo "Neither apptainer nor singularity is available on PATH." >&2 + exit 1 + fi +fi + +mkdir -p "$(dirname "${SLIME_SIF}")" "${APPTAINER_CACHEDIR}" "${APPTAINER_TMPDIR}" +export APPTAINER_CACHEDIR APPTAINER_TMPDIR + +cat </dev/null 2>&1 && pwd)" +source "${SCRIPT_DIR}/env.sh" + +mkdir -p \ + "${SCRATCH_ROOT}" \ + "${DATA_ROOT}" \ + "${MODEL_ROOT}" \ + "${OUTPUT_ROOT}" \ + "${HF_HOME}" \ + "${HF_DATASETS_CACHE}" \ + "${TRANSFORMERS_CACHE}" \ + "${WANDB_DIR}" \ + "${TMPDIR}" \ + "${RAY_TMPDIR}" \ + "${APPTAINER_CACHEDIR}" \ + "${APPTAINER_TMPDIR}" \ + "$(dirname "${SLIME_SIF}")" \ + "${CONTAINER_HOME}" \ + "${SFT_SAVE_DIR}" \ + "${OPD_SAVE_DIR}" \ + "${SFT_DETAILS_DIR}" \ + "${OPD_ROLLOUT_LOG_DIR}" \ + "${TEACHER_LOG_DIR}" \ + "${EVAL_OUTPUT_DIR}" \ + "${SLURM_LOG_DIR}" + +cat < argparse.Namespace: + env = os.environ + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--dataset", default=env.get("OT3_DATASET", "open-thoughts/OpenThoughts3-1.2M")) + parser.add_argument("--split", default=env.get("OT3_SPLIT", "train")) + parser.add_argument("--sft-size", type=int, default=int(env.get("SFT_SIZE", "100000"))) + parser.add_argument("--opd-size", type=int, default=int(env.get("OPD_SIZE", "50000"))) + parser.add_argument("--seed", type=int, default=int(env.get("DATA_SEED", "1234"))) + parser.add_argument("--math-field", default=env.get("DATA_MATH_FIELD", "domain")) + parser.add_argument("--math-value", default=env.get("DATA_MATH_VALUE", "math")) + parser.add_argument("--sft-out", default=env.get("SFT_PARQUET")) + parser.add_argument("--opd-out", default=env.get("OPD_JSONL")) + parser.add_argument("--metadata-out", default=env.get("SPLIT_METADATA")) + parser.add_argument("--hf-home", default=env.get("HF_HOME")) + parser.add_argument("--tokenizer", default=env.get("STUDENT_HF_REPO", "Qwen/Qwen3-8B-Base")) + parser.add_argument("--max-source-rows", type=int, default=None) + parser.add_argument("--force", action="store_true") + args = parser.parse_args() + + missing = [name for name in ("sft_out", "opd_out", "metadata_out") if getattr(args, name) in (None, "")] + if missing: + parser.error(f"missing required output env/arg(s): {', '.join(missing)}") + return args + + +def normalize_messages(value: Any, row: dict[str, Any]) -> list[dict[str, str]]: + if isinstance(value, str): + try: + value = json.loads(value) + except json.JSONDecodeError: + return [{"role": "user", "content": value}] + + if isinstance(value, dict) and "messages" in value: + value = value["messages"] + + if isinstance(value, list): + out: list[dict[str, str]] = [] + for item in value: + if not isinstance(item, dict): + continue + raw_role = item.get("role", item.get("from", item.get("speaker", ""))) + role = ROLE_MAP.get(str(raw_role).lower(), str(raw_role).lower()) + content = item.get("content", item.get("value", item.get("text", ""))) + if content is None: + content = "" + if role in {"system", "user", "assistant", "tool"}: + out.append({"role": role, "content": str(content)}) + if out: + return out + + prompt = first_present(row, ["prompt", "problem", "question", "instruction", "input"]) + response = first_present(row, ["response", "completion", "answer", "solution", "output"]) + if prompt is not None and response is not None: + return [ + {"role": "user", "content": str(prompt)}, + {"role": "assistant", "content": str(response)}, + ] + if prompt is not None: + return [{"role": "user", "content": str(prompt)}] + + raise ValueError("Could not infer OpenAI-style messages from row.") + + +def first_present(row: dict[str, Any], keys: list[str]) -> Any | None: + for key in keys: + if key in row and row[key] not in (None, ""): + return row[key] + return None + + +def extract_messages(row: dict[str, Any]) -> list[dict[str, str]]: + for key in ("messages", "conversations", "conversation"): + if key in row and row[key] not in (None, ""): + return normalize_messages(row[key], row) + return normalize_messages(None, row) + + +def extract_prompt(row: dict[str, Any], messages: list[dict[str, str]]) -> str: + prompt = first_present(row, ["prompt", "problem", "question", "instruction", "input"]) + if prompt is not None: + return str(prompt) + for message in messages: + if message.get("role") == "user": + return str(message.get("content", "")) + raise ValueError("Could not infer prompt text from row/messages.") + + +def prompt_hash(text: str) -> str: + return hashlib.sha256(text.encode("utf-8")).hexdigest() + + +def write_jsonl(path: Path, rows: list[dict[str, Any]]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as f: + for row in rows: + f.write(json.dumps(row, ensure_ascii=False) + "\n") + + +def token_lengths(tokenizer: Any, sft_rows: list[dict[str, Any]], opd_rows: list[dict[str, Any]]) -> dict[str, Any]: + sft_lengths: list[int] = [] + for row in sft_rows: + messages = row["messages"] + try: + ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=False) + except Exception: + text = "\n".join(f"{m['role']}: {m['content']}" for m in messages) + ids = tokenizer(text, add_special_tokens=False)["input_ids"] + sft_lengths.append(len(ids)) + + opd_lengths: list[int] = [] + for row in opd_rows: + messages = [{"role": "user", "content": row["prompt"]}] + try: + ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True) + except Exception: + ids = tokenizer(row["prompt"], add_special_tokens=False)["input_ids"] + opd_lengths.append(len(ids)) + + def stats(values: list[int]) -> dict[str, float | int]: + return { + "count": len(values), + "avg": float(statistics.fmean(values)) if values else 0.0, + "max": max(values) if values else 0, + } + + return { + "sft_total_token_lengths": stats(sft_lengths), + "opd_prompt_token_lengths": stats(opd_lengths), + } + + +def main() -> None: + args = parse_args() + sft_out = Path(args.sft_out) + opd_out = Path(args.opd_out) + metadata_out = Path(args.metadata_out) + + existing = [path for path in (sft_out, opd_out, metadata_out) if path.exists()] + if existing and not args.force: + raise SystemExit( + "Output already exists; use --force to replace:\n" + "\n".join(f" {path}" for path in existing) + ) + + print(f"Loading {args.dataset} split={args.split}") + ds = load_dataset(args.dataset, split=args.split, cache_dir=args.hf_home) + if args.max_source_rows is not None: + ds = ds.select(range(min(args.max_source_rows, len(ds)))) + + if args.math_field not in ds.column_names: + raise SystemExit( + f"Dataset does not contain math filter field {args.math_field!r}. " + f"Columns: {', '.join(ds.column_names)}" + ) + + ds = ds.map(lambda _, idx: {"source_row_id": idx}, with_indices=True) + math_value = str(args.math_value).lower() + math_ds = ds.filter(lambda row: str(row.get(args.math_field, "")).lower() == math_value) + + required = args.sft_size + args.opd_size + if len(math_ds) < required: + raise SystemExit(f"Need {required} math rows, found {len(math_ds)} after filtering.") + + rng = random.Random(args.seed) + shuffled = list(range(len(math_ds))) + rng.shuffle(shuffled) + sft_indices = shuffled[: args.sft_size] + opd_indices = shuffled[args.sft_size : required] + + sft_rows: list[dict[str, Any]] = [] + opd_rows: list[dict[str, Any]] = [] + + print(f"Building SFT rows: {len(sft_indices)}") + for idx in sft_indices: + row = math_ds[int(idx)] + messages = extract_messages(row) + prompt = extract_prompt(row, messages) + source_row_id = int(row["source_row_id"]) + sft_rows.append( + { + "messages": messages, + "metadata": { + "source_dataset": args.dataset, + "source_split": args.split, + "source_row_id": source_row_id, + "prompt_sha256": prompt_hash(prompt), + "split": "sft", + }, + } + ) + + print(f"Building OPD rows: {len(opd_indices)}") + for idx in opd_indices: + row = math_ds[int(idx)] + messages = extract_messages(row) + prompt = extract_prompt(row, messages) + source_row_id = int(row["source_row_id"]) + opd_rows.append( + { + "prompt": prompt, + "metadata": { + "source_dataset": args.dataset, + "source_split": args.split, + "source_row_id": source_row_id, + "prompt_sha256": prompt_hash(prompt), + "split": "opd", + }, + } + ) + + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, cache_dir=args.hf_home, trust_remote_code=True) + length_stats = token_lengths(tokenizer, sft_rows, opd_rows) + + sft_out.parent.mkdir(parents=True, exist_ok=True) + Dataset.from_list(sft_rows).to_parquet(str(sft_out)) + write_jsonl(opd_out, opd_rows) + + sft_source_ids = [row["metadata"]["source_row_id"] for row in sft_rows] + opd_source_ids = [row["metadata"]["source_row_id"] for row in opd_rows] + overlap = sorted(set(sft_source_ids).intersection(opd_source_ids)) + if overlap: + raise RuntimeError(f"SFT/OPD row split overlap detected: first overlaps {overlap[:10]}") + + metadata = { + "seed": args.seed, + "dataset": args.dataset, + "source_split": args.split, + "filtering_criteria": { + "field": args.math_field, + "value": args.math_value, + "operation": "case-insensitive equality", + }, + "counts": { + "source_rows_seen": len(ds), + "math_rows_after_filter": len(math_ds), + "sft": len(sft_rows), + "opd": len(opd_rows), + }, + "outputs": { + "sft_parquet": str(sft_out), + "opd_jsonl": str(opd_out), + }, + "source_row_ids": { + "sft": sft_source_ids, + "opd": opd_source_ids, + }, + "prompt_hashes": { + "sft": [row["metadata"]["prompt_sha256"] for row in sft_rows], + "opd": [row["metadata"]["prompt_sha256"] for row in opd_rows], + }, + "tokenizer": args.tokenizer, + "token_lengths": length_stats, + } + + metadata_out.parent.mkdir(parents=True, exist_ok=True) + metadata_out.write_text(json.dumps(metadata, indent=2, ensure_ascii=False) + "\n", encoding="utf-8") + + print(f"Wrote {sft_out}") + print(f"Wrote {opd_out}") + print(f"Wrote {metadata_out}") + + +if __name__ == "__main__": + main() diff --git a/examples/qwen3_8b_opd_tillicum/03_convert_models_if_needed.sbatch b/examples/qwen3_8b_opd_tillicum/03_convert_models_if_needed.sbatch new file mode 100755 index 000000000..441f0e00a --- /dev/null +++ b/examples/qwen3_8b_opd_tillicum/03_convert_models_if_needed.sbatch @@ -0,0 +1,60 @@ +#!/bin/bash +#SBATCH --job-name=slime-qwen3-convert +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=64 +#SBATCH --gres=gpu:h200:8 +#SBATCH --mem=0 +#SBATCH --time=02:00:00 +#SBATCH --mail-user=suryadv@cs.washington.edu +#SBATCH --mail-type=END,FAIL + +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" +source "${SCRIPT_DIR}/env.sh" + +mkdir -p "${SLURM_LOG_DIR}" +if [[ -n "${SLURM_JOB_ID:-}" ]]; then + exec >"${SLURM_LOG_DIR}/${SLURM_JOB_NAME}_${SLURM_JOB_ID}.log" 2>&1 +fi + +echo "Starting model prep on $(hostname) at $(date)" +echo "Student HF: ${STUDENT_HF_REPO} -> ${STUDENT_HF_DIR}" +echo "Teacher HF: ${TEACHER_HF_REPO} -> ${TEACHER_HF_DIR}" +echo "Student torch_dist: ${STUDENT_TORCH_DIST_DIR}" + +"${SCRIPT_DIR}/container_exec.sh" bash -lc ' +set -euo pipefail +cd "${SLIME_REPO_ROOT}" +mkdir -p "${MODEL_ROOT}" "${DATA_ROOT}" + +if [[ ! -d "${STUDENT_HF_DIR}" || -z "$(ls -A "${STUDENT_HF_DIR}" 2>/dev/null)" ]]; then + hf download "${STUDENT_HF_REPO}" --local-dir "${STUDENT_HF_DIR}" +else + echo "Student HF snapshot already exists: ${STUDENT_HF_DIR}" +fi + +if [[ ! -d "${TEACHER_HF_DIR}" || -z "$(ls -A "${TEACHER_HF_DIR}" 2>/dev/null)" ]]; then + hf download "${TEACHER_HF_REPO}" --local-dir "${TEACHER_HF_DIR}" +else + echo "Teacher HF snapshot already exists: ${TEACHER_HF_DIR}" +fi + +hf download --repo-type dataset "${MATH500_DATASET}" --local-dir "${DATA_ROOT}/math500_hf_snapshot" || true + +if [[ -f "${STUDENT_TORCH_DIST_DIR}/latest_checkpointed_iteration.txt" ]]; then + echo "Student torch_dist conversion already exists: ${STUDENT_TORCH_DIST_DIR}" + exit 0 +fi + +source scripts/models/qwen3-8B.sh +mkdir -p "${STUDENT_TORCH_DIST_DIR}" +PYTHONPATH="/root/Megatron-LM:${SLIME_REPO_ROOT}:${PYTHONPATH:-}" \ +torchrun --nproc_per_node="${CONVERT_NPROC:-8}" tools/convert_hf_to_torch_dist.py \ + "${MODEL_ARGS[@]}" \ + --hf-checkpoint "${STUDENT_HF_DIR}" \ + --save "${STUDENT_TORCH_DIST_DIR}" +' + +echo "Finished model prep at $(date)" diff --git a/examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch b/examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch new file mode 100755 index 000000000..57c05d6d6 --- /dev/null +++ b/examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch @@ -0,0 +1,143 @@ +#!/bin/bash +#SBATCH --job-name=slime-qwen3-sft100k +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=64 +#SBATCH --gres=gpu:h200:8 +#SBATCH --mem=0 +#SBATCH --time=04:00:00 +#SBATCH --mail-user=suryadv@cs.washington.edu +#SBATCH --mail-type=END,FAIL + +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" +source "${SCRIPT_DIR}/env.sh" + +mkdir -p "${SLURM_LOG_DIR}" "${SFT_SAVE_DIR}" "${SFT_DETAILS_DIR}" +if [[ -n "${SLURM_JOB_ID:-}" ]]; then + exec >"${SLURM_LOG_DIR}/${SLURM_JOB_NAME}_${SLURM_JOB_ID}.log" 2>&1 +fi + +for required_path in "${SFT_PARQUET}" "${STUDENT_HF_DIR}" "${STUDENT_TORCH_DIST_DIR}/latest_checkpointed_iteration.txt"; do + if [[ ! -e "${required_path}" ]]; then + echo "Missing required path: ${required_path}" >&2 + exit 1 + fi +done + +echo "Starting SFT on $(hostname) at $(date)" +echo "SFT data: ${SFT_PARQUET}" +echo "Save dir: ${SFT_SAVE_DIR}" + +"${SCRIPT_DIR}/container_exec.sh" bash -lc ' +set -euo pipefail +cd "${SLIME_REPO_ROOT}" + +ray stop --force >/dev/null 2>&1 || true + +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o "NV[0-9][0-9]*" | wc -l || true) +if [[ "${NVLINK_COUNT}" -gt 0 ]]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK=${HAS_NVLINK} (detected ${NVLINK_COUNT} NVLink references)" + +source scripts/models/qwen3-8B.sh + +CKPT_ARGS=( + --hf-checkpoint "${STUDENT_HF_DIR}" + --ref-load "${STUDENT_TORCH_DIST_DIR}" + --load "${SFT_SAVE_DIR}" + --save "${SFT_SAVE_DIR}" + --save-interval "${SFT_SAVE_INTERVAL}" + --no-save-optim +) + +SFT_ARGS=( + --rollout-function-path slime.rollout.sft_rollout.generate_rollout + --prompt-data "${SFT_PARQUET}" + --input-key messages + --rollout-shuffle + --num-epoch "${SFT_NUM_EPOCH}" + --rollout-batch-size "${SFT_ROLLOUT_BATCH_SIZE}" + --global-batch-size "${SFT_GLOBAL_BATCH_SIZE}" + --loss-type sft_loss + --calculate-per-token-loss + --disable-compute-advantages-and-returns + --debug-train-only + --dump-details "${SFT_DETAILS_DIR}" +) + +PERF_ARGS=( + --tensor-model-parallel-size 2 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + --use-dynamic-batch-size + --max-tokens-per-gpu "${SFT_MAX_TOKENS_PER_GPU}" +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-5 + --lr-decay-style cosine + --min-lr 1e-6 + --lr-warmup-fraction 0.1 + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.95 +) + +WANDB_ARGS=() +if [[ "${WANDB_MODE}" != "disabled" ]]; then + WANDB_ARGS+=(--wandb-project slime-tillicum --wandb-group qwen3-8b-sft100k) +fi + +MISC_ARGS=( + --attention-dropout 0.0 + --hidden-dropout 0.0 + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + --attention-backend flash +) + +export MASTER_ADDR="${MASTER_ADDR:-127.0.0.1}" +export no_proxy="127.0.0.1,localhost,${MASTER_ADDR}" +ray start --head --node-ip-address "${MASTER_ADDR}" --num-gpus 8 \ + --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 + +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM:${SLIME_REPO_ROOT}\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\", + \"PYTORCH_CUDA_ALLOC_CONF\": \"expandable_segments:True\", + \"WANDB_MODE\": \"${WANDB_MODE}\", + \"HF_HOME\": \"${HF_HOME}\" + } +}" + +trap "ray stop --force >/dev/null 2>&1 || true" EXIT + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train_async.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node 8 \ + "${MODEL_ARGS[@]}" \ + "${CKPT_ARGS[@]}" \ + "${SFT_ARGS[@]}" \ + "${OPTIMIZER_ARGS[@]}" \ + "${WANDB_ARGS[@]}" \ + "${PERF_ARGS[@]}" \ + "${MISC_ARGS[@]}" +' + +echo "Finished SFT at $(date)" diff --git a/examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch b/examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch new file mode 100755 index 000000000..06ef0a2ba --- /dev/null +++ b/examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch @@ -0,0 +1,205 @@ +#!/bin/bash +#SBATCH --job-name=slime-qwen3-opd50k +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=64 +#SBATCH --gres=gpu:h200:8 +#SBATCH --mem=0 +#SBATCH --time=03:00:00 +#SBATCH --mail-user=suryadv@cs.washington.edu +#SBATCH --mail-type=END,FAIL + +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" +source "${SCRIPT_DIR}/env.sh" + +mkdir -p "${SLURM_LOG_DIR}" "${OPD_SAVE_DIR}" "${OPD_ROLLOUT_LOG_DIR}" "${TEACHER_LOG_DIR}" +if [[ -n "${SLURM_JOB_ID:-}" ]]; then + exec >"${SLURM_LOG_DIR}/${SLURM_JOB_NAME}_${SLURM_JOB_ID}.log" 2>&1 +fi + +for required_path in \ + "${OPD_JSONL}" \ + "${STUDENT_HF_DIR}" \ + "${STUDENT_TORCH_DIST_DIR}/latest_checkpointed_iteration.txt" \ + "${TEACHER_HF_DIR}" \ + "${SFT_SAVE_DIR}/latest_checkpointed_iteration.txt"; do + if [[ ! -e "${required_path}" ]]; then + echo "Missing required path: ${required_path}" >&2 + exit 1 + fi +done + +echo "Starting OPD on $(hostname) at $(date)" +echo "OPD data: ${OPD_JSONL}" +echo "Teacher: ${TEACHER_HF_DIR} on physical GPU ${OPD_TEACHER_GPU}" +echo "Ray GPUs: 0-$((OPD_RAY_GPUS - 1))" +echo "Save dir: ${OPD_SAVE_DIR}" + +"${SCRIPT_DIR}/container_exec.sh" bash -lc ' +set -euo pipefail +cd "${SLIME_REPO_ROOT}" + +ray stop --force >/dev/null 2>&1 || true + +TEACHER_IP="127.0.0.1" +TEACHER_PORT="${OPD_TEACHER_PORT}" +TEACHER_LOG="${TEACHER_LOG_DIR}/sglang_teacher_${SLURM_JOB_ID:-manual}.log" + +CUDA_VISIBLE_DEVICES="${OPD_TEACHER_GPU}" python3 -m sglang.launch_server \ + --model-path "${TEACHER_HF_DIR}" \ + --host 0.0.0.0 \ + --port "${TEACHER_PORT}" \ + --tp 1 \ + --chunked-prefill-size 4096 \ + --mem-fraction-static "${OPD_TEACHER_MEM_FRACTION}" \ + >"${TEACHER_LOG}" 2>&1 & +TEACHER_PID=$! + +cleanup() { + set +e + if [[ -n "${TEACHER_PID:-}" ]]; then + kill "${TEACHER_PID}" >/dev/null 2>&1 || true + fi + ray stop --force >/dev/null 2>&1 || true +} +trap cleanup EXIT + +echo "Waiting for teacher server on ${TEACHER_IP}:${TEACHER_PORT}" +until curl -sf "http://${TEACHER_IP}:${TEACHER_PORT}/health_generate" >/dev/null; do + tail -n 20 "${TEACHER_LOG}" || true + sleep 10 +done +curl "http://${TEACHER_IP}:${TEACHER_PORT}/get_model_info" || true +sleep 10 + +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o "NV[0-9][0-9]*" | wc -l || true) +if [[ "${NVLINK_COUNT}" -gt 0 ]]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK=${HAS_NVLINK} (detected ${NVLINK_COUNT} NVLink references)" + +source scripts/models/qwen3-8B.sh + +CKPT_ARGS=( + --hf-checkpoint "${STUDENT_HF_DIR}" + --ref-load "${STUDENT_TORCH_DIST_DIR}" + --load "${SFT_SAVE_DIR}" + --save "${OPD_SAVE_DIR}" + --save-interval "${OPD_SAVE_INTERVAL}" + --no-save-optim +) + +ROLLOUT_ARGS=( + --prompt-data "${OPD_JSONL}" + --input-key prompt + --apply-chat-template + --rollout-shuffle + --num-rollout "${OPD_NUM_ROLLOUT}" + --rollout-batch-size "${OPD_ROLLOUT_BATCH_SIZE}" + --n-samples-per-prompt "${OPD_N_SAMPLES_PER_PROMPT}" + --rollout-max-response-len "${OPD_MAX_RESPONSE_LEN}" + --rollout-temperature 1 + --rollout-top-p 1 + --global-batch-size "${OPD_GLOBAL_BATCH_SIZE}" + --balance-data + --save-debug-rollout-data "${OPD_ROLLOUT_LOG_DIR}/rollout_{rollout_id}.pt" +) + +RM_ARGS=( + --custom-rm-path slime.rollout.on_policy_distillation.reward_func + --custom-reward-post-process-path slime.rollout.on_policy_distillation.post_process_rewards + --rm-url "http://${TEACHER_IP}:${TEACHER_PORT}/generate" +) + +PERF_ARGS=( + --tensor-model-parallel-size 2 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + --use-dynamic-batch-size + --max-tokens-per-gpu "${OPD_MAX_TOKENS_PER_GPU}" +) + +GRPO_ARGS=( + --advantage-estimator grpo + --use-opd + --opd-type sglang + --opd-kl-coef 1.0 + --use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --entropy-coef 0.00 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +WANDB_ARGS=() +if [[ "${WANDB_MODE}" != "disabled" ]]; then + WANDB_ARGS+=(--wandb-project slime-tillicum --wandb-group qwen3-8b-opd50k) +fi + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 1 + --sglang-mem-fraction-static 0.4 +) + +MISC_ARGS=( + --attention-dropout 0.0 + --hidden-dropout 0.0 + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + --attention-backend flash +) + +export MASTER_ADDR="${MASTER_ADDR:-127.0.0.1}" +export no_proxy="127.0.0.1,localhost,${MASTER_ADDR}" +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 ray start --head --node-ip-address "${MASTER_ADDR}" \ + --num-gpus "${OPD_RAY_GPUS}" \ + --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 + +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM:${SLIME_REPO_ROOT}\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\", + \"PYTORCH_CUDA_ALLOC_CONF\": \"expandable_segments:True\", + \"WANDB_MODE\": \"${WANDB_MODE}\", + \"HF_HOME\": \"${HF_HOME}\" + } +}" + +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node "${OPD_ACTOR_GPUS}" \ + --rollout-num-gpus "${OPD_ROLLOUT_GPUS}" \ + "${MODEL_ARGS[@]}" \ + "${CKPT_ARGS[@]}" \ + "${ROLLOUT_ARGS[@]}" \ + "${OPTIMIZER_ARGS[@]}" \ + "${GRPO_ARGS[@]}" \ + "${WANDB_ARGS[@]}" \ + "${PERF_ARGS[@]}" \ + "${SGLANG_ARGS[@]}" \ + "${MISC_ARGS[@]}" \ + "${RM_ARGS[@]}" +' + +echo "Finished OPD at $(date)" diff --git a/examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch b/examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch new file mode 100755 index 000000000..c5619c7a5 --- /dev/null +++ b/examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch @@ -0,0 +1,230 @@ +#!/bin/bash +#SBATCH --job-name=slime-qwen3-math500 +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=64 +#SBATCH --gres=gpu:h200:8 +#SBATCH --mem=0 +#SBATCH --time=01:00:00 +#SBATCH --mail-user=suryadv@cs.washington.edu +#SBATCH --mail-type=END,FAIL + +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" +source "${SCRIPT_DIR}/env.sh" + +mkdir -p "${SLURM_LOG_DIR}" "${EVAL_OUTPUT_DIR}" "${DATA_ROOT}" +if [[ -n "${SLURM_JOB_ID:-}" ]]; then + exec >"${SLURM_LOG_DIR}/${SLURM_JOB_NAME}_${SLURM_JOB_ID}.log" 2>&1 +fi + +for required_path in \ + "${STUDENT_HF_DIR}" \ + "${STUDENT_TORCH_DIST_DIR}/latest_checkpointed_iteration.txt" \ + "${SFT_SAVE_DIR}/latest_checkpointed_iteration.txt" \ + "${OPD_SAVE_DIR}/latest_checkpointed_iteration.txt"; do + if [[ ! -e "${required_path}" ]]; then + echo "Missing required path: ${required_path}" >&2 + exit 1 + fi +done + +echo "Starting MATH-500 greedy eval on $(hostname) at $(date)" +echo "Eval output: ${EVAL_OUTPUT_DIR}" + +"${SCRIPT_DIR}/container_exec.sh" bash -lc ' +set -euo pipefail +cd "${SLIME_REPO_ROOT}" + +python3 - </dev/null 2>&1 || true + + NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o "NV[0-9][0-9]*" | wc -l || true) + if [[ "${NVLINK_COUNT}" -gt 0 ]]; then + HAS_NVLINK=1 + else + HAS_NVLINK=0 + fi + + source scripts/models/qwen3-8B.sh + + CKPT_ARGS=( + --hf-checkpoint "${STUDENT_HF_DIR}" + --ref-load "${STUDENT_TORCH_DIST_DIR}" + --load "${load_dir}" + --save "${stage_dir}/unused_save" + --save-interval 100000 + --no-save-optim + ) + + ROLLOUT_ARGS=( + --prompt-data "${OPD_JSONL}" + --input-key prompt + --apply-chat-template + --num-rollout 0 + --rollout-batch-size "${EVAL_ROLLOUT_BATCH_SIZE}" + --n-samples-per-prompt 1 + --rollout-max-response-len 1024 + --rollout-temperature 0 + --rollout-top-p 1 + --global-batch-size "${EVAL_ROLLOUT_BATCH_SIZE}" + --rm-type deepscaler + --save-debug-rollout-data "${stage_dir}/debug_{rollout_id}.pt" + ) + + EVAL_ARGS=( + --eval-interval 1 + --eval-config "${MATH500_CONFIG}" + --n-samples-per-eval-prompt 1 + --eval-temperature 0 + --eval-top-p 1 + --eval-max-response-len "${EVAL_MAX_RESPONSE_LEN}" + --log-passrate + ) + + PERF_ARGS=( + --tensor-model-parallel-size 2 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + --use-dynamic-batch-size + --max-tokens-per-gpu "${EVAL_MAX_RESPONSE_LEN}" + ) + + OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 + ) + + GRPO_ARGS=( + --advantage-estimator grpo + --use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --entropy-coef 0.00 + ) + + SGLANG_ARGS=( + --rollout-num-gpus-per-engine 1 + --sglang-mem-fraction-static 0.7 + ) + + MISC_ARGS=( + --attention-dropout 0.0 + --hidden-dropout 0.0 + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + --attention-backend flash + ) + + export MASTER_ADDR="${MASTER_ADDR:-127.0.0.1}" + export no_proxy="127.0.0.1,localhost,${MASTER_ADDR}" + ray start --head --node-ip-address "${MASTER_ADDR}" --num-gpus "${EVAL_ROLLOUT_NUM_GPUS}" \ + --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 + + RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM:${SLIME_REPO_ROOT}\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\", + \"PYTORCH_CUDA_ALLOC_CONF\": \"expandable_segments:True\", + \"WANDB_MODE\": \"${WANDB_MODE}\", + \"HF_HOME\": \"${HF_HOME}\" + } + }" + + ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node 8 \ + --rollout-num-gpus "${EVAL_ROLLOUT_NUM_GPUS}" \ + --colocate \ + "${MODEL_ARGS[@]}" \ + "${CKPT_ARGS[@]}" \ + "${ROLLOUT_ARGS[@]}" \ + "${OPTIMIZER_ARGS[@]}" \ + "${GRPO_ARGS[@]}" \ + "${PERF_ARGS[@]}" \ + "${EVAL_ARGS[@]}" \ + "${SGLANG_ARGS[@]}" \ + "${MISC_ARGS[@]}" + + ray stop --force >/dev/null 2>&1 || true + + python3 examples/qwen3_8b_opd_tillicum/summarize_eval.py \ + --stage "${stage}" \ + --debug-file "${stage_dir}/debug_eval_0.pt" \ + --out-json "${stage_dir}/summary.json" \ + --max-response-len "${EVAL_MAX_RESPONSE_LEN}" +} + +trap "ray stop --force >/dev/null 2>&1 || true" EXIT + +run_eval base "${STUDENT_TORCH_DIST_DIR}" +run_eval sft "${SFT_SAVE_DIR}" +run_eval opd "${OPD_SAVE_DIR}" + +python3 examples/qwen3_8b_opd_tillicum/summarize_eval.py \ + --aggregate-dir "${EVAL_OUTPUT_DIR}" \ + --out-json "${EVAL_OUTPUT_DIR}/summary_all.json" +' + +echo "Finished MATH-500 eval at $(date)" diff --git a/examples/qwen3_8b_opd_tillicum/README.md b/examples/qwen3_8b_opd_tillicum/README.md new file mode 100644 index 000000000..e2c6c3965 --- /dev/null +++ b/examples/qwen3_8b_opd_tillicum/README.md @@ -0,0 +1,112 @@ +# Qwen3-8B SFT + OPD on Tillicum + +This directory contains thin Tillicum wrappers for a Qwen3-8B SFT plus +on-policy distillation experiment using OpenThoughts3-1.2M math data. + +The wrappers intentionally reuse slime's existing paths: + +- `docs/en/get_started/quick_start.md` for the container and Megatron + conversion workflow. +- `scripts/models/qwen3-8B.sh` for the student Megatron model args. +- `examples/on_policy_distillation/run-qwen3-8B-opd.sh` for the SGLang + teacher OPD shape. +- `scripts/run-qwen3-4B-base-sft.sh` for the SFT rollout/training shape. +- `examples/eval_multi_task/multi_task.sh` and `multi_task.yaml` for eval. + +No slime core code is modified. + +## Required environment + +Source `env.sh` before running commands: + +```bash +cd /gpfs/scrubbed/suryadv/repos/slime +source examples/qwen3_8b_opd_tillicum/env.sh +``` + +Important variables: + +- `ACCOUNT`: Slurm account. Default: `raivn`. +- `PARTITION`: Slurm partition. Default: `gpu-h200`. +- `QOS`: Slurm QOS. Default: `normal`. +- `SCRATCH_ROOT`: root for all generated data, checkpoints, caches, logs, and + the Apptainer image. Default: + `/gpfs/scrubbed/suryadv/slime-qwen3-8b-opd`. +- `DATA_ROOT`: prepared datasets. Default: `$SCRATCH_ROOT/data`. +- `MODEL_ROOT`: HF model snapshots and Megatron torch_dist conversion. + Default: `$SCRATCH_ROOT/models`. +- `OUTPUT_ROOT`: training/eval outputs. Default: `$SCRATCH_ROOT/outputs`. +- `HF_HOME`: Hugging Face cache under scratch. Default: `$SCRATCH_ROOT/hf_home`. +- `WANDB_MODE`: default `offline`. +- `SLIME_SIF`: Apptainer SIF path. Default: + `$SCRATCH_ROOT/containers/slime_latest.sif`. + +The scripts avoid writing caches/checkpoints/data under home. + +## Dry checks + +Dry checks do not pull the container, download data/models, install packages, or +submit real jobs. + +```bash +cd /gpfs/scrubbed/suryadv/repos/slime +source examples/qwen3_8b_opd_tillicum/env.sh +bash examples/qwen3_8b_opd_tillicum/run_all_dry_check.sh +``` + +The dry check runs `bash -n`, Python bytecode compilation, and +`sbatch --test-only` for the four Slurm scripts. If `$SLIME_SIF` already +exists and `RUN_CONTAINER_CHECKS=1` is set, it also verifies imports inside the +container. + +## Setup and launch after approval + +Run the setup steps manually, in this order: + +```bash +bash examples/qwen3_8b_opd_tillicum/00_pull_or_load_container.sh +bash examples/qwen3_8b_opd_tillicum/01_prepare_env.sh +bash examples/qwen3_8b_opd_tillicum/container_exec.sh \ + python examples/qwen3_8b_opd_tillicum/02_prepare_openthoughts3_math_sample.py +``` + +Then launch only after explicit approval: + +```bash +jid0=$(sbatch -A "$ACCOUNT" -p "$PARTITION" --qos "$QOS" \ + examples/qwen3_8b_opd_tillicum/03_convert_models_if_needed.sbatch | awk '{print $4}') +jid1=$(sbatch -A "$ACCOUNT" -p "$PARTITION" --qos "$QOS" --dependency=afterok:$jid0 \ + examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch | awk '{print $4}') +jid2=$(sbatch -A "$ACCOUNT" -p "$PARTITION" --qos "$QOS" --dependency=afterok:$jid1 \ + examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch | awk '{print $4}') +sbatch -A "$ACCOUNT" -p "$PARTITION" --qos "$QOS" --dependency=afterok:$jid2 \ + examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch +``` + +## Outputs + +- SFT split: `$SFT_PARQUET` +- OPD prompt split: `$OPD_JSONL` +- Data metadata: `$SPLIT_METADATA` +- Student HF snapshot: `$STUDENT_HF_DIR` +- Teacher HF snapshot: `$TEACHER_HF_DIR` +- Student Megatron torch_dist: `$STUDENT_TORCH_DIST_DIR` +- SFT checkpoint: `$SFT_SAVE_DIR` +- OPD checkpoint: `$OPD_SAVE_DIR` +- Eval summaries: `$EVAL_OUTPUT_DIR` + +## Slurm resources + +Each job requests one node with `--gres=gpu:h200:8`, `--ntasks=1`, +`--cpus-per-task=64`, and all node memory. The account, partition, and QOS are +passed at submit time from the environment variables above. + +The intended wall-clock budget after model/data/container preparation is: + +- SFT 100k: 2-4 hours. +- OPD 50k: 2-3 hours. +- MATH-500 greedy eval once for base, SFT, and OPD: <=1 hour. + +The main runtime risk is the Qwen3-32B teacher logprob server throughput during +OPD. If dry measurements show the full OPD run will exceed the budget, the +first reduced run to try is 100k SFT plus 25k OPD. diff --git a/examples/qwen3_8b_opd_tillicum/container_exec.sh b/examples/qwen3_8b_opd_tillicum/container_exec.sh new file mode 100755 index 000000000..515875473 --- /dev/null +++ b/examples/qwen3_8b_opd_tillicum/container_exec.sh @@ -0,0 +1,114 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" +source "${SCRIPT_DIR}/env.sh" + +if [[ $# -eq 0 ]]; then + echo "Usage: $0 [args...]" + exit 2 +fi + +if [[ ! -f "${SLIME_SIF}" ]]; then + cat >&2 </dev/null 2>&1; then + APPTAINER_BIN=apptainer + elif command -v singularity >/dev/null 2>&1; then + APPTAINER_BIN=singularity + else + echo "Neither apptainer nor singularity is available on PATH." >&2 + exit 1 + fi +fi + +mkdir -p \ + "${APPTAINER_CACHEDIR}" \ + "${APPTAINER_TMPDIR}" \ + "${CONTAINER_HOME}" \ + "${TMPDIR}" \ + "${RAY_TMPDIR}" + +export APPTAINER_CACHEDIR APPTAINER_TMPDIR + +IFS=',' read -r -a BIND_ROOTS <<< "${CONTAINER_BIND_ROOTS}" +APPTAINER_ARGS=( + exec + --nv + --cleanenv + --ipc + --writable-tmpfs + --cwd "${SLIME_REPO_ROOT}" + --home "${CONTAINER_HOME}:${CONTAINER_HOME_INNER}" +) + +for bind_root in "${BIND_ROOTS[@]}"; do + if [[ -e "${bind_root}" ]]; then + APPTAINER_ARGS+=(--bind "${bind_root}:${bind_root}") + fi +done + +PASS_ENV=( + ACCOUNT + PARTITION + QOS + SCRATCH_ROOT + DATA_ROOT + MODEL_ROOT + OUTPUT_ROOT + HF_HOME + HF_DATASETS_CACHE + TRANSFORMERS_CACHE + WANDB_MODE + WANDB_DIR + TMPDIR + RAY_TMPDIR + SLIME_REPO_ROOT + TILLICUM_EXAMPLE_DIR + STUDENT_HF_REPO + TEACHER_HF_REPO + OT3_DATASET + MATH500_DATASET + STUDENT_HF_DIR + TEACHER_HF_DIR + STUDENT_TORCH_DIST_DIR + SFT_PARQUET + OPD_JSONL + SPLIT_METADATA + MATH500_JSONL + MATH500_CONFIG + SFT_SAVE_DIR + OPD_SAVE_DIR + SFT_DETAILS_DIR + OPD_ROLLOUT_LOG_DIR + TEACHER_LOG_DIR + EVAL_OUTPUT_DIR + SLURM_LOG_DIR + PYTHONUNBUFFERED + TOKENIZERS_PARALLELISM + NCCL_DEBUG +) + +for name in "${PASS_ENV[@]}"; do + if [[ -n "${!name-}" ]]; then + APPTAINER_ARGS+=(--env "${name}=${!name}") + fi +done + +APPTAINER_ARGS+=( + --env "HOME=${CONTAINER_HOME_INNER}" + --env "PYTHONPATH=${SLIME_REPO_ROOT}:/root/Megatron-LM:${PYTHONPATH:-}" + --env "CUDA_DEVICE_MAX_CONNECTIONS=${CUDA_DEVICE_MAX_CONNECTIONS:-1}" +) + +"${APPTAINER_BIN}" "${APPTAINER_ARGS[@]}" "${SLIME_SIF}" "$@" diff --git a/examples/qwen3_8b_opd_tillicum/env.sh b/examples/qwen3_8b_opd_tillicum/env.sh new file mode 100755 index 000000000..d2e47ca3d --- /dev/null +++ b/examples/qwen3_8b_opd_tillicum/env.sh @@ -0,0 +1,99 @@ +#!/usr/bin/env bash + +# Source this file from the repository root or from any script in this +# directory. All generated state is kept under scrubbed/scratch storage. + +if [[ "${BASH_SOURCE[0]}" == "${0}" ]]; then + echo "Source this file instead of executing it:" + echo " source ${BASH_SOURCE[0]}" + exit 2 +fi + +export TILLICUM_EXAMPLE_DIR +TILLICUM_EXAMPLE_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" + +export SLIME_REPO_ROOT +SLIME_REPO_ROOT="$(cd -- "${TILLICUM_EXAMPLE_DIR}/../.." >/dev/null 2>&1 && pwd)" + +export ACCOUNT="${ACCOUNT:-raivn}" +export PARTITION="${PARTITION:-gpu-h200}" +export QOS="${QOS:-normal}" + +export SCRATCH_ROOT="${SCRATCH_ROOT:-/gpfs/scrubbed/suryadv/slime-qwen3-8b-opd}" +export DATA_ROOT="${DATA_ROOT:-${SCRATCH_ROOT}/data}" +export MODEL_ROOT="${MODEL_ROOT:-${SCRATCH_ROOT}/models}" +export OUTPUT_ROOT="${OUTPUT_ROOT:-${SCRATCH_ROOT}/outputs}" +export HF_HOME="${HF_HOME:-${SCRATCH_ROOT}/hf_home}" +export HF_DATASETS_CACHE="${HF_DATASETS_CACHE:-${HF_HOME}/datasets}" +export TRANSFORMERS_CACHE="${TRANSFORMERS_CACHE:-${HF_HOME}/transformers}" +export WANDB_MODE="${WANDB_MODE:-offline}" +export WANDB_DIR="${WANDB_DIR:-${OUTPUT_ROOT}/wandb}" +export TMPDIR="${TMPDIR:-${SCRATCH_ROOT}/tmp}" +export RAY_TMPDIR="${RAY_TMPDIR:-${SCRATCH_ROOT}/ray_tmp}" + +export APPTAINER_CACHEDIR="${APPTAINER_CACHEDIR:-${SCRATCH_ROOT}/apptainer_cache}" +export APPTAINER_TMPDIR="${APPTAINER_TMPDIR:-${SCRATCH_ROOT}/apptainer_tmp}" +export SLIME_IMAGE_URI="${SLIME_IMAGE_URI:-docker://slimerl/slime:latest}" +export SLIME_SIF="${SLIME_SIF:-${SCRATCH_ROOT}/containers/slime_latest.sif}" +export CONTAINER_BIND_ROOTS="${CONTAINER_BIND_ROOTS:-/gpfs/scrubbed/suryadv,/tmp}" +export CONTAINER_HOME="${CONTAINER_HOME:-${SCRATCH_ROOT}/container_home}" +export CONTAINER_HOME_INNER="${CONTAINER_HOME_INNER:-/home/${USER:-slime}}" + +export STUDENT_HF_REPO="${STUDENT_HF_REPO:-Qwen/Qwen3-8B-Base}" +export TEACHER_HF_REPO="${TEACHER_HF_REPO:-Qwen/Qwen3-32B}" +export OT3_DATASET="${OT3_DATASET:-open-thoughts/OpenThoughts3-1.2M}" +export MATH500_DATASET="${MATH500_DATASET:-HuggingFaceH4/MATH-500}" + +export STUDENT_HF_DIR="${STUDENT_HF_DIR:-${MODEL_ROOT}/Qwen3-8B-Base}" +export TEACHER_HF_DIR="${TEACHER_HF_DIR:-${MODEL_ROOT}/Qwen3-32B}" +export STUDENT_TORCH_DIST_DIR="${STUDENT_TORCH_DIST_DIR:-${MODEL_ROOT}/Qwen3-8B-Base_torch_dist}" + +export SFT_SIZE="${SFT_SIZE:-100000}" +export OPD_SIZE="${OPD_SIZE:-50000}" +export DATA_SEED="${DATA_SEED:-1234}" +export DATA_MATH_FIELD="${DATA_MATH_FIELD:-domain}" +export DATA_MATH_VALUE="${DATA_MATH_VALUE:-math}" +export OT3_SPLIT="${OT3_SPLIT:-train}" + +export SFT_PARQUET="${SFT_PARQUET:-${DATA_ROOT}/openthoughts3_math_sft_${SFT_SIZE}.parquet}" +export OPD_JSONL="${OPD_JSONL:-${DATA_ROOT}/openthoughts3_math_opd_${OPD_SIZE}.jsonl}" +export SPLIT_METADATA="${SPLIT_METADATA:-${DATA_ROOT}/openthoughts3_math_split_metadata.json}" +export MATH500_JSONL="${MATH500_JSONL:-${DATA_ROOT}/math500_deepscaler.jsonl}" +export MATH500_CONFIG="${MATH500_CONFIG:-${DATA_ROOT}/math500_eval.yaml}" + +export SFT_SAVE_DIR="${SFT_SAVE_DIR:-${OUTPUT_ROOT}/qwen3_8b_sft_100k}" +export OPD_SAVE_DIR="${OPD_SAVE_DIR:-${OUTPUT_ROOT}/qwen3_8b_sft_100k_opd_50k}" +export SFT_DETAILS_DIR="${SFT_DETAILS_DIR:-${OUTPUT_ROOT}/sft_details}" +export OPD_ROLLOUT_LOG_DIR="${OPD_ROLLOUT_LOG_DIR:-${OUTPUT_ROOT}/opd_rollout_logs}" +export TEACHER_LOG_DIR="${TEACHER_LOG_DIR:-${OUTPUT_ROOT}/teacher_logs}" +export EVAL_OUTPUT_DIR="${EVAL_OUTPUT_DIR:-${OUTPUT_ROOT}/math500_eval_1x}" +export SLURM_LOG_DIR="${SLURM_LOG_DIR:-${OUTPUT_ROOT}/slurm_logs}" + +export SFT_NUM_EPOCH="${SFT_NUM_EPOCH:-1}" +export SFT_ROLLOUT_BATCH_SIZE="${SFT_ROLLOUT_BATCH_SIZE:-256}" +export SFT_GLOBAL_BATCH_SIZE="${SFT_GLOBAL_BATCH_SIZE:-256}" +export SFT_MAX_TOKENS_PER_GPU="${SFT_MAX_TOKENS_PER_GPU:-16384}" +export SFT_SAVE_INTERVAL="${SFT_SAVE_INTERVAL:-50}" + +export OPD_NUM_ROLLOUT="${OPD_NUM_ROLLOUT:-391}" +export OPD_ROLLOUT_BATCH_SIZE="${OPD_ROLLOUT_BATCH_SIZE:-128}" +export OPD_GLOBAL_BATCH_SIZE="${OPD_GLOBAL_BATCH_SIZE:-128}" +export OPD_N_SAMPLES_PER_PROMPT="${OPD_N_SAMPLES_PER_PROMPT:-1}" +export OPD_MAX_RESPONSE_LEN="${OPD_MAX_RESPONSE_LEN:-16384}" +export OPD_MAX_TOKENS_PER_GPU="${OPD_MAX_TOKENS_PER_GPU:-16384}" +export OPD_SAVE_INTERVAL="${OPD_SAVE_INTERVAL:-50}" +export OPD_ACTOR_GPUS="${OPD_ACTOR_GPUS:-2}" +export OPD_ROLLOUT_GPUS="${OPD_ROLLOUT_GPUS:-5}" +export OPD_RAY_GPUS="${OPD_RAY_GPUS:-7}" +export OPD_TEACHER_GPU="${OPD_TEACHER_GPU:-7}" +export OPD_TEACHER_PORT="${OPD_TEACHER_PORT:-13141}" +export OPD_TEACHER_MEM_FRACTION="${OPD_TEACHER_MEM_FRACTION:-0.6}" + +export EVAL_MAX_RESPONSE_LEN="${EVAL_MAX_RESPONSE_LEN:-32768}" +export EVAL_ROLLOUT_BATCH_SIZE="${EVAL_ROLLOUT_BATCH_SIZE:-64}" +export EVAL_ROLLOUT_NUM_GPUS="${EVAL_ROLLOUT_NUM_GPUS:-8}" +export EVAL_NUM_REPEATS="${EVAL_NUM_REPEATS:-1}" + +export PYTHONUNBUFFERED=1 +export TOKENIZERS_PARALLELISM="${TOKENIZERS_PARALLELISM:-false}" +export NCCL_DEBUG="${NCCL_DEBUG:-WARN}" diff --git a/examples/qwen3_8b_opd_tillicum/run_all_dry_check.sh b/examples/qwen3_8b_opd_tillicum/run_all_dry_check.sh new file mode 100755 index 000000000..f99db7d73 --- /dev/null +++ b/examples/qwen3_8b_opd_tillicum/run_all_dry_check.sh @@ -0,0 +1,72 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" +source "${SCRIPT_DIR}/env.sh" + +cd "${SLIME_REPO_ROOT}" + +echo "Dry check environment" +echo " repo: ${SLIME_REPO_ROOT}" +echo " account/partition/qos: ${ACCOUNT}/${PARTITION}/${QOS}" +echo " scratch: ${SCRATCH_ROOT}" +echo " container: ${SLIME_SIF}" + +SHELL_FILES=( + examples/qwen3_8b_opd_tillicum/env.sh + examples/qwen3_8b_opd_tillicum/container_exec.sh + examples/qwen3_8b_opd_tillicum/00_pull_or_load_container.sh + examples/qwen3_8b_opd_tillicum/01_prepare_env.sh + examples/qwen3_8b_opd_tillicum/run_all_dry_check.sh + examples/qwen3_8b_opd_tillicum/03_convert_models_if_needed.sbatch + examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch + examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch + examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch +) + +PYTHON_FILES=( + examples/qwen3_8b_opd_tillicum/02_prepare_openthoughts3_math_sample.py + examples/qwen3_8b_opd_tillicum/summarize_eval.py +) + +echo "Checking shell syntax" +for file in "${SHELL_FILES[@]}"; do + bash -n "${file}" +done + +echo "Checking Python syntax" +python3 -m py_compile "${PYTHON_FILES[@]}" + +SBATCH_FILES=( + examples/qwen3_8b_opd_tillicum/03_convert_models_if_needed.sbatch + examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch + examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch + examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch +) + +echo "Checking Slurm scripts with sbatch --test-only" +for file in "${SBATCH_FILES[@]}"; do + sbatch --test-only -A "${ACCOUNT}" -p "${PARTITION}" --qos "${QOS}" "${file}" +done + +if [[ "${RUN_CONTAINER_CHECKS:-0}" == "1" ]]; then + if [[ -f "${SLIME_SIF}" ]]; then + echo "Checking imports inside container" + "${SCRIPT_DIR}/container_exec.sh" python -c "import slime, sglang, transformers, datasets; print('container imports ok')" + else + echo "RUN_CONTAINER_CHECKS=1 but SLIME_SIF does not exist; skipping import check." + fi +else + echo "Skipping container import check. Set RUN_CONTAINER_CHECKS=1 after the SIF exists." +fi + +echo "git diff --stat" +git diff --stat + +echo "git diff --name-only" +git diff --name-only + +echo "git status --short" +git status --short + +echo "Dry checks completed. No real jobs were submitted." diff --git a/examples/qwen3_8b_opd_tillicum/summarize_eval.py b/examples/qwen3_8b_opd_tillicum/summarize_eval.py new file mode 100755 index 000000000..8ac6132e1 --- /dev/null +++ b/examples/qwen3_8b_opd_tillicum/summarize_eval.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python3 +"""Summarize slime eval debug rollout files.""" + +from __future__ import annotations + +import argparse +import json +import math +from pathlib import Path +from typing import Any + +import torch + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--stage", default=None) + parser.add_argument("--debug-file", default=None) + parser.add_argument("--out-json", required=True) + parser.add_argument("--max-response-len", type=int, default=None) + parser.add_argument("--aggregate-dir", default=None) + return parser.parse_args() + + +def reward_value(sample: dict[str, Any]) -> float: + reward = sample.get("reward", 0.0) + if isinstance(reward, dict): + for key in ("reward", "score", "acc", "accuracy"): + if key in reward: + return float(reward[key]) + return 0.0 + return float(reward or 0.0) + + +def extract_answer_or_none(response: str) -> str | None: + try: + from slime.rollout.rm_hub.math_utils import extract_answer + except Exception: + return None + + if "" in response: + response = response.split("")[-1] + elif "###Response" in response: + response = response.split("###Response", 1)[1] + return extract_answer(response) + + +def summarize_debug_file(stage: str, debug_file: Path, max_response_len: int | None) -> dict[str, Any]: + payload = torch.load(debug_file, map_location="cpu", weights_only=False) + samples = payload.get("samples", []) + if not samples: + raise RuntimeError(f"No samples found in {debug_file}") + + rewards = [reward_value(sample) for sample in samples] + response_lengths = [int(sample.get("response_length") or 0) for sample in samples] + statuses = [str(sample.get("status", "")) for sample in samples] + + parse_failures = 0 + for sample in samples: + response = str(sample.get("response", "")) + if extract_answer_or_none(response) is None: + parse_failures += 1 + + cap_hits = 0 + for length, status in zip(response_lengths, statuses, strict=True): + if status == "truncated" or (max_response_len is not None and length >= max_response_len): + cap_hits += 1 + + n = len(samples) + accuracy = sum(rewards) / n + return { + "stage": stage, + "debug_file": str(debug_file), + "n": n, + "accuracy": accuracy, + "mean_accuracy": accuracy, + "std_accuracy": None, + "avg_generated_tokens": sum(response_lengths) / n, + "max_generated_tokens": max(response_lengths), + "parse_failure_rate": parse_failures / n, + "cap_hit_rate": cap_hits / n, + } + + +def aggregate(aggregate_dir: Path) -> dict[str, Any]: + summaries = [] + for path in sorted(aggregate_dir.glob("*/summary.json")): + summaries.append(json.loads(path.read_text(encoding="utf-8"))) + if not summaries: + raise RuntimeError(f"No per-stage summaries found under {aggregate_dir}") + + return { + "num_repeats_per_stage": 1, + "stages": summaries, + "note": "Eval repeat count is 1, so per-stage std_accuracy is N/A.", + } + + +def main() -> None: + args = parse_args() + out = Path(args.out_json) + out.parent.mkdir(parents=True, exist_ok=True) + + if args.aggregate_dir: + summary = aggregate(Path(args.aggregate_dir)) + else: + if not args.stage or not args.debug_file: + raise SystemExit("--stage and --debug-file are required unless --aggregate-dir is used") + summary = summarize_debug_file(args.stage, Path(args.debug_file), args.max_response_len) + + out.write_text(json.dumps(summary, indent=2) + "\n", encoding="utf-8") + print(json.dumps(summary, indent=2)) + + +if __name__ == "__main__": + main() From 08a87a1a2eece843e3e539c80caffebbd28ea82e Mon Sep 17 00:00:00 2001 From: Surya Duraivenkatesh Date: Fri, 26 Jun 2026 17:56:00 -0700 Subject: [PATCH 02/13] Use Apptainer sandbox for Tillicum example --- .../00_pull_or_load_container.sh | 23 +++++++++++++++---- .../qwen3_8b_opd_tillicum/01_prepare_env.sh | 1 + examples/qwen3_8b_opd_tillicum/README.md | 7 ++++-- .../qwen3_8b_opd_tillicum/container_exec.sh | 4 ++-- examples/qwen3_8b_opd_tillicum/env.sh | 11 ++++++++- .../run_all_dry_check.sh | 4 ++-- 6 files changed, 38 insertions(+), 12 deletions(-) diff --git a/examples/qwen3_8b_opd_tillicum/00_pull_or_load_container.sh b/examples/qwen3_8b_opd_tillicum/00_pull_or_load_container.sh index a13e1e266..109fa8a82 100755 --- a/examples/qwen3_8b_opd_tillicum/00_pull_or_load_container.sh +++ b/examples/qwen3_8b_opd_tillicum/00_pull_or_load_container.sh @@ -22,15 +22,28 @@ export APPTAINER_CACHEDIR APPTAINER_TMPDIR cat <&2 + echo "Move it aside or remove it manually, then rerun this script." >&2 + exit 1 +fi + +if [[ "${SLIME_CONTAINER_FORMAT}" == "sandbox" ]]; then + "${APPTAINER_BIN}" build --sandbox "${SLIME_SIF}" "${SLIME_IMAGE_URI}" +else + "${APPTAINER_BIN}" pull --force "${SLIME_SIF}" "${SLIME_IMAGE_URI}" +fi + +"${APPTAINER_BIN}" exec --cleanenv "${SLIME_SIF}" python -V +echo "Wrote and validated ${SLIME_SIF}" diff --git a/examples/qwen3_8b_opd_tillicum/01_prepare_env.sh b/examples/qwen3_8b_opd_tillicum/01_prepare_env.sh index 2d6eb44e0..a9a880a7d 100755 --- a/examples/qwen3_8b_opd_tillicum/01_prepare_env.sh +++ b/examples/qwen3_8b_opd_tillicum/01_prepare_env.sh @@ -38,6 +38,7 @@ Slurm Container SLIME_IMAGE_URI=${SLIME_IMAGE_URI} + SLIME_CONTAINER_FORMAT=${SLIME_CONTAINER_FORMAT} SLIME_SIF=${SLIME_SIF} bind roots=${CONTAINER_BIND_ROOTS} diff --git a/examples/qwen3_8b_opd_tillicum/README.md b/examples/qwen3_8b_opd_tillicum/README.md index e2c6c3965..a31a8404e 100644 --- a/examples/qwen3_8b_opd_tillicum/README.md +++ b/examples/qwen3_8b_opd_tillicum/README.md @@ -38,8 +38,11 @@ Important variables: - `OUTPUT_ROOT`: training/eval outputs. Default: `$SCRATCH_ROOT/outputs`. - `HF_HOME`: Hugging Face cache under scratch. Default: `$SCRATCH_ROOT/hf_home`. - `WANDB_MODE`: default `offline`. -- `SLIME_SIF`: Apptainer SIF path. Default: - `$SCRATCH_ROOT/containers/slime_latest.sif`. +- `SLIME_CONTAINER_FORMAT`: Apptainer image format. Default: `sandbox`, + because Tillicum's Apptainer produced invalid SquashFS SIFs for this large + Docker image during testing. Set to `sif` to force SIF output. +- `SLIME_SIF`: Apptainer image/sandbox path. Default: + `$SCRATCH_ROOT/containers/slime_latest.sandbox`. The scripts avoid writing caches/checkpoints/data under home. diff --git a/examples/qwen3_8b_opd_tillicum/container_exec.sh b/examples/qwen3_8b_opd_tillicum/container_exec.sh index 515875473..6ee541895 100755 --- a/examples/qwen3_8b_opd_tillicum/container_exec.sh +++ b/examples/qwen3_8b_opd_tillicum/container_exec.sh @@ -9,9 +9,9 @@ if [[ $# -eq 0 ]]; then exit 2 fi -if [[ ! -f "${SLIME_SIF}" ]]; then +if [[ ! -e "${SLIME_SIF}" ]]; then cat >&2 < Date: Fri, 26 Jun 2026 18:21:47 -0700 Subject: [PATCH 03/13] Speed up Tillicum data prep sampling --- .../02_prepare_openthoughts3_math_sample.py | 28 +++++++++++++++---- 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/examples/qwen3_8b_opd_tillicum/02_prepare_openthoughts3_math_sample.py b/examples/qwen3_8b_opd_tillicum/02_prepare_openthoughts3_math_sample.py index e9ac493a8..a24a9fd3a 100755 --- a/examples/qwen3_8b_opd_tillicum/02_prepare_openthoughts3_math_sample.py +++ b/examples/qwen3_8b_opd_tillicum/02_prepare_openthoughts3_math_sample.py @@ -125,6 +125,21 @@ def write_jsonl(path: Path, rows: list[dict[str, Any]]) -> None: f.write(json.dumps(row, ensure_ascii=False) + "\n") +def iter_dataset_rows(ds: Dataset, indices: list[int], label: str, batch_size: int = 1000): + selected = ds.select(indices) + total = len(indices) + seen = 0 + for batch in selected.iter(batch_size=batch_size): + keys = list(batch.keys()) + if not keys: + continue + for row_idx in range(len(batch[keys[0]])): + seen += 1 + if seen % 5000 == 0 or seen == total: + print(f"{label}: {seen}/{total}", flush=True) + yield {key: batch[key][row_idx] for key in keys} + + def token_lengths(tokenizer: Any, sft_rows: list[dict[str, Any]], opd_rows: list[dict[str, Any]]) -> dict[str, Any]: sft_lengths: list[int] = [] for row in sft_rows: @@ -192,15 +207,16 @@ def main() -> None: rng = random.Random(args.seed) shuffled = list(range(len(math_ds))) rng.shuffle(shuffled) - sft_indices = shuffled[: args.sft_size] - opd_indices = shuffled[args.sft_size : required] + # Split membership is seeded by the shuffled order. Materialize in dataset + # order to avoid very slow random Arrow row reads on GPFS. + sft_indices = sorted(shuffled[: args.sft_size]) + opd_indices = sorted(shuffled[args.sft_size : required]) sft_rows: list[dict[str, Any]] = [] opd_rows: list[dict[str, Any]] = [] print(f"Building SFT rows: {len(sft_indices)}") - for idx in sft_indices: - row = math_ds[int(idx)] + for row in iter_dataset_rows(math_ds, sft_indices, "SFT rows"): messages = extract_messages(row) prompt = extract_prompt(row, messages) source_row_id = int(row["source_row_id"]) @@ -218,8 +234,7 @@ def main() -> None: ) print(f"Building OPD rows: {len(opd_indices)}") - for idx in opd_indices: - row = math_ds[int(idx)] + for row in iter_dataset_rows(math_ds, opd_indices, "OPD rows"): messages = extract_messages(row) prompt = extract_prompt(row, messages) source_row_id = int(row["source_row_id"]) @@ -258,6 +273,7 @@ def main() -> None: "value": args.math_value, "operation": "case-insensitive equality", }, + "sample_materialization": "Seeded split membership, indices sorted within each split before row materialization.", "counts": { "source_rows_seen": len(ds), "math_rows_after_filter": len(math_ds), From 0a9195f8d92c00c1bb407ef685131dbcea85688f Mon Sep 17 00:00:00 2001 From: Surya Duraivenkatesh Date: Fri, 26 Jun 2026 18:44:19 -0700 Subject: [PATCH 04/13] Fix Tillicum sbatch script path resolution --- .../03_convert_models_if_needed.sbatch | 6 +++++- .../qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch | 6 +++++- examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch | 6 +++++- .../qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch | 6 +++++- 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/examples/qwen3_8b_opd_tillicum/03_convert_models_if_needed.sbatch b/examples/qwen3_8b_opd_tillicum/03_convert_models_if_needed.sbatch index 441f0e00a..5930a8b24 100755 --- a/examples/qwen3_8b_opd_tillicum/03_convert_models_if_needed.sbatch +++ b/examples/qwen3_8b_opd_tillicum/03_convert_models_if_needed.sbatch @@ -11,7 +11,11 @@ set -euo pipefail -SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" +if [[ -n "${SLURM_SUBMIT_DIR:-}" && -f "${SLURM_SUBMIT_DIR}/examples/qwen3_8b_opd_tillicum/env.sh" ]]; then + SCRIPT_DIR="${SLURM_SUBMIT_DIR}/examples/qwen3_8b_opd_tillicum" +else + SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" +fi source "${SCRIPT_DIR}/env.sh" mkdir -p "${SLURM_LOG_DIR}" diff --git a/examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch b/examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch index 57c05d6d6..19c079294 100755 --- a/examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch +++ b/examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch @@ -11,7 +11,11 @@ set -euo pipefail -SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" +if [[ -n "${SLURM_SUBMIT_DIR:-}" && -f "${SLURM_SUBMIT_DIR}/examples/qwen3_8b_opd_tillicum/env.sh" ]]; then + SCRIPT_DIR="${SLURM_SUBMIT_DIR}/examples/qwen3_8b_opd_tillicum" +else + SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" +fi source "${SCRIPT_DIR}/env.sh" mkdir -p "${SLURM_LOG_DIR}" "${SFT_SAVE_DIR}" "${SFT_DETAILS_DIR}" diff --git a/examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch b/examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch index 06ef0a2ba..38b1ef15e 100755 --- a/examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch +++ b/examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch @@ -11,7 +11,11 @@ set -euo pipefail -SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" +if [[ -n "${SLURM_SUBMIT_DIR:-}" && -f "${SLURM_SUBMIT_DIR}/examples/qwen3_8b_opd_tillicum/env.sh" ]]; then + SCRIPT_DIR="${SLURM_SUBMIT_DIR}/examples/qwen3_8b_opd_tillicum" +else + SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" +fi source "${SCRIPT_DIR}/env.sh" mkdir -p "${SLURM_LOG_DIR}" "${OPD_SAVE_DIR}" "${OPD_ROLLOUT_LOG_DIR}" "${TEACHER_LOG_DIR}" diff --git a/examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch b/examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch index c5619c7a5..8858633f4 100755 --- a/examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch +++ b/examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch @@ -11,7 +11,11 @@ set -euo pipefail -SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" +if [[ -n "${SLURM_SUBMIT_DIR:-}" && -f "${SLURM_SUBMIT_DIR}/examples/qwen3_8b_opd_tillicum/env.sh" ]]; then + SCRIPT_DIR="${SLURM_SUBMIT_DIR}/examples/qwen3_8b_opd_tillicum" +else + SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" +fi source "${SCRIPT_DIR}/env.sh" mkdir -p "${SLURM_LOG_DIR}" "${EVAL_OUTPUT_DIR}" "${DATA_ROOT}" From 1fedba588fe7e1718d4b785051bd89f5ff1ff9c8 Mon Sep 17 00:00:00 2001 From: Surya Duraivenkatesh Date: Fri, 26 Jun 2026 18:48:43 -0700 Subject: [PATCH 05/13] Batch Tillicum data prep token stats --- .../02_prepare_openthoughts3_math_sample.py | 48 ++++++++++++++++--- 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/examples/qwen3_8b_opd_tillicum/02_prepare_openthoughts3_math_sample.py b/examples/qwen3_8b_opd_tillicum/02_prepare_openthoughts3_math_sample.py index a24a9fd3a..25d23ed2d 100755 --- a/examples/qwen3_8b_opd_tillicum/02_prepare_openthoughts3_math_sample.py +++ b/examples/qwen3_8b_opd_tillicum/02_prepare_openthoughts3_math_sample.py @@ -140,25 +140,61 @@ def iter_dataset_rows(ds: Dataset, indices: list[int], label: str, batch_size: i yield {key: batch[key][row_idx] for key in keys} +def encode_lengths(tokenizer: Any, texts: list[str]) -> list[int]: + encoded = tokenizer(texts, add_special_tokens=False) + return [len(input_ids) for input_ids in encoded["input_ids"]] + + +def flush_token_batch( + tokenizer: Any, + texts: list[str], + lengths: list[int], + label: str, + processed: int, + total: int, +) -> int: + if not texts: + return processed + lengths.extend(encode_lengths(tokenizer, texts)) + processed += len(texts) + texts.clear() + if processed % 5000 == 0 or processed == total: + print(f"{label} token stats: {processed}/{total}", flush=True) + return processed + + def token_lengths(tokenizer: Any, sft_rows: list[dict[str, Any]], opd_rows: list[dict[str, Any]]) -> dict[str, Any]: sft_lengths: list[int] = [] + sft_text_batch: list[str] = [] + sft_processed = 0 for row in sft_rows: messages = row["messages"] try: - ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=False) + text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) except Exception: text = "\n".join(f"{m['role']}: {m['content']}" for m in messages) - ids = tokenizer(text, add_special_tokens=False)["input_ids"] - sft_lengths.append(len(ids)) + sft_text_batch.append(text) + if len(sft_text_batch) >= 256: + sft_processed = flush_token_batch( + tokenizer, sft_text_batch, sft_lengths, "SFT", sft_processed, len(sft_rows) + ) + sft_processed = flush_token_batch(tokenizer, sft_text_batch, sft_lengths, "SFT", sft_processed, len(sft_rows)) opd_lengths: list[int] = [] + opd_text_batch: list[str] = [] + opd_processed = 0 for row in opd_rows: messages = [{"role": "user", "content": row["prompt"]}] try: - ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True) + text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) except Exception: - ids = tokenizer(row["prompt"], add_special_tokens=False)["input_ids"] - opd_lengths.append(len(ids)) + text = row["prompt"] + opd_text_batch.append(text) + if len(opd_text_batch) >= 256: + opd_processed = flush_token_batch( + tokenizer, opd_text_batch, opd_lengths, "OPD", opd_processed, len(opd_rows) + ) + opd_processed = flush_token_batch(tokenizer, opd_text_batch, opd_lengths, "OPD", opd_processed, len(opd_rows)) def stats(values: list[int]) -> dict[str, float | int]: return { From ed0c9e2c16a8d871961b22f48bf74a9024d94426 Mon Sep 17 00:00:00 2001 From: Surya Duraivenkatesh Date: Fri, 26 Jun 2026 18:56:50 -0700 Subject: [PATCH 06/13] Write Tillicum split files before token stats --- .../02_prepare_openthoughts3_math_sample.py | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/examples/qwen3_8b_opd_tillicum/02_prepare_openthoughts3_math_sample.py b/examples/qwen3_8b_opd_tillicum/02_prepare_openthoughts3_math_sample.py index 25d23ed2d..933b5e126 100755 --- a/examples/qwen3_8b_opd_tillicum/02_prepare_openthoughts3_math_sample.py +++ b/examples/qwen3_8b_opd_tillicum/02_prepare_openthoughts3_math_sample.py @@ -155,10 +155,12 @@ def flush_token_batch( ) -> int: if not texts: return processed + previous = processed + batch_size = len(texts) lengths.extend(encode_lengths(tokenizer, texts)) - processed += len(texts) + processed += batch_size texts.clear() - if processed % 5000 == 0 or processed == total: + if processed // 5000 > previous // 5000 or processed == total: print(f"{label} token stats: {processed}/{total}", flush=True) return processed @@ -287,19 +289,21 @@ def main() -> None: } ) - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, cache_dir=args.hf_home, trust_remote_code=True) - length_stats = token_lengths(tokenizer, sft_rows, opd_rows) - - sft_out.parent.mkdir(parents=True, exist_ok=True) - Dataset.from_list(sft_rows).to_parquet(str(sft_out)) - write_jsonl(opd_out, opd_rows) - sft_source_ids = [row["metadata"]["source_row_id"] for row in sft_rows] opd_source_ids = [row["metadata"]["source_row_id"] for row in opd_rows] overlap = sorted(set(sft_source_ids).intersection(opd_source_ids)) if overlap: raise RuntimeError(f"SFT/OPD row split overlap detected: first overlaps {overlap[:10]}") + sft_out.parent.mkdir(parents=True, exist_ok=True) + Dataset.from_list(sft_rows).to_parquet(str(sft_out)) + write_jsonl(opd_out, opd_rows) + print(f"Wrote {sft_out}") + print(f"Wrote {opd_out}") + + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, cache_dir=args.hf_home, trust_remote_code=True) + length_stats = token_lengths(tokenizer, sft_rows, opd_rows) + metadata = { "seed": args.seed, "dataset": args.dataset, @@ -335,8 +339,6 @@ def main() -> None: metadata_out.parent.mkdir(parents=True, exist_ok=True) metadata_out.write_text(json.dumps(metadata, indent=2, ensure_ascii=False) + "\n", encoding="utf-8") - print(f"Wrote {sft_out}") - print(f"Wrote {opd_out}") print(f"Wrote {metadata_out}") From c833b1943bc3ca9248573f45c6931d8329542fe1 Mon Sep 17 00:00:00 2001 From: Surya Duraivenkatesh Date: Fri, 26 Jun 2026 19:00:45 -0700 Subject: [PATCH 07/13] Stream Tillicum SFT parquet writing --- .../02_prepare_openthoughts3_math_sample.py | 22 +++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/examples/qwen3_8b_opd_tillicum/02_prepare_openthoughts3_math_sample.py b/examples/qwen3_8b_opd_tillicum/02_prepare_openthoughts3_math_sample.py index 933b5e126..a87b58675 100755 --- a/examples/qwen3_8b_opd_tillicum/02_prepare_openthoughts3_math_sample.py +++ b/examples/qwen3_8b_opd_tillicum/02_prepare_openthoughts3_math_sample.py @@ -12,6 +12,8 @@ from pathlib import Path from typing import Any +import pyarrow as pa +import pyarrow.parquet as pq from datasets import Dataset, load_dataset from transformers import AutoTokenizer @@ -125,6 +127,23 @@ def write_jsonl(path: Path, rows: list[dict[str, Any]]) -> None: f.write(json.dumps(row, ensure_ascii=False) + "\n") +def write_parquet_rows(path: Path, rows: list[dict[str, Any]], batch_size: int = 1000) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + writer: pq.ParquetWriter | None = None + try: + for start in range(0, len(rows), batch_size): + end = min(start + batch_size, len(rows)) + table = pa.Table.from_pylist(rows[start:end]) + if writer is None: + writer = pq.ParquetWriter(str(path), table.schema) + writer.write_table(table) + if end % 5000 == 0 or end == len(rows): + print(f"Parquet rows: {end}/{len(rows)}", flush=True) + finally: + if writer is not None: + writer.close() + + def iter_dataset_rows(ds: Dataset, indices: list[int], label: str, batch_size: int = 1000): selected = ds.select(indices) total = len(indices) @@ -295,8 +314,7 @@ def main() -> None: if overlap: raise RuntimeError(f"SFT/OPD row split overlap detected: first overlaps {overlap[:10]}") - sft_out.parent.mkdir(parents=True, exist_ok=True) - Dataset.from_list(sft_rows).to_parquet(str(sft_out)) + write_parquet_rows(sft_out, sft_rows) write_jsonl(opd_out, opd_rows) print(f"Wrote {sft_out}") print(f"Wrote {opd_out}") From b0999b5ec8878f95c58d6d1e07d81954b72ab064 Mon Sep 17 00:00:00 2001 From: Surya Duraivenkatesh Date: Fri, 26 Jun 2026 19:15:43 -0700 Subject: [PATCH 08/13] Forward Tillicum runtime knobs into container --- .../qwen3_8b_opd_tillicum/container_exec.sh | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/examples/qwen3_8b_opd_tillicum/container_exec.sh b/examples/qwen3_8b_opd_tillicum/container_exec.sh index 6ee541895..2d27e44e8 100755 --- a/examples/qwen3_8b_opd_tillicum/container_exec.sh +++ b/examples/qwen3_8b_opd_tillicum/container_exec.sh @@ -94,6 +94,28 @@ PASS_ENV=( TEACHER_LOG_DIR EVAL_OUTPUT_DIR SLURM_LOG_DIR + SFT_NUM_EPOCH + SFT_ROLLOUT_BATCH_SIZE + SFT_GLOBAL_BATCH_SIZE + SFT_MAX_TOKENS_PER_GPU + SFT_SAVE_INTERVAL + OPD_NUM_ROLLOUT + OPD_ROLLOUT_BATCH_SIZE + OPD_GLOBAL_BATCH_SIZE + OPD_N_SAMPLES_PER_PROMPT + OPD_MAX_RESPONSE_LEN + OPD_MAX_TOKENS_PER_GPU + OPD_SAVE_INTERVAL + OPD_ACTOR_GPUS + OPD_ROLLOUT_GPUS + OPD_RAY_GPUS + OPD_TEACHER_GPU + OPD_TEACHER_PORT + OPD_TEACHER_MEM_FRACTION + EVAL_MAX_RESPONSE_LEN + EVAL_ROLLOUT_BATCH_SIZE + EVAL_ROLLOUT_NUM_GPUS + EVAL_NUM_REPEATS PYTHONUNBUFFERED TOKENIZERS_PARALLELISM NCCL_DEBUG From c37460a659a02b0990212f7b15ccc1ebb43212cb Mon Sep 17 00:00:00 2001 From: Surya Duraivenkatesh Date: Fri, 26 Jun 2026 19:24:03 -0700 Subject: [PATCH 09/13] Use short Ray temp dirs for Tillicum jobs --- examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch | 2 ++ examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch | 2 ++ examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch | 2 ++ 3 files changed, 6 insertions(+) diff --git a/examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch b/examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch index 19c079294..6dbc76879 100755 --- a/examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch +++ b/examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch @@ -20,6 +20,8 @@ source "${SCRIPT_DIR}/env.sh" mkdir -p "${SLURM_LOG_DIR}" "${SFT_SAVE_DIR}" "${SFT_DETAILS_DIR}" if [[ -n "${SLURM_JOB_ID:-}" ]]; then + export RAY_TMPDIR="/tmp/${USER:-suryadv}/slime-ray-${SLURM_JOB_ID}" + mkdir -p "${RAY_TMPDIR}" exec >"${SLURM_LOG_DIR}/${SLURM_JOB_NAME}_${SLURM_JOB_ID}.log" 2>&1 fi diff --git a/examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch b/examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch index 38b1ef15e..165105a5f 100755 --- a/examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch +++ b/examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch @@ -20,6 +20,8 @@ source "${SCRIPT_DIR}/env.sh" mkdir -p "${SLURM_LOG_DIR}" "${OPD_SAVE_DIR}" "${OPD_ROLLOUT_LOG_DIR}" "${TEACHER_LOG_DIR}" if [[ -n "${SLURM_JOB_ID:-}" ]]; then + export RAY_TMPDIR="/tmp/${USER:-suryadv}/slime-ray-${SLURM_JOB_ID}" + mkdir -p "${RAY_TMPDIR}" exec >"${SLURM_LOG_DIR}/${SLURM_JOB_NAME}_${SLURM_JOB_ID}.log" 2>&1 fi diff --git a/examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch b/examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch index 8858633f4..8835f3e46 100755 --- a/examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch +++ b/examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch @@ -20,6 +20,8 @@ source "${SCRIPT_DIR}/env.sh" mkdir -p "${SLURM_LOG_DIR}" "${EVAL_OUTPUT_DIR}" "${DATA_ROOT}" if [[ -n "${SLURM_JOB_ID:-}" ]]; then + export RAY_TMPDIR="/tmp/${USER:-suryadv}/slime-ray-${SLURM_JOB_ID}" + mkdir -p "${RAY_TMPDIR}" exec >"${SLURM_LOG_DIR}/${SLURM_JOB_NAME}_${SLURM_JOB_ID}.log" 2>&1 fi From 029d81424b2648779588c0bb612c9cbce98b026b Mon Sep 17 00:00:00 2001 From: Surya Duraivenkatesh Date: Fri, 26 Jun 2026 19:30:07 -0700 Subject: [PATCH 10/13] Use JSONL for Tillicum SFT data --- .../02_prepare_openthoughts3_math_sample.py | 11 +++++++++-- examples/qwen3_8b_opd_tillicum/env.sh | 2 +- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/examples/qwen3_8b_opd_tillicum/02_prepare_openthoughts3_math_sample.py b/examples/qwen3_8b_opd_tillicum/02_prepare_openthoughts3_math_sample.py index a87b58675..7dd96959a 100755 --- a/examples/qwen3_8b_opd_tillicum/02_prepare_openthoughts3_math_sample.py +++ b/examples/qwen3_8b_opd_tillicum/02_prepare_openthoughts3_math_sample.py @@ -144,6 +144,13 @@ def write_parquet_rows(path: Path, rows: list[dict[str, Any]], batch_size: int = writer.close() +def write_sft_rows(path: Path, rows: list[dict[str, Any]]) -> None: + if path.suffix == ".jsonl": + write_jsonl(path, rows) + else: + write_parquet_rows(path, rows) + + def iter_dataset_rows(ds: Dataset, indices: list[int], label: str, batch_size: int = 1000): selected = ds.select(indices) total = len(indices) @@ -314,7 +321,7 @@ def main() -> None: if overlap: raise RuntimeError(f"SFT/OPD row split overlap detected: first overlaps {overlap[:10]}") - write_parquet_rows(sft_out, sft_rows) + write_sft_rows(sft_out, sft_rows) write_jsonl(opd_out, opd_rows) print(f"Wrote {sft_out}") print(f"Wrote {opd_out}") @@ -339,7 +346,7 @@ def main() -> None: "opd": len(opd_rows), }, "outputs": { - "sft_parquet": str(sft_out), + "sft_data": str(sft_out), "opd_jsonl": str(opd_out), }, "source_row_ids": { diff --git a/examples/qwen3_8b_opd_tillicum/env.sh b/examples/qwen3_8b_opd_tillicum/env.sh index 823f552da..68808683d 100755 --- a/examples/qwen3_8b_opd_tillicum/env.sh +++ b/examples/qwen3_8b_opd_tillicum/env.sh @@ -64,7 +64,7 @@ export DATA_MATH_FIELD="${DATA_MATH_FIELD:-domain}" export DATA_MATH_VALUE="${DATA_MATH_VALUE:-math}" export OT3_SPLIT="${OT3_SPLIT:-train}" -export SFT_PARQUET="${SFT_PARQUET:-${DATA_ROOT}/openthoughts3_math_sft_${SFT_SIZE}.parquet}" +export SFT_PARQUET="${SFT_PARQUET:-${DATA_ROOT}/openthoughts3_math_sft_${SFT_SIZE}.jsonl}" export OPD_JSONL="${OPD_JSONL:-${DATA_ROOT}/openthoughts3_math_opd_${OPD_SIZE}.jsonl}" export SPLIT_METADATA="${SPLIT_METADATA:-${DATA_ROOT}/openthoughts3_math_split_metadata.json}" export MATH500_JSONL="${MATH500_JSONL:-${DATA_ROOT}/math500_deepscaler.jsonl}" From 066c7e88a681539e89cc5ccce86a7452fc8361bf Mon Sep 17 00:00:00 2001 From: Surya Duraivenkatesh Date: Fri, 26 Jun 2026 20:01:15 -0700 Subject: [PATCH 11/13] Document non-Tillicum reproduction steps --- examples/qwen3_8b_opd_tillicum/README.md | 70 +++++++++++++++++++++++- 1 file changed, 69 insertions(+), 1 deletion(-) diff --git a/examples/qwen3_8b_opd_tillicum/README.md b/examples/qwen3_8b_opd_tillicum/README.md index a31a8404e..8a9695c9d 100644 --- a/examples/qwen3_8b_opd_tillicum/README.md +++ b/examples/qwen3_8b_opd_tillicum/README.md @@ -86,9 +86,77 @@ sbatch -A "$ACCOUNT" -p "$PARTITION" --qos "$QOS" --dependency=afterok:$jid2 \ examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch ``` +## Reproducing On Another Slurm Cluster + +These wrappers are Tillicum-shaped, but the core workflow is portable to a +Slurm cluster with Apptainer or Singularity, one 8-GPU node per job, outbound +access to Hugging Face and Docker Hub, and enough scratch space for large model +checkpoints. + +Clone the fork and switch to the reproduction branch: + +```bash +git clone https://github.com/suryathecreator/slime.git +cd slime +git checkout opd-reproduction +``` + +Choose cluster-local paths and Slurm settings. Keep all generated state on +scratch or project storage, not home: + +```bash +export ACCOUNT="" +export PARTITION="" +export QOS="" +export SCRATCH_ROOT="/path/to/scratch/${USER}/slime-qwen3-8b-opd" +export CONTAINER_BIND_ROOTS="$(pwd),${SCRATCH_ROOT},/tmp" + +# Match your site's GPU gres. Examples: gpu:8, gpu:a100:8, gpu:h100:8. +export GPU_GRES="gpu:8" + +# Use sif if your Apptainer can build a normal SIF for slimerl/slime:latest. +export SLIME_CONTAINER_FORMAT="sandbox" + +source examples/qwen3_8b_opd_tillicum/env.sh +``` + +Run the same preparation commands: + +```bash +bash examples/qwen3_8b_opd_tillicum/00_pull_or_load_container.sh +bash examples/qwen3_8b_opd_tillicum/01_prepare_env.sh +bash examples/qwen3_8b_opd_tillicum/container_exec.sh \ + python examples/qwen3_8b_opd_tillicum/02_prepare_openthoughts3_math_sample.py +``` + +Submit the dependency chain, overriding the Tillicum `h200` gres embedded in +the sbatch files: + +```bash +jid0=$(sbatch -A "$ACCOUNT" -p "$PARTITION" --qos "$QOS" --gres "$GPU_GRES" \ + examples/qwen3_8b_opd_tillicum/03_convert_models_if_needed.sbatch | awk '{print $4}') +jid1=$(sbatch -A "$ACCOUNT" -p "$PARTITION" --qos "$QOS" --gres "$GPU_GRES" \ + --dependency=afterok:$jid0 \ + examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch | awk '{print $4}') +jid2=$(sbatch -A "$ACCOUNT" -p "$PARTITION" --qos "$QOS" --gres "$GPU_GRES" \ + --dependency=afterok:$jid1 \ + examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch | awk '{print $4}') +sbatch -A "$ACCOUNT" -p "$PARTITION" --qos "$QOS" --gres "$GPU_GRES" \ + --dependency=afterok:$jid2 \ + examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch +``` + +If your cluster uses `--gpus-per-node` instead of `--gres`, replace the +`--gres "$GPU_GRES"` arguments above with your site's GPU request flag. If your +cluster uses Docker rather than Apptainer/Singularity, run the same Python and +Slurm entrypoints inside `slimerl/slime:latest` and bind the repo plus +`$SCRATCH_ROOT` into the container; `container_exec.sh` is the only +Apptainer-specific layer. + ## Outputs -- SFT split: `$SFT_PARQUET` +- SFT split: `$SFT_PARQUET` (JSONL by default despite the legacy variable + name). - OPD prompt split: `$OPD_JSONL` - Data metadata: `$SPLIT_METADATA` - Student HF snapshot: `$STUDENT_HF_DIR` From b8ec32a5d3d1aa63610c29df607863f0390edefb Mon Sep 17 00:00:00 2001 From: Surya Duraivenkatesh Date: Sat, 27 Jun 2026 00:15:44 -0700 Subject: [PATCH 12/13] Add Tillicum 25k SFT 10k OPD chain --- .../qwen3_8b_opd_tillicum/01_prepare_env.sh | 12 ++ .../02_prepare_data_25k_10k.sbatch | 114 ++++++++++++ .../04_run_sft_100k_8xh200.sbatch | 33 +++- .../05_run_opd_50k_8xh200.sbatch | 47 ++++- .../06_eval_math500_greedy_1x.sbatch | 90 +++++++-- examples/qwen3_8b_opd_tillicum/README.md | 44 ++--- .../qwen3_8b_opd_tillicum/checkpoint_utils.sh | 130 +++++++++++++ .../qwen3_8b_opd_tillicum/container_exec.sh | 25 +++ examples/qwen3_8b_opd_tillicum/env.sh | 47 +++-- .../run_all_dry_check.sh | 4 + .../submit_25k_10k_chain.sh | 77 ++++++++ .../qwen3_8b_opd_tillicum/summarize_eval.py | 172 +++++++++++++++++- 12 files changed, 719 insertions(+), 76 deletions(-) create mode 100755 examples/qwen3_8b_opd_tillicum/02_prepare_data_25k_10k.sbatch create mode 100755 examples/qwen3_8b_opd_tillicum/checkpoint_utils.sh create mode 100755 examples/qwen3_8b_opd_tillicum/submit_25k_10k_chain.sh diff --git a/examples/qwen3_8b_opd_tillicum/01_prepare_env.sh b/examples/qwen3_8b_opd_tillicum/01_prepare_env.sh index a9a880a7d..c73e5d475 100755 --- a/examples/qwen3_8b_opd_tillicum/01_prepare_env.sh +++ b/examples/qwen3_8b_opd_tillicum/01_prepare_env.sh @@ -21,10 +21,16 @@ mkdir -p \ "${CONTAINER_HOME}" \ "${SFT_SAVE_DIR}" \ "${OPD_SAVE_DIR}" \ + "${SFT_HF_SNAPSHOT_DIR}" \ + "${OPD_HF_SNAPSHOT_DIR}" \ "${SFT_DETAILS_DIR}" \ "${OPD_ROLLOUT_LOG_DIR}" \ "${TEACHER_LOG_DIR}" \ "${EVAL_OUTPUT_DIR}" \ + "${BASE_EVAL_OUTPUT_DIR}" \ + "${SFT_EVAL_OUTPUT_DIR}" \ + "${OPD_EVAL_OUTPUT_DIR}" \ + "${CHECKPOINT_REPORT_DIR}" \ "${SLURM_LOG_DIR}" cat </dev/null 2>&1 && pwd)" +fi +source "${SCRIPT_DIR}/env.sh" + +mkdir -p "${SLURM_LOG_DIR}" "${DATA_ROOT}" +if [[ -n "${SLURM_JOB_ID:-}" ]]; then + export RAY_TMPDIR="/tmp/${USER:-suryadv}/slime-ray-${SLURM_JOB_ID}" + mkdir -p "${RAY_TMPDIR}" + exec >"${SLURM_LOG_DIR}/${SLURM_JOB_NAME}_${SLURM_JOB_ID}.log" 2>&1 +fi + +echo "Preparing OpenThoughts3 math data at $(date)" +echo "SFT rows: ${SFT_SIZE} -> ${SFT_PARQUET}" +echo "OPD rows: ${OPD_SIZE} -> ${OPD_JSONL}" +echo "Metadata: ${SPLIT_METADATA}" + +"${SCRIPT_DIR}/container_exec.sh" bash -lc ' +set -euo pipefail +cd "${SLIME_REPO_ROOT}" + +needs_prepare=0 +for path in "${SFT_PARQUET}" "${OPD_JSONL}" "${SPLIT_METADATA}"; do + if [[ ! -s "${path}" ]]; then + needs_prepare=1 + fi +done + +if [[ "${needs_prepare}" == "0" ]]; then + if ! python3 - < int: + with path.open("rb") as f: + return sum(1 for _ in f) + +metadata = json.loads(metadata_path.read_text(encoding="utf-8")) +sft_count = count_jsonl(sft_path) +opd_count = count_jsonl(opd_path) +sft_ids = set(metadata["source_row_ids"]["sft"]) +opd_ids = set(metadata["source_row_ids"]["opd"]) + +if sft_count != sft_expected or opd_count != opd_expected or sft_ids.intersection(opd_ids): + raise SystemExit(1) +PY + then + echo "Existing data files do not match the requested split; regenerating." + needs_prepare=1 + fi +fi + +if [[ "${needs_prepare}" == "1" ]]; then + python3 examples/qwen3_8b_opd_tillicum/02_prepare_openthoughts3_math_sample.py --force +else + echo "Data files already exist; validating counts." +fi + +python3 - < int: + with path.open("rb") as f: + return sum(1 for _ in f) + +sft_count = count_jsonl(sft_path) +opd_count = count_jsonl(opd_path) +if sft_count != sft_expected: + raise SystemExit(f"SFT row count mismatch: {sft_count} != {sft_expected}") +if opd_count != opd_expected: + raise SystemExit(f"OPD row count mismatch: {opd_count} != {opd_expected}") + +metadata = json.loads(metadata_path.read_text(encoding="utf-8")) +sft_ids = set(metadata["source_row_ids"]["sft"]) +opd_ids = set(metadata["source_row_ids"]["opd"]) +overlap = sft_ids.intersection(opd_ids) +if overlap: + raise SystemExit(f"SFT/OPD metadata overlap detected: {sorted(overlap)[:10]}") + +print(f"DATA_PREP_OK sft_rows={sft_count} opd_rows={opd_count} metadata={metadata_path}") +PY +' + +echo "Finished data prep at $(date)" diff --git a/examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch b/examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch index 6dbc76879..77368b0b8 100755 --- a/examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch +++ b/examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch @@ -1,11 +1,11 @@ #!/bin/bash -#SBATCH --job-name=slime-qwen3-sft100k +#SBATCH --job-name=slime-qwen3-sft25k #SBATCH --nodes=1 #SBATCH --ntasks=1 #SBATCH --cpus-per-task=64 #SBATCH --gres=gpu:h200:8 #SBATCH --mem=0 -#SBATCH --time=04:00:00 +#SBATCH --time=03:00:00 #SBATCH --mail-user=suryadv@cs.washington.edu #SBATCH --mail-type=END,FAIL @@ -17,8 +17,9 @@ else SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" fi source "${SCRIPT_DIR}/env.sh" +source "${SCRIPT_DIR}/checkpoint_utils.sh" -mkdir -p "${SLURM_LOG_DIR}" "${SFT_SAVE_DIR}" "${SFT_DETAILS_DIR}" +mkdir -p "${SLURM_LOG_DIR}" "${SFT_SAVE_DIR}" "${SFT_HF_SNAPSHOT_DIR}" "${SFT_DETAILS_DIR}" "${CHECKPOINT_REPORT_DIR}" if [[ -n "${SLURM_JOB_ID:-}" ]]; then export RAY_TMPDIR="/tmp/${USER:-suryadv}/slime-ray-${SLURM_JOB_ID}" mkdir -p "${RAY_TMPDIR}" @@ -35,6 +36,14 @@ done echo "Starting SFT on $(hostname) at $(date)" echo "SFT data: ${SFT_PARQUET}" echo "Save dir: ${SFT_SAVE_DIR}" +echo "HF eval snapshots: ${SFT_HF_SNAPSHOT_DIR}" +echo "Expected final SFT rollout id: ${SFT_FINAL_ROLLOUT_ID}" + +checkpoint_start_pruner sft "${SFT_SAVE_DIR}" "${CHECKPOINT_PRUNE_INTERVAL_SECONDS}" "${CHECKPOINT_REPORT_DIR}" +cleanup() { + checkpoint_stop_pruner "${CHECKPOINT_PRUNER_PID:-}" +} +trap cleanup EXIT "${SCRIPT_DIR}/container_exec.sh" bash -lc ' set -euo pipefail @@ -58,7 +67,7 @@ CKPT_ARGS=( --load "${SFT_SAVE_DIR}" --save "${SFT_SAVE_DIR}" --save-interval "${SFT_SAVE_INTERVAL}" - --no-save-optim + --save-hf "${SFT_HF_SNAPSHOT_TEMPLATE}" ) SFT_ARGS=( @@ -103,7 +112,7 @@ OPTIMIZER_ARGS=( WANDB_ARGS=() if [[ "${WANDB_MODE}" != "disabled" ]]; then - WANDB_ARGS+=(--wandb-project slime-tillicum --wandb-group qwen3-8b-sft100k) + WANDB_ARGS+=(--wandb-project slime-tillicum --wandb-group qwen3-8b-sft25k) fi MISC_ARGS=( @@ -146,4 +155,18 @@ ray job submit --address="http://127.0.0.1:8265" \ "${MISC_ARGS[@]}" ' +checkpoint_stop_pruner "${CHECKPOINT_PRUNER_PID:-}" +checkpoint_prune_old sft "${SFT_SAVE_DIR}" +checkpoint_report sft "${SFT_SAVE_DIR}" "${CHECKPOINT_REPORT_DIR}" +checkpoint_verify_single_latest sft "${SFT_SAVE_DIR}" + +if [[ -f "${SPLIT_METADATA}" ]]; then + cp -f "${SPLIT_METADATA}" "${SFT_SAVE_DIR}/split_metadata.json" +fi + +if [[ ! -d "${SFT_FINAL_HF_DIR}" ]]; then + echo "Missing final SFT HF eval snapshot: ${SFT_FINAL_HF_DIR}" >&2 + exit 1 +fi + echo "Finished SFT at $(date)" diff --git a/examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch b/examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch index 165105a5f..788ecc67f 100755 --- a/examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch +++ b/examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch @@ -1,11 +1,11 @@ #!/bin/bash -#SBATCH --job-name=slime-qwen3-opd50k +#SBATCH --job-name=slime-qwen3-opd10k #SBATCH --nodes=1 #SBATCH --ntasks=1 #SBATCH --cpus-per-task=64 #SBATCH --gres=gpu:h200:8 #SBATCH --mem=0 -#SBATCH --time=03:00:00 +#SBATCH --time=05:00:00 #SBATCH --mail-user=suryadv@cs.washington.edu #SBATCH --mail-type=END,FAIL @@ -17,8 +17,9 @@ else SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" fi source "${SCRIPT_DIR}/env.sh" +source "${SCRIPT_DIR}/checkpoint_utils.sh" -mkdir -p "${SLURM_LOG_DIR}" "${OPD_SAVE_DIR}" "${OPD_ROLLOUT_LOG_DIR}" "${TEACHER_LOG_DIR}" +mkdir -p "${SLURM_LOG_DIR}" "${OPD_SAVE_DIR}" "${OPD_HF_SNAPSHOT_DIR}" "${OPD_ROLLOUT_LOG_DIR}" "${TEACHER_LOG_DIR}" "${CHECKPOINT_REPORT_DIR}" if [[ -n "${SLURM_JOB_ID:-}" ]]; then export RAY_TMPDIR="/tmp/${USER:-suryadv}/slime-ray-${SLURM_JOB_ID}" mkdir -p "${RAY_TMPDIR}" @@ -30,7 +31,7 @@ for required_path in \ "${STUDENT_HF_DIR}" \ "${STUDENT_TORCH_DIST_DIR}/latest_checkpointed_iteration.txt" \ "${TEACHER_HF_DIR}" \ - "${SFT_SAVE_DIR}/latest_checkpointed_iteration.txt"; do + "${SFT_FINAL_HF_DIR}"; do if [[ ! -e "${required_path}" ]]; then echo "Missing required path: ${required_path}" >&2 exit 1 @@ -42,6 +43,14 @@ echo "OPD data: ${OPD_JSONL}" echo "Teacher: ${TEACHER_HF_DIR} on physical GPU ${OPD_TEACHER_GPU}" echo "Ray GPUs: 0-$((OPD_RAY_GPUS - 1))" echo "Save dir: ${OPD_SAVE_DIR}" +echo "HF eval snapshots: ${OPD_HF_SNAPSHOT_DIR}" +echo "Expected final OPD rollout id: ${OPD_FINAL_ROLLOUT_ID}" + +checkpoint_start_pruner opd "${OPD_SAVE_DIR}" "${CHECKPOINT_PRUNE_INTERVAL_SECONDS}" "${CHECKPOINT_REPORT_DIR}" +cleanup_pruner() { + checkpoint_stop_pruner "${CHECKPOINT_PRUNER_PID:-}" +} +trap cleanup_pruner EXIT "${SCRIPT_DIR}/container_exec.sh" bash -lc ' set -euo pipefail @@ -90,13 +99,22 @@ echo "HAS_NVLINK=${HAS_NVLINK} (detected ${NVLINK_COUNT} NVLink references)" source scripts/models/qwen3-8B.sh +if [[ -f "${OPD_SAVE_DIR}/latest_checkpointed_iteration.txt" ]]; then + OPD_LOAD_DIR="${OPD_SAVE_DIR}" + echo "Resuming OPD from ${OPD_LOAD_DIR}" +else + OPD_LOAD_DIR="${SFT_FINAL_HF_DIR}" + echo "Starting OPD from final SFT HF snapshot ${OPD_LOAD_DIR}" +fi + CKPT_ARGS=( + --megatron-to-hf-mode bridge --hf-checkpoint "${STUDENT_HF_DIR}" --ref-load "${STUDENT_TORCH_DIST_DIR}" - --load "${SFT_SAVE_DIR}" + --load "${OPD_LOAD_DIR}" --save "${OPD_SAVE_DIR}" --save-interval "${OPD_SAVE_INTERVAL}" - --no-save-optim + --save-hf "${OPD_HF_SNAPSHOT_TEMPLATE}" ) ROLLOUT_ARGS=( @@ -157,7 +175,7 @@ OPTIMIZER_ARGS=( WANDB_ARGS=() if [[ "${WANDB_MODE}" != "disabled" ]]; then - WANDB_ARGS+=(--wandb-project slime-tillicum --wandb-group qwen3-8b-opd50k) + WANDB_ARGS+=(--wandb-project slime-tillicum --wandb-group qwen3-8b-opd10k) fi SGLANG_ARGS=( @@ -184,7 +202,6 @@ RUNTIME_ENV_JSON="{ \"PYTHONPATH\": \"/root/Megatron-LM:${SLIME_REPO_ROOT}\", \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\", - \"PYTORCH_CUDA_ALLOC_CONF\": \"expandable_segments:True\", \"WANDB_MODE\": \"${WANDB_MODE}\", \"HF_HOME\": \"${HF_HOME}\" } @@ -208,4 +225,18 @@ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 ray job submit --address="http://127.0.0.1:82 "${RM_ARGS[@]}" ' +checkpoint_stop_pruner "${CHECKPOINT_PRUNER_PID:-}" +checkpoint_prune_old opd "${OPD_SAVE_DIR}" +checkpoint_report opd "${OPD_SAVE_DIR}" "${CHECKPOINT_REPORT_DIR}" +checkpoint_verify_single_latest opd "${OPD_SAVE_DIR}" + +if [[ -f "${SPLIT_METADATA}" ]]; then + cp -f "${SPLIT_METADATA}" "${OPD_SAVE_DIR}/split_metadata.json" +fi + +if [[ ! -d "${OPD_FINAL_HF_DIR}" ]]; then + echo "Missing final OPD HF eval snapshot: ${OPD_FINAL_HF_DIR}" >&2 + exit 1 +fi + echo "Finished OPD at $(date)" diff --git a/examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch b/examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch index 8835f3e46..9aa783270 100755 --- a/examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch +++ b/examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch @@ -5,7 +5,7 @@ #SBATCH --cpus-per-task=64 #SBATCH --gres=gpu:h200:8 #SBATCH --mem=0 -#SBATCH --time=01:00:00 +#SBATCH --time=03:00:00 #SBATCH --mail-user=suryadv@cs.washington.edu #SBATCH --mail-type=END,FAIL @@ -25,18 +25,8 @@ if [[ -n "${SLURM_JOB_ID:-}" ]]; then exec >"${SLURM_LOG_DIR}/${SLURM_JOB_NAME}_${SLURM_JOB_ID}.log" 2>&1 fi -for required_path in \ - "${STUDENT_HF_DIR}" \ - "${STUDENT_TORCH_DIST_DIR}/latest_checkpointed_iteration.txt" \ - "${SFT_SAVE_DIR}/latest_checkpointed_iteration.txt" \ - "${OPD_SAVE_DIR}/latest_checkpointed_iteration.txt"; do - if [[ ! -e "${required_path}" ]]; then - echo "Missing required path: ${required_path}" >&2 - exit 1 - fi -done - echo "Starting MATH-500 greedy eval on $(hostname) at $(date)" +echo "Eval targets: ${EVAL_TARGETS}" echo "Eval output: ${EVAL_OUTPUT_DIR}" "${SCRIPT_DIR}/container_exec.sh" bash -lc ' @@ -87,10 +77,16 @@ PY run_eval() { local stage="$1" local load_dir="$2" + local train_samples="$3" local stage_dir="${EVAL_OUTPUT_DIR}/${stage}" mkdir -p "${stage_dir}" - echo "Evaluating ${stage} from ${load_dir}" + if [[ ! -e "${load_dir}" ]]; then + echo "Missing eval load dir for ${stage}: ${load_dir}" >&2 + exit 1 + fi + + echo "Evaluating ${stage} from ${load_dir} train_samples=${train_samples}" ray stop --force >/dev/null 2>&1 || true NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o "NV[0-9][0-9]*" | wc -l || true) @@ -103,7 +99,8 @@ run_eval() { source scripts/models/qwen3-8B.sh CKPT_ARGS=( - --hf-checkpoint "${STUDENT_HF_DIR}" + --megatron-to-hf-mode bridge + --hf-checkpoint "${load_dir}" --ref-load "${STUDENT_TORCH_DIST_DIR}" --load "${load_dir}" --save "${stage_dir}/unused_save" @@ -112,7 +109,7 @@ run_eval() { ) ROLLOUT_ARGS=( - --prompt-data "${OPD_JSONL}" + --prompt-data "${MATH500_JSONL}" --input-key prompt --apply-chat-template --num-rollout 0 @@ -170,6 +167,7 @@ run_eval() { SGLANG_ARGS=( --rollout-num-gpus-per-engine 1 --sglang-mem-fraction-static 0.7 + --sglang-server-concurrency "${EVAL_SGLANG_SERVER_CONCURRENCY}" ) MISC_ARGS=( @@ -190,7 +188,6 @@ run_eval() { \"PYTHONPATH\": \"/root/Megatron-LM:${SLIME_REPO_ROOT}\", \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\", - \"PYTORCH_CUDA_ALLOC_CONF\": \"expandable_segments:True\", \"WANDB_MODE\": \"${WANDB_MODE}\", \"HF_HOME\": \"${HF_HOME}\" } @@ -199,6 +196,7 @@ run_eval() { ray job submit --address="http://127.0.0.1:8265" \ --runtime-env-json="${RUNTIME_ENV_JSON}" \ -- python3 train.py \ + --debug-rollout-only \ --actor-num-nodes 1 \ --actor-num-gpus-per-node 8 \ --rollout-num-gpus "${EVAL_ROLLOUT_NUM_GPUS}" \ @@ -219,14 +217,66 @@ run_eval() { --stage "${stage}" \ --debug-file "${stage_dir}/debug_eval_0.pt" \ --out-json "${stage_dir}/summary.json" \ - --max-response-len "${EVAL_MAX_RESPONSE_LEN}" + --max-response-len "${EVAL_MAX_RESPONSE_LEN}" \ + --train-samples "${train_samples}" +} + +hf_snapshot_dir() { + local root="$1" + local rollout_id="$2" + printf "%s/iter_%07d" "${root}" "${rollout_id}" +} + +sample_count_for_rollout() { + local rollout_id="$1" + local batch_size="$2" + local total_size="$3" + local count=$(((rollout_id + 1) * batch_size)) + if [[ "${count}" -gt "${total_size}" ]]; then + count="${total_size}" + fi + printf "%s" "${count}" +} + +run_sft_curve() { + local rid samples dir stage + for rid in ${SFT_MILESTONE_ROLLOUT_IDS}; do + samples="$(sample_count_for_rollout "${rid}" "${SFT_ROLLOUT_BATCH_SIZE}" "${SFT_SIZE}")" + dir="$(hf_snapshot_dir "${SFT_HF_SNAPSHOT_DIR}" "${rid}")" + stage="$(printf "sft_%06d" "${samples}")" + run_eval "${stage}" "${dir}" "${samples}" + done +} + +run_opd_curve() { + local rid samples dir stage + for rid in ${OPD_MILESTONE_ROLLOUT_IDS}; do + samples="$(sample_count_for_rollout "${rid}" "${OPD_ROLLOUT_BATCH_SIZE}" "${OPD_SIZE}")" + dir="$(hf_snapshot_dir "${OPD_HF_SNAPSHOT_DIR}" "${rid}")" + stage="$(printf "opd_%06d" "${samples}")" + run_eval "${stage}" "${dir}" "${samples}" + done } trap "ray stop --force >/dev/null 2>&1 || true" EXIT -run_eval base "${STUDENT_TORCH_DIST_DIR}" -run_eval sft "${SFT_SAVE_DIR}" -run_eval opd "${OPD_SAVE_DIR}" +for target in ${EVAL_TARGETS}; do + case "${target}" in + base) + run_eval base "${STUDENT_HF_DIR}" 0 + ;; + sft) + run_sft_curve + ;; + opd) + run_opd_curve + ;; + *) + echo "Unknown EVAL_TARGETS entry: ${target}" >&2 + exit 2 + ;; + esac +done python3 examples/qwen3_8b_opd_tillicum/summarize_eval.py \ --aggregate-dir "${EVAL_OUTPUT_DIR}" \ diff --git a/examples/qwen3_8b_opd_tillicum/README.md b/examples/qwen3_8b_opd_tillicum/README.md index 8a9695c9d..f0d34fbf7 100644 --- a/examples/qwen3_8b_opd_tillicum/README.md +++ b/examples/qwen3_8b_opd_tillicum/README.md @@ -69,21 +69,13 @@ Run the setup steps manually, in this order: ```bash bash examples/qwen3_8b_opd_tillicum/00_pull_or_load_container.sh bash examples/qwen3_8b_opd_tillicum/01_prepare_env.sh -bash examples/qwen3_8b_opd_tillicum/container_exec.sh \ - python examples/qwen3_8b_opd_tillicum/02_prepare_openthoughts3_math_sample.py +bash examples/qwen3_8b_opd_tillicum/run_all_dry_check.sh ``` Then launch only after explicit approval: ```bash -jid0=$(sbatch -A "$ACCOUNT" -p "$PARTITION" --qos "$QOS" \ - examples/qwen3_8b_opd_tillicum/03_convert_models_if_needed.sbatch | awk '{print $4}') -jid1=$(sbatch -A "$ACCOUNT" -p "$PARTITION" --qos "$QOS" --dependency=afterok:$jid0 \ - examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch | awk '{print $4}') -jid2=$(sbatch -A "$ACCOUNT" -p "$PARTITION" --qos "$QOS" --dependency=afterok:$jid1 \ - examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch | awk '{print $4}') -sbatch -A "$ACCOUNT" -p "$PARTITION" --qos "$QOS" --dependency=afterok:$jid2 \ - examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch +bash examples/qwen3_8b_opd_tillicum/submit_25k_10k_chain.sh ``` ## Reproducing On Another Slurm Cluster @@ -133,17 +125,7 @@ Submit the dependency chain, overriding the Tillicum `h200` gres embedded in the sbatch files: ```bash -jid0=$(sbatch -A "$ACCOUNT" -p "$PARTITION" --qos "$QOS" --gres "$GPU_GRES" \ - examples/qwen3_8b_opd_tillicum/03_convert_models_if_needed.sbatch | awk '{print $4}') -jid1=$(sbatch -A "$ACCOUNT" -p "$PARTITION" --qos "$QOS" --gres "$GPU_GRES" \ - --dependency=afterok:$jid0 \ - examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch | awk '{print $4}') -jid2=$(sbatch -A "$ACCOUNT" -p "$PARTITION" --qos "$QOS" --gres "$GPU_GRES" \ - --dependency=afterok:$jid1 \ - examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch | awk '{print $4}') -sbatch -A "$ACCOUNT" -p "$PARTITION" --qos "$QOS" --gres "$GPU_GRES" \ - --dependency=afterok:$jid2 \ - examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch +GPU_GRES=gpu:h200:8 bash examples/qwen3_8b_opd_tillicum/submit_25k_10k_chain.sh ``` If your cluster uses `--gpus-per-node` instead of `--gres`, replace the @@ -162,9 +144,13 @@ Apptainer-specific layer. - Student HF snapshot: `$STUDENT_HF_DIR` - Teacher HF snapshot: `$TEACHER_HF_DIR` - Student Megatron torch_dist: `$STUDENT_TORCH_DIST_DIR` -- SFT checkpoint: `$SFT_SAVE_DIR` -- OPD checkpoint: `$OPD_SAVE_DIR` -- Eval summaries: `$EVAL_OUTPUT_DIR` +- SFT full optimizer checkpoint: `$SFT_SAVE_DIR` +- OPD full optimizer checkpoint: `$OPD_SAVE_DIR` +- SFT model-only eval snapshots: `$SFT_HF_SNAPSHOT_DIR` +- OPD model-only eval snapshots: `$OPD_HF_SNAPSHOT_DIR` +- Eval summaries and curves: `$BASE_EVAL_OUTPUT_DIR`, `$SFT_EVAL_OUTPUT_DIR`, + `$OPD_EVAL_OUTPUT_DIR` +- Checkpoint storage reports: `$CHECKPOINT_REPORT_DIR` ## Slurm resources @@ -174,10 +160,10 @@ passed at submit time from the environment variables above. The intended wall-clock budget after model/data/container preparation is: -- SFT 100k: 2-4 hours. -- OPD 50k: 2-3 hours. -- MATH-500 greedy eval once for base, SFT, and OPD: <=1 hour. +- SFT 25k: 3 hours. +- OPD 10k: 5 hours. +- MATH-500 greedy eval: base 1 hour, SFT curve 3 hours, OPD curve 3 hours. The main runtime risk is the Qwen3-32B teacher logprob server throughput during -OPD. If dry measurements show the full OPD run will exceed the budget, the -first reduced run to try is 100k SFT plus 25k OPD. +OPD. The current conservative chain evaluates base first, then SFT milestones, +then OPD milestones so SFT results are available before OPD starts. diff --git a/examples/qwen3_8b_opd_tillicum/checkpoint_utils.sh b/examples/qwen3_8b_opd_tillicum/checkpoint_utils.sh new file mode 100755 index 000000000..f5c35c61c --- /dev/null +++ b/examples/qwen3_8b_opd_tillicum/checkpoint_utils.sh @@ -0,0 +1,130 @@ +#!/usr/bin/env bash + +checkpoint_latest_iteration() { + local save_dir="$1" + local tracker="${save_dir}/latest_checkpointed_iteration.txt" + if [[ ! -f "${tracker}" ]]; then + return 1 + fi + tr -d '[:space:]' <"${tracker}" +} + +checkpoint_iter_name() { + local iteration="$1" + printf "iter_%07d" "${iteration}" +} + +checkpoint_latest_dir() { + local save_dir="$1" + local iteration + iteration="$(checkpoint_latest_iteration "${save_dir}")" || return 1 + printf "%s/%s" "${save_dir}" "$(checkpoint_iter_name "${iteration}")" +} + +checkpoint_prune_old() { + local stage="$1" + local save_dir="$2" + local latest_dir + latest_dir="$(checkpoint_latest_dir "${save_dir}")" || return 0 + + if [[ ! -d "${latest_dir}" ]]; then + echo "CHECKPOINT_PRUNE_WAIT stage=${stage} latest_dir_missing=${latest_dir}" + return 0 + fi + + local dir + find "${save_dir}" -maxdepth 1 -type d -name 'iter_*' -print | while IFS= read -r dir; do + if [[ "${dir}" != "${latest_dir}" ]]; then + echo "CHECKPOINT_PRUNE_REMOVE stage=${stage} dir=${dir}" + rm -rf -- "${dir}" + fi + done +} + +checkpoint_report() { + local stage="$1" + local save_dir="$2" + local report_dir="$3" + mkdir -p "${report_dir}" + + local iteration latest_dir bytes report + iteration="$(checkpoint_latest_iteration "${save_dir}")" + latest_dir="${save_dir}/$(checkpoint_iter_name "${iteration}")" + if [[ ! -d "${latest_dir}" ]]; then + echo "Missing latest checkpoint dir: ${latest_dir}" >&2 + return 1 + fi + + bytes="$(du -sb "${latest_dir}" | awk '{print $1}')" + report="${report_dir}/${stage}_checkpoint_storage.tsv" + if [[ ! -f "${report}" ]]; then + printf "stage\titeration\tcheckpoint_dir\tbytes\testimated_full_optim_bytes\n" >"${report}" + fi + printf "%s\t%s\t%s\t%s\t%s\n" \ + "${stage}" \ + "${iteration}" \ + "${latest_dir}" \ + "${bytes}" \ + "${ESTIMATED_FULL_OPTIM_CKPT_BYTES:-}" >>"${report}" + + echo "CHECKPOINT_SIZE_BYTES stage=${stage} iter=${iteration} bytes=${bytes} dir=${latest_dir}" +} + +checkpoint_verify_single_latest() { + local stage="$1" + local save_dir="$2" + local latest_dir + latest_dir="$(checkpoint_latest_dir "${save_dir}")" + + local count + count="$(find "${save_dir}" -maxdepth 1 -type d -name 'iter_*' | wc -l)" + if [[ "${count}" -ne 1 ]]; then + echo "CHECKPOINT_PRUNE_FAIL stage=${stage} expected_count=1 actual_count=${count}" >&2 + find "${save_dir}" -maxdepth 1 -type d -name 'iter_*' -print >&2 + return 1 + fi + + if [[ ! -d "${latest_dir}" ]]; then + echo "CHECKPOINT_PRUNE_FAIL stage=${stage} missing_latest=${latest_dir}" >&2 + return 1 + fi + + echo "CHECKPOINT_PRUNE_OK stage=${stage} remaining_iter=$(basename "${latest_dir}")" +} + +checkpoint_start_pruner() { + local stage="$1" + local save_dir="$2" + local interval_seconds="$3" + local report_dir="${4:-}" + + ( + set +e + last_reported_iteration="" + while true; do + checkpoint_prune_old "${stage}" "${save_dir}" + if [[ -n "${report_dir}" ]]; then + iteration="$(checkpoint_latest_iteration "${save_dir}" 2>/dev/null)" + if [[ -n "${iteration}" && "${iteration}" != "${last_reported_iteration}" ]]; then + latest_dir="${save_dir}/$(checkpoint_iter_name "${iteration}")" + if [[ -d "${latest_dir}" ]]; then + checkpoint_report "${stage}" "${save_dir}" "${report_dir}" + last_reported_iteration="${iteration}" + fi + fi + fi + sleep "${interval_seconds}" + done + ) & + CHECKPOINT_PRUNER_PID="$!" + export CHECKPOINT_PRUNER_PID + echo "CHECKPOINT_PRUNER_STARTED stage=${stage} pid=${CHECKPOINT_PRUNER_PID}" +} + +checkpoint_stop_pruner() { + local pid="${1:-}" + if [[ -n "${pid}" ]]; then + kill "${pid}" >/dev/null 2>&1 || true + wait "${pid}" >/dev/null 2>&1 || true + fi +} diff --git a/examples/qwen3_8b_opd_tillicum/container_exec.sh b/examples/qwen3_8b_opd_tillicum/container_exec.sh index 2d27e44e8..d163b2b57 100755 --- a/examples/qwen3_8b_opd_tillicum/container_exec.sh +++ b/examples/qwen3_8b_opd_tillicum/container_exec.sh @@ -78,7 +78,13 @@ PASS_ENV=( STUDENT_HF_REPO TEACHER_HF_REPO OT3_DATASET + OT3_SPLIT MATH500_DATASET + SFT_SIZE + OPD_SIZE + DATA_SEED + DATA_MATH_FIELD + DATA_MATH_VALUE STUDENT_HF_DIR TEACHER_HF_DIR STUDENT_TORCH_DIST_DIR @@ -89,17 +95,32 @@ PASS_ENV=( MATH500_CONFIG SFT_SAVE_DIR OPD_SAVE_DIR + SFT_HF_SNAPSHOT_DIR + OPD_HF_SNAPSHOT_DIR + SFT_HF_SNAPSHOT_TEMPLATE + OPD_HF_SNAPSHOT_TEMPLATE + SFT_FINAL_HF_DIR + OPD_FINAL_HF_DIR SFT_DETAILS_DIR OPD_ROLLOUT_LOG_DIR TEACHER_LOG_DIR EVAL_OUTPUT_DIR + BASE_EVAL_OUTPUT_DIR + SFT_EVAL_OUTPUT_DIR + OPD_EVAL_OUTPUT_DIR + CHECKPOINT_REPORT_DIR SLURM_LOG_DIR SFT_NUM_EPOCH + SFT_NUM_ROLLOUT + SFT_FINAL_ROLLOUT_ID + SFT_MILESTONE_ROLLOUT_IDS SFT_ROLLOUT_BATCH_SIZE SFT_GLOBAL_BATCH_SIZE SFT_MAX_TOKENS_PER_GPU SFT_SAVE_INTERVAL OPD_NUM_ROLLOUT + OPD_FINAL_ROLLOUT_ID + OPD_MILESTONE_ROLLOUT_IDS OPD_ROLLOUT_BATCH_SIZE OPD_GLOBAL_BATCH_SIZE OPD_N_SAMPLES_PER_PROMPT @@ -112,10 +133,14 @@ PASS_ENV=( OPD_TEACHER_GPU OPD_TEACHER_PORT OPD_TEACHER_MEM_FRACTION + CHECKPOINT_PRUNE_INTERVAL_SECONDS + ESTIMATED_FULL_OPTIM_CKPT_BYTES EVAL_MAX_RESPONSE_LEN EVAL_ROLLOUT_BATCH_SIZE EVAL_ROLLOUT_NUM_GPUS EVAL_NUM_REPEATS + EVAL_SGLANG_SERVER_CONCURRENCY + EVAL_TARGETS PYTHONUNBUFFERED TOKENIZERS_PARALLELISM NCCL_DEBUG diff --git a/examples/qwen3_8b_opd_tillicum/env.sh b/examples/qwen3_8b_opd_tillicum/env.sh index 68808683d..c23ec90a3 100755 --- a/examples/qwen3_8b_opd_tillicum/env.sh +++ b/examples/qwen3_8b_opd_tillicum/env.sh @@ -57,8 +57,8 @@ export STUDENT_HF_DIR="${STUDENT_HF_DIR:-${MODEL_ROOT}/Qwen3-8B-Base}" export TEACHER_HF_DIR="${TEACHER_HF_DIR:-${MODEL_ROOT}/Qwen3-32B}" export STUDENT_TORCH_DIST_DIR="${STUDENT_TORCH_DIST_DIR:-${MODEL_ROOT}/Qwen3-8B-Base_torch_dist}" -export SFT_SIZE="${SFT_SIZE:-100000}" -export OPD_SIZE="${OPD_SIZE:-50000}" +export SFT_SIZE="${SFT_SIZE:-25000}" +export OPD_SIZE="${OPD_SIZE:-10000}" export DATA_SEED="${DATA_SEED:-1234}" export DATA_MATH_FIELD="${DATA_MATH_FIELD:-domain}" export DATA_MATH_VALUE="${DATA_MATH_VALUE:-math}" @@ -66,42 +66,65 @@ export OT3_SPLIT="${OT3_SPLIT:-train}" export SFT_PARQUET="${SFT_PARQUET:-${DATA_ROOT}/openthoughts3_math_sft_${SFT_SIZE}.jsonl}" export OPD_JSONL="${OPD_JSONL:-${DATA_ROOT}/openthoughts3_math_opd_${OPD_SIZE}.jsonl}" -export SPLIT_METADATA="${SPLIT_METADATA:-${DATA_ROOT}/openthoughts3_math_split_metadata.json}" +export SPLIT_METADATA="${SPLIT_METADATA:-${DATA_ROOT}/openthoughts3_math_sft_${SFT_SIZE}_opd_${OPD_SIZE}_metadata.json}" export MATH500_JSONL="${MATH500_JSONL:-${DATA_ROOT}/math500_deepscaler.jsonl}" export MATH500_CONFIG="${MATH500_CONFIG:-${DATA_ROOT}/math500_eval.yaml}" -export SFT_SAVE_DIR="${SFT_SAVE_DIR:-${OUTPUT_ROOT}/qwen3_8b_sft_100k}" -export OPD_SAVE_DIR="${OPD_SAVE_DIR:-${OUTPUT_ROOT}/qwen3_8b_sft_100k_opd_50k}" -export SFT_DETAILS_DIR="${SFT_DETAILS_DIR:-${OUTPUT_ROOT}/sft_details}" -export OPD_ROLLOUT_LOG_DIR="${OPD_ROLLOUT_LOG_DIR:-${OUTPUT_ROOT}/opd_rollout_logs}" +export SFT_SAVE_DIR="${SFT_SAVE_DIR:-${OUTPUT_ROOT}/qwen3_8b_sft_25k_full_optim}" +export OPD_SAVE_DIR="${OPD_SAVE_DIR:-${OUTPUT_ROOT}/qwen3_8b_sft_25k_opd_10k_full_optim}" +export SFT_HF_SNAPSHOT_DIR="${SFT_HF_SNAPSHOT_DIR:-${OUTPUT_ROOT}/qwen3_8b_sft_25k_eval_snapshots}" +export OPD_HF_SNAPSHOT_DIR="${OPD_HF_SNAPSHOT_DIR:-${OUTPUT_ROOT}/qwen3_8b_sft_25k_opd_10k_eval_snapshots}" +export SFT_DETAILS_DIR="${SFT_DETAILS_DIR:-${OUTPUT_ROOT}/sft_25k_details}" +export OPD_ROLLOUT_LOG_DIR="${OPD_ROLLOUT_LOG_DIR:-${OUTPUT_ROOT}/opd_10k_rollout_logs}" export TEACHER_LOG_DIR="${TEACHER_LOG_DIR:-${OUTPUT_ROOT}/teacher_logs}" -export EVAL_OUTPUT_DIR="${EVAL_OUTPUT_DIR:-${OUTPUT_ROOT}/math500_eval_1x}" +export EVAL_OUTPUT_DIR="${EVAL_OUTPUT_DIR:-${OUTPUT_ROOT}/math500_eval_25k_10k}" +export BASE_EVAL_OUTPUT_DIR="${BASE_EVAL_OUTPUT_DIR:-${OUTPUT_ROOT}/math500_eval_base_25k_10k}" +export SFT_EVAL_OUTPUT_DIR="${SFT_EVAL_OUTPUT_DIR:-${OUTPUT_ROOT}/math500_eval_sft_25k_curve}" +export OPD_EVAL_OUTPUT_DIR="${OPD_EVAL_OUTPUT_DIR:-${OUTPUT_ROOT}/math500_eval_opd_10k_curve}" +export CHECKPOINT_REPORT_DIR="${CHECKPOINT_REPORT_DIR:-${OUTPUT_ROOT}/checkpoint_reports_25k_10k}" export SLURM_LOG_DIR="${SLURM_LOG_DIR:-${OUTPUT_ROOT}/slurm_logs}" export SFT_NUM_EPOCH="${SFT_NUM_EPOCH:-1}" export SFT_ROLLOUT_BATCH_SIZE="${SFT_ROLLOUT_BATCH_SIZE:-256}" export SFT_GLOBAL_BATCH_SIZE="${SFT_GLOBAL_BATCH_SIZE:-256}" export SFT_MAX_TOKENS_PER_GPU="${SFT_MAX_TOKENS_PER_GPU:-16384}" -export SFT_SAVE_INTERVAL="${SFT_SAVE_INTERVAL:-50}" +export SFT_SAVE_INTERVAL="${SFT_SAVE_INTERVAL:-20}" +export SFT_NUM_ROLLOUT="${SFT_NUM_ROLLOUT:-$(((SFT_SIZE + SFT_ROLLOUT_BATCH_SIZE - 1) / SFT_ROLLOUT_BATCH_SIZE))}" +export SFT_FINAL_ROLLOUT_ID="${SFT_FINAL_ROLLOUT_ID:-$((SFT_NUM_ROLLOUT - 1))}" +export SFT_MILESTONE_ROLLOUT_IDS="${SFT_MILESTONE_ROLLOUT_IDS:-19 39 59 79 ${SFT_FINAL_ROLLOUT_ID}}" +export SFT_HF_SNAPSHOT_TEMPLATE="${SFT_HF_SNAPSHOT_TEMPLATE:-${SFT_HF_SNAPSHOT_DIR}/iter_{rollout_id:07d}}" +printf -v _SFT_FINAL_ROLLOUT_TAG "iter_%07d" "${SFT_FINAL_ROLLOUT_ID}" +export SFT_FINAL_HF_DIR="${SFT_FINAL_HF_DIR:-${SFT_HF_SNAPSHOT_DIR}/${_SFT_FINAL_ROLLOUT_TAG}}" +unset _SFT_FINAL_ROLLOUT_TAG -export OPD_NUM_ROLLOUT="${OPD_NUM_ROLLOUT:-391}" export OPD_ROLLOUT_BATCH_SIZE="${OPD_ROLLOUT_BATCH_SIZE:-128}" export OPD_GLOBAL_BATCH_SIZE="${OPD_GLOBAL_BATCH_SIZE:-128}" export OPD_N_SAMPLES_PER_PROMPT="${OPD_N_SAMPLES_PER_PROMPT:-1}" +export OPD_NUM_ROLLOUT="${OPD_NUM_ROLLOUT:-$(((OPD_SIZE + OPD_ROLLOUT_BATCH_SIZE - 1) / OPD_ROLLOUT_BATCH_SIZE))}" +export OPD_FINAL_ROLLOUT_ID="${OPD_FINAL_ROLLOUT_ID:-$((OPD_NUM_ROLLOUT - 1))}" +export OPD_MILESTONE_ROLLOUT_IDS="${OPD_MILESTONE_ROLLOUT_IDS:-15 31 47 63 ${OPD_FINAL_ROLLOUT_ID}}" export OPD_MAX_RESPONSE_LEN="${OPD_MAX_RESPONSE_LEN:-16384}" export OPD_MAX_TOKENS_PER_GPU="${OPD_MAX_TOKENS_PER_GPU:-16384}" -export OPD_SAVE_INTERVAL="${OPD_SAVE_INTERVAL:-50}" +export OPD_SAVE_INTERVAL="${OPD_SAVE_INTERVAL:-16}" +export OPD_HF_SNAPSHOT_TEMPLATE="${OPD_HF_SNAPSHOT_TEMPLATE:-${OPD_HF_SNAPSHOT_DIR}/iter_{rollout_id:07d}}" +printf -v _OPD_FINAL_ROLLOUT_TAG "iter_%07d" "${OPD_FINAL_ROLLOUT_ID}" +export OPD_FINAL_HF_DIR="${OPD_FINAL_HF_DIR:-${OPD_HF_SNAPSHOT_DIR}/${_OPD_FINAL_ROLLOUT_TAG}}" +unset _OPD_FINAL_ROLLOUT_TAG export OPD_ACTOR_GPUS="${OPD_ACTOR_GPUS:-2}" export OPD_ROLLOUT_GPUS="${OPD_ROLLOUT_GPUS:-5}" export OPD_RAY_GPUS="${OPD_RAY_GPUS:-7}" export OPD_TEACHER_GPU="${OPD_TEACHER_GPU:-7}" export OPD_TEACHER_PORT="${OPD_TEACHER_PORT:-13141}" export OPD_TEACHER_MEM_FRACTION="${OPD_TEACHER_MEM_FRACTION:-0.6}" +export CHECKPOINT_PRUNE_INTERVAL_SECONDS="${CHECKPOINT_PRUNE_INTERVAL_SECONDS:-60}" +export ESTIMATED_FULL_OPTIM_CKPT_BYTES="${ESTIMATED_FULL_OPTIM_CKPT_BYTES:-114694164347}" -export EVAL_MAX_RESPONSE_LEN="${EVAL_MAX_RESPONSE_LEN:-32768}" +export EVAL_MAX_RESPONSE_LEN="${EVAL_MAX_RESPONSE_LEN:-31744}" export EVAL_ROLLOUT_BATCH_SIZE="${EVAL_ROLLOUT_BATCH_SIZE:-64}" export EVAL_ROLLOUT_NUM_GPUS="${EVAL_ROLLOUT_NUM_GPUS:-8}" export EVAL_NUM_REPEATS="${EVAL_NUM_REPEATS:-1}" +export EVAL_SGLANG_SERVER_CONCURRENCY="${EVAL_SGLANG_SERVER_CONCURRENCY:-1}" +export EVAL_TARGETS="${EVAL_TARGETS:-base sft opd}" export PYTHONUNBUFFERED=1 export TOKENIZERS_PARALLELISM="${TOKENIZERS_PARALLELISM:-false}" diff --git a/examples/qwen3_8b_opd_tillicum/run_all_dry_check.sh b/examples/qwen3_8b_opd_tillicum/run_all_dry_check.sh index d56fe671b..11029778f 100755 --- a/examples/qwen3_8b_opd_tillicum/run_all_dry_check.sh +++ b/examples/qwen3_8b_opd_tillicum/run_all_dry_check.sh @@ -14,10 +14,13 @@ echo " container (${SLIME_CONTAINER_FORMAT}): ${SLIME_SIF}" SHELL_FILES=( examples/qwen3_8b_opd_tillicum/env.sh + examples/qwen3_8b_opd_tillicum/checkpoint_utils.sh examples/qwen3_8b_opd_tillicum/container_exec.sh examples/qwen3_8b_opd_tillicum/00_pull_or_load_container.sh examples/qwen3_8b_opd_tillicum/01_prepare_env.sh examples/qwen3_8b_opd_tillicum/run_all_dry_check.sh + examples/qwen3_8b_opd_tillicum/submit_25k_10k_chain.sh + examples/qwen3_8b_opd_tillicum/02_prepare_data_25k_10k.sbatch examples/qwen3_8b_opd_tillicum/03_convert_models_if_needed.sbatch examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch @@ -38,6 +41,7 @@ echo "Checking Python syntax" python3 -m py_compile "${PYTHON_FILES[@]}" SBATCH_FILES=( + examples/qwen3_8b_opd_tillicum/02_prepare_data_25k_10k.sbatch examples/qwen3_8b_opd_tillicum/03_convert_models_if_needed.sbatch examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch diff --git a/examples/qwen3_8b_opd_tillicum/submit_25k_10k_chain.sh b/examples/qwen3_8b_opd_tillicum/submit_25k_10k_chain.sh new file mode 100755 index 000000000..64ff54fdf --- /dev/null +++ b/examples/qwen3_8b_opd_tillicum/submit_25k_10k_chain.sh @@ -0,0 +1,77 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" +source "${SCRIPT_DIR}/env.sh" + +cd "${SLIME_REPO_ROOT}" +mkdir -p "${SLURM_LOG_DIR}" + +SBATCH_COMMON=(-A "${ACCOUNT}" -p "${PARTITION}" --qos "${QOS}") +if [[ -n "${GPU_GRES:-}" ]]; then + SBATCH_COMMON+=(--gres "${GPU_GRES}") +fi + +submit_log="${SLURM_LOG_DIR}/submit_25k_10k_$(date +%Y%m%d_%H%M%S).txt" + +echo "Submitting 25k SFT / 10k OPD chain" +echo "account/partition/qos: ${ACCOUNT}/${PARTITION}/${QOS}" +echo "submit log: ${submit_log}" + +jid_data="$( + sbatch --parsable "${SBATCH_COMMON[@]}" \ + examples/qwen3_8b_opd_tillicum/02_prepare_data_25k_10k.sbatch +)" +jid_convert="$( + sbatch --parsable "${SBATCH_COMMON[@]}" \ + examples/qwen3_8b_opd_tillicum/03_convert_models_if_needed.sbatch +)" +jid_base_eval="$( + sbatch --parsable "${SBATCH_COMMON[@]}" \ + --dependency=afterok:${jid_convert} \ + --time=01:00:00 \ + --job-name=slime-qwen3-base-math500 \ + --export=ALL,EVAL_TARGETS=base,EVAL_OUTPUT_DIR="${BASE_EVAL_OUTPUT_DIR}" \ + examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch +)" +jid_sft="$( + sbatch --parsable "${SBATCH_COMMON[@]}" \ + --dependency=afterok:${jid_data}:${jid_convert}:${jid_base_eval} \ + examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch +)" +jid_sft_eval="$( + sbatch --parsable "${SBATCH_COMMON[@]}" \ + --dependency=afterok:${jid_sft} \ + --time=03:00:00 \ + --job-name=slime-qwen3-sft-math500 \ + --export=ALL,EVAL_TARGETS=sft,EVAL_OUTPUT_DIR="${SFT_EVAL_OUTPUT_DIR}" \ + examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch +)" +jid_opd="$( + sbatch --parsable "${SBATCH_COMMON[@]}" \ + --dependency=afterok:${jid_sft_eval} \ + examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch +)" +jid_opd_eval="$( + sbatch --parsable "${SBATCH_COMMON[@]}" \ + --dependency=afterok:${jid_opd} \ + --time=03:00:00 \ + --job-name=slime-qwen3-opd-math500 \ + --export=ALL,EVAL_TARGETS=opd,EVAL_OUTPUT_DIR="${OPD_EVAL_OUTPUT_DIR}" \ + examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch +)" + +{ + echo "data=${jid_data}" + echo "convert=${jid_convert}" + echo "base_eval=${jid_base_eval}" + echo "sft=${jid_sft}" + echo "sft_eval=${jid_sft_eval}" + echo "opd=${jid_opd}" + echo "opd_eval=${jid_opd_eval}" + echo "SFT_SAVE_DIR=${SFT_SAVE_DIR}" + echo "OPD_SAVE_DIR=${OPD_SAVE_DIR}" + echo "SFT_EVAL_OUTPUT_DIR=${SFT_EVAL_OUTPUT_DIR}" + echo "OPD_EVAL_OUTPUT_DIR=${OPD_EVAL_OUTPUT_DIR}" + echo "CHECKPOINT_REPORT_DIR=${CHECKPOINT_REPORT_DIR}" +} | tee "${submit_log}" diff --git a/examples/qwen3_8b_opd_tillicum/summarize_eval.py b/examples/qwen3_8b_opd_tillicum/summarize_eval.py index 8ac6132e1..f086a5fe0 100755 --- a/examples/qwen3_8b_opd_tillicum/summarize_eval.py +++ b/examples/qwen3_8b_opd_tillicum/summarize_eval.py @@ -4,8 +4,11 @@ from __future__ import annotations import argparse +import csv import json import math +import struct +import zlib from pathlib import Path from typing import Any @@ -19,6 +22,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--out-json", required=True) parser.add_argument("--max-response-len", type=int, default=None) parser.add_argument("--aggregate-dir", default=None) + parser.add_argument("--train-samples", type=int, default=None) return parser.parse_args() @@ -45,7 +49,12 @@ def extract_answer_or_none(response: str) -> str | None: return extract_answer(response) -def summarize_debug_file(stage: str, debug_file: Path, max_response_len: int | None) -> dict[str, Any]: +def summarize_debug_file( + stage: str, + debug_file: Path, + max_response_len: int | None, + train_samples: int | None, +) -> dict[str, Any]: payload = torch.load(debug_file, map_location="cpu", weights_only=False) samples = payload.get("samples", []) if not samples: @@ -70,6 +79,7 @@ def summarize_debug_file(stage: str, debug_file: Path, max_response_len: int | N accuracy = sum(rewards) / n return { "stage": stage, + "train_samples": train_samples, "debug_file": str(debug_file), "n": n, "accuracy": accuracy, @@ -89,6 +99,10 @@ def aggregate(aggregate_dir: Path) -> dict[str, Any]: if not summaries: raise RuntimeError(f"No per-stage summaries found under {aggregate_dir}") + summaries.sort(key=lambda item: (item.get("train_samples") is None, item.get("train_samples") or 0, item["stage"])) + write_accuracy_curve_csv(aggregate_dir / "accuracy_curve.csv", summaries) + write_accuracy_curve_png(aggregate_dir / "accuracy_curve.png", summaries) + return { "num_repeats_per_stage": 1, "stages": summaries, @@ -96,6 +110,160 @@ def aggregate(aggregate_dir: Path) -> dict[str, Any]: } +def write_accuracy_curve_csv(path: Path, summaries: list[dict[str, Any]]) -> None: + with path.open("w", encoding="utf-8", newline="") as f: + writer = csv.DictWriter( + f, + fieldnames=[ + "stage", + "train_samples", + "accuracy", + "avg_generated_tokens", + "parse_failure_rate", + "cap_hit_rate", + ], + ) + writer.writeheader() + for item in summaries: + writer.writerow( + { + "stage": item.get("stage"), + "train_samples": item.get("train_samples"), + "accuracy": item.get("accuracy"), + "avg_generated_tokens": item.get("avg_generated_tokens"), + "parse_failure_rate": item.get("parse_failure_rate"), + "cap_hit_rate": item.get("cap_hit_rate"), + } + ) + + +def put_px(image: bytearray, width: int, height: int, x: int, y: int, color: tuple[int, int, int]) -> None: + if 0 <= x < width and 0 <= y < height: + idx = (y * width + x) * 3 + image[idx : idx + 3] = bytes(color) + + +def draw_line( + image: bytearray, + width: int, + height: int, + start: tuple[int, int], + end: tuple[int, int], + color: tuple[int, int, int], +) -> None: + x0, y0 = start + x1, y1 = end + dx = abs(x1 - x0) + sx = 1 if x0 < x1 else -1 + dy = -abs(y1 - y0) + sy = 1 if y0 < y1 else -1 + err = dx + dy + while True: + put_px(image, width, height, x0, y0, color) + if x0 == x1 and y0 == y1: + break + e2 = 2 * err + if e2 >= dy: + err += dy + x0 += sx + if e2 <= dx: + err += dx + y0 += sy + + +def draw_dot( + image: bytearray, + width: int, + height: int, + center: tuple[int, int], + radius: int, + color: tuple[int, int, int], +) -> None: + cx, cy = center + for y in range(cy - radius, cy + radius + 1): + for x in range(cx - radius, cx + radius + 1): + if (x - cx) ** 2 + (y - cy) ** 2 <= radius**2: + put_px(image, width, height, x, y, color) + + +def write_png(path: Path, width: int, height: int, image: bytearray) -> None: + rows = [] + row_bytes = width * 3 + for y in range(height): + rows.append(b"\x00" + bytes(image[y * row_bytes : (y + 1) * row_bytes])) + raw = b"".join(rows) + + def chunk(kind: bytes, payload: bytes) -> bytes: + return ( + struct.pack(">I", len(payload)) + + kind + + payload + + struct.pack(">I", zlib.crc32(kind + payload) & 0xFFFFFFFF) + ) + + png = b"\x89PNG\r\n\x1a\n" + png += chunk("IHDR".encode(), struct.pack(">IIBBBBB", width, height, 8, 2, 0, 0, 0)) + png += chunk("IDAT".encode(), zlib.compress(raw, level=9)) + png += chunk("IEND".encode(), b"") + path.write_bytes(png) + + +def write_accuracy_curve_png(path: Path, summaries: list[dict[str, Any]]) -> None: + points = [ + (item.get("train_samples"), float(item.get("accuracy", 0.0))) + for item in summaries + if item.get("train_samples") is not None + ] + if not points: + return + + width, height = 900, 520 + margin_left, margin_right, margin_top, margin_bottom = 70, 35, 35, 70 + plot_w = width - margin_left - margin_right + plot_h = height - margin_top - margin_bottom + image = bytearray([255] * width * height * 3) + + axis = (32, 32, 32) + grid = (225, 225, 225) + line = (34, 102, 190) + dot = (190, 60, 60) + + x_min = min(x for x, _ in points) + x_max = max(x for x, _ in points) + if x_min == x_max: + x_min = 0 + y_min, y_max = 0.0, 1.0 + + def map_x(x: int) -> int: + if x_max == x_min: + return margin_left + plot_w // 2 + return margin_left + round((x - x_min) / (x_max - x_min) * plot_w) + + def map_y(y: float) -> int: + return margin_top + round((y_max - max(y_min, min(y_max, y))) / (y_max - y_min) * plot_h) + + for i in range(6): + y = margin_top + round(i / 5 * plot_h) + draw_line(image, width, height, (margin_left, y), (width - margin_right, y), grid) + draw_line(image, width, height, (margin_left, margin_top), (margin_left, height - margin_bottom), axis) + draw_line( + image, + width, + height, + (margin_left, height - margin_bottom), + (width - margin_right, height - margin_bottom), + axis, + ) + + mapped = [(map_x(x), map_y(y)) for x, y in points] + for a, b in zip(mapped, mapped[1:], strict=False): + draw_line(image, width, height, a, b, line) + for point in mapped: + draw_dot(image, width, height, point, 5, dot) + + write_png(path, width, height, image) + + def main() -> None: args = parse_args() out = Path(args.out_json) @@ -106,7 +274,7 @@ def main() -> None: else: if not args.stage or not args.debug_file: raise SystemExit("--stage and --debug-file are required unless --aggregate-dir is used") - summary = summarize_debug_file(args.stage, Path(args.debug_file), args.max_response_len) + summary = summarize_debug_file(args.stage, Path(args.debug_file), args.max_response_len, args.train_samples) out.write_text(json.dumps(summary, indent=2) + "\n", encoding="utf-8") print(json.dumps(summary, indent=2)) From 73f7e0bfb1ab43e8d6b50f7bc847528045560ea8 Mon Sep 17 00:00:00 2001 From: Surya Duraivenkatesh Date: Sat, 27 Jun 2026 01:02:26 -0700 Subject: [PATCH 13/13] Fix Tillicum MATH eval scoring and walltimes --- .../04_run_sft_100k_8xh200.sbatch | 2 +- .../05_run_opd_50k_8xh200.sbatch | 2 +- .../06_eval_math500_greedy_1x.sbatch | 6 +- examples/qwen3_8b_opd_tillicum/README.md | 15 ++- .../submit_25k_10k_chain.sh | 24 ++--- .../qwen3_8b_opd_tillicum/summarize_eval.py | 100 +++++++++++++----- 6 files changed, 102 insertions(+), 47 deletions(-) diff --git a/examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch b/examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch index 77368b0b8..e28f58b84 100755 --- a/examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch +++ b/examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch @@ -5,7 +5,7 @@ #SBATCH --cpus-per-task=64 #SBATCH --gres=gpu:h200:8 #SBATCH --mem=0 -#SBATCH --time=03:00:00 +#SBATCH --time=08:00:00 #SBATCH --mail-user=suryadv@cs.washington.edu #SBATCH --mail-type=END,FAIL diff --git a/examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch b/examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch index 788ecc67f..b3e74353e 100755 --- a/examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch +++ b/examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch @@ -5,7 +5,7 @@ #SBATCH --cpus-per-task=64 #SBATCH --gres=gpu:h200:8 #SBATCH --mem=0 -#SBATCH --time=05:00:00 +#SBATCH --time=10:00:00 #SBATCH --mail-user=suryadv@cs.washington.edu #SBATCH --mail-type=END,FAIL diff --git a/examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch b/examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch index 9aa783270..38ac7ce32 100755 --- a/examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch +++ b/examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch @@ -5,7 +5,7 @@ #SBATCH --cpus-per-task=64 #SBATCH --gres=gpu:h200:8 #SBATCH --mem=0 -#SBATCH --time=03:00:00 +#SBATCH --time=02:00:00 #SBATCH --mail-user=suryadv@cs.washington.edu #SBATCH --mail-type=END,FAIL @@ -67,7 +67,7 @@ cfg.write_text( " datasets:\n" " - name: math500\n" f" path: {out}\n" - " rm_type: deepscaler\n", + " rm_type: math\n", encoding="utf-8", ) print(f"Wrote {out}") @@ -119,7 +119,7 @@ run_eval() { --rollout-temperature 0 --rollout-top-p 1 --global-batch-size "${EVAL_ROLLOUT_BATCH_SIZE}" - --rm-type deepscaler + --rm-type math --save-debug-rollout-data "${stage_dir}/debug_{rollout_id}.pt" ) diff --git a/examples/qwen3_8b_opd_tillicum/README.md b/examples/qwen3_8b_opd_tillicum/README.md index f0d34fbf7..e55aa8114 100644 --- a/examples/qwen3_8b_opd_tillicum/README.md +++ b/examples/qwen3_8b_opd_tillicum/README.md @@ -160,10 +160,15 @@ passed at submit time from the environment variables above. The intended wall-clock budget after model/data/container preparation is: -- SFT 25k: 3 hours. -- OPD 10k: 5 hours. -- MATH-500 greedy eval: base 1 hour, SFT curve 3 hours, OPD curve 3 hours. +- SFT 25k: 8 hours. +- OPD 10k: 10 hours. +- MATH-500 greedy eval: base 2 hours, SFT curve 5 hours, OPD curve 5 hours. The main runtime risk is the Qwen3-32B teacher logprob server throughput during -OPD. The current conservative chain evaluates base first, then SFT milestones, -then OPD milestones so SFT results are available before OPD starts. +OPD. The current conservative chain runs SFT, evaluates SFT milestones, runs +OPD, evaluates OPD milestones, then runs the fixed base eval so SFT results are +available before OPD starts. + +MATH-500 summaries report `accuracy` with parse failures counted wrong, +`accuracy_on_parseable` as a diagnostic over parseable responses only, and +`parse_failure_rate` separately. diff --git a/examples/qwen3_8b_opd_tillicum/submit_25k_10k_chain.sh b/examples/qwen3_8b_opd_tillicum/submit_25k_10k_chain.sh index 64ff54fdf..d62539ab4 100755 --- a/examples/qwen3_8b_opd_tillicum/submit_25k_10k_chain.sh +++ b/examples/qwen3_8b_opd_tillicum/submit_25k_10k_chain.sh @@ -26,23 +26,15 @@ jid_convert="$( sbatch --parsable "${SBATCH_COMMON[@]}" \ examples/qwen3_8b_opd_tillicum/03_convert_models_if_needed.sbatch )" -jid_base_eval="$( - sbatch --parsable "${SBATCH_COMMON[@]}" \ - --dependency=afterok:${jid_convert} \ - --time=01:00:00 \ - --job-name=slime-qwen3-base-math500 \ - --export=ALL,EVAL_TARGETS=base,EVAL_OUTPUT_DIR="${BASE_EVAL_OUTPUT_DIR}" \ - examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch -)" jid_sft="$( sbatch --parsable "${SBATCH_COMMON[@]}" \ - --dependency=afterok:${jid_data}:${jid_convert}:${jid_base_eval} \ + --dependency=afterok:${jid_data}:${jid_convert} \ examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch )" jid_sft_eval="$( sbatch --parsable "${SBATCH_COMMON[@]}" \ --dependency=afterok:${jid_sft} \ - --time=03:00:00 \ + --time=05:00:00 \ --job-name=slime-qwen3-sft-math500 \ --export=ALL,EVAL_TARGETS=sft,EVAL_OUTPUT_DIR="${SFT_EVAL_OUTPUT_DIR}" \ examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch @@ -55,20 +47,28 @@ jid_opd="$( jid_opd_eval="$( sbatch --parsable "${SBATCH_COMMON[@]}" \ --dependency=afterok:${jid_opd} \ - --time=03:00:00 \ + --time=05:00:00 \ --job-name=slime-qwen3-opd-math500 \ --export=ALL,EVAL_TARGETS=opd,EVAL_OUTPUT_DIR="${OPD_EVAL_OUTPUT_DIR}" \ examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch )" +jid_base_eval="$( + sbatch --parsable "${SBATCH_COMMON[@]}" \ + --dependency=afterok:${jid_opd_eval} \ + --time=02:00:00 \ + --job-name=slime-qwen3-base-math500 \ + --export=ALL,EVAL_TARGETS=base,EVAL_OUTPUT_DIR="${BASE_EVAL_OUTPUT_DIR}" \ + examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch +)" { echo "data=${jid_data}" echo "convert=${jid_convert}" - echo "base_eval=${jid_base_eval}" echo "sft=${jid_sft}" echo "sft_eval=${jid_sft_eval}" echo "opd=${jid_opd}" echo "opd_eval=${jid_opd_eval}" + echo "base_eval=${jid_base_eval}" echo "SFT_SAVE_DIR=${SFT_SAVE_DIR}" echo "OPD_SAVE_DIR=${OPD_SAVE_DIR}" echo "SFT_EVAL_OUTPUT_DIR=${SFT_EVAL_OUTPUT_DIR}" diff --git a/examples/qwen3_8b_opd_tillicum/summarize_eval.py b/examples/qwen3_8b_opd_tillicum/summarize_eval.py index f086a5fe0..514b86091 100755 --- a/examples/qwen3_8b_opd_tillicum/summarize_eval.py +++ b/examples/qwen3_8b_opd_tillicum/summarize_eval.py @@ -7,6 +7,7 @@ import csv import json import math +import re import struct import zlib from pathlib import Path @@ -26,27 +27,73 @@ def parse_args() -> argparse.Namespace: return parser.parse_args() -def reward_value(sample: dict[str, Any]) -> float: - reward = sample.get("reward", 0.0) - if isinstance(reward, dict): - for key in ("reward", "score", "acc", "accuracy"): - if key in reward: - return float(reward[key]) - return 0.0 - return float(reward or 0.0) +FINAL_ANSWER_RE = re.compile(r"(?i)\b(?:final\s+answer|answer)\s*(?:is|:)\s*(?P.+)$") -def extract_answer_or_none(response: str) -> str | None: - try: - from slime.rollout.rm_hub.math_utils import extract_answer - except Exception: - return None - +def answer_segment(response: str) -> str: if "" in response: - response = response.split("")[-1] - elif "###Response" in response: - response = response.split("###Response", 1)[1] - return extract_answer(response) + return response.rsplit("", 1)[1] + if "###Response" in response: + return response.split("###Response", 1)[1] + return response + + +def clean_final_answer_candidate(candidate: str) -> str: + candidate = candidate.strip() + candidate = re.split(r"<\|endoftext\|>||", candidate, maxsplit=1)[0].strip() + candidate = candidate.split(". ", 1)[0].strip() + candidate = candidate.split("\n", 1)[0].strip() + + for left, right in (("\\(", "\\)"), ("\\[", "\\]"), ("$", "$")): + if candidate.startswith(left) and candidate.endswith(right): + candidate = candidate[len(left) : -len(right)].strip() + + while candidate and candidate[-1] in ".,;:": + candidate = candidate[:-1].strip() + return candidate + + +def extract_prediction(response: str) -> str | None: + from slime.rollout.rm_hub.math_utils import extract_answer + + segment = answer_segment(response) + boxed = extract_answer(segment) + if boxed is not None: + return boxed + + for line in reversed(segment.splitlines()): + match = FINAL_ANSWER_RE.search(line.strip()) + if not match: + continue + candidate = clean_final_answer_candidate(match.group("answer")) + if candidate and len(candidate) <= 120: + return candidate + return None + + +def extract_ground_truth(label: Any) -> str: + from slime.rollout.rm_hub.math_utils import extract_answer + + ground_truth = str(label) + if "\\boxed" in ground_truth: + boxed = extract_answer(ground_truth) + if boxed is not None: + return boxed + return ground_truth + + +def score_sample(sample: dict[str, Any]) -> tuple[float, bool]: + from slime.rollout.rm_hub.math_utils import grade_answer_mathd, grade_answer_sympy + + response = str(sample.get("response", "")) + label = sample.get("label", "") + prediction = extract_prediction(response) + if prediction is None: + return 0.0, False + + ground_truth = extract_ground_truth(label) + is_correct = grade_answer_mathd(prediction, ground_truth) or grade_answer_sympy(prediction, ground_truth) + return float(is_correct), True def summarize_debug_file( @@ -60,15 +107,12 @@ def summarize_debug_file( if not samples: raise RuntimeError(f"No samples found in {debug_file}") - rewards = [reward_value(sample) for sample in samples] + scores_and_parseable = [score_sample(sample) for sample in samples] + rewards = [score for score, _ in scores_and_parseable] + parseable_count = sum(1 for _, parseable in scores_and_parseable if parseable) response_lengths = [int(sample.get("response_length") or 0) for sample in samples] statuses = [str(sample.get("status", "")) for sample in samples] - - parse_failures = 0 - for sample in samples: - response = str(sample.get("response", "")) - if extract_answer_or_none(response) is None: - parse_failures += 1 + parse_failures = len(samples) - parseable_count cap_hits = 0 for length, status in zip(response_lengths, statuses, strict=True): @@ -77,12 +121,16 @@ def summarize_debug_file( n = len(samples) accuracy = sum(rewards) / n + accuracy_on_parseable = sum(rewards) / parseable_count if parseable_count else None return { "stage": stage, "train_samples": train_samples, "debug_file": str(debug_file), "n": n, + "correct_count": int(sum(rewards)), + "parseable_count": parseable_count, "accuracy": accuracy, + "accuracy_on_parseable": accuracy_on_parseable, "mean_accuracy": accuracy, "std_accuracy": None, "avg_generated_tokens": sum(response_lengths) / n, @@ -118,6 +166,7 @@ def write_accuracy_curve_csv(path: Path, summaries: list[dict[str, Any]]) -> Non "stage", "train_samples", "accuracy", + "accuracy_on_parseable", "avg_generated_tokens", "parse_failure_rate", "cap_hit_rate", @@ -130,6 +179,7 @@ def write_accuracy_curve_csv(path: Path, summaries: list[dict[str, Any]]) -> Non "stage": item.get("stage"), "train_samples": item.get("train_samples"), "accuracy": item.get("accuracy"), + "accuracy_on_parseable": item.get("accuracy_on_parseable"), "avg_generated_tokens": item.get("avg_generated_tokens"), "parse_failure_rate": item.get("parse_failure_rate"), "cap_hit_rate": item.get("cap_hit_rate"),