Skip to content

ds4 (deepseek_v4): DSA indexer aux-loss scale hook + lift forced CP=1 (MTP/dense-CSA support CP>1)#59

Open
ISEEKYAN wants to merge 1 commit into
mainfrom
mlite-ds4-indexer-scale-and-cp
Open

ds4 (deepseek_v4): DSA indexer aux-loss scale hook + lift forced CP=1 (MTP/dense-CSA support CP>1)#59
ISEEKYAN wants to merge 1 commit into
mainfrom
mlite-ds4-indexer-scale-and-cp

Conversation

@ISEEKYAN

Copy link
Copy Markdown
Owner

Summary

Two coupled deepseek_v4 DSA training-correctness fixes (one PR):

1. DSA indexer aux-loss backward-scale in the pre-forward hook

The ds4 pre-forward hook now also sets DSAIndexerLossAutoScaler.set_loss_scale(1/num_microbatches) (mirroring the glm5/deepseek_v3_2 hook), alongside the MTP scaler. This is defensive: ds4's MoE is aux-loss-free and the CSA indexer currently runs with sparse_loss=False (indexer loss is a no-op today, honestly noted), but wiring the scale prevents an nmb-amplification bug if/when the indexer aux-loss is enabled.

2. Lift the forced context-parallel=1 for the dense path

Removed the cp_size==1 gate in deepseek_v4/model.py that forced CP=1, and added a roll_contiguous_left_for_cp primitive (CP-boundary roll handled in primitive layer, not inlined in the model forward) so num_tokens stays locally matched to the main loss. ds4's CSA structurally takes a dense rebuild path for cp>1 (fused is cp1-only), so cp2-vs-cp1 is a dense-vs-dense same-algorithm comparison.

Validation (GPU, Slurm job 12932122, 2xGPU, DSA sm90 overlay, non-skip, ALL_STEPS_RC=0)

  • gloo 2-rank CP-boundary roll == global roll (core structural fix)
  • hook: MTP and DSA-indexer backward scale both resolve to 1/num_microbatches
  • ds4 CP2 + MTP real fwd/bwd: mtp_loss present (MTP genuinely runs under CP>1, gate unlocked), loss/grad finite
  • cp2-vs-cp1 per-token logprob max_abs_diff = 0.0 (dense bitwise consistent)

Note: ds4 THD-packed path keeps enable_mtp=False (orthogonal to the CP gate; real SFT/DAPO uses THD and RL does not train MTP) — out of scope here; the dense MTP+CP path is what's unlocked and verified.

…scale

Two DeepSeek-V4 DSA training-correctness fixes that touch the same protocol
pre_forward_hook and MTP path.

1. MTP under CP>1. DS4 previously forced cp_size==1 for MTP because its
   next-token roll used a plain torch.roll, which wraps each CP rank's last
   token onto its own first token instead of the first token of the next
   rank's contiguous slice. Add a shared roll_contiguous_left_for_cp primitive
   (gather full sequence -> roll once globally -> re-slice contiguously,
   mirroring how the THD roll_packed_thd_left handles the CP boundary) and
   route the ds4 MTP input_ids/labels/loss_mask rolls through it, then drop the
   cp_size==1 gate. The local token sum is kept local to match the model's
   per-rank loss.mean() normalization (the CP reduction happens in the gradient
   all-reduce). num_tokens stays consistent with the main cross-entropy loss.

2. Aux-loss hook. The hook now also sets DSAIndexerLossAutoScaler's backward
   scale, mirroring the GLM-5 hook. DS4's MoE router is aux-loss-free and its
   CSA indexer currently runs sparse_loss=False (so the indexer scaler is a
   no-op today); setting it defensively keeps the DSA-family hook uniform and
   guards against a future indexer-loss enable being mis-weighted under
   num_microbatches>1.

Tests: a deterministic gloo unit test for the CP-boundary roll (gathered roll
== global roll), a GPU ds4 CP2+MTP forward/backward smoke (mtp_loss present =>
the gate is open), a cp2-vs-cp1 dense-path per-token log-prob match, and a
hook test asserting both MTP and indexer backward scales follow 1/nmb.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant