diff --git a/examples/qwen3_8b_opd_tillicum/00_pull_or_load_container.sh b/examples/qwen3_8b_opd_tillicum/00_pull_or_load_container.sh new file mode 100755 index 0000000000..109fa8a82e --- /dev/null +++ b/examples/qwen3_8b_opd_tillicum/00_pull_or_load_container.sh @@ -0,0 +1,49 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" +source "${SCRIPT_DIR}/env.sh" + +APPTAINER_BIN="${APPTAINER_BIN:-}" +if [[ -z "${APPTAINER_BIN}" ]]; then + if command -v apptainer >/dev/null 2>&1; then + APPTAINER_BIN=apptainer + elif command -v singularity >/dev/null 2>&1; then + APPTAINER_BIN=singularity + else + echo "Neither apptainer nor singularity is available on PATH." >&2 + exit 1 + fi +fi + +mkdir -p "$(dirname "${SLIME_SIF}")" "${APPTAINER_CACHEDIR}" "${APPTAINER_TMPDIR}" +export APPTAINER_CACHEDIR APPTAINER_TMPDIR + +cat <&2 + echo "Move it aside or remove it manually, then rerun this script." >&2 + exit 1 +fi + +if [[ "${SLIME_CONTAINER_FORMAT}" == "sandbox" ]]; then + "${APPTAINER_BIN}" build --sandbox "${SLIME_SIF}" "${SLIME_IMAGE_URI}" +else + "${APPTAINER_BIN}" pull --force "${SLIME_SIF}" "${SLIME_IMAGE_URI}" +fi + +"${APPTAINER_BIN}" exec --cleanenv "${SLIME_SIF}" python -V +echo "Wrote and validated ${SLIME_SIF}" diff --git a/examples/qwen3_8b_opd_tillicum/01_prepare_env.sh b/examples/qwen3_8b_opd_tillicum/01_prepare_env.sh new file mode 100755 index 0000000000..c73e5d4758 --- /dev/null +++ b/examples/qwen3_8b_opd_tillicum/01_prepare_env.sh @@ -0,0 +1,81 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" +source "${SCRIPT_DIR}/env.sh" + +mkdir -p \ + "${SCRATCH_ROOT}" \ + "${DATA_ROOT}" \ + "${MODEL_ROOT}" \ + "${OUTPUT_ROOT}" \ + "${HF_HOME}" \ + "${HF_DATASETS_CACHE}" \ + "${TRANSFORMERS_CACHE}" \ + "${WANDB_DIR}" \ + "${TMPDIR}" \ + "${RAY_TMPDIR}" \ + "${APPTAINER_CACHEDIR}" \ + "${APPTAINER_TMPDIR}" \ + "$(dirname "${SLIME_SIF}")" \ + "${CONTAINER_HOME}" \ + "${SFT_SAVE_DIR}" \ + "${OPD_SAVE_DIR}" \ + "${SFT_HF_SNAPSHOT_DIR}" \ + "${OPD_HF_SNAPSHOT_DIR}" \ + "${SFT_DETAILS_DIR}" \ + "${OPD_ROLLOUT_LOG_DIR}" \ + "${TEACHER_LOG_DIR}" \ + "${EVAL_OUTPUT_DIR}" \ + "${BASE_EVAL_OUTPUT_DIR}" \ + "${SFT_EVAL_OUTPUT_DIR}" \ + "${OPD_EVAL_OUTPUT_DIR}" \ + "${CHECKPOINT_REPORT_DIR}" \ + "${SLURM_LOG_DIR}" + +cat </dev/null 2>&1 && pwd)" +fi +source "${SCRIPT_DIR}/env.sh" + +mkdir -p "${SLURM_LOG_DIR}" "${DATA_ROOT}" +if [[ -n "${SLURM_JOB_ID:-}" ]]; then + export RAY_TMPDIR="/tmp/${USER:-suryadv}/slime-ray-${SLURM_JOB_ID}" + mkdir -p "${RAY_TMPDIR}" + exec >"${SLURM_LOG_DIR}/${SLURM_JOB_NAME}_${SLURM_JOB_ID}.log" 2>&1 +fi + +echo "Preparing OpenThoughts3 math data at $(date)" +echo "SFT rows: ${SFT_SIZE} -> ${SFT_PARQUET}" +echo "OPD rows: ${OPD_SIZE} -> ${OPD_JSONL}" +echo "Metadata: ${SPLIT_METADATA}" + +"${SCRIPT_DIR}/container_exec.sh" bash -lc ' +set -euo pipefail +cd "${SLIME_REPO_ROOT}" + +needs_prepare=0 +for path in "${SFT_PARQUET}" "${OPD_JSONL}" "${SPLIT_METADATA}"; do + if [[ ! -s "${path}" ]]; then + needs_prepare=1 + fi +done + +if [[ "${needs_prepare}" == "0" ]]; then + if ! python3 - < int: + with path.open("rb") as f: + return sum(1 for _ in f) + +metadata = json.loads(metadata_path.read_text(encoding="utf-8")) +sft_count = count_jsonl(sft_path) +opd_count = count_jsonl(opd_path) +sft_ids = set(metadata["source_row_ids"]["sft"]) +opd_ids = set(metadata["source_row_ids"]["opd"]) + +if sft_count != sft_expected or opd_count != opd_expected or sft_ids.intersection(opd_ids): + raise SystemExit(1) +PY + then + echo "Existing data files do not match the requested split; regenerating." + needs_prepare=1 + fi +fi + +if [[ "${needs_prepare}" == "1" ]]; then + python3 examples/qwen3_8b_opd_tillicum/02_prepare_openthoughts3_math_sample.py --force +else + echo "Data files already exist; validating counts." +fi + +python3 - < int: + with path.open("rb") as f: + return sum(1 for _ in f) + +sft_count = count_jsonl(sft_path) +opd_count = count_jsonl(opd_path) +if sft_count != sft_expected: + raise SystemExit(f"SFT row count mismatch: {sft_count} != {sft_expected}") +if opd_count != opd_expected: + raise SystemExit(f"OPD row count mismatch: {opd_count} != {opd_expected}") + +metadata = json.loads(metadata_path.read_text(encoding="utf-8")) +sft_ids = set(metadata["source_row_ids"]["sft"]) +opd_ids = set(metadata["source_row_ids"]["opd"]) +overlap = sft_ids.intersection(opd_ids) +if overlap: + raise SystemExit(f"SFT/OPD metadata overlap detected: {sorted(overlap)[:10]}") + +print(f"DATA_PREP_OK sft_rows={sft_count} opd_rows={opd_count} metadata={metadata_path}") +PY +' + +echo "Finished data prep at $(date)" diff --git a/examples/qwen3_8b_opd_tillicum/02_prepare_openthoughts3_math_sample.py b/examples/qwen3_8b_opd_tillicum/02_prepare_openthoughts3_math_sample.py new file mode 100755 index 0000000000..7dd96959a9 --- /dev/null +++ b/examples/qwen3_8b_opd_tillicum/02_prepare_openthoughts3_math_sample.py @@ -0,0 +1,371 @@ +#!/usr/bin/env python3 +"""Prepare row-disjoint OpenThoughts3 math splits for slime SFT and OPD.""" + +from __future__ import annotations + +import argparse +import hashlib +import json +import os +import random +import statistics +from pathlib import Path +from typing import Any + +import pyarrow as pa +import pyarrow.parquet as pq +from datasets import Dataset, load_dataset +from transformers import AutoTokenizer + + +ROLE_MAP = { + "human": "user", + "user": "user", + "prompt": "user", + "gpt": "assistant", + "assistant": "assistant", + "model": "assistant", + "system": "system", +} + + +def parse_args() -> argparse.Namespace: + env = os.environ + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--dataset", default=env.get("OT3_DATASET", "open-thoughts/OpenThoughts3-1.2M")) + parser.add_argument("--split", default=env.get("OT3_SPLIT", "train")) + parser.add_argument("--sft-size", type=int, default=int(env.get("SFT_SIZE", "100000"))) + parser.add_argument("--opd-size", type=int, default=int(env.get("OPD_SIZE", "50000"))) + parser.add_argument("--seed", type=int, default=int(env.get("DATA_SEED", "1234"))) + parser.add_argument("--math-field", default=env.get("DATA_MATH_FIELD", "domain")) + parser.add_argument("--math-value", default=env.get("DATA_MATH_VALUE", "math")) + parser.add_argument("--sft-out", default=env.get("SFT_PARQUET")) + parser.add_argument("--opd-out", default=env.get("OPD_JSONL")) + parser.add_argument("--metadata-out", default=env.get("SPLIT_METADATA")) + parser.add_argument("--hf-home", default=env.get("HF_HOME")) + parser.add_argument("--tokenizer", default=env.get("STUDENT_HF_REPO", "Qwen/Qwen3-8B-Base")) + parser.add_argument("--max-source-rows", type=int, default=None) + parser.add_argument("--force", action="store_true") + args = parser.parse_args() + + missing = [name for name in ("sft_out", "opd_out", "metadata_out") if getattr(args, name) in (None, "")] + if missing: + parser.error(f"missing required output env/arg(s): {', '.join(missing)}") + return args + + +def normalize_messages(value: Any, row: dict[str, Any]) -> list[dict[str, str]]: + if isinstance(value, str): + try: + value = json.loads(value) + except json.JSONDecodeError: + return [{"role": "user", "content": value}] + + if isinstance(value, dict) and "messages" in value: + value = value["messages"] + + if isinstance(value, list): + out: list[dict[str, str]] = [] + for item in value: + if not isinstance(item, dict): + continue + raw_role = item.get("role", item.get("from", item.get("speaker", ""))) + role = ROLE_MAP.get(str(raw_role).lower(), str(raw_role).lower()) + content = item.get("content", item.get("value", item.get("text", ""))) + if content is None: + content = "" + if role in {"system", "user", "assistant", "tool"}: + out.append({"role": role, "content": str(content)}) + if out: + return out + + prompt = first_present(row, ["prompt", "problem", "question", "instruction", "input"]) + response = first_present(row, ["response", "completion", "answer", "solution", "output"]) + if prompt is not None and response is not None: + return [ + {"role": "user", "content": str(prompt)}, + {"role": "assistant", "content": str(response)}, + ] + if prompt is not None: + return [{"role": "user", "content": str(prompt)}] + + raise ValueError("Could not infer OpenAI-style messages from row.") + + +def first_present(row: dict[str, Any], keys: list[str]) -> Any | None: + for key in keys: + if key in row and row[key] not in (None, ""): + return row[key] + return None + + +def extract_messages(row: dict[str, Any]) -> list[dict[str, str]]: + for key in ("messages", "conversations", "conversation"): + if key in row and row[key] not in (None, ""): + return normalize_messages(row[key], row) + return normalize_messages(None, row) + + +def extract_prompt(row: dict[str, Any], messages: list[dict[str, str]]) -> str: + prompt = first_present(row, ["prompt", "problem", "question", "instruction", "input"]) + if prompt is not None: + return str(prompt) + for message in messages: + if message.get("role") == "user": + return str(message.get("content", "")) + raise ValueError("Could not infer prompt text from row/messages.") + + +def prompt_hash(text: str) -> str: + return hashlib.sha256(text.encode("utf-8")).hexdigest() + + +def write_jsonl(path: Path, rows: list[dict[str, Any]]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as f: + for row in rows: + f.write(json.dumps(row, ensure_ascii=False) + "\n") + + +def write_parquet_rows(path: Path, rows: list[dict[str, Any]], batch_size: int = 1000) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + writer: pq.ParquetWriter | None = None + try: + for start in range(0, len(rows), batch_size): + end = min(start + batch_size, len(rows)) + table = pa.Table.from_pylist(rows[start:end]) + if writer is None: + writer = pq.ParquetWriter(str(path), table.schema) + writer.write_table(table) + if end % 5000 == 0 or end == len(rows): + print(f"Parquet rows: {end}/{len(rows)}", flush=True) + finally: + if writer is not None: + writer.close() + + +def write_sft_rows(path: Path, rows: list[dict[str, Any]]) -> None: + if path.suffix == ".jsonl": + write_jsonl(path, rows) + else: + write_parquet_rows(path, rows) + + +def iter_dataset_rows(ds: Dataset, indices: list[int], label: str, batch_size: int = 1000): + selected = ds.select(indices) + total = len(indices) + seen = 0 + for batch in selected.iter(batch_size=batch_size): + keys = list(batch.keys()) + if not keys: + continue + for row_idx in range(len(batch[keys[0]])): + seen += 1 + if seen % 5000 == 0 or seen == total: + print(f"{label}: {seen}/{total}", flush=True) + yield {key: batch[key][row_idx] for key in keys} + + +def encode_lengths(tokenizer: Any, texts: list[str]) -> list[int]: + encoded = tokenizer(texts, add_special_tokens=False) + return [len(input_ids) for input_ids in encoded["input_ids"]] + + +def flush_token_batch( + tokenizer: Any, + texts: list[str], + lengths: list[int], + label: str, + processed: int, + total: int, +) -> int: + if not texts: + return processed + previous = processed + batch_size = len(texts) + lengths.extend(encode_lengths(tokenizer, texts)) + processed += batch_size + texts.clear() + if processed // 5000 > previous // 5000 or processed == total: + print(f"{label} token stats: {processed}/{total}", flush=True) + return processed + + +def token_lengths(tokenizer: Any, sft_rows: list[dict[str, Any]], opd_rows: list[dict[str, Any]]) -> dict[str, Any]: + sft_lengths: list[int] = [] + sft_text_batch: list[str] = [] + sft_processed = 0 + for row in sft_rows: + messages = row["messages"] + try: + text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) + except Exception: + text = "\n".join(f"{m['role']}: {m['content']}" for m in messages) + sft_text_batch.append(text) + if len(sft_text_batch) >= 256: + sft_processed = flush_token_batch( + tokenizer, sft_text_batch, sft_lengths, "SFT", sft_processed, len(sft_rows) + ) + sft_processed = flush_token_batch(tokenizer, sft_text_batch, sft_lengths, "SFT", sft_processed, len(sft_rows)) + + opd_lengths: list[int] = [] + opd_text_batch: list[str] = [] + opd_processed = 0 + for row in opd_rows: + messages = [{"role": "user", "content": row["prompt"]}] + try: + text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + except Exception: + text = row["prompt"] + opd_text_batch.append(text) + if len(opd_text_batch) >= 256: + opd_processed = flush_token_batch( + tokenizer, opd_text_batch, opd_lengths, "OPD", opd_processed, len(opd_rows) + ) + opd_processed = flush_token_batch(tokenizer, opd_text_batch, opd_lengths, "OPD", opd_processed, len(opd_rows)) + + def stats(values: list[int]) -> dict[str, float | int]: + return { + "count": len(values), + "avg": float(statistics.fmean(values)) if values else 0.0, + "max": max(values) if values else 0, + } + + return { + "sft_total_token_lengths": stats(sft_lengths), + "opd_prompt_token_lengths": stats(opd_lengths), + } + + +def main() -> None: + args = parse_args() + sft_out = Path(args.sft_out) + opd_out = Path(args.opd_out) + metadata_out = Path(args.metadata_out) + + existing = [path for path in (sft_out, opd_out, metadata_out) if path.exists()] + if existing and not args.force: + raise SystemExit( + "Output already exists; use --force to replace:\n" + "\n".join(f" {path}" for path in existing) + ) + + print(f"Loading {args.dataset} split={args.split}") + ds = load_dataset(args.dataset, split=args.split, cache_dir=args.hf_home) + if args.max_source_rows is not None: + ds = ds.select(range(min(args.max_source_rows, len(ds)))) + + if args.math_field not in ds.column_names: + raise SystemExit( + f"Dataset does not contain math filter field {args.math_field!r}. " + f"Columns: {', '.join(ds.column_names)}" + ) + + ds = ds.map(lambda _, idx: {"source_row_id": idx}, with_indices=True) + math_value = str(args.math_value).lower() + math_ds = ds.filter(lambda row: str(row.get(args.math_field, "")).lower() == math_value) + + required = args.sft_size + args.opd_size + if len(math_ds) < required: + raise SystemExit(f"Need {required} math rows, found {len(math_ds)} after filtering.") + + rng = random.Random(args.seed) + shuffled = list(range(len(math_ds))) + rng.shuffle(shuffled) + # Split membership is seeded by the shuffled order. Materialize in dataset + # order to avoid very slow random Arrow row reads on GPFS. + sft_indices = sorted(shuffled[: args.sft_size]) + opd_indices = sorted(shuffled[args.sft_size : required]) + + sft_rows: list[dict[str, Any]] = [] + opd_rows: list[dict[str, Any]] = [] + + print(f"Building SFT rows: {len(sft_indices)}") + for row in iter_dataset_rows(math_ds, sft_indices, "SFT rows"): + messages = extract_messages(row) + prompt = extract_prompt(row, messages) + source_row_id = int(row["source_row_id"]) + sft_rows.append( + { + "messages": messages, + "metadata": { + "source_dataset": args.dataset, + "source_split": args.split, + "source_row_id": source_row_id, + "prompt_sha256": prompt_hash(prompt), + "split": "sft", + }, + } + ) + + print(f"Building OPD rows: {len(opd_indices)}") + for row in iter_dataset_rows(math_ds, opd_indices, "OPD rows"): + messages = extract_messages(row) + prompt = extract_prompt(row, messages) + source_row_id = int(row["source_row_id"]) + opd_rows.append( + { + "prompt": prompt, + "metadata": { + "source_dataset": args.dataset, + "source_split": args.split, + "source_row_id": source_row_id, + "prompt_sha256": prompt_hash(prompt), + "split": "opd", + }, + } + ) + + sft_source_ids = [row["metadata"]["source_row_id"] for row in sft_rows] + opd_source_ids = [row["metadata"]["source_row_id"] for row in opd_rows] + overlap = sorted(set(sft_source_ids).intersection(opd_source_ids)) + if overlap: + raise RuntimeError(f"SFT/OPD row split overlap detected: first overlaps {overlap[:10]}") + + write_sft_rows(sft_out, sft_rows) + write_jsonl(opd_out, opd_rows) + print(f"Wrote {sft_out}") + print(f"Wrote {opd_out}") + + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, cache_dir=args.hf_home, trust_remote_code=True) + length_stats = token_lengths(tokenizer, sft_rows, opd_rows) + + metadata = { + "seed": args.seed, + "dataset": args.dataset, + "source_split": args.split, + "filtering_criteria": { + "field": args.math_field, + "value": args.math_value, + "operation": "case-insensitive equality", + }, + "sample_materialization": "Seeded split membership, indices sorted within each split before row materialization.", + "counts": { + "source_rows_seen": len(ds), + "math_rows_after_filter": len(math_ds), + "sft": len(sft_rows), + "opd": len(opd_rows), + }, + "outputs": { + "sft_data": str(sft_out), + "opd_jsonl": str(opd_out), + }, + "source_row_ids": { + "sft": sft_source_ids, + "opd": opd_source_ids, + }, + "prompt_hashes": { + "sft": [row["metadata"]["prompt_sha256"] for row in sft_rows], + "opd": [row["metadata"]["prompt_sha256"] for row in opd_rows], + }, + "tokenizer": args.tokenizer, + "token_lengths": length_stats, + } + + metadata_out.parent.mkdir(parents=True, exist_ok=True) + metadata_out.write_text(json.dumps(metadata, indent=2, ensure_ascii=False) + "\n", encoding="utf-8") + + print(f"Wrote {metadata_out}") + + +if __name__ == "__main__": + main() diff --git a/examples/qwen3_8b_opd_tillicum/03_convert_models_if_needed.sbatch b/examples/qwen3_8b_opd_tillicum/03_convert_models_if_needed.sbatch new file mode 100755 index 0000000000..5930a8b24f --- /dev/null +++ b/examples/qwen3_8b_opd_tillicum/03_convert_models_if_needed.sbatch @@ -0,0 +1,64 @@ +#!/bin/bash +#SBATCH --job-name=slime-qwen3-convert +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=64 +#SBATCH --gres=gpu:h200:8 +#SBATCH --mem=0 +#SBATCH --time=02:00:00 +#SBATCH --mail-user=suryadv@cs.washington.edu +#SBATCH --mail-type=END,FAIL + +set -euo pipefail + +if [[ -n "${SLURM_SUBMIT_DIR:-}" && -f "${SLURM_SUBMIT_DIR}/examples/qwen3_8b_opd_tillicum/env.sh" ]]; then + SCRIPT_DIR="${SLURM_SUBMIT_DIR}/examples/qwen3_8b_opd_tillicum" +else + SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" +fi +source "${SCRIPT_DIR}/env.sh" + +mkdir -p "${SLURM_LOG_DIR}" +if [[ -n "${SLURM_JOB_ID:-}" ]]; then + exec >"${SLURM_LOG_DIR}/${SLURM_JOB_NAME}_${SLURM_JOB_ID}.log" 2>&1 +fi + +echo "Starting model prep on $(hostname) at $(date)" +echo "Student HF: ${STUDENT_HF_REPO} -> ${STUDENT_HF_DIR}" +echo "Teacher HF: ${TEACHER_HF_REPO} -> ${TEACHER_HF_DIR}" +echo "Student torch_dist: ${STUDENT_TORCH_DIST_DIR}" + +"${SCRIPT_DIR}/container_exec.sh" bash -lc ' +set -euo pipefail +cd "${SLIME_REPO_ROOT}" +mkdir -p "${MODEL_ROOT}" "${DATA_ROOT}" + +if [[ ! -d "${STUDENT_HF_DIR}" || -z "$(ls -A "${STUDENT_HF_DIR}" 2>/dev/null)" ]]; then + hf download "${STUDENT_HF_REPO}" --local-dir "${STUDENT_HF_DIR}" +else + echo "Student HF snapshot already exists: ${STUDENT_HF_DIR}" +fi + +if [[ ! -d "${TEACHER_HF_DIR}" || -z "$(ls -A "${TEACHER_HF_DIR}" 2>/dev/null)" ]]; then + hf download "${TEACHER_HF_REPO}" --local-dir "${TEACHER_HF_DIR}" +else + echo "Teacher HF snapshot already exists: ${TEACHER_HF_DIR}" +fi + +hf download --repo-type dataset "${MATH500_DATASET}" --local-dir "${DATA_ROOT}/math500_hf_snapshot" || true + +if [[ -f "${STUDENT_TORCH_DIST_DIR}/latest_checkpointed_iteration.txt" ]]; then + echo "Student torch_dist conversion already exists: ${STUDENT_TORCH_DIST_DIR}" + exit 0 +fi + +source scripts/models/qwen3-8B.sh +mkdir -p "${STUDENT_TORCH_DIST_DIR}" +PYTHONPATH="/root/Megatron-LM:${SLIME_REPO_ROOT}:${PYTHONPATH:-}" \ +torchrun --nproc_per_node="${CONVERT_NPROC:-8}" tools/convert_hf_to_torch_dist.py \ + "${MODEL_ARGS[@]}" \ + --hf-checkpoint "${STUDENT_HF_DIR}" \ + --save "${STUDENT_TORCH_DIST_DIR}" +' + +echo "Finished model prep at $(date)" diff --git a/examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch b/examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch new file mode 100755 index 0000000000..e28f58b849 --- /dev/null +++ b/examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch @@ -0,0 +1,172 @@ +#!/bin/bash +#SBATCH --job-name=slime-qwen3-sft25k +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=64 +#SBATCH --gres=gpu:h200:8 +#SBATCH --mem=0 +#SBATCH --time=08:00:00 +#SBATCH --mail-user=suryadv@cs.washington.edu +#SBATCH --mail-type=END,FAIL + +set -euo pipefail + +if [[ -n "${SLURM_SUBMIT_DIR:-}" && -f "${SLURM_SUBMIT_DIR}/examples/qwen3_8b_opd_tillicum/env.sh" ]]; then + SCRIPT_DIR="${SLURM_SUBMIT_DIR}/examples/qwen3_8b_opd_tillicum" +else + SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" +fi +source "${SCRIPT_DIR}/env.sh" +source "${SCRIPT_DIR}/checkpoint_utils.sh" + +mkdir -p "${SLURM_LOG_DIR}" "${SFT_SAVE_DIR}" "${SFT_HF_SNAPSHOT_DIR}" "${SFT_DETAILS_DIR}" "${CHECKPOINT_REPORT_DIR}" +if [[ -n "${SLURM_JOB_ID:-}" ]]; then + export RAY_TMPDIR="/tmp/${USER:-suryadv}/slime-ray-${SLURM_JOB_ID}" + mkdir -p "${RAY_TMPDIR}" + exec >"${SLURM_LOG_DIR}/${SLURM_JOB_NAME}_${SLURM_JOB_ID}.log" 2>&1 +fi + +for required_path in "${SFT_PARQUET}" "${STUDENT_HF_DIR}" "${STUDENT_TORCH_DIST_DIR}/latest_checkpointed_iteration.txt"; do + if [[ ! -e "${required_path}" ]]; then + echo "Missing required path: ${required_path}" >&2 + exit 1 + fi +done + +echo "Starting SFT on $(hostname) at $(date)" +echo "SFT data: ${SFT_PARQUET}" +echo "Save dir: ${SFT_SAVE_DIR}" +echo "HF eval snapshots: ${SFT_HF_SNAPSHOT_DIR}" +echo "Expected final SFT rollout id: ${SFT_FINAL_ROLLOUT_ID}" + +checkpoint_start_pruner sft "${SFT_SAVE_DIR}" "${CHECKPOINT_PRUNE_INTERVAL_SECONDS}" "${CHECKPOINT_REPORT_DIR}" +cleanup() { + checkpoint_stop_pruner "${CHECKPOINT_PRUNER_PID:-}" +} +trap cleanup EXIT + +"${SCRIPT_DIR}/container_exec.sh" bash -lc ' +set -euo pipefail +cd "${SLIME_REPO_ROOT}" + +ray stop --force >/dev/null 2>&1 || true + +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o "NV[0-9][0-9]*" | wc -l || true) +if [[ "${NVLINK_COUNT}" -gt 0 ]]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK=${HAS_NVLINK} (detected ${NVLINK_COUNT} NVLink references)" + +source scripts/models/qwen3-8B.sh + +CKPT_ARGS=( + --hf-checkpoint "${STUDENT_HF_DIR}" + --ref-load "${STUDENT_TORCH_DIST_DIR}" + --load "${SFT_SAVE_DIR}" + --save "${SFT_SAVE_DIR}" + --save-interval "${SFT_SAVE_INTERVAL}" + --save-hf "${SFT_HF_SNAPSHOT_TEMPLATE}" +) + +SFT_ARGS=( + --rollout-function-path slime.rollout.sft_rollout.generate_rollout + --prompt-data "${SFT_PARQUET}" + --input-key messages + --rollout-shuffle + --num-epoch "${SFT_NUM_EPOCH}" + --rollout-batch-size "${SFT_ROLLOUT_BATCH_SIZE}" + --global-batch-size "${SFT_GLOBAL_BATCH_SIZE}" + --loss-type sft_loss + --calculate-per-token-loss + --disable-compute-advantages-and-returns + --debug-train-only + --dump-details "${SFT_DETAILS_DIR}" +) + +PERF_ARGS=( + --tensor-model-parallel-size 2 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + --use-dynamic-batch-size + --max-tokens-per-gpu "${SFT_MAX_TOKENS_PER_GPU}" +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-5 + --lr-decay-style cosine + --min-lr 1e-6 + --lr-warmup-fraction 0.1 + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.95 +) + +WANDB_ARGS=() +if [[ "${WANDB_MODE}" != "disabled" ]]; then + WANDB_ARGS+=(--wandb-project slime-tillicum --wandb-group qwen3-8b-sft25k) +fi + +MISC_ARGS=( + --attention-dropout 0.0 + --hidden-dropout 0.0 + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + --attention-backend flash +) + +export MASTER_ADDR="${MASTER_ADDR:-127.0.0.1}" +export no_proxy="127.0.0.1,localhost,${MASTER_ADDR}" +ray start --head --node-ip-address "${MASTER_ADDR}" --num-gpus 8 \ + --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 + +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM:${SLIME_REPO_ROOT}\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\", + \"PYTORCH_CUDA_ALLOC_CONF\": \"expandable_segments:True\", + \"WANDB_MODE\": \"${WANDB_MODE}\", + \"HF_HOME\": \"${HF_HOME}\" + } +}" + +trap "ray stop --force >/dev/null 2>&1 || true" EXIT + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train_async.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node 8 \ + "${MODEL_ARGS[@]}" \ + "${CKPT_ARGS[@]}" \ + "${SFT_ARGS[@]}" \ + "${OPTIMIZER_ARGS[@]}" \ + "${WANDB_ARGS[@]}" \ + "${PERF_ARGS[@]}" \ + "${MISC_ARGS[@]}" +' + +checkpoint_stop_pruner "${CHECKPOINT_PRUNER_PID:-}" +checkpoint_prune_old sft "${SFT_SAVE_DIR}" +checkpoint_report sft "${SFT_SAVE_DIR}" "${CHECKPOINT_REPORT_DIR}" +checkpoint_verify_single_latest sft "${SFT_SAVE_DIR}" + +if [[ -f "${SPLIT_METADATA}" ]]; then + cp -f "${SPLIT_METADATA}" "${SFT_SAVE_DIR}/split_metadata.json" +fi + +if [[ ! -d "${SFT_FINAL_HF_DIR}" ]]; then + echo "Missing final SFT HF eval snapshot: ${SFT_FINAL_HF_DIR}" >&2 + exit 1 +fi + +echo "Finished SFT at $(date)" diff --git a/examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch b/examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch new file mode 100755 index 0000000000..b3e74353e8 --- /dev/null +++ b/examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch @@ -0,0 +1,242 @@ +#!/bin/bash +#SBATCH --job-name=slime-qwen3-opd10k +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=64 +#SBATCH --gres=gpu:h200:8 +#SBATCH --mem=0 +#SBATCH --time=10:00:00 +#SBATCH --mail-user=suryadv@cs.washington.edu +#SBATCH --mail-type=END,FAIL + +set -euo pipefail + +if [[ -n "${SLURM_SUBMIT_DIR:-}" && -f "${SLURM_SUBMIT_DIR}/examples/qwen3_8b_opd_tillicum/env.sh" ]]; then + SCRIPT_DIR="${SLURM_SUBMIT_DIR}/examples/qwen3_8b_opd_tillicum" +else + SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" +fi +source "${SCRIPT_DIR}/env.sh" +source "${SCRIPT_DIR}/checkpoint_utils.sh" + +mkdir -p "${SLURM_LOG_DIR}" "${OPD_SAVE_DIR}" "${OPD_HF_SNAPSHOT_DIR}" "${OPD_ROLLOUT_LOG_DIR}" "${TEACHER_LOG_DIR}" "${CHECKPOINT_REPORT_DIR}" +if [[ -n "${SLURM_JOB_ID:-}" ]]; then + export RAY_TMPDIR="/tmp/${USER:-suryadv}/slime-ray-${SLURM_JOB_ID}" + mkdir -p "${RAY_TMPDIR}" + exec >"${SLURM_LOG_DIR}/${SLURM_JOB_NAME}_${SLURM_JOB_ID}.log" 2>&1 +fi + +for required_path in \ + "${OPD_JSONL}" \ + "${STUDENT_HF_DIR}" \ + "${STUDENT_TORCH_DIST_DIR}/latest_checkpointed_iteration.txt" \ + "${TEACHER_HF_DIR}" \ + "${SFT_FINAL_HF_DIR}"; do + if [[ ! -e "${required_path}" ]]; then + echo "Missing required path: ${required_path}" >&2 + exit 1 + fi +done + +echo "Starting OPD on $(hostname) at $(date)" +echo "OPD data: ${OPD_JSONL}" +echo "Teacher: ${TEACHER_HF_DIR} on physical GPU ${OPD_TEACHER_GPU}" +echo "Ray GPUs: 0-$((OPD_RAY_GPUS - 1))" +echo "Save dir: ${OPD_SAVE_DIR}" +echo "HF eval snapshots: ${OPD_HF_SNAPSHOT_DIR}" +echo "Expected final OPD rollout id: ${OPD_FINAL_ROLLOUT_ID}" + +checkpoint_start_pruner opd "${OPD_SAVE_DIR}" "${CHECKPOINT_PRUNE_INTERVAL_SECONDS}" "${CHECKPOINT_REPORT_DIR}" +cleanup_pruner() { + checkpoint_stop_pruner "${CHECKPOINT_PRUNER_PID:-}" +} +trap cleanup_pruner EXIT + +"${SCRIPT_DIR}/container_exec.sh" bash -lc ' +set -euo pipefail +cd "${SLIME_REPO_ROOT}" + +ray stop --force >/dev/null 2>&1 || true + +TEACHER_IP="127.0.0.1" +TEACHER_PORT="${OPD_TEACHER_PORT}" +TEACHER_LOG="${TEACHER_LOG_DIR}/sglang_teacher_${SLURM_JOB_ID:-manual}.log" + +CUDA_VISIBLE_DEVICES="${OPD_TEACHER_GPU}" python3 -m sglang.launch_server \ + --model-path "${TEACHER_HF_DIR}" \ + --host 0.0.0.0 \ + --port "${TEACHER_PORT}" \ + --tp 1 \ + --chunked-prefill-size 4096 \ + --mem-fraction-static "${OPD_TEACHER_MEM_FRACTION}" \ + >"${TEACHER_LOG}" 2>&1 & +TEACHER_PID=$! + +cleanup() { + set +e + if [[ -n "${TEACHER_PID:-}" ]]; then + kill "${TEACHER_PID}" >/dev/null 2>&1 || true + fi + ray stop --force >/dev/null 2>&1 || true +} +trap cleanup EXIT + +echo "Waiting for teacher server on ${TEACHER_IP}:${TEACHER_PORT}" +until curl -sf "http://${TEACHER_IP}:${TEACHER_PORT}/health_generate" >/dev/null; do + tail -n 20 "${TEACHER_LOG}" || true + sleep 10 +done +curl "http://${TEACHER_IP}:${TEACHER_PORT}/get_model_info" || true +sleep 10 + +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o "NV[0-9][0-9]*" | wc -l || true) +if [[ "${NVLINK_COUNT}" -gt 0 ]]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK=${HAS_NVLINK} (detected ${NVLINK_COUNT} NVLink references)" + +source scripts/models/qwen3-8B.sh + +if [[ -f "${OPD_SAVE_DIR}/latest_checkpointed_iteration.txt" ]]; then + OPD_LOAD_DIR="${OPD_SAVE_DIR}" + echo "Resuming OPD from ${OPD_LOAD_DIR}" +else + OPD_LOAD_DIR="${SFT_FINAL_HF_DIR}" + echo "Starting OPD from final SFT HF snapshot ${OPD_LOAD_DIR}" +fi + +CKPT_ARGS=( + --megatron-to-hf-mode bridge + --hf-checkpoint "${STUDENT_HF_DIR}" + --ref-load "${STUDENT_TORCH_DIST_DIR}" + --load "${OPD_LOAD_DIR}" + --save "${OPD_SAVE_DIR}" + --save-interval "${OPD_SAVE_INTERVAL}" + --save-hf "${OPD_HF_SNAPSHOT_TEMPLATE}" +) + +ROLLOUT_ARGS=( + --prompt-data "${OPD_JSONL}" + --input-key prompt + --apply-chat-template + --rollout-shuffle + --num-rollout "${OPD_NUM_ROLLOUT}" + --rollout-batch-size "${OPD_ROLLOUT_BATCH_SIZE}" + --n-samples-per-prompt "${OPD_N_SAMPLES_PER_PROMPT}" + --rollout-max-response-len "${OPD_MAX_RESPONSE_LEN}" + --rollout-temperature 1 + --rollout-top-p 1 + --global-batch-size "${OPD_GLOBAL_BATCH_SIZE}" + --balance-data + --save-debug-rollout-data "${OPD_ROLLOUT_LOG_DIR}/rollout_{rollout_id}.pt" +) + +RM_ARGS=( + --custom-rm-path slime.rollout.on_policy_distillation.reward_func + --custom-reward-post-process-path slime.rollout.on_policy_distillation.post_process_rewards + --rm-url "http://${TEACHER_IP}:${TEACHER_PORT}/generate" +) + +PERF_ARGS=( + --tensor-model-parallel-size 2 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + --use-dynamic-batch-size + --max-tokens-per-gpu "${OPD_MAX_TOKENS_PER_GPU}" +) + +GRPO_ARGS=( + --advantage-estimator grpo + --use-opd + --opd-type sglang + --opd-kl-coef 1.0 + --use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --entropy-coef 0.00 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +WANDB_ARGS=() +if [[ "${WANDB_MODE}" != "disabled" ]]; then + WANDB_ARGS+=(--wandb-project slime-tillicum --wandb-group qwen3-8b-opd10k) +fi + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 1 + --sglang-mem-fraction-static 0.4 +) + +MISC_ARGS=( + --attention-dropout 0.0 + --hidden-dropout 0.0 + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + --attention-backend flash +) + +export MASTER_ADDR="${MASTER_ADDR:-127.0.0.1}" +export no_proxy="127.0.0.1,localhost,${MASTER_ADDR}" +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 ray start --head --node-ip-address "${MASTER_ADDR}" \ + --num-gpus "${OPD_RAY_GPUS}" \ + --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 + +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM:${SLIME_REPO_ROOT}\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\", + \"WANDB_MODE\": \"${WANDB_MODE}\", + \"HF_HOME\": \"${HF_HOME}\" + } +}" + +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node "${OPD_ACTOR_GPUS}" \ + --rollout-num-gpus "${OPD_ROLLOUT_GPUS}" \ + "${MODEL_ARGS[@]}" \ + "${CKPT_ARGS[@]}" \ + "${ROLLOUT_ARGS[@]}" \ + "${OPTIMIZER_ARGS[@]}" \ + "${GRPO_ARGS[@]}" \ + "${WANDB_ARGS[@]}" \ + "${PERF_ARGS[@]}" \ + "${SGLANG_ARGS[@]}" \ + "${MISC_ARGS[@]}" \ + "${RM_ARGS[@]}" +' + +checkpoint_stop_pruner "${CHECKPOINT_PRUNER_PID:-}" +checkpoint_prune_old opd "${OPD_SAVE_DIR}" +checkpoint_report opd "${OPD_SAVE_DIR}" "${CHECKPOINT_REPORT_DIR}" +checkpoint_verify_single_latest opd "${OPD_SAVE_DIR}" + +if [[ -f "${SPLIT_METADATA}" ]]; then + cp -f "${SPLIT_METADATA}" "${OPD_SAVE_DIR}/split_metadata.json" +fi + +if [[ ! -d "${OPD_FINAL_HF_DIR}" ]]; then + echo "Missing final OPD HF eval snapshot: ${OPD_FINAL_HF_DIR}" >&2 + exit 1 +fi + +echo "Finished OPD at $(date)" diff --git a/examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch b/examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch new file mode 100755 index 0000000000..38ac7ce32c --- /dev/null +++ b/examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch @@ -0,0 +1,286 @@ +#!/bin/bash +#SBATCH --job-name=slime-qwen3-math500 +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=64 +#SBATCH --gres=gpu:h200:8 +#SBATCH --mem=0 +#SBATCH --time=02:00:00 +#SBATCH --mail-user=suryadv@cs.washington.edu +#SBATCH --mail-type=END,FAIL + +set -euo pipefail + +if [[ -n "${SLURM_SUBMIT_DIR:-}" && -f "${SLURM_SUBMIT_DIR}/examples/qwen3_8b_opd_tillicum/env.sh" ]]; then + SCRIPT_DIR="${SLURM_SUBMIT_DIR}/examples/qwen3_8b_opd_tillicum" +else + SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" +fi +source "${SCRIPT_DIR}/env.sh" + +mkdir -p "${SLURM_LOG_DIR}" "${EVAL_OUTPUT_DIR}" "${DATA_ROOT}" +if [[ -n "${SLURM_JOB_ID:-}" ]]; then + export RAY_TMPDIR="/tmp/${USER:-suryadv}/slime-ray-${SLURM_JOB_ID}" + mkdir -p "${RAY_TMPDIR}" + exec >"${SLURM_LOG_DIR}/${SLURM_JOB_NAME}_${SLURM_JOB_ID}.log" 2>&1 +fi + +echo "Starting MATH-500 greedy eval on $(hostname) at $(date)" +echo "Eval targets: ${EVAL_TARGETS}" +echo "Eval output: ${EVAL_OUTPUT_DIR}" + +"${SCRIPT_DIR}/container_exec.sh" bash -lc ' +set -euo pipefail +cd "${SLIME_REPO_ROOT}" + +python3 - <&2 + exit 1 + fi + + echo "Evaluating ${stage} from ${load_dir} train_samples=${train_samples}" + ray stop --force >/dev/null 2>&1 || true + + NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o "NV[0-9][0-9]*" | wc -l || true) + if [[ "${NVLINK_COUNT}" -gt 0 ]]; then + HAS_NVLINK=1 + else + HAS_NVLINK=0 + fi + + source scripts/models/qwen3-8B.sh + + CKPT_ARGS=( + --megatron-to-hf-mode bridge + --hf-checkpoint "${load_dir}" + --ref-load "${STUDENT_TORCH_DIST_DIR}" + --load "${load_dir}" + --save "${stage_dir}/unused_save" + --save-interval 100000 + --no-save-optim + ) + + ROLLOUT_ARGS=( + --prompt-data "${MATH500_JSONL}" + --input-key prompt + --apply-chat-template + --num-rollout 0 + --rollout-batch-size "${EVAL_ROLLOUT_BATCH_SIZE}" + --n-samples-per-prompt 1 + --rollout-max-response-len 1024 + --rollout-temperature 0 + --rollout-top-p 1 + --global-batch-size "${EVAL_ROLLOUT_BATCH_SIZE}" + --rm-type math + --save-debug-rollout-data "${stage_dir}/debug_{rollout_id}.pt" + ) + + EVAL_ARGS=( + --eval-interval 1 + --eval-config "${MATH500_CONFIG}" + --n-samples-per-eval-prompt 1 + --eval-temperature 0 + --eval-top-p 1 + --eval-max-response-len "${EVAL_MAX_RESPONSE_LEN}" + --log-passrate + ) + + PERF_ARGS=( + --tensor-model-parallel-size 2 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + --use-dynamic-batch-size + --max-tokens-per-gpu "${EVAL_MAX_RESPONSE_LEN}" + ) + + OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 + ) + + GRPO_ARGS=( + --advantage-estimator grpo + --use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --entropy-coef 0.00 + ) + + SGLANG_ARGS=( + --rollout-num-gpus-per-engine 1 + --sglang-mem-fraction-static 0.7 + --sglang-server-concurrency "${EVAL_SGLANG_SERVER_CONCURRENCY}" + ) + + MISC_ARGS=( + --attention-dropout 0.0 + --hidden-dropout 0.0 + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + --attention-backend flash + ) + + export MASTER_ADDR="${MASTER_ADDR:-127.0.0.1}" + export no_proxy="127.0.0.1,localhost,${MASTER_ADDR}" + ray start --head --node-ip-address "${MASTER_ADDR}" --num-gpus "${EVAL_ROLLOUT_NUM_GPUS}" \ + --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 + + RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM:${SLIME_REPO_ROOT}\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\", + \"WANDB_MODE\": \"${WANDB_MODE}\", + \"HF_HOME\": \"${HF_HOME}\" + } + }" + + ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + --debug-rollout-only \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node 8 \ + --rollout-num-gpus "${EVAL_ROLLOUT_NUM_GPUS}" \ + --colocate \ + "${MODEL_ARGS[@]}" \ + "${CKPT_ARGS[@]}" \ + "${ROLLOUT_ARGS[@]}" \ + "${OPTIMIZER_ARGS[@]}" \ + "${GRPO_ARGS[@]}" \ + "${PERF_ARGS[@]}" \ + "${EVAL_ARGS[@]}" \ + "${SGLANG_ARGS[@]}" \ + "${MISC_ARGS[@]}" + + ray stop --force >/dev/null 2>&1 || true + + python3 examples/qwen3_8b_opd_tillicum/summarize_eval.py \ + --stage "${stage}" \ + --debug-file "${stage_dir}/debug_eval_0.pt" \ + --out-json "${stage_dir}/summary.json" \ + --max-response-len "${EVAL_MAX_RESPONSE_LEN}" \ + --train-samples "${train_samples}" +} + +hf_snapshot_dir() { + local root="$1" + local rollout_id="$2" + printf "%s/iter_%07d" "${root}" "${rollout_id}" +} + +sample_count_for_rollout() { + local rollout_id="$1" + local batch_size="$2" + local total_size="$3" + local count=$(((rollout_id + 1) * batch_size)) + if [[ "${count}" -gt "${total_size}" ]]; then + count="${total_size}" + fi + printf "%s" "${count}" +} + +run_sft_curve() { + local rid samples dir stage + for rid in ${SFT_MILESTONE_ROLLOUT_IDS}; do + samples="$(sample_count_for_rollout "${rid}" "${SFT_ROLLOUT_BATCH_SIZE}" "${SFT_SIZE}")" + dir="$(hf_snapshot_dir "${SFT_HF_SNAPSHOT_DIR}" "${rid}")" + stage="$(printf "sft_%06d" "${samples}")" + run_eval "${stage}" "${dir}" "${samples}" + done +} + +run_opd_curve() { + local rid samples dir stage + for rid in ${OPD_MILESTONE_ROLLOUT_IDS}; do + samples="$(sample_count_for_rollout "${rid}" "${OPD_ROLLOUT_BATCH_SIZE}" "${OPD_SIZE}")" + dir="$(hf_snapshot_dir "${OPD_HF_SNAPSHOT_DIR}" "${rid}")" + stage="$(printf "opd_%06d" "${samples}")" + run_eval "${stage}" "${dir}" "${samples}" + done +} + +trap "ray stop --force >/dev/null 2>&1 || true" EXIT + +for target in ${EVAL_TARGETS}; do + case "${target}" in + base) + run_eval base "${STUDENT_HF_DIR}" 0 + ;; + sft) + run_sft_curve + ;; + opd) + run_opd_curve + ;; + *) + echo "Unknown EVAL_TARGETS entry: ${target}" >&2 + exit 2 + ;; + esac +done + +python3 examples/qwen3_8b_opd_tillicum/summarize_eval.py \ + --aggregate-dir "${EVAL_OUTPUT_DIR}" \ + --out-json "${EVAL_OUTPUT_DIR}/summary_all.json" +' + +echo "Finished MATH-500 eval at $(date)" diff --git a/examples/qwen3_8b_opd_tillicum/README.md b/examples/qwen3_8b_opd_tillicum/README.md new file mode 100644 index 0000000000..e55aa8114f --- /dev/null +++ b/examples/qwen3_8b_opd_tillicum/README.md @@ -0,0 +1,174 @@ +# Qwen3-8B SFT + OPD on Tillicum + +This directory contains thin Tillicum wrappers for a Qwen3-8B SFT plus +on-policy distillation experiment using OpenThoughts3-1.2M math data. + +The wrappers intentionally reuse slime's existing paths: + +- `docs/en/get_started/quick_start.md` for the container and Megatron + conversion workflow. +- `scripts/models/qwen3-8B.sh` for the student Megatron model args. +- `examples/on_policy_distillation/run-qwen3-8B-opd.sh` for the SGLang + teacher OPD shape. +- `scripts/run-qwen3-4B-base-sft.sh` for the SFT rollout/training shape. +- `examples/eval_multi_task/multi_task.sh` and `multi_task.yaml` for eval. + +No slime core code is modified. + +## Required environment + +Source `env.sh` before running commands: + +```bash +cd /gpfs/scrubbed/suryadv/repos/slime +source examples/qwen3_8b_opd_tillicum/env.sh +``` + +Important variables: + +- `ACCOUNT`: Slurm account. Default: `raivn`. +- `PARTITION`: Slurm partition. Default: `gpu-h200`. +- `QOS`: Slurm QOS. Default: `normal`. +- `SCRATCH_ROOT`: root for all generated data, checkpoints, caches, logs, and + the Apptainer image. Default: + `/gpfs/scrubbed/suryadv/slime-qwen3-8b-opd`. +- `DATA_ROOT`: prepared datasets. Default: `$SCRATCH_ROOT/data`. +- `MODEL_ROOT`: HF model snapshots and Megatron torch_dist conversion. + Default: `$SCRATCH_ROOT/models`. +- `OUTPUT_ROOT`: training/eval outputs. Default: `$SCRATCH_ROOT/outputs`. +- `HF_HOME`: Hugging Face cache under scratch. Default: `$SCRATCH_ROOT/hf_home`. +- `WANDB_MODE`: default `offline`. +- `SLIME_CONTAINER_FORMAT`: Apptainer image format. Default: `sandbox`, + because Tillicum's Apptainer produced invalid SquashFS SIFs for this large + Docker image during testing. Set to `sif` to force SIF output. +- `SLIME_SIF`: Apptainer image/sandbox path. Default: + `$SCRATCH_ROOT/containers/slime_latest.sandbox`. + +The scripts avoid writing caches/checkpoints/data under home. + +## Dry checks + +Dry checks do not pull the container, download data/models, install packages, or +submit real jobs. + +```bash +cd /gpfs/scrubbed/suryadv/repos/slime +source examples/qwen3_8b_opd_tillicum/env.sh +bash examples/qwen3_8b_opd_tillicum/run_all_dry_check.sh +``` + +The dry check runs `bash -n`, Python bytecode compilation, and +`sbatch --test-only` for the four Slurm scripts. If `$SLIME_SIF` already +exists and `RUN_CONTAINER_CHECKS=1` is set, it also verifies imports inside the +container. + +## Setup and launch after approval + +Run the setup steps manually, in this order: + +```bash +bash examples/qwen3_8b_opd_tillicum/00_pull_or_load_container.sh +bash examples/qwen3_8b_opd_tillicum/01_prepare_env.sh +bash examples/qwen3_8b_opd_tillicum/run_all_dry_check.sh +``` + +Then launch only after explicit approval: + +```bash +bash examples/qwen3_8b_opd_tillicum/submit_25k_10k_chain.sh +``` + +## Reproducing On Another Slurm Cluster + +These wrappers are Tillicum-shaped, but the core workflow is portable to a +Slurm cluster with Apptainer or Singularity, one 8-GPU node per job, outbound +access to Hugging Face and Docker Hub, and enough scratch space for large model +checkpoints. + +Clone the fork and switch to the reproduction branch: + +```bash +git clone https://github.com/suryathecreator/slime.git +cd slime +git checkout opd-reproduction +``` + +Choose cluster-local paths and Slurm settings. Keep all generated state on +scratch or project storage, not home: + +```bash +export ACCOUNT="" +export PARTITION="" +export QOS="" +export SCRATCH_ROOT="/path/to/scratch/${USER}/slime-qwen3-8b-opd" +export CONTAINER_BIND_ROOTS="$(pwd),${SCRATCH_ROOT},/tmp" + +# Match your site's GPU gres. Examples: gpu:8, gpu:a100:8, gpu:h100:8. +export GPU_GRES="gpu:8" + +# Use sif if your Apptainer can build a normal SIF for slimerl/slime:latest. +export SLIME_CONTAINER_FORMAT="sandbox" + +source examples/qwen3_8b_opd_tillicum/env.sh +``` + +Run the same preparation commands: + +```bash +bash examples/qwen3_8b_opd_tillicum/00_pull_or_load_container.sh +bash examples/qwen3_8b_opd_tillicum/01_prepare_env.sh +bash examples/qwen3_8b_opd_tillicum/container_exec.sh \ + python examples/qwen3_8b_opd_tillicum/02_prepare_openthoughts3_math_sample.py +``` + +Submit the dependency chain, overriding the Tillicum `h200` gres embedded in +the sbatch files: + +```bash +GPU_GRES=gpu:h200:8 bash examples/qwen3_8b_opd_tillicum/submit_25k_10k_chain.sh +``` + +If your cluster uses `--gpus-per-node` instead of `--gres`, replace the +`--gres "$GPU_GRES"` arguments above with your site's GPU request flag. If your +cluster uses Docker rather than Apptainer/Singularity, run the same Python and +Slurm entrypoints inside `slimerl/slime:latest` and bind the repo plus +`$SCRATCH_ROOT` into the container; `container_exec.sh` is the only +Apptainer-specific layer. + +## Outputs + +- SFT split: `$SFT_PARQUET` (JSONL by default despite the legacy variable + name). +- OPD prompt split: `$OPD_JSONL` +- Data metadata: `$SPLIT_METADATA` +- Student HF snapshot: `$STUDENT_HF_DIR` +- Teacher HF snapshot: `$TEACHER_HF_DIR` +- Student Megatron torch_dist: `$STUDENT_TORCH_DIST_DIR` +- SFT full optimizer checkpoint: `$SFT_SAVE_DIR` +- OPD full optimizer checkpoint: `$OPD_SAVE_DIR` +- SFT model-only eval snapshots: `$SFT_HF_SNAPSHOT_DIR` +- OPD model-only eval snapshots: `$OPD_HF_SNAPSHOT_DIR` +- Eval summaries and curves: `$BASE_EVAL_OUTPUT_DIR`, `$SFT_EVAL_OUTPUT_DIR`, + `$OPD_EVAL_OUTPUT_DIR` +- Checkpoint storage reports: `$CHECKPOINT_REPORT_DIR` + +## Slurm resources + +Each job requests one node with `--gres=gpu:h200:8`, `--ntasks=1`, +`--cpus-per-task=64`, and all node memory. The account, partition, and QOS are +passed at submit time from the environment variables above. + +The intended wall-clock budget after model/data/container preparation is: + +- SFT 25k: 8 hours. +- OPD 10k: 10 hours. +- MATH-500 greedy eval: base 2 hours, SFT curve 5 hours, OPD curve 5 hours. + +The main runtime risk is the Qwen3-32B teacher logprob server throughput during +OPD. The current conservative chain runs SFT, evaluates SFT milestones, runs +OPD, evaluates OPD milestones, then runs the fixed base eval so SFT results are +available before OPD starts. + +MATH-500 summaries report `accuracy` with parse failures counted wrong, +`accuracy_on_parseable` as a diagnostic over parseable responses only, and +`parse_failure_rate` separately. diff --git a/examples/qwen3_8b_opd_tillicum/checkpoint_utils.sh b/examples/qwen3_8b_opd_tillicum/checkpoint_utils.sh new file mode 100755 index 0000000000..f5c35c61ce --- /dev/null +++ b/examples/qwen3_8b_opd_tillicum/checkpoint_utils.sh @@ -0,0 +1,130 @@ +#!/usr/bin/env bash + +checkpoint_latest_iteration() { + local save_dir="$1" + local tracker="${save_dir}/latest_checkpointed_iteration.txt" + if [[ ! -f "${tracker}" ]]; then + return 1 + fi + tr -d '[:space:]' <"${tracker}" +} + +checkpoint_iter_name() { + local iteration="$1" + printf "iter_%07d" "${iteration}" +} + +checkpoint_latest_dir() { + local save_dir="$1" + local iteration + iteration="$(checkpoint_latest_iteration "${save_dir}")" || return 1 + printf "%s/%s" "${save_dir}" "$(checkpoint_iter_name "${iteration}")" +} + +checkpoint_prune_old() { + local stage="$1" + local save_dir="$2" + local latest_dir + latest_dir="$(checkpoint_latest_dir "${save_dir}")" || return 0 + + if [[ ! -d "${latest_dir}" ]]; then + echo "CHECKPOINT_PRUNE_WAIT stage=${stage} latest_dir_missing=${latest_dir}" + return 0 + fi + + local dir + find "${save_dir}" -maxdepth 1 -type d -name 'iter_*' -print | while IFS= read -r dir; do + if [[ "${dir}" != "${latest_dir}" ]]; then + echo "CHECKPOINT_PRUNE_REMOVE stage=${stage} dir=${dir}" + rm -rf -- "${dir}" + fi + done +} + +checkpoint_report() { + local stage="$1" + local save_dir="$2" + local report_dir="$3" + mkdir -p "${report_dir}" + + local iteration latest_dir bytes report + iteration="$(checkpoint_latest_iteration "${save_dir}")" + latest_dir="${save_dir}/$(checkpoint_iter_name "${iteration}")" + if [[ ! -d "${latest_dir}" ]]; then + echo "Missing latest checkpoint dir: ${latest_dir}" >&2 + return 1 + fi + + bytes="$(du -sb "${latest_dir}" | awk '{print $1}')" + report="${report_dir}/${stage}_checkpoint_storage.tsv" + if [[ ! -f "${report}" ]]; then + printf "stage\titeration\tcheckpoint_dir\tbytes\testimated_full_optim_bytes\n" >"${report}" + fi + printf "%s\t%s\t%s\t%s\t%s\n" \ + "${stage}" \ + "${iteration}" \ + "${latest_dir}" \ + "${bytes}" \ + "${ESTIMATED_FULL_OPTIM_CKPT_BYTES:-}" >>"${report}" + + echo "CHECKPOINT_SIZE_BYTES stage=${stage} iter=${iteration} bytes=${bytes} dir=${latest_dir}" +} + +checkpoint_verify_single_latest() { + local stage="$1" + local save_dir="$2" + local latest_dir + latest_dir="$(checkpoint_latest_dir "${save_dir}")" + + local count + count="$(find "${save_dir}" -maxdepth 1 -type d -name 'iter_*' | wc -l)" + if [[ "${count}" -ne 1 ]]; then + echo "CHECKPOINT_PRUNE_FAIL stage=${stage} expected_count=1 actual_count=${count}" >&2 + find "${save_dir}" -maxdepth 1 -type d -name 'iter_*' -print >&2 + return 1 + fi + + if [[ ! -d "${latest_dir}" ]]; then + echo "CHECKPOINT_PRUNE_FAIL stage=${stage} missing_latest=${latest_dir}" >&2 + return 1 + fi + + echo "CHECKPOINT_PRUNE_OK stage=${stage} remaining_iter=$(basename "${latest_dir}")" +} + +checkpoint_start_pruner() { + local stage="$1" + local save_dir="$2" + local interval_seconds="$3" + local report_dir="${4:-}" + + ( + set +e + last_reported_iteration="" + while true; do + checkpoint_prune_old "${stage}" "${save_dir}" + if [[ -n "${report_dir}" ]]; then + iteration="$(checkpoint_latest_iteration "${save_dir}" 2>/dev/null)" + if [[ -n "${iteration}" && "${iteration}" != "${last_reported_iteration}" ]]; then + latest_dir="${save_dir}/$(checkpoint_iter_name "${iteration}")" + if [[ -d "${latest_dir}" ]]; then + checkpoint_report "${stage}" "${save_dir}" "${report_dir}" + last_reported_iteration="${iteration}" + fi + fi + fi + sleep "${interval_seconds}" + done + ) & + CHECKPOINT_PRUNER_PID="$!" + export CHECKPOINT_PRUNER_PID + echo "CHECKPOINT_PRUNER_STARTED stage=${stage} pid=${CHECKPOINT_PRUNER_PID}" +} + +checkpoint_stop_pruner() { + local pid="${1:-}" + if [[ -n "${pid}" ]]; then + kill "${pid}" >/dev/null 2>&1 || true + wait "${pid}" >/dev/null 2>&1 || true + fi +} diff --git a/examples/qwen3_8b_opd_tillicum/container_exec.sh b/examples/qwen3_8b_opd_tillicum/container_exec.sh new file mode 100755 index 0000000000..d163b2b570 --- /dev/null +++ b/examples/qwen3_8b_opd_tillicum/container_exec.sh @@ -0,0 +1,161 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" +source "${SCRIPT_DIR}/env.sh" + +if [[ $# -eq 0 ]]; then + echo "Usage: $0 [args...]" + exit 2 +fi + +if [[ ! -e "${SLIME_SIF}" ]]; then + cat >&2 </dev/null 2>&1; then + APPTAINER_BIN=apptainer + elif command -v singularity >/dev/null 2>&1; then + APPTAINER_BIN=singularity + else + echo "Neither apptainer nor singularity is available on PATH." >&2 + exit 1 + fi +fi + +mkdir -p \ + "${APPTAINER_CACHEDIR}" \ + "${APPTAINER_TMPDIR}" \ + "${CONTAINER_HOME}" \ + "${TMPDIR}" \ + "${RAY_TMPDIR}" + +export APPTAINER_CACHEDIR APPTAINER_TMPDIR + +IFS=',' read -r -a BIND_ROOTS <<< "${CONTAINER_BIND_ROOTS}" +APPTAINER_ARGS=( + exec + --nv + --cleanenv + --ipc + --writable-tmpfs + --cwd "${SLIME_REPO_ROOT}" + --home "${CONTAINER_HOME}:${CONTAINER_HOME_INNER}" +) + +for bind_root in "${BIND_ROOTS[@]}"; do + if [[ -e "${bind_root}" ]]; then + APPTAINER_ARGS+=(--bind "${bind_root}:${bind_root}") + fi +done + +PASS_ENV=( + ACCOUNT + PARTITION + QOS + SCRATCH_ROOT + DATA_ROOT + MODEL_ROOT + OUTPUT_ROOT + HF_HOME + HF_DATASETS_CACHE + TRANSFORMERS_CACHE + WANDB_MODE + WANDB_DIR + TMPDIR + RAY_TMPDIR + SLIME_REPO_ROOT + TILLICUM_EXAMPLE_DIR + STUDENT_HF_REPO + TEACHER_HF_REPO + OT3_DATASET + OT3_SPLIT + MATH500_DATASET + SFT_SIZE + OPD_SIZE + DATA_SEED + DATA_MATH_FIELD + DATA_MATH_VALUE + STUDENT_HF_DIR + TEACHER_HF_DIR + STUDENT_TORCH_DIST_DIR + SFT_PARQUET + OPD_JSONL + SPLIT_METADATA + MATH500_JSONL + MATH500_CONFIG + SFT_SAVE_DIR + OPD_SAVE_DIR + SFT_HF_SNAPSHOT_DIR + OPD_HF_SNAPSHOT_DIR + SFT_HF_SNAPSHOT_TEMPLATE + OPD_HF_SNAPSHOT_TEMPLATE + SFT_FINAL_HF_DIR + OPD_FINAL_HF_DIR + SFT_DETAILS_DIR + OPD_ROLLOUT_LOG_DIR + TEACHER_LOG_DIR + EVAL_OUTPUT_DIR + BASE_EVAL_OUTPUT_DIR + SFT_EVAL_OUTPUT_DIR + OPD_EVAL_OUTPUT_DIR + CHECKPOINT_REPORT_DIR + SLURM_LOG_DIR + SFT_NUM_EPOCH + SFT_NUM_ROLLOUT + SFT_FINAL_ROLLOUT_ID + SFT_MILESTONE_ROLLOUT_IDS + SFT_ROLLOUT_BATCH_SIZE + SFT_GLOBAL_BATCH_SIZE + SFT_MAX_TOKENS_PER_GPU + SFT_SAVE_INTERVAL + OPD_NUM_ROLLOUT + OPD_FINAL_ROLLOUT_ID + OPD_MILESTONE_ROLLOUT_IDS + OPD_ROLLOUT_BATCH_SIZE + OPD_GLOBAL_BATCH_SIZE + OPD_N_SAMPLES_PER_PROMPT + OPD_MAX_RESPONSE_LEN + OPD_MAX_TOKENS_PER_GPU + OPD_SAVE_INTERVAL + OPD_ACTOR_GPUS + OPD_ROLLOUT_GPUS + OPD_RAY_GPUS + OPD_TEACHER_GPU + OPD_TEACHER_PORT + OPD_TEACHER_MEM_FRACTION + CHECKPOINT_PRUNE_INTERVAL_SECONDS + ESTIMATED_FULL_OPTIM_CKPT_BYTES + EVAL_MAX_RESPONSE_LEN + EVAL_ROLLOUT_BATCH_SIZE + EVAL_ROLLOUT_NUM_GPUS + EVAL_NUM_REPEATS + EVAL_SGLANG_SERVER_CONCURRENCY + EVAL_TARGETS + PYTHONUNBUFFERED + TOKENIZERS_PARALLELISM + NCCL_DEBUG +) + +for name in "${PASS_ENV[@]}"; do + if [[ -n "${!name-}" ]]; then + APPTAINER_ARGS+=(--env "${name}=${!name}") + fi +done + +APPTAINER_ARGS+=( + --env "HOME=${CONTAINER_HOME_INNER}" + --env "PYTHONPATH=${SLIME_REPO_ROOT}:/root/Megatron-LM:${PYTHONPATH:-}" + --env "CUDA_DEVICE_MAX_CONNECTIONS=${CUDA_DEVICE_MAX_CONNECTIONS:-1}" +) + +"${APPTAINER_BIN}" "${APPTAINER_ARGS[@]}" "${SLIME_SIF}" "$@" diff --git a/examples/qwen3_8b_opd_tillicum/env.sh b/examples/qwen3_8b_opd_tillicum/env.sh new file mode 100755 index 0000000000..c23ec90a39 --- /dev/null +++ b/examples/qwen3_8b_opd_tillicum/env.sh @@ -0,0 +1,131 @@ +#!/usr/bin/env bash + +# Source this file from the repository root or from any script in this +# directory. All generated state is kept under scrubbed/scratch storage. + +if [[ "${BASH_SOURCE[0]}" == "${0}" ]]; then + echo "Source this file instead of executing it:" + echo " source ${BASH_SOURCE[0]}" + exit 2 +fi + +export TILLICUM_EXAMPLE_DIR +TILLICUM_EXAMPLE_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" + +export SLIME_REPO_ROOT +SLIME_REPO_ROOT="$(cd -- "${TILLICUM_EXAMPLE_DIR}/../.." >/dev/null 2>&1 && pwd)" + +export ACCOUNT="${ACCOUNT:-raivn}" +export PARTITION="${PARTITION:-gpu-h200}" +export QOS="${QOS:-normal}" + +export SCRATCH_ROOT="${SCRATCH_ROOT:-/gpfs/scrubbed/suryadv/slime-qwen3-8b-opd}" +export DATA_ROOT="${DATA_ROOT:-${SCRATCH_ROOT}/data}" +export MODEL_ROOT="${MODEL_ROOT:-${SCRATCH_ROOT}/models}" +export OUTPUT_ROOT="${OUTPUT_ROOT:-${SCRATCH_ROOT}/outputs}" +export HF_HOME="${HF_HOME:-${SCRATCH_ROOT}/hf_home}" +export HF_DATASETS_CACHE="${HF_DATASETS_CACHE:-${HF_HOME}/datasets}" +export TRANSFORMERS_CACHE="${TRANSFORMERS_CACHE:-${HF_HOME}/transformers}" +export WANDB_MODE="${WANDB_MODE:-offline}" +export WANDB_DIR="${WANDB_DIR:-${OUTPUT_ROOT}/wandb}" +export TMPDIR="${TMPDIR:-${SCRATCH_ROOT}/tmp}" +export RAY_TMPDIR="${RAY_TMPDIR:-${SCRATCH_ROOT}/ray_tmp}" + +export APPTAINER_CACHEDIR="${APPTAINER_CACHEDIR:-${SCRATCH_ROOT}/apptainer_cache}" +export APPTAINER_TMPDIR="${APPTAINER_TMPDIR:-${SCRATCH_ROOT}/apptainer_tmp}" +export SLIME_IMAGE_URI="${SLIME_IMAGE_URI:-docker://slimerl/slime:latest}" +export SLIME_CONTAINER_FORMAT="${SLIME_CONTAINER_FORMAT:-sandbox}" +if [[ -z "${SLIME_SIF:-}" ]]; then + if [[ "${SLIME_CONTAINER_FORMAT}" == "sandbox" ]]; then + export SLIME_SIF="${SCRATCH_ROOT}/containers/slime_latest.sandbox" + else + export SLIME_SIF="${SCRATCH_ROOT}/containers/slime_latest.sif" + fi +else + export SLIME_SIF +fi +export CONTAINER_BIND_ROOTS="${CONTAINER_BIND_ROOTS:-/gpfs/scrubbed/suryadv,/tmp}" +export CONTAINER_HOME="${CONTAINER_HOME:-${SCRATCH_ROOT}/container_home}" +export CONTAINER_HOME_INNER="${CONTAINER_HOME_INNER:-/home/${USER:-slime}}" + +export STUDENT_HF_REPO="${STUDENT_HF_REPO:-Qwen/Qwen3-8B-Base}" +export TEACHER_HF_REPO="${TEACHER_HF_REPO:-Qwen/Qwen3-32B}" +export OT3_DATASET="${OT3_DATASET:-open-thoughts/OpenThoughts3-1.2M}" +export MATH500_DATASET="${MATH500_DATASET:-HuggingFaceH4/MATH-500}" + +export STUDENT_HF_DIR="${STUDENT_HF_DIR:-${MODEL_ROOT}/Qwen3-8B-Base}" +export TEACHER_HF_DIR="${TEACHER_HF_DIR:-${MODEL_ROOT}/Qwen3-32B}" +export STUDENT_TORCH_DIST_DIR="${STUDENT_TORCH_DIST_DIR:-${MODEL_ROOT}/Qwen3-8B-Base_torch_dist}" + +export SFT_SIZE="${SFT_SIZE:-25000}" +export OPD_SIZE="${OPD_SIZE:-10000}" +export DATA_SEED="${DATA_SEED:-1234}" +export DATA_MATH_FIELD="${DATA_MATH_FIELD:-domain}" +export DATA_MATH_VALUE="${DATA_MATH_VALUE:-math}" +export OT3_SPLIT="${OT3_SPLIT:-train}" + +export SFT_PARQUET="${SFT_PARQUET:-${DATA_ROOT}/openthoughts3_math_sft_${SFT_SIZE}.jsonl}" +export OPD_JSONL="${OPD_JSONL:-${DATA_ROOT}/openthoughts3_math_opd_${OPD_SIZE}.jsonl}" +export SPLIT_METADATA="${SPLIT_METADATA:-${DATA_ROOT}/openthoughts3_math_sft_${SFT_SIZE}_opd_${OPD_SIZE}_metadata.json}" +export MATH500_JSONL="${MATH500_JSONL:-${DATA_ROOT}/math500_deepscaler.jsonl}" +export MATH500_CONFIG="${MATH500_CONFIG:-${DATA_ROOT}/math500_eval.yaml}" + +export SFT_SAVE_DIR="${SFT_SAVE_DIR:-${OUTPUT_ROOT}/qwen3_8b_sft_25k_full_optim}" +export OPD_SAVE_DIR="${OPD_SAVE_DIR:-${OUTPUT_ROOT}/qwen3_8b_sft_25k_opd_10k_full_optim}" +export SFT_HF_SNAPSHOT_DIR="${SFT_HF_SNAPSHOT_DIR:-${OUTPUT_ROOT}/qwen3_8b_sft_25k_eval_snapshots}" +export OPD_HF_SNAPSHOT_DIR="${OPD_HF_SNAPSHOT_DIR:-${OUTPUT_ROOT}/qwen3_8b_sft_25k_opd_10k_eval_snapshots}" +export SFT_DETAILS_DIR="${SFT_DETAILS_DIR:-${OUTPUT_ROOT}/sft_25k_details}" +export OPD_ROLLOUT_LOG_DIR="${OPD_ROLLOUT_LOG_DIR:-${OUTPUT_ROOT}/opd_10k_rollout_logs}" +export TEACHER_LOG_DIR="${TEACHER_LOG_DIR:-${OUTPUT_ROOT}/teacher_logs}" +export EVAL_OUTPUT_DIR="${EVAL_OUTPUT_DIR:-${OUTPUT_ROOT}/math500_eval_25k_10k}" +export BASE_EVAL_OUTPUT_DIR="${BASE_EVAL_OUTPUT_DIR:-${OUTPUT_ROOT}/math500_eval_base_25k_10k}" +export SFT_EVAL_OUTPUT_DIR="${SFT_EVAL_OUTPUT_DIR:-${OUTPUT_ROOT}/math500_eval_sft_25k_curve}" +export OPD_EVAL_OUTPUT_DIR="${OPD_EVAL_OUTPUT_DIR:-${OUTPUT_ROOT}/math500_eval_opd_10k_curve}" +export CHECKPOINT_REPORT_DIR="${CHECKPOINT_REPORT_DIR:-${OUTPUT_ROOT}/checkpoint_reports_25k_10k}" +export SLURM_LOG_DIR="${SLURM_LOG_DIR:-${OUTPUT_ROOT}/slurm_logs}" + +export SFT_NUM_EPOCH="${SFT_NUM_EPOCH:-1}" +export SFT_ROLLOUT_BATCH_SIZE="${SFT_ROLLOUT_BATCH_SIZE:-256}" +export SFT_GLOBAL_BATCH_SIZE="${SFT_GLOBAL_BATCH_SIZE:-256}" +export SFT_MAX_TOKENS_PER_GPU="${SFT_MAX_TOKENS_PER_GPU:-16384}" +export SFT_SAVE_INTERVAL="${SFT_SAVE_INTERVAL:-20}" +export SFT_NUM_ROLLOUT="${SFT_NUM_ROLLOUT:-$(((SFT_SIZE + SFT_ROLLOUT_BATCH_SIZE - 1) / SFT_ROLLOUT_BATCH_SIZE))}" +export SFT_FINAL_ROLLOUT_ID="${SFT_FINAL_ROLLOUT_ID:-$((SFT_NUM_ROLLOUT - 1))}" +export SFT_MILESTONE_ROLLOUT_IDS="${SFT_MILESTONE_ROLLOUT_IDS:-19 39 59 79 ${SFT_FINAL_ROLLOUT_ID}}" +export SFT_HF_SNAPSHOT_TEMPLATE="${SFT_HF_SNAPSHOT_TEMPLATE:-${SFT_HF_SNAPSHOT_DIR}/iter_{rollout_id:07d}}" +printf -v _SFT_FINAL_ROLLOUT_TAG "iter_%07d" "${SFT_FINAL_ROLLOUT_ID}" +export SFT_FINAL_HF_DIR="${SFT_FINAL_HF_DIR:-${SFT_HF_SNAPSHOT_DIR}/${_SFT_FINAL_ROLLOUT_TAG}}" +unset _SFT_FINAL_ROLLOUT_TAG + +export OPD_ROLLOUT_BATCH_SIZE="${OPD_ROLLOUT_BATCH_SIZE:-128}" +export OPD_GLOBAL_BATCH_SIZE="${OPD_GLOBAL_BATCH_SIZE:-128}" +export OPD_N_SAMPLES_PER_PROMPT="${OPD_N_SAMPLES_PER_PROMPT:-1}" +export OPD_NUM_ROLLOUT="${OPD_NUM_ROLLOUT:-$(((OPD_SIZE + OPD_ROLLOUT_BATCH_SIZE - 1) / OPD_ROLLOUT_BATCH_SIZE))}" +export OPD_FINAL_ROLLOUT_ID="${OPD_FINAL_ROLLOUT_ID:-$((OPD_NUM_ROLLOUT - 1))}" +export OPD_MILESTONE_ROLLOUT_IDS="${OPD_MILESTONE_ROLLOUT_IDS:-15 31 47 63 ${OPD_FINAL_ROLLOUT_ID}}" +export OPD_MAX_RESPONSE_LEN="${OPD_MAX_RESPONSE_LEN:-16384}" +export OPD_MAX_TOKENS_PER_GPU="${OPD_MAX_TOKENS_PER_GPU:-16384}" +export OPD_SAVE_INTERVAL="${OPD_SAVE_INTERVAL:-16}" +export OPD_HF_SNAPSHOT_TEMPLATE="${OPD_HF_SNAPSHOT_TEMPLATE:-${OPD_HF_SNAPSHOT_DIR}/iter_{rollout_id:07d}}" +printf -v _OPD_FINAL_ROLLOUT_TAG "iter_%07d" "${OPD_FINAL_ROLLOUT_ID}" +export OPD_FINAL_HF_DIR="${OPD_FINAL_HF_DIR:-${OPD_HF_SNAPSHOT_DIR}/${_OPD_FINAL_ROLLOUT_TAG}}" +unset _OPD_FINAL_ROLLOUT_TAG +export OPD_ACTOR_GPUS="${OPD_ACTOR_GPUS:-2}" +export OPD_ROLLOUT_GPUS="${OPD_ROLLOUT_GPUS:-5}" +export OPD_RAY_GPUS="${OPD_RAY_GPUS:-7}" +export OPD_TEACHER_GPU="${OPD_TEACHER_GPU:-7}" +export OPD_TEACHER_PORT="${OPD_TEACHER_PORT:-13141}" +export OPD_TEACHER_MEM_FRACTION="${OPD_TEACHER_MEM_FRACTION:-0.6}" +export CHECKPOINT_PRUNE_INTERVAL_SECONDS="${CHECKPOINT_PRUNE_INTERVAL_SECONDS:-60}" +export ESTIMATED_FULL_OPTIM_CKPT_BYTES="${ESTIMATED_FULL_OPTIM_CKPT_BYTES:-114694164347}" + +export EVAL_MAX_RESPONSE_LEN="${EVAL_MAX_RESPONSE_LEN:-31744}" +export EVAL_ROLLOUT_BATCH_SIZE="${EVAL_ROLLOUT_BATCH_SIZE:-64}" +export EVAL_ROLLOUT_NUM_GPUS="${EVAL_ROLLOUT_NUM_GPUS:-8}" +export EVAL_NUM_REPEATS="${EVAL_NUM_REPEATS:-1}" +export EVAL_SGLANG_SERVER_CONCURRENCY="${EVAL_SGLANG_SERVER_CONCURRENCY:-1}" +export EVAL_TARGETS="${EVAL_TARGETS:-base sft opd}" + +export PYTHONUNBUFFERED=1 +export TOKENIZERS_PARALLELISM="${TOKENIZERS_PARALLELISM:-false}" +export NCCL_DEBUG="${NCCL_DEBUG:-WARN}" diff --git a/examples/qwen3_8b_opd_tillicum/run_all_dry_check.sh b/examples/qwen3_8b_opd_tillicum/run_all_dry_check.sh new file mode 100755 index 0000000000..11029778f1 --- /dev/null +++ b/examples/qwen3_8b_opd_tillicum/run_all_dry_check.sh @@ -0,0 +1,76 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" +source "${SCRIPT_DIR}/env.sh" + +cd "${SLIME_REPO_ROOT}" + +echo "Dry check environment" +echo " repo: ${SLIME_REPO_ROOT}" +echo " account/partition/qos: ${ACCOUNT}/${PARTITION}/${QOS}" +echo " scratch: ${SCRATCH_ROOT}" +echo " container (${SLIME_CONTAINER_FORMAT}): ${SLIME_SIF}" + +SHELL_FILES=( + examples/qwen3_8b_opd_tillicum/env.sh + examples/qwen3_8b_opd_tillicum/checkpoint_utils.sh + examples/qwen3_8b_opd_tillicum/container_exec.sh + examples/qwen3_8b_opd_tillicum/00_pull_or_load_container.sh + examples/qwen3_8b_opd_tillicum/01_prepare_env.sh + examples/qwen3_8b_opd_tillicum/run_all_dry_check.sh + examples/qwen3_8b_opd_tillicum/submit_25k_10k_chain.sh + examples/qwen3_8b_opd_tillicum/02_prepare_data_25k_10k.sbatch + examples/qwen3_8b_opd_tillicum/03_convert_models_if_needed.sbatch + examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch + examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch + examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch +) + +PYTHON_FILES=( + examples/qwen3_8b_opd_tillicum/02_prepare_openthoughts3_math_sample.py + examples/qwen3_8b_opd_tillicum/summarize_eval.py +) + +echo "Checking shell syntax" +for file in "${SHELL_FILES[@]}"; do + bash -n "${file}" +done + +echo "Checking Python syntax" +python3 -m py_compile "${PYTHON_FILES[@]}" + +SBATCH_FILES=( + examples/qwen3_8b_opd_tillicum/02_prepare_data_25k_10k.sbatch + examples/qwen3_8b_opd_tillicum/03_convert_models_if_needed.sbatch + examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch + examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch + examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch +) + +echo "Checking Slurm scripts with sbatch --test-only" +for file in "${SBATCH_FILES[@]}"; do + sbatch --test-only -A "${ACCOUNT}" -p "${PARTITION}" --qos "${QOS}" "${file}" +done + +if [[ "${RUN_CONTAINER_CHECKS:-0}" == "1" ]]; then + if [[ -e "${SLIME_SIF}" ]]; then + echo "Checking imports inside container" + "${SCRIPT_DIR}/container_exec.sh" python -c "import slime, sglang, transformers, datasets; print('container imports ok')" + else + echo "RUN_CONTAINER_CHECKS=1 but SLIME_SIF does not exist; skipping import check." + fi +else + echo "Skipping container import check. Set RUN_CONTAINER_CHECKS=1 after the SIF exists." +fi + +echo "git diff --stat" +git diff --stat + +echo "git diff --name-only" +git diff --name-only + +echo "git status --short" +git status --short + +echo "Dry checks completed. No real jobs were submitted." diff --git a/examples/qwen3_8b_opd_tillicum/submit_25k_10k_chain.sh b/examples/qwen3_8b_opd_tillicum/submit_25k_10k_chain.sh new file mode 100755 index 0000000000..d62539ab4d --- /dev/null +++ b/examples/qwen3_8b_opd_tillicum/submit_25k_10k_chain.sh @@ -0,0 +1,77 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" +source "${SCRIPT_DIR}/env.sh" + +cd "${SLIME_REPO_ROOT}" +mkdir -p "${SLURM_LOG_DIR}" + +SBATCH_COMMON=(-A "${ACCOUNT}" -p "${PARTITION}" --qos "${QOS}") +if [[ -n "${GPU_GRES:-}" ]]; then + SBATCH_COMMON+=(--gres "${GPU_GRES}") +fi + +submit_log="${SLURM_LOG_DIR}/submit_25k_10k_$(date +%Y%m%d_%H%M%S).txt" + +echo "Submitting 25k SFT / 10k OPD chain" +echo "account/partition/qos: ${ACCOUNT}/${PARTITION}/${QOS}" +echo "submit log: ${submit_log}" + +jid_data="$( + sbatch --parsable "${SBATCH_COMMON[@]}" \ + examples/qwen3_8b_opd_tillicum/02_prepare_data_25k_10k.sbatch +)" +jid_convert="$( + sbatch --parsable "${SBATCH_COMMON[@]}" \ + examples/qwen3_8b_opd_tillicum/03_convert_models_if_needed.sbatch +)" +jid_sft="$( + sbatch --parsable "${SBATCH_COMMON[@]}" \ + --dependency=afterok:${jid_data}:${jid_convert} \ + examples/qwen3_8b_opd_tillicum/04_run_sft_100k_8xh200.sbatch +)" +jid_sft_eval="$( + sbatch --parsable "${SBATCH_COMMON[@]}" \ + --dependency=afterok:${jid_sft} \ + --time=05:00:00 \ + --job-name=slime-qwen3-sft-math500 \ + --export=ALL,EVAL_TARGETS=sft,EVAL_OUTPUT_DIR="${SFT_EVAL_OUTPUT_DIR}" \ + examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch +)" +jid_opd="$( + sbatch --parsable "${SBATCH_COMMON[@]}" \ + --dependency=afterok:${jid_sft_eval} \ + examples/qwen3_8b_opd_tillicum/05_run_opd_50k_8xh200.sbatch +)" +jid_opd_eval="$( + sbatch --parsable "${SBATCH_COMMON[@]}" \ + --dependency=afterok:${jid_opd} \ + --time=05:00:00 \ + --job-name=slime-qwen3-opd-math500 \ + --export=ALL,EVAL_TARGETS=opd,EVAL_OUTPUT_DIR="${OPD_EVAL_OUTPUT_DIR}" \ + examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch +)" +jid_base_eval="$( + sbatch --parsable "${SBATCH_COMMON[@]}" \ + --dependency=afterok:${jid_opd_eval} \ + --time=02:00:00 \ + --job-name=slime-qwen3-base-math500 \ + --export=ALL,EVAL_TARGETS=base,EVAL_OUTPUT_DIR="${BASE_EVAL_OUTPUT_DIR}" \ + examples/qwen3_8b_opd_tillicum/06_eval_math500_greedy_1x.sbatch +)" + +{ + echo "data=${jid_data}" + echo "convert=${jid_convert}" + echo "sft=${jid_sft}" + echo "sft_eval=${jid_sft_eval}" + echo "opd=${jid_opd}" + echo "opd_eval=${jid_opd_eval}" + echo "base_eval=${jid_base_eval}" + echo "SFT_SAVE_DIR=${SFT_SAVE_DIR}" + echo "OPD_SAVE_DIR=${OPD_SAVE_DIR}" + echo "SFT_EVAL_OUTPUT_DIR=${SFT_EVAL_OUTPUT_DIR}" + echo "OPD_EVAL_OUTPUT_DIR=${OPD_EVAL_OUTPUT_DIR}" + echo "CHECKPOINT_REPORT_DIR=${CHECKPOINT_REPORT_DIR}" +} | tee "${submit_log}" diff --git a/examples/qwen3_8b_opd_tillicum/summarize_eval.py b/examples/qwen3_8b_opd_tillicum/summarize_eval.py new file mode 100755 index 0000000000..514b86091b --- /dev/null +++ b/examples/qwen3_8b_opd_tillicum/summarize_eval.py @@ -0,0 +1,334 @@ +#!/usr/bin/env python3 +"""Summarize slime eval debug rollout files.""" + +from __future__ import annotations + +import argparse +import csv +import json +import math +import re +import struct +import zlib +from pathlib import Path +from typing import Any + +import torch + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--stage", default=None) + parser.add_argument("--debug-file", default=None) + parser.add_argument("--out-json", required=True) + parser.add_argument("--max-response-len", type=int, default=None) + parser.add_argument("--aggregate-dir", default=None) + parser.add_argument("--train-samples", type=int, default=None) + return parser.parse_args() + + +FINAL_ANSWER_RE = re.compile(r"(?i)\b(?:final\s+answer|answer)\s*(?:is|:)\s*(?P.+)$") + + +def answer_segment(response: str) -> str: + if "" in response: + return response.rsplit("", 1)[1] + if "###Response" in response: + return response.split("###Response", 1)[1] + return response + + +def clean_final_answer_candidate(candidate: str) -> str: + candidate = candidate.strip() + candidate = re.split(r"<\|endoftext\|>||", candidate, maxsplit=1)[0].strip() + candidate = candidate.split(". ", 1)[0].strip() + candidate = candidate.split("\n", 1)[0].strip() + + for left, right in (("\\(", "\\)"), ("\\[", "\\]"), ("$", "$")): + if candidate.startswith(left) and candidate.endswith(right): + candidate = candidate[len(left) : -len(right)].strip() + + while candidate and candidate[-1] in ".,;:": + candidate = candidate[:-1].strip() + return candidate + + +def extract_prediction(response: str) -> str | None: + from slime.rollout.rm_hub.math_utils import extract_answer + + segment = answer_segment(response) + boxed = extract_answer(segment) + if boxed is not None: + return boxed + + for line in reversed(segment.splitlines()): + match = FINAL_ANSWER_RE.search(line.strip()) + if not match: + continue + candidate = clean_final_answer_candidate(match.group("answer")) + if candidate and len(candidate) <= 120: + return candidate + return None + + +def extract_ground_truth(label: Any) -> str: + from slime.rollout.rm_hub.math_utils import extract_answer + + ground_truth = str(label) + if "\\boxed" in ground_truth: + boxed = extract_answer(ground_truth) + if boxed is not None: + return boxed + return ground_truth + + +def score_sample(sample: dict[str, Any]) -> tuple[float, bool]: + from slime.rollout.rm_hub.math_utils import grade_answer_mathd, grade_answer_sympy + + response = str(sample.get("response", "")) + label = sample.get("label", "") + prediction = extract_prediction(response) + if prediction is None: + return 0.0, False + + ground_truth = extract_ground_truth(label) + is_correct = grade_answer_mathd(prediction, ground_truth) or grade_answer_sympy(prediction, ground_truth) + return float(is_correct), True + + +def summarize_debug_file( + stage: str, + debug_file: Path, + max_response_len: int | None, + train_samples: int | None, +) -> dict[str, Any]: + payload = torch.load(debug_file, map_location="cpu", weights_only=False) + samples = payload.get("samples", []) + if not samples: + raise RuntimeError(f"No samples found in {debug_file}") + + scores_and_parseable = [score_sample(sample) for sample in samples] + rewards = [score for score, _ in scores_and_parseable] + parseable_count = sum(1 for _, parseable in scores_and_parseable if parseable) + response_lengths = [int(sample.get("response_length") or 0) for sample in samples] + statuses = [str(sample.get("status", "")) for sample in samples] + parse_failures = len(samples) - parseable_count + + cap_hits = 0 + for length, status in zip(response_lengths, statuses, strict=True): + if status == "truncated" or (max_response_len is not None and length >= max_response_len): + cap_hits += 1 + + n = len(samples) + accuracy = sum(rewards) / n + accuracy_on_parseable = sum(rewards) / parseable_count if parseable_count else None + return { + "stage": stage, + "train_samples": train_samples, + "debug_file": str(debug_file), + "n": n, + "correct_count": int(sum(rewards)), + "parseable_count": parseable_count, + "accuracy": accuracy, + "accuracy_on_parseable": accuracy_on_parseable, + "mean_accuracy": accuracy, + "std_accuracy": None, + "avg_generated_tokens": sum(response_lengths) / n, + "max_generated_tokens": max(response_lengths), + "parse_failure_rate": parse_failures / n, + "cap_hit_rate": cap_hits / n, + } + + +def aggregate(aggregate_dir: Path) -> dict[str, Any]: + summaries = [] + for path in sorted(aggregate_dir.glob("*/summary.json")): + summaries.append(json.loads(path.read_text(encoding="utf-8"))) + if not summaries: + raise RuntimeError(f"No per-stage summaries found under {aggregate_dir}") + + summaries.sort(key=lambda item: (item.get("train_samples") is None, item.get("train_samples") or 0, item["stage"])) + write_accuracy_curve_csv(aggregate_dir / "accuracy_curve.csv", summaries) + write_accuracy_curve_png(aggregate_dir / "accuracy_curve.png", summaries) + + return { + "num_repeats_per_stage": 1, + "stages": summaries, + "note": "Eval repeat count is 1, so per-stage std_accuracy is N/A.", + } + + +def write_accuracy_curve_csv(path: Path, summaries: list[dict[str, Any]]) -> None: + with path.open("w", encoding="utf-8", newline="") as f: + writer = csv.DictWriter( + f, + fieldnames=[ + "stage", + "train_samples", + "accuracy", + "accuracy_on_parseable", + "avg_generated_tokens", + "parse_failure_rate", + "cap_hit_rate", + ], + ) + writer.writeheader() + for item in summaries: + writer.writerow( + { + "stage": item.get("stage"), + "train_samples": item.get("train_samples"), + "accuracy": item.get("accuracy"), + "accuracy_on_parseable": item.get("accuracy_on_parseable"), + "avg_generated_tokens": item.get("avg_generated_tokens"), + "parse_failure_rate": item.get("parse_failure_rate"), + "cap_hit_rate": item.get("cap_hit_rate"), + } + ) + + +def put_px(image: bytearray, width: int, height: int, x: int, y: int, color: tuple[int, int, int]) -> None: + if 0 <= x < width and 0 <= y < height: + idx = (y * width + x) * 3 + image[idx : idx + 3] = bytes(color) + + +def draw_line( + image: bytearray, + width: int, + height: int, + start: tuple[int, int], + end: tuple[int, int], + color: tuple[int, int, int], +) -> None: + x0, y0 = start + x1, y1 = end + dx = abs(x1 - x0) + sx = 1 if x0 < x1 else -1 + dy = -abs(y1 - y0) + sy = 1 if y0 < y1 else -1 + err = dx + dy + while True: + put_px(image, width, height, x0, y0, color) + if x0 == x1 and y0 == y1: + break + e2 = 2 * err + if e2 >= dy: + err += dy + x0 += sx + if e2 <= dx: + err += dx + y0 += sy + + +def draw_dot( + image: bytearray, + width: int, + height: int, + center: tuple[int, int], + radius: int, + color: tuple[int, int, int], +) -> None: + cx, cy = center + for y in range(cy - radius, cy + radius + 1): + for x in range(cx - radius, cx + radius + 1): + if (x - cx) ** 2 + (y - cy) ** 2 <= radius**2: + put_px(image, width, height, x, y, color) + + +def write_png(path: Path, width: int, height: int, image: bytearray) -> None: + rows = [] + row_bytes = width * 3 + for y in range(height): + rows.append(b"\x00" + bytes(image[y * row_bytes : (y + 1) * row_bytes])) + raw = b"".join(rows) + + def chunk(kind: bytes, payload: bytes) -> bytes: + return ( + struct.pack(">I", len(payload)) + + kind + + payload + + struct.pack(">I", zlib.crc32(kind + payload) & 0xFFFFFFFF) + ) + + png = b"\x89PNG\r\n\x1a\n" + png += chunk("IHDR".encode(), struct.pack(">IIBBBBB", width, height, 8, 2, 0, 0, 0)) + png += chunk("IDAT".encode(), zlib.compress(raw, level=9)) + png += chunk("IEND".encode(), b"") + path.write_bytes(png) + + +def write_accuracy_curve_png(path: Path, summaries: list[dict[str, Any]]) -> None: + points = [ + (item.get("train_samples"), float(item.get("accuracy", 0.0))) + for item in summaries + if item.get("train_samples") is not None + ] + if not points: + return + + width, height = 900, 520 + margin_left, margin_right, margin_top, margin_bottom = 70, 35, 35, 70 + plot_w = width - margin_left - margin_right + plot_h = height - margin_top - margin_bottom + image = bytearray([255] * width * height * 3) + + axis = (32, 32, 32) + grid = (225, 225, 225) + line = (34, 102, 190) + dot = (190, 60, 60) + + x_min = min(x for x, _ in points) + x_max = max(x for x, _ in points) + if x_min == x_max: + x_min = 0 + y_min, y_max = 0.0, 1.0 + + def map_x(x: int) -> int: + if x_max == x_min: + return margin_left + plot_w // 2 + return margin_left + round((x - x_min) / (x_max - x_min) * plot_w) + + def map_y(y: float) -> int: + return margin_top + round((y_max - max(y_min, min(y_max, y))) / (y_max - y_min) * plot_h) + + for i in range(6): + y = margin_top + round(i / 5 * plot_h) + draw_line(image, width, height, (margin_left, y), (width - margin_right, y), grid) + draw_line(image, width, height, (margin_left, margin_top), (margin_left, height - margin_bottom), axis) + draw_line( + image, + width, + height, + (margin_left, height - margin_bottom), + (width - margin_right, height - margin_bottom), + axis, + ) + + mapped = [(map_x(x), map_y(y)) for x, y in points] + for a, b in zip(mapped, mapped[1:], strict=False): + draw_line(image, width, height, a, b, line) + for point in mapped: + draw_dot(image, width, height, point, 5, dot) + + write_png(path, width, height, image) + + +def main() -> None: + args = parse_args() + out = Path(args.out_json) + out.parent.mkdir(parents=True, exist_ok=True) + + if args.aggregate_dir: + summary = aggregate(Path(args.aggregate_dir)) + else: + if not args.stage or not args.debug_file: + raise SystemExit("--stage and --debug-file are required unless --aggregate-dir is used") + summary = summarize_debug_file(args.stage, Path(args.debug_file), args.max_response_len, args.train_samples) + + out.write_text(json.dumps(summary, indent=2) + "\n", encoding="utf-8") + print(json.dumps(summary, indent=2)) + + +if __name__ == "__main__": + main()