Skip to content

Add discrete-trait (DTA) model and treeflow_dta_hmc CLI#93

Open
alexeid wants to merge 3 commits into
masterfrom
feature/discrete-trait-model
Open

Add discrete-trait (DTA) model and treeflow_dta_hmc CLI#93
alexeid wants to merge 3 commits into
masterfrom
feature/discrete-trait-model

Conversation

@alexeid

@alexeid alexeid commented Apr 14, 2026

Copy link
Copy Markdown
Collaborator

Summary

Adds fixed-topology Bayesian phylogeography / discrete trait analysis (DTA) to TreeFlow: a K-state time-reversible substitution model, a trait-data CSV loader, discrete_trait support in the YAML model spec, and a new treeflow_dta_hmc CLI that runs HMC/NUTS on a time-tree with tip states.

Motivation: for large trees that are too expensive to infer topology and times over jointly in BEAST (e.g. real-time genomic surveillance settings with thousands of tips), a fixed-tree DTA analysis with NUTS is a fast way to infer migration rates, equilibrium frequencies, and ancestral states. This PR lays the groundwork; MASCOT-style structured-coalescent follow-on is then a natural next step in the same framework.

What's included

Core

  • treeflow.evolution.substitution.discrete_trait.dta.DiscreteTraitModel(K): time-reversible K-state substitution model (Lemey et al. 2009 style), inherits EigendecompositionSubstitutionModel so it plugs into the existing transition-probability and pruning machinery without modification.
  • treeflow.evolution.traitio.DiscreteTraitData: two-column CSV loader (taxon,trait), supports explicit or inferred state ordering and unknown-state partials. Duck-types as a one-site Alignment (site_count = 1) so it flows through the existing LeafCTMC / phylogenetic_likelihood pipeline unchanged.
  • discrete_trait block in treeflow.model.phylo_model: n_states is a config int that is pulled at model-construction time and stripped from the params dict. The clock.strict.clock_rate scalar acts as the mean migration-rate multiplier.
  • treeflow_dta_hmc CLI: fork of treeflow_hmc that takes --traits (CSV) instead of --input (alignment), validates that declared n_states matches observed states and that all tree taxa have trait assignments.

Docs

  • cli.treeflow_dta_hmc.rst with complete YAML and CSV examples.
  • model-definition.md extended with a discrete_trait section.
  • API-reference stubs wired into the existing toctrees.
  • Incidentally filled a pre-existing gap: treeflow_hmc CLI was not documented — now is.

Tests (34 new, all passing)

  • test/evolution/substitution/discrete_trait/test_dta.py — 10 tests: shape/zero-row-sum for K ∈ {2,3,4,5,8}, time-reversibility, JC reduction at K=4, eigendecomposition round-trip, batched inputs, input validation.
  • test/evolution/test_traitio.py — 14 tests: CSV round-trip, taxon reordering, unknown-state partials, explicit state ordering, error cases.
  • test/model/test_phylo_model_discrete_trait.py — 8 tests: end-to-end from model dict through joint distribution sample + log_prob across K ∈ {3,4,6}.
  • test/cli/test_dta_hmc_cli.py — 2 CLI tests.

Test plan

  • pytest test/evolution/ test/model/ — 86 passed (52 existing + 34 new), 0 regressions
  • pytest test/ (excluding cli/acceleration markers) — 249 passed, 0 regressions
  • pytest test/cli/test_dta_hmc_cli.py -m cli — 2 passed end-to-end (NUTS on a 3-taxon tree completes in ~35s)
  • sphinx-build -b html docs/source — clean build, all new pages present
  • Parameter recovery on simulated data (next step — not part of this PR)

Followups not in this PR

  1. Validation against BEAST: simulate from a known Q matrix under a known tree, fit with treeflow_dta_hmc, verify posterior means/CIs recover the generating rates. This is the natural next commit.
  2. Non-reversible K-state variant via tf.linalg.expm(Q*t) for asymmetric migration rates (e.g. source/sink phylogeography).
  3. Ancestral-state reconstruction output via the existing sample_ctmc_preorder primitive.
  4. MASCOT-style structured coalescent: with the differentiable pruning and fixed-tree infrastructure in place, adding the per-segment ODE for lineage-state probabilities becomes a focused extension.

alexeid added 3 commits April 14, 2026 11:20
Enables fixed-topology Bayesian phylogeography / discrete trait analysis
via HMC/NUTS on a time-tree with tip states given as a two-column CSV.

Adds:
- DiscreteTraitModel(K): time-reversible K-state substitution model
  (Lemey et al. 2009 style), reuses the existing
  EigendecompositionSubstitutionModel infrastructure.
- DiscreteTraitData: CSV loader with explicit/inferred state ordering
  and unknown-state partials, duck-typed to plug into the existing
  LeafCTMC/phylogenetic_likelihood pipeline as a one-site "alignment".
- discrete_trait block in phylo_model.py: n_states is a config int
  and is stripped from the params dict; model construction dispatches
  on DISCRETE_TRAIT_KEY.
- treeflow_dta_hmc CLI: fork of treeflow_hmc that loads traits instead
  of alignments; validates that model.n_states matches observed states
  and that all tree taxa have trait assignments.

Tests: 34 new across model, loader, integration, and CLI.
Adds user-facing CLI guide (cli.treeflow_dta_hmc.rst) with model
specification, traits CSV format, and example invocation; extends
model-definition.md with a discrete_trait section including a
complete K=5 YAML example. Adds API-reference stubs for the new
modules and wires them into existing toctrees.

Also fills a pre-existing gap by adding documentation for the
already-shipped treeflow_hmc CLI and its module.
End-to-end sanity check for treeflow_dta_hmc: simulate a K=4 discrete
trait under a known time-reversible Q on a Yule(1.0, n=400) tree with
LPhy, then fit with treeflow_dta_hmc and verify posterior coverage of
the simulation truth.

Components:
- dta-validation.lphy: LPhy script with fixed π, R, μ truth
- convert.py: LPhy nexus -> taxon/trait CSV + Newick
- model.yaml: DTA model spec matching the simulation
- run.sh: slphy -> convert -> treeflow_dta_hmc -> summarize
- summarize.py: posterior mean / 95% CI / ESS / coverage vs. truth

On the default seed the posterior covers 9/10 parameters; the one miss
(π₁) is a finite-tree-depth artefact — observed state-1 count is well
below the stationary expectation.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant