DeepSeek-v4 Integrate TileKernels#1
Merged
yueming-yuan merged 5 commits intoMay 15, 2026
Merged
Conversation
…epseek-v4 @ 8e1ef3c) Verified single-node 8xH200 training of DeepSeek-V4-Flash-FP8-4layer (Pinaster/...) end-to-end through 12 GRPO rollouts (steps 0-11, all loss=0 since 4-layer prune produces gibberish + reward=0, but pipeline ran cleanly). Job raysubmit_WTh5sDDAaRcVrfTT. Required environment additions on top of the upstream Dockerfile: - sglang from sgl-project/sglang@deepseek_v4 (pip install -e python pulled sgl-kernel==0.3.21 and downgraded fastapi to 0.115.x) - flashinfer-jit-cache==0.6.8+cu129 (matching flashinfer-python 0.6.8) - tilelang==0.1.8 (PyPI release; 0.1.9 dropped wg_wait from T.gemm public API) - flash_mla 1.0.0+71c7379 from deepseek-ai/FlashMLA (CUDA-built) - fast_hadamard_transform 1.1.0 from Dao-AILab/fast-hadamard-transform - Megatron-LM at radixark/Megatron-LM PR #28 (mlm-pr28, commit 8455dbf, in both /workspace/Megatron-LM and /root/Megatron-LM editable install) Run-time tweaks: - bf16 ckpt config.json: model_type "deepseek_v4" -> "deepseek_v3" (to satisfy AutoConfig.from_pretrained on transformers 4.57.1) - bf16 ckpt config.json: drop quantization_config (avoid sglang creating fp8 model params when our weights are bf16) - SGLANG_APPLY_CONFIG_BACKUP=none (otherwise sglang substitutes the packaged 43-layer config and breaks miles' hf_validate_args) dsv4_flash_to_bf16.py was written for the cluster's MXFP4 4-layer prune (DeepSeek-V4-Flash-4layer); the actual 12-step run used the standard FP8 prune (Pinaster/DeepSeek-V4-Flash-FP8-4layer) so the tool was not exercised. Keeping it for the MXFP4 path.
Drop-in replacement of the in-tree no-grad ``hc_split_sinkhorn`` / ``hc_pre_raw`` / ``hc_post_raw`` / ``hc_head_raw`` paths with calls into ``tile_kernels.modeling.mhc.ops`` (sinkhorn, pre_norm_fn, pre_split_mixes, pre_apply_mix, post, head_compute_mix, plus the fused inference ``pre_big_fuse`` for ``no_grad`` paths). Public class API (``DeepSeekV4HyperConnectionUtil``, ``HCHeadParams``) unchanged so the radixark/Megatron-LM PR #28 call sites in ``transformer_layer.py`` / ``transformer_block.py`` keep working. Net effect: - forward path now matches the canonical TileKernels kernels (which fuse RMS-norm + GEMM split-K, sinkhorn, and the pre-apply mix) - backward path is enabled (the original code asserted ``_HYPER_CONNECTION_MIXER_NO_GRAD = True`` and wrapped everything in ``torch.no_grad``) - ``post = 2 * sigmoid(...)`` reproduced via ``post_mult_value=2.0``; PR's single ``hc_eps`` reused for both ``pre_eps`` and ``sinkhorn_eps`` Tweaks vs upstream TileKernels' high-level wrappers: - inline ``mhc_pre`` body so we can pass ``fuse_grad_acc=False`` to ``mhc_pre_norm_fn``. The default (``True``) requires ``mhc_post`` to have written ``grad_from_mhc_post`` onto the same residual storage during backward, but Megatron's call sites use independent ``s b hc d`` -> ``b s hc d`` einops.rearrange'd tensors for ``layer_pre`` and ``layer_post`` so the storage objects don't match. - inline ``mhc_head`` body to ``.contiguous()`` the ``mixes[..., :mhc_mult]`` slice before feeding it to ``mhc_head_compute_mix_fwd_kernel`` (the kernel asserts ``strides[0] == mhc_mult`` but the slice keeps the ``mhc_mult * (mhc_mult + 2)`` parent stride). - ``.contiguous()`` everywhere we hand bf16 tensors to TileLang kernels so the no-grad fused ``pre_big_fuse`` path doesn't trip the ``view`` stride check. Verified end-to-end: 3-rollout GRPO smoke test on DeepSeek-V4-Flash-FP8-4layer (single-node 8xH200, GRPO with rollout batch 8, 4 samples/prompt, 64 tok cap) reaches steps 0..2 with entropy_loss / logprob_abs_diff in the same band as the original in-tree implementation (rl-smoke-pass tag) — see job raysubmit_B3DgGam9vCYuDAVU.
- ``kernel/act_quant.py`` is now a thin wrapper around ``tile_kernels.quant.per_token_cast(fmt='e4m3', round_sf=True)``; shape/dtype contract preserved (``(y_fp8, s_fp32)`` with ``s.shape == (*x.shape[:-1], N // block_size)``) so callers in ``qat.py`` / ``compressor.py`` / ``v4_indexer.py`` don't change. - ``kernel/sinkhorn.py`` is removed: it was only consumed by the legacy ``hyper_connection.py`` path which now calls ``tile_kernels.modeling.mhc.ops.sinkhorn_normalize`` instead. End-to-end audit of ``miles_plugins/models/deepseek_v4/ops/kernel/``: - act_quant.py -> tile_kernels.quant.per_token_cast (this commit) - sinkhorn.py -> tile_kernels.modeling.mhc.ops (Batch 1, removed) - tilelang_indexer*.py -> no TileKernels equivalent (DSV4 DSA-specific) - tilelang_sparse_mla*.py -> no TileKernels equivalent (DSV4 sparse-MLA) So every ``kernel/`` file that has a TileKernels analogue now routes through TileKernels; the remaining files are V4-specific and have no upstream replacement. Verified end-to-end with the 12-rollout GRPO smoke harness on DeepSeek-V4-Flash-FP8-4layer (single-node 8xH200, run id 260425-054701-579 -> raysubmit_a4YH6wKJ963VrTg9). All 12 steps (0..11) completed cleanly; entropy_loss / logprob_abs_diff land in the same band as the rl-smoke-pass baseline (12-step run on the unmodified PR @ a72ed84): step | baseline ent | TK ent -----+--------------+-------- 0 | 1.6858 | 1.7286 1 | 1.7854 | 1.7321 2 | 1.7360 | 1.7803 3 | 1.7696 | 1.6750 4 | 1.6815 | 1.7532 5 | 1.7172 | 1.6865 6 | 1.7394 | 1.6949 7 | 1.7634 | 1.7217 8 | 1.6674 | 1.6871 9 | 1.7330 | 1.6943 10 | 1.6319 | 1.6998 11 | 1.6962 | 1.7321
- ops/qat.py:fp8_simulate now uses tile_kernels.quant.per_token_cast_back for the FP8→BF16 dequant step (was a manual unflatten/multiply). - tools/fp8_cast_bf16.py:weight_dequant replaces the in-tree Triton kernel with tile_kernels.quant.cast_back (128x128 block FP8 dequant). Both substitutions are bit-exact against the prior implementations on GPU (max_diff = 0.0 across 2D/3D/4D inputs and 256x768 fp8 weights).
Author
|
@yueming-yuan This pr is to merge tilekernels (mhc and quant) into radixark#1045. thanks for reviewing! |
…cision
TileKernels' MHC fp32 GEMM uses TF32 tensor cores on H100/H200 (H100+
fp32 GEMM has no full-precision tensor-core path; TF32 is the fastest
fp32 mode). PyTorch's default ``torch.backends.cuda.matmul.allow_tf32 =
False`` forces fp32 F.linear onto the SIMT path, which introduced a
~1e-4 mean-abs gap vs the TileKernels-backed HC mixer.
Setting allow_tf32 = True at deepseek_v4 plugin import time:
HC parity (TileKernels MHC fwd vs legacy in-tree fwd, no-grad):
mean_abs max_abs notes
layer_input 1.05e-5 1.56e-2 bf16 LSB output
pre.post 1.52e-5 7.65e-5
pre.comb 4.61e-6 2.87e-5
hc_post out 1.15e-8 9.77e-4 fp32-equivalent
hc_head out 1.14e-5 1.56e-2 bf16 LSB output
All ops <= 1.5e-5 mean-abs (matches the attention/indexer 1e-5 bar).
The max-abs values are 1 ULP of bf16 at the output magnitude (~6.0,
1 ULP = 2^-5 = 3.13e-2). Other fp32 matmuls in the plugin (compressor,
indexer projections) get a free TF32 speed-up as a side effect.
8828ca6 to
bf655f0
Compare
| @@ -1,94 +1,46 @@ | |||
| """Tilelang FP8 activation quantization used by the compressor and indexer QAT paths.""" | |||
| """Per-token UE8M0 FP8 activation quantization — backed by deepseek-ai/TileKernels. | |||
There was a problem hiding this comment.
@Zhichenzzz both UE8M0 and FP32 should be supported.
see sglang 's impl:
def _dequant_fp8(weight: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
from einops import rearrange
assert (
weight.dtype == torch.float8_e4m3fn
), f"expected fp8_e4m3fn, got {weight.dtype}"
assert scale.dtype in (
torch.float8_e8m0fnu,
torch.float32,
), f"expected fp8_e8m0fnu or float32, got {scale.dtype}"
if envs.SGLANG_DEBUG_SANITY_CHECK_CONFIG.get() and not is_large_dummy_model():
assert weight.shape == (8192, 4096), f"unexpected weight shape {weight.shape}"
assert scale.shape == (64, 32), f"unexpected scale shape {scale.shape}"
weight_f32 = rearrange(
weight.float(), "(sn bn) (sk bk) -> sn bn sk bk", bn=128, bk=128
)
result = rearrange(
weight_f32 * scale.float()[:, None, :, None], "sn bn sk bk -> (sn bn) (sk bk)"
)
if envs.SGLANG_DEBUG_SANITY_CHECK_CONFIG.get() and not is_large_dummy_model():
assert result.shape == (8192, 4096)
return result.to(torch.bfloat16)
There was a problem hiding this comment.
@Zhichenzzz Note DeepSeek uses fp32 scale by default :
https://huggingface.co/deepseek-ai/DeepSeek-V4-Pro/blob/main/inference/kernel.py#L105
Owner
There was a problem hiding this comment.
@yiakwy-xpu-ml-framework-team Thank you for your comment! This is a good catch. We will make it aligned with SGLang's actual precision, depending on the runtime kernel choice.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.