Skip to content

Add convergence criterion, full-rank and IAF approximations, ADVI stability notebooks#108

Open
christiaanjs wants to merge 27 commits into
masterfrom
convergence
Open

Add convergence criterion, full-rank and IAF approximations, ADVI stability notebooks#108
christiaanjs wants to merge 27 commits into
masterfrom
convergence

Conversation

@christiaanjs

@christiaanjs christiaanjs commented Jun 16, 2026

Copy link
Copy Markdown
Owner

Summary

  • RelativeLossNotDecreasing convergence criterion — tracks an EWMA of the per-step ELBO decrease normalised by |ELBO|, with a min_consecutive threshold to avoid spurious early stops from single-step dips in the convergence rate
  • Full-rank variational approximation (treeflow/model/approximation/full_rank.py) — multivariate Normal with full lower-triangular covariance in joint unconstrained space (~6k parameters for YFV vs 154 for mean-field); achieves ~2700–3700 nat ELBO improvement over mean-field on YFV
  • IAF approximation improvements (iaf.py) — trainable affine base (loc_var, log_scale_var) so the IAF starts as a mean-field at network init; DeferredTensor fix so log_scale_var receives gradients; small kernel init (stddev=0.01) to prevent NaN warm-up losses from extreme initial samples; surrogate warm-up against a fitted mean-field target before ELBO optimisation
  • advi_stability.ipynb — experiment notebook covering mean-field stability across seeds, full-rank comparison, and posterior geometry analysis of the clock rate × root height non-identifiability
  • advi_iaf.ipynb — self-contained IAF experiment notebook with MF/FR reference runs and the surrogate warm-up strategy; IAF achieves ~−6010 ELBO vs full-rank ~−6110
  • Tests — 20 tests for RelativeLossNotDecreasing; phylo likelihood test parametrized over all three unroll modes (unrolled, tensorarray, while_loop)
  • vi/util.pyVIResults namedtuple and default_vi_trace_fn to capture convergence criterion state in traces

Test plan

  • pytest test/vi/test_relative_loss_not_decreasing.py — all 20 convergence criterion tests pass
  • pytest test/traversal/test_phylo_likelihood.py — parametrized over unroll modes
  • Run advi_stability.ipynb end-to-end (MF + FR cells)
  • Run advi_iaf.ipynb end-to-end (verifies warm-up + ELBO convergence)

🤖 Generated with Claude Code

claude and others added 24 commits June 14, 2026 05:25
Implements a compiled TensorFlow custom op for the node-height ratio
transform (a preorder tree traversal mapping height ratios to node
heights) with an analytic reverse-mode gradient, mirroring the existing
native phylogenetic-likelihood op.

- cc/node_height_ratio_op.cc: forward (NodeHeightRatio) + analytic
  gradient (NodeHeightRatioGrad) walking nodes in reverse preorder,
  reusing the saved forward heights; float32/float64, int32/int64
  indices, batched over leading sample/site dims.
- cc/tree_traversal.h: shared host-side index helpers (ReadIndices) and
  documented index conventions, now used by both native ops; the
  likelihood op is refactored to consume it.
- node_height_ratio.py: Python wrapper + RegisterGradient, a drop-in for
  traversal.ratio_transform.ratios_to_node_heights.
- build.sh/build.py: build both ops (selectable by name).
- NodeHeightRatioBijector gains use_native (False default, True/"auto"),
  routing only the forward transform through the native op.
- profile CLI times the native ratio transform and reports its speedup.
- Tests mirror the native-likelihood suite: forward/gradient vs the
  reference, finite differences, batching, anchor-broadcast reduction,
  and bijector integration.

https://claude.ai/code/session_0116ZUM3pEYTRmSCLPybq1kh
The bijector now auto-detects the native ratio-transform op (matching
LeafCTMC's use_native="auto" convention) instead of being opt-in. Since
the native op registers only a first-order gradient, the one test that
differentiates the forward log-det-Jacobian (a higher-order derivative
through the forward transform) is pinned to use_native=False, and the
profile CLI's pure-TensorFlow baseline timing is likewise pinned so the
native-vs-TF comparison stays meaningful.

https://claude.ai/code/session_0116ZUM3pEYTRmSCLPybq1kh
A spike (no library changes) to decide how to reuse native-accelerated tree
traversals more widely and host neural-network blocks per node. Benchmarks a
generic differentiable traversal across execution backends:

- graph-mode tf.while_loop (baseline)
- the same under XLA (jit_compile)
- an unrolled fixed-topology graph (+/- XLA)
- the existing native C++ ops (ratio transform, phylogenetic likelihood)
- an optional JAX lax.scan prototype that runs only if JAX is importable

over three per-node operations: the affine node-height ratio transform, the
multilinear Felsenstein partial likelihood (both with a native op for a ceiling),
and an NN message-passing block with a hidden-width sweep. Every backend is
correctness-gated against the reference before timing; compile/first-call times
are reported separately.

Findings: for tiny per-node ops the native op dominates and its margin grows
with tree size, while for NN-heavy nodes graph/XLA/JAX converge as the per-node
matmuls dominate -- evidence that programmable/NN traversals should be built on a
compiled generic combinator rather than a hand-written C++ kernel. The
tf.while_loop driver doubles as a prototype of that combinator.

https://claude.ai/code/session_0116ZUM3pEYTRmSCLPybq1kh
Revert the einsum node-compute and scatter leaf-init tweaks in
phylo_likelihood: both are faster on the forward pass but have a more
expensive gradient, and since the backward dominates, value+gradient
(the inference path the profiler measures) was ~15-39% slower across
tree sizes up to 1024 taxa -- matching the CI profiling regression.

Add a use_matvec parameter (default False) to phylogenetic_likelihood
and phylogenetic_log_likelihood_rescaled, via a shared
_combine_child_partials helper. Default keeps multiply+reduce_sum
(cheap gradient). use_matvec=True uses tf.linalg.matvec: ~2x faster
forward on large trees but slower value+gradient, so it is documented
as forward-only. Values and gradients are identical either way.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
- TF likelihood backends use multiply+reduce_sum (matches the library node
  compute; cheaper gradient than einsum).
- Decouple NN sizes: NN_TAXA=[16,64,256] (smaller, slower op) and
  NN_WIDTHS=[8,32,128], kept disjoint so the taxa/width sweeps no longer
  collide at size 64; NN_WIDTH_TAXA=64.
- Add a §6 note that underflow is accepted (timings are value-independent
  and stay valid; fixing large problems needs the rescaled likelihood).
- Rewrite §9 takeaways with conclusions from the fresh run (two regimes by
  per-node cost; avoid full-state scatter; XLA only helps overhead-bound
  ops; unrolling wins for fixed topology; per-op design conclusions).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…ixed-topology VI

Change ratios_to_node_heights and phylogenetic_likelihood (+ rescaled, +
dispatch) to take a topology object and delegate to the generic
preorder/postorder traversal drivers, removing the shim/index-lifting glue.
Drop the @tf.function decorators on the likelihood functions so a captured
topology folds and the unrolled path fires inside a traced VI step.

Add use_native + unroll to fit_fixed_topology_variational_approximation,
threaded to the approximation's node-height bijector via the coalescent/
birth-death default event-space bijectors -> TreeRatioBijector ->
NodeHeightRatioBijector. Thread unroll through phylo_model_to_joint_distribution
-> LeafCTMC -> likelihood dispatch. All new args default to "auto".

Rewrite test/vi/test_fixed_topology_advi.py to fit through both native and
pure-TensorFlow engines with unroll on/off, plus a graph-level check that
unroll removes the traversal while_loop. Update native-comparison and traversal
tests to the new signatures.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…native

Replace the single pure-TensorFlow engine with two traversal engines --
'static' (unrolled for the static topology, unroll=True) and 'dynamic'
(TensorArray while_loop, unroll=False) -- alongside 'native'. Time both the
likelihood and the ratio transform for each, and report speedups relative to
the dynamic engine. 'tf' is kept as an alias for 'dynamic'; the CI profiling
workflow now requests native,static,dynamic.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
tf.range no longer constant-folds in-graph on TF 2.21, so the unrolled
postorder driver could not statically detect the topology inside a traced
model and unroll=True raised (the profiling CI 'static' engine and the
fixed-topology VI likelihood). Fold a numpy arange into a constant when the
taxon count is static, mirroring preorder_node_indices.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…failures

The unrolled (static) engine cannot statically detect the topology when the tree
is routed through a JointDistribution (TFP decomposes it; get_static_value only
folds for small trees), so the full-joint static timing fails past ~32 taxa.
Rather than abort, time each measurement defensively: a failure records a NaN
("missing") time plus a warning surfaced in the output, and the run continues.

Also time the phylogenetic likelihood *directly* per engine (captured-constant
topology), which unrolls at any size, giving a meaningful native/static/dynamic
likelihood comparison alongside the full-joint breakdown.

Add >64-taxon unroll coverage: parametrise the 600-taxon rescaled test over
unroll, and a 128-taxon graph check that unroll removes the traversal while_loop.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Add StaticNumpyTreeTopology: a tf.nest-atom topology with pre-computed index
arrays that threads through traced code as a captured constant, so the
node-height ratio-transform traversal folds it and unrolls at any tree size.

A static topology cannot live in a JointDistribution's tree value (the JD
coerces the value to tensors via tf.convert_to_tensor, which can't preserve a
non-tensor atom). So FixedTopologyRootedTreeBijector rebuilds it as an
in-graph-constant TensorflowTreeTopology in _forward via the new
StaticNumpyTreeTopology.to_constant_tensor_topology() helper. Because _forward
runs inside the trace, those Const ops fold (tf.get_static_value) at any size,
so the downstream likelihood traversal unrolls too -- unlike a captured tensor
topology, which is re-materialised as a Placeholder for >64 taxa.

Also:
- preorder dynamic (unroll=False) driver tf.convert_to_tensor's the topology
  index arrays so a static NumPy topology works there too (a NumPy array can't
  be indexed by a symbolic loop counter).
- profiler: build the full-joint tree value from an in-trace constant topology
  so the static engine unrolls through the JointDistribution -- the full-joint
  static timings are now real at 64/128 taxa instead of NaN.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Note: we should still be able to unroll the loop for a dynamic topology because the taxon count should be statically known
Replace the boolean `unroll` with "auto" plus three named traversal modes in
postorder/preorder:

  - "unrolled"    : straight-line graph, no TensorArray. Needs the topology index
                    *values* statically known. Fastest, and the only mode that
                    XLA-compiles for value+gradient.
  - "tensorarray" : Python-unrolled loop over a TensorArray. Needs only a static
                    node *count* (index values may be runtime tensors).
  - "while_loop"  : AutoGraph tf.while_loop over a TensorArray. Needs nothing
                    static; O(1) graph, for a varying/very large topology.

"auto" prefers "unrolled" when the values are static, else "tensorarray" when the
count is static, else "while_loop". Static-count detection uses the prefer_static
topology.taxon_count (robust to the boolean_mask shape-drop on preorder indices).

The docstrings record the empirically-verified XLA/jit_compile behaviour: only
"unrolled" compiles for value+gradient; the TensorArray modes compile only the
forward pass and only with xla_compatible=True; and an inner tf.function is not a
jit boundary.

Profiler engines map static->"unrolled", dynamic->"while_loop", plus a new
"tensorarray" engine. StaticNumpyTreeTopology is now accepted by the numpy rooted
tree and numpy_topology_to_tensor. Tests parametrize over the named modes.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Implements a custom TFP ConvergenceCriterion that tracks EWMA of per-step
loss decrease normalised by |ELBO|, making the threshold invariant to dataset
scale and starting conditions.  A min_consecutive parameter (default 10)
requires the condition to hold for N consecutive steps, preventing spurious
early stopping from transient single-step dips in rel_rate.

Also adds:
- VIResults.convergence_criterion_state field (backward-compat default None)
  so criterion state (ewma, rel_rate, consecutive_below) is available in traces
- 20 unit tests covering EWMA update, NaN handling, and convergence logic
- experiments/advi_stability.ipynb: 5-run YFV ADVI stability study; conclusion
  notes that clock_rate / root_height variation reflects the mean-field
  approximation's inability to capture the clock-rate × root-height ridge

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Adds cells to run NUM_IAF_RUNS independent IAF fits alongside the existing
mean-field runs, then compares:
- Within-IAF stability (inter_run/post_sd table)
- Loss traces: mean-field (dashed) vs IAF (solid) on shared axes
- Pooled posterior marginals: MF (blue) vs IAF (orange) histograms
- Side-by-side MF vs IAF summary table

Conclusion has a placeholder IAF findings section with the key questions
to fill in once the cells have been run.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Implements get_fixed_topology_full_rank_approximation: a multivariate Normal
in the joint unconstrained parameter space, capable of capturing correlations
(e.g. the clock rate x root height ridge that degrades mean-field stability).

Key design:
- _FullRankAffineBijector stores loc (D,) and raw_scale (D,D) as tf.Variables
- lower-triangular extraction via band_part, diagonal positivity via softplus
- Composed with the existing split/restructure/event-space bijector chain
  (same pattern as the IAF approximation)
- Initialised to loc=prior-medians (unconstrained), scale_tril ≈ identity

Updates advi_stability.ipynb to run 3 full-rank fits alongside the 5
mean-field runs, with loss-trace comparison, within-FR stability summary,
and pooled MF vs FR posterior marginal histograms.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Full-rank ELBO ~2700-3700 nats better than mean-field (FR: -6099 to -6123
vs MF: -8845 to -9790), confirming mean-field loses significant posterior
mass on the strict-clock model.

Root height posterior std expands 2.5x (86 -> 213) and pop_size 1.4x wider
under full-rank, reflecting correct representation of the clock rate x root
height ridge.  Inter-run stability for clock_rate is similar between
approximations (0.60 vs 0.62), showing the degeneracy is inherent in the
posterior geometry, not an artifact of the approximation family.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Quantify the clock-rate × root-height ridge: Pearson corr = -0.48,
  log-log slope = -0.75 (partial non-identifiability), quadratic coeff = 0.08
- Test IAF as a more flexible approximation family: all three IAF runs
  fail (run 1 stalls at ELBO ≈ -14700; runs 2-3 crash with NaN gradients),
  showing optimisation difficulty is the limiting factor, not expressiveness
- Update conclusion with geometry diagnostics and IAF findings
- Refresh native_vs_tf_vi_validation.ipynb outputs

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Move IAF cells from advi_stability.ipynb to new advi_iaf.ipynb,
  which includes self-contained MF/FR reference runs for comparison
- advi_stability.ipynb now covers mean-field and full-rank only
- iaf.py: add trainable affine base (loc_var + log_scale_var) so the
  IAF is Normal(init_loc, 1) at network init rather than a random scramble
- iaf.py: use DeferredTensor for softplus(log_scale_var) so gradients
  flow back to log_scale_var during training
- iaf.py: use TruncatedNormal(stddev=0.01) kernel init for the
  autoregressive network so IAF bijectors start near-identity,
  preventing NaN warm-up losses from extreme initial samples

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@christiaanjs christiaanjs changed the title Add native C++ op for the node-height ratio transform Add convergence criterion, full-rank and IAF approximations, ADVI stability notebooks Jun 16, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants