ds4 (deepseek_v4): DSA indexer aux-loss scale hook + lift forced CP=1 (MTP/dense-CSA support CP>1)#59
Open
ISEEKYAN wants to merge 1 commit into
Open
ds4 (deepseek_v4): DSA indexer aux-loss scale hook + lift forced CP=1 (MTP/dense-CSA support CP>1)#59ISEEKYAN wants to merge 1 commit into
ISEEKYAN wants to merge 1 commit into
Conversation
…scale Two DeepSeek-V4 DSA training-correctness fixes that touch the same protocol pre_forward_hook and MTP path. 1. MTP under CP>1. DS4 previously forced cp_size==1 for MTP because its next-token roll used a plain torch.roll, which wraps each CP rank's last token onto its own first token instead of the first token of the next rank's contiguous slice. Add a shared roll_contiguous_left_for_cp primitive (gather full sequence -> roll once globally -> re-slice contiguously, mirroring how the THD roll_packed_thd_left handles the CP boundary) and route the ds4 MTP input_ids/labels/loss_mask rolls through it, then drop the cp_size==1 gate. The local token sum is kept local to match the model's per-rank loss.mean() normalization (the CP reduction happens in the gradient all-reduce). num_tokens stays consistent with the main cross-entropy loss. 2. Aux-loss hook. The hook now also sets DSAIndexerLossAutoScaler's backward scale, mirroring the GLM-5 hook. DS4's MoE router is aux-loss-free and its CSA indexer currently runs sparse_loss=False (so the indexer scaler is a no-op today); setting it defensively keeps the DSA-family hook uniform and guards against a future indexer-loss enable being mis-weighted under num_microbatches>1. Tests: a deterministic gloo unit test for the CP-boundary roll (gathered roll == global roll), a GPU ds4 CP2+MTP forward/backward smoke (mtp_loss present => the gate is open), a cp2-vs-cp1 dense-path per-token log-prob match, and a hook test asserting both MTP and indexer backward scales follow 1/nmb. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Two coupled deepseek_v4 DSA training-correctness fixes (one PR):
1. DSA indexer aux-loss backward-scale in the pre-forward hook
The ds4 pre-forward hook now also sets
DSAIndexerLossAutoScaler.set_loss_scale(1/num_microbatches)(mirroring the glm5/deepseek_v3_2 hook), alongside the MTP scaler. This is defensive: ds4's MoE is aux-loss-free and the CSA indexer currently runs withsparse_loss=False(indexer loss is a no-op today, honestly noted), but wiring the scale prevents an nmb-amplification bug if/when the indexer aux-loss is enabled.2. Lift the forced context-parallel=1 for the dense path
Removed the
cp_size==1gate in deepseek_v4/model.py that forced CP=1, and added aroll_contiguous_left_for_cpprimitive (CP-boundary roll handled in primitive layer, not inlined in the model forward) so num_tokens stays locally matched to the main loss. ds4's CSA structurally takes a dense rebuild path for cp>1 (fused is cp1-only), so cp2-vs-cp1 is a dense-vs-dense same-algorithm comparison.Validation (GPU, Slurm job 12932122, 2xGPU, DSA sm90 overlay, non-skip, ALL_STEPS_RC=0)
Note: ds4 THD-packed path keeps enable_mtp=False (orthogonal to the CP gate; real SFT/DAPO uses THD and RL does not train MTP) — out of scope here; the dense MTP+CP path is what's unlocked and verified.