bump jax to 0.10, drop distrax to unblock the persistent compile cache#21
Merged
Conversation
JAX 0.4.38 silently no-ops the persistent compilation cache when any jax.Array is materialized before the cache dir is set (jax-ml/jax#25768), so every process paid full cold compile (~7 min GPU / ~20 min/cell CPU). JAX 0.5.x has a worse bug: cache hits silently return stale .at[].set() updates inside lax.fori_loop / lax.cond (jax-ml/jax#31733), corrupting the env step — reproduced here with 17/35 blue-Remove tests failing on a warm cache under 0.5.1. Bug is fixed from 0.7.1 onward; we land on 0.10. Distrax is dropped because its transitive tensorflow-probability<=0.25 imports jax.interpreters.xla.pytype_aval_mappings, removed in JAX 0.7. The 5 callsites only used distrax.Categorical for sampling / log_prob / entropy, all easily expressed on jax.random.categorical and jax.nn.log_softmax — see src/jaxborg/policies/categorical.py (31 lines, flax struct so it's jit/vmap/scan-compatible). Dropping distrax also removes TFP, gast, and decorator from the dep graph. Knock-on bumps from JAX 0.7+ requiring ml-dtypes>=0.5 (numpy 2 C ABI): * numpy 1.26.4 -> 2.3.5 (overrides cyborg's pin; cyborg runtime is numpy-2-compatible empirically) * scipy 1.12 -> 1.17 (overrides jaxmarl's `scipy<=1.12` precautionary upper bound; cf. jaxmarl commit 3ffa5b8f and issue #175) * torch 2.2 -> 2.10 (overrides cyborg's pin; needed for numpy-2 init) Also retires a flaky test that relied on accidental RNG alignment between CybORG's numpy and JAX's threefry: tests/subsystems/test_fsm_red_agent.py::test_fsm_hidden_state_applies_after_completion_step walked both pipelines forward on seed=0 and asserted they hit the same FSM sequence — a parity-by-coincidence that breaks any time JAX's PRNG layout changes. Restructured as a pure-state check of the two-stage delayed-update mechanism (fsm_red_schedule_post_step_update stages red_fsm_delayed_states; fsm_red_apply_delayed_update commits it on the next step), which is the actual invariant the test name promises. Verification: * fast suite: 772/772, cold 154s -> warm 115s (cache hit, correct) * full suite incl. slow: 1214 passed, 100 skipped, 3 xfailed (xfails are pre-existing, unrelated) * cache miscompile probe: blue_remove.py 35/35 cold and 35/35 warm against the same cache dir — under 0.5.1 the same probe gave 17/35 failures on warm
b4cd079 to
b221685
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
jax.Arrayis materialized before the cache dir is set (jax-ml/jax#25768). Every process pays full cold compile (~7 min GPU / ~20 min/cell CPU). JAX 0.5.x has a worse bug: cache hits silently return stale.at[].set()updates insidelax.fori_loop/lax.cond(jax-ml/jax#31733) — corrupts the env step. Reproduced here with 17/35 blue-Remove tests failing on a warm cache under 0.5.1. Fixed from 0.7.1 onward; this PR lands on 0.10.0.distraxbecause its transitivetensorflow-probability<=0.25importsjax.interpreters.xla.pytype_aval_mappings, removed in JAX 0.7. All 5 callsites useddistrax.Categoricalfor.sample()/.log_prob()/.entropy(), replaced by a local 31-linesrc/jaxborg/policies/categorical.py(flax struct, jit/vmap/scan-compatible). Also removes TFP, gast, decorator from the dep graph.ml-dtypes>=0.5(numpy 2 C ABI):scipy<=1.12; cf. jaxmarl commit3ffa5b8fand issue #175)tests/subsystems/test_fsm_red_agent.py::test_fsm_hidden_state_applies_after_completion_step's reliance on accidental RNG alignment between CybORG's numpy and JAX's threefry. It was walking both pipelines forward onseed=0and asserting they hit the same FSM sequence — parity-by-coincidence that breaks any time JAX's PRNG layout changes. Restructured as a pure-state check of the two-stage delayed-update mechanism that the test name actually promises.Test plan
tests/subsystems/test_blue_remove.py35/35 cold and 35/35 warm against the same cache dir — under 0.5.1 the same probe gave 17/35 failures on warmuv run ruff check . && uv run ruff format .clean