feat(jax): persistent XLA compilation cache for jsp-train#875
Merged
Conversation
Cache compiled XLA executables to a shared-FS dir reused across runs, requeues, and future launches at the same config+topology. The ~24-min chunkwise-step compile is keyed by HLO + backend + topology + jax/xla version, so a matching re-compile loads from disk in seconds. Set in run.py::main after init_distributed (the write gate reads the distributed state) and before the first compile, so it covers direct jsp-train too — not just pd-jax-lm. Cache dir is a SIBLING of runs/ ($PARAM_DECOMP_OUT_DIR/xla_compilation_cache), shared across all runs and all 8N ranks. Threshold 60s so only the big compiles cache. Multi-host safe on jax 0.10.1: jax gates the cache WRITE on process_id == 0 (compiler.py), so all ranks read but only rank 0 writes — no shared-FS race. Verified a cross-process cache HIT on CPU. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What
Enables JAX's persistent compilation cache for
jsp-train, so the expensive XLA compile of the chunkwise step (~24 min for the 27-site Llama-8B step on 64×B200) is cached to disk and reused across runs, SLURM requeues/resumes, and future R&D launches at the same config + topology. A matching re-compile loads the executable from disk in seconds.Where it's set + why
run.py::maincalls a new_enable_persistent_compilation_cache(cfg.out_dir)afterinit_distributed()(the cache's write gate reads the distributed state) and before the first compile (themake_train_step/harvestjits insidetrain()). Putting it inrun.pyrather than thepd-jax-lmlauncher means it covers directjsp-traintoo, not just launcher-submitted runs, and the path is resolvable from config (always-on). It uses the establishedjax.config.update(...)pattern (cf.slow_eval.py:395).Flags set:
jax_compilation_cache_dir→ the shared cache pathjax_persistent_cache_min_compile_time_secs = 60.0— only the big compiles cache, no thrash on trivial onesjax_persistent_cache_min_entry_size_bytes = 0— jax default (allows its own override heuristic)Cache dir path
$PARAM_DECOMP_OUT_DIR/xla_compilation_cache— a sibling ofruns/(derived ascfg.out_dir.parent / "xla_compilation_cache", sinceout_dir == $PARAM_DECOMP_OUT_DIR/runs). It is deliberately NOT per-run and NOT inside the per-run immutable workspace — it must be shared across runs for any cross-run reuse. All 8N ranks of a multi-host run point at the same shared-FS dir. Created by rank 0 inmain.Multi-host correctness (jax 0.10.1)
Safe. Confirmed from the installed
jax/_src/compiler.py: the cache write is gated ondistributed.global_state.process_id == 0(verbatim comment: "Only write cache entries from the first process. Otherwise we create problems with contention for writes on some filesystems, e.g., GCS."). The XLA autotune subcache is likewise gated (AutotuneCacheMode.UPDATEon rank 0,READelsewhere). So all ranks read, only rank 0 writes — no shared-FS write race across the 64 ranks. The official JAX persistent-cache doc states the same rank-0 write design and requires a shared FS (NFS/GFS);$PARAM_DECOMP_OUT_DIRis already shared-FS cluster storage. No multi-host caveat outstanding for our version.Verification — cache HIT proven (no GPU needed)
Compiled a non-trivial jit'd fn (8×
tanh(x@xᵀ)@xchain) on CPU with the cache dir set to a temp path, threshold lowered so the CPU compile actually writes, in two separate processes:Phase 1 (cold) —
PERSISTENT COMPILATION CACHE MISS→Writing jit_f to persistent compilation cache with key 'jit_f-a610e383...'; 3 entries written to disk (jit_f-...-cache, etc).Phase 2 (fresh process) —
Persistent compilation cache hit for 'jit_f' with key 'jit_f-a610e383...'for the identical key; wall time 0.044s → 0.018s, identical output. Mechanism confirmed on jax 0.10.1.(The 64-GPU production gain will confirm naturally on the next requeue.)
Validation
basedpyright jax_single_pool/— 0 errors, 0 warningspytest jax_single_pool/tests/— 189 passed, 2 skipped (incl. equivalence goldens — bit-identical), at both default device count andXLA_FLAGS=--xla_force_host_platform_device_count=4make type(torch side) — 0 errorsNo training-semantics change — caching only affects compile, not numerics; goldens stay green.
🤖 Generated with Claude Code