Birth-death tree sampling with tf.function compatibility#90
Birth-death tree sampling with tf.function compatibility#90christiaanjs wants to merge 9 commits into
Conversation
Replaces the python import check with pytest --collect-only, which confirms not just that packages are importable but that the full test suite can be loaded successfully. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…ices Replaces np_top_ops.get_preorder_indices with a pure-TF iterative DFS (compute_preorder_indices) that handles both unbatched and batched child_indices and is fully tf.function-compatible. The CPP sampler now imports this shared function instead of keeping a local copy. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…andom_topologies Internal nodes in the CPP sampler are created in postorder (each parent created after both its children are active lineages), so the reversed creation order is a valid preorder without running a separate DFS. Leaves are appended last; since every leaf's parent is an internal node the parent-before-child invariant holds. The resulting preorder is identical for every sample in the batch, so it is constructed once and tiled. Documents that build_random_topologies samples uniformly over ranked labeled tree topologies (labeled histories). Adds four new tests: preorder contains all nodes, parent precedes child, all samples share the same preorder, and compatibility with ratios_to_node_heights. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Allows a caller to supply a fixed tree topology and still draw valid trees from the birth-death distribution: the CPP sampler (Stage 2) generates sorted speciation times, then a uniformly random linear extension (ranking) of the topology's internal-node partial order is sampled, and the heights are permuted to match. New function sample_ranking(n_taxa, child_indices, parent_indices, n_total, seed) samples n_total independent linear extensions via an iterative eligible-set algorithm: at each step, one node is drawn uniformly from the set of internal nodes whose every internal child has already been ranked. Sampling uses float multiplication (not modulo) to avoid bias. BirthDeathContemporarySampling.__init__ gains a fixed_topology kwarg that is threaded through to sample_bd_tree. The fixed topology is tiled and reshaped into [n_samples, *batch_shape, ...] form so that the base class machinery works unchanged. Adds 7 new tests: ranking permutation correctness, parent-always- higher-rank invariant, empirical uniformity over linear extensions, reproducibility, and fixed-topology shape / height-monotonicity / log-prob-finite checks. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
Implements a TensorFlow-native birth–death (CPP backward algorithm) tree sampler and updates traversal/topology utilities to be compatible with tf.function, including support for topology batch dimensions.
Changes:
- Added a new CPP birth–death sampler (
cpp_sampler.py) and wired it intoBirthDeathContemporarySampling._sample_n. - Introduced
compute_preorder_indices(pure-TF iterative DFS) and used it innumpy_topology_to_tensor. - Extended CTMC preorder sampling to handle batched topologies; added substantial test coverage for the new sampler and traversal logic.
Reviewed changes
Copilot reviewed 15 out of 15 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| treeflow/distributions/tree/birthdeath/cpp_sampler.py | New CPP sampler implementation (origin age, speciation times, random topology, ranking, orchestration). |
| treeflow/distributions/tree/birthdeath/birth_death_contemporary_sampling.py | Uses new sampler; adds fixed_topology; enables topology batch dims. |
| treeflow/tree/topology/tensorflow_tree_topology.py | Adds compute_preorder_indices and switches numpy→TF topology conversion to use it. |
| treeflow/traversal/sample_ctmc.py | Adds a batched-topology execution path for preorder CTMC sampling. |
| treeflow/distributions/tree/rooted_tree_distribution.py | Implements dummy topology samples when topology batch dims are supported. |
| treeflow/distributions/leaf_ctmc.py | Removes the sampling-time topology-batching guard (relies on updated traversal). |
| test/distributions/tree/birthdeath/test_cpp_sampler.py | New comprehensive tests for CPP sampling and fixed-topology path. |
| test/tree/topology/test_tensorflow_topology.py | Adds tests for compute_preorder_indices (batched/unbatched + tf.function). |
| test/vi/test_marginal_likelihood.py | Marks one test as flaky with reruns. |
| test/evolution/substitution/test_probabilities.py | Removes an unused/incorrect import. |
| setup.cfg | Adds pytest-rerunfailures to test dependencies. |
| dev/requirements.txt | Pins pytest-rerunfailures. |
| CLAUDE.md | Updates guidance about pytest path usage. |
| .claude/settings.json | Updates Claude permissions configuration. |
| .claude/hooks/session-start.sh | Changes session-start fast-path to use pytest collection. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| n_batch = int(tf.reduce_prod(tf.shape(lambda_))) | ||
| else: | ||
| n_batch = int(np.prod(shape_list)) if shape_list else 1 |
There was a problem hiding this comment.
n_batch = int(tf.reduce_prod(tf.shape(lambda_))) will raise TypeError under tf.function/graph mode when lambda_ has any dynamic batch dimension (a tf.Tensor can’t be converted to a Python int). Since n_total must be a Python int for build_random_topologies, this path breaks the stated tf.function compatibility. Consider requiring a fully static lambda_.shape (and raising a clear ValueError otherwise), or refactoring Stage 3 to avoid Python-int n_total (e.g., make topology construction work with tensor-sized loops/TensorArrays).
| n_batch = int(tf.reduce_prod(tf.shape(lambda_))) | |
| else: | |
| n_batch = int(np.prod(shape_list)) if shape_list else 1 | |
| raise ValueError( | |
| "lambda_ must have a fully static batch shape when sampling tree " | |
| "topologies, because build_random_topologies requires a Python int " | |
| "for n_total. Got lambda_.shape={!r}.".format(lambda_.shape) | |
| ) | |
| n_batch = int(np.prod(shape_list)) if shape_list else 1 |
* Add HMC + BirthDeathContemporarySampling validation notebook Adds experiments/validate_hmc_birth_death.ipynb which checks that fit_fixed_topology_hmc (NUTS kernel) samples from the correct distribution by comparing against direct simulation from BirthDeathContemporarySampling on a fixed topology. Uses KS two-sample tests and histogram overlays for all 7 internal node heights on an 8-taxon tree. https://claude.ai/code/session_019wmV49E4JVWySo3cAoYim7 * Add experiments extras and README for notebook dependencies Adds a new [experiments] extras_require entry to setup.cfg with matplotlib, scipy, jupyter, and nbconvert. Adds experiments/README.md documenting the install step and how to run notebooks. https://claude.ai/code/session_019wmV49E4JVWySo3cAoYim7 * Run HMC notebook * Rerun HMC notebook * Add report on HMC geometry and preconditioning for time trees Analyses the funnel geometry arising from the ratio transform, explains why NUTS undersamples shallow nodes, surveys alternative samplers, and proposes fixed metric preconditioning (G* = J(h*)ᵀ D(h*) J(h*)) as a tractable near-term fix using TransformedTransitionKernel. https://claude.ai/code/session_019wmV49E4JVWySo3cAoYim7 --------- Co-authored-by: Claude <noreply@anthropic.com>
Summary
BirthDeathContemporarySampling)build_random_topologiesin pure TensorFlow so it runs correctly insidetf.function(eliminates all.numpy()calls and numpy-based array construction)compute_preorder_indicesinTensorflowTreeTopology, shared across the sampler and topology utilities, with fulltf.functioncompatibility for both unbatched and batched inputsKey changes
treeflow/distributions/tree/birthdeath/cpp_sampler.py— new module withsample_bd_treeandbuild_random_topologies; usestf.gather_nd/tf.tensor_scatter_nd_update/tf.wherefor batched topology construction without.numpy()callstreeflow/tree/topology/tensorflow_tree_topology.py— addscompute_preorder_indices(child_indices)(pure-TF iterative DFS, handles batched and unbatched);numpy_topology_to_tensornow uses it instead of the numpy fallbacktreeflow/distributions/tree/birthdeath/birth_death_contemporary_sampling.py—_sample_nwired to new sampler with correct parameter conversion (r,a→lambda_,mu)test/distributions/tree/birthdeath/test_cpp_sampler.py— 26 tests covering sampler correctness, tf.function tracing, log-prob consistency, and edge casestest/tree/topology/test_tensorflow_topology.py— 6 new tests forcompute_preorder_indices(unbatched/batched vs numpy, root-first, all-nodes-present, insidetf.function)Test plan
pytest test/distributions/tree/birthdeath/test_cpp_sampler.py— 26 sampler tests passpytest test/tree/topology/test_tensorflow_topology.py— all topology tests including 6 new preorder tests passpytest— full suite passes🤖 Generated with Claude Code