diff --git a/.Rbuildignore b/.Rbuildignore index 7935e7a9a..ef1edf913 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -32,5 +32,6 @@ vignettes/.RData ^data-raw$ ^pkgdown$ ^revdep$ +^dev$ ^\.positai$ ^\.claude$ diff --git a/NEWS.md b/NEWS.md index 2badabdfe..788ac6460 100644 --- a/NEWS.md +++ b/NEWS.md @@ -10,6 +10,13 @@ - `ReadTntCharacters()` attaches an `xgroup` attribute (factor) when a TNT `xgroup` partition block is present, replacing the stand-alone `ReadXgroup()`. +## Performance + +- `Consensus()` computes majority-rule and threshold consensus trees in time + linear in the number of trees (previously quadratic), after + Jansson, Shen & Sung (2016); implementation informed by their `FACT` package. + `SplitFrequency()` inherits the same single-pass speed-up. + ## Fixes - `NexusTokens()` once again handles polymorphism tokens with internal diff --git a/R/Consensus.R b/R/Consensus.R index d7973c542..31bc81080 100644 --- a/R/Consensus.R +++ b/R/Consensus.R @@ -1,7 +1,17 @@ #' Construct consensus trees #' #' `Consensus()` calculates the consensus of a set of trees, using the -#' algorithm of \insertCite{Day1985}{TreeTools}. +#' cluster-table approach of \insertCite{Day1985}{TreeTools}. +#' +#' The strict consensus (`p = 1`) compares the clusters of the first tree +#' against every other tree in linear time. The majority-rule and threshold +#' consensus (`0.5 <= p < 1`) instead count the frequency of every split across +#' all trees in a single pass and retain those occurring in a proportion `p` or +#' more of trees; this runs in time linear in the number of trees, after +#' \insertCite{Jansson2016}{TreeTools} (implementation informed by the +#' \acronym{FACT} package of Jansson, Shen and Sung). By default the count uses +#' a 128-bit hash, whose results are exact with overwhelming probability; set +#' `exact = TRUE` for a slower but guaranteed-exact count. #' #' @param trees List of trees, optionally of class `multiPhylo`. #' @param p Proportion of trees that must contain a split for it to be reported @@ -9,6 +19,11 @@ #' default) gives the strict consensus. #' @param check.labels Logical specifying whether to check that all trees have #' identical labels. Defaults to `TRUE`, which is slower. +#' @param exact Logical; if `TRUE`, majority/threshold consensus uses a slower +#' but guaranteed-exact split count instead of the default 128-bit hashing +#' (whose results are exact unless a hash collision conflates two distinct +#' splits, which is vanishingly unlikely). Ignored when `p = 1`, which is +#' always exact. #' #' @return `Consensus()` returns an object of class `phylo`, rooted as in the #' first entry of `trees`. @@ -23,7 +38,7 @@ #' @references #' \insertAllCited{} #' @export -Consensus <- function(trees, p = 1, check.labels = TRUE) { +Consensus <- function(trees, p = 1, check.labels = TRUE, exact = FALSE) { if (length(trees) == 1L) { return(trees[[1]]) } @@ -69,7 +84,7 @@ Consensus <- function(trees, p = 1, check.labels = TRUE) { # Return: RootTree(.PreorderTree( - edge = splits_to_edge(consensus_tree(trees, p), nTip), + edge = splits_to_edge(consensus_tree(trees, p, exact = isTRUE(exact)), nTip), tip.label = TipLabels(trees[[1]]) ), root) } diff --git a/R/RcppExports.R b/R/RcppExports.R index ada9065f3..54c14a345 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -25,12 +25,12 @@ as_newick <- function(edge) { .Call(`_TreeTools_as_newick`, edge) } -split_frequencies <- function(trees) { - .Call(`_TreeTools_split_frequencies`, trees) +split_frequencies <- function(trees, exact = FALSE) { + .Call(`_TreeTools_split_frequencies`, trees, exact) } -consensus_tree <- function(trees, p) { - .Call(`_TreeTools_consensus_tree`, trees, p) +consensus_tree <- function(trees, p, exact = FALSE) { + .Call(`_TreeTools_consensus_tree`, trees, p, exact) } descendant_edges <- function(parent, child, postorder) { diff --git a/R/Support.R b/R/Support.R index e080f9cc2..dc6f5d021 100644 --- a/R/Support.R +++ b/R/Support.R @@ -14,6 +14,13 @@ #' or a `Splits` object. See #' [vignette](https://ms609.github.io/TreeTools/articles/load-trees.html) for #' possible methods of loading trees into R. +#' @param exact Logical specifying whether to use the slower but guaranteed +#' exact algorithm when counting the frequencies of _all_ splits (i.e. when +#' `reference = NULL`). The default (`FALSE`) uses a faster hashing approach +#' whose results are exact with overwhelming probability (a 128-bit hash +#' collision, which would conflate two distinct splits, is vanishingly +#' unlikely); set `exact = TRUE` if certainty is required. Ignored when +#' `reference` is a tree or `Splits` object. #' #' @return `SplitFrequency()` returns the number of trees in `forest` that #' contain each split in `reference`. @@ -31,7 +38,7 @@ #' @template MRS #' @family Splits operations #' @export -SplitFrequency <- function(reference, forest = NULL) { +SplitFrequency <- function(reference, forest = NULL, exact = FALSE) { if (is.null(reference) || is.null(forest)) { if (is.null(forest)) forest <- reference if (inherits(forest, "phylo")) forest <- list(forest) @@ -49,7 +56,7 @@ SplitFrequency <- function(reference, forest = NULL) { } forest <- RenumberTips(forest, tipLabels) forest <- Preorder(forest) - result <- split_frequencies(forest) + result <- split_frequencies(forest, exact = isTRUE(exact)) splits <- result[["splits"]] counts <- result[["counts"]] nTip <- length(tipLabels) diff --git a/dev/red-team/bench-majority-scaling.R b/dev/red-team/bench-majority-scaling.R new file mode 100644 index 000000000..3157715c9 --- /dev/null +++ b/dev/red-team/bench-majority-scaling.R @@ -0,0 +1,57 @@ +# Benchmark: majority consensus scaling in k (number of trees) at fixed n. +# The new core scales ~linearly in k; both count modes (hashed default, exact +# opt-in) are timed. +# +# Historical result vs the previous O(k^2 n) core (random trees, n = 100): +# k: 100 200 400 800 1600 +# new (s): 0.014 0.027 0.050 0.100 0.209 (exponent a = 0.97, linear) +# old (s): 0.027 0.095 0.353 1.36 5.36 (exponent a = 1.91, quadratic) +# speed-up: 1.9x 3.6x 7.0x 13.6x 25.6x +# +# Run: Rscript dev/red-team/bench-majority-scaling.R +suppressMessages(devtools::load_all(".", quiet = TRUE)) + +# Random trees: their majority consensus is ~unresolved, so the legacy core +# never hits its perfectly-resolved early exit and pays the full O(k^2 n). +n_tip <- 100L +ks <- c(100, 200, 400, 800, 1600) +reps <- 2L +set.seed(1) +make_forest <- function(k) { + forest <- lapply(seq_len(k), function(i) ape::rtree(n_tip, br = NULL)) + forest <- Preorder(RenumberTips(forest, forest[[1]])) + structure(forest, class = "multiPhylo") +} + +bench1 <- function(fn, forest, p) { + best <- Inf + for (i in seq_len(reps)) { + t0 <- Sys.time() + fn(forest, p) + best <- min(best, as.numeric(Sys.time() - t0, units = "secs")) + } + best +} + +hashed <- function(f, p) TreeTools:::consensus_tree(f, p) +exact <- function(f, p) TreeTools:::consensus_tree(f, p, exact = TRUE) + +cat(sprintf("Majority consensus, fixed n = %d tips, p = 0.5\n", n_tip)) +cat(sprintf("%6s %12s %12s\n", "k", "hashed (s)", "exact (s)")) +res <- data.frame() +for (k in ks) { + forest <- make_forest(k) + th <- bench1(hashed, forest, 0.5) + te <- bench1(exact, forest, 0.5) + cat(sprintf("%6d %12.4f %12.4f\n", k, th, te)) + res <- rbind(res, data.frame(k = k, hashed = th, exact = te)) +} + +# Scaling exponent: slope of log(time) vs log(k). +fit_exp <- function(k, t) { + ok <- t > 0 + unname(coef(lm(log(t[ok]) ~ log(k[ok])))[2]) +} +cat(sprintf("\nEmpirical scaling exponent (time ~ k^a; target ~1.0, linear):\n")) +cat(sprintf(" hashed a = %.2f\n", fit_exp(res$k, res$hashed))) +cat(sprintf(" exact a = %.2f\n", fit_exp(res$k, res$exact))) diff --git a/dev/red-team/verify-consensus.R b/dev/red-team/verify-consensus.R new file mode 100644 index 000000000..76ea701ef --- /dev/null +++ b/dev/red-team/verify-consensus.R @@ -0,0 +1,102 @@ +# Consistency cross-check for the O(kn) consensus / split-frequency machinery on +# adversarial inputs: the hashed (default) and exact split counts must agree, in +# both Consensus() and SplitFrequency(). +# +# Historical note: during development this script also compared the new core +# against a transient legacy O(k^2 n) oracle (`consensus_tree_legacy`); that gate +# passed with 0 failures across all fixtures below before the oracle was removed. +# The independent oracle for ongoing regression is `ape::consensus()`, exercised +# by tests/testthat/test-consensus.R. +# +# Run: Rscript dev/red-team/verify-consensus.R +suppressMessages(devtools::load_all(".", quiet = TRUE)) +set.seed(1) + +# Canonical set of splits from a packed RawMatrix (order-independent). +split_set <- function(m, n_tip) { + if (is.null(m) || nrow(m) == 0) return(character(0)) + out <- character(nrow(m)) + for (r in seq_len(nrow(m))) { + bits <- as.logical(rawToBits(as.raw(m[r, ])))[seq_len(n_tip)] + if (isTRUE(bits[1])) bits <- !bits # tip 1 on the FALSE side + side <- which(bits) + if (length(side) < 2 || length(side) > n_tip - 2) next # skip trivial + out[r] <- paste(side, collapse = ",") + } + sort(unique(out[out != ""])) +} + +fails <- 0L +check <- function(cond, msg) { + if (!isTRUE(cond)) { cat(" FAIL:", msg, "\n"); fails <<- fails + 1L } +} + +ps <- c(0.5, 0.55, 0.6, 2/3, 0.75, 0.8, 0.9, 0.99, 1) + +# ---- Random forests across a range of (n, k) ---------------------------------- +message("== random forests ==") +# k = 1 is handled by R's Consensus() wrapper (returns the single tree), never +# reaching consensus_tree(); the legacy oracle is not robust to it, so skip. +for (n_tip in c(4, 5, 6, 8, 13, 20, 50)) { + for (k in c(2, 3, 5, 7, 8, 15, 32, 60)) { + message(sprintf(" n=%d k=%d", n_tip, k)) + forest <- lapply(seq_len(k), function(i) ape::rtree(n_tip, br = NULL)) + forest <- RenumberTips(forest, forest[[1]]) + forest <- Preorder(forest) + class(forest) <- "multiPhylo" + for (p in ps) { + hashed <- split_set(TreeTools:::consensus_tree(forest, p), n_tip) + exact <- split_set(TreeTools:::consensus_tree(forest, p, exact = TRUE), n_tip) + check(identical(hashed, exact), + sprintf("consensus hashed!=exact n=%d k=%d p=%.3g", n_tip, k, p)) + } + } +} + +# ---- Adversarial structured fixtures ----------------------------------------- +cat("== structured fixtures ==\n") +adversarial <- list( + all_identical = rep(list(BalancedTree(8)), 7), + star_disagree = list(ape::read.tree(text = "((a,b),(c,d));"), + ape::read.tree(text = "((a,c),(b,d));")), + single_split = list(ape::read.tree(text = "((a,b,c,d),(e,f,g));"), + ape::read.tree(text = "((a,b,c,d),(e,f,g));")), + one_off_bal = c(rep(list(BalancedTree(10)), 3), list(PectinateTree(10))), + threshold_tie = c(rep(list(BalancedTree(8)), 2), rep(list(PectinateTree(8)), 2)), + mixed_resoln = list(BalancedTree(8), CollapseNode(BalancedTree(8), 11:12), + PectinateTree(8)) +) +for (nm in names(adversarial)) { + forest <- adversarial[[nm]] + forest <- RenumberTips(forest, forest[[1]]) + forest <- Preorder(forest) + class(forest) <- "multiPhylo" + n_tip <- NTip(forest[[1]]) + for (p in ps) { + hashed <- split_set(TreeTools:::consensus_tree(forest, p), n_tip) + exact <- split_set(TreeTools:::consensus_tree(forest, p, exact = TRUE), n_tip) + check(identical(hashed, exact), sprintf("%s p=%.3g: hashed!=exact", nm, p)) + } +} + +# ---- Hashed vs exact split frequencies (gate 5) ------------------------------ +cat("== hashed vs exact SplitFrequency ==\n") +for (n_tip in c(4, 6, 8, 13, 30)) { + for (k in c(1, 2, 5, 12, 40)) { + forest <- lapply(seq_len(k), function(i) ape::rtree(n_tip, br = NULL)) + class(forest) <- "multiPhylo" + sh <- SplitFrequency(forest, exact = FALSE) + se <- SplitFrequency(forest, exact = TRUE) + # Compare as (split -> count) maps, order-independent + key <- function(s) { + if (length(s) == 0) return(character(0)) + paste(as.character(s), attr(s, "count")) + } + check(setequal(key(sh), key(se)), + sprintf("SplitFrequency n=%d k=%d: hashed != exact", n_tip, k)) + } +} + +cat(sprintf("\n==== %s : %d failures ====\n", + if (fails == 0) "PASS" else "FAIL", fails)) +quit(status = if (fails == 0) 0 else 1) diff --git a/inst/REFERENCES.bib b/inst/REFERENCES.bib index 4a861824b..076db0bac 100644 --- a/inst/REFERENCES.bib +++ b/inst/REFERENCES.bib @@ -376,3 +376,14 @@ @article{Wilkinson1992 pages = {375--385}, doi = {10.1111/j.1096-0031.1992.tb00079.x} } + +@article{Jansson2016, + title = {Improved algorithms for constructing consensus trees}, + author = {Jansson, Jesper and Shen, Chuanqi and Sung, Wing-Kin}, + journal = {Journal of the ACM}, + volume = {63}, + number = {3}, + pages = {28:1--28:24}, + year = {2016}, + doi = {10.1145/2898436} +} diff --git a/inst/WORDLIST b/inst/WORDLIST index 73e0a08b2..f91dcfdb2 100644 --- a/inst/WORDLIST +++ b/inst/WORDLIST @@ -11,6 +11,7 @@ Foulds Guillerme HCL Hennig's +Jansson Klopfstein Krzywinski Lemant @@ -36,6 +37,7 @@ Rdpack Reweight SPR Sackin's +Shen Spasojevic Stemmier Stemwardness diff --git a/man/Consensus.Rd b/man/Consensus.Rd index d29239933..d61d5e82a 100644 --- a/man/Consensus.Rd +++ b/man/Consensus.Rd @@ -4,7 +4,7 @@ \alias{Consensus} \title{Construct consensus trees} \usage{ -Consensus(trees, p = 1, check.labels = TRUE) +Consensus(trees, p = 1, check.labels = TRUE, exact = FALSE) } \arguments{ \item{trees}{List of trees, optionally of class \code{multiPhylo}.} @@ -15,6 +15,12 @@ default) gives the strict consensus.} \item{check.labels}{Logical specifying whether to check that all trees have identical labels. Defaults to \code{TRUE}, which is slower.} + +\item{exact}{Logical; if \code{TRUE}, majority/threshold consensus uses a slower +but guaranteed-exact split count instead of the default 128-bit hashing +(whose results are exact unless a hash collision conflates two distinct +splits, which is vanishingly unlikely). Ignored when \code{p = 1}, which is +always exact.} } \value{ \code{Consensus()} returns an object of class \code{phylo}, rooted as in the @@ -22,7 +28,18 @@ first entry of \code{trees}. } \description{ \code{Consensus()} calculates the consensus of a set of trees, using the -algorithm of \insertCite{Day1985}{TreeTools}. +cluster-table approach of \insertCite{Day1985}{TreeTools}. +} +\details{ +The strict consensus (\code{p = 1}) compares the clusters of the first tree +against every other tree in linear time. The majority-rule and threshold +consensus (\verb{0.5 <= p < 1}) instead count the frequency of every split across +all trees in a single pass and retain those occurring in a proportion \code{p} or +more of trees; this runs in time linear in the number of trees, after +\insertCite{Jansson2016}{TreeTools} (implementation informed by the +\acronym{FACT} package of Jansson, Shen and Sung). By default the count uses +a 128-bit hash, whose results are exact with overwhelming probability; set +\code{exact = TRUE} for a slower but guaranteed-exact count. } \examples{ Consensus(as.phylo(0:2, 8)) diff --git a/man/SplitFrequency.Rd b/man/SplitFrequency.Rd index cb6d72731..ea338681e 100644 --- a/man/SplitFrequency.Rd +++ b/man/SplitFrequency.Rd @@ -4,7 +4,7 @@ \alias{SplitFrequency} \title{Frequency of splits} \usage{ -SplitFrequency(reference, forest = NULL) +SplitFrequency(reference, forest = NULL, exact = FALSE) } \arguments{ \item{reference}{A tree of class \code{phylo}, a \code{Splits} object. If \code{NULL}, @@ -14,6 +14,14 @@ the frequencies of all splits in \code{forest} will be returned.} or a \code{Splits} object. See \href{https://ms609.github.io/TreeTools/articles/load-trees.html}{vignette} for possible methods of loading trees into R.} + +\item{exact}{Logical specifying whether to use the slower but guaranteed +exact algorithm when counting the frequencies of \emph{all} splits (i.e. when +\code{reference = NULL}). The default (\code{FALSE}) uses a faster hashing approach +whose results are exact with overwhelming probability (a 128-bit hash +collision, which would conflate two distinct splits, is vanishingly +unlikely); set \code{exact = TRUE} if certainty is required. Ignored when +\code{reference} is a tree or \code{Splits} object.} } \value{ \code{SplitFrequency()} returns the number of trees in \code{forest} that diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 31ebde5bb..36c2de209 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -86,25 +86,27 @@ BEGIN_RCPP END_RCPP } // split_frequencies -List split_frequencies(const List trees); -RcppExport SEXP _TreeTools_split_frequencies(SEXP treesSEXP) { +List split_frequencies(const List trees, const bool exact); +RcppExport SEXP _TreeTools_split_frequencies(SEXP treesSEXP, SEXP exactSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; Rcpp::traits::input_parameter< const List >::type trees(treesSEXP); - rcpp_result_gen = Rcpp::wrap(split_frequencies(trees)); + Rcpp::traits::input_parameter< const bool >::type exact(exactSEXP); + rcpp_result_gen = Rcpp::wrap(split_frequencies(trees, exact)); return rcpp_result_gen; END_RCPP } // consensus_tree -RawMatrix consensus_tree(const List trees, const NumericVector p); -RcppExport SEXP _TreeTools_consensus_tree(SEXP treesSEXP, SEXP pSEXP) { +RawMatrix consensus_tree(const List trees, const NumericVector p, const bool exact); +RcppExport SEXP _TreeTools_consensus_tree(SEXP treesSEXP, SEXP pSEXP, SEXP exactSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; Rcpp::traits::input_parameter< const List >::type trees(treesSEXP); Rcpp::traits::input_parameter< const NumericVector >::type p(pSEXP); - rcpp_result_gen = Rcpp::wrap(consensus_tree(trees, p)); + Rcpp::traits::input_parameter< const bool >::type exact(exactSEXP); + rcpp_result_gen = Rcpp::wrap(consensus_tree(trees, p, exact)); return rcpp_result_gen; END_RCPP } @@ -541,8 +543,8 @@ static const R_CallMethodDef CallEntries[] = { {"_TreeTools_ape_neworder_phylo", (DL_FUNC) &_TreeTools_ape_neworder_phylo, 5}, {"_TreeTools_ape_neworder_pruningwise", (DL_FUNC) &_TreeTools_ape_neworder_pruningwise, 5}, {"_TreeTools_as_newick", (DL_FUNC) &_TreeTools_as_newick, 1}, - {"_TreeTools_split_frequencies", (DL_FUNC) &_TreeTools_split_frequencies, 1}, - {"_TreeTools_consensus_tree", (DL_FUNC) &_TreeTools_consensus_tree, 2}, + {"_TreeTools_split_frequencies", (DL_FUNC) &_TreeTools_split_frequencies, 2}, + {"_TreeTools_consensus_tree", (DL_FUNC) &_TreeTools_consensus_tree, 3}, {"_TreeTools_descendant_edges", (DL_FUNC) &_TreeTools_descendant_edges, 3}, {"_TreeTools_descendant_edges_single", (DL_FUNC) &_TreeTools_descendant_edges_single, 5}, {"_TreeTools_descendant_tips", (DL_FUNC) &_TreeTools_descendant_tips, 3}, diff --git a/src/consensus.cpp b/src/consensus.cpp index 9f2cf3d9e..84bce7122 100644 --- a/src/consensus.cpp +++ b/src/consensus.cpp @@ -8,327 +8,409 @@ using namespace Rcpp; #include /* for fill */ #include /* for array */ #include /* for steady_clock (interrupt timing) */ +#include /* for uint64_t (split hashing) */ #include /* for string (hash key) */ #include /* for unordered_map */ +#include /* for vector */ using TreeTools::ct_stack_threshold; using TreeTools::ct_max_leaves_heap; +using TreeTools::ClusterTable; struct StackEntry { int32 L, R, N, W; }; -// Helper template function to perform consensus computation -// Uses StackContainer for the S array (either std::array or std::vector) +// Throttled (~1 s) user-interrupt check. +inline void throttled_interrupt( + std::chrono::steady_clock::time_point& last) { + const auto now = std::chrono::steady_clock::now(); + if (std::chrono::duration_cast(now - last).count() + >= 1) { + last = now; + Rcpp::checkUserInterrupt(); + } +} + +// --------------------------------------------------------------------------- +// Shared postorder cluster-enumeration primitive. +// +// Walks `tree` in (short) postorder, computing each internal node's cluster +// (leaf count N, weight W) under the leaf encoding of `ref`, and calls +// `visit(L, R, N, j_pos)` for each internal node (j_pos = 1-based index in the +// traversal). `NVERTEX_short` skips the three trivial top vertices (root, +// ingroup, sentinel), so only non-trivial clusters are visited. This single +// implementation backs the consensus count, the exact split-frequency count, +// and (with ref == tree) the all-splits count, so the delicate stack +// arithmetic lives in exactly one place. +template +inline void for_each_internal_node(ClusterTable& ref, ClusterTable& tree, + StackEntry* const S_start, Visit&& visit) { + int32 v = 0, w = 0, L, R, N, W; + tree.TRESET(); + tree.READT(&v, &w); + int32 j_pos = 0; + StackEntry* S_top = S_start; // Empty the stack S + do { + if (tree.is_leaf(v)) { + const auto enc_v = ref.ENCODE(v); + *S_top++ = {enc_v, enc_v, 1, 1}; + } else { + const StackEntry& entry = *--S_top; + L = entry.L; R = entry.R; N = entry.N; + W = 1 + entry.W; + w -= entry.W; + while (w) { + const StackEntry& next = *--S_top; + L = std::min(L, next.L); // Faster than ternary operator + R = std::max(R, next.R); + N += next.N; + W += next.W; + w -= next.W; + } + *S_top++ = {L, R, N, W}; + ++j_pos; + visit(L, R, N, j_pos); + } + tree.NVERTEX_short(&v, &w); + } while (v); +} + +// --------------------------------------------------------------------------- +// Exact single-pass count of every distinct non-trivial split. +// +// One O(k * sum-of-cluster-sizes) = O(k n h) pass: each tree's clusters are +// enumerated under its own encoding (so each is a contiguous L..R), the leaf +// set is packed into a bit pattern, and that exact pattern is the hash-map key. +// Counts accumulate directly (each tree contributes 1 per cluster it holds). +// Deterministic; no collision risk. Backs both `split_frequencies(exact=TRUE)` +// and the majority/threshold `Consensus()`. +template +void count_splits_exact(std::vector& tables, const int32 n_tip, + const int32 nbin, StackContainer& S, + std::vector>& split_patterns, + std::vector& counts) { + const int32 n_trees = int32(tables.size()); + const int32 ntip_3 = n_tip - 3; + std::unordered_map split_map; + split_map.reserve((ntip_3 > 0 ? ntip_3 : 1) * 2); + std::string key(nbin, '\0'); + StackEntry* const S_start = S.data(); + auto lastInterrupt = std::chrono::steady_clock::now(); + + for (int32 t = 0; t < n_trees; ++t) { + throttled_interrupt(lastInterrupt); + ClusterTable& tree = tables[t]; + for_each_internal_node(tree, tree, S_start, + [&tree, &key, &split_map, &split_patterns, &counts] + (int32 L, int32 R, int32 /* N */, int32 /* j_pos */) { + std::fill(key.begin(), key.end(), '\0'); + for (int32 j = L; j <= R; ++j) { + const int32 leaf_idx = tree.DECODE(j) - 1; // 0-based + key[leaf_idx >> 3] |= static_cast(1 << (leaf_idx & 7)); + } + auto it = split_map.find(key); + if (it == split_map.end()) { + split_map.emplace(key, int32(split_patterns.size())); + split_patterns.emplace_back(key.begin(), key.end()); + counts.push_back(1); + } else { + ++counts[it->second]; + } + }); + } +} + +// Forward declaration: the O(kn) hashed counter (defined below) is the default +// for the majority/threshold consensus. Non-dependent name in the template +// below, so it must be declared first. +void count_splits_hashed(std::vector& tables, const int32 n_tip, + const int32 nbin, + std::vector>& split_patterns, + std::vector& counts); + +// --------------------------------------------------------------------------- +// Consensus tree. +// +// Strict (p = 1, thresh == n_trees) keeps its already-optimal single-reference +// path over the first tree. Majority / threshold (0.5 <= p < 1) counts every +// split's frequency in one pass and keeps those reaching the threshold: any two +// such splits each occur in > k/2 trees, so they co-occur in some tree and are +// pairwise (hence globally) compatible, forming a valid tree directly. The +// count is hashed (O(kn), probabilistic) by default, or exact (deterministic, +// O(k.n.height)) when `exact` is set. template RawMatrix calc_consensus_tree( const List& trees, const NumericVector& p, + const bool exact, StackContainer& S ) { - int32 v = 0; - int32 w = 0; - int32 L, R, N, W; - const int32 n_trees = trees.length(); const int32 frac_thresh = int32(n_trees * p[0]) + 1; const int32 thresh = frac_thresh > n_trees ? n_trees : frac_thresh; - - std::vector tables; + + std::vector tables; tables.reserve(n_trees); for (int32 i = 0; i < n_trees; ++i) { - tables.emplace_back(TreeTools::ClusterTable(Rcpp::List(trees(i)))); + tables.emplace_back(ClusterTable(Rcpp::List(trees(i)))); } - + const int32 n_tip = tables[0].N(); const int32 ntip_3 = n_tip - 3; const int32 nbin = (n_tip + 7) / 8; // bytes per row in packed output - - int32* split_count; - std::array split_stack; - std::vector split_heap; - if (n_tip <= ct_stack_threshold) { - split_count = split_stack.data(); - } else { - split_heap.resize(n_tip); - split_count = split_heap.data(); - } StackEntry *const S_start = S.data(); - - // Packed output: each row has nbin bytes RawMatrix ret(ntip_3, nbin); - - int32 i = 0; int32 splits_found = 0; - auto lastInterrupt = std::chrono::steady_clock::now(); - - do { - // ~1 s user-interrupt check - { - const auto now = std::chrono::steady_clock::now(); - if (std::chrono::duration_cast( - now - lastInterrupt).count() >= 1) { - lastInterrupt = now; - Rcpp::checkUserInterrupt(); - } - } - if (tables[i].NOSWX(ntip_3)) { - continue; + + if (thresh >= n_trees) { + // ---- Strict consensus: single reference (tree 0) ---------------------- + int32* split_count; + std::array split_stack; + std::vector split_heap; + if (n_tip <= ct_stack_threshold) { + split_count = split_stack.data(); + } else { + split_heap.resize(n_tip); + split_count = split_heap.data(); } - - std::fill(split_count, split_count + n_tip, 1); - - for (int32 j = i + 1; j < n_trees; ++j) { - ASSERT(tables[i].N() == tables[j].N()); - - tables[i].CLEAR(); - - tables[j].TRESET(); - tables[j].READT(&v, &w); - - int32 j_pos = 0; - StackEntry* S_top = S_start; // Empty the stack S - - do { - if (CT_IS_LEAF(v)) { - const auto enc_v = tables[i].ENCODE(v); - *S_top++ = {enc_v, enc_v, 1, 1}; - } else { - const StackEntry& entry = *--S_top; - L = entry.L; R = entry.R; N = entry.N; - W = 1 + entry.W; - w -= entry.W; - while (w) { - const StackEntry& next = *--S_top; - L = std::min(L, next.L); // Faster than ternary operator - R = std::max(R, next.R); - N += next.N; - W += next.W; - w -= next.W; - } - - *S_top++ = {L, R, N, W}; - - ++j_pos; - - if (!tables[j].GETSWX(&j_pos)) { - if (N == R - L + 1) { // L..R is contiguous, and must be tested - if (tables[i].CLUSTONL(L, R)) { - tables[j].SETSWX(j_pos); - ASSERT(L > 0); - ++split_count[L - 1]; - } else if (tables[i].CLUSTONR(L, R)) { - tables[j].SETSWX(j_pos); - ASSERT(R > 0); - ++split_count[R - 1]; - } + std::fill(split_count, split_count + n_tip, 1); // tree 0 holds its clusters + + auto lastInterrupt = std::chrono::steady_clock::now(); + for (int32 j = 1; j < n_trees; ++j) { + throttled_interrupt(lastInterrupt); + ASSERT(tables[j].N() == n_tip); + for_each_internal_node(tables[0], tables[j], S_start, + [&tables, &split_count](int32 cl_L, int32 cl_R, int32 cl_N, int32) { + if (cl_N == cl_R - cl_L + 1) { // contiguous: testable against ref + if (tables[0].CLUSTONL(cl_L, cl_R)) { + ASSERT(cl_L > 0); + ++split_count[cl_L - 1]; + } else if (tables[0].CLUSTONR(cl_L, cl_R)) { + ASSERT(cl_R > 0); + ++split_count[cl_R - 1]; } } - } - tables[j].NVERTEX_short(&v, &w); - } while (v); + }); } - + // Pack reference clusters present in every tree. for (int32 k = 0; k < n_tip; ++k) { if (split_count[k] >= thresh) { - const int32 start = tables[i].X_left(k + 1); - const int32 end = tables[i].X_right(k + 1); - + const int32 start = tables[0].X_left(k + 1); + const int32 end = tables[0].X_right(k + 1); + if (start == 0 && end == 0) continue; // no cluster at this row for (int32 j = start; j <= end; ++j) { - const int32 leaf_idx = tables[i].DECODE(j) - 1; // 0-based - const int32 byte_idx = leaf_idx >> 3; // column index - const int32 bit_idx = leaf_idx & 7; // bit within byte - - // pointer to the first row of this column + const int32 leaf_idx = tables[0].DECODE(j) - 1; + const int32 byte_idx = leaf_idx >> 3; + const int32 bit_idx = leaf_idx & 7; Rbyte* col_ptr = &ret(0, byte_idx); - col_ptr[splits_found] |= (Rbyte(1) << bit_idx); // set bit in row + col_ptr[splits_found] |= (Rbyte(1) << bit_idx); } - ++splits_found; - // If we have a perfectly resolved tree, exit early. - if (splits_found == ntip_3) { - return ret; + if (splits_found == ntip_3) return ret; + } + } + } else { + // ---- Majority / threshold: count then threshold ----------------------- + std::vector> split_patterns; + std::vector counts; + if (exact) { + count_splits_exact(tables, n_tip, nbin, S, split_patterns, counts); + } else { + count_splits_hashed(tables, n_tip, nbin, split_patterns, counts); + } + + const int32 n_distinct = int32(split_patterns.size()); + for (int32 i = 0; i < n_distinct; ++i) { + if (counts[i] >= thresh) { + for (int32 c = 0; c < nbin; ++c) { + ret(splits_found, c) = split_patterns[i][c]; } + ++splits_found; + if (splits_found == ntip_3) return ret; } } - } while (i++ != n_trees - thresh); // All clades in p% consensus must occur in first q% of trees. - - return (splits_found == 0) ? RawMatrix(0, nbin) : + } + + return (splits_found == 0) ? RawMatrix(0, nbin) : (splits_found < ntip_3) ? ret(Range(0, splits_found - 1), _) : ret; } -// Helper template function to compute split frequencies for all splits -// Like calc_consensus_tree but without threshold or early exit -template -List calc_split_frequencies( - const List& trees, - StackContainer& S -) { - int32 v = 0; - int32 w = 0; - int32 L, R, N, W; - - const int32 n_trees = trees.length(); - - std::vector tables; - tables.reserve(n_trees); - for (int32 i = 0; i < n_trees; ++i) { - tables.emplace_back(TreeTools::ClusterTable(Rcpp::List(trees(i)))); +// --------------------------------------------------------------------------- +// Hashed split frequencies (the fast default for split_frequencies). +// +// A single O(kn) pass: each non-trivial cluster is identified by a 128-bit +// subtree hash = the (order-independent) sum of its leaves' fixed splitmix64 +// hashes, so the same split in different trees hashes identically without an +// O(cluster size) key. Counts accumulate directly; the bit pattern is +// materialised only the first time a split is seen. Exactness is therefore +// probabilistic (a 128-bit collision, ~1e-30, would conflate two splits). +inline uint64_t splitmix64(uint64_t x) { + x += 0x9e3779b97f4a7c15ULL; + x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9ULL; + x = (x ^ (x >> 27)) * 0x94d049bb133111ebULL; + return x ^ (x >> 31); +} + +struct HashEntry { int32 L, R, W; uint64_t lo, hi; }; +struct Hash128 { + uint64_t lo, hi; + bool operator==(const Hash128& o) const noexcept { + return lo == o.lo && hi == o.hi; } - - const int32 n_tip = tables[0].N(); - const int32 ntip_3 = n_tip - 3; - const int32 nbin = (n_tip + 7) / 8; // bytes per row in packed output - - int32* split_count; - std::array split_stack; - std::vector split_heap; +}; +struct Hash128Hasher { + std::size_t operator()(const Hash128& k) const noexcept { + return std::size_t(k.lo ^ (k.hi * 0x9e3779b97f4a7c15ULL)); + } +}; + +void count_splits_hashed(std::vector& tables, const int32 n_tip, + const int32 nbin, + std::vector>& split_patterns, + std::vector& counts) { + const int32 n_trees = int32(tables.size()); + const int32 ntip_3 = n_tip - 3; + + // Fixed per-leaf 128-bit hashes, keyed by original (1-based) leaf id. + std::vector leaf_lo(n_tip + 1), leaf_hi(n_tip + 1); + for (int32 i = 1; i <= n_tip; ++i) { + leaf_lo[i] = splitmix64(uint64_t(i)); + leaf_hi[i] = splitmix64(uint64_t(i) + 0x9e3779b97f4a7c15ULL); + } + + HashEntry* S_start; + std::array hash_stack; + std::vector hash_heap; if (n_tip <= ct_stack_threshold) { - split_count = split_stack.data(); + S_start = hash_stack.data(); } else { - split_heap.resize(n_tip); - split_count = split_heap.data(); + hash_heap.resize(n_tip); + S_start = hash_heap.data(); } - StackEntry *const S_start = S.data(); - - // Hash map for O(1) amortized split deduplication - std::unordered_map split_map; - split_map.reserve(ntip_3 * 2); - std::vector> split_patterns; - std::vector counts; + std::unordered_map split_map; + split_map.reserve((ntip_3 > 0 ? ntip_3 : 1) * 2); - // Reusable key buffer — avoids per-split heap allocation - std::string key(nbin, '\0'); auto lastInterrupt = std::chrono::steady_clock::now(); - for (int32 i = 0; i < n_trees; ++i) { - // ~1 s user-interrupt check - { - const auto now = std::chrono::steady_clock::now(); - if (std::chrono::duration_cast( - now - lastInterrupt).count() >= 1) { - lastInterrupt = now; - Rcpp::checkUserInterrupt(); - } - } - if (tables[i].NOSWX(ntip_3)) { - continue; - } - - std::fill(split_count, split_count + n_tip, 1); - - for (int32 j = i + 1; j < n_trees; ++j) { - ASSERT(tables[i].N() == tables[j].N()); - - tables[i].CLEAR(); - - tables[j].TRESET(); - tables[j].READT(&v, &w); - - int32 j_pos = 0; - StackEntry* S_top = S_start; // Empty the stack S - - do { - if (CT_IS_LEAF(v)) { - const auto enc_v = tables[i].ENCODE(v); - *S_top++ = {enc_v, enc_v, 1, 1}; - } else { - const StackEntry& entry = *--S_top; - L = entry.L; R = entry.R; N = entry.N; - W = 1 + entry.W; - w -= entry.W; - while (w) { - const StackEntry& next = *--S_top; - L = std::min(L, next.L); - R = std::max(R, next.R); - N += next.N; - W += next.W; - w -= next.W; - } - - *S_top++ = {L, R, N, W}; - - ++j_pos; - - if (!tables[j].GETSWX(&j_pos)) { - if (N == R - L + 1) { - if (tables[i].CLUSTONL(L, R)) { - tables[j].SETSWX(j_pos); - ASSERT(L > 0); - ++split_count[L - 1]; - } else if (tables[i].CLUSTONR(L, R)) { - tables[j].SETSWX(j_pos); - ASSERT(R > 0); - ++split_count[R - 1]; - } - } + for (int32 t = 0; t < n_trees; ++t) { + throttled_interrupt(lastInterrupt); + ClusterTable& tree = tables[t]; + int32 v = 0, w = 0, L, R, W; + tree.TRESET(); + tree.READT(&v, &w); + HashEntry* S_top = S_start; + do { + if (tree.is_leaf(v)) { + const int32 enc_v = tree.ENCODE(v); + *S_top++ = {enc_v, enc_v, 1, leaf_lo[v], leaf_hi[v]}; + } else { + const HashEntry& entry = *--S_top; + L = entry.L; R = entry.R; + W = 1 + entry.W; + uint64_t hlo = entry.lo, hhi = entry.hi; + w -= entry.W; + while (w) { + const HashEntry& next = *--S_top; + L = std::min(L, next.L); + R = std::max(R, next.R); + W += next.W; + hlo += next.lo; + hhi += next.hi; + w -= next.W; + } + *S_top++ = {L, R, W, hlo, hhi}; + + const Hash128 hkey{hlo, hhi}; + auto it = split_map.find(hkey); + if (it == split_map.end()) { + split_map.emplace(hkey, int32(split_patterns.size())); + std::vector pattern(nbin, 0); + for (int32 j = L; j <= R; ++j) { + const int32 leaf_idx = tree.DECODE(j) - 1; + pattern[leaf_idx >> 3] |= (Rbyte(1) << (leaf_idx & 7)); } + split_patterns.push_back(std::move(pattern)); + counts.push_back(1); + } else { + ++counts[it->second]; } - tables[j].NVERTEX_short(&v, &w); - } while (v); - } - - for (int32 k = 0; k < n_tip; ++k) { - const int32 start = tables[i].X_left(k + 1); - const int32 end = tables[i].X_right(k + 1); - if (start == 0 && end == 0) continue; // No valid cluster at this position - - // Build the bit pattern into the reusable key buffer - std::fill(key.begin(), key.end(), '\0'); - for (int32 j = start; j <= end; ++j) { - const int32 leaf_idx = tables[i].DECODE(j) - 1; // 0-based - const int32 byte_idx = leaf_idx >> 3; - const int32 bit_idx = leaf_idx & 7; - key[byte_idx] |= static_cast(1 << bit_idx); - } - - auto it = split_map.find(key); - if (it == split_map.end()) { - // New split: record it with count from this reference tree - const int32 idx = split_patterns.size(); - split_map.emplace(key, idx); - split_patterns.emplace_back(key.begin(), key.end()); - counts.push_back(split_count[k]); } - // If already found, the first reference tree that found it has the - // correct total count (it compared against all later trees). - } + tree.NVERTEX_short(&v, &w); + } while (v); } - - const int32 splits_found = split_patterns.size(); +} + +// Assemble a split-frequency result List from collected patterns + counts. +inline List frequencies_list( + const std::vector>& split_patterns, + const std::vector& counts, const int32 nbin) { + const int32 splits_found = int32(split_patterns.size()); RawMatrix ret(splits_found, nbin); - for (int32 r = 0; r < splits_found; ++r) { for (int32 c = 0; c < nbin; ++c) { ret(r, c) = split_patterns[r][c]; } } - IntegerVector count_vec(counts.begin(), counts.end()); - - return List::create( - Named("splits") = ret, - Named("counts") = count_vec - ); + return List::create(Named("splits") = ret, Named("counts") = count_vec); +} + +List calc_split_frequencies_hashed(const List& trees, const int32 n_tip) { + const int32 n_trees = trees.length(); + std::vector tables; + tables.reserve(n_trees); + for (int32 i = 0; i < n_trees; ++i) { + tables.emplace_back(ClusterTable(Rcpp::List(trees(i)))); + } + const int32 nbin = (n_tip + 7) / 8; + std::vector> split_patterns; + std::vector counts; + count_splits_hashed(tables, n_tip, nbin, split_patterns, counts); + return frequencies_list(split_patterns, counts, nbin); } +// Exact split frequencies: the same single-pass count, returned in full. +template +List calc_split_frequencies_exact(const List& trees, StackContainer& S) { + const int32 n_trees = trees.length(); + std::vector tables; + tables.reserve(n_trees); + for (int32 i = 0; i < n_trees; ++i) { + tables.emplace_back(ClusterTable(Rcpp::List(trees(i)))); + } + const int32 n_tip = tables[0].N(); + const int32 nbin = (n_tip + 7) / 8; + + std::vector> split_patterns; + std::vector counts; + count_splits_exact(tables, n_tip, nbin, S, split_patterns, counts); + return frequencies_list(split_patterns, counts, nbin); +} + +// --------------------------------------------------------------------------- +// Exports + // [[Rcpp::export]] -List split_frequencies(const List trees) { +List split_frequencies(const List trees, const bool exact = false) { try { - TreeTools::ClusterTable temp_table(Rcpp::List(trees(0))); + ClusterTable temp_table(Rcpp::List(trees(0))); const int32 n_tip = temp_table.N(); - + + if (!exact) { + return calc_split_frequencies_hashed(trees, n_tip); + } if (n_tip <= ct_stack_threshold) { std::array S; - return calc_split_frequencies(trees, S); + return calc_split_frequencies_exact(trees, S); } else { std::vector S(n_tip); - return calc_split_frequencies(trees, S); + return calc_split_frequencies_exact(trees, S); } } catch(const std::exception& e) { Rcpp::stop(e.what()); } - + ASSERT(false && "Unreachable code in split_frequencies"); return List(); } @@ -339,26 +421,23 @@ List split_frequencies(const List trees) { // Further investigation could be beneficial; for now, suggest applying // the function to preorder trees only. // [[Rcpp::export]] -RawMatrix consensus_tree(const List trees, const NumericVector p) { - // First, peek at the tree size to determine allocation strategy - // We'll create a temporary ClusterTable just to check the size +RawMatrix consensus_tree(const List trees, const NumericVector p, + const bool exact = false) { try { - TreeTools::ClusterTable temp_table(Rcpp::List(trees(0))); + ClusterTable temp_table(Rcpp::List(trees(0))); const int32 n_tip = temp_table.N(); - + if (n_tip <= ct_stack_threshold) { - // Small tree: use stack-allocated array std::array S; - return calc_consensus_tree(trees, p, S); + return calc_consensus_tree(trees, p, exact, S); } else { - // Large tree: use heap-allocated vector std::vector S(n_tip); - return calc_consensus_tree(trees, p, S); + return calc_consensus_tree(trees, p, exact, S); } } catch(const std::exception& e) { Rcpp::stop(e.what()); } - + ASSERT(false && "Unreachable code in consensus_tree"); return RawMatrix(0, 0); } diff --git a/tests/testthat/test-Support.R b/tests/testthat/test-Support.R index 2fa0e3421..ccb8204ea 100644 --- a/tests/testthat/test-Support.R +++ b/tests/testthat/test-Support.R @@ -50,6 +50,20 @@ test_that("Node supports calculated correctly", { structure(as.Splits(PectinateTree(8200)), count = rep(2, 8197))) }) +test_that("SplitFrequency() exact and hashed counts agree", { + set.seed(2) + freqKey <- function(sf) sort(paste(as.character(sf), attr(sf, "count"))) + forests <- list( + c(PectinateTree(16), PectinateTree(16)), + lapply(1:10, function(i) RandomTree(13, root = TRUE)), + lapply(1:30, function(i) RandomTree(8, root = TRUE)) + ) + for (f in forests) { + expect_equal(freqKey(SplitFrequency(f)), + freqKey(SplitFrequency(f, exact = TRUE))) + } +}) + test_that("Node support colours consistent", { expect_equal(SupportColour(NA), "red") expect_equal(SupportColour(1:2, show1 = FALSE), c("#ffffff00", "red")) diff --git a/tests/testthat/test-consensus.R b/tests/testthat/test-consensus.R index 169115c76..be65bf93f 100644 --- a/tests/testthat/test-consensus.R +++ b/tests/testthat/test-consensus.R @@ -80,6 +80,28 @@ test_that("Consensus() handles large sets of trees", { )) }) +test_that("Consensus() exact and hashed counts agree", { + # The hashed (default) and exact (opt-in) split counts must yield identical + # consensus trees; this also guards the shared counting core. + skip_if_not_installed("ape") + set.seed(1) + forests <- list( + balPec = list(BalancedTree(8), PectinateTree(8))[c(1, 1, 1, 1, 2, 2, 2)], + starlike = list(ape::read.tree(text = "((a, b), (c, d));"), + ape::read.tree(text = "((a, c), (b, d));")), + tie = c(rep(list(BalancedTree(8)), 2L), rep(list(PectinateTree(8)), 2L)), + rand12 = lapply(1:7, function(i) ape::rtree(12, br = NULL)), + rand9 = lapply(1:20, function(i) ape::rtree(9, br = NULL)) + ) + for (f in forests) { + for (p in c(0.5, 2 / 3, 1)) { + hashed <- Consensus(f, p = p) + exact <- Consensus(f, p = p, exact = TRUE) + expect_true(isTRUE(all.equal(RootTree(hashed, 1), RootTree(exact, 1)))) + } + } +}) + test_that("Consensus() handles non-preorder trees", { trees <- ape::read.nexus(test_path("testdata", "nonPreCons.nex")) expect_equal(Consensus(trees)$Nnode, 3)