Skip to content

Birth-death tree sampling with tf.function compatibility#90

Open
christiaanjs wants to merge 9 commits into
masterfrom
bd-tree-sampling
Open

Birth-death tree sampling with tf.function compatibility#90
christiaanjs wants to merge 9 commits into
masterfrom
bd-tree-sampling

Conversation

@christiaanjs

Copy link
Copy Markdown
Owner

Summary

  • Implements CPP (Coalescent Point Process) backward algorithm for sampling trees from the birth-death contemporary sampling distribution (BirthDeathContemporarySampling)
  • Rewrites build_random_topologies in pure TensorFlow so it runs correctly inside tf.function (eliminates all .numpy() calls and numpy-based array construction)
  • Extracts preorder DFS traversal into compute_preorder_indices in TensorflowTreeTopology, shared across the sampler and topology utilities, with full tf.function compatibility for both unbatched and batched inputs

Key changes

  • treeflow/distributions/tree/birthdeath/cpp_sampler.py — new module with sample_bd_tree and build_random_topologies; uses tf.gather_nd / tf.tensor_scatter_nd_update / tf.where for batched topology construction without .numpy() calls
  • treeflow/tree/topology/tensorflow_tree_topology.py — adds compute_preorder_indices(child_indices) (pure-TF iterative DFS, handles batched and unbatched); numpy_topology_to_tensor now uses it instead of the numpy fallback
  • treeflow/distributions/tree/birthdeath/birth_death_contemporary_sampling.py_sample_n wired to new sampler with correct parameter conversion (r, alambda_, mu)
  • test/distributions/tree/birthdeath/test_cpp_sampler.py — 26 tests covering sampler correctness, tf.function tracing, log-prob consistency, and edge cases
  • test/tree/topology/test_tensorflow_topology.py — 6 new tests for compute_preorder_indices (unbatched/batched vs numpy, root-first, all-nodes-present, inside tf.function)

Test plan

  • pytest test/distributions/tree/birthdeath/test_cpp_sampler.py — 26 sampler tests pass
  • pytest test/tree/topology/test_tensorflow_topology.py — all topology tests including 6 new preorder tests pass
  • pytest — full suite passes

🤖 Generated with Claude Code

christiaanjs and others added 8 commits April 2, 2026 07:57
Replaces the python import check with pytest --collect-only, which
confirms not just that packages are importable but that the full test
suite can be loaded successfully.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…ices

Replaces np_top_ops.get_preorder_indices with a pure-TF iterative DFS
(compute_preorder_indices) that handles both unbatched and batched
child_indices and is fully tf.function-compatible. The CPP sampler now
imports this shared function instead of keeping a local copy.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…andom_topologies

Internal nodes in the CPP sampler are created in postorder (each parent
created after both its children are active lineages), so the reversed
creation order is a valid preorder without running a separate DFS.  Leaves
are appended last; since every leaf's parent is an internal node the
parent-before-child invariant holds.  The resulting preorder is identical
for every sample in the batch, so it is constructed once and tiled.

Documents that build_random_topologies samples uniformly over ranked
labeled tree topologies (labeled histories).

Adds four new tests: preorder contains all nodes, parent precedes child,
all samples share the same preorder, and compatibility with
ratios_to_node_heights.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Allows a caller to supply a fixed tree topology and still draw valid
trees from the birth-death distribution: the CPP sampler (Stage 2)
generates sorted speciation times, then a uniformly random linear
extension (ranking) of the topology's internal-node partial order is
sampled, and the heights are permuted to match.

New function sample_ranking(n_taxa, child_indices, parent_indices,
n_total, seed) samples n_total independent linear extensions via an
iterative eligible-set algorithm: at each step, one node is drawn
uniformly from the set of internal nodes whose every internal child has
already been ranked.  Sampling uses float multiplication (not modulo)
to avoid bias.

BirthDeathContemporarySampling.__init__ gains a fixed_topology kwarg
that is threaded through to sample_bd_tree.  The fixed topology is
tiled and reshaped into [n_samples, *batch_shape, ...] form so that
the base class machinery works unchanged.

Adds 7 new tests: ranking permutation correctness, parent-always-
higher-rank invariant, empirical uniformity over linear extensions,
reproducibility, and fixed-topology shape / height-monotonicity /
log-prob-finite checks.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Implements a TensorFlow-native birth–death (CPP backward algorithm) tree sampler and updates traversal/topology utilities to be compatible with tf.function, including support for topology batch dimensions.

Changes:

  • Added a new CPP birth–death sampler (cpp_sampler.py) and wired it into BirthDeathContemporarySampling._sample_n.
  • Introduced compute_preorder_indices (pure-TF iterative DFS) and used it in numpy_topology_to_tensor.
  • Extended CTMC preorder sampling to handle batched topologies; added substantial test coverage for the new sampler and traversal logic.

Reviewed changes

Copilot reviewed 15 out of 15 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
treeflow/distributions/tree/birthdeath/cpp_sampler.py New CPP sampler implementation (origin age, speciation times, random topology, ranking, orchestration).
treeflow/distributions/tree/birthdeath/birth_death_contemporary_sampling.py Uses new sampler; adds fixed_topology; enables topology batch dims.
treeflow/tree/topology/tensorflow_tree_topology.py Adds compute_preorder_indices and switches numpy→TF topology conversion to use it.
treeflow/traversal/sample_ctmc.py Adds a batched-topology execution path for preorder CTMC sampling.
treeflow/distributions/tree/rooted_tree_distribution.py Implements dummy topology samples when topology batch dims are supported.
treeflow/distributions/leaf_ctmc.py Removes the sampling-time topology-batching guard (relies on updated traversal).
test/distributions/tree/birthdeath/test_cpp_sampler.py New comprehensive tests for CPP sampling and fixed-topology path.
test/tree/topology/test_tensorflow_topology.py Adds tests for compute_preorder_indices (batched/unbatched + tf.function).
test/vi/test_marginal_likelihood.py Marks one test as flaky with reruns.
test/evolution/substitution/test_probabilities.py Removes an unused/incorrect import.
setup.cfg Adds pytest-rerunfailures to test dependencies.
dev/requirements.txt Pins pytest-rerunfailures.
CLAUDE.md Updates guidance about pytest path usage.
.claude/settings.json Updates Claude permissions configuration.
.claude/hooks/session-start.sh Changes session-start fast-path to use pytest collection.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +587 to +589
n_batch = int(tf.reduce_prod(tf.shape(lambda_)))
else:
n_batch = int(np.prod(shape_list)) if shape_list else 1

Copilot AI Apr 7, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n_batch = int(tf.reduce_prod(tf.shape(lambda_))) will raise TypeError under tf.function/graph mode when lambda_ has any dynamic batch dimension (a tf.Tensor can’t be converted to a Python int). Since n_total must be a Python int for build_random_topologies, this path breaks the stated tf.function compatibility. Consider requiring a fully static lambda_.shape (and raising a clear ValueError otherwise), or refactoring Stage 3 to avoid Python-int n_total (e.g., make topology construction work with tensor-sized loops/TensorArrays).

Suggested change
n_batch = int(tf.reduce_prod(tf.shape(lambda_)))
else:
n_batch = int(np.prod(shape_list)) if shape_list else 1
raise ValueError(
"lambda_ must have a fully static batch shape when sampling tree "
"topologies, because build_random_topologies requires a Python int "
"for n_total. Got lambda_.shape={!r}.".format(lambda_.shape)
)
n_batch = int(np.prod(shape_list)) if shape_list else 1

Copilot uses AI. Check for mistakes.
Comment thread treeflow/tree/topology/tensorflow_tree_topology.py
Comment thread treeflow/traversal/sample_ctmc.py
* Add HMC + BirthDeathContemporarySampling validation notebook

Adds experiments/validate_hmc_birth_death.ipynb which checks that
fit_fixed_topology_hmc (NUTS kernel) samples from the correct distribution
by comparing against direct simulation from BirthDeathContemporarySampling
on a fixed topology. Uses KS two-sample tests and histogram overlays for
all 7 internal node heights on an 8-taxon tree.

https://claude.ai/code/session_019wmV49E4JVWySo3cAoYim7

* Add experiments extras and README for notebook dependencies

Adds a new [experiments] extras_require entry to setup.cfg with
matplotlib, scipy, jupyter, and nbconvert. Adds experiments/README.md
documenting the install step and how to run notebooks.

https://claude.ai/code/session_019wmV49E4JVWySo3cAoYim7

* Run HMC notebook

* Rerun HMC notebook

* Add report on HMC geometry and preconditioning for time trees

Analyses the funnel geometry arising from the ratio transform, explains
why NUTS undersamples shallow nodes, surveys alternative samplers, and
proposes fixed metric preconditioning (G* = J(h*)ᵀ D(h*) J(h*)) as a
tractable near-term fix using TransformedTransitionKernel.

https://claude.ai/code/session_019wmV49E4JVWySo3cAoYim7

---------

Co-authored-by: Claude <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants