-
Notifications
You must be signed in to change notification settings - Fork 0
Add Gemma4 dense and MoE model support #4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: codex/empty-colocated-weight-bucket-20260626
Are you sure you want to change the base?
Changes from all commits
199963c
43bd7bc
a137666
29592e9
80928f7
c4c07d6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,97 @@ | ||
| # Gemma4 Dense and MoE with GSM8K | ||
|
|
||
| This example is a small model-support validation for the Gemma4 text models. It | ||
| uses GSM8K instead of SWE because the purpose is to verify the Megatron model | ||
| path, SGLang rollout load path, loss masking, backward pass, and live weight | ||
| update without adding sandbox, tool-use, or agent runtime variables. | ||
|
|
||
| Use a downstream SWE recipe only after this validation passes. | ||
|
|
||
| ## What to Run | ||
|
|
||
| Run the dense and MoE variants separately on one 8-GPU node: | ||
|
|
||
| | Model | Script | Megatron topology | SGLang topology | | ||
| | --- | --- | --- | --- | | ||
| | `google/gemma-4-31B-it` | `scripts/run-gemma4-31B-gsm8k.sh` | TP2 PP4 CP1 | TP8 | | ||
| | `google/gemma-4-26B-A4B-it` | `scripts/run-gemma4-26B-A4B-gsm8k.sh` | TP2 PP2 EP2 CP1 | TP8 | | ||
|
|
||
| The scripts default to two rollouts with short responses. They are intended to | ||
| prove that the model can train, not to report a meaningful GSM8K score. A small | ||
| default `--entropy-coef` keeps the optimizer path active even when the tiny | ||
| sample receives zero reward. | ||
|
|
||
| Use a fresh converted checkpoint directory for each model and topology. The | ||
| default paths include TP/PP/EP/CP because Megatron distributed checkpoints are | ||
| sharded by the conversion topology. | ||
|
|
||
| ## Prepare Checkpoints and Data | ||
|
|
||
| ```bash | ||
| cd /root | ||
| git clone https://github.com/THUDM/slime.git | ||
| cd slime | ||
| pip install -e . --no-deps | ||
|
|
||
| hf download google/gemma-4-31B-it --local-dir /root/gemma-4-31B-it | ||
| hf download google/gemma-4-26B-A4B-it --local-dir /root/gemma-4-26B-A4B-it | ||
| hf download --repo-type dataset zhuzilin/gsm8k --local-dir /root/datasets/gsm8k | ||
| ``` | ||
|
|
||
| Convert the dense checkpoint: | ||
|
|
||
| ```bash | ||
| cd /root/slime | ||
| source scripts/models/gemma4-31B.sh | ||
| PYTHONPATH=/root/Megatron-LM torchrun --nproc-per-node 8 \ | ||
| tools/convert_hf_to_torch_dist.py \ | ||
| "${MODEL_ARGS[@]}" \ | ||
| --hf-checkpoint /root/gemma-4-31B-it \ | ||
| --tensor-model-parallel-size 2 \ | ||
| --pipeline-model-parallel-size 4 \ | ||
| --context-parallel-size 1 \ | ||
| --save /root/gemma-4-31B-it_tp2_pp4_cp1_torch_dist | ||
| ``` | ||
|
|
||
| Convert the MoE checkpoint: | ||
|
|
||
| ```bash | ||
| cd /root/slime | ||
| source scripts/models/gemma4-26B-A4B.sh | ||
| PYTHONPATH=/root/Megatron-LM torchrun --nproc-per-node 8 \ | ||
| tools/convert_hf_to_torch_dist.py \ | ||
| "${MODEL_ARGS[@]}" \ | ||
| --hf-checkpoint /root/gemma-4-26B-A4B-it \ | ||
| --tensor-model-parallel-size 2 \ | ||
| --pipeline-model-parallel-size 2 \ | ||
| --expert-model-parallel-size 2 \ | ||
| --context-parallel-size 1 \ | ||
| --save /root/gemma-4-26B-A4B-it_tp2_pp2_ep2_cp1_torch_dist | ||
| ``` | ||
|
|
||
| ## Run Training | ||
|
|
||
| ```bash | ||
| cd /root/slime | ||
| bash scripts/run-gemma4-31B-gsm8k.sh | ||
| bash scripts/run-gemma4-26B-A4B-gsm8k.sh | ||
| ``` | ||
|
|
||
| To log the validation runs: | ||
|
|
||
| ```bash | ||
| USE_WANDB=1 WANDB_PROJECT=slime-gemma4-gsm8k bash scripts/run-gemma4-31B-gsm8k.sh | ||
| USE_WANDB=1 WANDB_PROJECT=slime-gemma4-gsm8k bash scripts/run-gemma4-26B-A4B-gsm8k.sh | ||
| ``` | ||
|
|
||
| ## Expected Signal | ||
|
|
||
| A successful run should show: | ||
|
|
||
| - SGLang loading `Gemma4ForConditionalGeneration`. | ||
| - At least one completed rollout and train step. | ||
| - `train/loss`, `train/grad_norm`, and entropy metrics in stdout or W&B. | ||
| - Successful raw `update_weights` from Megatron to SGLang. | ||
|
|
||
| For quality training, increase the rollout count, batch sizes, response length, | ||
| and evaluation interval, and set `ENTROPY_COEF=0`. | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,95 @@ | ||||||||||||||||||||
| # Gemma4 Dense 与 MoE 的 GSM8K 示例 | ||||||||||||||||||||
|
|
||||||||||||||||||||
| 这个示例用于验证 Gemma4 text 模型在 slime 中的模型支持。这里使用 | ||||||||||||||||||||
| GSM8K,而不是 SWE,因为目标是验证 Megatron 模型路径、SGLang rollout | ||||||||||||||||||||
| 加载路径、loss mask、反向传播和在线权重更新,不引入 sandbox、工具调用或 | ||||||||||||||||||||
| agent runtime 变量。 | ||||||||||||||||||||
|
|
||||||||||||||||||||
| SWE 类型任务应当在这个验证通过后再接入。 | ||||||||||||||||||||
|
|
||||||||||||||||||||
| ## 运行内容 | ||||||||||||||||||||
|
|
||||||||||||||||||||
| 在单个 8 卡节点上分别运行 dense 和 MoE 版本: | ||||||||||||||||||||
|
|
||||||||||||||||||||
| | 模型 | 脚本 | Megatron 拓扑 | SGLang 拓扑 | | ||||||||||||||||||||
| | --- | --- | --- | --- | | ||||||||||||||||||||
| | `google/gemma-4-31B-it` | `scripts/run-gemma4-31B-gsm8k.sh` | TP2 PP4 CP1 | TP8 | | ||||||||||||||||||||
| | `google/gemma-4-26B-A4B-it` | `scripts/run-gemma4-26B-A4B-gsm8k.sh` | TP2 PP2 EP2 CP1 | TP8 | | ||||||||||||||||||||
|
|
||||||||||||||||||||
| 脚本默认只跑两个 rollout,并使用较短的 response length。它用于证明模型可以 | ||||||||||||||||||||
| 完成训练闭环,不用于报告有意义的 GSM8K 分数。默认的一个很小的 | ||||||||||||||||||||
| `--entropy-coef` 用来确保在小样本全零 reward 时仍然会触发 optimizer 路径。 | ||||||||||||||||||||
|
|
||||||||||||||||||||
| 每种模型和拓扑都应使用新的转换 checkpoint 目录。默认路径包含 TP/PP/EP/CP, | ||||||||||||||||||||
| 因为 Megatron distributed checkpoint 会按转换拓扑切分。 | ||||||||||||||||||||
|
|
||||||||||||||||||||
| ## 准备 Checkpoint 与数据 | ||||||||||||||||||||
|
|
||||||||||||||||||||
| ```bash | ||||||||||||||||||||
| cd /root | ||||||||||||||||||||
| git clone https://github.com/THUDM/slime.git | ||||||||||||||||||||
| cd slime | ||||||||||||||||||||
| pip install -e . --no-deps | ||||||||||||||||||||
|
|
||||||||||||||||||||
| hf download google/gemma-4-31B-it --local-dir /root/gemma-4-31B-it | ||||||||||||||||||||
| hf download google/gemma-4-26B-A4B-it --local-dir /root/gemma-4-26B-A4B-it | ||||||||||||||||||||
| hf download --repo-type dataset zhuzilin/gsm8k --local-dir /root/datasets/gsm8k | ||||||||||||||||||||
|
Comment on lines
+28
to
+36
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🎯 Functional Correctness | 🟡 Minor | ⚡ Quick win Document the This guide also calls Suggested doc fix cd /root
git clone https://github.com/THUDM/slime.git
cd slime
pip install -e . --no-deps
+pip install -U "huggingface_hub[cli]"
hf download google/gemma-4-31B-it --local-dir /root/gemma-4-31B-it
hf download google/gemma-4-26B-A4B-it --local-dir /root/gemma-4-26B-A4B-it
hf download --repo-type dataset zhuzilin/gsm8k --local-dir /root/datasets/gsm8k📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||
| ``` | ||||||||||||||||||||
|
|
||||||||||||||||||||
| 转换 dense checkpoint: | ||||||||||||||||||||
|
|
||||||||||||||||||||
| ```bash | ||||||||||||||||||||
| cd /root/slime | ||||||||||||||||||||
| source scripts/models/gemma4-31B.sh | ||||||||||||||||||||
| PYTHONPATH=/root/Megatron-LM torchrun --nproc-per-node 8 \ | ||||||||||||||||||||
| tools/convert_hf_to_torch_dist.py \ | ||||||||||||||||||||
| "${MODEL_ARGS[@]}" \ | ||||||||||||||||||||
| --hf-checkpoint /root/gemma-4-31B-it \ | ||||||||||||||||||||
| --tensor-model-parallel-size 2 \ | ||||||||||||||||||||
| --pipeline-model-parallel-size 4 \ | ||||||||||||||||||||
| --context-parallel-size 1 \ | ||||||||||||||||||||
| --save /root/gemma-4-31B-it_tp2_pp4_cp1_torch_dist | ||||||||||||||||||||
| ``` | ||||||||||||||||||||
|
|
||||||||||||||||||||
| 转换 MoE checkpoint: | ||||||||||||||||||||
|
|
||||||||||||||||||||
| ```bash | ||||||||||||||||||||
| cd /root/slime | ||||||||||||||||||||
| source scripts/models/gemma4-26B-A4B.sh | ||||||||||||||||||||
| PYTHONPATH=/root/Megatron-LM torchrun --nproc-per-node 8 \ | ||||||||||||||||||||
| tools/convert_hf_to_torch_dist.py \ | ||||||||||||||||||||
| "${MODEL_ARGS[@]}" \ | ||||||||||||||||||||
| --hf-checkpoint /root/gemma-4-26B-A4B-it \ | ||||||||||||||||||||
| --tensor-model-parallel-size 2 \ | ||||||||||||||||||||
| --pipeline-model-parallel-size 2 \ | ||||||||||||||||||||
| --expert-model-parallel-size 2 \ | ||||||||||||||||||||
| --context-parallel-size 1 \ | ||||||||||||||||||||
| --save /root/gemma-4-26B-A4B-it_tp2_pp2_ep2_cp1_torch_dist | ||||||||||||||||||||
| ``` | ||||||||||||||||||||
|
|
||||||||||||||||||||
| ## 运行训练 | ||||||||||||||||||||
|
|
||||||||||||||||||||
| ```bash | ||||||||||||||||||||
| cd /root/slime | ||||||||||||||||||||
| bash scripts/run-gemma4-31B-gsm8k.sh | ||||||||||||||||||||
| bash scripts/run-gemma4-26B-A4B-gsm8k.sh | ||||||||||||||||||||
| ``` | ||||||||||||||||||||
|
|
||||||||||||||||||||
| 如果需要记录到 W&B: | ||||||||||||||||||||
|
|
||||||||||||||||||||
| ```bash | ||||||||||||||||||||
| USE_WANDB=1 WANDB_PROJECT=slime-gemma4-gsm8k bash scripts/run-gemma4-31B-gsm8k.sh | ||||||||||||||||||||
| USE_WANDB=1 WANDB_PROJECT=slime-gemma4-gsm8k bash scripts/run-gemma4-26B-A4B-gsm8k.sh | ||||||||||||||||||||
| ``` | ||||||||||||||||||||
|
|
||||||||||||||||||||
| ## 期望信号 | ||||||||||||||||||||
|
|
||||||||||||||||||||
| 成功运行时应当看到: | ||||||||||||||||||||
|
|
||||||||||||||||||||
| - SGLang 加载 `Gemma4ForConditionalGeneration`。 | ||||||||||||||||||||
| - 至少一个 rollout 和 train step 完成。 | ||||||||||||||||||||
| - stdout 或 W&B 中出现 `train/loss`、`train/grad_norm` 和 entropy 指标。 | ||||||||||||||||||||
| - Megatron 到 SGLang 的 raw `update_weights` 成功。 | ||||||||||||||||||||
|
|
||||||||||||||||||||
| 如果要做正式效果训练,应增加 rollout 数量、batch size、response length 和 | ||||||||||||||||||||
| eval interval,并设置 `ENTROPY_COEF=0`。 | ||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| MODEL_ARGS=( | ||
| --spec "slime_plugins.models.gemma4" "get_gemma4_spec" | ||
| --custom-model-provider-path "slime_plugins.models.gemma4_provider.model_provider" | ||
| --num-layers 48 | ||
| --hidden-size 3840 | ||
| --ffn-hidden-size 15360 | ||
| --num-attention-heads 16 | ||
| --group-query-attention | ||
| --num-query-groups 8 | ||
| --kv-channels 256 | ||
| --use-rotary-position-embeddings | ||
| --disable-bias-linear | ||
| --normalization "RMSNorm" | ||
| --norm-epsilon 1e-6 | ||
| --rotary-base 10000 | ||
| --rotary-percent 1.0 | ||
| --vocab-size 262144 | ||
| --qk-layernorm | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| MODEL_ARGS=( | ||
| --spec "slime_plugins.models.gemma4" "get_gemma4_spec" | ||
| --custom-model-provider-path "slime_plugins.models.gemma4_provider.model_provider" | ||
| --num-layers 30 | ||
| --hidden-size 2816 | ||
| --ffn-hidden-size 2112 | ||
| --num-attention-heads 16 | ||
| --group-query-attention | ||
| --num-query-groups 8 | ||
| --kv-channels 256 | ||
| --use-rotary-position-embeddings | ||
| --disable-bias-linear | ||
| --normalization "RMSNorm" | ||
| --norm-epsilon 1e-6 | ||
| --rotary-base 10000 | ||
| --rotary-percent 1.0 | ||
| --vocab-size 262144 | ||
| --qk-layernorm | ||
| --num-experts 128 | ||
| --moe-ffn-hidden-size 704 | ||
| --moe-router-topk 8 | ||
| --moe-router-dtype fp32 | ||
| --moe-router-score-function softmax | ||
| --moe-router-load-balancing-type none | ||
| --moe-aux-loss-coeff 0.0 | ||
| --moe-token-dispatcher-type alltoall | ||
| --moe-grouped-gemm | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| MODEL_ARGS=( | ||
| --spec "slime_plugins.models.gemma4" "get_gemma4_spec" | ||
| --custom-model-provider-path "slime_plugins.models.gemma4_provider.model_provider" | ||
| --num-layers 60 | ||
| --hidden-size 5376 | ||
| --ffn-hidden-size 21504 | ||
| --num-attention-heads 32 | ||
| --group-query-attention | ||
| --num-query-groups 16 | ||
| --kv-channels 256 | ||
| --use-rotary-position-embeddings | ||
| --disable-bias-linear | ||
| --normalization "RMSNorm" | ||
| --norm-epsilon 1e-6 | ||
| --rotary-base 10000 | ||
| --rotary-percent 1.0 | ||
| --vocab-size 262144 | ||
| --qk-layernorm | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🎯 Functional Correctness | 🟡 Minor | ⚡ Quick win
Document the
hfCLI prerequisite.These steps install slime with
--no-depsand then immediately callhf download, so a clean environment can fail withhf: command not found. Please add the Hugging Face CLI install step before the download commands.Suggested doc fix
cd /root git clone https://github.com/THUDM/slime.git cd slime pip install -e . --no-deps +pip install -U "huggingface_hub[cli]" hf download google/gemma-4-31B-it --local-dir /root/gemma-4-31B-it hf download google/gemma-4-26B-A4B-it --local-dir /root/gemma-4-26B-A4B-it hf download --repo-type dataset zhuzilin/gsm8k --local-dir /root/datasets/gsm8k📝 Committable suggestion
🤖 Prompt for AI Agents