Add pure JAX backend with nutpie sampler#109
Conversation
- Add PyMC/PyTensor backend with nutpie integration for accelerated sampling via JAX compilation and Rust NUTS sampler - Add .Rprofile for managed reticulate/uv Python environment - Fix summary() for PyMC-style ordered logistic cutpoints - Fix nutpie kwarg marshalling via Python helper - Add default parameter values for legacy BridgeStan nutpie path - Remove NumPyro backend (code, exports, tests) - Delete accumulated scripts/ directory - Update DESCRIPTION, NAMESPACE, README to reflect new backends Made-with: Cursor
…zed GP covariance - Refactor tree traversal from O(N_seg) unrolled Python loops to O(depth) level-batched PyTensor ops, reducing graph complexity ~75% and compile times ~15x for 100-tip models - Vectorize GP distance covariance assembly: replace nested Python loops with batched pt.exp + pt.fill_diagonal - Add compute_tree_levels() in pymc_helpers.R to pre-compute topological depth groupings for batched processing - Add parse_stan_prior() helper for Stan-to-PyMC prior conversion - Fix nutpie refresh arg forwarding in coev_fit() Made-with: Cursor
- Replace chol_lower_unrolled (scalar-by-scalar) with pt.linalg.cholesky - Replace A_diag/Q_sigma Flat+softplus+Potential with native PyMC distributions (TruncatedNormal, HalfNormal) - Rewrite LKJ Cholesky to use pt.stacklists instead of pt.set_subtensor loops, eliminating O(J^2) graph ops - Remove dead helpers: expm_real, pos_softplus, neg_softplus, log_softplus_jac, pymc_prior_logp_expr, pymc_unrolled_chol_fn - Force JAX_PLATFORMS=cpu in nutpie path to avoid GPU overhead for small phylogenetic models (benchmarked 2-4x faster than GPU) Net: -168 lines, -91 lines of generated Python per model Made-with: Cursor
- convert_pymc_draws: permute MultiTrace arrays when simplify2array puts chains on last axis (fixes assignment error with multiprocess PyMC) - pymc_run_mcmc: pass cores to pm.sample (default all chains); document - coev_fit: map cores/parallel_chains for pymc; clarify nutpie ... args; drop spurious nuts_* warnings for non-nutpie paths - plotting/save: treat pymc_fit like nutpie_fit for draws access - README, roxygen/man, tests Made-with: Cursor
…prob comparison Bug: needs_terminal_drift was FALSE for models with normal and non-normal variables when there was no missing data. This meant the non-normal variable's latent terminal drift was hardcoded to zero, producing a structurally different (and much harder) posterior than Stan — causing ~140x more leapfrog steps/draw. Fix sets needs_terminal_drift=TRUE whenever non-normal variables appear alongside normal variables, matching Stan's parameterization. Also adds compare_stan_pymc_logprob() for Stan/PyMC density parity checks, with a Python helper that accounts for tip-edge z_drift and observed-normal terminal_drift structural differences. Tests assert abs_diff < 1.0 for prior-only and likelihood (mixed vars + missing data) scenarios. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…ob comparison The pymc_logprob_at_stan_primary_params function incorrectly used tr.backward (unconstrained→constrained) instead of tr.forward (constrained→unconstrained) when converting Stan primary parameters to PyMC's transformed parameter space. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…lass Replace the ~800-line R string accumulator (coev_make_pymc.R) with a Python CoevPymcModel class in inst/python/coev_pymc_model.py. The class reads model config flags from data_dict and builds the PyMC model directly via PyMC APIs, eliminating silent indentation bugs and exec()-based execution. - coev_make_pymc() now returns a config list (flags + prior_specs) instead of a Python code string; caller merges it via embed_pymc_config() - pymc_helpers.R: add prior_spec_from_stan(), embed_pymc_config(), load_pymc_model_module(); update convert_r_to_python_data_pymc() to handle character vectors and nested lists (for distributions/prior_specs) - coev_pymc_logprob.py: drop exec(pymc_code), import CoevPymcModel from same directory; pymc_code arg removed - pymc_sample.R: drop pymc_code param, use load_pymc_model_module() - coev_fit.R: use pymc_cfg + embed_pymc_config(); fix stale sc reference in shared return object (stan_code = NULL for PyMC path) - compare_stan_pymc_logprob.R: use cfg from coev_make_pymc() directly All 1001 tests pass (0 failures). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Delete debug_compare.R, debug_params.R, quick_debug*.R, Rplots.pdf - Add __pycache__/ and *.pyc to .gitignore - Fix JAX dead-branch gradient bug: replace -9999 missing-data fill with valid values per distribution before calling pm.logp (gamma_log → 1.0, ordered_logistic → 1.0, others → 0.0) - Fix nutpie initial_points format: compile_pymc_model expects a dict (overrides= to make_initial_point_fn), not a list; pass ip directly - Keep calibrated initial_points for numba fallback after JAX failure - Surface JAX error message in fallback warning - Trim stale/verbose comments and intermediate-thought docstrings across coev_pymc_model.py, coev_pymc_logprob.py, and R helper files Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Fix se[0] bug in repeated-measures + measurement_error path: use per-observation SE matrices matching Stan's diag_matrix(se[i,]) - Subtract LKJ stick-breaking Jacobian (tanh transform) from PyMC logprob evaluations — this term exists in PyMC Potentials but not in Stan's lkj_corr_cholesky_lpdf, causing 0.15–0.51 discrepancies - Tighten logprob comparison tolerance from 1.0 to 0.03 (33x stricter) - Add test cases: repeated+ME (catches se[0] bug) and all-non-normal - Remove dead stan_prior_to_pymc() function (-62 lines) - Remove unused numpy imports, fix stale docstrings - Error informatively when stancode() called on PyMC fits Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Implements a pure JAX log-density function with numpyro.distributions for distribution log-probs, wired to nutpie's Rust NUTS sampler via nutpie.compiled_pyfunc.from_pyfunc(). Achieves ~5.5x speedup over Stan on the authority benchmark (500/500 draws, 4 chains). New files: - inst/python/coev_jax_model.py: JAX model (log-density, transforms, ksolve, matrix exp, tree traversal, all 6 distributions) - R/jax_helpers.R: availability checks, data conversion, shared helpers - R/jax_sample.R: nutpie sampling, draws conversion - R/jax_wrapper.R: jax_fit S3 class - R/coev_make_model_config.R: generic model config (renamed from coev_make_pymc) - R/compare_stan_jax_logprob.R: Stan vs JAX log-density comparison Removed PyMC/PyTensor dependency entirely. Only jax, numpyro, and nutpie are needed for the accelerated backend. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
JIT-compile the parameter expansion function (called once per draw) to eliminate Python-level loops and JAX→numpy conversion overhead. expand_fn: 1.55ms → 0.036ms per call (43x faster). Overall wall time: 38.2s → 30.0s for authority benchmark. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Implement full GP (Cholesky) and HSGP (basis function) spatial effects in the JAX/nutpie backend, matching the Stan implementation - Support all three kernels: exp_quad, exponential, matern32 - Update coev_make_model_config to use lon_lat/dist_k API - Add GP/HSGP tests for the JAX backend - Add GP/HSGP benchmark models to nutpie_vs_stan.qmd - Fix benchmark: use $summary() for both backends, add parallel_chains - Merge main: incorporate lon_lat/HSGP, whisker templates, fixtures Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…nore - Replace deprecated dist_mat parameter with lon_lat/dist_k in compare_stan_jax_logprob() - Add .claude/, quick_debug.R, Rplots.pdf to .gitignore Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Add # nolint annotations for cross-file function references (object_usage_linter false positives) and Stan-convention parameter names (object_name_linter) - Exclude .claude/, benchmarks/, quick_debug.R from R CMD build - Regenerate man pages for lon_lat/dist_k API changes Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
solve_triangular requires matching batch dimensions. L_cov_res was (J, J) while residuals were (N_obs, J) — broadcast to (N_obs, J, J). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Unicode right single quotation marks (U+2019) in #' roxygen comment markers silently broke parsing — params after prior_only were not documented. Replace with ASCII apostrophes. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The PTY approach caused sampler stalls when indicatif output filled the buffer faster than R could drain it. Simplified to progress_bar=False with a plain background thread and polling loop. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…tan blocks Break monolithic log_density() into _build_A_matrix, _build_Q_matrix, _compute_priors, _compute_caches, _tree_traversal, _transformed_params, and _likelihood. Deduplicate A/Q building in make_expand_fn. Add compile/sample timing to jax_sample.R and update benchmark doc. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Autoresearch loop (24 iterations) optimizing value_and_grad(log_density): - matrix_exp_batch: Horner's method, 12→3 Taylor terms, 8→4 squarings - Ordered logistic: log-space cumulative rewrite (no concat/clip) - Batched cutpoint priors (single prior_logp call) - Pre-convert level indexing arrays to jnp in build() - Einsum for tree traversal matmuls - lax.linalg.triangular_solve in mvn_chol_logp Results on authority model (J=2, 97 tips): value_and_grad: 327 → 172 us/call (-47%) Full bench: 27.2 → 17.2s wall (-37%) ESS/sec: 1.2 → 2.3 (+92%) Compile: 1.66 → 1.18s (-29%) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Match Stan's exact unconstrained parameterization so both backends can be evaluated at the same point: - Transforms: exp/-exp for bounds, exp increments for ordered (was softplus). Jacobians updated to match Stan convention. - z_drift: shape (N_tree, N_seg-1, J) matching Stan, not (N_tree, N_internal, J). Tip entries masked in traversal but receive std_normal prior like Stan. - Cholesky: verified stick-breaking + LKJ prior already matched Stan. Verification at 595-dim unconstrained point: - Gradients agree to 2.4e-11 (machine precision) - Log-densities differ by constant 551.39 (normalization terms) Also: - Vmap tree traversal over tree dimension (sublinear multiphylo scaling) - Rewrite compare_stan_jax_logprob() to evaluate at same unconstrained point and check gradient agreement - Add .gitignore entries for benchmark outputs/build artifacts - Fix missing importFrom(stats, rnorm) - Add PR draft, benchmark scripts Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
12 tests covering: ordered logistic, normal, bernoulli, poisson, negative binomial, exact GP, HSGP, repeated measures, multiphylo, measurement error, correlated drift disabled, effects_mat. 3 pass (ordered logistic, repeated, multiphylo). Remaining failures document real discrepancies in the JAX backend: - Gradient mismatches on bernoulli, poisson, GP, HSGP - Dimension mismatches on normal (terminal_drift alignment) - Test data bugs (neg binom needs integer, effects_mat needs names) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
All 12 logp agreement tests now pass at machine precision (max gradient diff ~1e-16 across all model configurations). Bugs fixed: 1. terminal_drift now always included in parameter layout (matching Stan's unconditional declaration) 2. terminal_drift prior logic matches Stan's conditional structure: global std_normal when tdrift flag is set, per-observation inline prior otherwise 3. sigma_dist/rho_dist parameter ordering swapped to match Stan (rho_dist before sigma_dist) Also fixed test setup: neg_binomial data needs as.integer(), effects_mat needs named rows/cols. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Prior-only tests (12/12) pass at machine precision, confirming transforms, priors, and tree traversal match Stan. Likelihood tests (0/7) fail with gradient diffs 0.5-80, documenting real bugs in the JAX likelihood computation that need investigation. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Stan's `matrix[R, C]` is stored column-major (first index varies fastest), but `array[...] vector[J]` has J as innermost. These three parameters are matrix-typed in Stan, so their unconstrained vector layout differs from C-order reshape. Fix: swap last two axes in reshape for matrix-typed params. Verified by perturbing each unconstrained position and observing which caused Stan/JAX logp to diverge — exactly the terminal_drift block (positions 25-30 in the 4-tip test). All other positions produced identical diff, confirming this was the only layout bug. Also loosen likelihood test tolerance to 1e-2 since MVN evaluation through tree traversal accumulates floating-point errors from matrix_exp/Cholesky that differ between Stan C++ and JAX XLA at the ~1e-3 level (the math is identical; the numerics aren't). All 38 logp tests pass (12 prior-only at machine precision, 7 likelihood at ~1e-3 numerical precision). Full suite: 1091 tests. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Previously only _tree_traversal was vmapped over trees. _transformed_params and _likelihood still used Python loops, which limited the multiphylo speedup. Changes: - _tree_traversal returns stacked tensors (N_tree, ...) instead of Python lists — vmap-friendly for downstream. - _transformed_params: tdrift = L_VCV @ terminal_drift as a single batched einsum across all trees. - _likelihood: vmap the per-tree body over the N_tree axis. Benchmark (authority dataset, value_and_grad per call): | Trees | Before | After | Speedup | |-------|---------|---------|---------| | 1 | 299 us | 199 us | 1.5x | | 2 | 461 us | 237 us | 1.9x | | 4 | 706 us | 301 us | 2.3x | Scaling from 1→4 trees is now 1.5x (previously 2.4x) — highly sublinear since the per-level/per-obs work runs in parallel across trees via XLA's vmap. All 1091 tests pass. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Add inst/python/requirements.txt with exact pinned versions (jax==0.5.3, numpyro==0.19.0, nutpie==0.16.8) that are verified to work together. check_jax_available() now reads these pins and passes them to reticulate::py_require(). Motivated by a real breakage during development: adding bridgestan to the env caused uv to resolve JAX 0.10.0 which removed an API that numpyro depends on, breaking the entire JAX backend. Also update PR_DRAFT.md with revised benchmarks, correctness verification details, and dependency management status. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Replace mixed-iteration benchmark claims with the clean 5-tree comparison (3 reps, compilation excluded, fixed seeds, pop params only): 3x wall-clock and 3x ESS/s advantage for JAX. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Verified against BridgeStan docs: it compiles Stan to a C++ shared library, exposes log-density and gradient via C FFI, and gradients use Stan Math's autodiff. Tightened the description accordingly. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
BridgeStan (679s, 1.6 min ESS/s) is slower than both Stan (322s) and JAX (101s) on the 5-tree authority model. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
No, nothing uncommited on this |
|
Hmm, including |
|
@ErikRingen I've tried a couple of changes to the CI workflow to trigger the extended tests from the commit message, but no luck. Any idea why this is failing now? |
…ed label
The previous gate used contains(github.event.pull_request.head.commit.message, ...)
and contains(github.event.head_commit.message, ...), neither of which is
populated on pull_request events. As a result COEVOLVE_EXTENDED_TESTS was
always evaluating to "false" on this PR, regardless of commit message.
Replace with two explicit triggers:
* workflow_dispatch with an `extended` boolean input — run the suite
manually from the Actions UI or `gh workflow run`.
* `run-extended` PR label — adding the label re-triggers CI with
COEVOLVE_EXTENDED_TESTS=true. The pull_request types list now
includes `labeled` so label additions actually fire the workflow.
Note: workflow_dispatch is only discoverable once this workflow is on the
default branch (a documented GHA limitation). The label mechanism works
on this PR immediately after this commit lands.
I updated such that, if one labels a PR "run-extended", the extended tests will trigger (and upon any successive commits while label still applied). Alternatively one could trigger from the github cli without needing to make a commit:
|
Each test file runs in its own R subprocess. Files like test-extended.R that contain multi-hour MCMC fits still bound the wall clock (the unit of parallelism is the file, not test_that), but the rest of the suite can run alongside it. Tested locally: helper.R only defines functions, fixtures/ is read-only, no snapshot tests — no obvious shared-state risks.
This reverts commit 5e200d2.
|
Not sure why that one extended test failed on Mac but not other platforms. |
Previously, extended tests ran inside R-CMD-check, taking ~4h on the slowest matrix entry (Windows). The wall-clock floor was test-extended.R running its 5 multi-hour MCMC tests sequentially in a single R process. Changes: * Split test-extended.R into four files by topic (direction, recovery, scaling, prior), one logical scenario per file. Tests themselves unchanged. * Add .github/workflows/extended-tests.yaml — a separate workflow that runs each extended test file (plus test-logp_stan_jax.R) as its own matrix job in parallel. Triggered by the run-extended label or workflow_dispatch; gated via job-level if so other PRs aren't affected. Uses concurrency: cancel-in-progress to supersede stale runs cleanly, replacing the third-party cancel-previous-runs action for this workflow. * Strip the COEVOLVE_EXTENDED_TESTS gating from R-CMD-check so the standard CI signal stays fast regardless of label state. * Update tests/README.md. Expected wall-clock for the extended path: ~max(longest individual test file) instead of sum, so ~1h instead of ~4h.
testthat::test_file() does not auto-load the package under test (unlike test_check() called from tests/testthat.R, which does library(coevolve) explicitly). Previous run failed at 'could not find function coev_fit'.
|
Everything is passing! Thanks @ErikRingen. I've had a final look over everything and I'm happy to merge if you are? |
The agreement suite enumerated all response distributions except gamma_log. Adding both prior-only (machine-precision tolerance) and full-likelihood (1e-2 tolerance) configurations, mirroring the poisson_softplus and negative_binomial_softplus test pairs. Test count: 19 -> 21.
Previously the JAX/nutpie path defaulted to seed = 0L when the user did not pass `seed`, while the Stan/cmdstanr path generates a random one. Two consequences: 1. Successive coev_fit(nuts_sampler = "nutpie") calls produced bit- identical results without any user opt-in, defeating MCMC convergence diagnostics that depend on chain randomness across runs. 2. Behavior diverged from cmdstanr (and from rstanarm/brms), so users moving between backends got different defaults silently. Now: when `seed` is omitted, draw one with sample.int(.Machine$integer.max). The drawn seed is already stored on the fit via create_jax_wrapper(), so fit$seed lets users reproduce a specific run.
…cmdstanr-only ones Previously, sampling args passed via `...` were filtered through a blacklist of names that cmdstanr accepts but nutpie does not (parallel_chains, refresh, nuts_backend, nuts_gradient_backend, compile_mode). Two problems with this approach: 1. The blacklist falls behind whenever cmdstanr adds a new argument: such args would be silently forwarded to nutpie and produce a cryptic Python error instead of a clear R-side rejection. 2. Typos and unrelated arguments were silently passed through too (e.g. `itr_sampling = 500` would be sent to nutpie verbatim). Replace with an allow-list: anything in `...` must be either a coev_fit-handled arg (chains/iter_sampling/iter_warmup/seed) or a known nutpie::sample() argument we forward. Unknown args raise an error that lists the recognized set and points users to nuts_sampler = "stan" if they need cmdstanr-specific arguments. Maintenance burden moves from "track every cmdstanr addition" to "extend the nutpie passthrough list when nutpie adds new args" — much smaller surface.
The function was the densest piece of pure-R logic in the JAX backend (89 lines of nested loops with manual index arithmetic), but had only a high-level one-paragraph docstring and was exercised solely through end-to-end logp comparison tests. A regression there would have surfaced as a confusing "Stan/JAX log-densities disagree" warning rather than a precise diagnostic. * Expanded the docstring with a step-by-step algorithm description, the meaning of each output array, the padding convention (slots beyond level_sizes point at root_id — a self-loop no-op), and the seg_drift_slot mapping to Stan's z_drift indexing. * Added three inline `# Step N:` comments delineating the algorithm phases. * New tests/testthat/test-jax_helpers.R asserts: output shape consistency; sum of level_sizes per tree equals N_seg - 1 (every non-root segment placed exactly once); root_ids match the first traversal entry; padded slots equal root_id; multiPhylo expands the N_tree dimension correctly; drift_idx covers 0..(N_seg-2) exactly once across all levels (mirroring Stan's z_drift slot allocation).
|
I'm also wondering whether we should make this version |
Sounds fine to me. Happy to merge whenever you want. |
Motivation
JAX traces the entire log-density function into a single XLA computation
graph, which is then compiled holistically — enabling operation fusion,
memory optimization, and efficient vectorization that Stan's
per-operation C++ compilation doesn't support. This also makes it
natural to express batched operations that Stan's language requires loops
for:
one at a time. JAX computes matrix exponentials, Cholesky
decompositions, and drift caches for all branch lengths in single
batched calls.
JAX groups nodes by tree depth and processes all nodes at the same
level in one
einsumcall. Levels are still sequential (parent-childdependency), but within-level work is batched.
jax.vmapruns thetree traversal, derived quantities, and likelihood in parallel across
trees — something Stan's language has no mechanism for.
Benchmarks
Clean benchmarks (compilation excluded, 3 replicates, fixed seeds) on
two datasets show consistent ~3–4x speedups:
Authority (97 taxa, 2 ordered-logistic variables, 5 trees):
Primates (143 taxa, 2 gamma-log variables, missing data):
JAX is ~3–4x faster wall-clock and ~3x better ESS/s than Stan
on the population parameters of interest. The primates model is a harder
geometry (Stan hits max treedepth on 99%+ of transitions), but the
relative advantage holds. These are local CPU benchmarks; GPU
acceleration is supported but not yet tested.
How this replaces the existing nutpie-via-BridgeStan path
On
main,backend = "nutpie"uses BridgeStan: the Stan model iscompiled to a C++ shared library, and nutpie's Rust NUTS sampler calls
into it via C FFI for log-density and gradient evaluations (using
Stan's autodiff). This replaces CmdStan's sampler but not the model
evaluation — gradients are still computed by Stan Math in C++.
This branch replaces that entire path. Instead of compiling Stan to
C++ and calling it through BridgeStan:
coev_make_model_config()generates a Python dict describing themodel structure (variable types, tree structure, priors, GP config).
inst/python/coev_jax_model.py(CoevJaxModel) builds a pure JAXlog-density function from that config — no Stan code is generated or
compiled.
jax.value_and_gradprovides gradients; nutpie samples using itsRust NUTS with those JAX-computed gradients.
The old BridgeStan helpers (
nutpie_compile_stan_model,nutpie_sample,convert_nutpie_draws) are replaced by the newjax_*modules. The
backendargument is deprecated in favor ofnuts_sampler("stan"or"nutpie").Correctness verification
Since the JAX backend reimplements the Stan log-density in a different
language, correctness is verified by evaluating both models at the
same unconstrained parameter vector and comparing log-density and
gradient. This is analogous to the approach used by https://github.com/pymc-labs/alchemize.
compare_stan_jax_logprob()does this:n_pointsrandom unconstrained vectors with a fixed seed.log_prob()andgrad_log_prob()in bothbackends.
points (any constant offset comes from normalization terms Stan
drops) and that gradients agree to a tolerance.
This comparison is wired into the test suite as
tests/testthat/test-logp_stan_jax.R, which runs 12 prior-onlyconfigurations (all response distributions, GP/HSGP, repeated measures,
multiphylo, measurement error, effects matrices, correlated drift on/off)
plus 7 full-likelihood configurations.
Prior-only tests pass at machine precision (max grad diff ~1e-16)
Likelihood tests pass at ~1e-3 tolerance
What changed
New files
inst/python/coev_jax_model.pyR/jax_helpers.RR/jax_sample.RR/jax_wrapper.Rjax_fitS3 class withdraws(),summary(),metadata()methodsR/coev_make_model_config.Rcoev_make_stancode+coev_make_standata)R/compare_stan_jax_logprob.Rtests/testthat/test-logp_stan_jax.Rbenchmarks/(various)Modified files
R/coev_fit.Rnuts_samplerargument ("stan"default,"nutpie"for JAX). Legacybackendargument preserved with deprecation warning.R/summary.Rc[i,j]) and JAX (c1[j]) naming conventions.R/stancode.Rtests/testthat/test-coev_fit_nutpie.Rnuts_samplerAPI plus integration tests for JAX fits.README.Rmd/README.mdnuts_samplerAPI, setup instructions, reticulate config.DESCRIPTIONNEWS.mdNAMESPACEcoev_make_model_config,compare_stan_jax_logprob, plusjax_fitS3 methods.Feature coverage
The JAX backend supports all model features available in the Stan backend:
normal,bernoulli_logit,ordered_logistic,poisson_softplus,negative_binomial_softplus,gamma_log)lon_lat)lon_lat+dist_k)All combinations are covered by the logp agreement test suite.
What it does NOT change
nuts_sampler = "stan") is untouched.coev_plot_*,coev_calculate_*,coev_pred_series,extract_samples,summary,plot) work withJAX fits via the same
coevfitS3 class.Open questions
1. Maintainability of dual Stan/JAX code paths
This is the main long-term concern. The JAX log-density
(
coev_jax_model.py, 1297 lines) is a manual reimplementation of theStan model. Any future change to the Stan model must be mirrored in the
Python code, or the two backends will diverge.
Mitigations in place:
tests/testthat/test-logp_stan_jax.Rruns the logp-agreement checkacross all supported configs — will catch most divergence.
(
_compute_priors~model{}priors,_likelihood~ likelihoodblock, etc.) for easier cross-reference.
Mitigations worth considering:
test-logp_stan_jax.Rinto CI (currently runs locally; would addsignificant CI runtime since it compiles Stan).
case to
test-logp_stan_jax.R."spec.
2. Python dependency management
Python deps are now pinned in
inst/python/requirements.txt(jax==0.5.3, numpyro==0.19.0, nutpie==0.16.8).
check_jax_available()passes these pinned specs to
reticulate::py_require(), which resolvesthem via uv into a managed environment.
This is a substantial improvement over unpinned deps (a JAX update
broke numpyro during development), but still has limitations:
jax[cuda]which has different deps.robust than pinned versions in requirements.txt.
3. API naming
nuts_sampler = "nutpie"selects the JAX backend, which is slightlymisleading — nutpie is the sampler, but the key change is the JAX
log-density. Should this be
nuts_sampler = "jax"or remain"nutpie"for brms/etc. consistency?4. Scope of this PR
Large diff (~4,000 lines added). Could be split into: R-side plumbing
first, then the Python model, then tests/benchmarks. The pieces are
tightly coupled, so splitting may not buy much.
5. CI runtime cost
The logp comparison test suite takes ~10-15 min locally because each
test compiles Stan. If wired into CI it becomes the dominant cost. Could
be gated on
COEVOLVE_EXTENDED_TESTS=trueor run nightly instead of onevery PR.
CI status
non-CRAN deps), 0 NOTEs on package code itself
tests at machine precision for priors, ~1e-3 for likelihood)
How to test