DeepSeek V4 RL support#1045
Conversation
Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
There was a problem hiding this comment.
Code Review
This pull request implements comprehensive support for the DeepSeek-V4 model, introducing specialized attention and indexer layers via TileLang kernels, a transformers patching utility, and atomic bucketing for SGLang weight updates. Feedback identifies several critical improvements: the router's proxy response should preserve upstream headers to maintain features like compression, and the attn_sink parameter must be zero-initialized to ensure training stability. Reviewers also noted an inconsistency in padding size calculations between modules and recommended using 64-bit integers for token counts to prevent overflow. Furthermore, the weight update logic should be hardened against orphan groups to avoid SGLang assertion errors, and redundant top-k calculations in the indexer should be eliminated for better efficiency.
| content_type: str = headers.get("content-type", "application/json") | ||
| return Response(content=content, status_code=status_code, media_type=content_type) |
There was a problem hiding this comment.
The _build_proxy_response method is currently dropping all upstream headers except for content-type. This will break critical functionality such as compression (Content-Encoding), caching, and session management (Set-Cookie). While the goal is to avoid re-serialization issues with Content-Length, you should still pass the original headers after removing the content-length entry, allowing the Response object to calculate it correctly for the provided body.
| content_type: str = headers.get("content-type", "application/json") | |
| return Response(content=content, status_code=status_code, media_type=content_type) | |
| content_type: str = headers.get("content-type", "application/json") | |
| headers.pop("content-length", None) | |
| return Response(content=content, status_code=status_code, headers=headers, media_type=content_type) |
| config_no_sp = copy.copy(config) | ||
| config_no_sp.sequence_parallel = False | ||
|
|
||
| self.attn_sink = nn.Parameter(torch.empty(self.n_local_heads, dtype=torch.float32)) |
There was a problem hiding this comment.
The attn_sink parameter is initialized using torch.empty, which leaves it with uninitialized (garbage) values. Since this is a learnable parameter that affects the softmax denominator, it should be initialized to a consistent starting value, such as zero, to ensure stable training.
| self.attn_sink = nn.Parameter(torch.empty(self.n_local_heads, dtype=torch.float32)) | |
| self.attn_sink = nn.Parameter(torch.zeros(self.n_local_heads, dtype=torch.float32)) |
| pad = (pad_size - replay_data.size(0) % pad_size) % pad_size | ||
| if pad != 0: | ||
| replay_data = pad_func(replay_data, pad) | ||
| pad_size = self.parallel_state.dp_size * self.args.data_pad_size_multiplier |
There was a problem hiding this comment.
There is an inconsistency in the pad_size calculation between actor.py and data.py. In miles/backends/training_utils/data.py, pad_size is calculated using parallel_state.tp_size, whereas here it uses parallel_state.dp_size. This discrepancy can lead to misaligned data shapes during training if the tensor parallel size and data parallel size differ. Ensure that dimensions are derived consistently from the configuration or parallel state.
| pad_size = self.parallel_state.dp_size * self.args.data_pad_size_multiplier | |
| pad_size = self.parallel_state.tp_size * self.args.data_pad_size_multiplier |
References
- Avoid hardcoding model dimensions; derive them from configuration or input tensor shapes instead.
| for items, items_size in pending_groups.values(): | ||
| _commit_atomic(items, items_size) |
There was a problem hiding this comment.
Processing orphan groups at the end of bucketing might lead to a crash in SGLang. The atomic bucketing logic is designed to ensure that fusion pairs (like wq_a and wkv) land in the same update bucket to satisfy SGLang's load_weights assertions. If only one half of a pair is present, committing it alone will likely trigger an AssertionError in SGLang. It is safer to raise an informative error if pending_groups is not empty here.
| return ( | ||
| loss, | ||
| torch.tensor(num_tokens if args.calculate_per_token_loss else 1, device=logits.device), | ||
| torch.tensor(num_tokens if args.calculate_per_token_loss else 1, dtype=torch.int, device=logits.device), |
There was a problem hiding this comment.
Using torch.int (32-bit) for token counts can lead to overflow in large-scale training scenarios when these values are reduced across many ranks. It is safer to use torch.long (64-bit) to ensure correctness for high token counts.
| torch.tensor(num_tokens if args.calculate_per_token_loss else 1, dtype=torch.int, device=logits.device), | |
| torch.tensor(num_tokens if args.calculate_per_token_loss else 1, dtype=torch.long, device=logits.device), |
| index_scores = batched_indexer_fwd(q, k, weights.float(), cu_ks, cu_ke) | ||
|
|
||
| topk_k = min(self.index_topk, index_scores.size(-1)) | ||
| topk_indices = index_scores.topk(topk_k, dim=-1)[1] |
There was a problem hiding this comment.
This manual topk calculation is redundant because the topk_fn (which handles both normal and replay modes) is called immediately after on line 147. Removing this line will improve efficiency by avoiding an unnecessary GPU kernel launch. Ensure that all indexing parameters are retrieved from the model configuration.
References
- Model parameters, such as index_topk, should be retrieved from the model configuration rather than being hardcoded.
|
GOAT |
Avoids _prepare_cp Ray dependency when only --model-dir is overridden: both fields now resolve to the same path unless user explicitly fans out to per-node local NVMe via --model-local-dir.
Pinaster/DeepSeek-V4-Flash-FP8-4layer ships with model_type=deepseek_v4 in config.json, but SGLang's get_config fallback (with SGLANG_APPLY_CONFIG_BACKUP=none) only fires on deepseek_ref. Rewrite the local config.json in-place so SGLang's _load_deepseek_temp_model gets reached. Idempotent; no-op for non-4-layer models.
…epseek-v4 @ 8e1ef3c) Verified single-node 8xH200 training of DeepSeek-V4-Flash-FP8-4layer (Pinaster/...) end-to-end through 12 GRPO rollouts (steps 0-11, all loss=0 since 4-layer prune produces gibberish + reward=0, but pipeline ran cleanly). Job raysubmit_WTh5sDDAaRcVrfTT. Required environment additions on top of the upstream Dockerfile: - sglang from sgl-project/sglang@deepseek_v4 (pip install -e python pulled sgl-kernel==0.3.21 and downgraded fastapi to 0.115.x) - flashinfer-jit-cache==0.6.8+cu129 (matching flashinfer-python 0.6.8) - tilelang==0.1.8 (PyPI release; 0.1.9 dropped wg_wait from T.gemm public API) - flash_mla 1.0.0+71c7379 from deepseek-ai/FlashMLA (CUDA-built) - fast_hadamard_transform 1.1.0 from Dao-AILab/fast-hadamard-transform - Megatron-LM at radixark/Megatron-LM PR radixark#28 (mlm-pr28, commit 8455dbf, in both /workspace/Megatron-LM and /root/Megatron-LM editable install) Run-time tweaks: - bf16 ckpt config.json: model_type "deepseek_v4" -> "deepseek_v3" (to satisfy AutoConfig.from_pretrained on transformers 4.57.1) - bf16 ckpt config.json: drop quantization_config (avoid sglang creating fp8 model params when our weights are bf16) - SGLANG_APPLY_CONFIG_BACKUP=none (otherwise sglang substitutes the packaged 43-layer config and breaks miles' hf_validate_args) dsv4_flash_to_bf16.py was written for the cluster's MXFP4 4-layer prune (DeepSeek-V4-Flash-4layer); the actual 12-step run used the standard FP8 prune (Pinaster/DeepSeek-V4-Flash-FP8-4layer) so the tool was not exercised. Keeping it for the MXFP4 path.
|
could you please share the kl and reward dashboard if possible? thanks |
Drop-in replacement of the in-tree no-grad ``hc_split_sinkhorn`` / ``hc_pre_raw`` / ``hc_post_raw`` / ``hc_head_raw`` paths with calls into ``tile_kernels.modeling.mhc.ops`` (sinkhorn, pre_norm_fn, pre_split_mixes, pre_apply_mix, post, head_compute_mix, plus the fused inference ``pre_big_fuse`` for ``no_grad`` paths). Public class API (``DeepSeekV4HyperConnectionUtil``, ``HCHeadParams``) unchanged so the radixark/Megatron-LM PR radixark#28 call sites in ``transformer_layer.py`` / ``transformer_block.py`` keep working. Net effect: - forward path now matches the canonical TileKernels kernels (which fuse RMS-norm + GEMM split-K, sinkhorn, and the pre-apply mix) - backward path is enabled (the original code asserted ``_HYPER_CONNECTION_MIXER_NO_GRAD = True`` and wrapped everything in ``torch.no_grad``) - ``post = 2 * sigmoid(...)`` reproduced via ``post_mult_value=2.0``; PR's single ``hc_eps`` reused for both ``pre_eps`` and ``sinkhorn_eps`` Tweaks vs upstream TileKernels' high-level wrappers: - inline ``mhc_pre`` body so we can pass ``fuse_grad_acc=False`` to ``mhc_pre_norm_fn``. The default (``True``) requires ``mhc_post`` to have written ``grad_from_mhc_post`` onto the same residual storage during backward, but Megatron's call sites use independent ``s b hc d`` -> ``b s hc d`` einops.rearrange'd tensors for ``layer_pre`` and ``layer_post`` so the storage objects don't match. - inline ``mhc_head`` body to ``.contiguous()`` the ``mixes[..., :mhc_mult]`` slice before feeding it to ``mhc_head_compute_mix_fwd_kernel`` (the kernel asserts ``strides[0] == mhc_mult`` but the slice keeps the ``mhc_mult * (mhc_mult + 2)`` parent stride). - ``.contiguous()`` everywhere we hand bf16 tensors to TileLang kernels so the no-grad fused ``pre_big_fuse`` path doesn't trip the ``view`` stride check. Verified end-to-end: 3-rollout GRPO smoke test on DeepSeek-V4-Flash-FP8-4layer (single-node 8xH200, GRPO with rollout batch 8, 4 samples/prompt, 64 tok cap) reaches steps 0..2 with entropy_loss / logprob_abs_diff in the same band as the original in-tree implementation (rl-smoke-pass tag) — see job raysubmit_B3DgGam9vCYuDAVU.
- ``kernel/act_quant.py`` is now a thin wrapper around ``tile_kernels.quant.per_token_cast(fmt='e4m3', round_sf=True)``; shape/dtype contract preserved (``(y_fp8, s_fp32)`` with ``s.shape == (*x.shape[:-1], N // block_size)``) so callers in ``qat.py`` / ``compressor.py`` / ``v4_indexer.py`` don't change. - ``kernel/sinkhorn.py`` is removed: it was only consumed by the legacy ``hyper_connection.py`` path which now calls ``tile_kernels.modeling.mhc.ops.sinkhorn_normalize`` instead. End-to-end audit of ``miles_plugins/models/deepseek_v4/ops/kernel/``: - act_quant.py -> tile_kernels.quant.per_token_cast (this commit) - sinkhorn.py -> tile_kernels.modeling.mhc.ops (Batch 1, removed) - tilelang_indexer*.py -> no TileKernels equivalent (DSV4 DSA-specific) - tilelang_sparse_mla*.py -> no TileKernels equivalent (DSV4 sparse-MLA) So every ``kernel/`` file that has a TileKernels analogue now routes through TileKernels; the remaining files are V4-specific and have no upstream replacement. Verified end-to-end with the 12-rollout GRPO smoke harness on DeepSeek-V4-Flash-FP8-4layer (single-node 8xH200, run id 260425-054701-579 -> raysubmit_a4YH6wKJ963VrTg9). All 12 steps (0..11) completed cleanly; entropy_loss / logprob_abs_diff land in the same band as the rl-smoke-pass baseline (12-step run on the unmodified PR @ a72ed84): step | baseline ent | TK ent -----+--------------+-------- 0 | 1.6858 | 1.7286 1 | 1.7854 | 1.7321 2 | 1.7360 | 1.7803 3 | 1.7696 | 1.6750 4 | 1.6815 | 1.7532 5 | 1.7172 | 1.6865 6 | 1.7394 | 1.6949 7 | 1.7634 | 1.7217 8 | 1.6674 | 1.6871 9 | 1.7330 | 1.6943 10 | 1.6319 | 1.6998 11 | 1.6962 | 1.7321
- ops/qat.py:fp8_simulate now uses tile_kernels.quant.per_token_cast_back for the FP8→BF16 dequant step (was a manual unflatten/multiply). - tools/fp8_cast_bf16.py:weight_dequant replaces the in-tree Triton kernel with tile_kernels.quant.cast_back (128x128 block FP8 dequant). Both substitutions are bit-exact against the prior implementations on GPU (max_diff = 0.0 across 2D/3D/4D inputs and 256x768 fp8 weights).
…cision
TileKernels' MHC fp32 GEMM uses TF32 tensor cores on H100/H200 (H100+
fp32 GEMM has no full-precision tensor-core path; TF32 is the fastest
fp32 mode). PyTorch's default ``torch.backends.cuda.matmul.allow_tf32 =
False`` forces fp32 F.linear onto the SIMT path, which introduced a
~1e-4 mean-abs gap vs the TileKernels-backed HC mixer.
Setting allow_tf32 = True at deepseek_v4 plugin import time:
HC parity (TileKernels MHC fwd vs legacy in-tree fwd, no-grad):
mean_abs max_abs notes
layer_input 1.05e-5 1.56e-2 bf16 LSB output
pre.post 1.52e-5 7.65e-5
pre.comb 4.61e-6 2.87e-5
hc_post out 1.15e-8 9.77e-4 fp32-equivalent
hc_head out 1.14e-5 1.56e-2 bf16 LSB output
All ops <= 1.5e-5 mean-abs (matches the attention/indexer 1e-5 bar).
The max-abs values are 1 ULP of bf16 at the output magnitude (~6.0,
1 ULP = 2^-5 = 3.13e-2). Other fp32 matmuls in the plugin (compressor,
indexer projections) get a free TF32 speed-up as a side effect.
This reverts commit fda4195.
# Conflicts: # miles_plugins/models/deepseek_v4/ops/hyper_connection.py # miles_plugins/models/deepseek_v4/ops/qat.py
…ware-aware scale dtype
Replace the TileKernels per_token_cast wrapper with a verbatim port of
deepseek-ai/DeepSeek-V4-Pro/inference/kernel.py:act_quant so this code
path is bit-exact with the upstream inference kernel.
The ported act_quant exposes scale_dtype and inplace, matching official:
- scale_dtype=None auto-selects via SGLang should_deepgemm_weight_requant_ue8m0
(Blackwell + DeepGEMM JIT -> float8_e8m0fnu; Hopper -> float32).
- explicit scale_dtype overrides the auto path either way.
- inplace=True wires through to the fused quant+dequant kernel.
Drops the tile_kernels import for this file; the package was not listed
in any requirements manifest. Caller (qat.py:fp8_simulate) is unaffected:
on Hopper the auto path resolves to float32, preserving prior behavior.
Verified on H200:
- 26 kernel cases bit-exact vs official act_quant (4 shapes x 5 modes + 3D).
- 4-layer iterated fp8_simulate stable, ~2.25% mean rel noise as expected for E4M3.
- Hardware auto-resolve matches sglang.deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0.
Hi @hebiao064 We have some initial reasoning task reward curve for the reference! |
…ripts
Brings up the full V4-Flash 284B SFT pipeline on an 8-node × 8-GPU H20
cluster, plus the layered config infrastructure to make it portable
across H20 / H200 / different cluster sizes.
Pipeline scripts (examples/deepseek_v4_sft/):
- prepare_data.py / prepare_megatron_ckpt.sh / verify_chat_template.py
- run_stage_{a_smoke,b0_dryrun,b_full}.sh reference runners
- tools/megablocks_to_hf_bf16.py FP8/megablocks -> unpacked BF16 HF
(the missing piece in PR radixark#1045's
ckpt conversion)
- tools/tilelang_*.py sparse-MLA repro / bisect /
shape-sweep harnesses used to
isolate the V4 NaN bug
Cluster scripts (examples/deepseek_v4_sft/cluster/):
- bring_up_cluster.sh / tear_down.sh / convert_hf_to_megatron.sh
- run_stage_a.sh 1-node 4-layer NaN reproducer
Layered config (so a new cluster only edits one file):
env/base.env project path conventions
env/cluster_h20_8node.env current H20 cluster
env/cluster_h200_template.env H200 example
env.sh backwards-compat shim
hw/{h20,h200}.env GPU-specific defaults
presets/{smoke,validation,cp_smoke,
prod,long_context}.env workload presets
lib/preflight.sh shared pre-flight checks
run.sh unified launcher with CLI overrides
BOOTSTRAP.md new-cluster setup checklist
Integration patches required by V4 on this image / mbridge / sglang stack:
- miles/utils/arguments.py / mask_utils.py / transformers_patch.py
- tools/fp8_cast_bf16.py
.gitignore picks up two `!` exceptions because the root rules ignore
`lib/` and `env/` as Python-build leftovers.
| pass_configs={ | ||
| tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, | ||
| tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, | ||
| tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE: True, |
There was a problem hiding this comment.
Hello @yueming-yuan, nice to meet you. This is Peng. I’m cherry-picking your code for DeepSeek V4 training. Thank you for the implementation!
I ran into a NaN issue during the backward pass, and after debugging it appears to come from enabling TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE in the sparse MLA/MQA backward kernel. In our setup, aggressive shared-memory merge can alias/corrupt the region used by Q_shared, which matches the failure mode described in this TileLang PR: tile-ai/tilelang#2204.
Would you be okay with removing/disabling this pass config temporarily to unblock training until the upstream TileLang fix lands and we can use a version that includes it?
Thank you!
There was a problem hiding this comment.
Hello @yueming-yuan, nice to meet you. This is Peng. I’m cherry-picking your code for DeepSeek V4 training. Thank you for the implementation!
I ran into a NaN issue during the backward pass, and after debugging it appears to come from enabling
TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGEin the sparse MLA/MQA backward kernel. In our setup, aggressive shared-memory merge can alias/corrupt the region used byQ_shared, which matches the failure mode described in this TileLang PR: tile-ai/tilelang#2204.Would you be okay with removing/disabling this pass config temporarily to unblock training until the upstream TileLang fix lands and we can use a version that includes it?
Thank you!
It works when set TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE False. Thanks.
…ripts
Brings up the full V4-Flash 284B SFT pipeline on an 8-node × 8-GPU H20
cluster, plus the layered config infrastructure to make it portable
across H20 / H200 / different cluster sizes.
Pipeline scripts (examples/deepseek_v4_sft/):
- prepare_data.py / prepare_megatron_ckpt.sh / verify_chat_template.py
- run_stage_{a_smoke,b0_dryrun,b_full}.sh reference runners
- tools/megablocks_to_hf_bf16.py FP8/megablocks -> unpacked BF16 HF
(the missing piece in PR radixark#1045's
ckpt conversion)
- tools/tilelang_*.py sparse-MLA repro / bisect /
shape-sweep harnesses used to
isolate the V4 NaN bug
Cluster scripts (examples/deepseek_v4_sft/cluster/):
- bring_up_cluster.sh / tear_down.sh / convert_hf_to_megatron.sh
- run_stage_a.sh 1-node 4-layer NaN reproducer
Layered config (so a new cluster only edits one file):
env/base.env project path conventions
env/cluster_h20_8node.env current H20 cluster
env/cluster_h200_template.env H200 example
env.sh backwards-compat shim
hw/{h20,h200}.env GPU-specific defaults
presets/{smoke,validation,cp_smoke,
prod,long_context}.env workload presets
lib/preflight.sh shared pre-flight checks
run.sh unified launcher with CLI overrides
BOOTSTRAP.md new-cluster setup checklist
Integration patches required by V4 on this image / mbridge / sglang stack:
- miles/utils/arguments.py / mask_utils.py / transformers_patch.py
- tools/fp8_cast_bf16.py
.gitignore picks up two `!` exceptions because the root rules ignore
`lib/` and `env/` as Python-build leftovers.
…ripts
Brings up the full V4-Flash 284B SFT pipeline on an 8-node × 8-GPU H20
cluster, plus the layered config infrastructure to make it portable
across H20 / H200 / different cluster sizes.
Pipeline scripts (examples/deepseek_v4_sft/):
- prepare_data.py / prepare_megatron_ckpt.sh / verify_chat_template.py
- run_stage_{a_smoke,b0_dryrun,b_full}.sh reference runners
- tools/megablocks_to_hf_bf16.py FP8/megablocks -> unpacked BF16 HF
(the missing piece in PR radixark#1045's
ckpt conversion)
- tools/tilelang_*.py sparse-MLA repro / bisect /
shape-sweep harnesses used to
isolate the V4 NaN bug
Cluster scripts (examples/deepseek_v4_sft/cluster/):
- bring_up_cluster.sh / tear_down.sh / convert_hf_to_megatron.sh
- run_stage_a.sh 1-node 4-layer NaN reproducer
Layered config (so a new cluster only edits one file):
env/base.env project path conventions
env/cluster_h20_8node.env current H20 cluster
env/cluster_h200_template.env H200 example
env.sh backwards-compat shim
hw/{h20,h200}.env GPU-specific defaults
presets/{smoke,validation,cp_smoke,
prod,long_context}.env workload presets
lib/preflight.sh shared pre-flight checks
run.sh unified launcher with CLI overrides
BOOTSTRAP.md new-cluster setup checklist
Integration patches required by V4 on this image / mbridge / sglang stack:
- miles/utils/arguments.py / mask_utils.py / transformers_patch.py
- tools/fp8_cast_bf16.py
.gitignore picks up two `!` exceptions because the root rules ignore
`lib/` and `env/` as Python-build leftovers.
|
Hi @yueming-yuan and folks, thanks for putting together the DeepSeek V4 support here. I hit the same tilelang sparse MLA backward NaN issue while validating Bridge TP8/EP8 e2e training, and opened a small stacked fix here: I also noticed @pengdurice had already reported what looks like the same root cause in this thread: #1045 (comment). The behavior matches what we saw: forward loss stays finite, but backward produces NaN gradients when Validation summary from our side:
This is stacked on top of this PR branch, so the diff should just be the temporary pass-config removal plus the test guard. Thanks again! |

radixark/Megatron-LM#28