Skip to content

DeepSeek V4 RL support#1045

Open
yueming-yuan wants to merge 50 commits into
radixark:mainfrom
yueming-yuan:deepseek-v4
Open

DeepSeek V4 RL support#1045
yueming-yuan wants to merge 50 commits into
radixark:mainfrom
yueming-yuan:deepseek-v4

Conversation

@yueming-yuan
Copy link
Copy Markdown
Collaborator

@yueming-yuan yueming-yuan commented Apr 24, 2026

Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements comprehensive support for the DeepSeek-V4 model, introducing specialized attention and indexer layers via TileLang kernels, a transformers patching utility, and atomic bucketing for SGLang weight updates. Feedback identifies several critical improvements: the router's proxy response should preserve upstream headers to maintain features like compression, and the attn_sink parameter must be zero-initialized to ensure training stability. Reviewers also noted an inconsistency in padding size calculations between modules and recommended using 64-bit integers for token counts to prevent overflow. Furthermore, the weight update logic should be hardened against orphan groups to avoid SGLang assertion errors, and redundant top-k calculations in the indexer should be eliminated for better efficiency.

Comment thread miles/router/router.py
Comment on lines +180 to +181
content_type: str = headers.get("content-type", "application/json")
return Response(content=content, status_code=status_code, media_type=content_type)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _build_proxy_response method is currently dropping all upstream headers except for content-type. This will break critical functionality such as compression (Content-Encoding), caching, and session management (Set-Cookie). While the goal is to avoid re-serialization issues with Content-Length, you should still pass the original headers after removing the content-length entry, allowing the Response object to calculate it correctly for the provided body.

Suggested change
content_type: str = headers.get("content-type", "application/json")
return Response(content=content, status_code=status_code, media_type=content_type)
content_type: str = headers.get("content-type", "application/json")
headers.pop("content-length", None)
return Response(content=content, status_code=status_code, headers=headers, media_type=content_type)

config_no_sp = copy.copy(config)
config_no_sp.sequence_parallel = False

self.attn_sink = nn.Parameter(torch.empty(self.n_local_heads, dtype=torch.float32))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The attn_sink parameter is initialized using torch.empty, which leaves it with uninitialized (garbage) values. Since this is a learnable parameter that affects the softmax denominator, it should be initialized to a consistent starting value, such as zero, to ensure stable training.

Suggested change
self.attn_sink = nn.Parameter(torch.empty(self.n_local_heads, dtype=torch.float32))
self.attn_sink = nn.Parameter(torch.zeros(self.n_local_heads, dtype=torch.float32))

pad = (pad_size - replay_data.size(0) % pad_size) % pad_size
if pad != 0:
replay_data = pad_func(replay_data, pad)
pad_size = self.parallel_state.dp_size * self.args.data_pad_size_multiplier
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is an inconsistency in the pad_size calculation between actor.py and data.py. In miles/backends/training_utils/data.py, pad_size is calculated using parallel_state.tp_size, whereas here it uses parallel_state.dp_size. This discrepancy can lead to misaligned data shapes during training if the tensor parallel size and data parallel size differ. Ensure that dimensions are derived consistently from the configuration or parallel state.

Suggested change
pad_size = self.parallel_state.dp_size * self.args.data_pad_size_multiplier
pad_size = self.parallel_state.tp_size * self.args.data_pad_size_multiplier
References
  1. Avoid hardcoding model dimensions; derive them from configuration or input tensor shapes instead.

Comment on lines +194 to +195
for items, items_size in pending_groups.values():
_commit_atomic(items, items_size)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Processing orphan groups at the end of bucketing might lead to a crash in SGLang. The atomic bucketing logic is designed to ensure that fusion pairs (like wq_a and wkv) land in the same update bucket to satisfy SGLang's load_weights assertions. If only one half of a pair is present, committing it alone will likely trigger an AssertionError in SGLang. It is safer to raise an informative error if pending_groups is not empty here.

return (
loss,
torch.tensor(num_tokens if args.calculate_per_token_loss else 1, device=logits.device),
torch.tensor(num_tokens if args.calculate_per_token_loss else 1, dtype=torch.int, device=logits.device),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using torch.int (32-bit) for token counts can lead to overflow in large-scale training scenarios when these values are reduced across many ranks. It is safer to use torch.long (64-bit) to ensure correctness for high token counts.

Suggested change
torch.tensor(num_tokens if args.calculate_per_token_loss else 1, dtype=torch.int, device=logits.device),
torch.tensor(num_tokens if args.calculate_per_token_loss else 1, dtype=torch.long, device=logits.device),

index_scores = batched_indexer_fwd(q, k, weights.float(), cu_ks, cu_ke)

topk_k = min(self.index_topk, index_scores.size(-1))
topk_indices = index_scores.topk(topk_k, dim=-1)[1]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This manual topk calculation is redundant because the topk_fn (which handles both normal and replay modes) is called immediately after on line 147. Removing this line will improve efficiency by avoiding an unnecessary GPU kernel launch. Ensure that all indexing parameters are retrieved from the model configuration.

References
  1. Model parameters, such as index_topk, should be retrieved from the model configuration rather than being hardcoded.

Copy link
Copy Markdown
Collaborator

@guapisolo guapisolo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GOAT!!!!!!

@ByronHsu
Copy link
Copy Markdown

GOAT

yueming-yuan and others added 7 commits April 24, 2026 14:12
Avoids _prepare_cp Ray dependency when only --model-dir is overridden:
both fields now resolve to the same path unless user explicitly fans out
to per-node local NVMe via --model-local-dir.
Pinaster/DeepSeek-V4-Flash-FP8-4layer ships with model_type=deepseek_v4 in
config.json, but SGLang's get_config fallback (with SGLANG_APPLY_CONFIG_BACKUP=none)
only fires on deepseek_ref. Rewrite the local config.json in-place so SGLang's
_load_deepseek_temp_model gets reached. Idempotent; no-op for non-4-layer models.
…epseek-v4 @ 8e1ef3c)

Verified single-node 8xH200 training of DeepSeek-V4-Flash-FP8-4layer
(Pinaster/...) end-to-end through 12 GRPO rollouts (steps 0-11, all
loss=0 since 4-layer prune produces gibberish + reward=0, but pipeline
ran cleanly). Job raysubmit_WTh5sDDAaRcVrfTT.

Required environment additions on top of the upstream Dockerfile:
- sglang from sgl-project/sglang@deepseek_v4 (pip install -e python pulled
  sgl-kernel==0.3.21 and downgraded fastapi to 0.115.x)
- flashinfer-jit-cache==0.6.8+cu129 (matching flashinfer-python 0.6.8)
- tilelang==0.1.8 (PyPI release; 0.1.9 dropped wg_wait from T.gemm public API)
- flash_mla 1.0.0+71c7379 from deepseek-ai/FlashMLA (CUDA-built)
- fast_hadamard_transform 1.1.0 from Dao-AILab/fast-hadamard-transform
- Megatron-LM at radixark/Megatron-LM PR radixark#28 (mlm-pr28, commit 8455dbf,
  in both /workspace/Megatron-LM and /root/Megatron-LM editable install)

Run-time tweaks:
- bf16 ckpt config.json: model_type "deepseek_v4" -> "deepseek_v3" (to
  satisfy AutoConfig.from_pretrained on transformers 4.57.1)
- bf16 ckpt config.json: drop quantization_config (avoid sglang creating
  fp8 model params when our weights are bf16)
- SGLANG_APPLY_CONFIG_BACKUP=none (otherwise sglang substitutes the
  packaged 43-layer config and breaks miles' hf_validate_args)

dsv4_flash_to_bf16.py was written for the cluster's MXFP4 4-layer prune
(DeepSeek-V4-Flash-4layer); the actual 12-step run used the standard FP8
prune (Pinaster/DeepSeek-V4-Flash-FP8-4layer) so the tool was not
exercised. Keeping it for the MXFP4 path.
@hebiao064
Copy link
Copy Markdown

could you please share the kl and reward dashboard if possible? thanks

@Ying1123 Ying1123 mentioned this pull request Apr 27, 2026
11 tasks
Zhichenzzz and others added 3 commits April 28, 2026 05:48
Drop-in replacement of the in-tree no-grad ``hc_split_sinkhorn`` /
``hc_pre_raw`` / ``hc_post_raw`` / ``hc_head_raw`` paths with calls
into ``tile_kernels.modeling.mhc.ops`` (sinkhorn, pre_norm_fn,
pre_split_mixes, pre_apply_mix, post, head_compute_mix, plus the
fused inference ``pre_big_fuse`` for ``no_grad`` paths).

Public class API (``DeepSeekV4HyperConnectionUtil``, ``HCHeadParams``)
unchanged so the radixark/Megatron-LM PR radixark#28 call sites in
``transformer_layer.py`` / ``transformer_block.py`` keep working.

Net effect:
- forward path now matches the canonical TileKernels kernels (which
  fuse RMS-norm + GEMM split-K, sinkhorn, and the pre-apply mix)
- backward path is enabled (the original code asserted
  ``_HYPER_CONNECTION_MIXER_NO_GRAD = True`` and wrapped everything
  in ``torch.no_grad``)
- ``post = 2 * sigmoid(...)`` reproduced via ``post_mult_value=2.0``;
  PR's single ``hc_eps`` reused for both ``pre_eps`` and
  ``sinkhorn_eps``

Tweaks vs upstream TileKernels' high-level wrappers:
- inline ``mhc_pre`` body so we can pass ``fuse_grad_acc=False`` to
  ``mhc_pre_norm_fn``. The default (``True``) requires ``mhc_post``
  to have written ``grad_from_mhc_post`` onto the same residual
  storage during backward, but Megatron's call sites use independent
  ``s b hc d`` -> ``b s hc d`` einops.rearrange'd tensors for
  ``layer_pre`` and ``layer_post`` so the storage objects don't match.
- inline ``mhc_head`` body to ``.contiguous()`` the
  ``mixes[..., :mhc_mult]`` slice before feeding it to
  ``mhc_head_compute_mix_fwd_kernel`` (the kernel asserts
  ``strides[0] == mhc_mult`` but the slice keeps the
  ``mhc_mult * (mhc_mult + 2)`` parent stride).
- ``.contiguous()`` everywhere we hand bf16 tensors to TileLang
  kernels so the no-grad fused ``pre_big_fuse`` path doesn't trip
  the ``view`` stride check.

Verified end-to-end: 3-rollout GRPO smoke test on
DeepSeek-V4-Flash-FP8-4layer (single-node 8xH200, GRPO with rollout
batch 8, 4 samples/prompt, 64 tok cap) reaches steps 0..2 with
entropy_loss / logprob_abs_diff in the same band as the original
in-tree implementation (rl-smoke-pass tag) — see job
raysubmit_B3DgGam9vCYuDAVU.
- ``kernel/act_quant.py`` is now a thin wrapper around
  ``tile_kernels.quant.per_token_cast(fmt='e4m3', round_sf=True)``;
  shape/dtype contract preserved (``(y_fp8, s_fp32)`` with
  ``s.shape == (*x.shape[:-1], N // block_size)``) so callers in
  ``qat.py`` / ``compressor.py`` / ``v4_indexer.py`` don't change.
- ``kernel/sinkhorn.py`` is removed: it was only consumed by the
  legacy ``hyper_connection.py`` path which now calls
  ``tile_kernels.modeling.mhc.ops.sinkhorn_normalize`` instead.

End-to-end audit of ``miles_plugins/models/deepseek_v4/ops/kernel/``:
- act_quant.py             -> tile_kernels.quant.per_token_cast (this commit)
- sinkhorn.py              -> tile_kernels.modeling.mhc.ops (Batch 1, removed)
- tilelang_indexer*.py     -> no TileKernels equivalent (DSV4 DSA-specific)
- tilelang_sparse_mla*.py  -> no TileKernels equivalent (DSV4 sparse-MLA)

So every ``kernel/`` file that has a TileKernels analogue now routes
through TileKernels; the remaining files are V4-specific and have no
upstream replacement.

Verified end-to-end with the 12-rollout GRPO smoke harness on
DeepSeek-V4-Flash-FP8-4layer (single-node 8xH200, run id
260425-054701-579 -> raysubmit_a4YH6wKJ963VrTg9). All 12 steps
(0..11) completed cleanly; entropy_loss / logprob_abs_diff land in
the same band as the rl-smoke-pass baseline (12-step run on the
unmodified PR @ a72ed84):

  step | baseline ent | TK ent
  -----+--------------+--------
   0   | 1.6858       | 1.7286
   1   | 1.7854       | 1.7321
   2   | 1.7360       | 1.7803
   3   | 1.7696       | 1.6750
   4   | 1.6815       | 1.7532
   5   | 1.7172       | 1.6865
   6   | 1.7394       | 1.6949
   7   | 1.7634       | 1.7217
   8   | 1.6674       | 1.6871
   9   | 1.7330       | 1.6943
  10   | 1.6319       | 1.6998
  11   | 1.6962       | 1.7321
- ops/qat.py:fp8_simulate now uses tile_kernels.quant.per_token_cast_back
  for the FP8→BF16 dequant step (was a manual unflatten/multiply).
- tools/fp8_cast_bf16.py:weight_dequant replaces the in-tree Triton kernel
  with tile_kernels.quant.cast_back (128x128 block FP8 dequant).

Both substitutions are bit-exact against the prior implementations on
GPU (max_diff = 0.0 across 2D/3D/4D inputs and 256x768 fp8 weights).
yueming-yuan and others added 2 commits April 28, 2026 17:48
…cision

TileKernels' MHC fp32 GEMM uses TF32 tensor cores on H100/H200 (H100+
fp32 GEMM has no full-precision tensor-core path; TF32 is the fastest
fp32 mode).  PyTorch's default ``torch.backends.cuda.matmul.allow_tf32 =
False`` forces fp32 F.linear onto the SIMT path, which introduced a
~1e-4 mean-abs gap vs the TileKernels-backed HC mixer.

Setting allow_tf32 = True at deepseek_v4 plugin import time:

  HC parity (TileKernels MHC fwd vs legacy in-tree fwd, no-grad):
                       mean_abs    max_abs    notes
    layer_input        1.05e-5     1.56e-2    bf16 LSB output
    pre.post           1.52e-5     7.65e-5
    pre.comb           4.61e-6     2.87e-5
    hc_post out        1.15e-8     9.77e-4    fp32-equivalent
    hc_head out        1.14e-5     1.56e-2    bf16 LSB output

All ops <= 1.5e-5 mean-abs (matches the attention/indexer 1e-5 bar).
The max-abs values are 1 ULP of bf16 at the output magnitude (~6.0,
1 ULP = 2^-5 = 3.13e-2).  Other fp32 matmuls in the plugin (compressor,
indexer projections) get a free TF32 speed-up as a side effect.
Comment thread examples/train_infer_mismatch_helper/mis.py
yueming-yuan and others added 22 commits May 14, 2026 16:42
# Conflicts:
#	miles_plugins/models/deepseek_v4/ops/hyper_connection.py
#	miles_plugins/models/deepseek_v4/ops/qat.py
…ware-aware scale dtype

Replace the TileKernels per_token_cast wrapper with a verbatim port of
deepseek-ai/DeepSeek-V4-Pro/inference/kernel.py:act_quant so this code
path is bit-exact with the upstream inference kernel.

The ported act_quant exposes scale_dtype and inplace, matching official:
  - scale_dtype=None auto-selects via SGLang should_deepgemm_weight_requant_ue8m0
    (Blackwell + DeepGEMM JIT -> float8_e8m0fnu; Hopper -> float32).
  - explicit scale_dtype overrides the auto path either way.
  - inplace=True wires through to the fused quant+dequant kernel.

Drops the tile_kernels import for this file; the package was not listed
in any requirements manifest. Caller (qat.py:fp8_simulate) is unaffected:
on Hopper the auto path resolves to float32, preserving prior behavior.

Verified on H200:
  - 26 kernel cases bit-exact vs official act_quant (4 shapes x 5 modes + 3D).
  - 4-layer iterated fp8_simulate stable, ~2.25% mean rel noise as expected for E4M3.
  - Hardware auto-resolve matches sglang.deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0.
@Zhichenzzz
Copy link
Copy Markdown
Contributor

could you please share the kl and reward dashboard if possible? thanks

Hi @hebiao064 We have some initial reasoning task reward curve for the reference!
image

kakisong added a commit to kakisong/miles that referenced this pull request May 18, 2026
…ripts

Brings up the full V4-Flash 284B SFT pipeline on an 8-node × 8-GPU H20
cluster, plus the layered config infrastructure to make it portable
across H20 / H200 / different cluster sizes.

Pipeline scripts (examples/deepseek_v4_sft/):
  - prepare_data.py / prepare_megatron_ckpt.sh / verify_chat_template.py
  - run_stage_{a_smoke,b0_dryrun,b_full}.sh   reference runners
  - tools/megablocks_to_hf_bf16.py            FP8/megablocks -> unpacked BF16 HF
                                              (the missing piece in PR radixark#1045's
                                              ckpt conversion)
  - tools/tilelang_*.py                       sparse-MLA repro / bisect /
                                              shape-sweep harnesses used to
                                              isolate the V4 NaN bug

Cluster scripts (examples/deepseek_v4_sft/cluster/):
  - bring_up_cluster.sh / tear_down.sh / convert_hf_to_megatron.sh
  - run_stage_a.sh                            1-node 4-layer NaN reproducer

  Layered config (so a new cluster only edits one file):
    env/base.env                              project path conventions
    env/cluster_h20_8node.env                 current H20 cluster
    env/cluster_h200_template.env             H200 example
    env.sh                                    backwards-compat shim
    hw/{h20,h200}.env                         GPU-specific defaults
    presets/{smoke,validation,cp_smoke,
             prod,long_context}.env           workload presets
    lib/preflight.sh                          shared pre-flight checks
    run.sh                                    unified launcher with CLI overrides
    BOOTSTRAP.md                              new-cluster setup checklist

Integration patches required by V4 on this image / mbridge / sglang stack:
  - miles/utils/arguments.py / mask_utils.py / transformers_patch.py
  - tools/fp8_cast_bf16.py

.gitignore picks up two `!` exceptions because the root rules ignore
`lib/` and `env/` as Python-build leftovers.
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE: True,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @yueming-yuan, nice to meet you. This is Peng. I’m cherry-picking your code for DeepSeek V4 training. Thank you for the implementation!

I ran into a NaN issue during the backward pass, and after debugging it appears to come from enabling TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE in the sparse MLA/MQA backward kernel. In our setup, aggressive shared-memory merge can alias/corrupt the region used by Q_shared, which matches the failure mode described in this TileLang PR: tile-ai/tilelang#2204.

Would you be okay with removing/disabling this pass config temporarily to unblock training until the upstream TileLang fix lands and we can use a version that includes it?

Thank you!

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @yueming-yuan, nice to meet you. This is Peng. I’m cherry-picking your code for DeepSeek V4 training. Thank you for the implementation!

I ran into a NaN issue during the backward pass, and after debugging it appears to come from enabling TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE in the sparse MLA/MQA backward kernel. In our setup, aggressive shared-memory merge can alias/corrupt the region used by Q_shared, which matches the failure mode described in this TileLang PR: tile-ai/tilelang#2204.

Would you be okay with removing/disabling this pass config temporarily to unblock training until the upstream TileLang fix lands and we can use a version that includes it?

Thank you!

It works when set TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE False. Thanks.

kakisong added a commit to kakisong/miles that referenced this pull request May 20, 2026
…ripts

Brings up the full V4-Flash 284B SFT pipeline on an 8-node × 8-GPU H20
cluster, plus the layered config infrastructure to make it portable
across H20 / H200 / different cluster sizes.

Pipeline scripts (examples/deepseek_v4_sft/):
  - prepare_data.py / prepare_megatron_ckpt.sh / verify_chat_template.py
  - run_stage_{a_smoke,b0_dryrun,b_full}.sh   reference runners
  - tools/megablocks_to_hf_bf16.py            FP8/megablocks -> unpacked BF16 HF
                                              (the missing piece in PR radixark#1045's
                                              ckpt conversion)
  - tools/tilelang_*.py                       sparse-MLA repro / bisect /
                                              shape-sweep harnesses used to
                                              isolate the V4 NaN bug

Cluster scripts (examples/deepseek_v4_sft/cluster/):
  - bring_up_cluster.sh / tear_down.sh / convert_hf_to_megatron.sh
  - run_stage_a.sh                            1-node 4-layer NaN reproducer

  Layered config (so a new cluster only edits one file):
    env/base.env                              project path conventions
    env/cluster_h20_8node.env                 current H20 cluster
    env/cluster_h200_template.env             H200 example
    env.sh                                    backwards-compat shim
    hw/{h20,h200}.env                         GPU-specific defaults
    presets/{smoke,validation,cp_smoke,
             prod,long_context}.env           workload presets
    lib/preflight.sh                          shared pre-flight checks
    run.sh                                    unified launcher with CLI overrides
    BOOTSTRAP.md                              new-cluster setup checklist

Integration patches required by V4 on this image / mbridge / sglang stack:
  - miles/utils/arguments.py / mask_utils.py / transformers_patch.py
  - tools/fp8_cast_bf16.py

.gitignore picks up two `!` exceptions because the root rules ignore
`lib/` and `env/` as Python-build leftovers.
kakisong added a commit to kakisong/miles that referenced this pull request May 20, 2026
…ripts

Brings up the full V4-Flash 284B SFT pipeline on an 8-node × 8-GPU H20
cluster, plus the layered config infrastructure to make it portable
across H20 / H200 / different cluster sizes.

Pipeline scripts (examples/deepseek_v4_sft/):
  - prepare_data.py / prepare_megatron_ckpt.sh / verify_chat_template.py
  - run_stage_{a_smoke,b0_dryrun,b_full}.sh   reference runners
  - tools/megablocks_to_hf_bf16.py            FP8/megablocks -> unpacked BF16 HF
                                              (the missing piece in PR radixark#1045's
                                              ckpt conversion)
  - tools/tilelang_*.py                       sparse-MLA repro / bisect /
                                              shape-sweep harnesses used to
                                              isolate the V4 NaN bug

Cluster scripts (examples/deepseek_v4_sft/cluster/):
  - bring_up_cluster.sh / tear_down.sh / convert_hf_to_megatron.sh
  - run_stage_a.sh                            1-node 4-layer NaN reproducer

  Layered config (so a new cluster only edits one file):
    env/base.env                              project path conventions
    env/cluster_h20_8node.env                 current H20 cluster
    env/cluster_h200_template.env             H200 example
    env.sh                                    backwards-compat shim
    hw/{h20,h200}.env                         GPU-specific defaults
    presets/{smoke,validation,cp_smoke,
             prod,long_context}.env           workload presets
    lib/preflight.sh                          shared pre-flight checks
    run.sh                                    unified launcher with CLI overrides
    BOOTSTRAP.md                              new-cluster setup checklist

Integration patches required by V4 on this image / mbridge / sglang stack:
  - miles/utils/arguments.py / mask_utils.py / transformers_patch.py
  - tools/fp8_cast_bf16.py

.gitignore picks up two `!` exceptions because the root rules ignore
`lib/` and `env/` as Python-build leftovers.
@tang-t21
Copy link
Copy Markdown

tang-t21 commented May 21, 2026

Hi @yueming-yuan and folks, thanks for putting together the DeepSeek V4 support here.

I hit the same tilelang sparse MLA backward NaN issue while validating Bridge TP8/EP8 e2e training, and opened a small stacked fix here:

yueming-yuan#4

I also noticed @pengdurice had already reported what looks like the same root cause in this thread: #1045 (comment). The behavior matches what we saw: forward loss stays finite, but backward produces NaN gradients when TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE is enabled in tilelang_sparse_mla_bwd.py. Disabling that pass makes the standalone backward test finite and lets the Bridge e2e run complete.

Validation summary from our side:

  • standalone test_sparse_mla_backward: 5 passed after adding a finite-gradient assertion
  • Bridge tilelang backend 3 iter: losses 0.9047399, 1.2337190, 0.5607164; NaN iterations 0

This is stacked on top of this PR branch, so the diff should just be the temporary pass-config removal plus the test guard. Thanks again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.