Deep (neural networks for) Phylogenetics Via Traversals
This package can easily be installed with mamba (or conda). First, create the
dpvt conda environment:
mamba env create -f environment.ymlTo install the package locally run in the root folder (with the pyproject.toml
file):
pip install -e .Currently, datasets are stored in the dpvt-experiments-1 repository.
We currently assume that training/testing/validation data is pickled as one
dictionary with keys being trees and values being list of labels determining
whether an edge is in a MP tree (label 0) or not (label 1). These labels are
sorted according to a pre-order traversal. Our training/validation split is
0.8/0.2 and when splitting the data we ensure that we get balanced training and
validation sets, i.e. roughly the same ratio of MP to non-MP data in both sets.
We implement two different classes for datasets in wrapper.py:
TraversalDataset and TreeDataset. By default we use the TraversalDataset,
this is set in the dpvt-experiments-1 repo.
-
traversal: (node1, node2, node3) for all nodes node3 that are below an internal edge for every tree in the input. node1 and node2 will be input to the RNN predicting the feature of node3. For upward traversal this means that node1 and node2 are children of node3. Nodes are indexed by preorder traversal.- dimension:
(num_trees, 2, num_int_edges, 3)(2 for upward and downward traversal)
- dimension:
-
mutations: Contains for each tree, node, and site a tensor$(m_A,m_G,m_C,m_T)$ , where$m_i=1$ and$m_j=-1$ if there is a mutation from base j to base i at this node, all other entries are$0$ .- dimension:
(num_trees, num_nodes, num_sites, 4)
- dimension:
-
labels: For each tree and each node, indicates whether the edge above this node is in a MP tree (0) or not (1)- dimension:
(num_trees, num_nodes)
- dimension:
-
masks: For each tree and each node, indicates whether the edge above this node is an internal edge (True) or not (False)- dimension:
(num_trees, num_nodes)
- dimension:
Note that if input trees have different number of taxa and/or the sequences on
leaves have different lengths in different trees, traversal, mutations, and
labels are padded with -1, masks are padded with False.
When iterating through the TraversalDataset in the forward function, we stop as
soon as we see two -1 in the traversal tensor, as this means that we reached
the padding. With the masks set to False for all those padded entries in the
mutations tensor, none of the -1 are being used for calculating the loss.
data: ete3 trees with node attributessequence, which is needed to assess which mutations occur on each edge- dimension:
num_trees
- dimension:
labels: For each tree and each node, indicates whether the edge above this node is in a MP tree (0) or not (1)- dimension:
(num_trees, num_nodes)
- dimension:
masks: For each tree and each node, indicates whether the edge above this node is an internal edge (True) or not (False)- dimension:
(num_trees, num_nodes)
- dimension:
We define a Pytorch module TraverseNN which evaluates whether edges in a given
labeled tree appear in a maximum parsimony tree, for the given sequences on the
leaf nodes. This module is defined in dpvt/models.py.
In the following we describe how the models work for the two different data
structures lined out above. Though the description of the models is slightly
different for the two datasets, they two versions are doing the exact same
thing. The advantage of the TraversalDataset is, however, that it uses
torch.tensors only can can therefore be run on GPUs.
-
Traversal step: traverse the tree by iterating through the
traversaltensor to learn features for each node that are saved in thelearned_featurestensor. Due to the setup, iterating throughtraversalwill automatically first apply the upward and then the downward traversal of the tree. When we are at an element(node1, node2, node3)of the tensor, we input the part of themutationsandlearned_featurestensor corresponding tonode1andnode2into our RNN to learn thelearned_featuresofnode3. For each node triple we iterate over all sites of the alignments, to the feature fornode3is learned separately for each site of the sequences. -
Site-aggregation step: We apply a transformer to combine the
learned_featuresover all sites. Thelearned_featuresare the input and the output is a tensor of the same size. We then average the output over all sites to be our final feature for each node. -
Final output step: The output from the previous step is passed through a linear layer, the
classifierattribute, to produce a tensor in logit space, of dimension(n_nodes). Then a sigmoid function is applied. At entry$i$ , values near0.0mean the$i$ -th edge is in a maximum parsimony tree, while values near1.0mean the$i$ -th edge is not in a maximum parsimony tree. The output values are arranged to correspond to edges in preorder traversal order.
-
Edge mutation annotation: At each node, we assign a
edge_mutationattribute which encodes the difference between the node'ssequenceand its parent'ssequence. Theedge_mutationattribute is a pytorch tensor of dimension(n_sites, 4). A mutationA -> Tfrom parent to child is encoded as[..., [-1, 1, 0, 0], ...]. -
Traversal step: We apply two traversals to the tree, combining mutation data across the tree. This step applies to each site separately.
-
Post-order traversal: We first traverse the tree root-ward, where at each step we assign a node the attribute
node.to_parent["clade_mutation"], a tensor of dimension(n_sites, 4), which is the output of a single-hidden-layer neural network, stored in the class attributetraverse_stack. As input, the neural network takes theedge_mutationandclade_mutationfeatures of its two children nodes, and applies symmetrization so that the order of the children does not matter. For leaf nodes, which have no children,node.to_parent["clade_mutation"]is initialized to a zero tensor. -
Pre-order traversal: We traverse the tree leaf-ward, where at each step we assign a node the attribute
node.from_parent["clade_mutation"], a tensor of dimension(n_sites, 4), which is the output of the single-hidden-layer neural networktraverse_stack. As input, the neural network takes thenode.from_parent[edge_mutation]andnode.from_parent[clade_mutation]tensors of the parent node and thenode.to_parent[edge_mutation]andnode.to_parent[clade_mutation]tensors of the sister node.
-
-
Site-aggregation step: We apply a transformer encoder to combine the clade mutation data across sites. This uses the
encoderattribute. As input, we concatenate the tensorsnode.to_parent["clade_mutation"]andnode.from_parent["clade_mutation"], to form a tensor of dimension(n_sites, 8). This is passed through the transformer encoder, and the first row, a size-8tensor, is kept. These tensors from each node are stacked together in preorder-traversal order, forming a tensor of dimension(n_nodes, 8). -
Final output step: The output from the previous step is passed through a linear layer, the
classifierattribute, to produce a tensor in logit space, of dimension(n_nodes). Then a sigmoid function is applied. At entry$i$ , values near0.0mean the$i$ -th edge is in a maximum parsimony tree, while values near1.0mean the$i$ -th edge is not in a maximum parsimony tree. The output values are arranged to correspond to edges in preorder traversal order.
This Pytorch module inherits from TraverseNN and changes the order of the
steps described for this module to first aggregate per-site information at every
node of a tree (step 2.) and then use the learned features for the tree
traversal (step 1.).
This model is very similar to TraverseNN, but we replace step 2. with a
simpler aggregation method. We aggregate sites by simply outputting as feature
the maximum of learned_features[i] over all sites.
This model is very similar to TraverseNN, but we replace step 2. with a
simpler aggregation method. We aggregate sites by simply outputting as feature
the average of learned_features[i] over all sites.
As a baseline model to compare our neural network model to, we have implemented
the BaselineReversion model. This model checks for every edge in the given
tree if at any site there is a mutation back to a state that appeared at this
site at an ancestor of this edge. If so, we label the edge as non-MP edge. This
model does not require any training.
To view training logs, run tensorboard --logdir . and direct your browser to
http://localhost:6006/. The tensorboard additionally shows ROC curves for the
performance of classification on the test set.
-
models.py: contains definitions of models. -
wrapper.py: contains wrappers for models and datasets.
dpvtex: containsdpvt_data.py, which implements functions to get datasets for a given nickname anddpvt_zoo.py, which creates models for a given nickname. These nicknames are provided to theSnakefileinconfig.yaml. It furthermore contains scripts to generate training and testing data. More details can be found in the README of thedpvt-experiments-1repo.train: containsSnakefileandconfig.yaml, in which models and datasets for training are specified.