Skip to content

[V4-Flash] miles framework patches + V4 plugin NaN fixes#4

Open
kakisong wants to merge 11 commits into
v4-rl-basefrom
v4-flash-sft
Open

[V4-Flash] miles framework patches + V4 plugin NaN fixes#4
kakisong wants to merge 11 commits into
v4-rl-basefrom
v4-flash-sft

Conversation

@kakisong
Copy link
Copy Markdown
Owner

@kakisong kakisong commented May 11, 2026

miles framework + V4 model plugin patches needed to run the DeepSeek V4-Flash 284B SFT pipeline. Task-side scripts (cluster bring-up, data/ckpt prep, end-to-end runners) live in a standalone repo: https://github.com/kakisong/deepseek-v4-flash-training

This PR is 9 atomic commits across 13 files (+288/-35 lines), organised by concern:

V4 model plugin fixes (miles_plugins/models/deepseek_v4/)

  1. fbcca59apply_rotary_emb returns a new tensor instead of mutating a view. The in-place op silently produced NaN grads when the upstream call wrapped it in a custom autograd Function. Switched to functional cat.
  2. 88a54b9sparse_attn_torch dtype + fully-masked clamp. q = q.float() overwrote q before the o.to(q.dtype) cast; fully-masked rows produced NaN because scores - (-inf) blew up. Mirrors the dense_attn_torch clamp.
  3. 8df91c1tilelang_sparse_mla_bwd produced 100% NaN dq/dkv on V4-Flash production shapes. TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE aliased acc_dkv_shared (fp32) with Q/KV/dQ_shared (bf16); the split_store atomic_addx4 wrote fp32 bytes that the next iteration read back as bf16 NaN. Disabling the flag (and using a safe in-bounds KV index when masked) restores finite, correct gradients. Same commit adds explicit NaN asserts in compute_diff so the existing unit test stops fake-passing on NaN gradients (NaN > 0 == False was falling into the rel_diff = 0 branch).
  4. 8835a1a — bump V4 rope max_seq_len 65536 → 131072 to match V4-Flash's published context length.

miles framework patches (V4 support)

  1. 9caac1d[feat] deepseek_v4 multi-turn loss mask. Adds deepseek_v4 to --loss-mask-type and implements gen_multi_turn_loss_mask_deepseek_v4 in MultiTurnLossMaskGenerator. The V4 chat template renders messages via the model encoder rather than a single jinja template; the new generator tokenizes the rendered conversation with offset tracking and marks tokens inside [start, end) of each assistant span.
  2. c944fca[fix] load V4 hf config through V3 architecture for mbridge compat. transformers 5.3 V3Config consolidates rope_theta etc. into rope_parameters and drops V4-only fields. Load deepseek_v4 / deepseek_ref / deepseek_v3 via the V3 architecture, then re-attach raw config.json entries and force V4 topology defaults (first_k_dense_replace=0, intermediate_size=moe_intermediate_size) so mbridge does not build phantom dense MLP layers and try to load nonexistent gate_proj/up_proj weights.

miles framework fixes (non-V4-specific)

  1. 32ee40a[fix] megatron backend ckpt iteration off-by-one. Save was triggered with step=rollout_id+1 but Megatron iteration got the raw rollout_id, so ckpts were named iter_49 / iter_99 / … instead of iter_50 / iter_100 / … and --save-retain-interval=100 never matched any checkpoint, deleting all mid-training ckpts.
  2. ca532c7[fix] warmup save after first train step. dist_checkpointing async D2H lazy-initialises on first save; under long runs (64K + 7.5h + 120 steps) the first real save crashed with cudaErrorInvalidValue. Forcing a no-op save after step 1 primes the lazy state and removes the failure mode.
  3. 34684df[feat] profile_utils productisation for multi-rank trace capture. Lets PP/CP profile traces be collected on a configurable rank subset instead of always rank 0.

What's NOT in this PR (moved out)

The cluster scripts, data/ckpt prep tools, runners and SFT-pipeline documentation that were in earlier revisions of this PR have been extracted to a separate repo so this PR cleanly contains only miles framework patches:

https://github.com/kakisong/deepseek-v4-flash-training

The new repo references this fork via $V4_MILES_REPO (default $V4_WORK/miles on CFS); cluster/bring_up_cluster.sh auto-clones the fork on first bring-up.

Validation evidence

Production validation (stageProd-20260510-231521, 64-GPU, 1600-step, V4-Flash 284B / OpenHermes):

metric dense (prior workaround) tilelang (this PR) delta
1600-step wall 14.7h (4 attempts) 9h 27min (single shot)
step_time steady-state ~17.1 s ~13.5 s -21%
actor_train_TFLOPS / GPU ~4.5 ~6.4 +42%
step 1599 loss 0.00096 0.00264 both at SFT plateau

Unit tests: tests/deepseekv4/test_v4_tilelang_sparse_mla.py passes with rel_diff ~1–2e-6 against torch refs (vs the prior NaN-blind PASS).

Test plan

  • pytest tests/deepseekv4/test_v4_tilelang_sparse_mla.py — 22 cases pass, real numerical agreement
  • 64-GPU 2-iter smoke — loss 0.856 → 0.702, no NaN
  • 64-GPU 20-iter SFT validation — loss 0.857 → 0.500, grad_norm 5.37 → 0.65, 0 NaN
  • 64-GPU CP=2 + 8K smoke — Megatron CP × V4 sparse-MLA path verified
  • 64-GPU 1600-step prod run (tilelang) — 9h 27min, ray job SUCCEEDED
  • 64-GPU 1600-step prod run (dense baseline) — completes after neighbor-OOM resume

@kakisong kakisong force-pushed the v4-flash-sft branch 3 times, most recently from b3d7a2a to 50fed8b Compare May 20, 2026 08:21
kakisong added 9 commits May 22, 2026 15:55
…caused NaN grads)

apply_rotary_emb mutated its input in-place via a view on the rotated half.
When upstream callers (deepseek_v4.py, compressor.py, v4_indexer.py) wrap
the function inside a custom autograd.Function, the in-place op silently
corrupted the saved-for-backward tensor and produced NaN gradients deep
into the V4 attention path.

Switch to functional `torch.cat` and update callers to consume the return
value instead of relying on the in-place side effect.
Two latent bugs in the torch reference implementation that surfaced when
the tilelang sparse-MLA bwd kernel was found to NaN on real workloads
and we needed sparse_attn_torch as a sanity baseline:

1. dtype lost — `q = q.float()` overwrites q before the trailing
   `o.to(q.dtype)` cast, so the reference always returned fp32. Save
   orig_dtype before .float() and cast back at the end.

2. fully-masked rows blew up — when a query has zero valid topk
   entries (early causal positions), `scores.max(...)` returns -inf
   and `scores - (-inf) = NaN -> exp(NaN) = NaN`. Clamp scores_max to
   -1e30 to match dense_attn_torch behavior.

Verified on Stage A reproducer + Stage B0 64-GPU smoke against dense_attn_torch.
…NaN asserts

Root cause: bwd kernel decorator enabled tilelang's
`TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE` pass, which aliased
`acc_dkv_shared` (fp32) with `Q_shared`/`KV_shared`/`dQ_shared` (bf16) in
physical shared memory. The split_store atomic_addx4 to dKV then wrote
fp32 bytes that the next loop iteration read back as bf16 Q/KV, producing
NaN columns in dq and 100% NaN dkv on V4-Flash production shapes
(B=1 S=1280 H=64 D=512 topk=640).

The bug shipped from day one because the existing test's compute_diff
fell back to rel_diff=0.0 when denom was NaN (`NaN > 0` is False),
silently passing on completely-NaN gradients. Production was running on
dense_attn_torch (17s/iter) as a workaround.

Fixes:
  - kernel/tilelang_sparse_mla_bwd.py: drop AGGRESSIVE_SHARED_MEMORY_MERGE.
    Use a safe (in-bounds) KV index when mask=False to avoid raw -1
    pointer arithmetic reading NaN-bearing memory.
  - attention_core.py: fold attn_sink into the LSE shift in sparse/dense
    torch refs so fully-masked rows do not produce 0*exp(1e30)=NaN in
    d_attn_sink autograd.
  - tests/test_v4_tilelang_sparse_mla.py: explicit NaN asserts in
    compute_diff so future regressions cannot fake-pass.
CP-packed sequences can spill ~5% past the trained context due to
dynamic-batch padding alignment. Observed on 64K shape:
seqlen_local=8448 × CP=8 = 67584, exceeding the freqs_cis table size
(65536) and triggering a view shape mismatch in get_freqs_cis_for_cp.

YaRN scaling factor (4-16x) keeps this table mathematically valid out
to 4-16x original_max_position_embeddings (~262K-1M), so 2x headroom
is well within range and no other YaRN params need tuning.
Add `deepseek_v4` to --loss-mask-type and implement
gen_multi_turn_loss_mask_deepseek_v4 in MultiTurnLossMaskGenerator.
The V4 chat template renders messages via the model's encoder rather
than a single jinja template; loss mask is computed by tokenizing
the rendered conversation with offset tracking and marking tokens
inside [start, end) of each assistant span.
…mpat

transformers 5.3 V3Config consolidates legacy fields like rope_theta into
rope_parameters dict and drops V4-only fields. Load deepseek_v4 / deepseek_ref /
deepseek_v3 via the V3 architecture, then re-attach raw config.json entries
(rope_theta, etc.) and force V4 topology defaults (first_k_dense_replace=0,
intermediate_size=moe_intermediate_size) so mbridge does not build phantom
dense MLP layers and try to load nonexistent gate_proj/up_proj weights.
Megatron's `iteration` is the 1-indexed count of completed steps
(= rollout_id + 1), but the megatron backend was passing the raw
0-indexed rollout_id directly as iteration. This made `iter_<N>` on
disk lag behind by 1, breaking `--save-retain-interval`:
e.g. with save_interval=50 + retain_interval=100, saves landed at
iter_49/99/149/...; none were divisible by 100, so Megatron's retain
check deleted every mid-train ckpt and only the final survived.

Add the +1 in actor.save_model() to align with the FSDP backend's
`step_id = iteration + 1` convention (fsdp_utils/checkpoint.py:199),
and drop the +1 in the load path since the on-disk iteration is now
already next_rollout_id.

Verified on cp_smoke 8K, 6 steps, --save-interval 2 --save-retain-
interval 4: iter_2/4/6 written, iter_2 retain-deleted (2%4!=0),
iter_4/6 retained, tracker=6, job SUCCEEDED.

Old ckpts saved before this fix load fine; their tracker reads
iteration=N → start_rollout_id=N → process re-does rollout N once
(one-shot, not cumulative).
…inting

In long-context (e.g. 64K seq) configs, the first dist_checkpointing
D2H save in a process can fail with cudaErrorInvalidValue when it
lands at high iteration (~50+). Fresh starts at iter 0-30 succeed and
prime cached pinned-mem pool / IPC handles / CUDA streams that all
later saves reuse.

Trigger one save after the first rollout's train_step (NOT before the
loop -- precision-aware-optimizer's master_param slot is only populated
after the first optimizer.step, so a pre-loop save raises
KeyError: 'master_param' from distrib_optimizer.sharded_state_dict).

Cost: one extra ckpt at iter_<start_rollout_id+1>. Megatron's
--save-retain-interval reaps it on the next regular save, no special
cleanup needed.

Verified on 64K shape (TP=2 PP=4 CP=8 EP=8) for both fresh start
(iter_1 + iter_10 saves succeed) and resume from iter_N (resumed-
process warmup save iter_N+1 succeeds). Matches the prior failure
mode where first-save-at-iter_50 reliably hit cudaErrorInvalidValue.
Three changes from the original profile_utils:

1. Default to a small rank subset {0, 32, 48} rather than every rank.
   64-rank trace capture writes 64 separate ~1 GB files per active
   step; cutting to 3 representative ranks (head / mid / tail PP
   stages) keeps trace volume manageable. Override with the
   MILES_PROFILE_RANKS env var (comma-separated) when a different
   shape needs different ranks.

2. Return None for non-selected ranks instead of constructing a
   torch.profiler.profile(schedule(active=0)). The latter trips
   torch's `assert active > 0`. Caller (_profile_simple_loop) now
   None-checks and yields the iterator unchanged.

3. Disable record_shapes / with_stack / profile_memory by default.
   with_stack=True deadlocks NCCL on 64-rank H20 + RoCE
   (100% reproducible: every rank stalls in `Timer data_preprocess
   start` for >5 minutes); the other two bloat trace size and slow
   steps without adding the kernel-level wall breakdown we want for
   CP/EP/PP collective analysis.
@kakisong kakisong changed the title Enable DeepSeek V4-Flash 64-GPU SFT pipeline [V4-Flash] miles framework patches + V4 plugin NaN fixes May 22, 2026
kakisong added 2 commits June 1, 2026 15:26
Match the DeepSeek V4 KV-QAT path to the official runtime by using an in-place FP8 quant-dequant simulation with FP32 scales instead of returning FP8 values and multiplying UE8M0 scales in Python.

Only quantize non-RoPE KV dimensions in the vanilla and compressed-KV paths, preserving RoPE dimensions as BF16. Rebuild tensors with cat so the compressor path avoids slice overlap writes.

Add gated internal attention trace captures behind MILES_DSV4_TRACE_INTERNALS for replay/debugging. These hooks are no-ops by default, and the numerical changes only apply when MEGATRON_USE_KV_QAT=1; the default BF16 SFT path is unchanged.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant