Skip to content

[codex] Fix DSv4 DCP checkpoint placements for DTensor-like params#61

Open
Meirtz wants to merge 2 commits into
ISEEKYAN:mainfrom
Meirtz:codex/mlite-dsv4-dcp-ep-shard-fix
Open

[codex] Fix DSv4 DCP checkpoint placements for DTensor-like params#61
Meirtz wants to merge 2 commits into
ISEEKYAN:mainfrom
Meirtz:codex/mlite-dsv4-dcp-ep-shard-fix

Conversation

@Meirtz

@Meirtz Meirtz commented Jun 21, 2026

Copy link
Copy Markdown

Summary

This PR fixes the MLite DCP fallback path for DTensor-like parameters whose
checkpoint placement differs from the wrapped parameter placement.

For DSv4 expert weights under EP, the checkpoint protocol supplies the expert
mesh and placements, but the previous helper rebuilt the DCP tensor from the
FSDP2 parameter mesh/placements. In the reproduced failure this let rank 0 load
rank 1's expert shard for layers.3.mlp.experts.fc1.weight0.

Changes

  • Use checkpoint mesh/placements when wrapping DTensor-like local shards for
    real sharded checkpoint placements.
  • Preserve the parameter mesh/placements for unsharded checkpoint placements,
    so empty local DTensor shards do not force full materialization.
  • Compute checkpoint global shape and contiguous stride from checkpoint
    placements when needed.
  • Treat hc_head.* as replicated before the generic DSv4 head placement
    rule.
  • Add targeted unit coverage with fake DTensor-like params for checkpoint
    placements, multi-axis sharding, unsharded checkpoint behavior, and empty
    local save behavior.

Validation

Local checks:

git diff --check
PYTHONDONTWRITEBYTECODE=1 /Users/lmei/anaconda3/envs/cua/bin/python -m py_compile \
  experimental/lite/megatron/lite/primitive/ckpt/dcp.py \
  experimental/lite/megatron/lite/model/deepseek_v4/lite/checkpoint.py \
  experimental/lite/tests/unit/primitive/test_training_checkpoint.py
  • Direct no-pytest helper harness passed in the same cua Python/Torch
    environment for checkpoint-placement DTensor wrapping, multi-axis checkpoint
    shard shape expansion, and unsharded checkpoint behavior that preserves the
    parameter mesh.
  • 2026-06-22 refresh: targeted pytest collection in the cua environment was
    attempted, but the file skipped before collection because local
    megatron.core.dist_checkpointing imports require triton, which is not
    installed. The no-pytest helper harness was rerun and passed for checkpoint
    placements, multi-axis shape expansion, unsharded empty-local save behavior,
    and empty local copy no-op.

GPU evidence from existing run artifacts:

  • Pre-fix peer-shard mismatch reproduced on GB300, H100x2, and B100x2.
  • Post-fix B100 ep2 SAVE_LOAD=1: analysis-save-load-ep2-b100-fix5.json
    reports overall=smoke_pass with all rank max deltas at 0.0.
  • Post-fix B100 ep4 SAVE_LOAD=1: analysis-save-load-ep4-b100-fix5.json
    reports overall=smoke_pass with all rank max deltas at 0.0.
  • Post-fix H100x2 ep2 SAVE_LOAD=1: job 1346071,
    analysis-save-load-ep2-h100-fix6.json reports overall=smoke_pass, and
    rank 0/rank 1 save_load.comparison.max_delta=0.0.
  • Supplemental post-fix H100x2 pp2 SAVE_LOAD=1: job 1346142 passed with
    finite train metrics and rank 0/rank 1 save_load.comparison.max_delta=0.0.
    Local evidence is recorded from the launcher terminal stream at
    runs/20260620-mlite-next-gates-pr-packaging/remote-results-dsv4_dcp_pp2_save_load-terminal/results/metrics.json.
  • Scoped DCP continuity, B100x2 ep2, MTP_ENABLE=0: job 1346239 passed
    save-load-continue versus uninterrupted training. loaded_step=1; rank 0
    compared 112 local tensors and rank 1 compared 111 local tensors; both ranks
    reported comparison.max_delta=0.0, no missing keys, and no shape
    mismatches. Evidence:
    runs/20260621-mlite-dsv4-dcp-continuity/remote-results-attempt2/results/metrics.json.
  • Scoped DCP continuity, H100x2 ep2, MTP_ENABLE=0: job 1346254 passed the
    same save-load-continue gate. loaded_step=1; rank 0 compared 112 local
    tensors and rank 1 compared 111 local tensors; both ranks reported
    comparison.max_delta=0.0, no missing keys, and no shape mismatches.
    Evidence:
    runs/20260621-mlite-dsv4-dcp-continuity/remote-results-h100-attempt3/results/metrics.json.

Boundaries

This PR does not claim:

  • full DSv4 Flash support,
  • default FlashMLA/fused attention support,
  • DSv4 MTP train continuity; the continuity evidence above disables MTP to
    isolate DCP from a known MTP/mHC DTensor/Tensor blocker outside this PR,
  • GLM5.2, VERL, Qwen, or LoRA support.

The full local pytest suite was not run locally. Targeted pytest collection now
starts in the cua environment, but the local environment is still missing the
triton dependency needed by megatron.core.dist_checkpointing.

Meirtz added 2 commits June 20, 2026 23:17
Merge DTensor/FSDP2 parameter shard placement into the matching checkpoint mesh axis when model-parallel checkpoint placements also shard expert tensors. Keep all-replicate checkpoint placements on the original parameter mesh so dense and mHC tensors are not treated as model-parallel shards.

Validation: OCI-HSG 2 nodes x 4 GB200 DSv4 tiny DCP continuity attempt oci-hsg-2node-8gpu-dsv4-dcp-mergefix-6, job 3488765, analyzer status=pass, all 8 ranks max_delta=0.0.
@Meirtz Meirtz marked this pull request as ready for review June 23, 2026 06:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant