Skip to content

Add pure JAX backend with nutpie sampler#109

Merged
ScottClaessens merged 55 commits into
mainfrom
jax
May 15, 2026
Merged

Add pure JAX backend with nutpie sampler#109
ScottClaessens merged 55 commits into
mainfrom
jax

Conversation

@ErikRingen

@ErikRingen ErikRingen commented Apr 22, 2026

Copy link
Copy Markdown
Collaborator

Feedback welcome on design, scope, and open questions below.

Motivation

JAX traces the entire log-density function into a single XLA computation
graph, which is then compiled holistically — enabling operation fusion,
memory optimization, and efficient vectorization that Stan's
per-operation C++ compilation doesn't support. This also makes it
natural to express batched operations that Stan's language requires loops
for:

  • Batched matrix operations. Stan loops over unique branch lengths
    one at a time. JAX computes matrix exponentials, Cholesky
    decompositions, and drift caches for all branch lengths in single
    batched calls.
  • Level-batched tree traversal. Stan propagates states node-by-node.
    JAX groups nodes by tree depth and processes all nodes at the same
    level in one einsum call. Levels are still sequential (parent-child
    dependency), but within-level work is batched.
  • Vmap across trees. For multiphylo models, jax.vmap runs the
    tree traversal, derived quantities, and likelihood in parallel across
    trees — something Stan's language has no mechanism for.

Benchmarks

Clean benchmarks (compilation excluded, 3 replicates, fixed seeds) on
two datasets show consistent ~3–4x speedups:

Authority (97 taxa, 2 ordered-logistic variables, 5 trees):

Stan (CmdStan) JAX (nutpie)
Wall-clock (median) 322 s 101 s
Min pop-param ESS/s 2.7 10.3
Median pop-param ESS/s 9.0 28.7

Primates (143 taxa, 2 gamma-log variables, missing data):

Stan (CmdStan) JAX (nutpie)
Wall-clock (median) 1031 s 278 s
Min pop-param ESS/s 0.0 0.1
Median pop-param ESS/s 0.2 0.7

JAX is ~3–4x faster wall-clock and ~3x better ESS/s than Stan
on the population parameters of interest. The primates model is a harder
geometry (Stan hits max treedepth on 99%+ of transitions), but the
relative advantage holds. These are local CPU benchmarks; GPU
acceleration is supported but not yet tested.

How this replaces the existing nutpie-via-BridgeStan path

On main, backend = "nutpie" uses BridgeStan: the Stan model is
compiled to a C++ shared library, and nutpie's Rust NUTS sampler calls
into it via C FFI for log-density and gradient evaluations (using
Stan's autodiff). This replaces CmdStan's sampler but not the model
evaluation — gradients are still computed by Stan Math in C++.

This branch replaces that entire path. Instead of compiling Stan to
C++ and calling it through BridgeStan:

  1. coev_make_model_config() generates a Python dict describing the
    model structure (variable types, tree structure, priors, GP config).
  2. inst/python/coev_jax_model.py (CoevJaxModel) builds a pure JAX
    log-density function from that config — no Stan code is generated or
    compiled.
  3. jax.value_and_grad provides gradients; nutpie samples using its
    Rust NUTS with those JAX-computed gradients.

The old BridgeStan helpers (nutpie_compile_stan_model,
nutpie_sample, convert_nutpie_draws) are replaced by the new jax_*
modules. The backend argument is deprecated in favor of
nuts_sampler ("stan" or "nutpie").

Correctness verification

Since the JAX backend reimplements the Stan log-density in a different
language, correctness is verified by evaluating both models at the
same unconstrained parameter vector and comparing log-density and
gradient. This is analogous to the approach used by https://github.com/pymc-labs/alchemize.

compare_stan_jax_logprob() does this:

  1. Compile both models for the same data/config.
  2. Draw n_points random unconstrained vectors with a fixed seed.
  3. At each point, compute log_prob() and grad_log_prob() in both
    backends.
  4. Check that the Stan/JAX log-density difference is constant across
    points (any constant offset comes from normalization terms Stan
    drops) and that gradients agree to a tolerance.

This comparison is wired into the test suite as
tests/testthat/test-logp_stan_jax.R, which runs 12 prior-only
configurations (all response distributions, GP/HSGP, repeated measures,
multiphylo, measurement error, effects matrices, correlated drift on/off)
plus 7 full-likelihood configurations.

Prior-only tests pass at machine precision (max grad diff ~1e-16)
Likelihood tests pass at ~1e-3 tolerance

What changed

New files

File Lines Purpose
inst/python/coev_jax_model.py 1297 Core JAX log-density, mirrors Stan blocks (data transforms, parameters, model, generated quantities)
R/jax_helpers.R 338 R-side helpers: Python availability check, JAX model instantiation, expand-fn compilation
R/jax_sample.R 159 Orchestrates JAX sampling: build model, JIT-compile, call nutpie, collect draws
R/jax_wrapper.R 167 jax_fit S3 class with draws(), summary(), metadata() methods
R/coev_make_model_config.R 146 Generates Python-side model config dict (analogous to coev_make_stancode + coev_make_standata)
R/compare_stan_jax_logprob.R ~170 Evaluates Stan and JAX log-densities + gradients at the same unconstrained point
tests/testthat/test-logp_stan_jax.R ~340 19 tests covering all supported configs (prior-only and full-likelihood)
benchmarks/ (various) Benchmark scripts, debug utilities, PR drafting

Modified files

File What changed
R/coev_fit.R New nuts_sampler argument ("stan" default, "nutpie" for JAX). Legacy backend argument preserved with deprecation warning.
R/summary.R Cutpoint variable parsing now handles both Stan (c[i,j]) and JAX (c1[j]) naming conventions.
R/stancode.R Returns Stan code when available, or a message indicating JAX backend was used.
tests/testthat/test-coev_fit_nutpie.R Rewritten for the new nuts_sampler API plus integration tests for JAX fits.
README.Rmd / README.md Updated sampler docs: nuts_sampler API, setup instructions, reticulate config.
DESCRIPTION Title/description updated, version bumped to 1.0.0.9001.
NEWS.md Added JAX backend entry.
NAMESPACE New exports: coev_make_model_config, compare_stan_jax_logprob, plus jax_fit S3 methods.

Feature coverage

The JAX backend supports all model features available in the Stan backend:

  • All response distributions (normal, bernoulli_logit,
    ordered_logistic, poisson_softplus, negative_binomial_softplus,
    gamma_log)
  • Repeated measures
  • Effects matrices
  • Exact GP spatial control (lon_lat)
  • HSGP approximate GP (lon_lat + dist_k)
  • Correlated drift estimation (on/off)
  • Residual correlation estimation
  • Measurement error
  • Multiphylo (phylogenetic uncertainty)
  • Prior-only sampling

All combinations are covered by the logp agreement test suite.

What it does NOT change

  • The Stan backend (nuts_sampler = "stan") is untouched.
  • All downstream functions (coev_plot_*, coev_calculate_*,
    coev_pred_series, extract_samples, summary, plot) work with
    JAX fits via the same coevfit S3 class.
  • No changes to the Stan model code or Stan data generation.

Open questions

1. Maintainability of dual Stan/JAX code paths

This is the main long-term concern. The JAX log-density
(coev_jax_model.py, 1297 lines) is a manual reimplementation of the
Stan model. Any future change to the Stan model must be mirrored in the
Python code, or the two backends will diverge.

Mitigations in place:

  • tests/testthat/test-logp_stan_jax.R runs the logp-agreement check
    across all supported configs — will catch most divergence.
  • The Python code is structured to mirror Stan block names
    (_compute_priors ~ model{} priors, _likelihood ~ likelihood
    block, etc.) for easier cross-reference.

Mitigations worth considering:

  • Wire test-logp_stan_jax.R into CI (currently runs locally; would add
    significant CI runtime since it compiles Stan).
  • Contributor checklist: "if you change the Stan model, add a matching
    case to test-logp_stan_jax.R."
  • Longer term: auto-generate the JAX log-density from a shared model
    spec.

2. Python dependency management

Python deps are now pinned in inst/python/requirements.txt
(jax==0.5.3, numpyro==0.19.0, nutpie==0.16.8). check_jax_available()
passes these pinned specs to reticulate::py_require(), which resolves
them via uv into a managed environment.

This is a substantial improvement over unpinned deps (a JAX update
broke numpyro during development), but still has limitations:

  • GPU support would need jax[cuda] which has different deps.
  • Users with existing Python environments may still hit conflicts.
  • A proper lockfile (e.g., pixi.toml or uv.lock) would be more
    robust than pinned versions in requirements.txt.

3. API naming

nuts_sampler = "nutpie" selects the JAX backend, which is slightly
misleading — nutpie is the sampler, but the key change is the JAX
log-density. Should this be nuts_sampler = "jax" or remain
"nutpie" for brms/etc. consistency?

4. Scope of this PR

Large diff (~4,000 lines added). Could be split into: R-side plumbing
first, then the Python model, then tests/benchmarks. The pieces are
tightly coupled, so splitting may not buy much.

5. CI runtime cost

The logp comparison test suite takes ~10-15 min locally because each
test compiles Stan. If wired into CI it becomes the dominant cost. Could
be gated on COEVOLVE_EXTENDED_TESTS=true or run nightly instead of on
every PR.

CI status

  • lintr: passing (no lints)
  • R CMD check: 1 WARNING (CRAN incoming feasibility, expected for
    non-CRAN deps), 0 NOTEs on package code itself
  • Test suite: 1091 tests pass locally (including 19 logp agreement
    tests at machine precision for priors, ~1e-3 for likelihood)

How to test

# Install Python deps
# pip install jax numpyro nutpie

# Fit with JAX backend
fit <- coev_fit(
  data = authority$data,
  variables = list(
    political_authority = "ordered_logistic",
    religious_authority = "ordered_logistic"
  ),
  id = "language",
  tree = authority$phylogeny,
  nuts_sampler = "nutpie"
)

# Verify log-density agreement with Stan
compare_stan_jax_logprob(
  data = authority$data,
  variables = list(
    political_authority = "ordered_logistic",
    religious_authority = "ordered_logistic"
  ),
  id = "language",
  tree = authority$phylogeny
)

ErikRingen and others added 30 commits March 19, 2026 11:39
- Add PyMC/PyTensor backend with nutpie integration for accelerated
  sampling via JAX compilation and Rust NUTS sampler
- Add .Rprofile for managed reticulate/uv Python environment
- Fix summary() for PyMC-style ordered logistic cutpoints
- Fix nutpie kwarg marshalling via Python helper
- Add default parameter values for legacy BridgeStan nutpie path
- Remove NumPyro backend (code, exports, tests)
- Delete accumulated scripts/ directory
- Update DESCRIPTION, NAMESPACE, README to reflect new backends

Made-with: Cursor
…zed GP covariance

- Refactor tree traversal from O(N_seg) unrolled Python loops to O(depth)
  level-batched PyTensor ops, reducing graph complexity ~75% and compile
  times ~15x for 100-tip models
- Vectorize GP distance covariance assembly: replace nested Python loops
  with batched pt.exp + pt.fill_diagonal
- Add compute_tree_levels() in pymc_helpers.R to pre-compute topological
  depth groupings for batched processing
- Add parse_stan_prior() helper for Stan-to-PyMC prior conversion
- Fix nutpie refresh arg forwarding in coev_fit()

Made-with: Cursor
- Replace chol_lower_unrolled (scalar-by-scalar) with pt.linalg.cholesky
- Replace A_diag/Q_sigma Flat+softplus+Potential with native PyMC
  distributions (TruncatedNormal, HalfNormal)
- Rewrite LKJ Cholesky to use pt.stacklists instead of pt.set_subtensor
  loops, eliminating O(J^2) graph ops
- Remove dead helpers: expm_real, pos_softplus, neg_softplus,
  log_softplus_jac, pymc_prior_logp_expr, pymc_unrolled_chol_fn
- Force JAX_PLATFORMS=cpu in nutpie path to avoid GPU overhead for
  small phylogenetic models (benchmarked 2-4x faster than GPU)

Net: -168 lines, -91 lines of generated Python per model
Made-with: Cursor
- convert_pymc_draws: permute MultiTrace arrays when simplify2array puts
  chains on last axis (fixes assignment error with multiprocess PyMC)
- pymc_run_mcmc: pass cores to pm.sample (default all chains); document
- coev_fit: map cores/parallel_chains for pymc; clarify nutpie ... args;
  drop spurious nuts_* warnings for non-nutpie paths
- plotting/save: treat pymc_fit like nutpie_fit for draws access
- README, roxygen/man, tests

Made-with: Cursor
…prob comparison

Bug: needs_terminal_drift was FALSE for models with normal and non-normal
variables when there was no missing data. This meant the non-normal variable's
latent terminal drift was hardcoded to zero, producing a structurally different
(and much harder) posterior than Stan — causing ~140x more leapfrog steps/draw.
Fix sets needs_terminal_drift=TRUE whenever non-normal variables appear
alongside normal variables, matching Stan's parameterization.

Also adds compare_stan_pymc_logprob() for Stan/PyMC density parity checks,
with a Python helper that accounts for tip-edge z_drift and observed-normal
terminal_drift structural differences. Tests assert abs_diff < 1.0 for
prior-only and likelihood (mixed vars + missing data) scenarios.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…ob comparison

The pymc_logprob_at_stan_primary_params function incorrectly used tr.backward
(unconstrained→constrained) instead of tr.forward (constrained→unconstrained)
when converting Stan primary parameters to PyMC's transformed parameter space.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…lass

Replace the ~800-line R string accumulator (coev_make_pymc.R) with a Python
CoevPymcModel class in inst/python/coev_pymc_model.py. The class reads
model config flags from data_dict and builds the PyMC model directly via
PyMC APIs, eliminating silent indentation bugs and exec()-based execution.

- coev_make_pymc() now returns a config list (flags + prior_specs) instead
  of a Python code string; caller merges it via embed_pymc_config()
- pymc_helpers.R: add prior_spec_from_stan(), embed_pymc_config(),
  load_pymc_model_module(); update convert_r_to_python_data_pymc() to
  handle character vectors and nested lists (for distributions/prior_specs)
- coev_pymc_logprob.py: drop exec(pymc_code), import CoevPymcModel from
  same directory; pymc_code arg removed
- pymc_sample.R: drop pymc_code param, use load_pymc_model_module()
- coev_fit.R: use pymc_cfg + embed_pymc_config(); fix stale sc reference
  in shared return object (stan_code = NULL for PyMC path)
- compare_stan_pymc_logprob.R: use cfg from coev_make_pymc() directly

All 1001 tests pass (0 failures).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Delete debug_compare.R, debug_params.R, quick_debug*.R, Rplots.pdf
- Add __pycache__/ and *.pyc to .gitignore
- Fix JAX dead-branch gradient bug: replace -9999 missing-data fill with
  valid values per distribution before calling pm.logp (gamma_log → 1.0,
  ordered_logistic → 1.0, others → 0.0)
- Fix nutpie initial_points format: compile_pymc_model expects a dict
  (overrides= to make_initial_point_fn), not a list; pass ip directly
- Keep calibrated initial_points for numba fallback after JAX failure
- Surface JAX error message in fallback warning
- Trim stale/verbose comments and intermediate-thought docstrings across
  coev_pymc_model.py, coev_pymc_logprob.py, and R helper files

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Fix se[0] bug in repeated-measures + measurement_error path: use
  per-observation SE matrices matching Stan's diag_matrix(se[i,])
- Subtract LKJ stick-breaking Jacobian (tanh transform) from PyMC
  logprob evaluations — this term exists in PyMC Potentials but not
  in Stan's lkj_corr_cholesky_lpdf, causing 0.15–0.51 discrepancies
- Tighten logprob comparison tolerance from 1.0 to 0.03 (33x stricter)
- Add test cases: repeated+ME (catches se[0] bug) and all-non-normal
- Remove dead stan_prior_to_pymc() function (-62 lines)
- Remove unused numpy imports, fix stale docstrings
- Error informatively when stancode() called on PyMC fits

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Implements a pure JAX log-density function with numpyro.distributions
for distribution log-probs, wired to nutpie's Rust NUTS sampler via
nutpie.compiled_pyfunc.from_pyfunc(). Achieves ~5.5x speedup over
Stan on the authority benchmark (500/500 draws, 4 chains).

New files:
- inst/python/coev_jax_model.py: JAX model (log-density, transforms,
  ksolve, matrix exp, tree traversal, all 6 distributions)
- R/jax_helpers.R: availability checks, data conversion, shared helpers
- R/jax_sample.R: nutpie sampling, draws conversion
- R/jax_wrapper.R: jax_fit S3 class
- R/coev_make_model_config.R: generic model config (renamed from
  coev_make_pymc)
- R/compare_stan_jax_logprob.R: Stan vs JAX log-density comparison

Removed PyMC/PyTensor dependency entirely. Only jax, numpyro, and
nutpie are needed for the accelerated backend.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
JIT-compile the parameter expansion function (called once per draw)
to eliminate Python-level loops and JAX→numpy conversion overhead.
expand_fn: 1.55ms → 0.036ms per call (43x faster).
Overall wall time: 38.2s → 30.0s for authority benchmark.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Implement full GP (Cholesky) and HSGP (basis function) spatial
  effects in the JAX/nutpie backend, matching the Stan implementation
- Support all three kernels: exp_quad, exponential, matern32
- Update coev_make_model_config to use lon_lat/dist_k API
- Add GP/HSGP tests for the JAX backend
- Add GP/HSGP benchmark models to nutpie_vs_stan.qmd
- Fix benchmark: use $summary() for both backends, add parallel_chains
- Merge main: incorporate lon_lat/HSGP, whisker templates, fixtures

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…nore

- Replace deprecated dist_mat parameter with lon_lat/dist_k in
  compare_stan_jax_logprob()
- Add .claude/, quick_debug.R, Rplots.pdf to .gitignore

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Add # nolint annotations for cross-file function references
  (object_usage_linter false positives) and Stan-convention
  parameter names (object_name_linter)
- Exclude .claude/, benchmarks/, quick_debug.R from R CMD build
- Regenerate man pages for lon_lat/dist_k API changes

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
solve_triangular requires matching batch dimensions. L_cov_res was
(J, J) while residuals were (N_obs, J) — broadcast to (N_obs, J, J).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Unicode right single quotation marks (U+2019) in #' roxygen comment
markers silently broke parsing — params after prior_only were not
documented. Replace with ASCII apostrophes.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The PTY approach caused sampler stalls when indicatif output filled
the buffer faster than R could drain it. Simplified to progress_bar=False
with a plain background thread and polling loop.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…tan blocks

Break monolithic log_density() into _build_A_matrix, _build_Q_matrix,
_compute_priors, _compute_caches, _tree_traversal, _transformed_params,
and _likelihood. Deduplicate A/Q building in make_expand_fn. Add
compile/sample timing to jax_sample.R and update benchmark doc.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Autoresearch loop (24 iterations) optimizing value_and_grad(log_density):
- matrix_exp_batch: Horner's method, 12→3 Taylor terms, 8→4 squarings
- Ordered logistic: log-space cumulative rewrite (no concat/clip)
- Batched cutpoint priors (single prior_logp call)
- Pre-convert level indexing arrays to jnp in build()
- Einsum for tree traversal matmuls
- lax.linalg.triangular_solve in mvn_chol_logp

Results on authority model (J=2, 97 tips):
  value_and_grad: 327 → 172 us/call (-47%)
  Full bench:     27.2 → 17.2s wall (-37%)
  ESS/sec:        1.2 → 2.3 (+92%)
  Compile:        1.66 → 1.18s (-29%)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Match Stan's exact unconstrained parameterization so both backends can
be evaluated at the same point:

- Transforms: exp/-exp for bounds, exp increments for ordered (was
  softplus). Jacobians updated to match Stan convention.
- z_drift: shape (N_tree, N_seg-1, J) matching Stan, not (N_tree,
  N_internal, J). Tip entries masked in traversal but receive
  std_normal prior like Stan.
- Cholesky: verified stick-breaking + LKJ prior already matched Stan.

Verification at 595-dim unconstrained point:
- Gradients agree to 2.4e-11 (machine precision)
- Log-densities differ by constant 551.39 (normalization terms)

Also:
- Vmap tree traversal over tree dimension (sublinear multiphylo scaling)
- Rewrite compare_stan_jax_logprob() to evaluate at same unconstrained
  point and check gradient agreement
- Add .gitignore entries for benchmark outputs/build artifacts
- Fix missing importFrom(stats, rnorm)
- Add PR draft, benchmark scripts

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
12 tests covering: ordered logistic, normal, bernoulli, poisson,
negative binomial, exact GP, HSGP, repeated measures, multiphylo,
measurement error, correlated drift disabled, effects_mat.

3 pass (ordered logistic, repeated, multiphylo).
Remaining failures document real discrepancies in the JAX backend:
- Gradient mismatches on bernoulli, poisson, GP, HSGP
- Dimension mismatches on normal (terminal_drift alignment)
- Test data bugs (neg binom needs integer, effects_mat needs names)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
All 12 logp agreement tests now pass at machine precision
(max gradient diff ~1e-16 across all model configurations).

Bugs fixed:
1. terminal_drift now always included in parameter layout
   (matching Stan's unconditional declaration)
2. terminal_drift prior logic matches Stan's conditional
   structure: global std_normal when tdrift flag is set,
   per-observation inline prior otherwise
3. sigma_dist/rho_dist parameter ordering swapped to match
   Stan (rho_dist before sigma_dist)

Also fixed test setup: neg_binomial data needs as.integer(),
effects_mat needs named rows/cols.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Prior-only tests (12/12) pass at machine precision, confirming
transforms, priors, and tree traversal match Stan.

Likelihood tests (0/7) fail with gradient diffs 0.5-80, documenting
real bugs in the JAX likelihood computation that need investigation.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Stan's `matrix[R, C]` is stored column-major (first index varies
fastest), but `array[...] vector[J]` has J as innermost. These
three parameters are matrix-typed in Stan, so their unconstrained
vector layout differs from C-order reshape.

Fix: swap last two axes in reshape for matrix-typed params.

Verified by perturbing each unconstrained position and observing
which caused Stan/JAX logp to diverge — exactly the terminal_drift
block (positions 25-30 in the 4-tip test). All other positions
produced identical diff, confirming this was the only layout bug.

Also loosen likelihood test tolerance to 1e-2 since MVN evaluation
through tree traversal accumulates floating-point errors from
matrix_exp/Cholesky that differ between Stan C++ and JAX XLA at
the ~1e-3 level (the math is identical; the numerics aren't).

All 38 logp tests pass (12 prior-only at machine precision,
7 likelihood at ~1e-3 numerical precision). Full suite: 1091 tests.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Previously only _tree_traversal was vmapped over trees.
_transformed_params and _likelihood still used Python loops,
which limited the multiphylo speedup.

Changes:
- _tree_traversal returns stacked tensors (N_tree, ...) instead of
  Python lists — vmap-friendly for downstream.
- _transformed_params: tdrift = L_VCV @ terminal_drift as a single
  batched einsum across all trees.
- _likelihood: vmap the per-tree body over the N_tree axis.

Benchmark (authority dataset, value_and_grad per call):

| Trees | Before  | After   | Speedup |
|-------|---------|---------|---------|
| 1     | 299 us  | 199 us  | 1.5x    |
| 2     | 461 us  | 237 us  | 1.9x    |
| 4     | 706 us  | 301 us  | 2.3x    |

Scaling from 1→4 trees is now 1.5x (previously 2.4x) — highly
sublinear since the per-level/per-obs work runs in parallel across
trees via XLA's vmap.

All 1091 tests pass.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Add inst/python/requirements.txt with exact pinned versions
(jax==0.5.3, numpyro==0.19.0, nutpie==0.16.8) that are verified
to work together. check_jax_available() now reads these pins and
passes them to reticulate::py_require().

Motivated by a real breakage during development: adding bridgestan
to the env caused uv to resolve JAX 0.10.0 which removed an API
that numpyro depends on, breaking the entire JAX backend.

Also update PR_DRAFT.md with revised benchmarks, correctness
verification details, and dependency management status.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Replace mixed-iteration benchmark claims with the clean 5-tree
comparison (3 reps, compilation excluded, fixed seeds, pop params
only): 3x wall-clock and 3x ESS/s advantage for JAX.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Verified against BridgeStan docs: it compiles Stan to a C++ shared
library, exposes log-density and gradient via C FFI, and gradients
use Stan Math's autodiff. Tightened the description accordingly.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
BridgeStan (679s, 1.6 min ESS/s) is slower than both Stan (322s)
and JAX (101s) on the 5-tree authority model.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@ErikRingen

Copy link
Copy Markdown
Collaborator Author

@ErikRingen any movement on this? I'm keen to use JAX in a current project of mine. I know you're busy with other stuff though. Shall I try making JAX available on the runner? I don't want to conflict with any work you've been doing behind the scenes.

No, nothing uncommited on this

@ScottClaessens

Copy link
Copy Markdown
Owner

Hmm, including run-extended in the commit message doesn't seem to trigger the extended tests.

@ScottClaessens

Copy link
Copy Markdown
Owner

@ErikRingen I've tried a couple of changes to the CI workflow to trigger the extended tests from the commit message, but no luck. Any idea why this is failing now?

…ed label

The previous gate used contains(github.event.pull_request.head.commit.message, ...)
and contains(github.event.head_commit.message, ...), neither of which is
populated on pull_request events. As a result COEVOLVE_EXTENDED_TESTS was
always evaluating to "false" on this PR, regardless of commit message.

Replace with two explicit triggers:
  * workflow_dispatch with an `extended` boolean input — run the suite
    manually from the Actions UI or `gh workflow run`.
  * `run-extended` PR label — adding the label re-triggers CI with
    COEVOLVE_EXTENDED_TESTS=true. The pull_request types list now
    includes `labeled` so label additions actually fire the workflow.

Note: workflow_dispatch is only discoverable once this workflow is on the
default branch (a documented GHA limitation). The label mechanism works
on this PR immediately after this commit lands.
@ErikRingen ErikRingen added the run-extended Run the extended (logp Stan/JAX) test suite in CI label May 13, 2026
@ErikRingen

Copy link
Copy Markdown
Collaborator Author

@ErikRingen I've tried a couple of changes to the CI workflow to trigger the extended tests from the commit message, but no luck. Any idea why this is failing now?

I updated such that, if one labels a PR "run-extended", the extended tests will trigger (and upon any successive commits while label still applied). Alternatively one could trigger from the github cli without needing to make a commit:

gh workflow run R-CMD-check.yaml --repo ScottClaessens/coevolve --ref jax -f extended=true

Each test file runs in its own R subprocess. Files like test-extended.R
that contain multi-hour MCMC fits still bound the wall clock (the unit of
parallelism is the file, not test_that), but the rest of the suite can
run alongside it.

Tested locally: helper.R only defines functions, fixtures/ is read-only,
no snapshot tests — no obvious shared-state risks.
@ScottClaessens

Copy link
Copy Markdown
Owner

Not sure why that one extended test failed on Mac but not other platforms.

ErikRingen and others added 3 commits May 14, 2026 13:47
Previously, extended tests ran inside R-CMD-check, taking ~4h on the
slowest matrix entry (Windows). The wall-clock floor was test-extended.R
running its 5 multi-hour MCMC tests sequentially in a single R process.

Changes:
* Split test-extended.R into four files by topic (direction, recovery,
  scaling, prior), one logical scenario per file. Tests themselves
  unchanged.
* Add .github/workflows/extended-tests.yaml — a separate workflow that
  runs each extended test file (plus test-logp_stan_jax.R) as its own
  matrix job in parallel. Triggered by the run-extended label or
  workflow_dispatch; gated via job-level if so other PRs aren't affected.
  Uses concurrency: cancel-in-progress to supersede stale runs cleanly,
  replacing the third-party cancel-previous-runs action for this workflow.
* Strip the COEVOLVE_EXTENDED_TESTS gating from R-CMD-check so the
  standard CI signal stays fast regardless of label state.
* Update tests/README.md.

Expected wall-clock for the extended path: ~max(longest individual
test file) instead of sum, so ~1h instead of ~4h.
testthat::test_file() does not auto-load the package under test (unlike
test_check() called from tests/testthat.R, which does library(coevolve)
explicitly). Previous run failed at 'could not find function coev_fit'.
@ScottClaessens

Copy link
Copy Markdown
Owner

Everything is passing! Thanks @ErikRingen. I've had a final look over everything and I'm happy to merge if you are?

The agreement suite enumerated all response distributions except
gamma_log. Adding both prior-only (machine-precision tolerance) and
full-likelihood (1e-2 tolerance) configurations, mirroring the
poisson_softplus and negative_binomial_softplus test pairs.

Test count: 19 -> 21.
Previously the JAX/nutpie path defaulted to seed = 0L when the user did
not pass `seed`, while the Stan/cmdstanr path generates a random one.
Two consequences:

1. Successive coev_fit(nuts_sampler = "nutpie") calls produced bit-
   identical results without any user opt-in, defeating MCMC convergence
   diagnostics that depend on chain randomness across runs.
2. Behavior diverged from cmdstanr (and from rstanarm/brms), so users
   moving between backends got different defaults silently.

Now: when `seed` is omitted, draw one with sample.int(.Machine$integer.max).
The drawn seed is already stored on the fit via create_jax_wrapper(), so
fit$seed lets users reproduce a specific run.
…cmdstanr-only ones

Previously, sampling args passed via `...` were filtered through a
blacklist of names that cmdstanr accepts but nutpie does not
(parallel_chains, refresh, nuts_backend, nuts_gradient_backend,
compile_mode). Two problems with this approach:

1. The blacklist falls behind whenever cmdstanr adds a new argument:
   such args would be silently forwarded to nutpie and produce a
   cryptic Python error instead of a clear R-side rejection.
2. Typos and unrelated arguments were silently passed through too
   (e.g. `itr_sampling = 500` would be sent to nutpie verbatim).

Replace with an allow-list: anything in `...` must be either a
coev_fit-handled arg (chains/iter_sampling/iter_warmup/seed) or a
known nutpie::sample() argument we forward. Unknown args raise an
error that lists the recognized set and points users to
nuts_sampler = "stan" if they need cmdstanr-specific arguments.

Maintenance burden moves from "track every cmdstanr addition" to
"extend the nutpie passthrough list when nutpie adds new args" —
much smaller surface.
The function was the densest piece of pure-R logic in the JAX backend
(89 lines of nested loops with manual index arithmetic), but had only a
high-level one-paragraph docstring and was exercised solely through
end-to-end logp comparison tests. A regression there would have surfaced
as a confusing "Stan/JAX log-densities disagree" warning rather than a
precise diagnostic.

* Expanded the docstring with a step-by-step algorithm description, the
  meaning of each output array, the padding convention (slots beyond
  level_sizes point at root_id — a self-loop no-op), and the
  seg_drift_slot mapping to Stan's z_drift indexing.
* Added three inline `# Step N:` comments delineating the algorithm
  phases.
* New tests/testthat/test-jax_helpers.R asserts: output shape
  consistency; sum of level_sizes per tree equals N_seg - 1 (every
  non-root segment placed exactly once); root_ids match the first
  traversal entry; padded slots equal root_id; multiPhylo expands the
  N_tree dimension correctly; drift_idx covers 0..(N_seg-2) exactly
  once across all levels (mirroring Stan's z_drift slot allocation).
@ScottClaessens

Copy link
Copy Markdown
Owner

I'm also wondering whether we should make this version 1.1.0 since the JAX addition is quite major.

@ErikRingen

Copy link
Copy Markdown
Collaborator Author

I'm also wondering whether we should make this version 1.1.0 since the JAX addition is quite major.

Sounds fine to me. Happy to merge whenever you want.

@ScottClaessens ScottClaessens merged commit 4d02f34 into main May 15, 2026
13 checks passed
@ScottClaessens ScottClaessens deleted the jax branch May 15, 2026 15:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

run-extended Run the extended (logp Stan/JAX) test suite in CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants