diff --git a/docs/en/examples/gemma4.md b/docs/en/examples/gemma4.md new file mode 100644 index 0000000000..3a35e14389 --- /dev/null +++ b/docs/en/examples/gemma4.md @@ -0,0 +1,97 @@ +# Gemma4 Dense and MoE with GSM8K + +This example is a small model-support validation for the Gemma4 text models. It +uses GSM8K because the purpose is to verify the Megatron model path, SGLang +rollout load path, loss masking, backward pass, and live weight update without +adding task-specific runtime variables. + +Larger task-specific recipes should be layered on after this validation passes. + +## What to Run + +Run the dense and MoE variants separately on one 8-GPU node: + +| Model | Script | Megatron topology | SGLang topology | +| --- | --- | --- | --- | +| `google/gemma-4-31B-it` | `scripts/run-gemma4-31B-gsm8k.sh` | TP2 PP4 CP1 | TP8 | +| `google/gemma-4-26B-A4B-it` | `scripts/run-gemma4-26B-A4B-gsm8k.sh` | TP2 PP2 EP2 CP1 | TP8 | + +The scripts default to two rollouts with short responses. They are intended to +prove that the model can train, not to report a meaningful GSM8K score. A small +default `--entropy-coef` keeps the optimizer path active even when the tiny +sample receives zero reward. + +Use a fresh converted checkpoint directory for each model and topology. The +default paths include TP/PP/EP/CP because Megatron distributed checkpoints are +sharded by the conversion topology. + +## Prepare Checkpoints and Data + +```bash +cd /root +git clone https://github.com/THUDM/slime.git +cd slime +pip install -e . --no-deps + +hf download google/gemma-4-31B-it --local-dir /root/gemma-4-31B-it +hf download google/gemma-4-26B-A4B-it --local-dir /root/gemma-4-26B-A4B-it +hf download --repo-type dataset zhuzilin/gsm8k --local-dir /root/datasets/gsm8k +``` + +Convert the dense checkpoint: + +```bash +cd /root/slime +source scripts/models/gemma4-31B.sh +PYTHONPATH=/root/Megatron-LM torchrun --nproc-per-node 8 \ + tools/convert_hf_to_torch_dist.py \ + "${MODEL_ARGS[@]}" \ + --hf-checkpoint /root/gemma-4-31B-it \ + --tensor-model-parallel-size 2 \ + --pipeline-model-parallel-size 4 \ + --context-parallel-size 1 \ + --save /root/gemma-4-31B-it_tp2_pp4_cp1_torch_dist +``` + +Convert the MoE checkpoint: + +```bash +cd /root/slime +source scripts/models/gemma4-26B-A4B.sh +PYTHONPATH=/root/Megatron-LM torchrun --nproc-per-node 8 \ + tools/convert_hf_to_torch_dist.py \ + "${MODEL_ARGS[@]}" \ + --hf-checkpoint /root/gemma-4-26B-A4B-it \ + --tensor-model-parallel-size 2 \ + --pipeline-model-parallel-size 2 \ + --expert-model-parallel-size 2 \ + --context-parallel-size 1 \ + --save /root/gemma-4-26B-A4B-it_tp2_pp2_ep2_cp1_torch_dist +``` + +## Run Training + +```bash +cd /root/slime +bash scripts/run-gemma4-31B-gsm8k.sh +bash scripts/run-gemma4-26B-A4B-gsm8k.sh +``` + +To log the validation runs: + +```bash +USE_WANDB=1 WANDB_PROJECT=slime-gemma4-gsm8k bash scripts/run-gemma4-31B-gsm8k.sh +USE_WANDB=1 WANDB_PROJECT=slime-gemma4-gsm8k bash scripts/run-gemma4-26B-A4B-gsm8k.sh +``` + +## Expected Signal + +A successful run should show: + +- SGLang loading `Gemma4ForConditionalGeneration`. +- At least one completed rollout and train step. +- `train/loss`, `train/grad_norm`, and entropy metrics in stdout or W&B. +- Successful raw `update_weights` from Megatron to SGLang. + +For quality training, increase the rollout count, batch sizes, response length, +and evaluation interval, and set `ENTROPY_COEF=0`. diff --git a/docs/en/index.rst b/docs/en/index.rst index f3401ce3a7..d066a36837 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -59,6 +59,7 @@ Start by Use Case :caption: Dense examples/qwen3-4B.md + examples/gemma4.md examples/glm4-9B.md .. toctree:: diff --git a/docs/zh/examples/gemma4.md b/docs/zh/examples/gemma4.md new file mode 100644 index 0000000000..10f45406de --- /dev/null +++ b/docs/zh/examples/gemma4.md @@ -0,0 +1,94 @@ +# Gemma4 Dense 与 MoE 的 GSM8K 示例 + +这个示例用于验证 Gemma4 text 模型在 slime 中的模型支持。这里使用 +GSM8K,因为目标是验证 Megatron 模型路径、SGLang rollout 加载路径、loss +mask、反向传播和在线权重更新,不引入任务特定的 runtime 变量。 + +更大的任务特定 recipe 应当在这个验证通过后再接入。 + +## 运行内容 + +在单个 8 卡节点上分别运行 dense 和 MoE 版本: + +| 模型 | 脚本 | Megatron 拓扑 | SGLang 拓扑 | +| --- | --- | --- | --- | +| `google/gemma-4-31B-it` | `scripts/run-gemma4-31B-gsm8k.sh` | TP2 PP4 CP1 | TP8 | +| `google/gemma-4-26B-A4B-it` | `scripts/run-gemma4-26B-A4B-gsm8k.sh` | TP2 PP2 EP2 CP1 | TP8 | + +脚本默认只跑两个 rollout,并使用较短的 response length。它用于证明模型可以 +完成训练闭环,不用于报告有意义的 GSM8K 分数。默认的一个很小的 +`--entropy-coef` 用来确保在小样本全零 reward 时仍然会触发 optimizer 路径。 + +每种模型和拓扑都应使用新的转换 checkpoint 目录。默认路径包含 TP/PP/EP/CP, +因为 Megatron distributed checkpoint 会按转换拓扑切分。 + +## 准备 Checkpoint 与数据 + +```bash +cd /root +git clone https://github.com/THUDM/slime.git +cd slime +pip install -e . --no-deps + +hf download google/gemma-4-31B-it --local-dir /root/gemma-4-31B-it +hf download google/gemma-4-26B-A4B-it --local-dir /root/gemma-4-26B-A4B-it +hf download --repo-type dataset zhuzilin/gsm8k --local-dir /root/datasets/gsm8k +``` + +转换 dense checkpoint: + +```bash +cd /root/slime +source scripts/models/gemma4-31B.sh +PYTHONPATH=/root/Megatron-LM torchrun --nproc-per-node 8 \ + tools/convert_hf_to_torch_dist.py \ + "${MODEL_ARGS[@]}" \ + --hf-checkpoint /root/gemma-4-31B-it \ + --tensor-model-parallel-size 2 \ + --pipeline-model-parallel-size 4 \ + --context-parallel-size 1 \ + --save /root/gemma-4-31B-it_tp2_pp4_cp1_torch_dist +``` + +转换 MoE checkpoint: + +```bash +cd /root/slime +source scripts/models/gemma4-26B-A4B.sh +PYTHONPATH=/root/Megatron-LM torchrun --nproc-per-node 8 \ + tools/convert_hf_to_torch_dist.py \ + "${MODEL_ARGS[@]}" \ + --hf-checkpoint /root/gemma-4-26B-A4B-it \ + --tensor-model-parallel-size 2 \ + --pipeline-model-parallel-size 2 \ + --expert-model-parallel-size 2 \ + --context-parallel-size 1 \ + --save /root/gemma-4-26B-A4B-it_tp2_pp2_ep2_cp1_torch_dist +``` + +## 运行训练 + +```bash +cd /root/slime +bash scripts/run-gemma4-31B-gsm8k.sh +bash scripts/run-gemma4-26B-A4B-gsm8k.sh +``` + +如果需要记录到 W&B: + +```bash +USE_WANDB=1 WANDB_PROJECT=slime-gemma4-gsm8k bash scripts/run-gemma4-31B-gsm8k.sh +USE_WANDB=1 WANDB_PROJECT=slime-gemma4-gsm8k bash scripts/run-gemma4-26B-A4B-gsm8k.sh +``` + +## 期望信号 + +成功运行时应当看到: + +- SGLang 加载 `Gemma4ForConditionalGeneration`。 +- 至少一个 rollout 和 train step 完成。 +- stdout 或 W&B 中出现 `train/loss`、`train/grad_norm` 和 entropy 指标。 +- Megatron 到 SGLang 的 raw `update_weights` 成功。 + +如果要做正式效果训练,应增加 rollout 数量、batch size、response length 和 +eval interval,并设置 `ENTROPY_COEF=0`。 diff --git a/docs/zh/index.rst b/docs/zh/index.rst index 7075a28b84..c062cde2a2 100644 --- a/docs/zh/index.rst +++ b/docs/zh/index.rst @@ -59,6 +59,7 @@ slime 的设计目标,是让这两大能力彼此强化,同时避免把系 :caption: Dense examples/qwen3-4B.md + examples/gemma4.md examples/glm4-9B.md .. toctree:: diff --git a/scripts/models/gemma4-12B.sh b/scripts/models/gemma4-12B.sh new file mode 100644 index 0000000000..dd2d629db1 --- /dev/null +++ b/scripts/models/gemma4-12B.sh @@ -0,0 +1,19 @@ +MODEL_ARGS=( + --spec "slime_plugins.models.gemma4" "get_gemma4_spec" + --custom-model-provider-path "slime_plugins.models.gemma4_provider.model_provider" + --num-layers 48 + --hidden-size 3840 + --ffn-hidden-size 15360 + --num-attention-heads 16 + --group-query-attention + --num-query-groups 8 + --kv-channels 256 + --use-rotary-position-embeddings + --disable-bias-linear + --normalization "RMSNorm" + --norm-epsilon 1e-6 + --rotary-base 10000 + --rotary-percent 1.0 + --vocab-size 262144 + --qk-layernorm +) diff --git a/scripts/models/gemma4-26B-A4B.sh b/scripts/models/gemma4-26B-A4B.sh new file mode 100644 index 0000000000..c6d4d029a2 --- /dev/null +++ b/scripts/models/gemma4-26B-A4B.sh @@ -0,0 +1,28 @@ +MODEL_ARGS=( + --spec "slime_plugins.models.gemma4" "get_gemma4_spec" + --custom-model-provider-path "slime_plugins.models.gemma4_provider.model_provider" + --num-layers 30 + --hidden-size 2816 + --ffn-hidden-size 2112 + --num-attention-heads 16 + --group-query-attention + --num-query-groups 8 + --kv-channels 256 + --use-rotary-position-embeddings + --disable-bias-linear + --normalization "RMSNorm" + --norm-epsilon 1e-6 + --rotary-base 10000 + --rotary-percent 1.0 + --vocab-size 262144 + --qk-layernorm + --num-experts 128 + --moe-ffn-hidden-size 704 + --moe-router-topk 8 + --moe-router-dtype fp32 + --moe-router-score-function softmax + --moe-router-load-balancing-type none + --moe-aux-loss-coeff 0.0 + --moe-token-dispatcher-type alltoall + --moe-grouped-gemm +) diff --git a/scripts/models/gemma4-31B.sh b/scripts/models/gemma4-31B.sh new file mode 100644 index 0000000000..c9d832159c --- /dev/null +++ b/scripts/models/gemma4-31B.sh @@ -0,0 +1,19 @@ +MODEL_ARGS=( + --spec "slime_plugins.models.gemma4" "get_gemma4_spec" + --custom-model-provider-path "slime_plugins.models.gemma4_provider.model_provider" + --num-layers 60 + --hidden-size 5376 + --ffn-hidden-size 21504 + --num-attention-heads 32 + --group-query-attention + --num-query-groups 16 + --kv-channels 256 + --use-rotary-position-embeddings + --disable-bias-linear + --normalization "RMSNorm" + --norm-epsilon 1e-6 + --rotary-base 10000 + --rotary-percent 1.0 + --vocab-size 262144 + --qk-layernorm +) diff --git a/scripts/run-gemma4-26B-A4B-gsm8k.sh b/scripts/run-gemma4-26B-A4B-gsm8k.sh new file mode 100644 index 0000000000..ad41749105 --- /dev/null +++ b/scripts/run-gemma4-26B-A4B-gsm8k.sh @@ -0,0 +1,167 @@ +#!/bin/bash + +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python +pkill -9 redis + +set -ex + +export PYTHONUNBUFFERED=1 +unset http_proxy https_proxy HTTP_PROXY HTTPS_PROXY + +BASE_DIR=${BASE_DIR:-/root} +MODEL_NAME=${MODEL_NAME:-gemma-4-26B-A4B-it} +MODEL_DIR=${MODEL_DIR:-${BASE_DIR}/${MODEL_NAME}} +GSM8K_DIR=${GSM8K_DIR:-${BASE_DIR}/datasets/gsm8k} +NUM_GPUS=${NUM_GPUS:-8} +TP_SIZE=${TP_SIZE:-2} +PP_SIZE=${PP_SIZE:-2} +EP_SIZE=${EP_SIZE:-2} +CP_SIZE=${CP_SIZE:-1} +TORCH_DIST_CKPT=${TORCH_DIST_CKPT:-${BASE_DIR}/${MODEL_NAME}_tp${TP_SIZE}_pp${PP_SIZE}_ep${EP_SIZE}_cp${CP_SIZE}_torch_dist} +SLIME_CKPT=${SLIME_CKPT:-${BASE_DIR}/${MODEL_NAME}_tp${TP_SIZE}_pp${PP_SIZE}_ep${EP_SIZE}_cp${CP_SIZE}_slime} + +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +source "${SCRIPT_DIR}/models/gemma4-26B-A4B.sh" + +CKPT_ARGS=( + --hf-checkpoint "${MODEL_DIR}" + --ref-load "${TORCH_DIST_CKPT}" + --load "${SLIME_CKPT}" + --save "${SLIME_CKPT}" + --save-interval 20 +) + +ROLLOUT_ARGS=( + --prompt-data "${GSM8K_DIR}/train.parquet" + --input-key messages + --label-key label + --apply-chat-template + --rollout-shuffle + --rm-type math + --num-rollout "${NUM_ROLLOUT:-2}" + --rollout-batch-size "${ROLLOUT_BATCH_SIZE:-4}" + --n-samples-per-prompt "${N_SAMPLES_PER_PROMPT:-4}" + --rollout-max-response-len "${ROLLOUT_MAX_RESPONSE_LEN:-512}" + --rollout-temperature "${ROLLOUT_TEMPERATURE:-0.8}" + --rollout-top-p "${ROLLOUT_TOP_P:-1.0}" + --global-batch-size "${GLOBAL_BATCH_SIZE:-16}" + --num-steps-per-rollout 1 + --balance-data +) + +EVAL_ARGS=() +if [ "${ENABLE_EVAL:-0}" = "1" ]; then + EVAL_ARGS=( + --eval-interval "${EVAL_INTERVAL:-20}" + --eval-prompt-data gsm8k "${GSM8K_DIR}/test.parquet" + --n-samples-per-eval-prompt "${N_SAMPLES_PER_EVAL_PROMPT:-1}" + --eval-max-response-len "${EVAL_MAX_RESPONSE_LEN:-512}" + --eval-top-p 1 + ) +fi + +PERF_ARGS=( + --tensor-model-parallel-size "${TP_SIZE}" + --sequence-parallel + --pipeline-model-parallel-size "${PP_SIZE}" + --context-parallel-size "${CP_SIZE}" + --expert-model-parallel-size "${EP_SIZE}" + --expert-tensor-parallel-size 1 + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + --use-dynamic-batch-size + --calculate-per-token-loss + --max-tokens-per-gpu "${MAX_TOKENS_PER_GPU:-2048}" +) + +GRPO_ARGS=( + --advantage-estimator grpo + --entropy-coef "${ENTROPY_COEF:-0.001}" + --eps-clip 0.2 + --eps-clip-high 0.28 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr "${LR:-1e-6}" + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 + --optimizer-cpu-offload + --overlap-cpu-optimizer-d2h-h2d + --use-precision-aware-optimizer +) + +WANDB_ARGS=() +if [ "${USE_WANDB:-0}" = "1" ]; then + WANDB_ARGS=( + --use-wandb + --wandb-project "${WANDB_PROJECT:-slime-gemma4-gsm8k}" + --wandb-group "${WANDB_GROUP:-gemma4-26B-A4B-gsm8k}" + ) + if [ -n "${WANDB_KEY:-}" ]; then + WANDB_ARGS+=(--wandb-key "${WANDB_KEY}") + fi +fi + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine "${ROLLOUT_TP_SIZE:-8}" + --sglang-mem-fraction-static "${SGLANG_MEM_FRACTION_STATIC:-0.20}" + --sglang-cuda-graph-max-bs "${SGLANG_CUDA_GRAPH_MAX_BS:-1}" + --sglang-max-running-requests "${SGLANG_MAX_RUNNING_REQUESTS:-4}" +) + +MISC_ARGS=( + --attention-dropout 0.0 + --hidden-dropout 0.0 + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + --attention-backend flash + --loss-mask-type gemma4 + --megatron-to-hf-mode raw +) + +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +ray start --head --node-ip-address "${MASTER_ADDR}" --num-gpus "${NUM_GPUS}" --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 + +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM/\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\" + } +}" + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node "${NUM_GPUS}" \ + --colocate \ + "${MODEL_ARGS[@]}" \ + "${CKPT_ARGS[@]}" \ + "${ROLLOUT_ARGS[@]}" \ + "${OPTIMIZER_ARGS[@]}" \ + "${GRPO_ARGS[@]}" \ + "${WANDB_ARGS[@]}" \ + "${PERF_ARGS[@]}" \ + "${EVAL_ARGS[@]}" \ + "${SGLANG_ARGS[@]}" \ + "${MISC_ARGS[@]}" diff --git a/scripts/run-gemma4-31B-gsm8k.sh b/scripts/run-gemma4-31B-gsm8k.sh new file mode 100644 index 0000000000..550f41732a --- /dev/null +++ b/scripts/run-gemma4-31B-gsm8k.sh @@ -0,0 +1,166 @@ +#!/bin/bash + +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python +pkill -9 redis + +set -ex + +export PYTHONUNBUFFERED=1 +unset http_proxy https_proxy HTTP_PROXY HTTPS_PROXY + +BASE_DIR=${BASE_DIR:-/root} +MODEL_NAME=${MODEL_NAME:-gemma-4-31B-it} +MODEL_DIR=${MODEL_DIR:-${BASE_DIR}/${MODEL_NAME}} +GSM8K_DIR=${GSM8K_DIR:-${BASE_DIR}/datasets/gsm8k} +NUM_GPUS=${NUM_GPUS:-8} +TP_SIZE=${TP_SIZE:-2} +PP_SIZE=${PP_SIZE:-4} +CP_SIZE=${CP_SIZE:-1} +TORCH_DIST_CKPT=${TORCH_DIST_CKPT:-${BASE_DIR}/${MODEL_NAME}_tp${TP_SIZE}_pp${PP_SIZE}_cp${CP_SIZE}_torch_dist} +SLIME_CKPT=${SLIME_CKPT:-${BASE_DIR}/${MODEL_NAME}_tp${TP_SIZE}_pp${PP_SIZE}_cp${CP_SIZE}_slime} + +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +source "${SCRIPT_DIR}/models/gemma4-31B.sh" + +CKPT_ARGS=( + --hf-checkpoint "${MODEL_DIR}" + --ref-load "${TORCH_DIST_CKPT}" + --load "${SLIME_CKPT}" + --save "${SLIME_CKPT}" + --save-interval 20 +) + +ROLLOUT_ARGS=( + --prompt-data "${GSM8K_DIR}/train.parquet" + --input-key messages + --label-key label + --apply-chat-template + --rollout-shuffle + --rm-type math + --num-rollout "${NUM_ROLLOUT:-2}" + --rollout-batch-size "${ROLLOUT_BATCH_SIZE:-4}" + --n-samples-per-prompt "${N_SAMPLES_PER_PROMPT:-4}" + --rollout-max-response-len "${ROLLOUT_MAX_RESPONSE_LEN:-512}" + --rollout-temperature "${ROLLOUT_TEMPERATURE:-0.8}" + --rollout-top-p "${ROLLOUT_TOP_P:-1.0}" + --global-batch-size "${GLOBAL_BATCH_SIZE:-16}" + --num-steps-per-rollout 1 + --balance-data +) + +EVAL_ARGS=() +if [ "${ENABLE_EVAL:-0}" = "1" ]; then + EVAL_ARGS=( + --eval-interval "${EVAL_INTERVAL:-20}" + --eval-prompt-data gsm8k "${GSM8K_DIR}/test.parquet" + --n-samples-per-eval-prompt "${N_SAMPLES_PER_EVAL_PROMPT:-1}" + --eval-max-response-len "${EVAL_MAX_RESPONSE_LEN:-512}" + --eval-top-p 1 + ) +fi + +PERF_ARGS=( + --tensor-model-parallel-size "${TP_SIZE}" + --sequence-parallel + --pipeline-model-parallel-size "${PP_SIZE}" + --context-parallel-size "${CP_SIZE}" + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + --use-dynamic-batch-size + --calculate-per-token-loss + --max-tokens-per-gpu "${MAX_TOKENS_PER_GPU:-2048}" +) + +GRPO_ARGS=( + --advantage-estimator grpo + --entropy-coef "${ENTROPY_COEF:-0.001}" + --eps-clip 0.2 + --eps-clip-high 0.28 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr "${LR:-1e-6}" + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 + --optimizer-cpu-offload + --overlap-cpu-optimizer-d2h-h2d + --use-precision-aware-optimizer +) + +WANDB_ARGS=() +if [ "${USE_WANDB:-0}" = "1" ]; then + WANDB_ARGS=( + --use-wandb + --wandb-project "${WANDB_PROJECT:-slime-gemma4-gsm8k}" + --wandb-group "${WANDB_GROUP:-gemma4-31B-gsm8k}" + ) + if [ -n "${WANDB_KEY:-}" ]; then + WANDB_ARGS+=(--wandb-key "${WANDB_KEY}") + fi +fi + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine "${ROLLOUT_TP_SIZE:-8}" + --sglang-mem-fraction-static "${SGLANG_MEM_FRACTION_STATIC:-0.20}" + --sglang-cuda-graph-max-bs "${SGLANG_CUDA_GRAPH_MAX_BS:-1}" + --sglang-max-running-requests "${SGLANG_MAX_RUNNING_REQUESTS:-4}" +) + +MISC_ARGS=( + --attention-dropout 0.0 + --hidden-dropout 0.0 + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + --attention-backend flash + --loss-mask-type gemma4 + --megatron-to-hf-mode raw +) + +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +ray start --head --node-ip-address "${MASTER_ADDR}" --num-gpus "${NUM_GPUS}" --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 + +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM/\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\" + } +}" + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node "${NUM_GPUS}" \ + --colocate \ + "${MODEL_ARGS[@]}" \ + "${CKPT_ARGS[@]}" \ + "${ROLLOUT_ARGS[@]}" \ + "${OPTIMIZER_ARGS[@]}" \ + "${GRPO_ARGS[@]}" \ + "${WANDB_ARGS[@]}" \ + "${PERF_ARGS[@]}" \ + "${EVAL_ARGS[@]}" \ + "${SGLANG_ARGS[@]}" \ + "${MISC_ARGS[@]}" diff --git a/slime/backends/megatron_utils/megatron_to_hf/__init__.py b/slime/backends/megatron_utils/megatron_to_hf/__init__.py index 3ccb1b579e..8f687e3939 100644 --- a/slime/backends/megatron_utils/megatron_to_hf/__init__.py +++ b/slime/backends/megatron_utils/megatron_to_hf/__init__.py @@ -1,4 +1,5 @@ from .deepseekv3 import convert_deepseekv3_to_hf +from .gemma4 import convert_gemma4_to_hf from .glm4 import convert_glm4_to_hf from .glm4moe import convert_glm4moe_to_hf from .gpt_oss import convert_gpt_oss_to_hf @@ -55,6 +56,8 @@ def _convert_to_hf_core(args, model_name, name, param): converted_named_tensors = convert_qwen3vl_to_hf(args, name, param) elif "qwen2" in model_name or "qwen3" in model_name: converted_named_tensors = convert_qwen2_to_hf(args, name, param) + elif "gemma4" in model_name: + converted_named_tensors = convert_gemma4_to_hf(args, name, param) elif "llama" in model_name: converted_named_tensors = convert_llama_to_hf(args, name, param) elif "mimo" in model_name: diff --git a/slime/backends/megatron_utils/megatron_to_hf/gemma4.py b/slime/backends/megatron_utils/megatron_to_hf/gemma4.py new file mode 100644 index 0000000000..539dc15523 --- /dev/null +++ b/slime/backends/megatron_utils/megatron_to_hf/gemma4.py @@ -0,0 +1,163 @@ +import re +import torch + +_config_cache: dict[str, dict] = {} + +# Per-layer buffers for stacked expert tensors. sglang's Gemma4 loader expects +# `experts.gate_up_proj` as a single 3D tensor of shape [E, 2I, H] and +# `experts.down_proj` as [E, H, I] - it walks all experts inside the loader +# and would silently drop per-expert 2D inputs. We accumulate expert tensors +# as they stream through and emit the stacked form once all num_experts arrive. +_expert_buffers: dict = {} + + +def reset_expert_buffers() -> None: + """Drop any partial expert buckets. Callers that drive the converter from a + long-lived process (tests, repeated conversions) should invoke this between + runs so an interrupted prior conversion doesn't leak its partial state.""" + _expert_buffers.clear() + + +def _get_config(args): + checkpoint = args.hf_checkpoint + if checkpoint not in _config_cache: + from transformers import AutoConfig + + hf_config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True) + hf_text = hf_config.text_config if hasattr(hf_config, "text_config") else hf_config + _config_cache[checkpoint] = { + "global_attn_layers": {i for i, t in enumerate(hf_text.layer_types) if t == "full_attention"}, + "local_head_dim": hf_text.head_dim, + "global_head_dim": hf_text.global_head_dim, + "num_attention_heads": hf_text.num_attention_heads, + "local_num_kv_heads": hf_text.num_key_value_heads, + "global_num_kv_heads": hf_text.num_global_key_value_heads, + "hidden_size": hf_text.hidden_size, + "num_experts": getattr(hf_text, "num_experts", 0), + } + return _config_cache[checkpoint] + + +def convert_gemma4_to_hf(args, name, param): + cfg = _get_config(args) + prefix = "model.language_model." + + if name == "module.module.embedding.word_embeddings.weight": + return [(f"{prefix}embed_tokens.weight", param)] + if name == "module.module.output_layer.weight": + return [(f"{prefix}embed_tokens.weight", param)] # tied embeddings + if name == "module.module.decoder.final_layernorm.weight": + return [(f"{prefix}norm.weight", param)] + + match = re.match(r"module\.module\.decoder\.layers\.(\d+)\.(.+)", name) + if match: + layer_idx = int(match.group(1)) + rest = match.group(2) + L = f"{prefix}layers.{layer_idx}" + is_global = layer_idx in cfg["global_attn_layers"] + + if rest == "self_attention.linear_proj.weight": + return [(f"{L}.self_attn.o_proj.weight", param)] + elif rest == "self_attention.linear_qkv.weight": + if is_global: + head_dim = cfg["global_head_dim"] + num_kv_heads = cfg["global_num_kv_heads"] + else: + head_dim = cfg["local_head_dim"] + num_kv_heads = cfg["local_num_kv_heads"] + + q_heads_per_kv = cfg["num_attention_heads"] // num_kv_heads + hidden_size = cfg["hidden_size"] + param = param.view(num_kv_heads, (q_heads_per_kv + 2) * head_dim, hidden_size) + q_dim = q_heads_per_kv * head_dim + q_param = param[:, :q_dim, :].reshape(-1, hidden_size) + k_param = param[:, q_dim : q_dim + head_dim, :].reshape(-1, hidden_size) + + if is_global: + return [ + (f"{L}.self_attn.q_proj.weight", q_param), + (f"{L}.self_attn.k_proj.weight", k_param), + ] + else: + v_param = param[:, q_dim + head_dim :, :].reshape(-1, hidden_size) + return [ + (f"{L}.self_attn.q_proj.weight", q_param), + (f"{L}.self_attn.k_proj.weight", k_param), + (f"{L}.self_attn.v_proj.weight", v_param), + ] + elif rest == "self_attention.linear_qkv.layer_norm_weight": + return [(f"{L}.input_layernorm.weight", param)] + elif rest == "self_attention.q_layernorm.weight": + return [(f"{L}.self_attn.q_norm.weight", param)] + elif rest == "self_attention.k_layernorm.weight": + return [(f"{L}.self_attn.k_norm.weight", param)] + elif rest in ("mlp.linear_fc1.weight", "dense_mlp.linear_fc1.weight"): + gate_weight, up_weight = param.chunk(2, dim=0) + return [ + (f"{L}.mlp.gate_proj.weight", gate_weight), + (f"{L}.mlp.up_proj.weight", up_weight), + ] + elif rest in ("mlp.linear_fc2.weight", "dense_mlp.linear_fc2.weight"): + return [(f"{L}.mlp.down_proj.weight", param)] + elif rest in ("mlp.linear_fc1.layer_norm_weight", "dense_mlp.linear_fc1.layer_norm_weight"): + return [(f"{L}.pre_feedforward_layernorm.weight", param)] + elif rest == "pre_mlp_layernorm.weight": + return [(f"{L}.pre_feedforward_layernorm.weight", param)] + elif rest == "post_attention_layernorm.weight": + return [(f"{L}.post_attention_layernorm.weight", param)] + elif rest == "post_feedforward_layernorm.weight": + return [(f"{L}.post_feedforward_layernorm.weight", param)] + elif rest == "layer_scalar": + return [(f"{L}.layer_scalar", param)] + elif rest == "mlp.router.proj.weight": + return [(f"{L}.router.proj.weight", param)] + elif rest == "mlp.router.scale": + return [(f"{L}.router.scale", param)] + elif rest == "mlp.router.per_expert_scale": + return [(f"{L}.router.per_expert_scale", param)] + else: + expert_match = re.match(r"mlp\.experts\.linear_fc([12])\.weight(\d+)", rest) + if expert_match: + fc, expert_idx = expert_match.group(1), int(expert_match.group(2)) + return _buffer_expert_and_maybe_flush( + layer_idx, + fc, + expert_idx, + param, + L, + num_experts=cfg["num_experts"], + ) + + if rest == "pre_feedforward_layernorm_2.weight": + return [(f"{L}.pre_feedforward_layernorm_2.weight", param)] + elif rest == "mlp.pre_feedforward_layernorm_2.weight": + return [(f"{L}.pre_feedforward_layernorm_2.weight", param)] + elif rest == "post_feedforward_layernorm_2.weight": + return [(f"{L}.post_feedforward_layernorm_2.weight", param)] + elif rest == "post_feedforward_layernorm_1.weight": + return [(f"{L}.post_feedforward_layernorm_1.weight", param)] + + raise ValueError(f"Unknown Gemma4 parameter name: {name}") + + +def _buffer_expert_and_maybe_flush(layer_idx, fc, expert_idx, param, L_prefix, num_experts): + """Buffer per-expert tensor; emit stacked 3D `experts.gate_up_proj` / `experts.down_proj` + once the bucket for (layer, fc) has all `num_experts` experts.""" + assert ( + num_experts and num_experts > 0 + ), f"num_experts must be known for MoE layer expert conversion, got {num_experts}" + key = (layer_idx, fc) + bucket = _expert_buffers.setdefault(key, {}) + bucket[expert_idx] = param + + if len(bucket) < num_experts: + return [] + + ordered = [bucket[i] for i in range(num_experts)] + stacked = torch.stack(ordered, dim=0).contiguous() + del _expert_buffers[key] + + if fc == "1": + return [(f"{L_prefix}.experts.gate_up_proj", stacked)] + else: + return [(f"{L_prefix}.experts.down_proj", stacked)] diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index d5cac9d44b..2e8044c705 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -1382,7 +1382,7 @@ def add_rollout_buffer_arguments(parser): "--loss-mask-type", type=str, default="qwen", - choices=["qwen", "qwen3", "qwen3_5", "distill_qwen"], + choices=["qwen", "qwen3", "qwen3_5", "gemma4", "distill_qwen"], help="Loss mask type", ) parser.add_argument( diff --git a/slime/utils/external_utils/command_utils.py b/slime/utils/external_utils/command_utils.py index 2f003e1f11..a0b98a1051 100644 --- a/slime/utils/external_utils/command_utils.py +++ b/slime/utils/external_utils/command_utils.py @@ -52,7 +52,7 @@ def convert_checkpoint( exec_command( f"source {repo_base_dir}/scripts/models/{megatron_model_type}.sh && " - f"PYTHONPATH=/root/Megatron-LM " + f"PYTHONPATH={repo_base_dir}:/root/Megatron-LM:${{PYTHONPATH:-}} " f"torchrun " f"--nproc-per-node {num_gpus_per_node} " f"{multinode_args}" diff --git a/slime/utils/mask_utils.py b/slime/utils/mask_utils.py index efe5e159f1..d298946106 100644 --- a/slime/utils/mask_utils.py +++ b/slime/utils/mask_utils.py @@ -195,6 +195,80 @@ def gen_multi_turn_loss_mask_qwen3_5( return token_ids, loss_mask + def gen_multi_turn_loss_mask_gemma4( + self, messages: list[dict], tools: list[dict] = None + ) -> tuple[list[int], list[int]]: + """Mask assistant content plus ```` in Gemma4 chat templates.""" + rendered_text = self.tokenizer.apply_chat_template(messages, tokenize=False, tools=tools, return_dict=False) + tokenized = self.tokenizer(rendered_text, add_special_tokens=False, return_offsets_mapping=True) + token_ids = tokenized["input_ids"] + offset_mapping = tokenized.get("offset_mapping") + + if offset_mapping is None: + raise ValueError( + "Gemma4 loss mask generation requires a fast tokenizer with `return_offsets_mapping` support." + ) + + expected_token_ids = self.tokenizer.apply_chat_template( + messages, tokenize=True, tools=tools, return_dict=False + ) + if token_ids != expected_token_ids: + raise ValueError( + "Gemma4 rendered text tokenization does not match " "`apply_chat_template(..., tokenize=True)` output." + ) + + assistant_header = "<|turn>model\n" + think_open = "<|channel>thought\n" + think_close = "" + end_marker = "" + + char_mask = [0] * len(rendered_text) + cursor = 0 + + for message in messages: + if message["role"] != "assistant": + continue + + header_pos = rendered_text.find(assistant_header, cursor) + if header_pos < 0: + raise ValueError("Failed to locate assistant (model) turn in rendered Gemma4 chat template output.") + + content_start = header_pos + len(assistant_header) + end_pos = rendered_text.find(end_marker, content_start) + if end_pos < 0: + raise ValueError("Failed to locate for assistant message in rendered Gemma4 text.") + + span_end = end_pos + len(end_marker) + if span_end < len(rendered_text) and rendered_text[span_end] == "\n": + span_end += 1 + cursor = span_end + + if message.get("step_loss_mask", 1) != 1: + continue + + mask_start = content_start + if rendered_text[content_start : content_start + len(think_open)] == think_open: + close_pos = rendered_text.find(think_close, content_start) + if close_pos < 0: + raise ValueError("Found <|channel>thought open without matching close.") + mask_start = close_pos + len(think_close) + + for pos in range(mask_start, span_end): + char_mask[pos] = 1 + + char_mask_prefix_sum = [0] + for value in char_mask: + char_mask_prefix_sum.append(char_mask_prefix_sum[-1] + value) + + loss_mask = [] + for start, end in offset_mapping: + if end <= start: + loss_mask.append(0) + else: + loss_mask.append(1 if char_mask_prefix_sum[end] - char_mask_prefix_sum[start] > 0 else 0) + + return token_ids, loss_mask + def gen_multi_turn_loss_mask_distill_qwen( self, messages: list[dict], tools: list[dict] = None ) -> tuple[list[int], list[int]]: @@ -223,6 +297,8 @@ def get_loss_mask(self, messages: list[dict], tools: list[dict] = None) -> tuple return self.gen_multi_turn_loss_mask_qwen3(messages, tools) elif self.tokenizer_type == "qwen3_5": return self.gen_multi_turn_loss_mask_qwen3_5(messages, tools) + elif self.tokenizer_type == "gemma4": + return self.gen_multi_turn_loss_mask_gemma4(messages, tools) elif self.tokenizer_type == "distill_qwen": return self.gen_multi_turn_loss_mask_distill_qwen(messages, tools) else: diff --git a/slime_plugins/mbridge/__init__.py b/slime_plugins/mbridge/__init__.py index 9263cbe90d..2c9ad7456d 100644 --- a/slime_plugins/mbridge/__init__.py +++ b/slime_plugins/mbridge/__init__.py @@ -1,4 +1,5 @@ from .deepseek_v32 import DeepseekV32Bridge +from .gemma4 import Gemma4Bridge from .glm4 import GLM4Bridge from .glm4moe import GLM4MoEBridge from .glm4moe_lite import GLM4MoELiteBridge @@ -18,4 +19,5 @@ "Qwen3_5Bridge", "MimoBridge", "DeepseekV32Bridge", + "Gemma4Bridge", ] diff --git a/slime_plugins/mbridge/gemma4.py b/slime_plugins/mbridge/gemma4.py new file mode 100644 index 0000000000..2423f7448c --- /dev/null +++ b/slime_plugins/mbridge/gemma4.py @@ -0,0 +1,277 @@ +import functools +import re + +import torch +import torch.nn.functional as F +from mbridge.core import register_model +from mbridge.models import Gemma3Bridge + +from slime_plugins.models.gemma4 import get_rope_local_base_freq as _rope_local_base_freq + +_gelu_tanh = functools.partial(F.gelu, approximate="tanh") + + +@register_model(["gemma4", "gemma4_text", "gemma4_unified_text"]) +class Gemma4Bridge(Gemma3Bridge): + """ + Bridge for Gemma4 text dense and MoE variants. + + Megatron-side keys have NO language_model. prefix (text-only model). + HF-side values have model.language_model. prefix (Gemma4ForConditionalGeneration). + """ + + _ATTENTION_MAPPING = { + "decoder.layers.{layer_number}.self_attention.linear_qkv.weight": [ + "model.language_model.layers.{layer_number}.self_attn.q_proj.weight", + "model.language_model.layers.{layer_number}.self_attn.k_proj.weight", + "model.language_model.layers.{layer_number}.self_attn.v_proj.weight", + ], + "decoder.layers.{layer_number}.self_attention.linear_proj.weight": [ + "model.language_model.layers.{layer_number}.self_attn.o_proj.weight", + ], + "decoder.layers.{layer_number}.self_attention.linear_qkv.layer_norm_weight": [ + "model.language_model.layers.{layer_number}.input_layernorm.weight", + ], + "decoder.layers.{layer_number}.self_attention.q_layernorm.weight": [ + "model.language_model.layers.{layer_number}.self_attn.q_norm.weight", + ], + "decoder.layers.{layer_number}.self_attention.k_layernorm.weight": [ + "model.language_model.layers.{layer_number}.self_attn.k_norm.weight", + ], + } + + _MLP_MAPPING = { + "decoder.layers.{layer_number}.mlp.linear_fc1.weight": [ + "model.language_model.layers.{layer_number}.mlp.gate_proj.weight", + "model.language_model.layers.{layer_number}.mlp.up_proj.weight", + ], + "decoder.layers.{layer_number}.mlp.linear_fc2.weight": [ + "model.language_model.layers.{layer_number}.mlp.down_proj.weight", + ], + "decoder.layers.{layer_number}.mlp.linear_fc1.layer_norm_weight": [ + "model.language_model.layers.{layer_number}.pre_feedforward_layernorm.weight", + ], + "decoder.layers.{layer_number}.pre_mlp_layernorm.weight": [ + "model.language_model.layers.{layer_number}.pre_feedforward_layernorm.weight", + ], + "decoder.layers.{layer_number}.dense_mlp.linear_fc1.weight": [ + "model.language_model.layers.{layer_number}.mlp.gate_proj.weight", + "model.language_model.layers.{layer_number}.mlp.up_proj.weight", + ], + "decoder.layers.{layer_number}.dense_mlp.linear_fc2.weight": [ + "model.language_model.layers.{layer_number}.mlp.down_proj.weight", + ], + "decoder.layers.{layer_number}.dense_mlp.linear_fc1.layer_norm_weight": [ + "model.language_model.layers.{layer_number}.pre_feedforward_layernorm.weight", + ], + "decoder.layers.{layer_number}.mlp.router.proj.weight": [ + "model.language_model.layers.{layer_number}.router.proj.weight", + ], + "decoder.layers.{layer_number}.mlp.router.scale": [ + "model.language_model.layers.{layer_number}.router.scale", + ], + "decoder.layers.{layer_number}.mlp.router.per_expert_scale": [ + "model.language_model.layers.{layer_number}.router.per_expert_scale", + ], + "decoder.layers.{layer_number}.mlp.pre_feedforward_layernorm_2.weight": [ + "model.language_model.layers.{layer_number}.pre_feedforward_layernorm_2.weight", + ], + } + + _OTHER_MAPPING = { + "decoder.layers.{layer_number}.post_attention_layernorm.weight": [ + "model.language_model.layers.{layer_number}.post_attention_layernorm.weight", + ], + "decoder.layers.{layer_number}.post_feedforward_layernorm.weight": [ + "model.language_model.layers.{layer_number}.post_feedforward_layernorm.weight", + ], + "decoder.layers.{layer_number}.layer_scalar": [ + "model.language_model.layers.{layer_number}.layer_scalar", + ], + "decoder.layers.{layer_number}.post_feedforward_layernorm_2.weight": [ + "model.language_model.layers.{layer_number}.post_feedforward_layernorm_2.weight", + ], + "decoder.layers.{layer_number}.post_feedforward_layernorm_1.weight": [ + "model.language_model.layers.{layer_number}.post_feedforward_layernorm_1.weight", + ], + } + + _RE_MOE_EXPERT = re.compile(r"^decoder\.layers\.(\d+)\.mlp\.experts\.linear_fc([12])\.weight(\d+)$") + + _DIRECT_MAPPING = { + "embedding.word_embeddings.weight": "model.language_model.embed_tokens.weight", + "decoder.final_layernorm.weight": "model.language_model.norm.weight", + "output_layer.weight": "model.language_model.embed_tokens.weight", + } + + _BUFFER_NAMES = [ + "model.language_model.layers.{layer_number}.layer_scalar", + ] + + _GLOBAL_ATTN_LAYERS = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + hf_text = self.hf_config.text_config if hasattr(self.hf_config, "text_config") else self.hf_config + layer_types = getattr(hf_text, "layer_types", []) + self._GLOBAL_ATTN_LAYERS = {i for i, t in enumerate(layer_types) if t == "full_attention"} + + def _attention_shape_for_hf_weights(self, hf_weights: list[torch.Tensor]) -> tuple[int, int]: + hf_text = self.hf_config.text_config if hasattr(self.hf_config, "text_config") else self.hf_config + if len(hf_weights) == 2: + return ( + int(getattr(hf_text, "num_global_key_value_heads", hf_text.num_key_value_heads)), + int(getattr(hf_text, "global_head_dim", hf_text.head_dim)), + ) + if len(hf_weights) == 3: + return ( + int(hf_text.num_key_value_heads), + int(getattr(hf_text, "head_dim", hf_text.hidden_size // hf_text.num_attention_heads)), + ) + raise ValueError(f"Gemma4 linear_qkv expects 2 or 3 HF tensors, got {len(hf_weights)}.") + + def _weight_name_mapping_attention(self, name: str) -> list[str]: + split_name = name.split(".") + layer_number = int(split_name[2]) + split_name[2] = "{layer_number}" + key = ".".join(split_name) + + if key == "decoder.layers.{layer_number}.self_attention.linear_qkv.weight": + if layer_number in self._GLOBAL_ATTN_LAYERS: + return [ + f"model.language_model.layers.{layer_number}.self_attn.q_proj.weight", + f"model.language_model.layers.{layer_number}.self_attn.k_proj.weight", + ] + + return [x.format(layer_number=layer_number) for x in self._ATTENTION_MAPPING[key]] + + def _weight_name_mapping_mcore_local_to_global(self, model, consider_ep: bool = True): + """Restore the GPT-style local->global mapping for text-only Gemma4. + + Gemma3Bridge (our base class) assumes a VLM structure where + ``model.language_model.decoder.layers`` exists, and only applies the + PP layer-offset remap when that attribute is present. Our Gemma4 + model provider builds a plain ``GPTModel`` (text-only) with + ``model.decoder.layers``, so the Gemma3 check fails silently and all + PP ranks end up mapping their local layer index i -> global index i - + which means every PP rank loads HF layers ``0..N/PP-1`` into its + local slots. The result is that, post-conversion, the torch_dist + checkpoint has layer weights cyclically duplicated with period + (num_layers / pp_size). + + We override to delegate to ``Bridge._weight_name_mapping_mcore_local_to_global`` + from the top-level mbridge base class, which walks ``model.decoder.layers`` + directly - matching our GPT-style layout. + """ + from mbridge.core.bridge import Bridge + + return Bridge._weight_name_mapping_mcore_local_to_global(self, model, consider_ep=consider_ep) + + def _weight_name_mapping_mlp(self, name: str) -> list[str]: + m = self._RE_MOE_EXPERT.match(name) + if m: + layer_number, fc = m.group(1), m.group(2) + hf_tensor = "gate_up_proj" if fc == "1" else "down_proj" + return [ + f"model.language_model.layers.{layer_number}.experts.{hf_tensor}", + ] + + split_name = name.split(".") + layer_number = split_name[2] + split_name[2] = "{layer_number}" + key = ".".join(split_name) + return [x.format(layer_number=layer_number) for x in self._MLP_MAPPING[key]] + + def _weight_name_mapping_other(self, name: str) -> list[str]: + split_name = name.split(".") + layer_number = split_name[2] + split_name[2] = "{layer_number}" + key = ".".join(split_name) + return [x.format(layer_number=layer_number) for x in self._OTHER_MAPPING[key]] + + def _weight_to_mcore_format(self, mcore_weights_name, hf_weights): + m = self._RE_MOE_EXPERT.match(mcore_weights_name) + if m: + expert_idx = int(m.group(3)) + assert len(hf_weights) == 1, f"expected exactly one HF tensor for expert weight, got {len(hf_weights)}" + return hf_weights[0][expert_idx].contiguous() + + if "self_attention.linear_qkv." in mcore_weights_name and "layer_norm" not in mcore_weights_name: + m = re.search(r"layers\.(\d+)\.", mcore_weights_name) + layer_num = int(m.group(1)) if m else -1 + + hf_text = self.hf_config.text_config if hasattr(self.hf_config, "text_config") else self.hf_config + num_attention_heads = hf_text.num_attention_heads + num_kv_heads, head_dim = self._attention_shape_for_hf_weights(hf_weights) + + if len(hf_weights) == 2: + q, k = hf_weights + hf_weights = [q, k, k.clone()] + elif len(hf_weights) != 3: + raise ValueError(f"Gemma4 linear_qkv expects 2 or 3 HF tensors, got {len(hf_weights)}.") + + q, k, v = hf_weights + group_dim = head_dim * num_attention_heads // num_kv_heads + assert q.shape[0] == num_kv_heads * group_dim, ( + f"layer {layer_num}: q_proj rows ({q.shape[0]}) must equal " + f"num_kv_heads ({num_kv_heads}) * group_dim ({group_dim}); " + f"check head_dim/num_attention_heads/num_kv_heads consistency" + ) + assert k.shape[0] == num_kv_heads * head_dim, ( + f"layer {layer_num}: k_proj rows ({k.shape[0]}) must equal " + f"num_kv_heads ({num_kv_heads}) * head_dim ({head_dim})" + ) + assert v.shape[0] == num_kv_heads * head_dim, ( + f"layer {layer_num}: v_proj rows ({v.shape[0]}) must equal " + f"num_kv_heads ({num_kv_heads}) * head_dim ({head_dim})" + ) + q = q.view(num_kv_heads, group_dim, -1) + k = k.view(num_kv_heads, head_dim, -1) + v = v.view(num_kv_heads, head_dim, -1) + return torch.cat([q, k, v], dim=1).view(-1, hf_text.hidden_size).contiguous() + + if "linear_fc1.weight" in mcore_weights_name: + assert len(hf_weights) == 2, ( + f"MLP linear_fc1.weight expects [gate_proj, up_proj] from HF " f"(2 tensors); got {len(hf_weights)}" + ) + gate, up = hf_weights + return torch.cat([gate, up], dim=0) + + if len(hf_weights) == 1: + return hf_weights[0] + + raise NotImplementedError(f"Unsupported parameter name: {mcore_weights_name}") + + def _build_config(self): + text_config_key = "text_config" if hasattr(self.hf_config, "text_config") else None + hf_text = self.hf_config.text_config if text_config_key else self.hf_config + + base_kwargs = dict( + text_config_key=text_config_key, + use_cpu_initialization=False, + add_qkv_bias=False, + qk_layernorm=True, + layernorm_zero_centered_gamma=False, + normalization="RMSNorm", + persist_layer_norm=True, + activation_func=_gelu_tanh, + bias_activation_fusion=False, + bias_dropout_fusion=True, + rope_local_base_freq=_rope_local_base_freq(hf_text), + ) + if getattr(hf_text, "enable_moe_block", False): + base_kwargs.update( + num_moe_experts=hf_text.num_experts, + moe_router_topk=hf_text.top_k_experts, + moe_ffn_hidden_size=hf_text.moe_intermediate_size, + moe_token_dispatcher_type="alltoall", + moe_grouped_gemm=True, + moe_aux_loss_coeff=0.0, + moe_router_load_balancing_type="none", + moe_router_score_function="softmax", + moe_router_topk_scaling_factor=1.0, + moe_router_pre_softmax=False, + moe_router_dtype="fp32", + ) + + return self._build_base_config(**base_kwargs) diff --git a/slime_plugins/models/gemma4.py b/slime_plugins/models/gemma4.py new file mode 100644 index 0000000000..523e7ce377 --- /dev/null +++ b/slime_plugins/models/gemma4.py @@ -0,0 +1,1176 @@ +"""Native Megatron Gemma4 transformer layer and config. + +Extends the Gemma3 implementation from mbridge with Gemma4-specific features: +- Heterogeneous attention: global layers use head_dim=512, num_kv_heads=4; + sliding layers use head_dim=256, num_kv_heads=16. +- attention_k_eq_v: global layers reuse K output as V (no v_proj). +- v_norm: RMSNorm without learnable scale applied to V states. +- layer_scalar: buffer multiplied after residual (not learned). +- final_logit_softcapping: applied to output logits in the model wrapper. +- MoE block (26B-A4B): Gemma4's custom router (with per-expert scale) plugged + into Megatron's MoE infrastructure for proper expert-parallel sharding. + The router is still custom (see Gemma4Router); dispatching + grouped-GEMM + come from Megatron's MoELayer + TEGroupedMLP. +""" + +import functools +import logging +from dataclasses import dataclass +from dataclasses import replace as dc_replace + +import torch +import torch.nn as nn +import torch.nn.functional as F +from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add +from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.moe.moe_layer import BaseMoELayer, MoELayer +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules +from megatron.core.utils import make_viewless_tensor + +try: + from megatron.core.extensions.transformer_engine import ( + TEColumnParallelLinear, + TEDotProductAttention, + TELayerNormColumnParallelLinear, + TENorm, + TERowParallelLinear, + ) + + HAVE_TE = True +except ImportError: + HAVE_TE = False + +from mbridge.models.gemma3.transformer_config import Gemma3TransformerConfig + +# Gemma uses GeGLU, not SwiGLU. +_gelu_tanh = functools.partial(F.gelu, approximate="tanh") + + +@dataclass +class Gemma4TransformerConfig(Gemma3TransformerConfig): + """Gemma4-specific config extending Gemma3.""" + + global_kv_channels: int = 512 + global_num_query_groups: int = 4 + global_partial_rotary_factor: float = 0.25 # fraction of global head_dim that gets RoPE + attention_k_eq_v: bool = True # global layers: V = K (no v_proj) + enable_moe_block: bool = False # 26B-A4B MoE variant + + +class VNorm(nn.Module): + """RMSNorm without learnable scale, matching Gemma4's v_norm.""" + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + dtype = x.dtype + x = x.float() + return (x * torch.pow(x.pow(2).mean(-1, keepdim=True) + self.eps, -0.5)).to(dtype) + + +@dataclass +class Gemma4TransformerLayerSubmodules(TransformerLayerSubmodules): + post_attention_layernorm: ModuleSpec | type = IdentityOp + post_feedforward_layernorm: ModuleSpec | type = IdentityOp + # For MoE-enabled variants (26B-A4B), the primary `mlp` submodule is swapped + # to a Gemma4MoELayer and the original dense MLP moves to `dense_mlp`. This + # keeps the `.mlp.experts.linear_fc...` naming that mbridge's EP auto-handling + # expects while preserving Gemma4's dense+MoE-in-parallel structure. + dense_mlp: ModuleSpec | type = IdentityOp + + +class Gemma4Router(nn.Module): + """Gemma4 MoE router. + + The router equation (mirroring HF ``Gemma4TextTopkRouter``) is: + + h_norm = RMSNorm_no_scale(h) # VNorm: no learnable scale + h_scaled = h_norm * scale / sqrt(H) # learnable per-hidden scale + logits = proj(h_scaled) # [T, E] + probs = softmax(logits, dim=-1) + top_w, top_i = topk(probs, k=top_k) + top_w = top_w / top_w.sum(dim=-1, keepdim=True) # renormalize + top_w = top_w * per_expert_scale[top_i] # per-expert scale + + The renormalise-then-scale order is load-bearing and must match HF: it + produces ``top_w.sum() == per_expert_scale.mean_over_selected`` rather + than a renormalised-back-to-1 distribution. Reversing the order (scale + first, then renormalise) would cancel ``per_expert_scale``. + ``test_router_matches_hf_reference_equation`` guards this. + """ + + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.num_experts = config.num_moe_experts + self.top_k = config.moe_router_topk + self.scalar_root_size = self.hidden_size**-0.5 + self.norm = VNorm(self.hidden_size, eps=config.layernorm_epsilon) + self.proj = nn.Linear(self.hidden_size, self.num_experts, bias=False) + self.scale = nn.Parameter(torch.ones(self.hidden_size)) + self.per_expert_scale = nn.Parameter(torch.ones(self.num_experts)) + + def forward(self, hidden_states): + h = self.norm(hidden_states) + h = h * self.scale * self.scalar_root_size + logits = self.proj(h) + probs = torch.softmax(logits, dim=-1) + top_k_weights, top_k_index = torch.topk(probs, k=self.top_k, dim=-1) + top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True) + top_k_weights = top_k_weights * self.per_expert_scale[top_k_index] + return top_k_weights, top_k_index + + def set_layer_number(self, layer_number): + pass + + +class Gemma4MoELayer(MoELayer): + """Gemma4 MoE block: Megatron's MoELayer with Gemma4's custom router. + + Megatron's MoELayer hardcodes its own ``TopKRouter`` which uses a + softmax-with-expert-bias scheme. Gemma4 has its own router semantics + (no-scale RMSNorm -> learnable per-hidden scale -> proj -> softmax -> topk -> + per-expert scale multiplier). We reuse all of Megatron's infrastructure + for dispatching (alltoall), expert parallelism, and grouped-GEMM expert + computation - but swap in our ``Gemma4Router`` and convert its compact + (top_k_weights [T, K], top_k_index [T, K]) output into Megatron's + expected (probs [T, E], routing_map [T, E]) format inside ``route()``. + """ + + def __init__(self, config, submodules=None, layer_number=None, pg_collection=None): + # Fall back to Megatron's global parallel_state when pg_collection isn't + # explicitly passed. TransformerLayer only forwards pg_collection when + # submodules.mlp.module is *exactly* one of + # (MoELayer, GroupedMLP, TEGroupedMLP, SequentialMLP) - an identity check + # via `in`, so Gemma4MoELayer (a MoELayer subclass) slips through and + # receives None. BaseMoELayer.__init__ then crashes on `pg_collection.ep`. + # Same fallback MoELayer.__init__ uses when invoked directly. + if pg_collection is None: + from megatron.core.transformer.moe.moe_utils import get_default_pg_collection + + pg_collection = get_default_pg_collection() + BaseMoELayer.__init__(self, config=config, layer_number=layer_number, pg_collection=pg_collection) + self.moe_layer_recompute = False + self.shared_experts_recompute = False + self.submodules = submodules + + self.router = Gemma4Router(config) + + from megatron.core.transformer.moe.token_dispatcher import ( + MoEAllGatherTokenDispatcher, + MoEAlltoAllTokenDispatcher, + MoEFlexTokenDispatcher, + ) + + if config.moe_token_dispatcher_type == "allgather": + self.token_dispatcher = MoEAllGatherTokenDispatcher( + self.num_local_experts, + self.local_expert_indices, + config=self.config, + pg_collection=pg_collection, + ) + elif config.moe_token_dispatcher_type == "alltoall": + self.token_dispatcher = MoEAlltoAllTokenDispatcher( + self.num_local_experts, + self.local_expert_indices, + config=self.config, + pg_collection=pg_collection, + ) + elif config.moe_token_dispatcher_type == "flex": + self.token_dispatcher = MoEFlexTokenDispatcher( + self.num_local_experts, + self.local_expert_indices, + config=self.config, + pg_collection=pg_collection, + ) + else: + raise ValueError(f"Unsupported token dispatcher type: {config.moe_token_dispatcher_type}") + + self.experts = build_module( + self.submodules.experts, + self.num_local_experts, + self.config, + pg_collection=pg_collection, + ) + + self.shared_experts = None + + from megatron.core.transformer.moe.moe_utils import MoECudaGraphTensorStore + + self.cudagraph_tensor_store = MoECudaGraphTensorStore() + + # pre_feedforward_layernorm_2: applied to experts' input ONLY (router + # input stays un-normed). Matches HF Gemma4TextDecoderLayer: + # hidden_states_flat = residual # router input (un-normed) + # hidden_states_2 = pre_feedforward_layernorm_2(hidden_states_flat) + # hidden_states_2 = experts(hidden_states_2, top_k_index, top_k_weights) + self.pre_feedforward_layernorm_2 = TENorm( + config=config, + hidden_size=config.hidden_size, + eps=config.layernorm_epsilon, + ) + + def route(self, hidden_states: torch.Tensor): + """Call ``Gemma4Router`` and pack its output into Megatron's + ``(probs, routing_map)`` format. + + ``Gemma4Router`` emits compact top-k tensors: + top_k_weights: [T, K] - routing weights (already scaled by per_expert_scale) + top_k_index: [T, K] - which experts each token routes to + Megatron's dispatcher wants: + probs: [T, E] - weight per (token, expert), 0 where not routed + routing_map: [T, E] - boolean mask + """ + flat = hidden_states.reshape(-1, hidden_states.shape[-1]) + top_k_weights, top_k_index = self.router(flat) + + num_tokens = flat.shape[0] + num_experts = self.config.num_moe_experts + probs = torch.zeros( + num_tokens, + num_experts, + dtype=top_k_weights.dtype, + device=top_k_weights.device, + ) + probs.scatter_(1, top_k_index, top_k_weights) + routing_map = probs != 0 + return probs, routing_map + + def forward( + self, + hidden_states: torch.Tensor, + router_input: torch.Tensor | None = None, + ): + """Gemma4 MoE forward with split router / experts inputs. + + HF's ``Gemma4TextDecoderLayer`` routes based on the *un-normed* residual + but feeds the experts the *pre-ff-norm-2'd* residual: + + hidden_states_flat = residual # un-normed + _, tk_w, tk_i = self.router(hidden_states_flat) + experts_input = self.pre_feedforward_layernorm_2(hidden_states_flat) + output = self.experts(experts_input, tk_i, tk_w) + + We take the un-normed residual in ``hidden_states`` and apply + ``pre_feedforward_layernorm_2`` internally to obtain the experts + input. The router path uses the un-normed residual directly. Callers + may pass a different ``router_input`` for tests or ablations; when + ``router_input is None`` (the normal case) the router sees the same + un-normed residual the layer was called with. + + We inline the Megatron parent's ``forward`` body here - rather than + calling ``super().forward`` with a side-channel stash - so the + router input is passed explicitly end-to-end and the code is safe + under activation checkpointing / recomputation. + """ + if self.training and self.attn_tp_group.size() > 1 and not self.config.sequence_parallel: + raise ValueError( + "During training, performance may degrade if MoE and tensor " + "parallelism are enabled without also enabling sequence parallelism." + ) + + router_in = router_input if router_input is not None else hidden_states + experts_in = self.pre_feedforward_layernorm_2(hidden_states) + + def custom_forward(experts_in, router_in): + # Gemma4 has no shared experts; shared_experts_compute returns None. + shared_expert_output = self.shared_experts_compute(experts_in) + probs, routing_map = self.route(router_in) + experts_in2, probs = self.preprocess(experts_in, probs, routing_map) + dispatched_input, probs = self.dispatch(experts_in2, probs) + output, mlp_bias = self.routed_experts_compute(dispatched_input, probs) + output = self.combine(output) + output = self.postprocess(output, shared_expert_output) + return output, mlp_bias + + # moe_layer_recompute is forced to False in __init__; call directly. + return custom_forward(experts_in, router_in) + + +class Gemma4TransformerLayer(TransformerLayer): + """Gemma4 transformer layer with heterogeneous attention and layer_scalar.""" + + def __init__( + self, + config: Gemma4TransformerConfig, + submodules: Gemma4TransformerLayerSubmodules, + layer_number: int = 1, + hidden_dropout: float = None, + **kwargs, + ): + from megatron.core.transformer.transformer_layer import get_transformer_layer_offset + + global_layer_number = layer_number + get_transformer_layer_offset(config) + # Megatron passes `layer_number` as 1-indexed (default 1), so in 0-indexed + # HF space a global layer is `(i+1) % pattern == 0` -> `i % pattern == pattern-1`. + # Equivalently: `is_sliding` when `global_layer_number % pattern != 0`. + self.is_sliding = bool(global_layer_number % config.sliding_window_pattern) + self._is_global = not self.is_sliding + + # Global layers have different head_dim (kv_channels) and num_kv_heads + # (num_query_groups). Build the layer against a *cloned* config with + # those overrides so we never mutate the shared transformer config. + # Mutation would be reentrant-unsafe under concurrent layer + # construction and leak global-layer shapes into sibling sliding + # layers if an exception were raised during super().__init__. + layer_config = ( + dc_replace( + config, + kv_channels=config.global_kv_channels, + num_query_groups=config.global_num_query_groups, + ) + if self._is_global + else config + ) + super().__init__( + config=layer_config, + submodules=submodules, + layer_number=layer_number, + hidden_dropout=hidden_dropout, + **kwargs, + ) + + self.self_attention._is_global = self._is_global + + # Global layers require this because head_dim=512 exceeds flash attention's limit (256). + # Local layers also use SDPA for consistency. + self.self_attention.core_attention = SDPACoreAttention( + config=config, + layer_number=self.layer_number, + attn_mask_type=AttnMaskType.causal, + softmax_scale=config.softmax_scale, + ) + self.self_attention.core_attention._is_sliding = self.is_sliding + + self.post_attention_layernorm = build_module( + submodules.post_attention_layernorm, + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + self.post_feedforward_layernorm = build_module( + submodules.post_feedforward_layernorm, + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + + # Layer scalar (buffer, not learned). Kept in fp32 intentionally - + # HF stores this scalar in fp32 and relies on the implicit upcast of + # ``bf16_hidden * fp32_scalar`` at multiply time (see HF Gemma4 + # ``Gemma4TextDecoderLayer.__init__`` at modeling_gemma4.py:1331). + # Don't switch to ``dtype=self.config.params_dtype``; that would + # silently change the arithmetic. + self.register_buffer("layer_scalar", torch.ones(1)) + + # MoE block (26B-A4B): super().__init__ already built self.mlp from the + # layer spec, which when enable_moe_block=True is a Gemma4MoELayer (not + # a dense MLP). We also build a parallel `dense_mlp` for Gemma4's + # dense + MoE combined-FFN pattern. The two outputs are summed in + # forward(). + self.enable_moe_block = getattr(config, "enable_moe_block", False) + if self.enable_moe_block: + self.dense_mlp = build_module( + submodules.dense_mlp, + config=config, + ) + self.post_feedforward_layernorm_1 = TENorm( + config=config, + hidden_size=config.hidden_size, + eps=config.layernorm_epsilon, + ) + # pre_feedforward_layernorm_2 now lives INSIDE Gemma4MoELayer + # (matching HF Gemma4TextDecoderLayer semantics: router sees un-normed + # residual, experts see pre_feedforward_layernorm_2(residual)). This + # attribute is kept on the MoE block so mbridge/state-dict paths + # don't change. + self.post_feedforward_layernorm_2 = TENorm( + config=config, + hidden_size=config.hidden_size, + eps=config.layernorm_epsilon, + ) + + def _forward_dense_ffn(self, pre_mlp_ln): + """Run the dense MLP. ``self.mlp`` is the dense MLP directly for the + 31B variant.""" + out, bias = self.mlp(pre_mlp_ln) + return out + bias if bias is not None else out + + def _forward_moe_ffn(self, residual, pre_mlp_ln): + """Run dense + MoE in parallel and sum (26B-A4B variant). + + Mirrors HF ``Gemma4TextDecoderLayer.forward`` (transformers + modeling_gemma4.py:1376-1391): dense branch goes through + ``post_feedforward_layernorm_1``, MoE branch through + ``post_feedforward_layernorm_2``, the two are summed, and the outer + ``Gemma4TransformerLayer.forward`` applies ``post_feedforward_layernorm`` + to the sum - 3 post-FFN LNs total for MoE layers is correct. + + HF routes on the un-normed residual but feeds experts the + ``pre_feedforward_layernorm_2``'d residual; Gemma4MoELayer applies + that norm internally, so we pass the un-normed residual directly. + """ + dense_out, dense_bias = self.dense_mlp(pre_mlp_ln) + if dense_bias is not None: + dense_out = dense_out + dense_bias + mlp_output = self.post_feedforward_layernorm_1(dense_out) + + moe_output, _ = self.mlp(residual) + moe_output = self.post_feedforward_layernorm_2(moe_output) + + return mlp_output + moe_output + + def forward( + self, + hidden_states, + attention_mask=None, + context=None, + context_mask=None, + rotary_pos_emb=None, + rotary_pos_cos=None, + rotary_pos_sin=None, + attention_bias=None, + inference_context=None, + inference_params=None, + packed_seq_params=None, + sequence_len_offset=None, + **kwargs, + ): + if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple): + global_dim = getattr(self.config, "dual_rope_global_dim", 0) + if global_dim > 0 and rotary_pos_emb.shape[-1] > global_dim: + if self.is_sliding: + rotary_pos_emb = rotary_pos_emb[..., global_dim:] + else: + rotary_pos_emb = rotary_pos_emb[..., :global_dim] + elif isinstance(rotary_pos_emb, tuple): + rotary_pos_emb = rotary_pos_emb[1] if self.is_sliding else rotary_pos_emb[0] + if isinstance(attention_mask, tuple): + attention_mask = attention_mask[1] if self.is_sliding else attention_mask[0] + + # Global layers use partial RoPE (25% of head_dim=512 = 128 dims) + # Local layers use full RoPE (100% of head_dim=256 = 256 dims) + # With DualRotaryEmbedding, global RoPE is full-size (512 dims) with zero-padded + # non-rotated dims, so no truncation needed. + # With single RoPE (local only, 256 dims), truncate for global layers. + if not self.is_sliding and rotary_pos_emb is not None: + global_rope_dim = int(self.config.global_kv_channels * self.config.global_partial_rotary_factor) + if ( + rotary_pos_emb.shape[-1] != self.config.global_kv_channels + and rotary_pos_emb.shape[-1] > global_rope_dim + ): + rotary_pos_emb = rotary_pos_emb[..., :global_rope_dim] + + residual = hidden_states + + extra_kwargs = {} + if inference_context is not None: + extra_kwargs["inference_context"] = inference_context + elif inference_params is not None: + extra_kwargs["inference_params"] = inference_params + + input_layernorm_output = self.input_layernorm(hidden_states) + + hidden_states, hidden_states_bias = self.self_attention( + input_layernorm_output, + attention_mask=attention_mask, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + **extra_kwargs, + ) + + if hidden_states_bias is not None: + hidden_states = hidden_states + hidden_states_bias + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states) + if self.enable_moe_block: + hidden_states = self._forward_moe_ffn(residual, pre_mlp_layernorm_output) + else: + hidden_states = self._forward_dense_ffn(pre_mlp_layernorm_output) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + + hidden_states = hidden_states * self.layer_scalar + + output = make_viewless_tensor( + inp=hidden_states, + requires_grad=hidden_states.requires_grad, + keep_graph=True, + ) + + if self.config.external_cuda_graph and self.training: + return output + return output, context + + +class SDPACoreAttention(nn.Module): + """Gemma4 core attention. + + Replaces TE's DotProductAttention because: + - Global layers have head_dim=512, which flash-attn 2.x doesn't support. + - Sliding-window layers need an explicit left-window mask (HF behavior). + - Context-parallelism on the global layers needs an all-gather+full-attn + path with a differentiable K/V gather. + + Dispatch at call time (packed / thd shape): + - CP > 1 (any layer) : all-gather K/V, apply causal + optional + sliding-window mask computed from slime zig-zag global indices. + - global + CP == 1 : sub-sequence causal SDPA (no O(T^2) mask alloc). + - sliding + CP == 1 : flash_attn_varlen_func with (sw-1, 0) window. + """ + + def __init__( + self, + config, + layer_number, + attn_mask_type, + attention_type="self", + attention_dropout=None, + softmax_scale=None, + **kwargs, + ): + super().__init__() + # Megatron's SelfAttention.__init__ passes a few kwargs (e.g. cp_comm_type, + # model_comm_pgs) intended for TE's DotProductAttention. We accept-and-ignore + # by name rather than asserting empty; a strict assert breaks whenever + # Megatron/TE add a new kwarg. If a kwarg shows up here that we *should* + # honor (e.g. a new softmax dtype), it will surface as a behavioral bug + # in parity, which is what the test suite covers. + del kwargs + self.config = config + self.softmax_scale = softmax_scale + self.dropout_p = config.attention_dropout if attention_dropout is None else attention_dropout + self._is_sliding = False # set by Gemma4TransformerLayer + + def _resolve_scale(self, hn: int) -> float: + return self.softmax_scale if self.softmax_scale is not None else (hn**-0.5) + + @staticmethod + def _zigzag_global_indices(local_len, cp_rank, cp_size, device): + """Global positions of this rank's local Q tokens under slime's + zig-zag CP layout (matches cp_utils.slice_with_cp). + + Local tokens on rank r occupy two global sub-ranges: + [r*cs, (r+1)*cs) and [(2*cp-r-1)*cs, (2*cp-r)*cs) + where cs = local_len / 2 = seq_len / (2*cp_size). + """ + cs = local_len // 2 + first = torch.arange(cp_rank * cs, (cp_rank + 1) * cs, device=device) + second = torch.arange( + (2 * cp_size - cp_rank - 1) * cs, + (2 * cp_size - cp_rank) * cs, + device=device, + ) + return torch.cat([first, second]) + + @staticmethod + def _cp_unzigzag_permutation(cu_seqlens_list, cp_size, device): + """Map rank-major CP-gathered K/V tokens back to packed global order.""" + total_local_len = sum( + (cu_seqlens_list[i + 1] - cu_seqlens_list[i]) // cp_size for i in range(len(cu_seqlens_list) - 1) + ) + local_prefix = 0 + perm_parts = [] + for s_idx in range(len(cu_seqlens_list) - 1): + seq_len_global = cu_seqlens_list[s_idx + 1] - cu_seqlens_list[s_idx] + cs = seq_len_global // (2 * cp_size) + g = torch.arange(seq_len_global, device=device) + chunk = g // cs + owner = torch.where(chunk < cp_size, chunk, 2 * cp_size - 1 - chunk) + local_in_rank = torch.where( + chunk < cp_size, + g - owner * cs, + cs + (g - (2 * cp_size - 1 - owner) * cs), + ) + perm_parts.append(owner * total_local_len + local_prefix + local_in_rank) + local_prefix += seq_len_global // cp_size + return torch.cat(perm_parts) + + def _forward_cp_subseq_mask(self, query, key, value, packed_seq_params, sliding_window=None): + """CP>1 path for any layer: all-gather K/V, then loop over sub-seqs + and apply a per-sub-seq attention mask built from zig-zag global + positions. Supports causal-only (global layers) and causal + + sliding-window (sliding layers). + + Under slime's CP convention, ``packed_seq_params.cu_seqlens_q`` holds + GLOBAL boundaries: each packed sub-sequence on this rank represents + ``(cu[i+1] - cu[i])`` tokens globally but only ``(cu[i+1] - cu[i]) // + cp_size`` tokens locally (the zig-zag slice of this rank's two + chunks, concatenated as [first, second]). + """ + from megatron.core import parallel_state + from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region + + cp_group = parallel_state.get_context_parallel_group() + cp_size = parallel_state.get_context_parallel_world_size() + cp_rank = parallel_state.get_context_parallel_rank() + + t_local = query.shape[0] + np_q, hn = query.shape[1], query.shape[2] + nk = key.shape[1] + scale = self._resolve_scale(hn) + + # Differentiable all-gather along the token dim. forward: AG, + # backward: RS - so K/V grads on non-owning ranks flow back to the + # originating rank. The raw `dist.all_gather_into_tensor` has no + # autograd rule and PyTorch prints a "silently incorrect behavior" + # warning + drops those grads. + k_full = gather_from_sequence_parallel_region(key.contiguous(), group=cp_group) + v_full = gather_from_sequence_parallel_region(value.contiguous(), group=cp_group) + # gather_from_sequence_parallel_region stacks each rank's chunk + # consecutively in rank order. Under zig-zag, each rank's [2*cs] + # local tokens are [chunk_r_first, chunk_r_second]. So the gathered + # tensor layout is [r0_first, r0_second, r1_first, r1_second, ...]. + # We need to un-zig-zag into pure global order so mask indices line + # up. Build a permutation that maps gathered index -> global index. + device = query.device + dtype = query.dtype + cu_seqlens = packed_seq_params.cu_seqlens_q if packed_seq_params is not None else None + + # Sanity: for each packed sub-seq, the GLOBAL length must be + # divisible by 2*cp_size so chunk_size is integer. With cp_size=1 this + # reduces to even-length, which the CP=1 parity-test harness may + # violate (no zig-zag pre-slicing). Skip the check there; permutation + # is identity under cp_size=1 so odd length is harmless. + if cu_seqlens is not None and cp_size > 1: + expected_t_local = 0 + for s_idx in range(len(cu_seqlens) - 1): + s_len = (cu_seqlens[s_idx + 1] - cu_seqlens[s_idx]).item() + assert s_len % (2 * cp_size) == 0, ( + f"sub-sequence {s_idx} global length ({s_len}) is not " + f"divisible by 2*cp_size ({2 * cp_size}); `slice_with_cp` " + "should pad before packing" + ) + expected_t_local += s_len // cp_size + assert expected_t_local == t_local, ( + f"packed-seq local length mismatch: sum(seq_len // cp_size) = " + f"{expected_t_local}, but query.shape[0] = {t_local}" + ) + + if cu_seqlens is None: + t_full_total = k_full.shape[0] + cu_seqlens_list = [0, t_full_total] + else: + cu_seqlens_list = cu_seqlens.tolist() + + # With cp_size=1 the zigzag degenerates to identity and all-gather is + # a no-op; skip the permutation (and the floor-div that would drop the + # trailing odd token for seq_len_global % 2 == 1). + if cp_size > 1: + perm = self._cp_unzigzag_permutation(cu_seqlens_list, cp_size, device) + k_full = k_full.index_select(0, perm) + v_full = v_full.index_select(0, perm) + + out = torch.empty(t_local, np_q * hn, dtype=dtype, device=device) + + local_offset = 0 + for s_idx in range(len(cu_seqlens_list) - 1): + seq_start = cu_seqlens_list[s_idx] + seq_len_global = cu_seqlens_list[s_idx + 1] - seq_start + local_len = seq_len_global // cp_size # this sub-seq's local Q count + + q_seq = query[local_offset : local_offset + local_len] + k_seq = k_full[seq_start : seq_start + seq_len_global] + v_seq = v_full[seq_start : seq_start + seq_len_global] + + q4 = q_seq.unsqueeze(0).transpose(1, 2) # [1, np, local_len, hn] + k4 = k_seq.unsqueeze(0).transpose(1, 2) # [1, nk, seq_len, hn] + v4 = v_seq.unsqueeze(0).transpose(1, 2) + + # Global positions of local Q tokens. cp_size=1 degenerates to + # identity; use arange to preserve odd-length seqs (zigzag helper + # floor-divides, dropping the trailing token). + if cp_size > 1: + row_idx = self._zigzag_global_indices(local_len, cp_rank, cp_size, device) + else: + row_idx = torch.arange(local_len, device=device) + col_idx = torch.arange(seq_len_global, device=device) + forbid_future = col_idx[None, :] > row_idx[:, None] + if sliding_window is not None and sliding_window > 0: + forbid_past = col_idx[None, :] < (row_idx[:, None] - (sliding_window - 1)) + forbid = forbid_future | forbid_past + else: + forbid = forbid_future + mask = torch.where( + forbid, + torch.finfo(dtype).min, + 0.0, + ).to(dtype=dtype) + + o = F.scaled_dot_product_attention( + q4, + k4, + v4, + attn_mask=mask[None, None, :, :], + dropout_p=self.dropout_p if self.training else 0.0, + scale=scale, + enable_gqa=(np_q != nk), + ) + out[local_offset : local_offset + local_len] = o.transpose(1, 2).reshape(local_len, -1) + local_offset += local_len + + return out + + def _forward_thd_flash(self, query, key, value, cu_seqlens): + """Sliding-window or head_dim<=256 path via flash_attn_varlen_func. + + CP==1 only. For CP>1, `_forward_cp_subseq_mask` handles zig-zag. + + Sliding-window layers must pass `window_size=(sliding_window-1, 0)` so + only tokens within `sliding_window` positions back are attended to - + this matches HF's `sliding_window_mask_function`. Global layers and + dense-attention sliding layers use the default full-causal window. + """ + from flash_attn import flash_attn_varlen_func + + window_size = (-1, -1) # full causal when causal=True + if self._is_sliding: + sw = getattr(self.config, "sliding_window", None) + if sw and sw > 0: + window_size = (int(sw) - 1, 0) + + cu = cu_seqlens.to(torch.int32) + max_seqlen = (cu[1:] - cu[:-1]).max().item() + out = flash_attn_varlen_func( + query.contiguous(), + key.contiguous(), + value.contiguous(), + cu_seqlens_q=cu, + cu_seqlens_k=cu, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=self.dropout_p if self.training else 0.0, + softmax_scale=self._resolve_scale(query.shape[2]), + causal=True, + window_size=window_size, + ) + return out.reshape(query.shape[0], -1) + + def _forward_thd_sdpa_per_subseq(self, query, key, value, cu_seqlens): + """Per-sub-sequence causal SDPA - used when flash-attn can't handle + head_dim (global layer w/o CP). Avoids materializing a [T, T] mask. + """ + np_q, hn = query.shape[1], query.shape[2] + nk = key.shape[1] + scale = self._resolve_scale(hn) + out = torch.empty(query.shape[0], np_q * hn, dtype=query.dtype, device=query.device) + for i in range(len(cu_seqlens) - 1): + s = cu_seqlens[i].item() + e = cu_seqlens[i + 1].item() + q4 = query[s:e].unsqueeze(0).transpose(1, 2) # [1, np, L, hn] + k4 = key[s:e].unsqueeze(0).transpose(1, 2) + v4 = value[s:e].unsqueeze(0).transpose(1, 2) + o = F.scaled_dot_product_attention( + q4, + k4, + v4, + dropout_p=self.dropout_p if self.training else 0.0, + scale=scale, + is_causal=True, + enable_gqa=(np_q != nk), + ) + out[s:e] = o.transpose(1, 2).reshape(e - s, -1) + return out + + def forward(self, query, key, value, attention_mask=None, attn_mask_type=None, packed_seq_params=None, **kwargs): + cp_size = getattr(self.config, "context_parallel_size", 1) or 1 + is_thd = query.dim() == 3 + + force_cp_path = getattr(self.config, "force_cp_subseq_mask", False) + + if is_thd: + if cp_size > 1 or force_cp_path: + sw = None + if self._is_sliding: + sw_cfg = getattr(self.config, "sliding_window", None) + if sw_cfg and sw_cfg > 0: + sw = int(sw_cfg) + return self._forward_cp_subseq_mask( + query, + key, + value, + packed_seq_params, + sliding_window=sw, + ) + + cu_seqlens = None + if packed_seq_params is not None: + cu_seqlens = packed_seq_params.cu_seqlens_q + + hn = query.shape[2] + if cu_seqlens is not None: + if hn <= 256: + return self._forward_thd_flash(query, key, value, cu_seqlens) + return self._forward_thd_sdpa_per_subseq(query, key, value, cu_seqlens) + + q = query.unsqueeze(0).transpose(1, 2) + k = key.unsqueeze(0).transpose(1, 2) + v = value.unsqueeze(0).transpose(1, 2) + nq, nk = q.shape[1], k.shape[1] + out = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.dropout_p if self.training else 0.0, + scale=self._resolve_scale(hn), + is_causal=True, + enable_gqa=(nq != nk), + ) + return out.transpose(1, 2).reshape(query.shape[0], -1) + + q = query.permute(1, 2, 0, 3) + k = key.permute(1, 2, 0, 3) + v = value.permute(1, 2, 0, 3) + nq, nk = q.shape[1], k.shape[1] + out = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.dropout_p if self.training else 0.0, + scale=self._resolve_scale(query.shape[3]), + is_causal=True, + enable_gqa=(nq != nk), + ) + return out.permute(2, 0, 1, 3).reshape(out.size(2), out.size(0), -1) + + +class Gemma4SelfAttention(SelfAttention): + """SelfAttention with Gemma4-specific modifications: + - v_norm: RMSNorm without learnable scale applied to value states. + - attention_k_eq_v: on global layers the linear_qkv projection emits + ``[q, k]`` only (no v_proj) and V is derived from K - specifically + ``V = v_norm(raw_k)`` while ``K = k_norm(raw_k)``. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._is_global = False # set by Gemma4TransformerLayer after construction + self.v_norm = VNorm(self.hidden_size_per_attention_head, eps=self.config.layernorm_epsilon) + + def _split_qkv_global_k_eq_v(self, hidden_states): + """Split linear_qkv output for global K=V layers. + + The Mcore linear_qkv weight for a K=V global layer is built with + ``v_proj_weight == k_proj_weight`` (see Gemma4Bridge + convert_gemma4_to_hf), + so ``linear_qkv(h)`` emits Q/K/V with ``raw_k == raw_v``. Gemma4's + per-head norms then apply as ``key = k_norm(raw_k)`` and + ``value = v_norm(raw_k)`` - *not* ``v_norm(k_norm(raw_k))``. We + reimplement the split here rather than calling the parent so we + don't have to mutate ``self.k_layernorm`` mid-forward. + + Returns (query[sq,b,np,hn], key[sq,b,ng,hn], value[sq,b,ng,hn]). + """ + mixed_qkv, _ = self.linear_qkv(hidden_states) + num_query_heads_per_group = self.num_attention_heads_per_partition // self.num_query_groups_per_partition + new_shape = mixed_qkv.size()[:-1] + ( + self.num_query_groups_per_partition, + (num_query_heads_per_group + 2) * self.hidden_size_per_attention_head, + ) + mixed_qkv = mixed_qkv.view(*new_shape) + + q_width = num_query_heads_per_group * self.hidden_size_per_attention_head + hn = self.hidden_size_per_attention_head + query, raw_key, _raw_value = torch.split(mixed_qkv, [q_width, hn, hn], dim=3) + query = query.reshape(query.size(0), query.size(1), -1, hn) + + if self.q_layernorm is not None: + query = self.q_layernorm(query) + + value = self.v_norm(raw_key) + key = self.k_layernorm(raw_key) if self.k_layernorm is not None else raw_key + return query, key, value + + def get_query_key_value_tensors(self, hidden_states, key_value_states=None, output_gate=False, split_qkv=True): + if self._is_global and self.config.attention_k_eq_v and split_qkv: + if output_gate: + raise NotImplementedError("output_gate is not supported together with attention_k_eq_v") + return self._split_qkv_global_k_eq_v(hidden_states) + + result = super().get_query_key_value_tensors( + hidden_states, key_value_states, output_gate=output_gate, split_qkv=split_qkv + ) + if not split_qkv: + return result + + if output_gate: + query, key, value, gate = result + value = self.v_norm(value) + return query, key, value, gate + + query, key, value = result + value = self.v_norm(value) + return query, key, value + + +def _build_moe_submodule_spec(config): + """Build the MoE submodule spec (Gemma4MoELayer + TE GroupedMLP experts).""" + from megatron.core.extensions.transformer_engine_spec_provider import TESpecProvider + from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec_for_backend + + base_spec = get_moe_module_spec_for_backend( + backend=TESpecProvider(), + num_experts=config.num_moe_experts, + moe_grouped_gemm=config.moe_grouped_gemm, + use_te_activation_func=False, # use plain F.gelu(approximate='tanh') from config.activation_func + ) + return ModuleSpec( + module=Gemma4MoELayer, + submodules=base_spec.submodules, + metainfo=base_spec.metainfo, + ) + + +def get_gemma4_layer_spec_te(config=None) -> ModuleSpec: + """Layer spec for Gemma4 using native Megatron attention with TE. + + If ``config.enable_moe_block`` is set, the main ``mlp`` submodule is a + :class:`Gemma4MoELayer` (so that the state-dict path + ``.mlp.experts.linear_fc*.weight*`` matches mbridge's EP auto-handling), + and the original dense MLP moves to a sibling ``dense_mlp`` submodule that + the layer forward sums with the MoE output. For the 31B dense variant, + ``enable_moe_block=False`` and ``mlp`` stays as the normal Megatron MLP. + """ + # dense_mlp: use a plain (non-fused-layernorm) linear_fc1 so our explicit + # `pre_mlp_layernorm` in the layer forward is the sole norm applied to the + # MLP input. Using TELayerNormColumnParallelLinear here would apply a + # SECOND layernorm inside fc1, resulting in double-normalization and + # ~8x inflated MLP outputs. + dense_mlp_spec = ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TEColumnParallelLinear, + linear_fc2=TERowParallelLinear, + ), + ) + if config is not None and getattr(config, "enable_moe_block", False): + mlp_spec = _build_moe_submodule_spec(config) + dense_spec = dense_mlp_spec + else: + mlp_spec = dense_mlp_spec + dense_spec = IdentityOp + + submods = Gemma4TransformerLayerSubmodules( + self_attention=ModuleSpec( + module=Gemma4SelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=SelfAttentionSubmodules( + linear_qkv=TELayerNormColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + q_layernorm=TENorm, + k_layernorm=TENorm, + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=IdentityOp, + mlp=mlp_spec, + mlp_bda=get_bias_dropout_add, + post_attention_layernorm=TENorm, + post_feedforward_layernorm=TENorm, + dense_mlp=dense_spec, + ) + return ModuleSpec(module=Gemma4TransformerLayer, submodules=submods) + + +@functools.lru_cache(maxsize=4) +def _load_hf_text_config(hf_checkpoint): + """Load HF config and unwrap `text_config` if it's a multimodal wrapper. + + Cached via lru_cache so repeated callers (model provider, mbridge, weight + converter) all share the same parsed object. + """ + from transformers import AutoConfig + + cfg = AutoConfig.from_pretrained(hf_checkpoint, trust_remote_code=True) + return cfg.text_config if hasattr(cfg, "text_config") else cfg + + +class _Gemma4MoELayerWarningFilter(logging.Filter): + """Silence the once-per-layer Megatron warning: + 'Unknown MLP type: . Using default kwargs.' + Megatron's TransformerLayer.__init__ recognizes a hardcoded tuple of MLP + classes via `==` (not issubclass), so Gemma4MoELayer (a MoELayer subclass) + falls through to the default-kwargs branch. That branch is correct for us + - Gemma4MoELayer.__init__ fetches its own pg_collection via + get_default_pg_collection - but the warning spams 30 lines per layer at + init and confuses log readers. See gemma4_provider.py install hook. + """ + + def filter(self, record: logging.LogRecord) -> bool: + msg = record.getMessage() + return not ("Unknown MLP type" in msg and "Gemma4MoELayer" in msg) + + +def _install_moe_warning_filter(): + """Silence the per-layer "Unknown MLP type: Gemma4MoELayer" warning. + + Megatron's TransformerLayer compares MLP class identity via ``==``, so + MoELayer subclasses hit the default-kwargs branch and log a warning. + The default-kwargs branch is correct for us (Gemma4MoELayer fetches + pg_collection itself); filter the noise. + """ + tl_logger = logging.getLogger("megatron.core.transformer.transformer_layer") + if getattr(tl_logger, "_gemma4_moe_filter_installed", False): + return + tl_logger.addFilter(_Gemma4MoELayerWarningFilter()) + tl_logger._gemma4_moe_filter_installed = True + + +def _assert_hf_features_supported(hf_text): + """Fail loudly on Gemma4 HF features this plugin doesn't implement.""" + if getattr(hf_text, "hidden_size_per_layer_input", 0): + raise NotImplementedError( + "Gemma4 per-layer input mechanism " + f"(hidden_size_per_layer_input={hf_text.hidden_size_per_layer_input}) " + "is not implemented. See Gemma4TextDecoderLayer.per_layer_input_gate in HF." + ) + if getattr(hf_text, "num_kv_shared_layers", 0): + raise NotImplementedError( + "Gemma4 KV-sharing across the last N layers " + f"(num_kv_shared_layers={hf_text.num_kv_shared_layers}) is not implemented." + ) + if getattr(hf_text, "use_double_wide_mlp", False): + raise NotImplementedError("Gemma4 use_double_wide_mlp is not implemented.") + # Text-only training assumes causal attention; HF's "all" mode disables it. + if getattr(hf_text, "use_bidirectional_attention", "vision") == "all": + raise NotImplementedError("Gemma4 use_bidirectional_attention='all' disables causal masking; not supported.") + + +def _apply_core_config(config, hf_text): + """Set Gemma4's non-MoE, non-RoPE config fields. + + Mutates ``config`` in place. Promotes its ``__class__`` to + ``Gemma4TransformerConfig`` so the new dataclass fields are reachable + from downstream Megatron code. + """ + # Gemma uses GeGLU (gated gelu-tanh), not SwiGLU. + config.gated_linear_unit = True + config.activation_func = _gelu_tanh + config.bias_activation_fusion = False + + # No MoE-vs-dense layer scheduling: every layer is our Gemma4TransformerLayer + # and the MoE block lives inside its forward. An all-zero list keeps + # transformer_block's non_homogeneous_layers=True branch active (correct for + # 26B's differing global vs sliding head_dim / num_kv_heads). + # Rationale for using moe_layer_freq as the flag: Megatron's + # TransformerBlock.__init__ sets ``non_homogeneous_layers = True`` iff + # ``config.moe_layer_freq is not None``. We only need that flag on - + # the actual dense/MoE dispatch happens inside + # Gemma4TransformerLayer.forward, so the list contents are never + # consulted by TransformerBlock itself. If a future Megatron refactor + # starts reading the list per-layer, we need a Gemma4-specific schedule + # instead. + config.moe_layer_freq = [0] * config.num_layers + + # Mirror Megatron's own misspelling (`hetereogenous_*`) - correcting it + # would silently no-op on Megatron's read path. + config.hetereogenous_dist_checkpoint = True + + config.__class__ = Gemma4TransformerConfig + config.global_kv_channels = hf_text.global_head_dim + config.global_num_query_groups = hf_text.num_global_key_value_heads + config.attention_k_eq_v = getattr(hf_text, "attention_k_eq_v", True) + config.final_logit_softcapping = getattr(hf_text, "final_logit_softcapping", 30.0) + config.sliding_window = hf_text.sliding_window + + # `sliding_window_pattern` isn't in Gemma4 HF configs - infer from + # layer_types (first full_attention layer's 1-indexed position). + layer_types = list(getattr(hf_text, "layer_types", [])) + try: + config.sliding_window_pattern = layer_types.index("full_attention") + 1 + except ValueError: + config.sliding_window_pattern = 6 + + # Q/K norms handle softmax scaling; Megatron's default of 1/sqrt(hn) is wrong. + config.softmax_scale = 1.0 + # Fused RoPE ignores zeroed inv_freq tails; we need unfused for partial-rotary. + config.apply_rope_fusion = False + + +def _apply_moe_config(config, hf_text): + """Set MoE fields if this is a MoE variant (26B-A4B).""" + config.enable_moe_block = getattr(hf_text, "enable_moe_block", False) + if not config.enable_moe_block: + return + + config.num_moe_experts = hf_text.num_experts + config.moe_router_topk = hf_text.top_k_experts + config.moe_ffn_hidden_size = hf_text.moe_intermediate_size + # Megatron MoE infrastructure reads these even though our custom router + # bypasses its scoring logic; defaults mirror a working Qwen3.5-A3B config. + config.moe_token_dispatcher_type = getattr(config, "moe_token_dispatcher_type", None) or "alltoall" + config.moe_grouped_gemm = getattr(config, "moe_grouped_gemm", None) or True + config.moe_aux_loss_coeff = 0.0 # Gemma4 router has no aux loss + config.moe_router_load_balancing_type = getattr(config, "moe_router_load_balancing_type", None) or "none" + config.moe_router_score_function = getattr(config, "moe_router_score_function", None) or "softmax" + config.moe_router_topk_scaling_factor = getattr(config, "moe_router_topk_scaling_factor", None) or 1.0 + config.moe_router_pre_softmax = False + + +def get_rope_local_base_freq(hf_text) -> float: + """Extract sliding-attention RoPE theta from an HF Gemma4 text config. + + Single source of truth for both the model provider and the mbridge + config builder - otherwise the 10000.0 default would drift between + call sites. + """ + return (getattr(hf_text, "rope_parameters", {}) or {}).get("sliding_attention", {}).get("rope_theta", 10000.0) + + +def _apply_rope_config(config, hf_text): + rope_params = getattr(hf_text, "rope_parameters", {}) or {} + config.rope_local_base_freq = get_rope_local_base_freq(hf_text) + config.global_partial_rotary_factor = rope_params.get("full_attention", {}).get("partial_rotary_factor", 0.25) + + +def _guard_cp_sliding_window(args, config): + """Fail if per-rank CP token cap is smaller than the sliding window. + + Strong signal of a miscounted CP sizing - we'd train on truncated + attention windows otherwise. + """ + cp_size = getattr(args, "context_parallel_size", 1) or 1 + if cp_size <= 1: + return + max_tokens = getattr(args, "max_tokens_per_gpu", None) + if max_tokens is not None and max_tokens < config.sliding_window: + raise ValueError( + f"context_parallel_size={cp_size} with max_tokens_per_gpu={max_tokens} " + f"< sliding_window={config.sliding_window}: per-rank CP chunk cap is " + "smaller than the sliding window. Reduce CP or raise max_tokens_per_gpu." + ) + + +def get_gemma4_spec(args, config, vp_stage): + """Return the native Gemma4 layer spec with proper config overrides.""" + hf_text = _load_hf_text_config(args.hf_checkpoint) + + _install_moe_warning_filter() + _assert_hf_features_supported(hf_text) + _apply_core_config(config, hf_text) + _apply_moe_config(config, hf_text) + _apply_rope_config(config, hf_text) + _guard_cp_sliding_window(args, config) + + spec = get_gemma4_layer_spec_te(config) + from megatron.core.extensions.transformer_engine_spec_provider import TESpecProvider + + if not getattr(config, "enable_moe_block", False): + spec.submodules.mlp.submodules.linear_fc1 = TEColumnParallelLinear + spec.submodules.mlp.metainfo = {"fuse_pre_mlp_layernorm": False} + spec.submodules.pre_mlp_layernorm = TESpecProvider().layer_norm() + return spec diff --git a/slime_plugins/models/gemma4_provider.py b/slime_plugins/models/gemma4_provider.py new file mode 100644 index 0000000000..951718cab7 --- /dev/null +++ b/slime_plugins/models/gemma4_provider.py @@ -0,0 +1,325 @@ +"""Custom model provider for Gemma4. + +Installs Gemma4-specific behaviors that sit outside the transformer layer: +- embedding scaling (multiply embeddings by sqrt(hidden_size)) +- logit softcapping (`final_logit_softcapping`) +- dual-RoPE (different rope_theta + partial-rotary for global vs sliding layers) +- layer_scalar buffers loaded from the HF checkpoint +""" + +import json +import logging +import os + +import torch +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.transformer.spec_utils import import_module +from megatron.training import get_args +from megatron.training.arguments import core_transformer_config_from_args + +from slime_plugins.models.gemma4 import _load_hf_text_config + +logger = logging.getLogger(__name__) + + +def _is_rank_zero() -> bool: + if not torch.distributed.is_available() or not torch.distributed.is_initialized(): + return True + return torch.distributed.get_rank() == 0 + + +def model_provider(pre_process=True, post_process=True, vp_stage=None): + args = get_args() + config = core_transformer_config_from_args(args) + + transformer_layer_spec = import_module(args.spec) + if callable(transformer_layer_spec): + transformer_layer_spec = transformer_layer_spec(args, config, vp_stage) + + model = GPTModel( + config=config, + transformer_layer_spec=transformer_layer_spec, + vocab_size=args.padded_vocab_size, + max_sequence_length=args.max_position_embeddings, + pre_process=pre_process, + post_process=post_process, + fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, + parallel_output=True, + share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, + position_embedding_type=args.position_embedding_type, + rotary_percent=args.rotary_percent, + rotary_base=args.rotary_base, + rope_scaling=args.use_rope_scaling, + ) + + _install_hooks(model, args, config, pre_process, post_process) + return model + + +class DualRotaryEmbedding(torch.nn.Module): + """Wraps a (global, local) pair of RotaryEmbedding modules and emits a + single concatenated tensor (global part first). ``Gemma4TransformerLayer`` + slices it per-layer based on ``is_sliding``. Concat (not tuple) because + Megatron's ``SelfAttention.forward`` reads a 2-tuple as + ``(self_attn, cross_attn)`` RoPE and would misread our pair. + """ + + def __init__(self, local_rope, global_rope, global_dim: int): + super().__init__() + self.local_rope = local_rope + self.global_rope = global_rope + self.global_dim = global_dim + + def get_rotary_seq_len(self, *args, **kwargs): + return self.local_rope.get_rotary_seq_len(*args, **kwargs) + + def forward(self, seq_len, **kwargs): + global_emb = self.global_rope(seq_len, **kwargs) + local_emb = self.local_rope(seq_len, **kwargs) + return torch.cat([global_emb, local_emb], dim=-1) + + +class _Gemma4LogitSoftcap(torch.autograd.Function): + """Apply Gemma4 final logit softcapping without allocating new logits.""" + + @staticmethod + def forward(ctx, logits: torch.Tensor, scale: float) -> torch.Tensor: + ctx.scale = scale + ctx.mark_dirty(logits) + logits.div_(scale) + logits.tanh_() + logits.mul_(scale) + ctx.save_for_backward(logits) + return logits + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]: + (softcapped,) = ctx.saved_tensors + scale = ctx.scale + grad_logits = softcapped / scale + grad_logits.pow_(2) + grad_logits.neg_() + grad_logits.add_(1.0) + grad_logits.mul_(grad_output) + return grad_logits, None + + +def _logit_softcapping(logits: torch.Tensor, scale: float) -> torch.Tensor: + if scale <= 0: + return logits + return _Gemma4LogitSoftcap.apply(logits, float(scale)) + + +def _install_hooks(model, args, config, pre_process, post_process): + """Install Gemma4-specific pre/post-process hooks on a built GPTModel. + + We use ``register_forward_hook`` rather than subclassing GPTModel + because: + - Two independent behaviors (embed scale, softcap) on two different + submodules. Subclassing would require overriding + ``GPTModel.forward`` and branching on pp/vp stage. + - The hooks are shape- and dtype-preserving, so they compose cleanly + with PP (only first-stage runs embedding, only last-stage runs + output_layer) - we gate registration on ``pre_process`` / + ``post_process`` accordingly. + - Keeps the diff local to this plugin: we don't need to shadow any + Megatron-maintained class. + """ + hf_text = _load_hf_text_config(args.hf_checkpoint) + hidden_size = config.hidden_size + + inner = model.module if hasattr(model, "module") else model + + # Embedding scaling - HF applies this inside the embedding module. + # See ``Gemma4TextScaledWordEmbedding``: the scale is stored as an fp32 + # tensor and cast to the embedding weight's dtype at forward time, so + # the scale-as-applied depends on the current weight dtype (bf16 during + # training, fp32 during some eval paths). We match that behavior here. + if pre_process and hasattr(inner, "embedding"): + embed_scale = torch.tensor(hidden_size**0.5) # fp32 + + def _embed_hook(module, inp, output): + return output * embed_scale.to(output.dtype) + + inner.embedding.register_forward_hook(_embed_hook) + + # Final logit softcapping - HF applies tanh(logits / cap) * cap. + # Some Megatron output_layer variants (parallel_output paths) return + # ``(logits, bias)``; we pass the non-logit tail through unchanged. + softcap = getattr(hf_text, "final_logit_softcapping", None) + if post_process and softcap and hasattr(inner, "output_layer"): + + def _softcap_hook(module, inp, output): + if isinstance(output, tuple): + return (_logit_softcapping(output[0], softcap),) + output[1:] + return _logit_softcapping(output, softcap) + + inner.output_layer.register_forward_hook(_softcap_hook) + + # Dual RoPE: replace Megatron's single rotary_pos_emb with a wrapper that + # produces (global, local) RoPE side-by-side. Gemma4 uses partial-rotary + # on global layers (implemented here by zeroing the tail of inv_freq). + if hasattr(inner, "rotary_pos_emb"): + from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding + + rope_params = getattr(hf_text, "rope_parameters", {}) or {} + full = rope_params.get("full_attention", {}) or {} + sliding = rope_params.get("sliding_attention", {}) or {} + global_theta = full.get("rope_theta", 1_000_000.0) + local_theta = sliding.get("rope_theta", 10_000.0) + global_head_dim = hf_text.global_head_dim + global_partial = full.get("partial_rotary_factor", 0.25) + + local_rope = inner.rotary_pos_emb # already built with args.rotary_base + + global_rope = RotaryEmbedding( + kv_channels=global_head_dim, + rotary_percent=1.0, + rotary_base=global_theta, + ) + # HF "proportional" RoPE: first (partial * head_dim // 2) inv_freq + # entries are live, the rest are zero (no rotation on those dims). + # Writing this to the existing buffer keeps device/dtype correct. + rope_angles = int(global_partial * global_head_dim // 2) + half = global_head_dim // 2 + # Guard the RoPE geometry: 0 means "no rotation" (nonsensical here); + # > half would produce nope<0 and a shape-mismatched copy_. Both + # should fail loudly rather than silently writing garbage. + assert 0 < rope_angles <= half, ( + f"global_partial_rotary_factor={global_partial} with " + f"global_head_dim={global_head_dim} produced rope_angles=" + f"{rope_angles}; must be in (0, {half}]." + ) + inv_freq_live = 1.0 / ( + global_theta ** (torch.arange(0, 2 * rope_angles, 2, dtype=torch.float) / global_head_dim) + ) + nope = half - rope_angles + inv_freq = torch.cat([inv_freq_live, torch.zeros(nope)]) if nope > 0 else inv_freq_live + assert inv_freq.shape == global_rope.inv_freq.shape, ( + f"inv_freq shape {tuple(inv_freq.shape)} doesn't match " + f"global_rope.inv_freq shape {tuple(global_rope.inv_freq.shape)}; " + "Megatron RotaryEmbedding layout may have changed." + ) + global_rope.inv_freq.copy_(inv_freq.to(global_rope.inv_freq.device)) + + inner.rotary_pos_emb = DualRotaryEmbedding(local_rope, global_rope, global_head_dim) + config.dual_rope_global_dim = global_head_dim + if _is_rank_zero(): + logger.info( + "DualRotaryEmbedding: local_theta=%s global_theta=%s " "global_dim=%s rope_angles=%d (nope=%d)", + local_theta, + global_theta, + global_head_dim, + rope_angles, + nope, + ) + + if hasattr(inner, "decoder") and args.hf_checkpoint: + _load_layer_scalars(inner, args.hf_checkpoint, config) + + +def _read_layer_scalars_from_safetensors(hf_checkpoint: str) -> dict[int, float] | None: + """Read all ``layer_scalar`` values from the HF safetensors checkpoint. + + Returns ``{global_layer_idx: scalar}`` or ``None`` if the checkpoint has + no safetensors index (older HF layouts) or no layer_scalar weights. Only + called on rank 0 - results are broadcast to the other ranks. + """ + index_path = os.path.join(hf_checkpoint, "model.safetensors.index.json") + if not os.path.exists(index_path): + logger.warning("No safetensors index at %s; skipping layer scalars", index_path) + return None + + from safetensors import safe_open + + with open(index_path) as f: + index = json.load(f) + + scalars: dict[int, float] = {} + for key, filename in index["weight_map"].items(): + if "layer_scalar" not in key: + continue + layer_idx = int(key.split(".layers.")[1].split(".")[0]) + with safe_open(os.path.join(hf_checkpoint, filename), framework="pt", device="cpu") as sf: + scalars[layer_idx] = sf.get_tensor(key).item() + + if not scalars: + logger.warning("No layer_scalar weights found in checkpoint %s", hf_checkpoint) + return None + return scalars + + +def _broadcast_layer_scalars(scalars: dict[int, float] | None) -> dict[int, float] | None: + """Broadcast the rank-0-read ``scalars`` dict to every rank. + + safetensors reads on every rank cause an O(world_size) fan-out of tiny + reads on the shared filesystem; the dict itself is a few kilobytes. If + ``torch.distributed`` isn't initialized (single-process run), we simply + return the input dict. + """ + if not torch.distributed.is_available() or not torch.distributed.is_initialized(): + return scalars + obj = [scalars] if torch.distributed.get_rank() == 0 else [None] + torch.distributed.broadcast_object_list(obj, src=0) + return obj[0] + + +def _load_layer_scalars(inner, hf_checkpoint, config): + # Wrong layer_scalars materially change activations vs HF (they're per- + # layer multiplicative gains on the residual stream, not decorative), so + # by default we fail hard if the load breaks. Set + # GEMMA4_ALLOW_MISSING_LAYER_SCALARS=1 to downgrade to a warning and + # train with the default value of 1.0 - only useful for debug runs + # against a checkpoint that genuinely lacks these buffers. + allow_missing = os.environ.get("GEMMA4_ALLOW_MISSING_LAYER_SCALARS") == "1" + try: + scalars = _read_layer_scalars_from_safetensors(hf_checkpoint) if _is_rank_zero() else None + scalars = _broadcast_layer_scalars(scalars) + if not scalars: + if allow_missing: + return + raise RuntimeError( + "No layer_scalar weights found in checkpoint; set " + "GEMMA4_ALLOW_MISSING_LAYER_SCALARS=1 to proceed with " + "default values (not numerically equivalent to HF)." + ) + + # Under pipeline-parallelism, inner.decoder.layers holds only this + # rank's local subset. Translate the local index back to the global + # (HF 0-indexed) layer index so we apply the right scalar per layer. + from megatron.core.transformer.transformer_layer import get_transformer_layer_offset + + pp_offset = get_transformer_layer_offset(config) + + loaded = 0 + for i, layer in enumerate(inner.decoder.layers): + if hasattr(layer, "layer_scalar"): + global_idx = i + pp_offset + if global_idx not in scalars: + if allow_missing: + logger.warning( + "layer_scalar for global layer %d missing; using default 1.0", + global_idx, + ) + else: + raise KeyError( + f"layer_scalar for global layer {global_idx} " + f"missing in checkpoint (have: {sorted(scalars)[:10]}...); " + "checkpoint may be truncated." + ) + layer.layer_scalar.fill_(scalars.get(global_idx, 1.0)) + loaded += 1 + if _is_rank_zero(): + logger.info( + "Applied %d/%d layer scalars (pp_offset=%d, range=%.4f..%.4f)", + loaded, + len(inner.decoder.layers), + pp_offset, + min(scalars.values()), + max(scalars.values()), + ) + except (FileNotFoundError, json.JSONDecodeError) as e: + if allow_missing: + logger.warning("layer scalars unavailable (%s: %s); using default 1.0", type(e).__name__, e) + return + raise diff --git a/tests/gemma4/_standalone_imports.py b/tests/gemma4/_standalone_imports.py new file mode 100644 index 0000000000..d85da55f11 --- /dev/null +++ b/tests/gemma4/_standalone_imports.py @@ -0,0 +1,154 @@ +import importlib.util +import pathlib +import sys +import types +from collections.abc import Iterator +from contextlib import contextmanager + + +def _repo_path(*parts: str) -> pathlib.Path: + return pathlib.Path(__file__).resolve().parents[2].joinpath(*parts) + + +def _ensure_module(name: str) -> types.ModuleType: + module = sys.modules.get(name) + if module is None: + module = types.ModuleType(name) + module.__path__ = [] + sys.modules[name] = module + + if "." in name: + parent_name, attr = name.rsplit(".", 1) + parent = _ensure_module(parent_name) + setattr(parent, attr, module) + + return module + + +def install_megatron_stubs() -> None: + import torch + + class _SelfAttentionStub(torch.nn.Module): + def get_query_key_value_tensors(self, *_args, **_kwargs): + raise NotImplementedError + + _ensure_module("megatron") + _ensure_module("megatron.core") + fusions = _ensure_module("megatron.core.fusions") + del fusions + fused_bias_dropout = _ensure_module("megatron.core.fusions.fused_bias_dropout") + fused_bias_dropout.get_bias_dropout_add = lambda *args, **kwargs: None + + _ensure_module("megatron.core.models") + _ensure_module("megatron.core.models.gpt") + gpt_model = _ensure_module("megatron.core.models.gpt.gpt_model") + gpt_model.GPTModel = object + + _ensure_module("megatron.core.transformer") + attention = _ensure_module("megatron.core.transformer.attention") + attention.SelfAttention = _SelfAttentionStub + attention.SelfAttentionSubmodules = type("SelfAttentionSubmodules", (), {}) + enums = _ensure_module("megatron.core.transformer.enums") + enums.AttnMaskType = type("AttnMaskType", (), {"causal": "causal"}) + identity_op = _ensure_module("megatron.core.transformer.identity_op") + identity_op.IdentityOp = type("IdentityOp", (), {}) + mlp = _ensure_module("megatron.core.transformer.mlp") + mlp.MLP = type("MLP", (), {}) + mlp.MLPSubmodules = type("MLPSubmodules", (), {}) + moe_layer = _ensure_module("megatron.core.transformer.moe.moe_layer") + moe_layer.BaseMoELayer = torch.nn.Module + moe_layer.MoELayer = torch.nn.Module + spec_utils = _ensure_module("megatron.core.transformer.spec_utils") + spec_utils.import_module = lambda *args, **kwargs: None + spec_utils.ModuleSpec = type("ModuleSpec", (), {}) + spec_utils.build_module = lambda *args, **kwargs: None + transformer_layer = _ensure_module("megatron.core.transformer.transformer_layer") + transformer_layer.TransformerLayer = object + transformer_layer.TransformerLayerSubmodules = type("TransformerLayerSubmodules", (), {}) + transformer_layer.get_transformer_layer_offset = lambda config: 0 + utils = _ensure_module("megatron.core.utils") + utils.make_viewless_tensor = lambda inp, **kwargs: inp + + training = _ensure_module("megatron.training") + training.get_args = lambda: None + arguments = _ensure_module("megatron.training.arguments") + arguments.core_transformer_config_from_args = lambda *args, **kwargs: None + + +def install_mbridge_stubs() -> None: + _ensure_module("mbridge") + core = _ensure_module("mbridge.core") + core.register_model = lambda *args, **kwargs: lambda cls: cls + models = _ensure_module("mbridge.models") + models.Gemma3Bridge = object + gemma3_config = _ensure_module("mbridge.models.gemma3.transformer_config") + gemma3_config.Gemma3TransformerConfig = type("Gemma3TransformerConfig", (), {}) + + +@contextmanager +def _temporary_module(name: str, module: types.ModuleType) -> Iterator[None]: + sentinel = object() + original = sys.modules.get(name, sentinel) + parent = sys.modules.get(name.rsplit(".", 1)[0]) if "." in name else None + attr = name.rsplit(".", 1)[1] if "." in name else None + original_attr = getattr(parent, attr, sentinel) if parent and attr else sentinel + + sys.modules[name] = module + if parent and attr: + setattr(parent, attr, module) + try: + yield + finally: + if original is sentinel: + sys.modules.pop(name, None) + else: + sys.modules[name] = original + + if parent and attr: + if original_attr is sentinel: + if getattr(parent, attr, None) is module: + delattr(parent, attr) + else: + setattr(parent, attr, original_attr) + + +def load_gemma4_provider_module(): + install_megatron_stubs() + gemma4_stub = types.ModuleType("slime_plugins.models.gemma4") + gemma4_stub._load_hf_text_config = lambda path: None + + with _temporary_module("slime_plugins.models.gemma4", gemma4_stub): + spec = importlib.util.spec_from_file_location( + "_gemma4_provider_under_test", + _repo_path("slime_plugins/models/gemma4_provider.py"), + ) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def load_gemma4_bridge_class(): + install_mbridge_stubs() + gemma4_stub = types.ModuleType("slime_plugins.models.gemma4") + gemma4_stub.get_rope_local_base_freq = lambda hf_text: None + + with _temporary_module("slime_plugins.models.gemma4", gemma4_stub): + spec = importlib.util.spec_from_file_location( + "_gemma4_bridge_under_test", + _repo_path("slime_plugins/mbridge/gemma4.py"), + ) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module.Gemma4Bridge + + +def load_gemma4_model_module(): + install_megatron_stubs() + install_mbridge_stubs() + spec = importlib.util.spec_from_file_location( + "_gemma4_model_under_test", + _repo_path("slime_plugins/models/gemma4.py"), + ) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module diff --git a/tests/gemma4/test_gemma4_attention.py b/tests/gemma4/test_gemma4_attention.py new file mode 100644 index 0000000000..edf755afb3 --- /dev/null +++ b/tests/gemma4/test_gemma4_attention.py @@ -0,0 +1,119 @@ +from types import SimpleNamespace + +import pytest +import torch + +try: + from slime_plugins.models.gemma4 import Gemma4SelfAttention, VNorm +except ModuleNotFoundError as exc: + missing = exc.name or "" + if not (missing == "megatron" or missing.startswith("megatron.") or missing == "mbridge"): + raise + from tests.gemma4._standalone_imports import load_gemma4_model_module + + _gemma4 = load_gemma4_model_module() + Gemma4SelfAttention = _gemma4.Gemma4SelfAttention + VNorm = _gemma4.VNorm + + +def _stub_attention(num_attention_heads, num_kv_heads, head_dim, hidden_size): + attn = object.__new__(Gemma4SelfAttention) + torch.nn.Module.__init__(attn) + + q_per_kv = num_attention_heads // num_kv_heads + out_width = num_kv_heads * (q_per_kv + 2) * head_dim + linear_qkv = torch.nn.Linear(hidden_size, out_width, bias=False) + torch.nn.init.normal_(linear_qkv.weight, std=0.02) + + def _linear_qkv(h): + return linear_qkv(h), None + + attn.linear_qkv = _linear_qkv + attn.num_attention_heads_per_partition = num_attention_heads + attn.num_query_groups_per_partition = num_kv_heads + attn.hidden_size_per_attention_head = head_dim + attn.q_layernorm = torch.nn.LayerNorm(head_dim) + attn.k_layernorm = torch.nn.LayerNorm(head_dim) + attn.v_norm = VNorm(head_dim, eps=1e-6) + attn.config = SimpleNamespace( + layernorm_epsilon=1e-6, + attention_k_eq_v=True, + ) + attn._is_global = False # flipped per-test + return attn, linear_qkv + + +def test_global_k_eq_v_produces_k_norm_and_v_norm_of_raw_k(): + torch.manual_seed(0) + num_attention_heads, num_kv_heads, head_dim, hidden_size = 8, 2, 512, 256 + attn, linear_qkv = _stub_attention(num_attention_heads, num_kv_heads, head_dim, hidden_size) + attn._is_global = True + + seq_len, batch = 4, 1 + hidden = torch.randn(seq_len, batch, hidden_size) + + query, key, value = attn.get_query_key_value_tensors(hidden) + + assert query.shape == (seq_len, batch, num_attention_heads, head_dim) + assert key.shape == (seq_len, batch, num_kv_heads, head_dim) + assert value.shape == (seq_len, batch, num_kv_heads, head_dim) + + mixed, _ = attn.linear_qkv(hidden) + q_per_kv = num_attention_heads // num_kv_heads + mixed = mixed.view(seq_len, batch, num_kv_heads, (q_per_kv + 2) * head_dim) + q_width = q_per_kv * head_dim + raw_q, raw_k, _raw_v = torch.split(mixed, [q_width, head_dim, head_dim], dim=3) + raw_q = raw_q.reshape(seq_len, batch, -1, head_dim) + + expected_query = attn.q_layernorm(raw_q) + expected_key = attn.k_layernorm(raw_k) + expected_value = attn.v_norm(raw_k) + + assert torch.allclose(query, expected_query), "query mismatch" + assert torch.allclose(key, expected_key), "key must be k_norm(raw_k)" + assert torch.allclose(value, expected_value), ( + "value must be v_norm(raw_k); if this fails, v is being derived from " "k_norm(raw_k) instead of raw_k" + ) + + +def test_global_k_eq_v_does_not_mutate_k_layernorm(): + torch.manual_seed(1) + attn, _ = _stub_attention(8, 2, 512, 256) + attn._is_global = True + + k_layernorm_before = attn.k_layernorm + hidden = torch.randn(3, 1, 256) + _ = attn.get_query_key_value_tensors(hidden) + assert attn.k_layernorm is k_layernorm_before + + +def test_global_k_eq_v_rejects_output_gate(): + attn, _ = _stub_attention(8, 2, 512, 256) + attn._is_global = True + with pytest.raises(NotImplementedError): + attn.get_query_key_value_tensors(torch.randn(3, 1, 256), output_gate=True) + + +def test_sliding_layer_applies_v_norm_to_value(): + torch.manual_seed(2) + num_attention_heads, num_kv_heads, head_dim, hidden_size = 8, 2, 256, 256 + attn, linear_qkv = _stub_attention(num_attention_heads, num_kv_heads, head_dim, hidden_size) + attn._is_global = False + + seq_len, batch = 3, 1 + raw_q = torch.randn(seq_len, batch, num_attention_heads, head_dim) + raw_k = torch.randn(seq_len, batch, num_kv_heads, head_dim) + raw_v = torch.randn(seq_len, batch, num_kv_heads, head_dim) + + def _fake_parent(*_a, **_k): + return raw_q, raw_k, raw_v + + import unittest.mock as mock + + _Base = Gemma4SelfAttention.__mro__[1] + with mock.patch.object(_Base, "get_query_key_value_tensors", _fake_parent): + query, key, value = attn.get_query_key_value_tensors(torch.randn(seq_len, batch, hidden_size)) + + assert torch.equal(query, raw_q) + assert torch.equal(key, raw_k) + assert torch.allclose(value, attn.v_norm(raw_v)) diff --git a/tests/gemma4/test_gemma4_bridge.py b/tests/gemma4/test_gemma4_bridge.py new file mode 100644 index 0000000000..6e70a5b823 --- /dev/null +++ b/tests/gemma4/test_gemma4_bridge.py @@ -0,0 +1,310 @@ +import importlib +import importlib.util +import pathlib +from types import SimpleNamespace + +import pytest +import torch + +from tests.gemma4._standalone_imports import load_gemma4_bridge_class + + +def _load_convert_module(): + try: + return importlib.import_module("slime.backends.megatron_utils.megatron_to_hf.gemma4") + except ImportError: + pass + repo_path = pathlib.Path(__file__).resolve().parents[2] / ( + "slime/backends/megatron_utils/megatron_to_hf/gemma4.py" + ) + if not repo_path.exists(): + pytest.skip(f"convert_gemma4_to_hf source not found at {repo_path}") + spec = importlib.util.spec_from_file_location("_gemma4_conv_under_test", repo_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +CFG_31B = SimpleNamespace( + hidden_size=5376, + num_attention_heads=32, + head_dim=256, + num_key_value_heads=16, + global_head_dim=512, + num_global_key_value_heads=4, + num_hidden_layers=60, + attention_k_eq_v=True, + layer_types=(["sliding_attention"] * 5 + ["full_attention"]) * 10, +) + + +def test_gemma4_bridge_dense_config_does_not_set_moe_kwargs(): + bridge = object.__new__(load_gemma4_bridge_class()) + bridge.hf_config = CFG_31B + bridge._build_base_config = lambda **kwargs: kwargs + + cfg = bridge._build_config() + + assert cfg["text_config_key"] is None + assert "num_moe_experts" not in cfg + assert "moe_router_topk" not in cfg + assert "moe_ffn_hidden_size" not in cfg + + +def test_gemma4_bridge_moe_config_sets_expert_parallel_kwargs(): + bridge = object.__new__(load_gemma4_bridge_class()) + bridge.hf_config = SimpleNamespace( + text_config=SimpleNamespace( + enable_moe_block=True, + num_experts=128, + top_k_experts=8, + moe_intermediate_size=704, + rope_parameters={"sliding_attention": {"rope_theta": 10000.0}}, + ) + ) + bridge._build_base_config = lambda **kwargs: kwargs + + cfg = bridge._build_config() + + assert cfg["text_config_key"] == "text_config" + assert cfg["num_moe_experts"] == 128 + assert cfg["moe_router_topk"] == 8 + assert cfg["moe_ffn_hidden_size"] == 704 + assert cfg["moe_token_dispatcher_type"] == "alltoall" + assert cfg["moe_grouped_gemm"] is True + assert cfg["moe_aux_loss_coeff"] == 0.0 + assert cfg["moe_router_load_balancing_type"] == "none" + assert cfg["moe_router_score_function"] == "softmax" + assert cfg["moe_router_pre_softmax"] is False + assert cfg["moe_router_dtype"] == "fp32" + + +def _pack_local_qkv(q, k, v): + num_kv = CFG_31B.num_key_value_heads + head_dim = CFG_31B.head_dim + q_per_kv = CFG_31B.num_attention_heads // num_kv + q = q.view(num_kv, q_per_kv * head_dim, CFG_31B.hidden_size) + k = k.view(num_kv, head_dim, CFG_31B.hidden_size) + v = v.view(num_kv, head_dim, CFG_31B.hidden_size) + return torch.cat([q, k, v], dim=1).reshape(-1, CFG_31B.hidden_size).contiguous() + + +def _pack_global_qkv(q, k): + num_kv = CFG_31B.num_global_key_value_heads + head_dim = CFG_31B.global_head_dim + q_per_kv = CFG_31B.num_attention_heads // num_kv + q = q.view(num_kv, q_per_kv * head_dim, CFG_31B.hidden_size) + k = k.view(num_kv, head_dim, CFG_31B.hidden_size) + return torch.cat([q, k, k], dim=1).reshape(-1, CFG_31B.hidden_size).contiguous() + + +def test_convert_gemma4_to_hf_local_layer_roundtrip(monkeypatch): + conv = _load_convert_module() + + conv._config_cache["/nonexistent"] = { + "global_attn_layers": {i for i, t in enumerate(CFG_31B.layer_types) if t == "full_attention"}, + "local_head_dim": CFG_31B.head_dim, + "global_head_dim": CFG_31B.global_head_dim, + "num_attention_heads": CFG_31B.num_attention_heads, + "local_num_kv_heads": CFG_31B.num_key_value_heads, + "global_num_kv_heads": CFG_31B.num_global_key_value_heads, + "hidden_size": CFG_31B.hidden_size, + } + + q = torch.randn(CFG_31B.num_attention_heads * CFG_31B.head_dim, CFG_31B.hidden_size) + k = torch.randn(CFG_31B.num_key_value_heads * CFG_31B.head_dim, CFG_31B.hidden_size) + v = torch.randn(CFG_31B.num_key_value_heads * CFG_31B.head_dim, CFG_31B.hidden_size) + packed = _pack_local_qkv(q, k, v) + + args = SimpleNamespace(hf_checkpoint="/nonexistent") + emitted = conv.convert_gemma4_to_hf( + args, + "module.module.decoder.layers.0.self_attention.linear_qkv.weight", + packed, + ) + names = {n for n, _ in emitted} + assert names == { + "model.language_model.layers.0.self_attn.q_proj.weight", + "model.language_model.layers.0.self_attn.k_proj.weight", + "model.language_model.layers.0.self_attn.v_proj.weight", + } + out = dict(emitted) + assert torch.allclose(out["model.language_model.layers.0.self_attn.q_proj.weight"], q) + assert torch.allclose(out["model.language_model.layers.0.self_attn.k_proj.weight"], k) + assert torch.allclose(out["model.language_model.layers.0.self_attn.v_proj.weight"], v) + + +def test_convert_gemma4_to_hf_global_layer_emits_no_v_proj(): + conv = _load_convert_module() + + conv._config_cache["/nonexistent"] = { + "global_attn_layers": {5, 11, 17, 23, 29, 35, 41, 47, 53, 59}, + "local_head_dim": CFG_31B.head_dim, + "global_head_dim": CFG_31B.global_head_dim, + "num_attention_heads": CFG_31B.num_attention_heads, + "local_num_kv_heads": CFG_31B.num_key_value_heads, + "global_num_kv_heads": CFG_31B.num_global_key_value_heads, + "hidden_size": CFG_31B.hidden_size, + } + + q = torch.randn(CFG_31B.num_attention_heads * CFG_31B.global_head_dim, CFG_31B.hidden_size) + k = torch.randn(CFG_31B.num_global_key_value_heads * CFG_31B.global_head_dim, CFG_31B.hidden_size) + packed = _pack_global_qkv(q, k) + + args = SimpleNamespace(hf_checkpoint="/nonexistent") + emitted = conv.convert_gemma4_to_hf( + args, + "module.module.decoder.layers.5.self_attention.linear_qkv.weight", + packed, + ) + names = {n for n, _ in emitted} + assert names == { + "model.language_model.layers.5.self_attn.q_proj.weight", + "model.language_model.layers.5.self_attn.k_proj.weight", + } + + +def test_convert_config_cache_is_checkpoint_scoped(monkeypatch): + conv = _load_convert_module() + conv._config_cache.clear() + + def fake_from_pretrained(path, trust_remote_code): + hidden_size = 128 if path == "/ckpt-a" else 256 + text_config = SimpleNamespace( + layer_types=["sliding_attention", "full_attention"], + head_dim=16, + global_head_dim=32, + num_attention_heads=4, + num_key_value_heads=2, + num_global_key_value_heads=1, + hidden_size=hidden_size, + ) + return SimpleNamespace(text_config=text_config) + + import transformers + + monkeypatch.setattr(transformers.AutoConfig, "from_pretrained", fake_from_pretrained) + + cfg_a = conv._get_config(SimpleNamespace(hf_checkpoint="/ckpt-a")) + cfg_b = conv._get_config(SimpleNamespace(hf_checkpoint="/ckpt-b")) + + assert cfg_a["hidden_size"] == 128 + assert cfg_b["hidden_size"] == 256 + assert conv._get_config(SimpleNamespace(hf_checkpoint="/ckpt-a")) is cfg_a + + +def test_convert_gemma4_to_hf_moe_expert_weights_stacked(): + conv = _load_convert_module() + num_experts = 4 # keep test fast + conv._config_cache["/nonexistent"] = { + "global_attn_layers": {5}, + "local_head_dim": 256, + "global_head_dim": 512, + "num_attention_heads": 16, + "local_num_kv_heads": 8, + "global_num_kv_heads": 2, + "hidden_size": 2816, + "num_experts": num_experts, + } + conv._expert_buffers.clear() + args = SimpleNamespace(hf_checkpoint="/nonexistent") + + fc1_tensors = [torch.randn(2 * 704, 2816) for _ in range(num_experts)] + emitted_total = [] + for e, t in enumerate(fc1_tensors): + out = conv.convert_gemma4_to_hf( + args, + f"module.module.decoder.layers.3.mlp.experts.linear_fc1.weight{e}", + t, + ) + emitted_total.append(out) + assert all(len(out) == 0 for out in emitted_total[:-1]) + last = emitted_total[-1] + assert len(last) == 1 + name, stacked = last[0] + assert name == "model.language_model.layers.3.experts.gate_up_proj" + assert stacked.shape == (num_experts, 2 * 704, 2816) + for e, t in enumerate(fc1_tensors): + assert torch.equal(stacked[e], t) + + fc2_tensors = [torch.randn(2816, 704) for _ in range(num_experts)] + emitted_total = [] + for e, t in enumerate(fc2_tensors): + out = conv.convert_gemma4_to_hf( + args, + f"module.module.decoder.layers.3.mlp.experts.linear_fc2.weight{e}", + t, + ) + emitted_total.append(out) + assert all(len(out) == 0 for out in emitted_total[:-1]) + last = emitted_total[-1] + assert len(last) == 1 + name, stacked = last[0] + assert name == "model.language_model.layers.3.experts.down_proj" + assert stacked.shape == (num_experts, 2816, 704) + for e, t in enumerate(fc2_tensors): + assert torch.equal(stacked[e], t) + + +def test_convert_gemma4_to_hf_moe_router_weights(): + conv = _load_convert_module() + conv._config_cache["/nonexistent"] = { + "global_attn_layers": {5}, + "local_head_dim": 256, + "global_head_dim": 512, + "num_attention_heads": 16, + "local_num_kv_heads": 8, + "global_num_kv_heads": 2, + "hidden_size": 2816, + } + args = SimpleNamespace(hf_checkpoint="/nonexistent") + for mcore_rest, hf_tail in [ + ("mlp.router.proj.weight", "router.proj.weight"), + ("mlp.router.scale", "router.scale"), + ("mlp.router.per_expert_scale", "router.per_expert_scale"), + ]: + param = torch.randn(4) + emitted = conv.convert_gemma4_to_hf( + args, + f"module.module.decoder.layers.3.{mcore_rest}", + param, + ) + assert len(emitted) == 1 + assert emitted[0][0] == f"model.language_model.layers.3.{hf_tail}" + + +def test_convert_gemma4_to_hf_dense_mlp_sibling(): + conv = _load_convert_module() + conv._config_cache["/nonexistent"] = { + "global_attn_layers": set(), + "local_head_dim": 256, + "global_head_dim": 512, + "num_attention_heads": 16, + "local_num_kv_heads": 8, + "global_num_kv_heads": 2, + "hidden_size": 2816, + } + args = SimpleNamespace(hf_checkpoint="/nonexistent") + + gate = torch.randn(2112, 2816) + up = torch.randn(2112, 2816) + fused = torch.cat([gate, up], dim=0) + + emitted = conv.convert_gemma4_to_hf( + args, + "module.module.decoder.layers.0.dense_mlp.linear_fc1.weight", + fused, + ) + names = {n for n, _ in emitted} + assert names == { + "model.language_model.layers.0.mlp.gate_proj.weight", + "model.language_model.layers.0.mlp.up_proj.weight", + } + + down = torch.randn(2816, 2112) + emitted = conv.convert_gemma4_to_hf( + args, + "module.module.decoder.layers.0.dense_mlp.linear_fc2.weight", + down, + ) + assert emitted == [("model.language_model.layers.0.mlp.down_proj.weight", down)] diff --git a/tests/gemma4/test_gemma4_cp_attention.py b/tests/gemma4/test_gemma4_cp_attention.py new file mode 100644 index 0000000000..9910e97059 --- /dev/null +++ b/tests/gemma4/test_gemma4_cp_attention.py @@ -0,0 +1,281 @@ +import os + +import pytest +import torch +import torch.distributed as dist +import torch.nn.functional as F + + +@pytest.fixture(scope="module", autouse=True) +def _init_dist(): + if dist.is_initialized(): + yield + return + os.environ.setdefault("MASTER_ADDR", "127.0.0.1") + os.environ.setdefault("MASTER_PORT", "29555") + os.environ.setdefault("RANK", "0") + os.environ.setdefault("WORLD_SIZE", "1") + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group(backend=backend, rank=0, world_size=1) + try: + try: + from megatron.core import parallel_state as mpu + + mpu.initialize_model_parallel(context_parallel_size=1) + except Exception: + pass + yield + finally: + dist.destroy_process_group() + + +def _ref_attention(query, key, value, cu_seqlens, scale, sliding_window=None): + t = query.shape[0] + nq, nk = query.shape[1], key.shape[1] + q = query.unsqueeze(0).transpose(1, 2).float() # [1, n, T, h] + k = key.unsqueeze(0).transpose(1, 2).float() + v = value.unsqueeze(0).transpose(1, 2).float() + if nq != nk: + k = k.repeat_interleave(nq // nk, dim=1) + v = v.repeat_interleave(nq // nk, dim=1) + + mask = torch.full((t, t), float("-inf"), device=query.device, dtype=torch.float32) + for i in range(len(cu_seqlens) - 1): + s, e = int(cu_seqlens[i]), int(cu_seqlens[i + 1]) + for qi in range(s, e): + lo = s if sliding_window is None else max(s, qi - sliding_window + 1) + mask[qi, lo : qi + 1] = 0.0 + + out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask[None, None, :, :], scale=scale) + return out.transpose(1, 2).reshape(t, -1).to(query.dtype) + + +def _make_core_attention(sliding_window: int | None, softmax_scale: float): + from types import SimpleNamespace + from slime_plugins.models.gemma4 import SDPACoreAttention + + config = SimpleNamespace( + attention_dropout=0.0, + sliding_window=sliding_window or 1024, + context_parallel_size=1, + ) + core = SDPACoreAttention( + config=config, + layer_number=1, + attn_mask_type=None, + softmax_scale=softmax_scale, + ) + core._is_sliding = sliding_window is not None + return core + + +def _load_core_attention_static_methods(): + try: + from slime_plugins.models.gemma4 import SDPACoreAttention + except ModuleNotFoundError as exc: + missing = exc.name or "" + if not (missing == "megatron" or missing.startswith("megatron.") or missing == "mbridge"): + raise + from tests.gemma4._standalone_imports import load_gemma4_model_module + + return load_gemma4_model_module().SDPACoreAttention + return SDPACoreAttention + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="needs CUDA") +def test_global_thd_sdpa_per_subseq_matches_reference(): + torch.manual_seed(0) + device = "cuda" + dtype = torch.float32 + + nq, nk, hn = 8, 2, 512 + scale = 1.0 / (hn**0.5) + lens = [13, 20, 7] + cu = torch.tensor([0] + list(__import__("itertools").accumulate(lens)), dtype=torch.int32, device=device) + t = int(cu[-1]) + q = torch.randn(t, nq, hn, device=device, dtype=dtype) + k = torch.randn(t, nk, hn, device=device, dtype=dtype) + v = torch.randn(t, nk, hn, device=device, dtype=dtype) + + ref = _ref_attention(q, k, v, cu, scale=scale) + + core = _make_core_attention(sliding_window=None, softmax_scale=scale) + out = core._forward_thd_sdpa_per_subseq(q, k, v, cu) + assert out.shape == (t, nq * hn) + + cos = F.cosine_similarity(ref.flatten().unsqueeze(0), out.flatten().unsqueeze(0)).item() + assert cos > 0.9999, f"global SDPA per-sub-seq mismatch, cosine={cos}" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="needs CUDA") +def test_flash_thd_with_sliding_window(): + try: + import flash_attn # noqa + except ImportError: + pytest.skip("flash_attn not installed") + + torch.manual_seed(1) + device = "cuda" + dtype = torch.bfloat16 + + nq, nk, hn = 16, 8, 256 + scale = 1.0 / (hn**0.5) + lens = [1200, 800] # > sliding_window on the first sequence + cu = torch.tensor([0] + list(__import__("itertools").accumulate(lens)), dtype=torch.int32, device=device) + t = int(cu[-1]) + q = torch.randn(t, nq, hn, device=device, dtype=dtype) + k = torch.randn(t, nk, hn, device=device, dtype=dtype) + v = torch.randn(t, nk, hn, device=device, dtype=dtype) + + core = _make_core_attention(sliding_window=1024, softmax_scale=scale) + out = core._forward_thd_flash(q, k, v, cu) + assert out.shape == (t, nq * hn) + assert not torch.isnan(out).any() + + ref = _ref_attention(q.float(), k.float(), v.float(), cu, scale=scale, sliding_window=1024) + cos = F.cosine_similarity(ref.flatten().unsqueeze(0), out.float().flatten().unsqueeze(0)).item() + assert cos > 0.999, f"flash+sliding mismatch, cosine={cos}" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="needs CUDA") +def test_forward_dispatches_correctly_by_layer_type_and_headdim(): + torch.manual_seed(2) + device = "cuda" + dtype = torch.bfloat16 + + from types import SimpleNamespace + + cu = torch.tensor([0, 64, 192], dtype=torch.int32, device=device) + packed = SimpleNamespace(cu_seqlens_q=cu) + + core = _make_core_attention(sliding_window=1024, softmax_scale=1.0 / (256**0.5)) + q = torch.randn(192, 8, 256, device=device, dtype=dtype) + k = torch.randn(192, 4, 256, device=device, dtype=dtype) + v = torch.randn(192, 4, 256, device=device, dtype=dtype) + out = core.forward(q, k, v, packed_seq_params=packed) + assert out.shape == (192, 8 * 256) + assert not torch.isnan(out).any() + + core_g = _make_core_attention(sliding_window=None, softmax_scale=1.0 / (512**0.5)) + qg = torch.randn(192, 8, 512, device=device, dtype=dtype) + kg = torch.randn(192, 2, 512, device=device, dtype=dtype) + vg = torch.randn(192, 2, 512, device=device, dtype=dtype) + out = core_g.forward(qg, kg, vg, packed_seq_params=packed) + assert out.shape == (192, 8 * 512) + assert not torch.isnan(out).any() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="needs CUDA") +def test_cp_global_gradient_flow_end_to_end(): + torch.manual_seed(3) + device = "cuda" + dtype = torch.float32 + + nq, nk, hn = 8, 2, 512 + scale = 1.0 / (hn**0.5) + cu = torch.tensor([0, 32, 96], dtype=torch.int32, device=device) + t = int(cu[-1]) + from types import SimpleNamespace + + packed = SimpleNamespace(cu_seqlens_q=cu) + q = torch.randn(t, nq, hn, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(t, nk, hn, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(t, nk, hn, device=device, dtype=dtype, requires_grad=True) + + core = _make_core_attention(sliding_window=None, softmax_scale=scale) + core.config.context_parallel_size = 2 + try: + out = core._forward_cp_subseq_mask(q, k, v, packed, sliding_window=None) + except Exception: + pytest.skip("Megatron parallel_state not initialized; skipping CP path smoke test") + + assert out.shape == (t, nq * hn) + assert not torch.isnan(out).any() + out.sum().backward() + assert q.grad is not None and not torch.isnan(q.grad).any() + assert k.grad is not None and not torch.isnan(k.grad).any() + assert v.grad is not None and not torch.isnan(v.grad).any() + assert (k.grad.abs() > 0).any() + assert (v.grad.abs() > 0).any() + + +def test_zigzag_global_indices_cp1_is_identity(): + SDPACoreAttention = _load_core_attention_static_methods() + + device = torch.device("cpu") + idx = SDPACoreAttention._zigzag_global_indices( + local_len=8, + cp_rank=0, + cp_size=1, + device=device, + ) + assert idx.tolist() == list(range(8)) + + +def test_zigzag_global_indices_cp2_matches_slime_slice(): + SDPACoreAttention = _load_core_attention_static_methods() + + device = torch.device("cpu") + idx_r0 = SDPACoreAttention._zigzag_global_indices( + local_len=8, + cp_rank=0, + cp_size=2, + device=device, + ) + idx_r1 = SDPACoreAttention._zigzag_global_indices( + local_len=8, + cp_rank=1, + cp_size=2, + device=device, + ) + assert idx_r0.tolist() == [0, 1, 2, 3, 12, 13, 14, 15] + assert idx_r1.tolist() == [4, 5, 6, 7, 8, 9, 10, 11] + + +def test_cp_unzigzag_permutation_handles_multiple_packed_subseqs(): + SDPACoreAttention = _load_core_attention_static_methods() + + device = torch.device("cpu") + cu = [0, 16, 32] + perm = SDPACoreAttention._cp_unzigzag_permutation(cu, cp_size=2, device=device) + + gathered = torch.tensor( + [ + # rank 0: seq0 chunks 0,3; seq1 chunks 0,3 + 0, + 1, + 2, + 3, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 28, + 29, + 30, + 31, + # rank 1: seq0 chunks 1,2; seq1 chunks 1,2 + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + ], + device=device, + ) + assert gathered.index_select(0, perm).tolist() == list(range(32)) diff --git a/tests/gemma4/test_gemma4_dual_rope.py b/tests/gemma4/test_gemma4_dual_rope.py new file mode 100644 index 0000000000..e72f25ec77 --- /dev/null +++ b/tests/gemma4/test_gemma4_dual_rope.py @@ -0,0 +1,94 @@ +import pytest +import torch + +from tests.gemma4._standalone_imports import load_gemma4_provider_module + +DualRotaryEmbedding = load_gemma4_provider_module().DualRotaryEmbedding + + +class _FakeRope: + def __init__(self, dim: int, tag: float): + self.dim = dim + self.tag = tag + self.calls = [] + + def __call__(self, seq_len, **kwargs): + self.calls.append((seq_len, kwargs)) + s = torch.arange(seq_len, dtype=torch.float).view(seq_len, 1, 1, 1) + d = torch.arange(self.dim, dtype=torch.float).view(1, 1, 1, self.dim) + return s * 100.0 + d + self.tag + + def get_rotary_seq_len(self, *args, **kwargs): + return ("fake_seq_len_result", args, kwargs) + + +def test_dual_rope_concat_shape_global_first(): + local = _FakeRope(dim=256, tag=0.1) + glob = _FakeRope(dim=512, tag=0.9) + dual = DualRotaryEmbedding(local, glob, global_dim=512) + + seq_len = 16 + combined = dual(seq_len) + assert combined.shape == (seq_len, 1, 1, 512 + 256) + + global_slice = combined[..., :512] + local_slice = combined[..., 512:] + assert torch.equal(global_slice, glob(seq_len)) + assert torch.equal(local_slice, local(seq_len)) + + +def test_dual_rope_split_matches_layer_convention(): + global_dim, local_dim = 384, 192 + local = _FakeRope(dim=local_dim, tag=11.0) + glob = _FakeRope(dim=global_dim, tag=22.0) + dual = DualRotaryEmbedding(local, glob, global_dim=global_dim) + + seq_len = 8 + combined = dual(seq_len) + + for is_sliding, expected_rope in [(False, glob), (True, local)]: + if is_sliding: + sliced = combined[..., global_dim:] + else: + sliced = combined[..., :global_dim] + assert torch.equal( + sliced, expected_rope(seq_len) + ), f"split for is_sliding={is_sliding} did not recover the right rope" + + +def test_dual_rope_delegates_get_rotary_seq_len_to_local(): + local = _FakeRope(dim=256, tag=0.0) + glob = _FakeRope(dim=512, tag=0.0) + dual = DualRotaryEmbedding(local, glob, global_dim=512) + + result = dual.get_rotary_seq_len("a", b=2) + assert result[0] == "fake_seq_len_result" + assert result[1] == ("a",) + assert result[2] == {"b": 2} + + +def test_dual_rope_forwards_packed_seq_params_to_both_ropes(): + local = _FakeRope(dim=4, tag=0.0) + glob = _FakeRope(dim=8, tag=0.0) + dual = DualRotaryEmbedding(local, glob, global_dim=8) + packed_seq_params = object() + + combined = dual(12, offset=3, packed_seq_params=packed_seq_params) + + assert combined.shape == (12, 1, 1, 12) + assert glob.calls == [(12, {"offset": 3, "packed_seq_params": packed_seq_params})] + assert local.calls == [(12, {"offset": 3, "packed_seq_params": packed_seq_params})] + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Megatron RotaryEmbedding.forward requires CUDA") +def test_dual_rope_end_to_end_with_real_megatron_rope(): + from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding + + local = RotaryEmbedding(kv_channels=256, rotary_percent=1.0, rotary_base=10_000.0) + glob = RotaryEmbedding(kv_channels=512, rotary_percent=1.0, rotary_base=1_000_000.0) + dual = DualRotaryEmbedding(local, glob, global_dim=512) + + combined = dual(64) + assert combined.shape[-1] == 512 + 256 + assert torch.equal(combined[..., :512], glob(64)) + assert torch.equal(combined[..., 512:], local(64)) diff --git a/tests/gemma4/test_gemma4_hf_key_contract.py b/tests/gemma4/test_gemma4_hf_key_contract.py new file mode 100644 index 0000000000..08c324f83f --- /dev/null +++ b/tests/gemma4/test_gemma4_hf_key_contract.py @@ -0,0 +1,151 @@ +import importlib.util +import pathlib +from types import SimpleNamespace + +import pytest +import torch + + +def _load_convert_module(): + repo_path = pathlib.Path(__file__).resolve().parents[2] / ( + "slime/backends/megatron_utils/megatron_to_hf/gemma4.py" + ) + spec = importlib.util.spec_from_file_location("_gemma4_key_contract_converter", repo_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +def _mcore_keys_tiny_moe(num_experts: int = 2) -> list[str]: + base = [ + "module.module.embedding.word_embeddings.weight", + "module.module.decoder.final_layernorm.weight", + ] + base.append("module.module.output_layer.weight") + for layer_idx in (0, 1): + prefix = f"module.module.decoder.layers.{layer_idx}" + base.extend( + [ + f"{prefix}.self_attention.linear_qkv.weight", + f"{prefix}.self_attention.linear_qkv.layer_norm_weight", + f"{prefix}.self_attention.linear_proj.weight", + f"{prefix}.self_attention.q_layernorm.weight", + f"{prefix}.self_attention.k_layernorm.weight", + f"{prefix}.post_attention_layernorm.weight", + f"{prefix}.layer_scalar", + f"{prefix}.dense_mlp.linear_fc1.weight", + f"{prefix}.dense_mlp.linear_fc1.layer_norm_weight", + f"{prefix}.dense_mlp.linear_fc2.weight", + f"{prefix}.pre_mlp_layernorm.weight", + f"{prefix}.post_feedforward_layernorm.weight", + f"{prefix}.post_feedforward_layernorm_1.weight", + f"{prefix}.post_feedforward_layernorm_2.weight", + f"{prefix}.mlp.pre_feedforward_layernorm_2.weight", + f"{prefix}.mlp.router.proj.weight", + f"{prefix}.mlp.router.scale", + f"{prefix}.mlp.router.per_expert_scale", + ] + ) + for e in range(num_experts): + base.extend( + [ + f"{prefix}.mlp.experts.linear_fc1.weight{e}", + f"{prefix}.mlp.experts.linear_fc2.weight{e}", + ] + ) + return base + + +def _build_tiny_hf_model(): + from transformers.models.gemma4 import configuration_gemma4 as C + from transformers.models.gemma4 import modeling_gemma4 as M + + text_cfg = C.Gemma4TextConfig( + vocab_size=64, + hidden_size=32, + intermediate_size=64, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=2, + num_global_key_value_heads=2, + head_dim=16, + global_head_dim=32, + sliding_window=64, + rope_theta=10000.0, + layer_types=["sliding_attention", "full_attention"], + enable_moe_block=True, + num_experts=2, + moe_intermediate_size=48, + top_k_experts=2, + hidden_size_per_layer_input=0, + attention_k_eq_v=True, + ) + full_cfg = C.Gemma4Config( + text_config=text_cfg.to_dict(), + vision_config=None, + audio_config=None, + ) + hf_model = M.Gemma4ForConditionalGeneration(full_cfg) + return set(k for k in hf_model.state_dict().keys() if "language_model" in k) + + +def test_converter_emits_every_hf_key(): + transformers_gemma4 = pytest.importorskip("transformers.models.gemma4") + del transformers_gemma4 # only needed to gate + + conv = _load_convert_module() + + conv._config_cache["/nonexistent"] = { + "global_attn_layers": {1}, # layer 1 is full_attention + "local_head_dim": 16, + "global_head_dim": 32, + "num_attention_heads": 4, + "local_num_kv_heads": 2, + "global_num_kv_heads": 2, + "hidden_size": 32, + "num_experts": 2, + } + conv.reset_expert_buffers() + + args = SimpleNamespace(hf_checkpoint="/nonexistent") + + def _fake_tensor_for(name: str) -> torch.Tensor: + if name.endswith("self_attention.linear_qkv.weight"): + if "layers.1" in name: + return torch.zeros(256, 32) + return torch.zeros(128, 32) + if name.endswith("self_attention.linear_proj.weight"): + return torch.zeros(32, 64) + if "dense_mlp.linear_fc1.weight" in name: + return torch.zeros(128, 32) + if "dense_mlp.linear_fc2.weight" in name: + return torch.zeros(32, 64) + if "mlp.router.proj.weight" in name: + return torch.zeros(2, 32) + if "mlp.router.scale" in name or "mlp.router.per_expert_scale" in name: + return torch.zeros(2) + if "experts.linear_fc1.weight" in name: + return torch.zeros(96, 32) + if "experts.linear_fc2.weight" in name: + return torch.zeros(32, 48) + if "embedding.word_embeddings" in name or "output_layer" in name: + return torch.zeros(64, 32) + if "layer_scalar" in name: + return torch.tensor([1.0]) + return torch.zeros(32) + + emitted: set[str] = set() + for mcore_name in _mcore_keys_tiny_moe(num_experts=2): + t = _fake_tensor_for(mcore_name) + out = conv.convert_gemma4_to_hf(args, mcore_name, t) + for hf_name, _hf_param in out: + emitted.add(hf_name) + + expected = _build_tiny_hf_model() + + missing = expected - emitted + assert not missing, ( + f"HF expects {len(missing)} key(s) the converter never emits; this " + f"would surface as a weight-load crash or silently-random weights in " + f"sglang. Missing:\n " + "\n ".join(sorted(missing)) + ) diff --git a/tests/gemma4/test_gemma4_layer_integration.py b/tests/gemma4/test_gemma4_layer_integration.py new file mode 100644 index 0000000000..0b7d355ad9 --- /dev/null +++ b/tests/gemma4/test_gemma4_layer_integration.py @@ -0,0 +1,219 @@ +import os + +import pytest +import torch + +requires_cuda = pytest.mark.skipif( + not torch.cuda.is_available(), + reason="Gemma4TransformerLayer requires CUDA + TE kernels", +) + + +def _init_single_rank_dist(): + import torch.distributed as dist + + try: + from megatron.core import parallel_state as mpu + except ImportError: + pytest.skip("Megatron-LM parallel_state is not installed") + + if mpu.model_parallel_is_initialized(): + mpu.destroy_model_parallel() + if not dist.is_initialized(): + os.environ.setdefault("MASTER_ADDR", "127.0.0.1") + os.environ.setdefault("MASTER_PORT", "29566") + os.environ.setdefault("RANK", "0") + os.environ.setdefault("WORLD_SIZE", "1") + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group(backend=backend, rank=0, world_size=1) + mpu.initialize_model_parallel() + + +@pytest.fixture(scope="module", autouse=True) +def _dist(): + _init_single_rank_dist() + yield + + +def _build_layer_config( + num_layers=6, + hidden_size=128, + ffn_hidden_size=256, + num_heads=8, + num_kv_heads=4, + head_dim=128, + global_head_dim=256, + num_global_kv_heads=2, + sliding_window=64, +): + from slime_plugins.models.gemma4 import Gemma4TransformerConfig + + cfg = Gemma4TransformerConfig( + num_layers=num_layers, + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_attention_heads=num_heads, + num_query_groups=num_kv_heads, + kv_channels=head_dim, + hidden_dropout=0.0, + attention_dropout=0.0, + bf16=True, + pipeline_dtype=torch.bfloat16, + params_dtype=torch.bfloat16, + add_bias_linear=False, + add_qkv_bias=False, + gated_linear_unit=True, + activation_func=torch.nn.functional.gelu, # placeholder + normalization="RMSNorm", + layernorm_epsilon=1e-6, + attention_softmax_in_fp32=True, + persist_layer_norm=True, + bias_activation_fusion=False, + bias_dropout_fusion=True, + apply_rope_fusion=False, + qk_layernorm=True, + sequence_parallel=False, + tensor_model_parallel_size=1, + ) + cfg.global_kv_channels = global_head_dim + cfg.global_num_query_groups = num_global_kv_heads + cfg.global_partial_rotary_factor = 0.25 + cfg.attention_k_eq_v = True + cfg.final_logit_softcapping = 30.0 + cfg.enable_moe_block = False + cfg.sliding_window = sliding_window + cfg.sliding_window_pattern = 6 + cfg.softmax_scale = 1.0 + return cfg + + +@requires_cuda +def test_layer_builds_and_forwards_sliding(): + from functools import partial + + import torch.nn.functional as F + from megatron.core.transformer.spec_utils import build_module + + from slime_plugins.models.gemma4 import get_gemma4_layer_spec_te + + cfg = _build_layer_config() + cfg.activation_func = partial(F.gelu, approximate="tanh") + spec = get_gemma4_layer_spec_te(cfg) + + layer = build_module(spec, config=cfg, layer_number=1) + layer = layer.cuda().to(torch.bfloat16) + assert layer.is_sliding is True + assert layer._is_global is False + + seq, batch = 16, 1 + h = torch.randn(seq, batch, cfg.hidden_size, device="cuda", dtype=torch.bfloat16) + + from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding + + rope = RotaryEmbedding(kv_channels=cfg.kv_channels, rotary_percent=1.0) + rotary = rope(seq).cuda() + + out, _ctx = layer(h, rotary_pos_emb=rotary, attention_mask=None) + assert out.shape == h.shape + assert torch.isfinite(out).all() + + +@requires_cuda +def test_layer_global_path_builds_and_forwards(): + from functools import partial + + import torch.nn.functional as F + from megatron.core.transformer.spec_utils import build_module + + from slime_plugins.models.gemma4 import get_gemma4_layer_spec_te + + cfg = _build_layer_config() + cfg.activation_func = partial(F.gelu, approximate="tanh") + spec = get_gemma4_layer_spec_te(cfg) + + layer = build_module(spec, config=cfg, layer_number=6) + layer = layer.cuda().to(torch.bfloat16) + assert layer.is_sliding is False + assert layer._is_global is True + + seq, batch = 16, 1 + h = torch.randn(seq, batch, cfg.hidden_size, device="cuda", dtype=torch.bfloat16) + + from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding + + rope = RotaryEmbedding(kv_channels=cfg.global_kv_channels, rotary_percent=1.0) + rotary = rope(seq).cuda() + + out, _ctx = layer(h, rotary_pos_emb=rotary, attention_mask=None) + assert out.shape == h.shape + assert torch.isfinite(out).all() + + +@requires_cuda +def test_layer_does_not_mutate_shared_config(): + from functools import partial + + import torch.nn.functional as F + from megatron.core.transformer.spec_utils import build_module + + from slime_plugins.models.gemma4 import get_gemma4_layer_spec_te + + cfg = _build_layer_config() + cfg.activation_func = partial(F.gelu, approximate="tanh") + orig_kv = cfg.kv_channels + orig_nqg = cfg.num_query_groups + + spec = get_gemma4_layer_spec_te(cfg) + build_module(spec, config=cfg, layer_number=6).cuda() + assert cfg.kv_channels == orig_kv, ( + f"building a global layer mutated shared config.kv_channels: " f"{orig_kv} -> {cfg.kv_channels}" + ) + assert cfg.num_query_groups == orig_nqg, ( + f"building a global layer mutated shared config.num_query_groups: " f"{orig_nqg} -> {cfg.num_query_groups}" + ) + + +def test_layer_spec_builds_without_cuda(): + from functools import partial + + import torch.nn.functional as F + + from slime_plugins.models.gemma4 import Gemma4SelfAttention, Gemma4TransformerLayer, get_gemma4_layer_spec_te + + cfg = _build_layer_config() + cfg.activation_func = partial(F.gelu, approximate="tanh") + spec = get_gemma4_layer_spec_te(cfg) + + assert spec.module is Gemma4TransformerLayer + assert spec.submodules.self_attention.module is Gemma4SelfAttention + from megatron.core.transformer.identity_op import IdentityOp + + assert spec.submodules.post_attention_layernorm is not IdentityOp + assert spec.submodules.post_feedforward_layernorm is not IdentityOp + + +def test_layer_spec_moe_variant_includes_dense_mlp_spec(): + from functools import partial + + import torch.nn.functional as F + from megatron.core.transformer.identity_op import IdentityOp + + from slime_plugins.models.gemma4 import Gemma4MoELayer, get_gemma4_layer_spec_te + + cfg = _build_layer_config() + cfg.activation_func = partial(F.gelu, approximate="tanh") + cfg.enable_moe_block = True + cfg.num_moe_experts = 8 + cfg.moe_router_topk = 2 + cfg.moe_ffn_hidden_size = 128 + cfg.moe_token_dispatcher_type = "alltoall" + cfg.moe_grouped_gemm = True + cfg.moe_aux_loss_coeff = 0.0 + cfg.moe_router_load_balancing_type = "none" + cfg.moe_router_score_function = "softmax" + cfg.moe_router_topk_scaling_factor = 1.0 + cfg.moe_router_pre_softmax = False + + spec = get_gemma4_layer_spec_te(cfg) + assert spec.submodules.mlp.module is Gemma4MoELayer + assert spec.submodules.dense_mlp is not IdentityOp, "dense_mlp must be a concrete spec when enable_moe_block=True" diff --git a/tests/gemma4/test_gemma4_layer_scalar_broadcast.py b/tests/gemma4/test_gemma4_layer_scalar_broadcast.py new file mode 100644 index 0000000000..552f3f56a0 --- /dev/null +++ b/tests/gemma4/test_gemma4_layer_scalar_broadcast.py @@ -0,0 +1,100 @@ +import json +import os +import tempfile + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + + +def _worker(rank: int, world_size: int, master_port: int, ckpt_dir: str, out_dir: str): + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(master_port) + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + try: + try: + import megatron.core.transformer.transformer_layer as tl + except ModuleNotFoundError: + from tests.gemma4._standalone_imports import install_mbridge_stubs, install_megatron_stubs + + install_megatron_stubs() + install_mbridge_stubs() + import megatron.core.transformer.transformer_layer as tl + + from slime_plugins.models import gemma4_provider as _provider + + inner = torch.nn.Module() + inner.decoder = torch.nn.Module() + layers = [] + for _ in range(3): + layer = torch.nn.Module() + layer.register_buffer("layer_scalar", torch.ones(1)) + layers.append(layer) + inner.decoder.layers = torch.nn.ModuleList(layers) + + orig_offset = tl.get_transformer_layer_offset + tl.get_transformer_layer_offset = lambda _cfg: 0 + try: + _provider._load_layer_scalars(inner, ckpt_dir, config=type("C", (), {})()) + finally: + tl.get_transformer_layer_offset = orig_offset + + loaded = [layer.layer_scalar.item() for layer in inner.decoder.layers] + out_path = os.path.join(out_dir, f"rank{rank}.json") + with open(out_path, "w") as fp: + json.dump({"rank": rank, "scalars": loaded}, fp) + finally: + dist.destroy_process_group() + + +def _write_fake_checkpoint(ckpt_dir: str, scalars: dict[int, float]) -> None: + from safetensors.torch import save_file + + weight_map = {} + for layer_idx, value in scalars.items(): + tensor_name = f"model.language_model.layers.{layer_idx}.layer_scalar" + fname = f"layer_{layer_idx}.safetensors" + save_file( + {tensor_name: torch.tensor([value], dtype=torch.float32)}, + os.path.join(ckpt_dir, fname), + ) + weight_map[tensor_name] = fname + + with open(os.path.join(ckpt_dir, "model.safetensors.index.json"), "w") as fp: + json.dump({"metadata": {}, "weight_map": weight_map}, fp) + + +def test_layer_scalars_broadcast_to_all_ranks(): + expected = {0: 0.5, 1: 1.25, 2: 2.0} + + with tempfile.TemporaryDirectory() as tmp: + ckpt_dir = os.path.join(tmp, "ckpt") + os.makedirs(ckpt_dir) + _write_fake_checkpoint(ckpt_dir, expected) + + out_dir = os.path.join(tmp, "out") + os.makedirs(out_dir) + master_port = 29577 + + mp.spawn( + _worker, + args=(2, master_port, ckpt_dir, out_dir), + nprocs=2, + join=True, + ) + + with open(os.path.join(out_dir, "rank0.json")) as fp: + r0 = json.load(fp) + with open(os.path.join(out_dir, "rank1.json")) as fp: + r1 = json.load(fp) + + assert r0["rank"] == 0 + assert r1["rank"] == 1 + assert r0["scalars"] == pytest.approx([0.5, 1.25, 2.0]) + assert r1["scalars"] == pytest.approx([0.5, 1.25, 2.0]), ( + "rank 1 did not receive the broadcast scalars; check " "_broadcast_layer_scalars" + ) diff --git a/tests/gemma4/test_gemma4_provider.py b/tests/gemma4/test_gemma4_provider.py new file mode 100644 index 0000000000..0b782f925f --- /dev/null +++ b/tests/gemma4/test_gemma4_provider.py @@ -0,0 +1,332 @@ +import json +from types import SimpleNamespace + +import pytest +import torch + +from tests.gemma4._standalone_imports import load_gemma4_provider_module + +_provider = load_gemma4_provider_module() + + +def test_install_hooks_softcap_wraps_tensor_output(): + inner = torch.nn.Module() + inner.output_layer = torch.nn.Linear(4, 8, bias=False) + + hf_text = SimpleNamespace(final_logit_softcapping=30.0) + orig = _provider._load_hf_text_config + _provider._load_hf_text_config = lambda _path: hf_text + try: + args = SimpleNamespace(hf_checkpoint="/nonexistent") + config = SimpleNamespace(hidden_size=4) + _provider._install_hooks( + model=inner, + args=args, + config=config, + pre_process=False, + post_process=True, + ) + finally: + _provider._load_hf_text_config = orig + + x = torch.randn(2, 4) + raw = x @ inner.output_layer.weight.T + hooked = inner.output_layer(x) + expected = torch.tanh(raw / 30.0) * 30.0 + assert torch.allclose(hooked, expected, atol=1e-6) + assert hooked.abs().max().item() <= 30.0 + + +def test_install_hooks_softcap_reuses_storage_with_correct_gradient(): + class _CaptureOutput(torch.nn.Module): + def __init__(self): + super().__init__() + self.raw = None + self.raw_before = None + + def forward(self, x): + self.raw = x * 1.0 + self.raw_before = self.raw.detach().clone() + return self.raw + + inner = torch.nn.Module() + inner.output_layer = _CaptureOutput() + + hf_text = SimpleNamespace(final_logit_softcapping=30.0) + orig = _provider._load_hf_text_config + _provider._load_hf_text_config = lambda _path: hf_text + try: + args = SimpleNamespace(hf_checkpoint="/nonexistent") + config = SimpleNamespace(hidden_size=4) + _provider._install_hooks( + model=inner, + args=args, + config=config, + pre_process=False, + post_process=True, + ) + finally: + _provider._load_hf_text_config = orig + + base = torch.linspace(-3.0, 3.0, steps=12, dtype=torch.float64).view(3, 4) + base.requires_grad_(True) + weights = torch.linspace(0.1, 1.2, steps=12, dtype=torch.float64).view(3, 4) + + hooked = inner.output_layer(base) + (hooked * weights).sum().backward() + + expected = 30.0 * torch.tanh(inner.output_layer.raw_before / 30.0) + expected_grad = weights * (1.0 - torch.tanh(inner.output_layer.raw_before / 30.0).pow(2)) + assert hooked.data_ptr() == inner.output_layer.raw.data_ptr() + assert torch.allclose(hooked, expected) + assert torch.allclose(base.grad, expected_grad) + + +def test_install_hooks_softcap_wraps_tuple_output(): + inner = torch.nn.Module() + + class _TupleOutLayer(torch.nn.Module): + def __init__(self): + super().__init__() + self.w = torch.nn.Parameter(torch.randn(8, 4)) + + def forward(self, x): + return x @ self.w.T, None # (output, bias) + + inner.output_layer = _TupleOutLayer() + hf_text = SimpleNamespace(final_logit_softcapping=30.0) + orig = _provider._load_hf_text_config + _provider._load_hf_text_config = lambda _path: hf_text + try: + args = SimpleNamespace(hf_checkpoint="/nonexistent") + config = SimpleNamespace(hidden_size=4) + _provider._install_hooks( + model=inner, + args=args, + config=config, + pre_process=False, + post_process=True, + ) + finally: + _provider._load_hf_text_config = orig + + x = torch.randn(3, 4) + hooked, bias = inner.output_layer(x) + raw = x @ inner.output_layer.w.T + expected = torch.tanh(raw / 30.0) * 30.0 + assert torch.allclose(hooked, expected, atol=1e-6) + assert bias is None # tuple tail preserved + + +def test_install_hooks_no_softcap_when_disabled(): + inner = torch.nn.Module() + inner.output_layer = torch.nn.Linear(4, 8, bias=False) + + for cap_value in (None, 0, 0.0): + for h in list(inner.output_layer._forward_hooks.keys()): + inner.output_layer._forward_hooks.pop(h) + + hf_text = SimpleNamespace(final_logit_softcapping=cap_value) + orig = _provider._load_hf_text_config + _provider._load_hf_text_config = lambda _p, _t=hf_text: _t + try: + args = SimpleNamespace(hf_checkpoint="/nonexistent") + config = SimpleNamespace(hidden_size=4) + _provider._install_hooks( + model=inner, + args=args, + config=config, + pre_process=False, + post_process=True, + ) + finally: + _provider._load_hf_text_config = orig + assert len(inner.output_layer._forward_hooks) == 0, f"softcap hook should not register when cap={cap_value!r}" + + +def _install_embed_hook(inner, hidden): + hf_text = SimpleNamespace(final_logit_softcapping=None) + orig = _provider._load_hf_text_config + _provider._load_hf_text_config = lambda _path: hf_text + try: + args = SimpleNamespace(hf_checkpoint="/nonexistent") + config = SimpleNamespace(hidden_size=hidden) + _provider._install_hooks( + model=inner, + args=args, + config=config, + pre_process=True, + post_process=False, + ) + finally: + _provider._load_hf_text_config = orig + + +def test_install_hooks_embedding_scale_fp32_weight(): + hidden = 1024 + inner = torch.nn.Module() + inner.embedding = torch.nn.Embedding(100, hidden) # fp32 by default + _install_embed_hook(inner, hidden) + + ids = torch.tensor([[1, 2, 3]]) + hooked = inner.embedding(ids) + raw = inner.embedding.weight[ids] + expected_scale = torch.tensor(hidden**0.5) + assert torch.allclose(hooked, raw * expected_scale, atol=1e-6) + + +def test_install_hooks_embedding_scale_bf16_weight(): + hidden = 1024 + inner = torch.nn.Module() + inner.embedding = torch.nn.Embedding(100, hidden).to(torch.bfloat16) + _install_embed_hook(inner, hidden) + + ids = torch.tensor([[1, 2, 3]]) + hooked = inner.embedding(ids) + raw = inner.embedding.weight[ids] + expected_scale = torch.tensor(hidden**0.5).to(torch.bfloat16) + assert torch.allclose(hooked, raw * expected_scale, atol=1e-2) + + +def _write_fake_safetensors_layer_scalars(ckpt_dir, scalars): + from safetensors.torch import save_file + + weight_map = {} + for layer_idx, value in scalars.items(): + tensor_name = f"model.language_model.layers.{layer_idx}.layer_scalar" + fname = f"layer_{layer_idx}.safetensors" + save_file({tensor_name: torch.tensor(value)}, str(ckpt_dir / fname)) + weight_map[tensor_name] = fname + index = {"metadata": {}, "weight_map": weight_map} + (ckpt_dir / "model.safetensors.index.json").write_text(json.dumps(index)) + + +def test_load_layer_scalars_applies_values_to_layers(tmp_path): + scalars = {0: 0.5, 1: 1.5, 2: 2.5} + _write_fake_safetensors_layer_scalars(tmp_path, scalars) + + inner = torch.nn.Module() + inner.decoder = torch.nn.Module() + layers = [] + for _ in range(3): + layer = torch.nn.Module() + layer.register_buffer("layer_scalar", torch.ones(1)) + layers.append(layer) + inner.decoder.layers = torch.nn.ModuleList(layers) + + import megatron.core.transformer.transformer_layer as tl + + orig_offset = tl.get_transformer_layer_offset + tl.get_transformer_layer_offset = lambda _cfg: 0 + try: + _provider._load_layer_scalars(inner, str(tmp_path), config=SimpleNamespace()) + finally: + tl.get_transformer_layer_offset = orig_offset + + for i, expected in scalars.items(): + assert inner.decoder.layers[i].layer_scalar.item() == pytest.approx(expected) + + +def test_load_layer_scalars_respects_pp_offset(tmp_path): + scalars = {10: 0.7, 11: 0.8, 12: 0.9} + _write_fake_safetensors_layer_scalars(tmp_path, scalars) + + inner = torch.nn.Module() + inner.decoder = torch.nn.Module() + layers = [] + for _ in range(3): + layer = torch.nn.Module() + layer.register_buffer("layer_scalar", torch.ones(1)) + layers.append(layer) + inner.decoder.layers = torch.nn.ModuleList(layers) + + import megatron.core.transformer.transformer_layer as tl + + orig_offset = tl.get_transformer_layer_offset + tl.get_transformer_layer_offset = lambda _cfg: 10 # PP offset + try: + _provider._load_layer_scalars(inner, str(tmp_path), config=SimpleNamespace()) + finally: + tl.get_transformer_layer_offset = orig_offset + + assert inner.decoder.layers[0].layer_scalar.item() == pytest.approx(0.7) + assert inner.decoder.layers[1].layer_scalar.item() == pytest.approx(0.8) + assert inner.decoder.layers[2].layer_scalar.item() == pytest.approx(0.9) + + +def test_load_layer_scalars_raises_by_default_when_missing(tmp_path, monkeypatch): + monkeypatch.delenv("GEMMA4_ALLOW_MISSING_LAYER_SCALARS", raising=False) + scalars = {0: 0.5} + _write_fake_safetensors_layer_scalars(tmp_path, scalars) + + inner = torch.nn.Module() + inner.decoder = torch.nn.Module() + layers = [] + for _ in range(2): + layer = torch.nn.Module() + layer.register_buffer("layer_scalar", torch.ones(1)) + layers.append(layer) + inner.decoder.layers = torch.nn.ModuleList(layers) + + import megatron.core.transformer.transformer_layer as tl + + orig_offset = tl.get_transformer_layer_offset + tl.get_transformer_layer_offset = lambda _cfg: 0 + try: + with pytest.raises(KeyError, match="missing in checkpoint"): + _provider._load_layer_scalars(inner, str(tmp_path), config=SimpleNamespace()) + finally: + tl.get_transformer_layer_offset = orig_offset + + +def test_load_layer_scalars_defaults_to_one_when_missing_with_opt_in(tmp_path, monkeypatch): + monkeypatch.setenv("GEMMA4_ALLOW_MISSING_LAYER_SCALARS", "1") + scalars = {0: 0.5} + _write_fake_safetensors_layer_scalars(tmp_path, scalars) + + inner = torch.nn.Module() + inner.decoder = torch.nn.Module() + layers = [] + for _ in range(2): + layer = torch.nn.Module() + layer.register_buffer("layer_scalar", torch.ones(1)) + layers.append(layer) + inner.decoder.layers = torch.nn.ModuleList(layers) + + import megatron.core.transformer.transformer_layer as tl + + orig_offset = tl.get_transformer_layer_offset + tl.get_transformer_layer_offset = lambda _cfg: 0 + try: + _provider._load_layer_scalars(inner, str(tmp_path), config=SimpleNamespace()) + finally: + tl.get_transformer_layer_offset = orig_offset + + assert inner.decoder.layers[0].layer_scalar.item() == pytest.approx(0.5) + assert inner.decoder.layers[1].layer_scalar.item() == pytest.approx(1.0) + + +def test_load_layer_scalars_raises_when_no_index_file(tmp_path, monkeypatch): + monkeypatch.delenv("GEMMA4_ALLOW_MISSING_LAYER_SCALARS", raising=False) + inner = torch.nn.Module() + inner.decoder = torch.nn.Module() + inner.decoder.layers = torch.nn.ModuleList([torch.nn.Module()]) + inner.decoder.layers[0].register_buffer("layer_scalar", torch.ones(1)) + + with pytest.raises(RuntimeError, match="No layer_scalar weights found"): + _provider._load_layer_scalars(inner, str(tmp_path), config=SimpleNamespace()) + + +def test_load_layer_scalars_skips_when_no_index_file_with_opt_in(tmp_path, monkeypatch, caplog): + import logging + + monkeypatch.setenv("GEMMA4_ALLOW_MISSING_LAYER_SCALARS", "1") + inner = torch.nn.Module() + inner.decoder = torch.nn.Module() + inner.decoder.layers = torch.nn.ModuleList([torch.nn.Module()]) + inner.decoder.layers[0].register_buffer("layer_scalar", torch.ones(1)) + + with caplog.at_level(logging.WARNING, logger=_provider.__name__): + _provider._load_layer_scalars(inner, str(tmp_path), config=SimpleNamespace()) + assert inner.decoder.layers[0].layer_scalar.item() == 1.0 + assert any("No safetensors index" in r.message for r in caplog.records) diff --git a/tests/gemma4/test_gemma4_qkv_roundtrip.py b/tests/gemma4/test_gemma4_qkv_roundtrip.py new file mode 100644 index 0000000000..3d953ba93e --- /dev/null +++ b/tests/gemma4/test_gemma4_qkv_roundtrip.py @@ -0,0 +1,192 @@ +import importlib +import importlib.util +import pathlib +from types import SimpleNamespace + +import pytest +import torch + +from tests.gemma4._standalone_imports import load_gemma4_bridge_class + +Gemma4Bridge = load_gemma4_bridge_class() + + +def _load_convert_module(): + try: + return importlib.import_module("slime.backends.megatron_utils.megatron_to_hf.gemma4") + except ImportError: + pass + repo_path = pathlib.Path(__file__).resolve().parents[2] / ( + "slime/backends/megatron_utils/megatron_to_hf/gemma4.py" + ) + if not repo_path.exists(): + pytest.skip(f"convert module not found at {repo_path}") + spec = importlib.util.spec_from_file_location("_gemma4_conv_rt", repo_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +CFG_31B = SimpleNamespace( + hidden_size=5376, + num_attention_heads=32, + head_dim=256, + num_key_value_heads=16, + global_head_dim=512, + num_global_key_value_heads=4, + num_hidden_layers=60, + attention_k_eq_v=True, + layer_types=(["sliding_attention"] * 5 + ["full_attention"]) * 10, +) +_GLOBAL_LAYERS_31B = {i for i, t in enumerate(CFG_31B.layer_types) if t == "full_attention"} + + +def _build_bridge_stub(cfg): + b = object.__new__(Gemma4Bridge) + b._GLOBAL_ATTN_LAYERS = {i for i, t in enumerate(cfg.layer_types) if t == "full_attention"} + b.hf_config = SimpleNamespace(text_config=cfg) + return b + + +def _prime_convert_config(conv): + conv._config_cache["/nonexistent"] = { + "global_attn_layers": _GLOBAL_LAYERS_31B, + "local_head_dim": CFG_31B.head_dim, + "global_head_dim": CFG_31B.global_head_dim, + "num_attention_heads": CFG_31B.num_attention_heads, + "local_num_kv_heads": CFG_31B.num_key_value_heads, + "global_num_kv_heads": CFG_31B.num_global_key_value_heads, + "hidden_size": CFG_31B.hidden_size, + } + + +def test_sliding_layer_qkv_roundtrip(): + torch.manual_seed(0) + conv = _load_convert_module() + _prime_convert_config(conv) + bridge = _build_bridge_stub(CFG_31B) + + layer_idx = 0 + q = torch.randn(CFG_31B.num_attention_heads * CFG_31B.head_dim, CFG_31B.hidden_size) + k = torch.randn(CFG_31B.num_key_value_heads * CFG_31B.head_dim, CFG_31B.hidden_size) + v = torch.randn(CFG_31B.num_key_value_heads * CFG_31B.head_dim, CFG_31B.hidden_size) + + mcore_name = f"decoder.layers.{layer_idx}.self_attention.linear_qkv.weight" + packed = bridge._weight_to_mcore_format(mcore_name, [q, k, v]) + assert packed.shape == ( + CFG_31B.num_attention_heads * CFG_31B.head_dim + 2 * CFG_31B.num_key_value_heads * CFG_31B.head_dim, + CFG_31B.hidden_size, + ) + + args = SimpleNamespace(hf_checkpoint="/nonexistent") + emitted = conv.convert_gemma4_to_hf( + args, + f"module.module.{mcore_name}", + packed, + ) + out = dict(emitted) + assert set(out) == { + f"model.language_model.layers.{layer_idx}.self_attn.q_proj.weight", + f"model.language_model.layers.{layer_idx}.self_attn.k_proj.weight", + f"model.language_model.layers.{layer_idx}.self_attn.v_proj.weight", + } + assert torch.allclose(out[f"model.language_model.layers.{layer_idx}.self_attn.q_proj.weight"], q) + assert torch.allclose(out[f"model.language_model.layers.{layer_idx}.self_attn.k_proj.weight"], k) + assert torch.allclose(out[f"model.language_model.layers.{layer_idx}.self_attn.v_proj.weight"], v) + + +def test_global_k_eq_v_layer_qkv_roundtrip(): + torch.manual_seed(1) + conv = _load_convert_module() + _prime_convert_config(conv) + bridge = _build_bridge_stub(CFG_31B) + + layer_idx = 5 + assert layer_idx in _GLOBAL_LAYERS_31B + + q = torch.randn(CFG_31B.num_attention_heads * CFG_31B.global_head_dim, CFG_31B.hidden_size) + k = torch.randn(CFG_31B.num_global_key_value_heads * CFG_31B.global_head_dim, CFG_31B.hidden_size) + + mcore_name = f"decoder.layers.{layer_idx}.self_attention.linear_qkv.weight" + packed = bridge._weight_to_mcore_format(mcore_name, [q, k]) + q_per_kv = CFG_31B.num_attention_heads // CFG_31B.num_global_key_value_heads + expected_rows = CFG_31B.num_global_key_value_heads * (q_per_kv + 2) * CFG_31B.global_head_dim + assert packed.shape == (expected_rows, CFG_31B.hidden_size) + + args = SimpleNamespace(hf_checkpoint="/nonexistent") + emitted = conv.convert_gemma4_to_hf( + args, + f"module.module.{mcore_name}", + packed, + ) + out = dict(emitted) + assert set(out) == { + f"model.language_model.layers.{layer_idx}.self_attn.q_proj.weight", + f"model.language_model.layers.{layer_idx}.self_attn.k_proj.weight", + } + assert torch.allclose(out[f"model.language_model.layers.{layer_idx}.self_attn.q_proj.weight"], q) + assert torch.allclose(out[f"model.language_model.layers.{layer_idx}.self_attn.k_proj.weight"], k) + + +def test_global_qkv_pack_uses_hf_tensor_count_not_local_layer_name(): + cfg = SimpleNamespace( + hidden_size=6, + num_attention_heads=4, + head_dim=1, + num_key_value_heads=2, + global_head_dim=2, + num_global_key_value_heads=2, + num_hidden_layers=1, + attention_k_eq_v=True, + layer_types=["sliding_attention"], + ) + bridge = _build_bridge_stub(cfg) + q = torch.arange(48, dtype=torch.float32).view(8, 6) + k = torch.arange(24, dtype=torch.float32).view(4, 6) + 1000 + + packed = bridge._weight_to_mcore_format( + "decoder.layers.0.self_attention.linear_qkv.weight", + [q, k], + ) + + expected = torch.cat( + [q.view(2, 4, 6), k.view(2, 2, 6), k.view(2, 2, 6)], + dim=1, + ).view(-1, 6) + assert torch.equal(packed, expected) + + +def test_sliding_layer_roundtrip_rejects_wrong_shape(): + bridge = _build_bridge_stub(CFG_31B) + + q_bad = torch.randn(CFG_31B.num_attention_heads * CFG_31B.global_head_dim, CFG_31B.hidden_size) + k_bad = torch.randn(CFG_31B.num_key_value_heads * CFG_31B.head_dim, CFG_31B.hidden_size) + v_bad = torch.randn(CFG_31B.num_key_value_heads * CFG_31B.head_dim, CFG_31B.hidden_size) + + with pytest.raises(AssertionError, match="q_proj rows"): + bridge._weight_to_mcore_format( + "decoder.layers.0.self_attention.linear_qkv.weight", + [q_bad, k_bad, v_bad], + ) + + +def test_mlp_fc1_asserts_wrong_count(): + bridge = _build_bridge_stub(CFG_31B) + with pytest.raises(AssertionError, match="linear_fc1.weight expects"): + bridge._weight_to_mcore_format( + "decoder.layers.0.mlp.linear_fc1.weight", + [torch.randn(4, 4), torch.randn(4, 4), torch.randn(4, 4)], + ) + + +def test_mlp_fc1_pack_concatenates_gate_up(): + bridge = _build_bridge_stub(CFG_31B) + gate = torch.randn(CFG_31B.hidden_size, CFG_31B.hidden_size) + up = torch.randn(CFG_31B.hidden_size, CFG_31B.hidden_size) + packed = bridge._weight_to_mcore_format( + "decoder.layers.0.mlp.linear_fc1.weight", + [gate, up], + ) + assert packed.shape == (2 * CFG_31B.hidden_size, CFG_31B.hidden_size) + assert torch.equal(packed[: CFG_31B.hidden_size], gate) + assert torch.equal(packed[CFG_31B.hidden_size :], up) diff --git a/tests/gemma4/test_gemma4_router.py b/tests/gemma4/test_gemma4_router.py new file mode 100644 index 0000000000..b6b33fb08f --- /dev/null +++ b/tests/gemma4/test_gemma4_router.py @@ -0,0 +1,208 @@ +from types import SimpleNamespace + +import torch + +try: + from slime_plugins.models.gemma4 import Gemma4MoELayer, Gemma4Router +except ModuleNotFoundError as exc: + missing = exc.name or "" + if not (missing == "megatron" or missing.startswith("megatron.") or missing == "mbridge"): + raise + from tests.gemma4._standalone_imports import load_gemma4_model_module + + _gemma4 = load_gemma4_model_module() + Gemma4MoELayer = _gemma4.Gemma4MoELayer + Gemma4Router = _gemma4.Gemma4Router + + +def _make_router_config(hidden_size=16, num_experts=8, top_k=2, eps=1e-6): + return SimpleNamespace( + hidden_size=hidden_size, + num_moe_experts=num_experts, + moe_router_topk=top_k, + layernorm_epsilon=eps, + ) + + +def test_router_outputs_have_correct_shapes(): + torch.manual_seed(0) + cfg = _make_router_config(num_experts=8, top_k=2) + router = Gemma4Router(cfg) + h = torch.randn(5, cfg.hidden_size) + weights, idx = router(h) + assert weights.shape == (5, cfg.moe_router_topk) + assert idx.shape == (5, cfg.moe_router_topk) + assert idx.min() >= 0 and idx.max() < cfg.num_moe_experts + + +def test_router_weights_sum_to_one_before_per_expert_scale(): + torch.manual_seed(1) + cfg = _make_router_config(num_experts=8, top_k=3) + router = Gemma4Router(cfg) + h = torch.randn(6, cfg.hidden_size) + weights, _idx = router(h) + sums = weights.sum(dim=-1) + assert torch.allclose(sums, torch.ones_like(sums), atol=1e-6) + + +def test_router_per_expert_scale_multiplies_output(): + torch.manual_seed(2) + cfg = _make_router_config(num_experts=4, top_k=2) + router = Gemma4Router(cfg) + with torch.no_grad(): + router.per_expert_scale.fill_(3.0) + h = torch.randn(4, cfg.hidden_size) + weights, _idx = router(h) + sums = weights.sum(dim=-1) + assert torch.allclose(sums, torch.full_like(sums, 3.0), atol=1e-6) + + +def _make_moe_route_stub(): + obj = object.__new__(Gemma4MoELayer) + torch.nn.Module.__init__(obj) + cfg = _make_router_config(num_experts=6, top_k=2) + obj.router = Gemma4Router(cfg) + obj.config = cfg + return obj, cfg + + +def test_moe_route_packs_topk_into_dense_probs_and_routing_map(): + torch.manual_seed(3) + obj, cfg = _make_moe_route_stub() + h = torch.randn(4, cfg.hidden_size) + probs, routing_map = obj.route(h) + + T, E = 4, cfg.num_moe_experts + assert probs.shape == (T, E) + assert routing_map.shape == (T, E) + assert routing_map.dtype == torch.bool + + assert (probs != 0).sum(dim=-1).eq(cfg.moe_router_topk).all() + assert routing_map.eq(probs != 0).all() + + expected_sums = probs.sum(dim=-1) + assert torch.allclose(expected_sums, torch.ones(T), atol=1e-6) + + +def test_moe_route_accepts_3d_input_by_flattening(): + torch.manual_seed(4) + obj, cfg = _make_moe_route_stub() + h = torch.randn(3, 2, cfg.hidden_size) + probs, routing_map = obj.route(h) + assert probs.shape == (6, cfg.num_moe_experts) + assert routing_map.shape == (6, cfg.num_moe_experts) + + +def test_moe_forward_uses_current_megatron_preprocess_contract(): + obj = object.__new__(Gemma4MoELayer) + torch.nn.Module.__init__(obj) + obj.config = SimpleNamespace(sequence_parallel=True) + obj.attn_tp_group = SimpleNamespace(size=lambda: 1) + + calls = [] + + def norm(hidden_states): + calls.append(("norm", hidden_states)) + return "experts_in" + + def shared_experts_compute(experts_in): + calls.append(("shared", experts_in)) + return None + + def route(router_in): + calls.append(("route", router_in)) + return "probs", "routing_map" + + def preprocess(experts_in, probs, routing_map): + calls.append(("preprocess", experts_in, probs, routing_map)) + return "preprocessed", "preprocessed_probs" + + def dispatch(experts_in, probs): + calls.append(("dispatch", experts_in, probs)) + return "dispatched", "dispatched_probs" + + def routed_experts_compute(dispatched_input, probs): + calls.append(("experts", dispatched_input, probs)) + return "expert_output", None + + def combine(output): + calls.append(("combine", output)) + return "combined" + + def postprocess(output, shared_expert_output): + calls.append(("postprocess", output, shared_expert_output)) + return "postprocessed" + + obj.pre_feedforward_layernorm_2 = norm + obj.shared_experts_compute = shared_experts_compute + obj.route = route + obj.preprocess = preprocess + obj.dispatch = dispatch + obj.routed_experts_compute = routed_experts_compute + obj.combine = combine + obj.postprocess = postprocess + + output, bias = obj.forward("hidden", router_input="router") + + assert output == "postprocessed" + assert bias is None + assert calls == [ + ("norm", "hidden"), + ("shared", "experts_in"), + ("route", "router"), + ("preprocess", "experts_in", "probs", "routing_map"), + ("dispatch", "preprocessed", "preprocessed_probs"), + ("experts", "dispatched", "dispatched_probs"), + ("combine", "expert_output"), + ("postprocess", "combined", None), + ] + + +def _hf_reference_router(h, proj_w, scale, per_expert_scale, top_k, eps=1e-6): + """Reference implementation of the HF Gemma4 router equation: + + h_norm = rmsnorm_noscale(h) # no-learnable-scale RMSNorm + h_norm2 = h_norm * scale / sqrt(H) # per-hidden learnable scale + logits = proj_w @ h_norm2 # [T, E] + probs = softmax(logits) + top_w, top_i = topk(probs, k=top_k) + top_w = top_w / sum(top_w) # renormalize + top_w = top_w * per_expert_scale[top_i] # per-expert scale multiplier + + This closes the loop on what Gemma4Router computes: exercises every step + (RMSNorm without scale, per-hidden scale, proj, softmax, topk, renormalise, + per-expert scale) and guards against silent reordering of those ops in + future refactors. + """ + h = h.float() + norm = h * torch.pow(h.pow(2).mean(-1, keepdim=True) + eps, -0.5) + h_norm2 = norm * scale * (h.shape[-1] ** -0.5) + logits = torch.nn.functional.linear(h_norm2, proj_w) + probs = torch.softmax(logits, dim=-1) + top_w, top_i = torch.topk(probs, k=top_k, dim=-1) + top_w = top_w / top_w.sum(dim=-1, keepdim=True) + top_w = top_w * per_expert_scale[top_i] + return top_w, top_i + + +def test_router_matches_hf_reference_equation(): + torch.manual_seed(42) + cfg = _make_router_config(hidden_size=32, num_experts=8, top_k=2) + router = Gemma4Router(cfg) + with torch.no_grad(): + router.scale.copy_(torch.randn(cfg.hidden_size) * 0.1 + 1.0) + router.per_expert_scale.copy_(torch.randn(cfg.num_moe_experts) * 0.2 + 1.0) + + h = torch.randn(5, cfg.hidden_size) + w, idx = router(h) + w_ref, idx_ref = _hf_reference_router( + h, + router.proj.weight, + router.scale, + router.per_expert_scale, + cfg.moe_router_topk, + eps=cfg.layernorm_epsilon, + ) + + assert torch.equal(idx, idx_ref), f"router top-k indices diverge: ours={idx}, ref={idx_ref}" + assert torch.allclose(w.float(), w_ref, atol=1e-5), "router top-k weights diverge from HF reference" diff --git a/tests/gemma4/test_gemma4_sft_rollout.py b/tests/gemma4/test_gemma4_sft_rollout.py new file mode 100644 index 0000000000..be861344c5 --- /dev/null +++ b/tests/gemma4/test_gemma4_sft_rollout.py @@ -0,0 +1,115 @@ +import os + +import pytest + +GEMMA4_CKPT = os.environ.get("GEMMA4_CKPT", "/fsx-shopper-intel/dev/jianhfan/gemma-4-31b-it") + +pytestmark = pytest.mark.skipif( + not os.path.exists(os.path.join(GEMMA4_CKPT, "tokenizer_config.json")), + reason=f"Gemma4 checkpoint tokenizer not found at {GEMMA4_CKPT}", +) + + +class _FakeArgs: + def __init__(self, ckpt, batch_size): + self.hf_checkpoint = ckpt + self.loss_mask_type = "gemma4" + self.rollout_batch_size = batch_size + self.rollout_global_dataset = True + + +class _FakeDataBuffer: + def __init__(self, samples): + self._samples = samples + + def get_samples(self, n): + return [(s,) for s in self._samples[:n]] + + +def _reset_sft_module_globals(): + import slime.rollout.sft_rollout as sft + + sft.TOKENIZER = None + sft.PROCESSOR = None + sft.MASK_GENERATOR = None + sft.SAMPLE_PRINTED = False + + +def _run_rollout(messages_list): + import slime.rollout.sft_rollout as sft + from slime.utils.types import Sample + + _reset_sft_module_globals() + samples = [Sample(prompt=msgs) for msgs in messages_list] + args = _FakeArgs(GEMMA4_CKPT, batch_size=len(samples)) + buf = _FakeDataBuffer(samples) + out = sft.generate_rollout(args, rollout_id=0, data_buffer=buf, evaluation=False) + unwrapped = [item[0] if isinstance(item, tuple) else item for item in out] + return unwrapped, sft.TOKENIZER + + +def test_tokens_full_mask_is_tail(): + messages = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "It is 4."}, + ] + samples, tok = _run_rollout([messages]) + sample = samples[0] + + assert len(sample.tokens) > 0 + assert sample.response_length > 0 + assert len(sample.loss_mask) == sample.response_length + assert len(sample.loss_mask) <= len(sample.tokens) + + tail_tokens = sample.tokens[-sample.response_length :] + masked = [tail_tokens[i] for i in range(len(tail_tokens)) if sample.loss_mask[i] == 1] + decoded = tok.decode(masked) + assert "It is 4." in decoded + assert "" in decoded + assert "What is 2+2?" not in decoded + assert "You are helpful." not in decoded + + +def test_multi_turn_response_length_spans_from_first_assistant(): + messages = [ + {"role": "user", "content": "Q1"}, + {"role": "assistant", "content": "A1"}, + {"role": "user", "content": "Q2"}, + {"role": "assistant", "content": "A2"}, + ] + samples, tok = _run_rollout([messages]) + sample = samples[0] + + tail_tokens = sample.tokens[-sample.response_length :] + masked = tok.decode([tail_tokens[i] for i in range(len(tail_tokens)) if sample.loss_mask[i] == 1]) + assert "A1" in masked + assert "A2" in masked + assert "Q2" not in masked + + assert sample.effective_response_length == sum(sample.loss_mask) + assert sample.effective_response_length < sample.response_length + + +def test_batch_of_samples_all_populated(): + convos = [ + [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello."}], + [{"role": "user", "content": "Bye"}, {"role": "assistant", "content": "Goodbye."}], + ] + out, _ = _run_rollout(convos) + assert len(out) == 2 + for sample in out: + assert len(sample.tokens) > 0 + assert len(sample.loss_mask) == sample.response_length + assert sample.reward == 0 + assert sum(sample.loss_mask) > 0 + + +def test_loss_mask_never_all_zero(): + messages = [ + {"role": "user", "content": "Solve x+1=2."}, + {"role": "assistant", "content": "x = 1."}, + ] + samples, _ = _run_rollout([messages]) + sample = samples[0] + assert sum(sample.loss_mask) > 0 diff --git a/tests/test_gemma4_12B_gsm8k_short.py b/tests/test_gemma4_12B_gsm8k_short.py new file mode 100644 index 0000000000..9903e73dbb --- /dev/null +++ b/tests/test_gemma4_12B_gsm8k_short.py @@ -0,0 +1,135 @@ +import os + +import slime.utils.external_utils.command_utils as U + + +ENABLE_EVAL = bool(int(os.environ.get("SLIME_TEST_ENABLE_EVAL", "0"))) + +MODEL_NAME = "gemma-4-12B-it" +MODEL_ID = f"google/{MODEL_NAME}" +MODEL_TYPE = "gemma4-12B" +NUM_GPUS = 8 +TORCH_DIST_CKPT = f"/root/models/{MODEL_NAME}_torch_dist" + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"hf download {MODEL_ID} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset("zhuzilin/gsm8k") + U.convert_checkpoint( + model_name=MODEL_NAME, + megatron_model_type=MODEL_TYPE, + num_gpus_per_node=NUM_GPUS, + dir_dst="/root/models", + ) + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " f"--ref-load {TORCH_DIST_CKPT} " + + rollout_args = ( + "--prompt-data /root/datasets/gsm8k/train.parquet " + "--input-key messages " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type math " + "--num-rollout 2 " + "--rollout-batch-size 4 " + "--n-samples-per-prompt 4 " + "--rollout-max-response-len 1024 " + "--rollout-temperature 0.8 " + "--rollout-top-p 1.0 " + "--global-batch-size 16 " + ) + + eval_args = ( + f"{'--eval-interval 20 ' if ENABLE_EVAL else ''}" + "--eval-prompt-data gsm8k /root/datasets/gsm8k/test.parquet " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 1024 " + "--eval-top-k 1 " + ) + + perf_args = ( + "--tensor-model-parallel-size 2 " + "--sequence-parallel " + "--pipeline-model-parallel-size 4 " + "--context-parallel-size 1 " + "--expert-model-parallel-size 1 " + "--expert-tensor-parallel-size 1 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 4096 " + ) + + grpo_args = ( + "--advantage-estimator grpo " + "--use-kl-loss " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + "--optimizer-cpu-offload " + "--overlap-cpu-optimizer-d2h-h2d " + "--use-precision-aware-optimizer " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 2 " + "--sglang-mem-fraction-static 0.75 " + "--sglang-cuda-graph-max-bs 16 " + "--sglang-enable-metrics " + ) + + misc_args = ( + "--ci-test " + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + "--attention-backend flash " + "--loss-mask-type gemma4 " + "--actor-num-nodes 1 " + "--actor-num-gpus-per-node 8 " + "--colocate " + "--megatron-to-hf-mode raw " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{perf_args} " + f"{eval_args} " + f"{sglang_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_GPUS, + megatron_model_type=MODEL_TYPE, + ) + + +if __name__ == "__main__": + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/utils/test_loss_mask_type_gemma4.py b/tests/utils/test_loss_mask_type_gemma4.py new file mode 100644 index 0000000000..2597e92b04 --- /dev/null +++ b/tests/utils/test_loss_mask_type_gemma4.py @@ -0,0 +1,171 @@ +import ast +import pathlib + +from slime.utils.mask_utils import MultiTurnLossMaskGenerator + + +class FakeGemma4Tokenizer: + is_fast = True + + def __call__(self, text, add_special_tokens=False, return_offsets_mapping=False): + encoded = {"input_ids": [ord(ch) for ch in text]} + if return_offsets_mapping: + encoded["offset_mapping"] = [(i, i + 1) for i in range(len(text))] + return encoded + + def decode(self, token_ids): + return "".join(chr(t) for t in token_ids) + + def get_added_vocab(self): + return {} + + def apply_chat_template( + self, + messages, + tokenize=True, + tools=None, + add_generation_prompt=False, + return_dict=False, + add_special_tokens=False, + **kwargs, + ): + rendered = self.render(messages, add_generation_prompt=add_generation_prompt) + if tokenize: + return [ord(ch) for ch in rendered] + return rendered + + def render(self, messages, add_generation_prompt=False): + pieces = [""] + for message in messages: + role = "model" if message["role"] == "assistant" else message["role"] + content = message.get("content", "") + reasoning = message.get("reasoning") + body = "" + if role == "model" and reasoning: + body += f"<|channel>thought\n{reasoning}\n" + body += content + pieces.append(f"<|turn>{role}\n{body}\n") + if add_generation_prompt: + pieces.append("<|turn>model\n<|channel>thought\n") + return "".join(pieces) + + +def _masked_text(gen, messages): + token_ids, mask = gen.get_loss_mask(messages) + assert len(token_ids) == len(mask) + return gen.tokenizer.decode([token_ids[i] for i in range(len(token_ids)) if mask[i] == 1]) + + +def _unmasked_text(gen, messages): + token_ids, mask = gen.get_loss_mask(messages) + return gen.tokenizer.decode([token_ids[i] for i in range(len(token_ids)) if mask[i] == 0]) + + +def _make_gen(): + return MultiTurnLossMaskGenerator(FakeGemma4Tokenizer(), tokenizer_type="gemma4") + + +def test_single_turn_masks_only_assistant(): + gen = _make_gen() + msgs = [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello."}] + assert _masked_text(gen, msgs) == "Hello.\n" + + +def test_multi_turn_masks_each_assistant_turn(): + gen = _make_gen() + msgs = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "It is 4."}, + {"role": "user", "content": "And 3+3?"}, + {"role": "assistant", "content": "It is 6."}, + ] + assert _masked_text(gen, msgs) == "It is 4.\nIt is 6.\n" + + +def test_system_and_user_never_masked(): + gen = _make_gen() + msgs = [ + {"role": "system", "content": "SYS"}, + {"role": "user", "content": "USR"}, + {"role": "assistant", "content": "ASST"}, + ] + unmasked = _unmasked_text(gen, msgs) + assert "SYS" in unmasked + assert "USR" in unmasked + assert "ASST" not in unmasked + + +def test_turn_terminator_included_in_loss(): + gen = _make_gen() + msgs = [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Yo"}] + assert "" in _masked_text(gen, msgs) + + +def test_model_header_not_masked(): + gen = _make_gen() + msgs = [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Yo"}] + assert "<|turn>model" not in _masked_text(gen, msgs) + + +def test_step_loss_mask_excludes_turn(): + gen = _make_gen() + msgs = [ + {"role": "user", "content": "Q1"}, + {"role": "assistant", "content": "A1", "step_loss_mask": 0}, + {"role": "user", "content": "Q2"}, + {"role": "assistant", "content": "A2"}, + ] + masked = _masked_text(gen, msgs) + assert "A1" not in masked + assert masked == "A2\n" + + +def test_thinking_channel_excluded_from_loss(): + gen = _make_gen() + msgs = [ + {"role": "user", "content": "Q"}, + {"role": "assistant", "content": "ANSWER", "reasoning": "secret chain of thought"}, + ] + masked = _masked_text(gen, msgs) + assert "secret chain of thought" not in masked + assert "ANSWER\n" == masked + + +def test_consecutive_assistant_turns(): + gen = _make_gen() + msgs = [ + {"role": "user", "content": "Q"}, + {"role": "assistant", "content": "first"}, + {"role": "assistant", "content": "second"}, + ] + masked = _masked_text(gen, msgs) + assert "first" in masked + assert "second" in masked + + +def test_response_lengths_helper(): + gen = _make_gen() + msgs = [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello."}] + _, mask = gen.get_loss_mask(msgs) + (length,) = gen.get_response_lengths([mask]) + assert length == sum(mask) + assert length > 0 + + +def test_gemma4_is_an_accepted_argparse_choice(): + arguments_py = pathlib.Path(__file__).resolve().parents[2] / "slime/utils/arguments.py" + tree = ast.parse(arguments_py.read_text()) + + for node in ast.walk(tree): + if not isinstance(node, ast.Call): + continue + if not any(isinstance(arg, ast.Constant) and arg.value == "--loss-mask-type" for arg in node.args): + continue + + choices = next((kw.value for kw in node.keywords if kw.arg == "choices"), None) + assert choices is not None, "no choices=[...] found for --loss-mask-type" + assert "gemma4" in ast.literal_eval(choices) + break + else: + raise AssertionError("could not locate --loss-mask-type in arguments.py") diff --git a/tools/convert_hf_to_torch_dist.py b/tools/convert_hf_to_torch_dist.py index 5decc7b4b2..13c1939130 100644 --- a/tools/convert_hf_to_torch_dist.py +++ b/tools/convert_hf_to_torch_dist.py @@ -21,6 +21,12 @@ def add_convertion_args(parser): """Add conversion arguments to the parser""" parser.add_argument("--hf-checkpoint", type=str, required=True, help="HuggingFace model path") + parser.add_argument( + "--custom-model-provider-path", + type=str, + default=None, + help="Path to a custom model provider function.", + ) parser.add_argument( "--megatron-to-hf-mode", choices=["raw", "bridge"],