Skip to content

DeepSeek-v4 Integrate TileKernels#1

Merged
yueming-yuan merged 5 commits into
yueming-yuan:deepseek-v4from
radixark:dpskv4/tilekernels-integration
May 15, 2026
Merged

DeepSeek-v4 Integrate TileKernels#1
yueming-yuan merged 5 commits into
yueming-yuan:deepseek-v4from
radixark:dpskv4/tilekernels-integration

Conversation

@Zhichenzzz
Copy link
Copy Markdown

No description provided.

Zhichenzzz and others added 4 commits April 25, 2026 05:15
…epseek-v4 @ 8e1ef3c)

Verified single-node 8xH200 training of DeepSeek-V4-Flash-FP8-4layer
(Pinaster/...) end-to-end through 12 GRPO rollouts (steps 0-11, all
loss=0 since 4-layer prune produces gibberish + reward=0, but pipeline
ran cleanly). Job raysubmit_WTh5sDDAaRcVrfTT.

Required environment additions on top of the upstream Dockerfile:
- sglang from sgl-project/sglang@deepseek_v4 (pip install -e python pulled
  sgl-kernel==0.3.21 and downgraded fastapi to 0.115.x)
- flashinfer-jit-cache==0.6.8+cu129 (matching flashinfer-python 0.6.8)
- tilelang==0.1.8 (PyPI release; 0.1.9 dropped wg_wait from T.gemm public API)
- flash_mla 1.0.0+71c7379 from deepseek-ai/FlashMLA (CUDA-built)
- fast_hadamard_transform 1.1.0 from Dao-AILab/fast-hadamard-transform
- Megatron-LM at radixark/Megatron-LM PR #28 (mlm-pr28, commit 8455dbf,
  in both /workspace/Megatron-LM and /root/Megatron-LM editable install)

Run-time tweaks:
- bf16 ckpt config.json: model_type "deepseek_v4" -> "deepseek_v3" (to
  satisfy AutoConfig.from_pretrained on transformers 4.57.1)
- bf16 ckpt config.json: drop quantization_config (avoid sglang creating
  fp8 model params when our weights are bf16)
- SGLANG_APPLY_CONFIG_BACKUP=none (otherwise sglang substitutes the
  packaged 43-layer config and breaks miles' hf_validate_args)

dsv4_flash_to_bf16.py was written for the cluster's MXFP4 4-layer prune
(DeepSeek-V4-Flash-4layer); the actual 12-step run used the standard FP8
prune (Pinaster/DeepSeek-V4-Flash-FP8-4layer) so the tool was not
exercised. Keeping it for the MXFP4 path.
Drop-in replacement of the in-tree no-grad ``hc_split_sinkhorn`` /
``hc_pre_raw`` / ``hc_post_raw`` / ``hc_head_raw`` paths with calls
into ``tile_kernels.modeling.mhc.ops`` (sinkhorn, pre_norm_fn,
pre_split_mixes, pre_apply_mix, post, head_compute_mix, plus the
fused inference ``pre_big_fuse`` for ``no_grad`` paths).

Public class API (``DeepSeekV4HyperConnectionUtil``, ``HCHeadParams``)
unchanged so the radixark/Megatron-LM PR #28 call sites in
``transformer_layer.py`` / ``transformer_block.py`` keep working.

Net effect:
- forward path now matches the canonical TileKernels kernels (which
  fuse RMS-norm + GEMM split-K, sinkhorn, and the pre-apply mix)
- backward path is enabled (the original code asserted
  ``_HYPER_CONNECTION_MIXER_NO_GRAD = True`` and wrapped everything
  in ``torch.no_grad``)
- ``post = 2 * sigmoid(...)`` reproduced via ``post_mult_value=2.0``;
  PR's single ``hc_eps`` reused for both ``pre_eps`` and
  ``sinkhorn_eps``

Tweaks vs upstream TileKernels' high-level wrappers:
- inline ``mhc_pre`` body so we can pass ``fuse_grad_acc=False`` to
  ``mhc_pre_norm_fn``. The default (``True``) requires ``mhc_post``
  to have written ``grad_from_mhc_post`` onto the same residual
  storage during backward, but Megatron's call sites use independent
  ``s b hc d`` -> ``b s hc d`` einops.rearrange'd tensors for
  ``layer_pre`` and ``layer_post`` so the storage objects don't match.
- inline ``mhc_head`` body to ``.contiguous()`` the
  ``mixes[..., :mhc_mult]`` slice before feeding it to
  ``mhc_head_compute_mix_fwd_kernel`` (the kernel asserts
  ``strides[0] == mhc_mult`` but the slice keeps the
  ``mhc_mult * (mhc_mult + 2)`` parent stride).
- ``.contiguous()`` everywhere we hand bf16 tensors to TileLang
  kernels so the no-grad fused ``pre_big_fuse`` path doesn't trip
  the ``view`` stride check.

Verified end-to-end: 3-rollout GRPO smoke test on
DeepSeek-V4-Flash-FP8-4layer (single-node 8xH200, GRPO with rollout
batch 8, 4 samples/prompt, 64 tok cap) reaches steps 0..2 with
entropy_loss / logprob_abs_diff in the same band as the original
in-tree implementation (rl-smoke-pass tag) — see job
raysubmit_B3DgGam9vCYuDAVU.
- ``kernel/act_quant.py`` is now a thin wrapper around
  ``tile_kernels.quant.per_token_cast(fmt='e4m3', round_sf=True)``;
  shape/dtype contract preserved (``(y_fp8, s_fp32)`` with
  ``s.shape == (*x.shape[:-1], N // block_size)``) so callers in
  ``qat.py`` / ``compressor.py`` / ``v4_indexer.py`` don't change.
- ``kernel/sinkhorn.py`` is removed: it was only consumed by the
  legacy ``hyper_connection.py`` path which now calls
  ``tile_kernels.modeling.mhc.ops.sinkhorn_normalize`` instead.

End-to-end audit of ``miles_plugins/models/deepseek_v4/ops/kernel/``:
- act_quant.py             -> tile_kernels.quant.per_token_cast (this commit)
- sinkhorn.py              -> tile_kernels.modeling.mhc.ops (Batch 1, removed)
- tilelang_indexer*.py     -> no TileKernels equivalent (DSV4 DSA-specific)
- tilelang_sparse_mla*.py  -> no TileKernels equivalent (DSV4 sparse-MLA)

So every ``kernel/`` file that has a TileKernels analogue now routes
through TileKernels; the remaining files are V4-specific and have no
upstream replacement.

Verified end-to-end with the 12-rollout GRPO smoke harness on
DeepSeek-V4-Flash-FP8-4layer (single-node 8xH200, run id
260425-054701-579 -> raysubmit_a4YH6wKJ963VrTg9). All 12 steps
(0..11) completed cleanly; entropy_loss / logprob_abs_diff land in
the same band as the rl-smoke-pass baseline (12-step run on the
unmodified PR @ a72ed84):

  step | baseline ent | TK ent
  -----+--------------+--------
   0   | 1.6858       | 1.7286
   1   | 1.7854       | 1.7321
   2   | 1.7360       | 1.7803
   3   | 1.7696       | 1.6750
   4   | 1.6815       | 1.7532
   5   | 1.7172       | 1.6865
   6   | 1.7394       | 1.6949
   7   | 1.7634       | 1.7217
   8   | 1.6674       | 1.6871
   9   | 1.7330       | 1.6943
  10   | 1.6319       | 1.6998
  11   | 1.6962       | 1.7321
- ops/qat.py:fp8_simulate now uses tile_kernels.quant.per_token_cast_back
  for the FP8→BF16 dequant step (was a manual unflatten/multiply).
- tools/fp8_cast_bf16.py:weight_dequant replaces the in-tree Triton kernel
  with tile_kernels.quant.cast_back (128x128 block FP8 dequant).

Both substitutions are bit-exact against the prior implementations on
GPU (max_diff = 0.0 across 2D/3D/4D inputs and 256x768 fp8 weights).
@Zhichenzzz
Copy link
Copy Markdown
Author

@yueming-yuan This pr is to merge tilekernels (mhc and quant) into radixark#1045. thanks for reviewing!

…cision

TileKernels' MHC fp32 GEMM uses TF32 tensor cores on H100/H200 (H100+
fp32 GEMM has no full-precision tensor-core path; TF32 is the fastest
fp32 mode).  PyTorch's default ``torch.backends.cuda.matmul.allow_tf32 =
False`` forces fp32 F.linear onto the SIMT path, which introduced a
~1e-4 mean-abs gap vs the TileKernels-backed HC mixer.

Setting allow_tf32 = True at deepseek_v4 plugin import time:

  HC parity (TileKernels MHC fwd vs legacy in-tree fwd, no-grad):
                       mean_abs    max_abs    notes
    layer_input        1.05e-5     1.56e-2    bf16 LSB output
    pre.post           1.52e-5     7.65e-5
    pre.comb           4.61e-6     2.87e-5
    hc_post out        1.15e-8     9.77e-4    fp32-equivalent
    hc_head out        1.14e-5     1.56e-2    bf16 LSB output

All ops <= 1.5e-5 mean-abs (matches the attention/indexer 1e-5 bar).
The max-abs values are 1 ULP of bf16 at the output magnitude (~6.0,
1 ULP = 2^-5 = 3.13e-2).  Other fp32 matmuls in the plugin (compressor,
indexer projections) get a free TF32 speed-up as a side effect.
@Zhichenzzz Zhichenzzz force-pushed the dpskv4/tilekernels-integration branch from 8828ca6 to bf655f0 Compare April 29, 2026 20:50
@@ -1,94 +1,46 @@
"""Tilelang FP8 activation quantization used by the compressor and indexer QAT paths."""
"""Per-token UE8M0 FP8 activation quantization — backed by deepseek-ai/TileKernels.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Zhichenzzz both UE8M0 and FP32 should be supported.

see sglang 's impl:

def _dequant_fp8(weight: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
    from einops import rearrange

    assert (
        weight.dtype == torch.float8_e4m3fn
    ), f"expected fp8_e4m3fn, got {weight.dtype}"
    assert scale.dtype in (
        torch.float8_e8m0fnu,
        torch.float32,
    ), f"expected fp8_e8m0fnu or float32, got {scale.dtype}"
    if envs.SGLANG_DEBUG_SANITY_CHECK_CONFIG.get() and not is_large_dummy_model():
        assert weight.shape == (8192, 4096), f"unexpected weight shape {weight.shape}"
        assert scale.shape == (64, 32), f"unexpected scale shape {scale.shape}"

    weight_f32 = rearrange(
        weight.float(), "(sn bn) (sk bk) -> sn bn sk bk", bn=128, bk=128
    )
    result = rearrange(
        weight_f32 * scale.float()[:, None, :, None], "sn bn sk bk -> (sn bn) (sk bk)"
    )
    if envs.SGLANG_DEBUG_SANITY_CHECK_CONFIG.get() and not is_large_dummy_model():
        assert result.shape == (8192, 4096)

    return result.to(torch.bfloat16)

cc @yueming-yuan

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yiakwy-xpu-ml-framework-team Thank you for your comment! This is a good catch. We will make it aligned with SGLang's actual precision, depending on the runtime kernel choice.

@yueming-yuan yueming-yuan merged commit 840a33c into yueming-yuan:deepseek-v4 May 15, 2026
@yueming-yuan yueming-yuan deleted the dpskv4/tilekernels-integration branch May 15, 2026 23:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants