feat(jax): Llama-8B l18-26 9-layer chunkwise config + abstract-AOT mem probe#871
Merged
Merged
Conversation
New from-scratch decomposition of Llama-3.1-8B MLP layers 18-26 (27 targets = 9 layers x gate/up/down, C=49152), chunkwise recon (3 chunks of 9 sites, coeff 1.5), remat on. Derived from the C49k single-layer-18 run. AOT-probed per-device HBM at 64 B200 (remat on, no buffer donation — matching the trainer's non-donating @jax.jit step): B=128 (2/device): 210.2 GiB -> over the 180 GiB cap B=64 (1/device): 179.0 GiB -> at the cliff, ~0 headroom B=128 does not fit, so the config carries B=64 with LRs 4th-root-scaled to ci_fn 4.2e-5 / components 1.26e-4. mem_probe.py rewritten to lower+compile on ABSTRACT sharded avals (shapes via jit(init).lower().out_info, GSPMD shardings re-attached per leaf) instead of eagerly materializing the state — the 27-site C=49152 state is ~360 GB and OOMs the eager init's jit__identity_fn (128 GiB replicated V/U) at 64 GPU. Adds --sites_per_chunk / --recon_coeff / --last_layer args and the 64-GPU probe sbatch. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
… seq-512 probe seq 2048→512 is what makes 27 sites fit on 64 GPU: AOT probe (seq 512, remat on) = 144.8 GiB (B=64) .. 150.3 GiB (B=256)/device, all under the 180 GiB B200 cap (seq 2048 was 179 GiB at B=64, over the cliff). Final config: layers 18-26 (27 sites, C=49152), seq 512 / fineweb_llama_tok_512, B=128, 40k steps; comp 1.5e-4 / ci_fn 5e-5 (B=128 precedent); chunkwise recon sites_per_chunk=9 coeff 1.5; imp-min eps 1e-12→1e-6 (fractional-pnorm grad stability), coeff/beta/pnorm unchanged; PGD/faith unchanged; ci max_len→512. mem_probe gains --seq; seq-512 sweep sbatch added. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What
A new from-scratch training config decomposing Llama-3.1-8B MLP layers 18–26 inclusive (9 layers × 3 matrices = 27 decomposition targets, C=49152), plus the AOT HBM probe used to size it for 64× B200.
New config files
param_decomp_jax/jax_single_pool/configs/torch/llama8b_l18-26_9layer_chunkwise.yaml(schema)param_decomp_jax/jax_single_pool/configs/llama8b_l18-26_9layer_chunkwise_from_torch.yaml(wrapper)Derived from the C49k single-layer-18 run with these deliberate changes:
model.layers.{18..26}.mlp.{gate,up,down}_proj, canonical order, each C=49152.ChunkwiseSubsetReconLoss(replacesStochasticReconSubsetLoss):sites_per_chunk: 9→ 3 chunks of 9,coeff: 1.5(= base 0.5 × 3 chunks, undoing the chunk-mean dilution).remat_recon_forwards: true) — 9 layers need it.PersistentPGDReconLoss(one_chunk, 0.5, broadcast),ImportanceMinimalityLoss(5e-6, β0.2, p 2.0→0.4),FaithfulnessLoss(1e5), seq 2048, faith warmup, eval set,weights_dtype: bfloat16, 200k steps.The schema validates cleanly: loader asserts 27 targets in canonical order, all C=49152, chunkwise term parses (3 chunk forwards via
build_recon_terms).AOT memory probe
The established probe (
experiments/mem_probe.py) eagerly materializes the state, which OOMs for this size — the 27-site C=49152 state is ~360 GB and the eager sharded init'sjit__identity_fntries to replicate the full ~98 GB V/U per device (128 GiB alloc, RESOURCE_EXHAUSTED at 64 GPU). Rewrote the probe to lower+compile on abstract sharded avals (shapes fromjit(init).lower().out_info, GSPMD shardings re-attached per leaf) — allocation-free, faithful to the trainer's placement. Validated on CPU (8 sim devices) and on a real 64× B200 SLURM job.Per-device peak HBM (
temp + args + out, no buffer donation — matching the trainer's non-donating@jax.jitstep;alias 0.0):Recommendation: B=64
B=128 does not fit. B=64 is the only launchable point but sits AT 179 GiB with essentially no margin — a real launch should re-confirm against live device HBM (
XLA_PYTHON_CLIENT_MEM_FRACTIONleaves <180 GiB usable) and may want a larger mesh or fewer layers. The config carries B=64 with LRs dropped to the 4th root of the 8× batch ratio vs the 512 baseline: ci_fn 4.2e-5 / components 1.26e-4.Probe reproduction
experiments/mem_probe_l18-26.sbatch(8 nodes / 64 GPU) runs both batches. New probe args:--sites_per_chunk,--recon_coeff,--last_layer.No training run was launched.
make format/make type/make check-jaxall clean.🤖 Generated with Claude Code