Add Gemma4 dense and MoE model support#4
Conversation
📝 WalkthroughWalkthroughAdds Gemma4 model, bridge, provider, loss-mask, launch, docs, and test support for GSM8K training, checkpoint conversion, attention, routing, rotary embeddings, and rollout masking. ChangesGemma4 Support
Sequence Diagram(s)sequenceDiagram
participant model_provider
participant GPTModel
participant _install_hooks
participant _load_layer_scalars
participant DualRotaryEmbedding
model_provider->>GPTModel: build model from args
model_provider->>_install_hooks: attach Gemma4 hooks
_install_hooks->>DualRotaryEmbedding: replace rotary_pos_emb
_install_hooks->>_load_layer_scalars: load layer_scalar buffers
_load_layer_scalars->>GPTModel: fill layer_scalar buffers
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
ceeb0b5 to
3309275
Compare
1a87c5d to
0284b88
Compare
3309275 to
c30e877
Compare
0284b88 to
f3da8a8
Compare
|
@coderabbitai full review |
✅ Action performedFull review finished. |
c30e877 to
7117b7a
Compare
f3da8a8 to
c4c07d6
Compare
|
@coderabbitai full review |
There was a problem hiding this comment.
Actionable comments posted: 8
🧹 Nitpick comments (4)
tests/gemma4/test_gemma4_sft_rollout.py (1)
5-5: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low valueAvoid baking an internal absolute path into the default.
The fallback default
/fsx-shopper-intel/dev/jianhfan/gemma-4-31b-ithardcodes internal infrastructure and a username into the repo, and is meaningless to anyone outside that environment. Prefer a neutral default (or no default, requiringGEMMA4_CKPTto be set) so the skip reason stays generic.♻️ Suggested change
-GEMMA4_CKPT = os.environ.get("GEMMA4_CKPT", "/fsx-shopper-intel/dev/jianhfan/gemma-4-31b-it") +GEMMA4_CKPT = os.environ.get("GEMMA4_CKPT", "")🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/gemma4/test_gemma4_sft_rollout.py` at line 5, Update the GEMMA4_CKPT environment lookup in test_gemma4_sft_rollout so it does not fall back to the internal absolute path; replace the hardcoded default with a neutral value or require GEMMA4_CKPT to be set explicitly. Keep the skip logic generic by ensuring the test setup only references GEMMA4_CKPT and does not embed environment-specific usernames or filesystem paths.slime/utils/mask_utils.py (1)
198-270: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚖️ Poor tradeoffHeavy duplication with
gen_multi_turn_loss_mask_qwen3_5.This method shares almost all of its structure (render → tokenize-with-offsets → validate → char-mask → prefix-sum → token mask) with
gen_multi_turn_loss_mask_qwen3_5(Lines 127-196). The only real differences are the marker strings and the thought-block handling (here the entire<|channel>thought ... <channel|>block is excluded, whereas qwen3.5 keeps the think content). Consider extracting a shared helper parameterized by markers and a thought-span resolver to avoid drift between the two paths.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@slime/utils/mask_utils.py` around lines 198 - 270, `gen_multi_turn_loss_mask_gemma4` duplicates nearly all of `gen_multi_turn_loss_mask_qwen3_5`, so refactor the shared render/tokenize/validate/char-mask/prefix-sum/token-mask flow into a common helper to prevent the two paths from drifting. Keep the model-specific pieces configurable via parameters or a small resolver for marker strings and thought-block masking behavior, then have both methods call that helper and only supply their unique markers and span rules.tests/gemma4/test_gemma4_qkv_roundtrip.py (1)
175-175: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winEscape the regex metacharacter in the
match=pattern.
match=is treated as a regex, so the.in"linear_fc1.weight expects"is an unescaped metacharacter (Ruff RUF043). It still matches today, but tighten it to avoid false positives and keep lint clean.♻️ Proposed fix
- with pytest.raises(AssertionError, match="linear_fc1.weight expects"): + with pytest.raises(AssertionError, match=r"linear_fc1\.weight expects"):🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/gemma4/test_gemma4_qkv_roundtrip.py` at line 175, The pytest assertion in test_gemma4_qkv_roundtrip is using a regex match string with an unescaped metacharacter, so update the pytest.raises(AssertionError, match=...) pattern to escape the dot in the linear_fc1.weight message. Keep the change localized to the existing test so the match remains strict and Ruff RUF043 is satisfied.Source: Linters/SAST tools
tests/gemma4/test_gemma4_cp_attention.py (1)
94-94: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low valuePrefer a top-level
import itertoolsover inline__import__.
__import__("itertools").accumulate(...)is harder to read and repeated in two tests. A module-level import is clearer.♻️ Suggested change
Add near the top imports:
import itertoolsThen:
- cu = torch.tensor([0] + list(__import__("itertools").accumulate(lens)), dtype=torch.int32, device=device) + cu = torch.tensor([0, *itertools.accumulate(lens)], dtype=torch.int32, device=device)Also applies to: 124-124
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/gemma4/test_gemma4_cp_attention.py` at line 94, The test code uses inline __import__("itertools") in Gemma4 attention tests, which is harder to read and duplicated. Add a module-level import itertools near the other imports in test_gemma4_cp_attention, then update the affected test cases to call itertools.accumulate directly instead of using __import__.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@docs/en/examples/gemma4.md`:
- Around line 30-39: The example in gemma4.md is missing the Hugging Face CLI
prerequisite, so the setup can fail before the first `hf download` call. Update
the installation steps around `git clone`, `pip install -e . --no-deps`, and the
subsequent download commands to explicitly install the Hugging Face CLI first,
then keep the `hf download` usage unchanged so readers can run the example in a
clean environment.
In `@docs/zh/examples/gemma4.md`:
- Around line 28-36: The Gemma 4 example setup is missing the prerequisite for
using the hf command before the download steps. Update the setup instructions in
the gemma4 example to explicitly install or otherwise require the Hugging Face
CLI alongside the existing environment setup near the pip and hf download
commands, so readers can run hf download on a clean machine without failing.
Reference the example’s hf download sequence and the surrounding installation
steps in the same section.
In `@scripts/run-gemma4-26B-A4B-gsm8k.sh`:
- Line 142: The Ray startup command in the script currently exposes the
dashboard and Jobs API on all interfaces via the Ray start invocation, which
makes the cluster remotely reachable by default. Update the ray start call in
run-gemma4-26B-A4B-gsm8k.sh to bind the dashboard host to localhost by default,
and add a clear opt-in override path for remote access if needed. Use the ray
start command and its dashboard-host setting as the main symbols to locate and
adjust this behavior.
- Around line 3-12: The teardown in this run script is too broad and uses
SIGKILL, which can terminate unrelated Python/Ray/Redis workloads on the host
and skip cleanup in the rollout engine shutdown path. Narrow the cleanup to only
processes started by this script (or the current run’s PID/process group) and
avoid blanket pkill targets for python, ray, and redis unless an explicit opt-in
flag is set; keep the existing ray stop --force behavior only for the intended
local runtime.
In `@scripts/run-gemma4-31B-gsm8k.sh`:
- Around line 140-141: The Ray dashboard is currently exposed on all interfaces
in the startup script, which unnecessarily opens the control plane; update the
ray start invocation to bind the dashboard to localhost instead of 0.0.0.0. Use
the existing ray start command in the script and adjust the --dashboard-host
setting so it stays reachable for ray job submit via 127.0.0.1:8265 while not
listening on external network interfaces.
- Around line 3-12: The cleanup in the script is too broad because the repeated
pkill -9 python calls can kill unrelated Python jobs and the launcher itself.
Update the shutdown logic in this script to scope termination only to the
rollout stack processes already started here, following the safer pattern used
by slime.utils.external_utils.command_utils.execute_train(), and keep the
existing sglang/ray/redis cleanup targeted to those components only.
In `@slime_plugins/models/gemma4.py`:
- Around line 789-848: The non-packed attention fallbacks in
Gemma4Attention.forward are ignoring sliding-window limits and using full causal
SDPA. Update the THD path without cu_seqlens and the 4D path to apply the same
left-window masking used by the CP/packed paths, keyed off self._is_sliding and
config.sliding_window, so tokens cannot attend outside the configured window.
Use the existing helpers around _forward_cp_subseq_mask, _forward_thd_flash, and
_forward_thd_sdpa_per_subseq as the reference points for where the masking logic
should be preserved or reused.
In `@slime/backends/megatron_utils/megatron_to_hf/__init__.py`:
- Around line 59-60: The conversion dispatch in the megatron_to_hf path only
checks for "gemma4", so official "gemma-4" model ids fall through to the
unsupported-model path. Update the model-name matching in the dispatch block
around convert_gemma4_to_hf so it accepts the hyphenated Gemma 4 spelling as
well, keeping the existing Gemma4 conversion branch working for both variants.
---
Nitpick comments:
In `@slime/utils/mask_utils.py`:
- Around line 198-270: `gen_multi_turn_loss_mask_gemma4` duplicates nearly all
of `gen_multi_turn_loss_mask_qwen3_5`, so refactor the shared
render/tokenize/validate/char-mask/prefix-sum/token-mask flow into a common
helper to prevent the two paths from drifting. Keep the model-specific pieces
configurable via parameters or a small resolver for marker strings and
thought-block masking behavior, then have both methods call that helper and only
supply their unique markers and span rules.
In `@tests/gemma4/test_gemma4_cp_attention.py`:
- Line 94: The test code uses inline __import__("itertools") in Gemma4 attention
tests, which is harder to read and duplicated. Add a module-level import
itertools near the other imports in test_gemma4_cp_attention, then update the
affected test cases to call itertools.accumulate directly instead of using
__import__.
In `@tests/gemma4/test_gemma4_qkv_roundtrip.py`:
- Line 175: The pytest assertion in test_gemma4_qkv_roundtrip is using a regex
match string with an unescaped metacharacter, so update the
pytest.raises(AssertionError, match=...) pattern to escape the dot in the
linear_fc1.weight message. Keep the change localized to the existing test so the
match remains strict and Ruff RUF043 is satisfied.
In `@tests/gemma4/test_gemma4_sft_rollout.py`:
- Line 5: Update the GEMMA4_CKPT environment lookup in test_gemma4_sft_rollout
so it does not fall back to the internal absolute path; replace the hardcoded
default with a neutral value or require GEMMA4_CKPT to be set explicitly. Keep
the skip logic generic by ensuring the test setup only references GEMMA4_CKPT
and does not embed environment-specific usernames or filesystem paths.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro Plus
Run ID: 79dfa4c0-3af7-4501-a547-a435fee0366d
📒 Files selected for processing (33)
docs/en/examples/gemma4.mddocs/en/index.rstdocs/zh/examples/gemma4.mddocs/zh/index.rstscripts/models/gemma4-12B.shscripts/models/gemma4-26B-A4B.shscripts/models/gemma4-31B.shscripts/run-gemma4-26B-A4B-gsm8k.shscripts/run-gemma4-31B-gsm8k.shslime/backends/megatron_utils/megatron_to_hf/__init__.pyslime/backends/megatron_utils/megatron_to_hf/gemma4.pyslime/utils/arguments.pyslime/utils/external_utils/command_utils.pyslime/utils/mask_utils.pyslime_plugins/mbridge/__init__.pyslime_plugins/mbridge/gemma4.pyslime_plugins/models/gemma4.pyslime_plugins/models/gemma4_provider.pytests/gemma4/_standalone_imports.pytests/gemma4/test_gemma4_attention.pytests/gemma4/test_gemma4_bridge.pytests/gemma4/test_gemma4_cp_attention.pytests/gemma4/test_gemma4_dual_rope.pytests/gemma4/test_gemma4_hf_key_contract.pytests/gemma4/test_gemma4_layer_integration.pytests/gemma4/test_gemma4_layer_scalar_broadcast.pytests/gemma4/test_gemma4_provider.pytests/gemma4/test_gemma4_qkv_roundtrip.pytests/gemma4/test_gemma4_router.pytests/gemma4/test_gemma4_sft_rollout.pytests/test_gemma4_12B_gsm8k_short.pytests/utils/test_loss_mask_type_gemma4.pytools/convert_hf_to_torch_dist.py
| ```bash | ||
| cd /root | ||
| git clone https://github.com/THUDM/slime.git | ||
| cd slime | ||
| pip install -e . --no-deps | ||
|
|
||
| hf download google/gemma-4-31B-it --local-dir /root/gemma-4-31B-it | ||
| hf download google/gemma-4-26B-A4B-it --local-dir /root/gemma-4-26B-A4B-it | ||
| hf download --repo-type dataset zhuzilin/gsm8k --local-dir /root/datasets/gsm8k | ||
| ``` |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟡 Minor | ⚡ Quick win
Document the hf CLI prerequisite.
These steps install slime with --no-deps and then immediately call hf download, so a clean environment can fail with hf: command not found. Please add the Hugging Face CLI install step before the download commands.
Suggested doc fix
cd /root
git clone https://github.com/THUDM/slime.git
cd slime
pip install -e . --no-deps
+pip install -U "huggingface_hub[cli]"
hf download google/gemma-4-31B-it --local-dir /root/gemma-4-31B-it
hf download google/gemma-4-26B-A4B-it --local-dir /root/gemma-4-26B-A4B-it
hf download --repo-type dataset zhuzilin/gsm8k --local-dir /root/datasets/gsm8k📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| ```bash | |
| cd /root | |
| git clone https://github.com/THUDM/slime.git | |
| cd slime | |
| pip install -e . --no-deps | |
| hf download google/gemma-4-31B-it --local-dir /root/gemma-4-31B-it | |
| hf download google/gemma-4-26B-A4B-it --local-dir /root/gemma-4-26B-A4B-it | |
| hf download --repo-type dataset zhuzilin/gsm8k --local-dir /root/datasets/gsm8k | |
| ``` |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@docs/en/examples/gemma4.md` around lines 30 - 39, The example in gemma4.md is
missing the Hugging Face CLI prerequisite, so the setup can fail before the
first `hf download` call. Update the installation steps around `git clone`, `pip
install -e . --no-deps`, and the subsequent download commands to explicitly
install the Hugging Face CLI first, then keep the `hf download` usage unchanged
so readers can run the example in a clean environment.
| ```bash | ||
| cd /root | ||
| git clone https://github.com/THUDM/slime.git | ||
| cd slime | ||
| pip install -e . --no-deps | ||
|
|
||
| hf download google/gemma-4-31B-it --local-dir /root/gemma-4-31B-it | ||
| hf download google/gemma-4-26B-A4B-it --local-dir /root/gemma-4-26B-A4B-it | ||
| hf download --repo-type dataset zhuzilin/gsm8k --local-dir /root/datasets/gsm8k |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟡 Minor | ⚡ Quick win
Document the hf CLI prerequisite here as well.
This guide also calls hf download without first installing the CLI, so the setup can fail on a clean machine before any checkpoint or dataset download starts.
Suggested doc fix
cd /root
git clone https://github.com/THUDM/slime.git
cd slime
pip install -e . --no-deps
+pip install -U "huggingface_hub[cli]"
hf download google/gemma-4-31B-it --local-dir /root/gemma-4-31B-it
hf download google/gemma-4-26B-A4B-it --local-dir /root/gemma-4-26B-A4B-it
hf download --repo-type dataset zhuzilin/gsm8k --local-dir /root/datasets/gsm8k📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| ```bash | |
| cd /root | |
| git clone https://github.com/THUDM/slime.git | |
| cd slime | |
| pip install -e . --no-deps | |
| hf download google/gemma-4-31B-it --local-dir /root/gemma-4-31B-it | |
| hf download google/gemma-4-26B-A4B-it --local-dir /root/gemma-4-26B-A4B-it | |
| hf download --repo-type dataset zhuzilin/gsm8k --local-dir /root/datasets/gsm8k |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@docs/zh/examples/gemma4.md` around lines 28 - 36, The Gemma 4 example setup
is missing the prerequisite for using the hf command before the download steps.
Update the setup instructions in the gemma4 example to explicitly install or
otherwise require the Hugging Face CLI alongside the existing environment setup
near the pip and hf download commands, so readers can run hf download on a clean
machine without failing. Reference the example’s hf download sequence and the
surrounding installation steps in the same section.
| pkill -9 sglang | ||
| sleep 3 | ||
| ray stop --force | ||
| pkill -9 ray | ||
| pkill -9 python | ||
| sleep 3 | ||
| pkill -9 ray | ||
| pkill -9 python | ||
| pkill -9 redis | ||
|
|
There was a problem hiding this comment.
🩺 Stability & Availability | 🟠 Major | ⚡ Quick win
Scope teardown to this run instead of killing every Python/Ray process.
pkill -9 python, pkill -9 ray, and pkill -9 redis will terminate unrelated workloads on the same host, and SIGKILL bypasses the runtime cleanup path in slime/backends/megatron_utils/actor.py:189-205 that disconnects rollout engines and destroys process groups. Keep the cleanup scoped to this script's own processes, or make the destructive sweep an explicit opt-in.
Suggested direction
-pkill -9 sglang
-sleep 3
-ray stop --force
-pkill -9 ray
-pkill -9 python
-sleep 3
-pkill -9 ray
-pkill -9 python
-pkill -9 redis
+ray stop --force || true
+# If extra cleanup is still needed, scope it to PIDs started by this script
+# or gate the broad cleanup behind an explicit env flag.📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| pkill -9 sglang | |
| sleep 3 | |
| ray stop --force | |
| pkill -9 ray | |
| pkill -9 python | |
| sleep 3 | |
| pkill -9 ray | |
| pkill -9 python | |
| pkill -9 redis | |
| ray stop --force || true | |
| # If extra cleanup is still needed, scope it to PIDs started by this script | |
| # or gate the broad cleanup behind an explicit env flag. |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@scripts/run-gemma4-26B-A4B-gsm8k.sh` around lines 3 - 12, The teardown in
this run script is too broad and uses SIGKILL, which can terminate unrelated
Python/Ray/Redis workloads on the host and skip cleanup in the rollout engine
shutdown path. Narrow the cleanup to only processes started by this script (or
the current run’s PID/process group) and avoid blanket pkill targets for python,
ray, and redis unless an explicit opt-in flag is set; keep the existing ray stop
--force behavior only for the intended local runtime.
| ) | ||
|
|
||
| export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} | ||
| ray start --head --node-ip-address "${MASTER_ADDR}" --num-gpus "${NUM_GPUS}" --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 |
There was a problem hiding this comment.
🔒 Security & Privacy | 🟠 Major | ⚡ Quick win
Don't expose the Ray dashboard on all interfaces by default.
--dashboard-host=0.0.0.0 makes the dashboard and Jobs API remotely reachable. Since this script immediately uses that API on port 8265, anyone who can reach the host can inspect the cluster and potentially submit jobs. Default this to localhost and require an explicit override for remote access.
Suggested fix
+DASHBOARD_HOST=${DASHBOARD_HOST:-127.0.0.1}
-ray start --head --node-ip-address "${MASTER_ADDR}" --num-gpus "${NUM_GPUS}" --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265
+ray start --head --node-ip-address "${MASTER_ADDR}" --num-gpus "${NUM_GPUS}" --disable-usage-stats --dashboard-host="${DASHBOARD_HOST}" --dashboard-port=8265📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| ray start --head --node-ip-address "${MASTER_ADDR}" --num-gpus "${NUM_GPUS}" --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 | |
| DASHBOARD_HOST=${DASHBOARD_HOST:-127.0.0.1} | |
| ray start --head --node-ip-address "${MASTER_ADDR}" --num-gpus "${NUM_GPUS}" --disable-usage-stats --dashboard-host="${DASHBOARD_HOST}" --dashboard-port=8265 |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@scripts/run-gemma4-26B-A4B-gsm8k.sh` at line 142, The Ray startup command in
the script currently exposes the dashboard and Jobs API on all interfaces via
the Ray start invocation, which makes the cluster remotely reachable by default.
Update the ray start call in run-gemma4-26B-A4B-gsm8k.sh to bind the dashboard
host to localhost by default, and add a clear opt-in override path for remote
access if needed. Use the ray start command and its dashboard-host setting as
the main symbols to locate and adjust this behavior.
| pkill -9 sglang | ||
| sleep 3 | ||
| ray stop --force | ||
| pkill -9 ray | ||
| pkill -9 python | ||
| sleep 3 | ||
| pkill -9 ray | ||
| pkill -9 python | ||
| pkill -9 redis | ||
|
|
There was a problem hiding this comment.
🩺 Stability & Availability | 🟠 Major | ⚡ Quick win
Don't SIGKILL every Python process on the node.
The two pkill -9 python calls will terminate unrelated jobs on the same machine and can kill the harness that launched this script. slime.utils.external_utils.command_utils.execute_train() avoids doing this for the same reason and scopes cleanup to the rollout stack instead.
Safer cleanup shape
pkill -9 sglang
sleep 3
ray stop --force
pkill -9 ray
-pkill -9 python
+pkill -9 slime
sleep 3
pkill -9 ray
-pkill -9 python
+pkill -9 slime
pkill -9 redis📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| pkill -9 sglang | |
| sleep 3 | |
| ray stop --force | |
| pkill -9 ray | |
| pkill -9 python | |
| sleep 3 | |
| pkill -9 ray | |
| pkill -9 python | |
| pkill -9 redis | |
| pkill -9 sglang | |
| sleep 3 | |
| ray stop --force | |
| pkill -9 ray | |
| pkill -9 slime | |
| sleep 3 | |
| pkill -9 ray | |
| pkill -9 slime | |
| pkill -9 redis |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@scripts/run-gemma4-31B-gsm8k.sh` around lines 3 - 12, The cleanup in the
script is too broad because the repeated pkill -9 python calls can kill
unrelated Python jobs and the launcher itself. Update the shutdown logic in this
script to scope termination only to the rollout stack processes already started
here, following the safer pattern used by
slime.utils.external_utils.command_utils.execute_train(), and keep the existing
sglang/ray/redis cleanup targeted to those components only.
| export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} | ||
| ray start --head --node-ip-address "${MASTER_ADDR}" --num-gpus "${NUM_GPUS}" --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 |
There was a problem hiding this comment.
🔒 Security & Privacy | 🟠 Major | ⚡ Quick win
Bind the Ray dashboard to localhost.
ray job submit connects to http://127.0.0.1:8265, so exposing the dashboard on 0.0.0.0 needlessly opens an unauthenticated control plane to the whole network.
Suggested fix
export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
-ray start --head --node-ip-address "${MASTER_ADDR}" --num-gpus "${NUM_GPUS}" --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265
+ray start --head --node-ip-address "${MASTER_ADDR}" --num-gpus "${NUM_GPUS}" --disable-usage-stats --dashboard-host=127.0.0.1 --dashboard-port=8265📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} | |
| ray start --head --node-ip-address "${MASTER_ADDR}" --num-gpus "${NUM_GPUS}" --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 | |
| export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} | |
| ray start --head --node-ip-address "${MASTER_ADDR}" --num-gpus "${NUM_GPUS}" --disable-usage-stats --dashboard-host=127.0.0.1 --dashboard-port=8265 |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@scripts/run-gemma4-31B-gsm8k.sh` around lines 140 - 141, The Ray dashboard is
currently exposed on all interfaces in the startup script, which unnecessarily
opens the control plane; update the ray start invocation to bind the dashboard
to localhost instead of 0.0.0.0. Use the existing ray start command in the
script and adjust the --dashboard-host setting so it stays reachable for ray job
submit via 127.0.0.1:8265 while not listening on external network interfaces.
| def forward(self, query, key, value, attention_mask=None, attn_mask_type=None, packed_seq_params=None, **kwargs): | ||
| cp_size = getattr(self.config, "context_parallel_size", 1) or 1 | ||
| is_thd = query.dim() == 3 | ||
|
|
||
| force_cp_path = getattr(self.config, "force_cp_subseq_mask", False) | ||
|
|
||
| if is_thd: | ||
| if cp_size > 1 or force_cp_path: | ||
| sw = None | ||
| if self._is_sliding: | ||
| sw_cfg = getattr(self.config, "sliding_window", None) | ||
| if sw_cfg and sw_cfg > 0: | ||
| sw = int(sw_cfg) | ||
| return self._forward_cp_subseq_mask( | ||
| query, | ||
| key, | ||
| value, | ||
| packed_seq_params, | ||
| sliding_window=sw, | ||
| ) | ||
|
|
||
| cu_seqlens = None | ||
| if packed_seq_params is not None: | ||
| cu_seqlens = packed_seq_params.cu_seqlens_q | ||
|
|
||
| hn = query.shape[2] | ||
| if cu_seqlens is not None: | ||
| if hn <= 256: | ||
| return self._forward_thd_flash(query, key, value, cu_seqlens) | ||
| return self._forward_thd_sdpa_per_subseq(query, key, value, cu_seqlens) | ||
|
|
||
| q = query.unsqueeze(0).transpose(1, 2) | ||
| k = key.unsqueeze(0).transpose(1, 2) | ||
| v = value.unsqueeze(0).transpose(1, 2) | ||
| nq, nk = q.shape[1], k.shape[1] | ||
| out = F.scaled_dot_product_attention( | ||
| q, | ||
| k, | ||
| v, | ||
| dropout_p=self.dropout_p if self.training else 0.0, | ||
| scale=self._resolve_scale(hn), | ||
| is_causal=True, | ||
| enable_gqa=(nq != nk), | ||
| ) | ||
| return out.transpose(1, 2).reshape(query.shape[0], -1) | ||
|
|
||
| q = query.permute(1, 2, 0, 3) | ||
| k = key.permute(1, 2, 0, 3) | ||
| v = value.permute(1, 2, 0, 3) | ||
| nq, nk = q.shape[1], k.shape[1] | ||
| out = F.scaled_dot_product_attention( | ||
| q, | ||
| k, | ||
| v, | ||
| dropout_p=self.dropout_p if self.training else 0.0, | ||
| scale=self._resolve_scale(query.shape[3]), | ||
| is_causal=True, | ||
| enable_gqa=(nq != nk), | ||
| ) | ||
| return out.permute(2, 0, 1, 3).reshape(out.size(2), out.size(0), -1) |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟠 Major | ⚡ Quick win
Preserve sliding-window masking in non-packed attention paths.
Lines 820-848 fall back to full causal SDPA for THD-without-cu_seqlens and 4D attention. For sliding layers, that lets tokens attend beyond config.sliding_window, while the CP and packed-varlen paths correctly enforce the left window.
🐛 Proposed direction
+ def _maybe_sliding_mask(self, q_len, k_len, device, dtype):
+ sw = getattr(self.config, "sliding_window", None)
+ if not self._is_sliding or not sw or sw <= 0:
+ return None
+ row_idx = torch.arange(q_len, device=device)
+ col_idx = torch.arange(k_len, device=device)
+ forbid = (col_idx[None, :] > row_idx[:, None]) | (
+ col_idx[None, :] < row_idx[:, None] - (int(sw) - 1)
+ )
+ return torch.where(forbid, torch.finfo(dtype).min, 0.0).to(dtype)
+
def forward(self, query, key, value, attention_mask=None, attn_mask_type=None, packed_seq_params=None, **kwargs):
...
+ sliding_mask = self._maybe_sliding_mask(q.size(-2), k.size(-2), q.device, q.dtype)
out = F.scaled_dot_product_attention(
q,
k,
v,
+ attn_mask=sliding_mask,
dropout_p=self.dropout_p if self.training else 0.0,
scale=self._resolve_scale(hn),
- is_causal=True,
+ is_causal=sliding_mask is None,
enable_gqa=(nq != nk),
)
...
+ sliding_mask = self._maybe_sliding_mask(q.size(-2), k.size(-2), q.device, q.dtype)
out = F.scaled_dot_product_attention(
q,
k,
v,
+ attn_mask=sliding_mask,
dropout_p=self.dropout_p if self.training else 0.0,
scale=self._resolve_scale(query.shape[3]),
- is_causal=True,
+ is_causal=sliding_mask is None,
enable_gqa=(nq != nk),
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def forward(self, query, key, value, attention_mask=None, attn_mask_type=None, packed_seq_params=None, **kwargs): | |
| cp_size = getattr(self.config, "context_parallel_size", 1) or 1 | |
| is_thd = query.dim() == 3 | |
| force_cp_path = getattr(self.config, "force_cp_subseq_mask", False) | |
| if is_thd: | |
| if cp_size > 1 or force_cp_path: | |
| sw = None | |
| if self._is_sliding: | |
| sw_cfg = getattr(self.config, "sliding_window", None) | |
| if sw_cfg and sw_cfg > 0: | |
| sw = int(sw_cfg) | |
| return self._forward_cp_subseq_mask( | |
| query, | |
| key, | |
| value, | |
| packed_seq_params, | |
| sliding_window=sw, | |
| ) | |
| cu_seqlens = None | |
| if packed_seq_params is not None: | |
| cu_seqlens = packed_seq_params.cu_seqlens_q | |
| hn = query.shape[2] | |
| if cu_seqlens is not None: | |
| if hn <= 256: | |
| return self._forward_thd_flash(query, key, value, cu_seqlens) | |
| return self._forward_thd_sdpa_per_subseq(query, key, value, cu_seqlens) | |
| q = query.unsqueeze(0).transpose(1, 2) | |
| k = key.unsqueeze(0).transpose(1, 2) | |
| v = value.unsqueeze(0).transpose(1, 2) | |
| nq, nk = q.shape[1], k.shape[1] | |
| out = F.scaled_dot_product_attention( | |
| q, | |
| k, | |
| v, | |
| dropout_p=self.dropout_p if self.training else 0.0, | |
| scale=self._resolve_scale(hn), | |
| is_causal=True, | |
| enable_gqa=(nq != nk), | |
| ) | |
| return out.transpose(1, 2).reshape(query.shape[0], -1) | |
| q = query.permute(1, 2, 0, 3) | |
| k = key.permute(1, 2, 0, 3) | |
| v = value.permute(1, 2, 0, 3) | |
| nq, nk = q.shape[1], k.shape[1] | |
| out = F.scaled_dot_product_attention( | |
| q, | |
| k, | |
| v, | |
| dropout_p=self.dropout_p if self.training else 0.0, | |
| scale=self._resolve_scale(query.shape[3]), | |
| is_causal=True, | |
| enable_gqa=(nq != nk), | |
| ) | |
| return out.permute(2, 0, 1, 3).reshape(out.size(2), out.size(0), -1) | |
| def _maybe_sliding_mask(self, q_len, k_len, device, dtype): | |
| sw = getattr(self.config, "sliding_window", None) | |
| if not self._is_sliding or not sw or sw <= 0: | |
| return None | |
| row_idx = torch.arange(q_len, device=device) | |
| col_idx = torch.arange(k_len, device=device) | |
| forbid = (col_idx[None, :] > row_idx[:, None]) | ( | |
| col_idx[None, :] < row_idx[:, None] - (int(sw) - 1) | |
| ) | |
| return torch.where(forbid, torch.finfo(dtype).min, 0.0).to(dtype) | |
| def forward(self, query, key, value, attention_mask=None, attn_mask_type=None, packed_seq_params=None, **kwargs): | |
| cp_size = getattr(self.config, "context_parallel_size", 1) or 1 | |
| is_thd = query.dim() == 3 | |
| force_cp_path = getattr(self.config, "force_cp_subseq_mask", False) | |
| if is_thd: | |
| if cp_size > 1 or force_cp_path: | |
| sw = None | |
| if self._is_sliding: | |
| sw_cfg = getattr(self.config, "sliding_window", None) | |
| if sw_cfg and sw_cfg > 0: | |
| sw = int(sw_cfg) | |
| return self._forward_cp_subseq_mask( | |
| query, | |
| key, | |
| value, | |
| packed_seq_params, | |
| sliding_window=sw, | |
| ) | |
| cu_seqlens = None | |
| if packed_seq_params is not None: | |
| cu_seqlens = packed_seq_params.cu_seqlens_q | |
| hn = query.shape[2] | |
| if cu_seqlens is not None: | |
| if hn <= 256: | |
| return self._forward_thd_flash(query, key, value, cu_seqlens) | |
| return self._forward_thd_sdpa_per_subseq(query, key, value, cu_seqlens) | |
| q = query.unsqueeze(0).transpose(1, 2) | |
| k = key.unsqueeze(0).transpose(1, 2) | |
| v = value.unsqueeze(0).transpose(1, 2) | |
| nq, nk = q.shape[1], k.shape[1] | |
| sliding_mask = self._maybe_sliding_mask(q.size(-2), k.size(-2), q.device, q.dtype) | |
| out = F.scaled_dot_product_attention( | |
| q, | |
| k, | |
| v, | |
| attn_mask=sliding_mask, | |
| dropout_p=self.dropout_p if self.training else 0.0, | |
| scale=self._resolve_scale(hn), | |
| is_causal=sliding_mask is None, | |
| enable_gqa=(nq != nk), | |
| ) | |
| return out.transpose(1, 2).reshape(query.shape[0], -1) | |
| q = query.permute(1, 2, 0, 3) | |
| k = key.permute(1, 2, 0, 3) | |
| v = value.permute(1, 2, 0, 3) | |
| nq, nk = q.shape[1], k.shape[1] | |
| sliding_mask = self._maybe_sliding_mask(q.size(-2), k.size(-2), q.device, q.dtype) | |
| out = F.scaled_dot_product_attention( | |
| q, | |
| k, | |
| v, | |
| attn_mask=sliding_mask, | |
| dropout_p=self.dropout_p if self.training else 0.0, | |
| scale=self._resolve_scale(query.shape[3]), | |
| is_causal=sliding_mask is None, | |
| enable_gqa=(nq != nk), | |
| ) | |
| return out.permute(2, 0, 1, 3).reshape(out.size(2), out.size(0), -1) |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@slime_plugins/models/gemma4.py` around lines 789 - 848, The non-packed
attention fallbacks in Gemma4Attention.forward are ignoring sliding-window
limits and using full causal SDPA. Update the THD path without cu_seqlens and
the 4D path to apply the same left-window masking used by the CP/packed paths,
keyed off self._is_sliding and config.sliding_window, so tokens cannot attend
outside the configured window. Use the existing helpers around
_forward_cp_subseq_mask, _forward_thd_flash, and _forward_thd_sdpa_per_subseq as
the reference points for where the masking logic should be preserved or reused.
| elif "gemma4" in model_name: | ||
| converted_named_tensors = convert_gemma4_to_hf(args, name, param) |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟠 Major | ⚡ Quick win
Accept the official gemma-4 spelling in conversion dispatch.
The PR targets google/gemma-4-* checkpoints, but this branch only matches "gemma4". If model_name is the HF id, Gemma4 raw conversion falls through to Unsupported model.
🐛 Proposed fix
- elif "gemma4" in model_name:
+ elif "gemma4" in model_name.lower() or "gemma-4" in model_name.lower():
converted_named_tensors = convert_gemma4_to_hf(args, name, param)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| elif "gemma4" in model_name: | |
| converted_named_tensors = convert_gemma4_to_hf(args, name, param) | |
| elif "gemma4" in model_name.lower() or "gemma-4" in model_name.lower(): | |
| converted_named_tensors = convert_gemma4_to_hf(args, name, param) |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@slime/backends/megatron_utils/megatron_to_hf/__init__.py` around lines 59 -
60, The conversion dispatch in the megatron_to_hf path only checks for "gemma4",
so official "gemma-4" model ids fall through to the unsupported-model path.
Update the model-name matching in the dispatch block around convert_gemma4_to_hf
so it accepts the hyphenated Gemma 4 spelling as well, keeping the existing
Gemma4 conversion branch working for both variants.
Summary
Adds native Slime support for Gemma4 text dense and MoE variants, including:
google/gemma-4-31B-itandgoogle/gemma-4-26B-A4B-it.Proof target
The proof recipe uses GSM8K rather than SWE. GSM8K isolates the model-support path: SGLang load, Megatron load/provider, raw weight sync, rollout, loss masking, backward, and optimizer metrics. SWE should be a downstream integration/capacity validation because it also tests sandboxing, tools, reward/runtime behavior, and long-context infra.
Validation
Unit/static checks:
uv run --with pytest --with torch --with transformers --with safetensors --with numpy python -m pytest tests/gemma4 tests/utils/test_loss_mask_type_gemma4.py tests/test_empty_colocated_weight_bucket.py -q(58 passed, 14 skipped)uv run --with ruff ruff check slime_plugins/models/gemma4.py slime_plugins/models/gemma4_provider.py slime_plugins/mbridge/gemma4.py slime/backends/megatron_utils/megatron_to_hf/gemma4.py slime/utils/mask_utils.py tests/gemma4 tests/utils/test_loss_mask_type_gemma4.py tests/test_empty_colocated_weight_bucket.pyuv run --with black black --check slime_plugins/models/gemma4.py slime_plugins/models/gemma4_provider.py slime_plugins/mbridge/gemma4.py slime/backends/megatron_utils/megatron_to_hf/gemma4.py slime/utils/mask_utils.py tests/gemma4 tests/utils/test_loss_mask_type_gemma4.py tests/test_empty_colocated_weight_bucket.pybash -n scripts/run-gemma4-31B-gsm8k.sh scripts/run-gemma4-26B-A4B-gsm8k.sh && git diff --checkOne-node proof runs:
google/gemma-4-31B-it: Slurm job16145,COMPLETED 0:0, 00:09:26, W&B run https://wandb.ai/augustinevmax-vmax/slime-gemma4-official-proof/runs/51l66o7otrain/loss=-3.646390043257728e-05,train/entropy_loss=0.03646389893659457,train/grad_norm=0.0031007271081220436,rollout/raw_reward=1.0.google/gemma-4-26B-A4B-it: Slurm job16143,COMPLETED 0:0, 00:08:12, W&B run https://wandb.ai/augustinevmax-vmax/slime-gemma4-official-proof/runs/eeere4c3train/loss=-0.00047254859237000346,train/entropy_loss=0.47254857420921326,train/grad_norm=0.12294327446564064.Visual report: https://wandb.ai/augustinevmax-vmax/slime-gemma4-official-proof/reports/Gemma4-Slime-GSM8K-Proof-Runs--VmlldzoxNzM1MTY2Mg==?accessToken=s43ay8st8n7w19ep73y531l898faxpdtjfvgsqd8etcrr4kszvjw6zlze6ehbmio
The report includes the original proof runs plus small stdout-reconstructed W&B runs for readable metric curves. The reconstructed runs are explicitly labeled and link back to the original Slurm/W&B provenance.
Notes
This PR is prepared in the EazyReal fork for review before any upstream THUDM/slime PR is opened. It is stacked on the empty colocated bucket fix.
Summary by CodeRabbit
New Features
Documentation
Bug Fixes
Tests