Skip to content

feat(jax): persistent XLA compilation cache for jsp-train#875

Merged
ocg-goodfire merged 1 commit into
feature/jaxfrom
worktree-agent-a220cbe2e9323975a
Jun 17, 2026
Merged

feat(jax): persistent XLA compilation cache for jsp-train#875
ocg-goodfire merged 1 commit into
feature/jaxfrom
worktree-agent-a220cbe2e9323975a

Conversation

@ocg-goodfire

Copy link
Copy Markdown
Collaborator

What

Enables JAX's persistent compilation cache for jsp-train, so the expensive XLA compile of the chunkwise step (~24 min for the 27-site Llama-8B step on 64×B200) is cached to disk and reused across runs, SLURM requeues/resumes, and future R&D launches at the same config + topology. A matching re-compile loads the executable from disk in seconds.

Where it's set + why

run.py::main calls a new _enable_persistent_compilation_cache(cfg.out_dir) after init_distributed() (the cache's write gate reads the distributed state) and before the first compile (the make_train_step / harvest jits inside train()). Putting it in run.py rather than the pd-jax-lm launcher means it covers direct jsp-train too, not just launcher-submitted runs, and the path is resolvable from config (always-on). It uses the established jax.config.update(...) pattern (cf. slow_eval.py:395).

Flags set:

  • jax_compilation_cache_dir → the shared cache path
  • jax_persistent_cache_min_compile_time_secs = 60.0 — only the big compiles cache, no thrash on trivial ones
  • jax_persistent_cache_min_entry_size_bytes = 0 — jax default (allows its own override heuristic)

Cache dir path

$PARAM_DECOMP_OUT_DIR/xla_compilation_cache — a sibling of runs/ (derived as cfg.out_dir.parent / "xla_compilation_cache", since out_dir == $PARAM_DECOMP_OUT_DIR/runs). It is deliberately NOT per-run and NOT inside the per-run immutable workspace — it must be shared across runs for any cross-run reuse. All 8N ranks of a multi-host run point at the same shared-FS dir. Created by rank 0 in main.

Multi-host correctness (jax 0.10.1)

Safe. Confirmed from the installed jax/_src/compiler.py: the cache write is gated on distributed.global_state.process_id == 0 (verbatim comment: "Only write cache entries from the first process. Otherwise we create problems with contention for writes on some filesystems, e.g., GCS."). The XLA autotune subcache is likewise gated (AutotuneCacheMode.UPDATE on rank 0, READ elsewhere). So all ranks read, only rank 0 writes — no shared-FS write race across the 64 ranks. The official JAX persistent-cache doc states the same rank-0 write design and requires a shared FS (NFS/GFS); $PARAM_DECOMP_OUT_DIR is already shared-FS cluster storage. No multi-host caveat outstanding for our version.

Verification — cache HIT proven (no GPU needed)

Compiled a non-trivial jit'd fn (8× tanh(x@xᵀ)@x chain) on CPU with the cache dir set to a temp path, threshold lowered so the CPU compile actually writes, in two separate processes:

Phase 1 (cold)PERSISTENT COMPILATION CACHE MISSWriting jit_f to persistent compilation cache with key 'jit_f-a610e383...'; 3 entries written to disk (jit_f-...-cache, etc).

Phase 2 (fresh process)Persistent compilation cache hit for 'jit_f' with key 'jit_f-a610e383...' for the identical key; wall time 0.044s → 0.018s, identical output. Mechanism confirmed on jax 0.10.1.

(The 64-GPU production gain will confirm naturally on the next requeue.)

Validation

  • basedpyright jax_single_pool/0 errors, 0 warnings
  • pytest jax_single_pool/tests/189 passed, 2 skipped (incl. equivalence goldens — bit-identical), at both default device count and XLA_FLAGS=--xla_force_host_platform_device_count=4
  • make type (torch side) — 0 errors
  • Pre-commit (Ruff, BasedPyright, BasedPyright-JAX) — all Passed

No training-semantics change — caching only affects compile, not numerics; goldens stay green.

🤖 Generated with Claude Code

Cache compiled XLA executables to a shared-FS dir reused across runs,
requeues, and future launches at the same config+topology. The ~24-min
chunkwise-step compile is keyed by HLO + backend + topology + jax/xla
version, so a matching re-compile loads from disk in seconds.

Set in run.py::main after init_distributed (the write gate reads the
distributed state) and before the first compile, so it covers direct
jsp-train too — not just pd-jax-lm. Cache dir is a SIBLING of runs/
($PARAM_DECOMP_OUT_DIR/xla_compilation_cache), shared across all runs and
all 8N ranks. Threshold 60s so only the big compiles cache.

Multi-host safe on jax 0.10.1: jax gates the cache WRITE on process_id == 0
(compiler.py), so all ranks read but only rank 0 writes — no shared-FS
race. Verified a cross-process cache HIT on CPU.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@ocg-goodfire ocg-goodfire merged commit 7cb9a6f into feature/jax Jun 17, 2026
1 check failed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant