[V4-Flash] miles framework patches + V4 plugin NaN fixes#4
Open
kakisong wants to merge 11 commits into
Open
Conversation
b3d7a2a to
50fed8b
Compare
…caused NaN grads) apply_rotary_emb mutated its input in-place via a view on the rotated half. When upstream callers (deepseek_v4.py, compressor.py, v4_indexer.py) wrap the function inside a custom autograd.Function, the in-place op silently corrupted the saved-for-backward tensor and produced NaN gradients deep into the V4 attention path. Switch to functional `torch.cat` and update callers to consume the return value instead of relying on the in-place side effect.
Two latent bugs in the torch reference implementation that surfaced when the tilelang sparse-MLA bwd kernel was found to NaN on real workloads and we needed sparse_attn_torch as a sanity baseline: 1. dtype lost — `q = q.float()` overwrites q before the trailing `o.to(q.dtype)` cast, so the reference always returned fp32. Save orig_dtype before .float() and cast back at the end. 2. fully-masked rows blew up — when a query has zero valid topk entries (early causal positions), `scores.max(...)` returns -inf and `scores - (-inf) = NaN -> exp(NaN) = NaN`. Clamp scores_max to -1e30 to match dense_attn_torch behavior. Verified on Stage A reproducer + Stage B0 64-GPU smoke against dense_attn_torch.
…NaN asserts
Root cause: bwd kernel decorator enabled tilelang's
`TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE` pass, which aliased
`acc_dkv_shared` (fp32) with `Q_shared`/`KV_shared`/`dQ_shared` (bf16) in
physical shared memory. The split_store atomic_addx4 to dKV then wrote
fp32 bytes that the next loop iteration read back as bf16 Q/KV, producing
NaN columns in dq and 100% NaN dkv on V4-Flash production shapes
(B=1 S=1280 H=64 D=512 topk=640).
The bug shipped from day one because the existing test's compute_diff
fell back to rel_diff=0.0 when denom was NaN (`NaN > 0` is False),
silently passing on completely-NaN gradients. Production was running on
dense_attn_torch (17s/iter) as a workaround.
Fixes:
- kernel/tilelang_sparse_mla_bwd.py: drop AGGRESSIVE_SHARED_MEMORY_MERGE.
Use a safe (in-bounds) KV index when mask=False to avoid raw -1
pointer arithmetic reading NaN-bearing memory.
- attention_core.py: fold attn_sink into the LSE shift in sparse/dense
torch refs so fully-masked rows do not produce 0*exp(1e30)=NaN in
d_attn_sink autograd.
- tests/test_v4_tilelang_sparse_mla.py: explicit NaN asserts in
compute_diff so future regressions cannot fake-pass.
CP-packed sequences can spill ~5% past the trained context due to dynamic-batch padding alignment. Observed on 64K shape: seqlen_local=8448 × CP=8 = 67584, exceeding the freqs_cis table size (65536) and triggering a view shape mismatch in get_freqs_cis_for_cp. YaRN scaling factor (4-16x) keeps this table mathematically valid out to 4-16x original_max_position_embeddings (~262K-1M), so 2x headroom is well within range and no other YaRN params need tuning.
Add `deepseek_v4` to --loss-mask-type and implement gen_multi_turn_loss_mask_deepseek_v4 in MultiTurnLossMaskGenerator. The V4 chat template renders messages via the model's encoder rather than a single jinja template; loss mask is computed by tokenizing the rendered conversation with offset tracking and marking tokens inside [start, end) of each assistant span.
…mpat transformers 5.3 V3Config consolidates legacy fields like rope_theta into rope_parameters dict and drops V4-only fields. Load deepseek_v4 / deepseek_ref / deepseek_v3 via the V3 architecture, then re-attach raw config.json entries (rope_theta, etc.) and force V4 topology defaults (first_k_dense_replace=0, intermediate_size=moe_intermediate_size) so mbridge does not build phantom dense MLP layers and try to load nonexistent gate_proj/up_proj weights.
Megatron's `iteration` is the 1-indexed count of completed steps (= rollout_id + 1), but the megatron backend was passing the raw 0-indexed rollout_id directly as iteration. This made `iter_<N>` on disk lag behind by 1, breaking `--save-retain-interval`: e.g. with save_interval=50 + retain_interval=100, saves landed at iter_49/99/149/...; none were divisible by 100, so Megatron's retain check deleted every mid-train ckpt and only the final survived. Add the +1 in actor.save_model() to align with the FSDP backend's `step_id = iteration + 1` convention (fsdp_utils/checkpoint.py:199), and drop the +1 in the load path since the on-disk iteration is now already next_rollout_id. Verified on cp_smoke 8K, 6 steps, --save-interval 2 --save-retain- interval 4: iter_2/4/6 written, iter_2 retain-deleted (2%4!=0), iter_4/6 retained, tracker=6, job SUCCEEDED. Old ckpts saved before this fix load fine; their tracker reads iteration=N → start_rollout_id=N → process re-does rollout N once (one-shot, not cumulative).
…inting In long-context (e.g. 64K seq) configs, the first dist_checkpointing D2H save in a process can fail with cudaErrorInvalidValue when it lands at high iteration (~50+). Fresh starts at iter 0-30 succeed and prime cached pinned-mem pool / IPC handles / CUDA streams that all later saves reuse. Trigger one save after the first rollout's train_step (NOT before the loop -- precision-aware-optimizer's master_param slot is only populated after the first optimizer.step, so a pre-loop save raises KeyError: 'master_param' from distrib_optimizer.sharded_state_dict). Cost: one extra ckpt at iter_<start_rollout_id+1>. Megatron's --save-retain-interval reaps it on the next regular save, no special cleanup needed. Verified on 64K shape (TP=2 PP=4 CP=8 EP=8) for both fresh start (iter_1 + iter_10 saves succeed) and resume from iter_N (resumed- process warmup save iter_N+1 succeeds). Matches the prior failure mode where first-save-at-iter_50 reliably hit cudaErrorInvalidValue.
Three changes from the original profile_utils:
1. Default to a small rank subset {0, 32, 48} rather than every rank.
64-rank trace capture writes 64 separate ~1 GB files per active
step; cutting to 3 representative ranks (head / mid / tail PP
stages) keeps trace volume manageable. Override with the
MILES_PROFILE_RANKS env var (comma-separated) when a different
shape needs different ranks.
2. Return None for non-selected ranks instead of constructing a
torch.profiler.profile(schedule(active=0)). The latter trips
torch's `assert active > 0`. Caller (_profile_simple_loop) now
None-checks and yields the iterator unchanged.
3. Disable record_shapes / with_stack / profile_memory by default.
with_stack=True deadlocks NCCL on 64-rank H20 + RoCE
(100% reproducible: every rank stalls in `Timer data_preprocess
start` for >5 minutes); the other two bloat trace size and slow
steps without adding the kernel-level wall breakdown we want for
CP/EP/PP collective analysis.
Match the DeepSeek V4 KV-QAT path to the official runtime by using an in-place FP8 quant-dequant simulation with FP32 scales instead of returning FP8 values and multiplying UE8M0 scales in Python. Only quantize non-RoPE KV dimensions in the vanilla and compressed-KV paths, preserving RoPE dimensions as BF16. Rebuild tensors with cat so the compressor path avoids slice overlap writes. Add gated internal attention trace captures behind MILES_DSV4_TRACE_INTERNALS for replay/debugging. These hooks are no-ops by default, and the numerical changes only apply when MEGATRON_USE_KV_QAT=1; the default BF16 SFT path is unchanged.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
miles framework + V4 model plugin patches needed to run the DeepSeek V4-Flash 284B SFT pipeline. Task-side scripts (cluster bring-up, data/ckpt prep, end-to-end runners) live in a standalone repo: https://github.com/kakisong/deepseek-v4-flash-training
This PR is 9 atomic commits across 13 files (+288/-35 lines), organised by concern:
V4 model plugin fixes (
miles_plugins/models/deepseek_v4/)fbcca59—apply_rotary_embreturns a new tensor instead of mutating a view. The in-place op silently produced NaN grads when the upstream call wrapped it in a custom autograd Function. Switched to functionalcat.88a54b9—sparse_attn_torchdtype + fully-masked clamp.q = q.float()overwrote q before theo.to(q.dtype)cast; fully-masked rows produced NaN becausescores - (-inf)blew up. Mirrors thedense_attn_torchclamp.8df91c1—tilelang_sparse_mla_bwdproduced 100% NaN dq/dkv on V4-Flash production shapes.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGEaliasedacc_dkv_shared(fp32) withQ/KV/dQ_shared(bf16); the split_store atomic_addx4 wrote fp32 bytes that the next iteration read back as bf16 NaN. Disabling the flag (and using a safe in-bounds KV index when masked) restores finite, correct gradients. Same commit adds explicit NaN asserts incompute_diffso the existing unit test stops fake-passing on NaN gradients (NaN > 0 == Falsewas falling into therel_diff = 0branch).8835a1a— bump V4 ropemax_seq_len65536 → 131072 to match V4-Flash's published context length.miles framework patches (V4 support)
9caac1d—[feat]deepseek_v4 multi-turn loss mask. Addsdeepseek_v4to--loss-mask-typeand implementsgen_multi_turn_loss_mask_deepseek_v4inMultiTurnLossMaskGenerator. The V4 chat template renders messages via the model encoder rather than a single jinja template; the new generator tokenizes the rendered conversation with offset tracking and marks tokens inside[start, end)of each assistant span.c944fca—[fix]load V4 hf config through V3 architecture for mbridge compat.transformers5.3 V3Config consolidatesrope_thetaetc. intorope_parametersand drops V4-only fields. Loaddeepseek_v4/deepseek_ref/deepseek_v3via the V3 architecture, then re-attach rawconfig.jsonentries and force V4 topology defaults (first_k_dense_replace=0,intermediate_size=moe_intermediate_size) so mbridge does not build phantom dense MLP layers and try to load nonexistentgate_proj/up_projweights.miles framework fixes (non-V4-specific)
32ee40a—[fix]megatron backend ckpt iteration off-by-one. Save was triggered withstep=rollout_id+1but Megatroniterationgot the rawrollout_id, so ckpts were namediter_49 / iter_99 / …instead ofiter_50 / iter_100 / …and--save-retain-interval=100never matched any checkpoint, deleting all mid-training ckpts.ca532c7—[fix]warmup save after first train step. dist_checkpointing async D2H lazy-initialises on first save; under long runs (64K + 7.5h + 120 steps) the first real save crashed withcudaErrorInvalidValue. Forcing a no-op save after step 1 primes the lazy state and removes the failure mode.34684df—[feat]profile_utilsproductisation for multi-rank trace capture. Lets PP/CP profile traces be collected on a configurable rank subset instead of always rank 0.What's NOT in this PR (moved out)
The cluster scripts, data/ckpt prep tools, runners and SFT-pipeline documentation that were in earlier revisions of this PR have been extracted to a separate repo so this PR cleanly contains only miles framework patches:
→ https://github.com/kakisong/deepseek-v4-flash-training
The new repo references this fork via
$V4_MILES_REPO(default$V4_WORK/mileson CFS);cluster/bring_up_cluster.shauto-clones the fork on first bring-up.Validation evidence
Production validation (
stageProd-20260510-231521, 64-GPU, 1600-step, V4-Flash 284B / OpenHermes):Unit tests:
tests/deepseekv4/test_v4_tilelang_sparse_mla.pypasses withrel_diff ~1–2e-6against torch refs (vs the prior NaN-blind PASS).Test plan
pytest tests/deepseekv4/test_v4_tilelang_sparse_mla.py— 22 cases pass, real numerical agreement