Skip to content

Add Gemma4 dense and MoE model support#4

Draft
EazyReal wants to merge 6 commits into
codex/empty-colocated-weight-bucket-20260626from
codex/gemma4-official-stacked-20260626
Draft

Add Gemma4 dense and MoE model support#4
EazyReal wants to merge 6 commits into
codex/empty-colocated-weight-bucket-20260626from
codex/gemma4-official-stacked-20260626

Conversation

@EazyReal

@EazyReal EazyReal commented Jun 26, 2026

Copy link
Copy Markdown
Owner

Summary

Adds native Slime support for Gemma4 text dense and MoE variants, including:

  • Megatron model provider/spec for Gemma4 heterogeneous attention, dual RoPE, K=V global layers, v_norm, layer scalars, and the Gemma4 MoE router.
  • HF <-> Megatron Bridge mapping for dense and 26B-A4B MoE checkpoints.
  • Raw Megatron-to-HF weight conversion used by SGLang weight updates.
  • Gemma4 chat loss masking.
  • One-node GSM8K proof scripts and docs for google/gemma-4-31B-it and google/gemma-4-26B-A4B-it.

Proof target

The proof recipe uses GSM8K rather than SWE. GSM8K isolates the model-support path: SGLang load, Megatron load/provider, raw weight sync, rollout, loss masking, backward, and optimizer metrics. SWE should be a downstream integration/capacity validation because it also tests sandboxing, tools, reward/runtime behavior, and long-context infra.

Validation

Unit/static checks:

  • uv run --with pytest --with torch --with transformers --with safetensors --with numpy python -m pytest tests/gemma4 tests/utils/test_loss_mask_type_gemma4.py tests/test_empty_colocated_weight_bucket.py -q (58 passed, 14 skipped)
  • uv run --with ruff ruff check slime_plugins/models/gemma4.py slime_plugins/models/gemma4_provider.py slime_plugins/mbridge/gemma4.py slime/backends/megatron_utils/megatron_to_hf/gemma4.py slime/utils/mask_utils.py tests/gemma4 tests/utils/test_loss_mask_type_gemma4.py tests/test_empty_colocated_weight_bucket.py
  • uv run --with black black --check slime_plugins/models/gemma4.py slime_plugins/models/gemma4_provider.py slime_plugins/mbridge/gemma4.py slime/backends/megatron_utils/megatron_to_hf/gemma4.py slime/utils/mask_utils.py tests/gemma4 tests/utils/test_loss_mask_type_gemma4.py tests/test_empty_colocated_weight_bucket.py
  • bash -n scripts/run-gemma4-31B-gsm8k.sh scripts/run-gemma4-26B-A4B-gsm8k.sh && git diff --check

One-node proof runs:

Visual report: https://wandb.ai/augustinevmax-vmax/slime-gemma4-official-proof/reports/Gemma4-Slime-GSM8K-Proof-Runs--VmlldzoxNzM1MTY2Mg==?accessToken=s43ay8st8n7w19ep73y531l898faxpdtjfvgsqd8etcrr4kszvjw6zlze6ehbmio

The report includes the original proof runs plus small stdout-reconstructed W&B runs for readable metric curves. The reconstructed runs are explicitly labeled and link back to the original Slurm/W&B provenance.

Notes

This PR is prepared in the EazyReal fork for review before any upstream THUDM/slime PR is opened. It is stacked on the empty colocated bucket fix.

Summary by CodeRabbit

  • New Features

    • Added Gemma4 support for dense and MoE text models, including new training/run scripts and model configurations.
    • Added Gemma4 model conversion, checkpoint loading, and loss-masking support for end-to-end workflows.
  • Documentation

    • Added English and Chinese examples for running Gemma4 validation jobs, with setup, conversion, and tuning guidance.
  • Bug Fixes

    • Improved distributed runtime handling and model export/import compatibility for Gemma4 workflows.
  • Tests

    • Added broad coverage for attention, routing, rotary embeddings, checkpoint conversion, rollout masking, and integration behavior.

@coderabbitai

coderabbitai Bot commented Jun 26, 2026

Copy link
Copy Markdown

Review Change Stack

📝 Walkthrough

Walkthrough

Adds Gemma4 model, bridge, provider, loss-mask, launch, docs, and test support for GSM8K training, checkpoint conversion, attention, routing, rotary embeddings, and rollout masking.

Changes

Gemma4 Support

Layer / File(s) Summary
Gemma4 transformer core
slime_plugins/models/gemma4.py
Adds Gemma4 transformer config, VNorm, router/MoE layers, heterogeneous attention paths, SDPA core attention, self-attention special cases, TE layer specs, and final Gemma4 spec assembly.
Provider hooks and checkpoint loading
slime_plugins/models/gemma4_provider.py
Builds GPTModel with Gemma4 hooks for embedding scaling, final-logit softcapping, dual RoPE, and layer-scalar loading/broadcasting.
Bridge mappings and HF conversion
slime_plugins/mbridge/__init__.py, slime_plugins/mbridge/gemma4.py, slime/backends/megatron_utils/megatron_to_hf/__init__.py, slime/backends/megatron_utils/megatron_to_hf/gemma4.py, slime/utils/external_utils/command_utils.py, tools/convert_hf_to_torch_dist.py
Registers Gemma4 bridge support, maps HF/MCore attention and MLP/MoE weights, routes gemma4 through HF conversion, and adds conversion CLI/path wiring.
Loss masking and rollout validation
slime/utils/arguments.py, slime/utils/mask_utils.py, tests/utils/test_loss_mask_type_gemma4.py, tests/gemma4/test_gemma4_sft_rollout.py
Adds Gemma4 chat-template loss masking, exposes gemma4 as a loss-mask choice, and validates assistant-span masking plus rollout response-length behavior in new tests.
Docs and launch scripts
docs/*/examples/gemma4.md, docs/*/index.rst, scripts/models/gemma4-*.sh, scripts/run-gemma4-*.sh, tests/test_gemma4_12B_gsm8k_short.py
Adds Gemma4 example docs, model-argument presets, GSM8K launch scripts, and the short 12B runner.
Conversion contract tests
tests/gemma4/_standalone_imports.py, tests/gemma4/test_gemma4_bridge.py, tests/gemma4/test_gemma4_qkv_roundtrip.py, tests/gemma4/test_gemma4_hf_key_contract.py
Adds standalone import stubs and conversion tests covering bridge config building, QKV packing, MoE expert stacking, and HF key coverage.
Runtime, routing, and provider tests
tests/gemma4/test_gemma4_attention.py, tests/gemma4/test_gemma4_cp_attention.py, tests/gemma4/test_gemma4_dual_rope.py, tests/gemma4/test_gemma4_layer_integration.py, tests/gemma4/test_gemma4_provider.py, tests/gemma4/test_gemma4_layer_scalar_broadcast.py, tests/gemma4/test_gemma4_router.py
Runtime tests cover attention, CP SDPA, dual RoPE, router behavior, provider hooks, layer integration, and layer-scalar broadcast.

Sequence Diagram(s)

sequenceDiagram
  participant model_provider
  participant GPTModel
  participant _install_hooks
  participant _load_layer_scalars
  participant DualRotaryEmbedding
  model_provider->>GPTModel: build model from args
  model_provider->>_install_hooks: attach Gemma4 hooks
  _install_hooks->>DualRotaryEmbedding: replace rotary_pos_emb
  _install_hooks->>_load_layer_scalars: load layer_scalar buffers
  _load_layer_scalars->>GPTModel: fill layer_scalar buffers
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Poem

I hopped through Gemma fields at dawn,
With masks and MoE sprouts laid on.
My whiskers twitched at RoPE light,
And every rollout thumped just right. 🐰

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 13.88% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main change: native Gemma4 dense and MoE model support.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch codex/gemma4-official-stacked-20260626

Comment @coderabbitai help to get the list of available commands.

@EazyReal EazyReal force-pushed the codex/empty-colocated-weight-bucket-20260626 branch from ceeb0b5 to 3309275 Compare June 26, 2026 17:09
@EazyReal EazyReal force-pushed the codex/gemma4-official-stacked-20260626 branch from 1a87c5d to 0284b88 Compare June 26, 2026 17:10
@EazyReal EazyReal force-pushed the codex/empty-colocated-weight-bucket-20260626 branch from 3309275 to c30e877 Compare June 26, 2026 17:30
@EazyReal EazyReal force-pushed the codex/gemma4-official-stacked-20260626 branch from 0284b88 to f3da8a8 Compare June 26, 2026 17:30
@EazyReal

Copy link
Copy Markdown
Owner Author

@coderabbitai full review

@coderabbitai

coderabbitai Bot commented Jun 26, 2026

Copy link
Copy Markdown
✅ Action performed

Full review finished.

@EazyReal EazyReal force-pushed the codex/empty-colocated-weight-bucket-20260626 branch from c30e877 to 7117b7a Compare June 26, 2026 17:56
@EazyReal EazyReal force-pushed the codex/gemma4-official-stacked-20260626 branch from f3da8a8 to c4c07d6 Compare June 26, 2026 17:56
@EazyReal

Copy link
Copy Markdown
Owner Author

@coderabbitai full review

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

🧹 Nitpick comments (4)
tests/gemma4/test_gemma4_sft_rollout.py (1)

5-5: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low value

Avoid baking an internal absolute path into the default.

The fallback default /fsx-shopper-intel/dev/jianhfan/gemma-4-31b-it hardcodes internal infrastructure and a username into the repo, and is meaningless to anyone outside that environment. Prefer a neutral default (or no default, requiring GEMMA4_CKPT to be set) so the skip reason stays generic.

♻️ Suggested change
-GEMMA4_CKPT = os.environ.get("GEMMA4_CKPT", "/fsx-shopper-intel/dev/jianhfan/gemma-4-31b-it")
+GEMMA4_CKPT = os.environ.get("GEMMA4_CKPT", "")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/gemma4/test_gemma4_sft_rollout.py` at line 5, Update the GEMMA4_CKPT
environment lookup in test_gemma4_sft_rollout so it does not fall back to the
internal absolute path; replace the hardcoded default with a neutral value or
require GEMMA4_CKPT to be set explicitly. Keep the skip logic generic by
ensuring the test setup only references GEMMA4_CKPT and does not embed
environment-specific usernames or filesystem paths.
slime/utils/mask_utils.py (1)

198-270: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚖️ Poor tradeoff

Heavy duplication with gen_multi_turn_loss_mask_qwen3_5.

This method shares almost all of its structure (render → tokenize-with-offsets → validate → char-mask → prefix-sum → token mask) with gen_multi_turn_loss_mask_qwen3_5 (Lines 127-196). The only real differences are the marker strings and the thought-block handling (here the entire <|channel>thought ... <channel|> block is excluded, whereas qwen3.5 keeps the think content). Consider extracting a shared helper parameterized by markers and a thought-span resolver to avoid drift between the two paths.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@slime/utils/mask_utils.py` around lines 198 - 270,
`gen_multi_turn_loss_mask_gemma4` duplicates nearly all of
`gen_multi_turn_loss_mask_qwen3_5`, so refactor the shared
render/tokenize/validate/char-mask/prefix-sum/token-mask flow into a common
helper to prevent the two paths from drifting. Keep the model-specific pieces
configurable via parameters or a small resolver for marker strings and
thought-block masking behavior, then have both methods call that helper and only
supply their unique markers and span rules.
tests/gemma4/test_gemma4_qkv_roundtrip.py (1)

175-175: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win

Escape the regex metacharacter in the match= pattern.

match= is treated as a regex, so the . in "linear_fc1.weight expects" is an unescaped metacharacter (Ruff RUF043). It still matches today, but tighten it to avoid false positives and keep lint clean.

♻️ Proposed fix
-    with pytest.raises(AssertionError, match="linear_fc1.weight expects"):
+    with pytest.raises(AssertionError, match=r"linear_fc1\.weight expects"):
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/gemma4/test_gemma4_qkv_roundtrip.py` at line 175, The pytest assertion
in test_gemma4_qkv_roundtrip is using a regex match string with an unescaped
metacharacter, so update the pytest.raises(AssertionError, match=...) pattern to
escape the dot in the linear_fc1.weight message. Keep the change localized to
the existing test so the match remains strict and Ruff RUF043 is satisfied.

Source: Linters/SAST tools

tests/gemma4/test_gemma4_cp_attention.py (1)

94-94: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low value

Prefer a top-level import itertools over inline __import__.

__import__("itertools").accumulate(...) is harder to read and repeated in two tests. A module-level import is clearer.

♻️ Suggested change

Add near the top imports:

import itertools

Then:

-    cu = torch.tensor([0] + list(__import__("itertools").accumulate(lens)), dtype=torch.int32, device=device)
+    cu = torch.tensor([0, *itertools.accumulate(lens)], dtype=torch.int32, device=device)

Also applies to: 124-124

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/gemma4/test_gemma4_cp_attention.py` at line 94, The test code uses
inline __import__("itertools") in Gemma4 attention tests, which is harder to
read and duplicated. Add a module-level import itertools near the other imports
in test_gemma4_cp_attention, then update the affected test cases to call
itertools.accumulate directly instead of using __import__.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@docs/en/examples/gemma4.md`:
- Around line 30-39: The example in gemma4.md is missing the Hugging Face CLI
prerequisite, so the setup can fail before the first `hf download` call. Update
the installation steps around `git clone`, `pip install -e . --no-deps`, and the
subsequent download commands to explicitly install the Hugging Face CLI first,
then keep the `hf download` usage unchanged so readers can run the example in a
clean environment.

In `@docs/zh/examples/gemma4.md`:
- Around line 28-36: The Gemma 4 example setup is missing the prerequisite for
using the hf command before the download steps. Update the setup instructions in
the gemma4 example to explicitly install or otherwise require the Hugging Face
CLI alongside the existing environment setup near the pip and hf download
commands, so readers can run hf download on a clean machine without failing.
Reference the example’s hf download sequence and the surrounding installation
steps in the same section.

In `@scripts/run-gemma4-26B-A4B-gsm8k.sh`:
- Line 142: The Ray startup command in the script currently exposes the
dashboard and Jobs API on all interfaces via the Ray start invocation, which
makes the cluster remotely reachable by default. Update the ray start call in
run-gemma4-26B-A4B-gsm8k.sh to bind the dashboard host to localhost by default,
and add a clear opt-in override path for remote access if needed. Use the ray
start command and its dashboard-host setting as the main symbols to locate and
adjust this behavior.
- Around line 3-12: The teardown in this run script is too broad and uses
SIGKILL, which can terminate unrelated Python/Ray/Redis workloads on the host
and skip cleanup in the rollout engine shutdown path. Narrow the cleanup to only
processes started by this script (or the current run’s PID/process group) and
avoid blanket pkill targets for python, ray, and redis unless an explicit opt-in
flag is set; keep the existing ray stop --force behavior only for the intended
local runtime.

In `@scripts/run-gemma4-31B-gsm8k.sh`:
- Around line 140-141: The Ray dashboard is currently exposed on all interfaces
in the startup script, which unnecessarily opens the control plane; update the
ray start invocation to bind the dashboard to localhost instead of 0.0.0.0. Use
the existing ray start command in the script and adjust the --dashboard-host
setting so it stays reachable for ray job submit via 127.0.0.1:8265 while not
listening on external network interfaces.
- Around line 3-12: The cleanup in the script is too broad because the repeated
pkill -9 python calls can kill unrelated Python jobs and the launcher itself.
Update the shutdown logic in this script to scope termination only to the
rollout stack processes already started here, following the safer pattern used
by slime.utils.external_utils.command_utils.execute_train(), and keep the
existing sglang/ray/redis cleanup targeted to those components only.

In `@slime_plugins/models/gemma4.py`:
- Around line 789-848: The non-packed attention fallbacks in
Gemma4Attention.forward are ignoring sliding-window limits and using full causal
SDPA. Update the THD path without cu_seqlens and the 4D path to apply the same
left-window masking used by the CP/packed paths, keyed off self._is_sliding and
config.sliding_window, so tokens cannot attend outside the configured window.
Use the existing helpers around _forward_cp_subseq_mask, _forward_thd_flash, and
_forward_thd_sdpa_per_subseq as the reference points for where the masking logic
should be preserved or reused.

In `@slime/backends/megatron_utils/megatron_to_hf/__init__.py`:
- Around line 59-60: The conversion dispatch in the megatron_to_hf path only
checks for "gemma4", so official "gemma-4" model ids fall through to the
unsupported-model path. Update the model-name matching in the dispatch block
around convert_gemma4_to_hf so it accepts the hyphenated Gemma 4 spelling as
well, keeping the existing Gemma4 conversion branch working for both variants.

---

Nitpick comments:
In `@slime/utils/mask_utils.py`:
- Around line 198-270: `gen_multi_turn_loss_mask_gemma4` duplicates nearly all
of `gen_multi_turn_loss_mask_qwen3_5`, so refactor the shared
render/tokenize/validate/char-mask/prefix-sum/token-mask flow into a common
helper to prevent the two paths from drifting. Keep the model-specific pieces
configurable via parameters or a small resolver for marker strings and
thought-block masking behavior, then have both methods call that helper and only
supply their unique markers and span rules.

In `@tests/gemma4/test_gemma4_cp_attention.py`:
- Line 94: The test code uses inline __import__("itertools") in Gemma4 attention
tests, which is harder to read and duplicated. Add a module-level import
itertools near the other imports in test_gemma4_cp_attention, then update the
affected test cases to call itertools.accumulate directly instead of using
__import__.

In `@tests/gemma4/test_gemma4_qkv_roundtrip.py`:
- Line 175: The pytest assertion in test_gemma4_qkv_roundtrip is using a regex
match string with an unescaped metacharacter, so update the
pytest.raises(AssertionError, match=...) pattern to escape the dot in the
linear_fc1.weight message. Keep the change localized to the existing test so the
match remains strict and Ruff RUF043 is satisfied.

In `@tests/gemma4/test_gemma4_sft_rollout.py`:
- Line 5: Update the GEMMA4_CKPT environment lookup in test_gemma4_sft_rollout
so it does not fall back to the internal absolute path; replace the hardcoded
default with a neutral value or require GEMMA4_CKPT to be set explicitly. Keep
the skip logic generic by ensuring the test setup only references GEMMA4_CKPT
and does not embed environment-specific usernames or filesystem paths.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro Plus

Run ID: 79dfa4c0-3af7-4501-a547-a435fee0366d

📥 Commits

Reviewing files that changed from the base of the PR and between 7117b7a and c4c07d6.

📒 Files selected for processing (33)
  • docs/en/examples/gemma4.md
  • docs/en/index.rst
  • docs/zh/examples/gemma4.md
  • docs/zh/index.rst
  • scripts/models/gemma4-12B.sh
  • scripts/models/gemma4-26B-A4B.sh
  • scripts/models/gemma4-31B.sh
  • scripts/run-gemma4-26B-A4B-gsm8k.sh
  • scripts/run-gemma4-31B-gsm8k.sh
  • slime/backends/megatron_utils/megatron_to_hf/__init__.py
  • slime/backends/megatron_utils/megatron_to_hf/gemma4.py
  • slime/utils/arguments.py
  • slime/utils/external_utils/command_utils.py
  • slime/utils/mask_utils.py
  • slime_plugins/mbridge/__init__.py
  • slime_plugins/mbridge/gemma4.py
  • slime_plugins/models/gemma4.py
  • slime_plugins/models/gemma4_provider.py
  • tests/gemma4/_standalone_imports.py
  • tests/gemma4/test_gemma4_attention.py
  • tests/gemma4/test_gemma4_bridge.py
  • tests/gemma4/test_gemma4_cp_attention.py
  • tests/gemma4/test_gemma4_dual_rope.py
  • tests/gemma4/test_gemma4_hf_key_contract.py
  • tests/gemma4/test_gemma4_layer_integration.py
  • tests/gemma4/test_gemma4_layer_scalar_broadcast.py
  • tests/gemma4/test_gemma4_provider.py
  • tests/gemma4/test_gemma4_qkv_roundtrip.py
  • tests/gemma4/test_gemma4_router.py
  • tests/gemma4/test_gemma4_sft_rollout.py
  • tests/test_gemma4_12B_gsm8k_short.py
  • tests/utils/test_loss_mask_type_gemma4.py
  • tools/convert_hf_to_torch_dist.py

Comment on lines +30 to +39
```bash
cd /root
git clone https://github.com/THUDM/slime.git
cd slime
pip install -e . --no-deps

hf download google/gemma-4-31B-it --local-dir /root/gemma-4-31B-it
hf download google/gemma-4-26B-A4B-it --local-dir /root/gemma-4-26B-A4B-it
hf download --repo-type dataset zhuzilin/gsm8k --local-dir /root/datasets/gsm8k
```

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟡 Minor | ⚡ Quick win

Document the hf CLI prerequisite.

These steps install slime with --no-deps and then immediately call hf download, so a clean environment can fail with hf: command not found. Please add the Hugging Face CLI install step before the download commands.

Suggested doc fix
 cd /root
 git clone https://github.com/THUDM/slime.git
 cd slime
 pip install -e . --no-deps
+pip install -U "huggingface_hub[cli]"
 
 hf download google/gemma-4-31B-it --local-dir /root/gemma-4-31B-it
 hf download google/gemma-4-26B-A4B-it --local-dir /root/gemma-4-26B-A4B-it
 hf download --repo-type dataset zhuzilin/gsm8k --local-dir /root/datasets/gsm8k
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
```bash
cd /root
git clone https://github.com/THUDM/slime.git
cd slime
pip install -e . --no-deps
hf download google/gemma-4-31B-it --local-dir /root/gemma-4-31B-it
hf download google/gemma-4-26B-A4B-it --local-dir /root/gemma-4-26B-A4B-it
hf download --repo-type dataset zhuzilin/gsm8k --local-dir /root/datasets/gsm8k
```
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@docs/en/examples/gemma4.md` around lines 30 - 39, The example in gemma4.md is
missing the Hugging Face CLI prerequisite, so the setup can fail before the
first `hf download` call. Update the installation steps around `git clone`, `pip
install -e . --no-deps`, and the subsequent download commands to explicitly
install the Hugging Face CLI first, then keep the `hf download` usage unchanged
so readers can run the example in a clean environment.

Comment on lines +28 to +36
```bash
cd /root
git clone https://github.com/THUDM/slime.git
cd slime
pip install -e . --no-deps

hf download google/gemma-4-31B-it --local-dir /root/gemma-4-31B-it
hf download google/gemma-4-26B-A4B-it --local-dir /root/gemma-4-26B-A4B-it
hf download --repo-type dataset zhuzilin/gsm8k --local-dir /root/datasets/gsm8k

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟡 Minor | ⚡ Quick win

Document the hf CLI prerequisite here as well.

This guide also calls hf download without first installing the CLI, so the setup can fail on a clean machine before any checkpoint or dataset download starts.

Suggested doc fix
 cd /root
 git clone https://github.com/THUDM/slime.git
 cd slime
 pip install -e . --no-deps
+pip install -U "huggingface_hub[cli]"
 
 hf download google/gemma-4-31B-it --local-dir /root/gemma-4-31B-it
 hf download google/gemma-4-26B-A4B-it --local-dir /root/gemma-4-26B-A4B-it
 hf download --repo-type dataset zhuzilin/gsm8k --local-dir /root/datasets/gsm8k
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
```bash
cd /root
git clone https://github.com/THUDM/slime.git
cd slime
pip install -e . --no-deps
hf download google/gemma-4-31B-it --local-dir /root/gemma-4-31B-it
hf download google/gemma-4-26B-A4B-it --local-dir /root/gemma-4-26B-A4B-it
hf download --repo-type dataset zhuzilin/gsm8k --local-dir /root/datasets/gsm8k
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@docs/zh/examples/gemma4.md` around lines 28 - 36, The Gemma 4 example setup
is missing the prerequisite for using the hf command before the download steps.
Update the setup instructions in the gemma4 example to explicitly install or
otherwise require the Hugging Face CLI alongside the existing environment setup
near the pip and hf download commands, so readers can run hf download on a clean
machine without failing. Reference the example’s hf download sequence and the
surrounding installation steps in the same section.

Comment on lines +3 to +12
pkill -9 sglang
sleep 3
ray stop --force
pkill -9 ray
pkill -9 python
sleep 3
pkill -9 ray
pkill -9 python
pkill -9 redis

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟠 Major | ⚡ Quick win

Scope teardown to this run instead of killing every Python/Ray process.

pkill -9 python, pkill -9 ray, and pkill -9 redis will terminate unrelated workloads on the same host, and SIGKILL bypasses the runtime cleanup path in slime/backends/megatron_utils/actor.py:189-205 that disconnects rollout engines and destroys process groups. Keep the cleanup scoped to this script's own processes, or make the destructive sweep an explicit opt-in.

Suggested direction
-pkill -9 sglang
-sleep 3
-ray stop --force
-pkill -9 ray
-pkill -9 python
-sleep 3
-pkill -9 ray
-pkill -9 python
-pkill -9 redis
+ray stop --force || true
+# If extra cleanup is still needed, scope it to PIDs started by this script
+# or gate the broad cleanup behind an explicit env flag.
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
pkill -9 sglang
sleep 3
ray stop --force
pkill -9 ray
pkill -9 python
sleep 3
pkill -9 ray
pkill -9 python
pkill -9 redis
ray stop --force || true
# If extra cleanup is still needed, scope it to PIDs started by this script
# or gate the broad cleanup behind an explicit env flag.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@scripts/run-gemma4-26B-A4B-gsm8k.sh` around lines 3 - 12, The teardown in
this run script is too broad and uses SIGKILL, which can terminate unrelated
Python/Ray/Redis workloads on the host and skip cleanup in the rollout engine
shutdown path. Narrow the cleanup to only processes started by this script (or
the current run’s PID/process group) and avoid blanket pkill targets for python,
ray, and redis unless an explicit opt-in flag is set; keep the existing ray stop
--force behavior only for the intended local runtime.

)

export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
ray start --head --node-ip-address "${MASTER_ADDR}" --num-gpus "${NUM_GPUS}" --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔒 Security & Privacy | 🟠 Major | ⚡ Quick win

Don't expose the Ray dashboard on all interfaces by default.

--dashboard-host=0.0.0.0 makes the dashboard and Jobs API remotely reachable. Since this script immediately uses that API on port 8265, anyone who can reach the host can inspect the cluster and potentially submit jobs. Default this to localhost and require an explicit override for remote access.

Suggested fix
+DASHBOARD_HOST=${DASHBOARD_HOST:-127.0.0.1}
-ray start --head --node-ip-address "${MASTER_ADDR}" --num-gpus "${NUM_GPUS}" --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265
+ray start --head --node-ip-address "${MASTER_ADDR}" --num-gpus "${NUM_GPUS}" --disable-usage-stats --dashboard-host="${DASHBOARD_HOST}" --dashboard-port=8265
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
ray start --head --node-ip-address "${MASTER_ADDR}" --num-gpus "${NUM_GPUS}" --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265
DASHBOARD_HOST=${DASHBOARD_HOST:-127.0.0.1}
ray start --head --node-ip-address "${MASTER_ADDR}" --num-gpus "${NUM_GPUS}" --disable-usage-stats --dashboard-host="${DASHBOARD_HOST}" --dashboard-port=8265
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@scripts/run-gemma4-26B-A4B-gsm8k.sh` at line 142, The Ray startup command in
the script currently exposes the dashboard and Jobs API on all interfaces via
the Ray start invocation, which makes the cluster remotely reachable by default.
Update the ray start call in run-gemma4-26B-A4B-gsm8k.sh to bind the dashboard
host to localhost by default, and add a clear opt-in override path for remote
access if needed. Use the ray start command and its dashboard-host setting as
the main symbols to locate and adjust this behavior.

Comment on lines +3 to +12
pkill -9 sglang
sleep 3
ray stop --force
pkill -9 ray
pkill -9 python
sleep 3
pkill -9 ray
pkill -9 python
pkill -9 redis

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟠 Major | ⚡ Quick win

Don't SIGKILL every Python process on the node.

The two pkill -9 python calls will terminate unrelated jobs on the same machine and can kill the harness that launched this script. slime.utils.external_utils.command_utils.execute_train() avoids doing this for the same reason and scopes cleanup to the rollout stack instead.

Safer cleanup shape
 pkill -9 sglang
 sleep 3
 ray stop --force
 pkill -9 ray
-pkill -9 python
+pkill -9 slime
 sleep 3
 pkill -9 ray
-pkill -9 python
+pkill -9 slime
 pkill -9 redis
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
pkill -9 sglang
sleep 3
ray stop --force
pkill -9 ray
pkill -9 python
sleep 3
pkill -9 ray
pkill -9 python
pkill -9 redis
pkill -9 sglang
sleep 3
ray stop --force
pkill -9 ray
pkill -9 slime
sleep 3
pkill -9 ray
pkill -9 slime
pkill -9 redis
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@scripts/run-gemma4-31B-gsm8k.sh` around lines 3 - 12, The cleanup in the
script is too broad because the repeated pkill -9 python calls can kill
unrelated Python jobs and the launcher itself. Update the shutdown logic in this
script to scope termination only to the rollout stack processes already started
here, following the safer pattern used by
slime.utils.external_utils.command_utils.execute_train(), and keep the existing
sglang/ray/redis cleanup targeted to those components only.

Comment on lines +140 to +141
export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
ray start --head --node-ip-address "${MASTER_ADDR}" --num-gpus "${NUM_GPUS}" --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔒 Security & Privacy | 🟠 Major | ⚡ Quick win

Bind the Ray dashboard to localhost.

ray job submit connects to http://127.0.0.1:8265, so exposing the dashboard on 0.0.0.0 needlessly opens an unauthenticated control plane to the whole network.

Suggested fix
 export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
-ray start --head --node-ip-address "${MASTER_ADDR}" --num-gpus "${NUM_GPUS}" --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265
+ray start --head --node-ip-address "${MASTER_ADDR}" --num-gpus "${NUM_GPUS}" --disable-usage-stats --dashboard-host=127.0.0.1 --dashboard-port=8265
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
ray start --head --node-ip-address "${MASTER_ADDR}" --num-gpus "${NUM_GPUS}" --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265
export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
ray start --head --node-ip-address "${MASTER_ADDR}" --num-gpus "${NUM_GPUS}" --disable-usage-stats --dashboard-host=127.0.0.1 --dashboard-port=8265
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@scripts/run-gemma4-31B-gsm8k.sh` around lines 140 - 141, The Ray dashboard is
currently exposed on all interfaces in the startup script, which unnecessarily
opens the control plane; update the ray start invocation to bind the dashboard
to localhost instead of 0.0.0.0. Use the existing ray start command in the
script and adjust the --dashboard-host setting so it stays reachable for ray job
submit via 127.0.0.1:8265 while not listening on external network interfaces.

Comment on lines +789 to +848
def forward(self, query, key, value, attention_mask=None, attn_mask_type=None, packed_seq_params=None, **kwargs):
cp_size = getattr(self.config, "context_parallel_size", 1) or 1
is_thd = query.dim() == 3

force_cp_path = getattr(self.config, "force_cp_subseq_mask", False)

if is_thd:
if cp_size > 1 or force_cp_path:
sw = None
if self._is_sliding:
sw_cfg = getattr(self.config, "sliding_window", None)
if sw_cfg and sw_cfg > 0:
sw = int(sw_cfg)
return self._forward_cp_subseq_mask(
query,
key,
value,
packed_seq_params,
sliding_window=sw,
)

cu_seqlens = None
if packed_seq_params is not None:
cu_seqlens = packed_seq_params.cu_seqlens_q

hn = query.shape[2]
if cu_seqlens is not None:
if hn <= 256:
return self._forward_thd_flash(query, key, value, cu_seqlens)
return self._forward_thd_sdpa_per_subseq(query, key, value, cu_seqlens)

q = query.unsqueeze(0).transpose(1, 2)
k = key.unsqueeze(0).transpose(1, 2)
v = value.unsqueeze(0).transpose(1, 2)
nq, nk = q.shape[1], k.shape[1]
out = F.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.dropout_p if self.training else 0.0,
scale=self._resolve_scale(hn),
is_causal=True,
enable_gqa=(nq != nk),
)
return out.transpose(1, 2).reshape(query.shape[0], -1)

q = query.permute(1, 2, 0, 3)
k = key.permute(1, 2, 0, 3)
v = value.permute(1, 2, 0, 3)
nq, nk = q.shape[1], k.shape[1]
out = F.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.dropout_p if self.training else 0.0,
scale=self._resolve_scale(query.shape[3]),
is_causal=True,
enable_gqa=(nq != nk),
)
return out.permute(2, 0, 1, 3).reshape(out.size(2), out.size(0), -1)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

Preserve sliding-window masking in non-packed attention paths.

Lines 820-848 fall back to full causal SDPA for THD-without-cu_seqlens and 4D attention. For sliding layers, that lets tokens attend beyond config.sliding_window, while the CP and packed-varlen paths correctly enforce the left window.

🐛 Proposed direction
+    def _maybe_sliding_mask(self, q_len, k_len, device, dtype):
+        sw = getattr(self.config, "sliding_window", None)
+        if not self._is_sliding or not sw or sw <= 0:
+            return None
+        row_idx = torch.arange(q_len, device=device)
+        col_idx = torch.arange(k_len, device=device)
+        forbid = (col_idx[None, :] > row_idx[:, None]) | (
+            col_idx[None, :] < row_idx[:, None] - (int(sw) - 1)
+        )
+        return torch.where(forbid, torch.finfo(dtype).min, 0.0).to(dtype)
+
     def forward(self, query, key, value, attention_mask=None, attn_mask_type=None, packed_seq_params=None, **kwargs):
...
+            sliding_mask = self._maybe_sliding_mask(q.size(-2), k.size(-2), q.device, q.dtype)
             out = F.scaled_dot_product_attention(
                 q,
                 k,
                 v,
+                attn_mask=sliding_mask,
                 dropout_p=self.dropout_p if self.training else 0.0,
                 scale=self._resolve_scale(hn),
-                is_causal=True,
+                is_causal=sliding_mask is None,
                 enable_gqa=(nq != nk),
             )
...
+        sliding_mask = self._maybe_sliding_mask(q.size(-2), k.size(-2), q.device, q.dtype)
         out = F.scaled_dot_product_attention(
             q,
             k,
             v,
+            attn_mask=sliding_mask,
             dropout_p=self.dropout_p if self.training else 0.0,
             scale=self._resolve_scale(query.shape[3]),
-            is_causal=True,
+            is_causal=sliding_mask is None,
             enable_gqa=(nq != nk),
         )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def forward(self, query, key, value, attention_mask=None, attn_mask_type=None, packed_seq_params=None, **kwargs):
cp_size = getattr(self.config, "context_parallel_size", 1) or 1
is_thd = query.dim() == 3
force_cp_path = getattr(self.config, "force_cp_subseq_mask", False)
if is_thd:
if cp_size > 1 or force_cp_path:
sw = None
if self._is_sliding:
sw_cfg = getattr(self.config, "sliding_window", None)
if sw_cfg and sw_cfg > 0:
sw = int(sw_cfg)
return self._forward_cp_subseq_mask(
query,
key,
value,
packed_seq_params,
sliding_window=sw,
)
cu_seqlens = None
if packed_seq_params is not None:
cu_seqlens = packed_seq_params.cu_seqlens_q
hn = query.shape[2]
if cu_seqlens is not None:
if hn <= 256:
return self._forward_thd_flash(query, key, value, cu_seqlens)
return self._forward_thd_sdpa_per_subseq(query, key, value, cu_seqlens)
q = query.unsqueeze(0).transpose(1, 2)
k = key.unsqueeze(0).transpose(1, 2)
v = value.unsqueeze(0).transpose(1, 2)
nq, nk = q.shape[1], k.shape[1]
out = F.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.dropout_p if self.training else 0.0,
scale=self._resolve_scale(hn),
is_causal=True,
enable_gqa=(nq != nk),
)
return out.transpose(1, 2).reshape(query.shape[0], -1)
q = query.permute(1, 2, 0, 3)
k = key.permute(1, 2, 0, 3)
v = value.permute(1, 2, 0, 3)
nq, nk = q.shape[1], k.shape[1]
out = F.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.dropout_p if self.training else 0.0,
scale=self._resolve_scale(query.shape[3]),
is_causal=True,
enable_gqa=(nq != nk),
)
return out.permute(2, 0, 1, 3).reshape(out.size(2), out.size(0), -1)
def _maybe_sliding_mask(self, q_len, k_len, device, dtype):
sw = getattr(self.config, "sliding_window", None)
if not self._is_sliding or not sw or sw <= 0:
return None
row_idx = torch.arange(q_len, device=device)
col_idx = torch.arange(k_len, device=device)
forbid = (col_idx[None, :] > row_idx[:, None]) | (
col_idx[None, :] < row_idx[:, None] - (int(sw) - 1)
)
return torch.where(forbid, torch.finfo(dtype).min, 0.0).to(dtype)
def forward(self, query, key, value, attention_mask=None, attn_mask_type=None, packed_seq_params=None, **kwargs):
cp_size = getattr(self.config, "context_parallel_size", 1) or 1
is_thd = query.dim() == 3
force_cp_path = getattr(self.config, "force_cp_subseq_mask", False)
if is_thd:
if cp_size > 1 or force_cp_path:
sw = None
if self._is_sliding:
sw_cfg = getattr(self.config, "sliding_window", None)
if sw_cfg and sw_cfg > 0:
sw = int(sw_cfg)
return self._forward_cp_subseq_mask(
query,
key,
value,
packed_seq_params,
sliding_window=sw,
)
cu_seqlens = None
if packed_seq_params is not None:
cu_seqlens = packed_seq_params.cu_seqlens_q
hn = query.shape[2]
if cu_seqlens is not None:
if hn <= 256:
return self._forward_thd_flash(query, key, value, cu_seqlens)
return self._forward_thd_sdpa_per_subseq(query, key, value, cu_seqlens)
q = query.unsqueeze(0).transpose(1, 2)
k = key.unsqueeze(0).transpose(1, 2)
v = value.unsqueeze(0).transpose(1, 2)
nq, nk = q.shape[1], k.shape[1]
sliding_mask = self._maybe_sliding_mask(q.size(-2), k.size(-2), q.device, q.dtype)
out = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=sliding_mask,
dropout_p=self.dropout_p if self.training else 0.0,
scale=self._resolve_scale(hn),
is_causal=sliding_mask is None,
enable_gqa=(nq != nk),
)
return out.transpose(1, 2).reshape(query.shape[0], -1)
q = query.permute(1, 2, 0, 3)
k = key.permute(1, 2, 0, 3)
v = value.permute(1, 2, 0, 3)
nq, nk = q.shape[1], k.shape[1]
sliding_mask = self._maybe_sliding_mask(q.size(-2), k.size(-2), q.device, q.dtype)
out = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=sliding_mask,
dropout_p=self.dropout_p if self.training else 0.0,
scale=self._resolve_scale(query.shape[3]),
is_causal=sliding_mask is None,
enable_gqa=(nq != nk),
)
return out.permute(2, 0, 1, 3).reshape(out.size(2), out.size(0), -1)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@slime_plugins/models/gemma4.py` around lines 789 - 848, The non-packed
attention fallbacks in Gemma4Attention.forward are ignoring sliding-window
limits and using full causal SDPA. Update the THD path without cu_seqlens and
the 4D path to apply the same left-window masking used by the CP/packed paths,
keyed off self._is_sliding and config.sliding_window, so tokens cannot attend
outside the configured window. Use the existing helpers around
_forward_cp_subseq_mask, _forward_thd_flash, and _forward_thd_sdpa_per_subseq as
the reference points for where the masking logic should be preserved or reused.

Comment on lines +59 to +60
elif "gemma4" in model_name:
converted_named_tensors = convert_gemma4_to_hf(args, name, param)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

Accept the official gemma-4 spelling in conversion dispatch.

The PR targets google/gemma-4-* checkpoints, but this branch only matches "gemma4". If model_name is the HF id, Gemma4 raw conversion falls through to Unsupported model.

🐛 Proposed fix
-    elif "gemma4" in model_name:
+    elif "gemma4" in model_name.lower() or "gemma-4" in model_name.lower():
         converted_named_tensors = convert_gemma4_to_hf(args, name, param)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
elif "gemma4" in model_name:
converted_named_tensors = convert_gemma4_to_hf(args, name, param)
elif "gemma4" in model_name.lower() or "gemma-4" in model_name.lower():
converted_named_tensors = convert_gemma4_to_hf(args, name, param)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@slime/backends/megatron_utils/megatron_to_hf/__init__.py` around lines 59 -
60, The conversion dispatch in the megatron_to_hf path only checks for "gemma4",
so official "gemma-4" model ids fall through to the unsupported-model path.
Update the model-name matching in the dispatch block around convert_gemma4_to_hf
so it accepts the hyphenated Gemma 4 spelling as well, keeping the existing
Gemma4 conversion branch working for both variants.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant