Add convergence criterion, full-rank and IAF approximations, ADVI stability notebooks#108
Open
christiaanjs wants to merge 27 commits into
Open
Add convergence criterion, full-rank and IAF approximations, ADVI stability notebooks#108christiaanjs wants to merge 27 commits into
christiaanjs wants to merge 27 commits into
Conversation
Implements a compiled TensorFlow custom op for the node-height ratio transform (a preorder tree traversal mapping height ratios to node heights) with an analytic reverse-mode gradient, mirroring the existing native phylogenetic-likelihood op. - cc/node_height_ratio_op.cc: forward (NodeHeightRatio) + analytic gradient (NodeHeightRatioGrad) walking nodes in reverse preorder, reusing the saved forward heights; float32/float64, int32/int64 indices, batched over leading sample/site dims. - cc/tree_traversal.h: shared host-side index helpers (ReadIndices) and documented index conventions, now used by both native ops; the likelihood op is refactored to consume it. - node_height_ratio.py: Python wrapper + RegisterGradient, a drop-in for traversal.ratio_transform.ratios_to_node_heights. - build.sh/build.py: build both ops (selectable by name). - NodeHeightRatioBijector gains use_native (False default, True/"auto"), routing only the forward transform through the native op. - profile CLI times the native ratio transform and reports its speedup. - Tests mirror the native-likelihood suite: forward/gradient vs the reference, finite differences, batching, anchor-broadcast reduction, and bijector integration. https://claude.ai/code/session_0116ZUM3pEYTRmSCLPybq1kh
The bijector now auto-detects the native ratio-transform op (matching LeafCTMC's use_native="auto" convention) instead of being opt-in. Since the native op registers only a first-order gradient, the one test that differentiates the forward log-det-Jacobian (a higher-order derivative through the forward transform) is pinned to use_native=False, and the profile CLI's pure-TensorFlow baseline timing is likewise pinned so the native-vs-TF comparison stays meaningful. https://claude.ai/code/session_0116ZUM3pEYTRmSCLPybq1kh
A spike (no library changes) to decide how to reuse native-accelerated tree traversals more widely and host neural-network blocks per node. Benchmarks a generic differentiable traversal across execution backends: - graph-mode tf.while_loop (baseline) - the same under XLA (jit_compile) - an unrolled fixed-topology graph (+/- XLA) - the existing native C++ ops (ratio transform, phylogenetic likelihood) - an optional JAX lax.scan prototype that runs only if JAX is importable over three per-node operations: the affine node-height ratio transform, the multilinear Felsenstein partial likelihood (both with a native op for a ceiling), and an NN message-passing block with a hidden-width sweep. Every backend is correctness-gated against the reference before timing; compile/first-call times are reported separately. Findings: for tiny per-node ops the native op dominates and its margin grows with tree size, while for NN-heavy nodes graph/XLA/JAX converge as the per-node matmuls dominate -- evidence that programmable/NN traversals should be built on a compiled generic combinator rather than a hand-written C++ kernel. The tf.while_loop driver doubles as a prototype of that combinator. https://claude.ai/code/session_0116ZUM3pEYTRmSCLPybq1kh
Revert the einsum node-compute and scatter leaf-init tweaks in phylo_likelihood: both are faster on the forward pass but have a more expensive gradient, and since the backward dominates, value+gradient (the inference path the profiler measures) was ~15-39% slower across tree sizes up to 1024 taxa -- matching the CI profiling regression. Add a use_matvec parameter (default False) to phylogenetic_likelihood and phylogenetic_log_likelihood_rescaled, via a shared _combine_child_partials helper. Default keeps multiply+reduce_sum (cheap gradient). use_matvec=True uses tf.linalg.matvec: ~2x faster forward on large trees but slower value+gradient, so it is documented as forward-only. Values and gradients are identical either way. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
- TF likelihood backends use multiply+reduce_sum (matches the library node compute; cheaper gradient than einsum). - Decouple NN sizes: NN_TAXA=[16,64,256] (smaller, slower op) and NN_WIDTHS=[8,32,128], kept disjoint so the taxa/width sweeps no longer collide at size 64; NN_WIDTH_TAXA=64. - Add a §6 note that underflow is accepted (timings are value-independent and stay valid; fixing large problems needs the rescaled likelihood). - Rewrite §9 takeaways with conclusions from the fresh run (two regimes by per-node cost; avoid full-state scatter; XLA only helps overhead-bound ops; unrolling wins for fixed topology; per-op design conclusions). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…ixed-topology VI Change ratios_to_node_heights and phylogenetic_likelihood (+ rescaled, + dispatch) to take a topology object and delegate to the generic preorder/postorder traversal drivers, removing the shim/index-lifting glue. Drop the @tf.function decorators on the likelihood functions so a captured topology folds and the unrolled path fires inside a traced VI step. Add use_native + unroll to fit_fixed_topology_variational_approximation, threaded to the approximation's node-height bijector via the coalescent/ birth-death default event-space bijectors -> TreeRatioBijector -> NodeHeightRatioBijector. Thread unroll through phylo_model_to_joint_distribution -> LeafCTMC -> likelihood dispatch. All new args default to "auto". Rewrite test/vi/test_fixed_topology_advi.py to fit through both native and pure-TensorFlow engines with unroll on/off, plus a graph-level check that unroll removes the traversal while_loop. Update native-comparison and traversal tests to the new signatures. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…native Replace the single pure-TensorFlow engine with two traversal engines -- 'static' (unrolled for the static topology, unroll=True) and 'dynamic' (TensorArray while_loop, unroll=False) -- alongside 'native'. Time both the likelihood and the ratio transform for each, and report speedups relative to the dynamic engine. 'tf' is kept as an alias for 'dynamic'; the CI profiling workflow now requests native,static,dynamic. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
tf.range no longer constant-folds in-graph on TF 2.21, so the unrolled postorder driver could not statically detect the topology inside a traced model and unroll=True raised (the profiling CI 'static' engine and the fixed-topology VI likelihood). Fold a numpy arange into a constant when the taxon count is static, mirroring preorder_node_indices. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…failures
The unrolled (static) engine cannot statically detect the topology when the tree
is routed through a JointDistribution (TFP decomposes it; get_static_value only
folds for small trees), so the full-joint static timing fails past ~32 taxa.
Rather than abort, time each measurement defensively: a failure records a NaN
("missing") time plus a warning surfaced in the output, and the run continues.
Also time the phylogenetic likelihood *directly* per engine (captured-constant
topology), which unrolls at any size, giving a meaningful native/static/dynamic
likelihood comparison alongside the full-joint breakdown.
Add >64-taxon unroll coverage: parametrise the 600-taxon rescaled test over
unroll, and a 128-taxon graph check that unroll removes the traversal while_loop.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Add StaticNumpyTreeTopology: a tf.nest-atom topology with pre-computed index arrays that threads through traced code as a captured constant, so the node-height ratio-transform traversal folds it and unrolls at any tree size. A static topology cannot live in a JointDistribution's tree value (the JD coerces the value to tensors via tf.convert_to_tensor, which can't preserve a non-tensor atom). So FixedTopologyRootedTreeBijector rebuilds it as an in-graph-constant TensorflowTreeTopology in _forward via the new StaticNumpyTreeTopology.to_constant_tensor_topology() helper. Because _forward runs inside the trace, those Const ops fold (tf.get_static_value) at any size, so the downstream likelihood traversal unrolls too -- unlike a captured tensor topology, which is re-materialised as a Placeholder for >64 taxa. Also: - preorder dynamic (unroll=False) driver tf.convert_to_tensor's the topology index arrays so a static NumPy topology works there too (a NumPy array can't be indexed by a symbolic loop counter). - profiler: build the full-joint tree value from an in-trace constant topology so the static engine unrolls through the JointDistribution -- the full-joint static timings are now real at 64/128 taxa instead of NaN. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Replace the boolean `unroll` with "auto" plus three named traversal modes in
postorder/preorder:
- "unrolled" : straight-line graph, no TensorArray. Needs the topology index
*values* statically known. Fastest, and the only mode that
XLA-compiles for value+gradient.
- "tensorarray" : Python-unrolled loop over a TensorArray. Needs only a static
node *count* (index values may be runtime tensors).
- "while_loop" : AutoGraph tf.while_loop over a TensorArray. Needs nothing
static; O(1) graph, for a varying/very large topology.
"auto" prefers "unrolled" when the values are static, else "tensorarray" when the
count is static, else "while_loop". Static-count detection uses the prefer_static
topology.taxon_count (robust to the boolean_mask shape-drop on preorder indices).
The docstrings record the empirically-verified XLA/jit_compile behaviour: only
"unrolled" compiles for value+gradient; the TensorArray modes compile only the
forward pass and only with xla_compatible=True; and an inner tf.function is not a
jit boundary.
Profiler engines map static->"unrolled", dynamic->"while_loop", plus a new
"tensorarray" engine. StaticNumpyTreeTopology is now accepted by the numpy rooted
tree and numpy_topology_to_tensor. Tests parametrize over the named modes.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Implements a custom TFP ConvergenceCriterion that tracks EWMA of per-step loss decrease normalised by |ELBO|, making the threshold invariant to dataset scale and starting conditions. A min_consecutive parameter (default 10) requires the condition to hold for N consecutive steps, preventing spurious early stopping from transient single-step dips in rel_rate. Also adds: - VIResults.convergence_criterion_state field (backward-compat default None) so criterion state (ewma, rel_rate, consecutive_below) is available in traces - 20 unit tests covering EWMA update, NaN handling, and convergence logic - experiments/advi_stability.ipynb: 5-run YFV ADVI stability study; conclusion notes that clock_rate / root_height variation reflects the mean-field approximation's inability to capture the clock-rate × root-height ridge Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Adds cells to run NUM_IAF_RUNS independent IAF fits alongside the existing mean-field runs, then compares: - Within-IAF stability (inter_run/post_sd table) - Loss traces: mean-field (dashed) vs IAF (solid) on shared axes - Pooled posterior marginals: MF (blue) vs IAF (orange) histograms - Side-by-side MF vs IAF summary table Conclusion has a placeholder IAF findings section with the key questions to fill in once the cells have been run. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Implements get_fixed_topology_full_rank_approximation: a multivariate Normal in the joint unconstrained parameter space, capable of capturing correlations (e.g. the clock rate x root height ridge that degrades mean-field stability). Key design: - _FullRankAffineBijector stores loc (D,) and raw_scale (D,D) as tf.Variables - lower-triangular extraction via band_part, diagonal positivity via softplus - Composed with the existing split/restructure/event-space bijector chain (same pattern as the IAF approximation) - Initialised to loc=prior-medians (unconstrained), scale_tril ≈ identity Updates advi_stability.ipynb to run 3 full-rank fits alongside the 5 mean-field runs, with loss-trace comparison, within-FR stability summary, and pooled MF vs FR posterior marginal histograms. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Full-rank ELBO ~2700-3700 nats better than mean-field (FR: -6099 to -6123 vs MF: -8845 to -9790), confirming mean-field loses significant posterior mass on the strict-clock model. Root height posterior std expands 2.5x (86 -> 213) and pop_size 1.4x wider under full-rank, reflecting correct representation of the clock rate x root height ridge. Inter-run stability for clock_rate is similar between approximations (0.60 vs 0.62), showing the degeneracy is inherent in the posterior geometry, not an artifact of the approximation family. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Quantify the clock-rate × root-height ridge: Pearson corr = -0.48, log-log slope = -0.75 (partial non-identifiability), quadratic coeff = 0.08 - Test IAF as a more flexible approximation family: all three IAF runs fail (run 1 stalls at ELBO ≈ -14700; runs 2-3 crash with NaN gradients), showing optimisation difficulty is the limiting factor, not expressiveness - Update conclusion with geometry diagnostics and IAF findings - Refresh native_vs_tf_vi_validation.ipynb outputs Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Move IAF cells from advi_stability.ipynb to new advi_iaf.ipynb, which includes self-contained MF/FR reference runs for comparison - advi_stability.ipynb now covers mean-field and full-rank only - iaf.py: add trainable affine base (loc_var + log_scale_var) so the IAF is Normal(init_loc, 1) at network init rather than a random scramble - iaf.py: use DeferredTensor for softplus(log_scale_var) so gradients flow back to log_scale_var during training - iaf.py: use TruncatedNormal(stddev=0.01) kernel init for the autoregressive network so IAF bijectors start near-identity, preventing NaN warm-up losses from extreme initial samples Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
RelativeLossNotDecreasingconvergence criterion — tracks an EWMA of the per-step ELBO decrease normalised by |ELBO|, with amin_consecutivethreshold to avoid spurious early stops from single-step dips in the convergence ratetreeflow/model/approximation/full_rank.py) — multivariate Normal with full lower-triangular covariance in joint unconstrained space (~6k parameters for YFV vs 154 for mean-field); achieves ~2700–3700 nat ELBO improvement over mean-field on YFViaf.py) — trainable affine base (loc_var,log_scale_var) so the IAF starts as a mean-field at network init;DeferredTensorfix solog_scale_varreceives gradients; small kernel init (stddev=0.01) to prevent NaN warm-up losses from extreme initial samples; surrogate warm-up against a fitted mean-field target before ELBO optimisationadvi_stability.ipynb— experiment notebook covering mean-field stability across seeds, full-rank comparison, and posterior geometry analysis of the clock rate × root height non-identifiabilityadvi_iaf.ipynb— self-contained IAF experiment notebook with MF/FR reference runs and the surrogate warm-up strategy; IAF achieves ~−6010 ELBO vs full-rank ~−6110RelativeLossNotDecreasing; phylo likelihood test parametrized over all three unroll modes (unrolled,tensorarray,while_loop)vi/util.py—VIResultsnamedtuple anddefault_vi_trace_fnto capture convergence criterion state in tracesTest plan
pytest test/vi/test_relative_loss_not_decreasing.py— all 20 convergence criterion tests passpytest test/traversal/test_phylo_likelihood.py— parametrized over unroll modesadvi_stability.ipynbend-to-end (MF + FR cells)advi_iaf.ipynbend-to-end (verifies warm-up + ELBO convergence)🤖 Generated with Claude Code