Skip to content

Make Othello starting-player tests RNG-version-agnostic#1325

Open
gweber wants to merge 1 commit into
sotetsuk:mainfrom
gweber:othello-rng-robust-tests
Open

Make Othello starting-player tests RNG-version-agnostic#1325
gweber wants to merge 1 commit into
sotetsuk:mainfrom
gweber:othello-rng-robust-tests

Conversation

@gweber

@gweber gweber commented Jun 11, 2026

Copy link
Copy Markdown

Make Othello's starting-player-dependent tests RNG-version-agnostic

Four tests in tests/test_othello.py (test_init, test_terminated,
test_legal_action, test_observe) hard-coded the outcome of the starting-player
coin flip in Othello._init (current_player = bernoulli(key)). They picked a
specific PRNGKey — via a double jax.random.split chosen to land on
current_player == 0 — and then made absolute board/reward/observation assertions
keyed to that player.

The exact bernoulli result for a given key is not stable across jax.random
versions
, so on newer JAX (e.g. 0.10.x) the flip returns 1 and all four tests
fail, even though the environment logic is unchanged and correct.

This replaces the brittle key-hunting with a small helper that searches keys for
the desired starting player:

def _init_with_current_player(player: int):
    for seed in range(1000):
        state = init(jax.random.PRNGKey(seed))
        if int(state.current_player) == player:
            return state
    raise AssertionError(...)

Every downstream assertion is unchanged and fully deterministic once the starting
player is fixed (board indices and outcomes don't depend on the RNG). No
environment/runtime code is touched — tests only.

Validation

Ran tests/test_othello.py on jax 0.4.30 and jax 0.10.1: the four tests
pass on both. (On 0.10.x a separate, unrelated test_api failure remains, caused
by an IndexError inside pgx.api_test/baseline infra — out of scope here.)

The bernoulli coin flip that picks the starting player in Othello._init is not
stable across jax.random versions, so tests that hard-coded the outcome of a
specific PRNGKey (via a double split chosen to land on current_player==0) broke
on newer JAX. Replace that with a small _init_with_current_player(player) helper
that searches keys for the desired starting player, keeping every downstream
board/reward/observation assertion deterministic on any JAX version.

Verified on jax 0.4.30 and jax 0.10.1.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant