Skip to content

feat(jax): Llama-8B l18-26 9-layer chunkwise config + abstract-AOT mem probe#871

Merged
ocg-goodfire merged 2 commits into
feature/jaxfrom
worktree-agent-a26fb5916afa66a4d
Jun 17, 2026
Merged

feat(jax): Llama-8B l18-26 9-layer chunkwise config + abstract-AOT mem probe#871
ocg-goodfire merged 2 commits into
feature/jaxfrom
worktree-agent-a26fb5916afa66a4d

Conversation

@ocg-goodfire

Copy link
Copy Markdown
Collaborator

What

A new from-scratch training config decomposing Llama-3.1-8B MLP layers 18–26 inclusive (9 layers × 3 matrices = 27 decomposition targets, C=49152), plus the AOT HBM probe used to size it for 64× B200.

New config files

  • param_decomp_jax/jax_single_pool/configs/torch/llama8b_l18-26_9layer_chunkwise.yaml (schema)
  • param_decomp_jax/jax_single_pool/configs/llama8b_l18-26_9layer_chunkwise_from_torch.yaml (wrapper)

Derived from the C49k single-layer-18 run with these deliberate changes:

  • 27 targets: model.layers.{18..26}.mlp.{gate,up,down}_proj, canonical order, each C=49152.
  • Recon = ChunkwiseSubsetReconLoss (replaces StochasticReconSubsetLoss): sites_per_chunk: 9 → 3 chunks of 9, coeff: 1.5 (= base 0.5 × 3 chunks, undoing the chunk-mean dilution).
  • remat ON (remat_recon_forwards: true) — 9 layers need it.
  • Unchanged: PersistentPGDReconLoss (one_chunk, 0.5, broadcast), ImportanceMinimalityLoss (5e-6, β0.2, p 2.0→0.4), FaithfulnessLoss (1e5), seq 2048, faith warmup, eval set, weights_dtype: bfloat16, 200k steps.

The schema validates cleanly: loader asserts 27 targets in canonical order, all C=49152, chunkwise term parses (3 chunk forwards via build_recon_terms).

AOT memory probe

The established probe (experiments/mem_probe.py) eagerly materializes the state, which OOMs for this size — the 27-site C=49152 state is ~360 GB and the eager sharded init's jit__identity_fn tries to replicate the full ~98 GB V/U per device (128 GiB alloc, RESOURCE_EXHAUSTED at 64 GPU). Rewrote the probe to lower+compile on abstract sharded avals (shapes from jit(init).lower().out_info, GSPMD shardings re-attached per leaf) — allocation-free, faithful to the trainer's placement. Validated on CPU (8 sim devices) and on a real 64× B200 SLURM job.

Per-device peak HBM (temp + args + out, no buffer donation — matching the trainer's non-donating @jax.jit step; alias 0.0):

batch per-device peak HBM fits 180 GiB w/ ~15% headroom (≤153 GiB)?
128 (2/device) 210.2 GiB No — over the 180 GiB cap outright
64 (1/device) 179.0 GiB No — at the cliff, ~0 headroom

Recommendation: B=64

B=128 does not fit. B=64 is the only launchable point but sits AT 179 GiB with essentially no margin — a real launch should re-confirm against live device HBM (XLA_PYTHON_CLIENT_MEM_FRACTION leaves <180 GiB usable) and may want a larger mesh or fewer layers. The config carries B=64 with LRs dropped to the 4th root of the 8× batch ratio vs the 512 baseline: ci_fn 4.2e-5 / components 1.26e-4.

Probe reproduction

experiments/mem_probe_l18-26.sbatch (8 nodes / 64 GPU) runs both batches. New probe args: --sites_per_chunk, --recon_coeff, --last_layer.

No training run was launched. make format / make type / make check-jax all clean.

🤖 Generated with Claude Code

ocg-goodfire and others added 2 commits June 17, 2026 15:16
New from-scratch decomposition of Llama-3.1-8B MLP layers 18-26 (27 targets =
9 layers x gate/up/down, C=49152), chunkwise recon (3 chunks of 9 sites,
coeff 1.5), remat on. Derived from the C49k single-layer-18 run.

AOT-probed per-device HBM at 64 B200 (remat on, no buffer donation — matching
the trainer's non-donating @jax.jit step):
  B=128 (2/device): 210.2 GiB  -> over the 180 GiB cap
  B=64  (1/device): 179.0 GiB  -> at the cliff, ~0 headroom
B=128 does not fit, so the config carries B=64 with LRs 4th-root-scaled to
ci_fn 4.2e-5 / components 1.26e-4.

mem_probe.py rewritten to lower+compile on ABSTRACT sharded avals (shapes via
jit(init).lower().out_info, GSPMD shardings re-attached per leaf) instead of
eagerly materializing the state — the 27-site C=49152 state is ~360 GB and
OOMs the eager init's jit__identity_fn (128 GiB replicated V/U) at 64 GPU.
Adds --sites_per_chunk / --recon_coeff / --last_layer args and the 64-GPU
probe sbatch.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
… seq-512 probe

seq 2048→512 is what makes 27 sites fit on 64 GPU: AOT probe (seq 512, remat on) =
144.8 GiB (B=64) .. 150.3 GiB (B=256)/device, all under the 180 GiB B200 cap (seq 2048
was 179 GiB at B=64, over the cliff). Final config: layers 18-26 (27 sites, C=49152),
seq 512 / fineweb_llama_tok_512, B=128, 40k steps; comp 1.5e-4 / ci_fn 5e-5 (B=128
precedent); chunkwise recon sites_per_chunk=9 coeff 1.5; imp-min eps 1e-12→1e-6
(fractional-pnorm grad stability), coeff/beta/pnorm unchanged; PGD/faith unchanged;
ci max_len→512. mem_probe gains --seq; seq-512 sweep sbatch added.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@ocg-goodfire ocg-goodfire merged commit cb515bc into feature/jax Jun 17, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant