#
# MIT License
#
# Copyright (c) 2020-2024 Tskit Developers
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
Module for ranking and unranking trees. Trees are considered only
leaf-labelled and unordered, so order of children does not influence equality.
"""
import collections
import functools
import heapq
import itertools
import json
import random
from typing import NamedTuple
import attr
import numpy as np
import tskit
[docs]
class Rank(NamedTuple):
"""
A tuple of 2 numbers, ``(shape, label)``, together defining a unique
topology for a labeled tree. See :ref:`sec_combinatorics`.
"""
shape: int
"""
A non-negative integer representing the (unlabelled) topology of a tree with a
defined number of tips.
"""
label: int
"""
A non-negative integer representing the order of labels for a given tree topology.
"""
def equal_chunks(lst, k):
"""
Yield k successive equally sized chunks from lst of size n.
If k >= n, we return n chunks of size 1.
Otherwise, we always return k chunks. The first k - 1 chunks will
contain exactly n // k items, and the last chunk the remainder.
"""
n = len(lst)
if k <= 0 or int(k) != k:
raise ValueError("Number of chunks must be a positive integer")
if n > 0:
chunk_size = max(1, n // k)
offset = 0
j = 0
while offset < n - chunk_size and j < k - 1:
yield lst[offset : offset + chunk_size]
offset += chunk_size
j += 1
yield lst[offset:]
@attr.s(eq=False)
class TreeNode:
"""
Simple linked tree class used to generate tree topologies.
"""
parent = attr.ib(default=None)
children = attr.ib(factory=list)
label = attr.ib(default=None)
def as_tables(self, *, num_leaves, span, branch_length):
"""
Convert the tree rooted at this node into an equivalent
TableCollection. Internal nodes are allocated in postorder.
"""
tables = tskit.TableCollection(span)
for _ in range(num_leaves):
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
def assign_internal_labels(node):
if len(node.children) == 0:
node.time = 0
else:
max_child_time = 0
for child in node.children:
assign_internal_labels(child)
max_child_time = max(max_child_time, child.time)
node.time = max_child_time + branch_length
node.label = tables.nodes.add_row(time=node.time)
for child in node.children:
tables.edges.add_row(0, span, node.label, child.label)
# Do a postorder traversal to assign the internal node labels and times.
assign_internal_labels(self)
tables.sort()
return tables
@staticmethod
def random_binary_tree(leaf_labels, rng):
"""
Returns a random binary tree where the leaves have the specified
labels using the specified random.Random instance. The root node
of this tree is returned.
Based on the description of Remy's method of generating "decorated"
random binary trees in TAOCP 7.2.1.6. This is not a direct
implementation of Algorithm R, because we are interested in
the leaf node labellings.
The pre-fascicle text is available here, page 16:
http://www.cs.utsa.edu/~wagner/knuth/fasc4a.pdf
"""
nodes = [TreeNode(label=leaf_labels[0])]
for label in leaf_labels[1:]:
# Choose a node x randomly and insert a new internal node above
# it with the (n + 1)th labelled leaf as its sibling.
x = rng.choice(nodes)
new_leaf = TreeNode(label=label)
new_internal = TreeNode(parent=x.parent, children=[x, new_leaf])
if x.parent is not None:
index = x.parent.children.index(x)
x.parent.children[index] = new_internal
rng.shuffle(new_internal.children)
x.parent = new_internal
new_leaf.parent = new_internal
nodes.extend([new_leaf, new_internal])
root = nodes[0]
while root.parent is not None:
root = root.parent
# Canonicalise the order of the children within a node. This
# is given by (num_leaves, min_label). See also the
# RankTree.canonical_order function for the definition of
# how these are ordered during rank/unrank.
def reorder_children(node):
if len(node.children) == 0:
return 1, node.label
keys = [reorder_children(child) for child in node.children]
if keys[0] > keys[1]:
node.children = node.children[::-1]
return (
sum(leaf_count for leaf_count, _ in keys),
min(min_label for _, min_label in keys),
)
reorder_children(root)
return root
@classmethod
def balanced_tree(cls, leaf_labels, arity):
"""
Returns a balanced tree of the specified arity. At each node the
leaf labels are split equally among the arity children using the
equal_chunks method.
"""
assert len(leaf_labels) > 0
if len(leaf_labels) == 1:
root = cls(label=leaf_labels[0])
else:
children = [
cls.balanced_tree(chunk, arity)
for chunk in equal_chunks(leaf_labels, arity)
]
root = cls(children=children)
for child in children:
child.parent = root
return root
def generate_star(num_leaves, *, span, branch_length, record_provenance, **kwargs):
"""
Generate a star tree for the specified number of leaves.
See the documentation for :meth:`Tree.generate_star` for more details.
"""
if num_leaves < 2:
raise ValueError("The number of leaves must be 2 or greater")
tables = tskit.TableCollection(sequence_length=span)
tables.nodes.set_columns(
flags=np.full(num_leaves, tskit.NODE_IS_SAMPLE, dtype=np.uint32),
time=np.zeros(num_leaves),
)
root = tables.nodes.add_row(time=branch_length)
tables.edges.set_columns(
left=np.full(num_leaves, 0),
right=np.full(num_leaves, span),
parent=np.full(num_leaves, root, dtype=np.int32),
child=np.arange(num_leaves, dtype=np.int32),
)
if record_provenance:
# TODO replace with a version of https://github.com/tskit-dev/tskit/pull/243
# TODO also make sure we convert all the arguments so that they are
# definitely JSON encodable.
parameters = {"command": "generate_star", "TODO": "add parameters"}
tables.provenances.add_row(
record=json.dumps(tskit.provenance.get_provenance_dict(parameters))
)
return tables.tree_sequence().first(**kwargs)
def generate_comb(num_leaves, *, span, branch_length, record_provenance, **kwargs):
"""
Generate a comb tree for the specified number of leaves.
See the documentation for :meth:`Tree.generate_comb` for more details.
"""
if num_leaves < 2:
raise ValueError("The number of leaves must be 2 or greater")
tables = tskit.TableCollection(sequence_length=span)
tables.nodes.set_columns(
flags=np.full(num_leaves, tskit.NODE_IS_SAMPLE, dtype=np.uint32),
time=np.zeros(num_leaves),
)
right_child = num_leaves - 1
time = branch_length
for left_child in range(num_leaves - 2, -1, -1):
parent = tables.nodes.add_row(time=time)
time += branch_length
tables.edges.add_row(0, span, parent, left_child)
tables.edges.add_row(0, span, parent, right_child)
right_child = parent
if record_provenance:
# TODO replace with a version of https://github.com/tskit-dev/tskit/pull/243
# TODO also make sure we convert all the arguments so that they are
# definitely JSON encodable.
parameters = {"command": "generate_comb", "TODO": "add parameters"}
tables.provenances.add_row(
record=json.dumps(tskit.provenance.get_provenance_dict(parameters))
)
return tables.tree_sequence().first(**kwargs)
def generate_balanced(
num_leaves, *, arity, span, branch_length, record_provenance, **kwargs
):
"""
Generate a balanced tree for the specified number of leaves.
See the documentation for :meth:`Tree.generate_balanced` for more details.
"""
if num_leaves < 1:
raise ValueError("The number of leaves must be at least 1")
if arity < 2:
raise ValueError("The arity must be at least 2")
root = TreeNode.balanced_tree(range(num_leaves), arity)
tables = root.as_tables(
num_leaves=num_leaves, span=span, branch_length=branch_length
)
if record_provenance:
# TODO replace with a version of https://github.com/tskit-dev/tskit/pull/243
# TODO also make sure we convert all the arguments so that they are
# definitely JSON encodable.
parameters = {"command": "generate_balanced", "TODO": "add parameters"}
tables.provenances.add_row(
record=json.dumps(tskit.provenance.get_provenance_dict(parameters))
)
return tables.tree_sequence().first(**kwargs)
def generate_random_binary(
num_leaves, *, span, branch_length, random_seed, record_provenance, **kwargs
):
"""
Sample a leaf-labelled binary tree uniformly.
See the documentation for :meth:`Tree.generate_random_binary` for more details.
"""
if num_leaves < 1:
raise ValueError("The number of leaves must be at least 1")
rng = random.Random(random_seed)
root = TreeNode.random_binary_tree(range(num_leaves), rng)
tables = root.as_tables(
num_leaves=num_leaves, span=span, branch_length=branch_length
)
if record_provenance:
# TODO replace with a version of https://github.com/tskit-dev/tskit/pull/243
# TODO also make sure we convert all the arguments so that they are
# definitely JSON encodable.
parameters = {"command": "generate_random_binary", "TODO": "add parameters"}
tables.provenances.add_row(
record=json.dumps(tskit.provenance.get_provenance_dict(parameters))
)
ts = tables.tree_sequence()
return ts.first(**kwargs)
def split_polytomies(
tree,
*,
epsilon=None,
method=None,
record_provenance=True,
random_seed=None,
**kwargs,
):
"""
Return a new tree where extra nodes and edges have been inserted
so that any any node with more than two children is resolved into
a binary tree.
See the documentation for :meth:`Tree.split_polytomies` for more details.
"""
allowed_methods = ["random"]
if method is None:
method = "random"
if method not in allowed_methods:
raise ValueError(f"Method must be chosen from {allowed_methods}")
tables = tree.tree_sequence.dump_tables()
tables.keep_intervals([tree.interval], simplify=False)
tables.edges.clear()
rng = random.Random(random_seed)
for u in tree.nodes():
if tree.num_children(u) > 2:
root = TreeNode.random_binary_tree(tree.children(u), rng)
root.label = u
root_time = tree.time(u)
stack = [(child, root_time) for child in root.children]
while len(stack) > 0:
node, parent_time = stack.pop()
if node.label is None:
if epsilon is None:
child_time = np.nextafter(parent_time, -np.inf)
else:
child_time = parent_time - epsilon
node.label = tables.nodes.add_row(time=child_time)
else:
assert len(node.children) == 0
# This is a leaf node connecting back into the original tree
child_time = tree.time(node.label)
if parent_time <= child_time:
u = root.label
min_child_time = min(tree.time(v) for v in tree.children(u))
min_time = root_time - min_child_time
message = (
f"Cannot resolve the degree {tree.num_children(u)} "
f"polytomy rooted at node {u} with minimum time difference "
f"of {min_time} to the resolved leaves."
)
if epsilon is None:
message += (
" The time difference between nodes is so small that "
"more nodes cannot be inserted between within the limits "
"of floating point precision."
)
else:
# We can also have parent_time == child_time if epsilon is
# chosen such that we exactly divide up the branch in the
# original tree. We avoid saying this is caused by a
# too-small epsilon by noting it can only happen when we
# are at leaf node in the randomly generated tree.
if parent_time == child_time and len(node.children) > 0:
message += (
f" The fixed epsilon value of {epsilon} is too small, "
"resulting in the parent and child times being equal "
"within the limits of numerical precision."
)
else:
message += (
f" The fixed epsilon value of {epsilon} is too large, "
"resulting in the parent time being less than the child "
"time."
)
raise tskit.LibraryError(message)
tables.edges.add_row(*tree.interval, node.parent.label, node.label)
for child in node.children:
stack.append((child, child_time))
else:
for v in tree.children(u):
tables.edges.add_row(*tree.interval, u, v)
if record_provenance:
parameters = {"command": "split_polytomies"}
tables.provenances.add_row(
record=json.dumps(tskit.provenance.get_provenance_dict(parameters))
)
try:
tables.sort()
ts = tables.tree_sequence()
except tskit.LibraryError as e:
msg = str(e)
# We should have caught all topology time travel above.
assert not msg.startswith("time[parent] must be greater than time[child]")
if msg.startswith(
"A mutation's time must be < the parent node of the edge on which it occurs"
):
if epsilon is not None:
msg = (
f"epsilon={epsilon} not small enough to create new nodes below a "
"polytomy, due to the time of a mutation above a child of the "
"polytomy."
)
else:
msg = (
"Cannot split polytomy: mutation with numerical precision "
"of the parent time."
)
e.args += (msg,)
raise e
return ts.at(tree.interval.left, **kwargs)
def treeseq_count_topologies(ts, sample_sets):
topology_counter = np.full(ts.num_nodes, None, dtype=object)
parent = np.full(ts.num_nodes, -1)
def update_state(tree, u):
stack = [u]
while len(stack) > 0:
v = stack.pop()
children = []
for c in tree.children(v):
if topology_counter[c] is not None:
children.append(topology_counter[c])
if len(children) > 0:
topology_counter[v] = combine_child_topologies(children)
else:
topology_counter[v] = None
p = parent[v]
if p != -1:
stack.append(p)
for sample_set_index, sample_set in enumerate(sample_sets):
for u in sample_set:
if not ts.node(u).is_sample():
raise ValueError(f"Node {u} in sample_sets is not a sample.")
topology_counter[u] = TopologyCounter.from_sample(sample_set_index)
for tree, (_, edges_out, edges_in) in zip(ts.trees(), ts.edge_diffs()):
# Avoid recomputing anything for the parent until all child edges
# for that parent are inserted/removed
for p, sibling_edges in itertools.groupby(edges_out, key=lambda e: e.parent):
for e in sibling_edges:
parent[e.child] = -1
update_state(tree, p)
for p, sibling_edges in itertools.groupby(edges_in, key=lambda e: e.parent):
if tree.is_sample(p):
raise ValueError("Internal samples not supported.")
for e in sibling_edges:
parent[e.child] = p
update_state(tree, p)
counters = []
for root in tree.roots:
if topology_counter[root] is not None:
counters.append(topology_counter[root])
yield TopologyCounter.merge(counters)
def tree_count_topologies(tree, sample_sets):
for u in tree.samples():
if not tree.is_leaf(u):
raise ValueError("Internal samples not supported.")
topology_counter = np.full(tree.tree_sequence.num_nodes, None, dtype=object)
for sample_set_index, sample_set in enumerate(sample_sets):
for u in sample_set:
if not tree.is_sample(u):
raise ValueError(f"Node {u} in sample_sets is not a sample.")
topology_counter[u] = TopologyCounter.from_sample(sample_set_index)
for u in tree.nodes(order="postorder"):
children = []
for v in tree.children(u):
if topology_counter[v] is not None:
children.append(topology_counter[v])
if len(children) > 0:
topology_counter[u] = combine_child_topologies(children)
counters = []
for root in tree.roots:
if topology_counter[root] is not None:
counters.append(topology_counter[root])
return TopologyCounter.merge(counters)
def combine_child_topologies(topology_counters):
"""
Select all combinations of topologies from different
counters in ``topology_counters`` that are capable of
being combined into a single topology. This includes
any combination of at least two topologies, all from
different children, where no topologies share a
sample set index.
"""
partial_topologies = PartialTopologyCounter()
for tc in topology_counters:
partial_topologies.add_sibling_topologies(tc)
return partial_topologies.join_all_combinations()
[docs]
class TopologyCounter:
"""
Contains the distributions of embedded topologies for every combination
of the sample sets used to generate the ``TopologyCounter``. It is
indexable by a combination of sample set indexes and returns a
``collections.Counter`` whose keys are topology ranks
(see :ref:`sec_tree_ranks`). See :meth:`Tree.count_topologies` for more
detail on how this structure is used.
"""
def __init__(self):
self.topologies = collections.defaultdict(collections.Counter)
def __getitem__(self, sample_set_indexes):
k = TopologyCounter._to_key(sample_set_indexes)
return self.topologies[k]
def __setitem__(self, sample_set_indexes, counter):
k = TopologyCounter._to_key(sample_set_indexes)
self.topologies[k] = counter
@staticmethod
def _to_key(sample_set_indexes):
if not isinstance(sample_set_indexes, collections.abc.Iterable):
sample_set_indexes = (sample_set_indexes,)
return tuple(sorted(sample_set_indexes))
def __eq__(self, other):
return self.__class__ == other.__class__ and self.topologies == other.topologies
@staticmethod
def merge(topology_counters):
"""
Union together independent topology counters into one.
"""
total = TopologyCounter()
for tc in topology_counters:
for k, v in tc.topologies.items():
total.topologies[k] += v
return total
@staticmethod
def from_sample(sample_set_index):
"""
Generate the topologies covered by a single sample. This
is the single-leaf topology representing the single sample
set.
"""
rank_tree = RankTree(children=[], label=sample_set_index)
tc = TopologyCounter()
tc[sample_set_index][rank_tree.rank()] = 1
return tc
class PartialTopologyCounter:
"""
Represents the possible combinations of children under a node in a tree
and the combinations of embedded topologies that are rooted at the node.
This allows an efficient way of calculating which unique embedded
topologies arise by only every storing a given pairing of sibling topologies
once.
``partials`` is a dictionary where a key is a tuple of sample set indexes,
and the value is a ``collections.Counter`` that counts combinations of
sibling topologies whose tips represent the sample sets in the key.
Each element of the counter is a homogeneous tuple where each element represents
a topology. The topology is itself a tuple of the sample set indexes in that
topology and the rank.
"""
def __init__(self):
self.partials = collections.defaultdict(collections.Counter)
def add_sibling_topologies(self, topology_counter):
"""
Combine each topology in the given TopologyCounter with every existing
combination of topologies whose sample set indexes are disjoint from the
topology from the counter. This also includes adding the topologies from
the counter without joining them to any existing combinations.
"""
merged = collections.defaultdict(collections.Counter)
for sample_set_indexes, topologies in topology_counter.topologies.items():
for rank, count in topologies.items():
topology = ((sample_set_indexes, rank),)
# Cross with existing topology combinations
for sibling_sample_set_indexes, siblings in self.partials.items():
if isdisjoint(sample_set_indexes, sibling_sample_set_indexes):
for sib_topologies, sib_count in siblings.items():
merged_topologies = merge_tuple(sib_topologies, topology)
merged_sample_set_indexes = merge_tuple(
sibling_sample_set_indexes, sample_set_indexes
)
merged[merged_sample_set_indexes][merged_topologies] += (
count * sib_count
)
# Propagate without combining
merged[sample_set_indexes][topology] += count
for sample_set_indexes, counter in merged.items():
self.partials[sample_set_indexes] += counter
def join_all_combinations(self):
"""
For each pairing of child topologies, join them together into a new
tree and count the resulting topologies.
"""
topology_counter = TopologyCounter()
for sample_set_indexes, sibling_topologies in self.partials.items():
for topologies, count in sibling_topologies.items():
# A node must have at least two children
if len(topologies) >= 2:
rank = PartialTopologyCounter.join_topologies(topologies)
topology_counter[sample_set_indexes][rank] += count
else:
# Pass on the single tree without adding a parent
for _, rank in topologies:
topology_counter[sample_set_indexes][rank] += count
return topology_counter
@staticmethod
def join_topologies(child_topologies):
children = []
for sample_set_indexes, rank in child_topologies:
n = len(sample_set_indexes)
t = RankTree.unrank(n, rank, list(sample_set_indexes))
children.append(t)
children.sort(key=RankTree.canonical_order)
return RankTree(children).rank()
[docs]
def all_trees(num_leaves, span=1):
"""
Generates all unique leaf-labelled trees with ``num_leaves``
leaves. See :ref:`sec_combinatorics` on the details of this
enumeration. The leaf labels are selected from the set
``[0, num_leaves)``. The times and labels on internal nodes are
chosen arbitrarily.
:param int num_leaves: The number of leaves of the tree to generate.
:param float span: The genomic span of each returned tree.
:rtype: tskit.Tree
"""
for rank_tree in RankTree.all_labelled_trees(num_leaves):
yield rank_tree.to_tsk_tree(span=span)
[docs]
def all_tree_shapes(num_leaves, span=1):
"""
Generates all unique shapes of trees with ``num_leaves`` leaves.
:param int num_leaves: The number of leaves of the tree to generate.
:param float span: The genomic span of each returned tree.
:rtype: tskit.Tree
"""
for rank_tree in RankTree.all_unlabelled_trees(num_leaves):
default_labelling = rank_tree.label_unrank(0)
yield default_labelling.to_tsk_tree(span=span)
[docs]
def all_tree_labellings(tree, span=1):
"""
Generates all unique labellings of the leaves of a
:class:`tskit.Tree`. Leaves are labelled from the set
``[0, n)`` where ``n`` is the number of leaves of ``tree``.
:param tskit.Tree tree: The tree used to generate
labelled trees of the same shape.
:param float span: The genomic span of each returned tree.
:rtype: tskit.Tree
"""
rank_tree = RankTree.from_tsk_tree(tree)
for labelling in RankTree.all_labellings(rank_tree):
yield labelling.to_tsk_tree(span=span)
class RankTree:
"""
A tree class that maintains the topological ranks of each node in the tree.
This structure can be used to efficiently compute the rank of a tree of
n labelled leaves and produce a tree given a rank.
"""
def __init__(self, children, label=None):
# Children are assumed to be sorted by RankTree.canonical_order
self.children = children
if len(children) == 0:
self.num_leaves = 1
self.labels = [label]
else:
self.num_leaves = sum(c.num_leaves for c in children)
self.labels = list(heapq.merge(*(c.labels for c in children)))
self._shape_rank = None
self._label_rank = None
def compute_shape_rank(self):
"""
Mirroring the way in which unlabelled trees are enumerated, we must
first calculate the number of trees whose partitions of number of leaves
rank lesser than this tree's partition.
Once we reach the partition of leaves in this tree, we examine the
groups of child subtrees assigned to subsequences of the partition.
For each group of children with the same number of leaves, k, the trees
in that group were selected according to a combination with replacement
of those trees from S(k). By finding the rank of that combination,
we find how many combinations preceded the current one in that group.
That rank is then multiplied by the total number of arrangements that
could be made in the following groups, added to the total rank,
and then we recur on the rest of the group and groups.
"""
part = self.leaf_partition()
total = 0
for prev_part in partitions(self.num_leaves):
if prev_part == part:
break
total += num_tree_pairings(prev_part)
child_groups = self.group_children_by_num_leaves()
next_child_idx = 0
for g in child_groups:
next_child_idx += len(g)
k = g[0].num_leaves
S_k = num_shapes(k)
child_ranks = [c.shape_rank() for c in g]
g_rank = Combination.with_replacement_rank(child_ranks, S_k)
# TODO precompute vector before loop
rest_part = part[next_child_idx:]
total_rest = num_tree_pairings(rest_part)
total += g_rank * total_rest
return total
def compute_label_rank(self):
"""
Again mirroring how we've labeled a particular tree, T, we can rank the
labelling on T.
We group the children into symmetric groups. In the context of labelling,
symmetric groups contain child trees that are of the same shape. Each
group contains a combination of labels selected from all the labels
available to T.
The different variables to consider are:
1. How to assign a combination of labels to the first group.
2. Given a combination of labels assigned to the group, how can we
distribute those labels to each tree in the group.
3. Given an assignment of the labels to each tree in the group, how many
distinct ways could all the trees in the group be labelled.
These steps for generating labelled trees break down the stages of
ranking them.
For each group G, we can find the rank of the combination of labels
assigned to G. This rank times the number of ways the trees in G
could be labelled, times the number of possible labellings of the
rest of the trees, gives the number of labellings that precede those with
the given combination of labels assigned to G. This process repeats and
breaks down to give the rank of the assignment of labels to trees in G,
and the label ranks of the trees themselves in G.
"""
all_labels = self.labels
child_groups = self.group_children_by_shape()
total = 0
for i, g in enumerate(child_groups):
rest_groups = child_groups[i + 1 :]
g_labels = list(heapq.merge(*(t.labels for t in g)))
num_rest_labellings = num_list_of_group_labellings(rest_groups)
# Preceded by all of the ways to label all the groups
# with a lower ranking combination given to g.
comb_rank = Combination.rank(g_labels, all_labels)
num_g_labellings = num_group_labellings(g)
preceding_comb = comb_rank * num_g_labellings * num_rest_labellings
# Preceded then by all the configurations of g ranking less than
# the current one
rank_from_g = group_rank(g) * num_rest_labellings
total += preceding_comb + rank_from_g
all_labels = set_minus(all_labels, g_labels)
return total
# TODO I think this would boost performance if it were a field and not
# recomputed.
def num_labellings(self):
child_groups = self.group_children_by_shape()
return num_list_of_group_labellings(child_groups)
def rank(self):
return Rank(self.shape_rank(), self.label_rank())
def shape_rank(self):
if self._shape_rank is None:
self._shape_rank = self.compute_shape_rank()
return self._shape_rank
def label_rank(self):
if self._label_rank is None:
assert self.shape_rank() is not None
self._label_rank = self.compute_label_rank()
return self._label_rank
@staticmethod
def unrank(num_leaves, rank, labels=None):
"""
Produce a ``RankTree`` of the given ``rank`` with ``num_leaves`` leaves,
labelled with ``labels``. Labels must be sorted, and if ``None`` default
to ``[0, num_leaves)``.
"""
shape_rank, label_rank = rank
if shape_rank < 0 or label_rank < 0:
raise ValueError("Rank is out of bounds.")
unlabelled = RankTree.shape_unrank(num_leaves, shape_rank)
return unlabelled.label_unrank(label_rank, labels)
@staticmethod
def shape_unrank(n, shape_rank):
"""
Generate an unlabelled tree with n leaves with a shape corresponding to
the `shape_rank`.
"""
part, child_shape_ranks = children_shape_ranks(shape_rank, n)
children = [
RankTree.shape_unrank(k, rk) for k, rk in zip(part, child_shape_ranks)
]
t = RankTree(children=children)
t._shape_rank = shape_rank
return t
def label_unrank(self, label_rank, labels=None):
"""
Generate a tree with the same shape, whose leaves are labelled
from ``labels`` with the labelling corresponding to ``label_rank``.
"""
if labels is None:
labels = list(range(self.num_leaves))
if self.is_leaf():
if label_rank != 0:
raise ValueError("Rank is out of bounds.")
return RankTree(children=[], label=labels[0])
child_groups = self.group_children_by_shape()
child_labels, child_label_ranks = children_label_ranks(
child_groups, label_rank, labels
)
children = self.children
labelled_children = [
RankTree.label_unrank(c, c_rank, c_labels)
for c, c_rank, c_labels in zip(children, child_label_ranks, child_labels)
]
t = RankTree(children=labelled_children)
t._shape_rank = self.shape_rank()
t._label_rank = label_rank
return t
@staticmethod
def canonical_order(c):
"""
Defines the canonical ordering of sibling subtrees.
"""
return c.num_leaves, c.shape_rank(), c.min_label()
@staticmethod
def from_tsk_tree_node(tree, u):
if tree.is_leaf(u):
return RankTree(children=[], label=u)
if tree.num_children(u) == 1:
raise ValueError("Cannot rank trees with unary nodes")
children = list(
sorted(
(RankTree.from_tsk_tree_node(tree, c) for c in tree.children(u)),
key=RankTree.canonical_order,
)
)
return RankTree(children=children)
@staticmethod
def from_tsk_tree(tree):
if tree.num_roots != 1:
raise ValueError("Cannot rank trees with multiple roots")
return RankTree.from_tsk_tree_node(tree, tree.root)
def to_tsk_tree(self, span=1, branch_length=1):
"""
Convert a ``RankTree`` into the only tree in a new tree sequence. Internal
nodes and their times are assigned via a postorder traversal of the tree.
:param float span: The genomic span of the returned tree. The tree will cover
the interval :math:`[0, span)` and the :attr:`~Tree.tree_sequence` from which
the tree is taken will have its :attr:`~tskit.TreeSequence.sequence_length`
equal to ``span``.
:param float branch_length: The minimum length of a branch in the returned
tree.
"""
if set(self.labels) != set(range(self.num_leaves)):
raise ValueError("Labels set must be equivalent to [0, num_leaves)")
tables = tskit.TableCollection(span)
def add_node(node):
if node.is_leaf():
assert node.label is not None
return node.label
child_ids = [add_node(child) for child in node.children]
max_child_time = max(tables.nodes.time[c] for c in child_ids)
parent_id = tables.nodes.add_row(time=max_child_time + branch_length)
for child_id in child_ids:
tables.edges.add_row(0, span, parent_id, child_id)
return parent_id
for _ in range(self.num_leaves):
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
add_node(self)
# The way in which we're inserting nodes doesn't necessarily
# adhere to the ordering constraint on edges, so we have
# to sort.
tables.sort()
return tables.tree_sequence().first()
@staticmethod
def all_labelled_trees(n):
"""
Generate all unordered, leaf-labelled trees with n leaves.
"""
for tree in RankTree.all_unlabelled_trees(n):
yield from RankTree.all_labellings(tree)
@staticmethod
def all_unlabelled_trees(n):
"""
Generate all tree shapes with n leaves. See :ref:`sec_combinatorics`
for how tree shapes are enumerated.
"""
if n == 1:
yield RankTree(children=[])
else:
for part in partitions(n):
for subtree_pairing in RankTree.all_subtree_pairings(
group_partition(part)
):
yield RankTree(children=subtree_pairing)
@staticmethod
def all_subtree_pairings(grouped_part):
if len(grouped_part) == 0:
yield []
else:
g = grouped_part[0]
k = g[0]
all_k_leaf_trees = RankTree.all_unlabelled_trees(k)
num_k_leaf_trees = len(g)
g_trees = itertools.combinations_with_replacement(
all_k_leaf_trees, num_k_leaf_trees
)
for first_trees in g_trees:
for rest in RankTree.all_subtree_pairings(grouped_part[1:]):
yield list(first_trees) + rest
@staticmethod
def all_labellings(tree, labels=None):
"""
Given a tree, generate all the unique labellings of that tree.
See :ref:`sec_combinatorics` for how labellings of a tree are
enumerated.
"""
if labels is None:
labels = list(range(tree.num_leaves))
if tree.is_leaf():
assert len(labels) == 1
yield RankTree(children=[], label=labels[0])
else:
groups = tree.group_children_by_shape()
for labeled_children in RankTree.label_all_groups(groups, labels):
yield RankTree(children=labeled_children)
@staticmethod
def label_all_groups(groups, labels):
if len(groups) == 0:
yield []
else:
g, rest = groups[0], groups[1:]
x = len(g)
k = g[0].num_leaves
for g_labels in itertools.combinations(labels, x * k):
rest_labels = set_minus(labels, g_labels)
for labeled_g in RankTree.label_tree_group(g, g_labels):
for labeled_rest in RankTree.label_all_groups(rest, rest_labels):
yield labeled_g + labeled_rest
@staticmethod
def label_tree_group(trees, labels):
if len(trees) == 0:
assert len(labels) == 0
yield []
else:
first, rest = trees[0], trees[1:]
k = first.num_leaves
min_label = labels[0]
for first_other_labels in itertools.combinations(labels[1:], k - 1):
first_labels = [min_label] + list(first_other_labels)
rest_labels = set_minus(labels, first_labels)
for labeled_first in RankTree.all_labellings(first, first_labels):
for labeled_rest in RankTree.label_tree_group(rest, rest_labels):
yield [labeled_first] + labeled_rest
def _newick(self):
if self.is_leaf():
return str(self.label) if self.labelled() else ""
return "(" + ",".join(c._newick() for c in self.children) + ")"
def newick(self):
return self._newick() + ";"
@property
def label(self):
return self.labels[0]
def labelled(self):
return all(label is not None for label in self.labels)
def min_label(self):
return self.labels[0]
def is_leaf(self):
return len(self.children) == 0
def leaf_partition(self):
return [c.num_leaves for c in self.children]
def group_children_by_num_leaves(self):
def same_num_leaves(c1, c2):
return c1.num_leaves == c2.num_leaves
return group_by(self.children, same_num_leaves)
def group_children_by_shape(self):
def same_shape(c1, c2):
return c1.num_leaves == c2.num_leaves and c1.shape_rank() == c2.shape_rank()
return group_by(self.children, same_shape)
def __eq__(self, other):
if self.__class__ != other.__class__:
return False
if self.is_leaf() and other.is_leaf():
return self.label == other.label
if len(self.children) != len(other.children):
return False
return all(c1 == c2 for c1, c2 in zip(self.children, other.children))
def __ne__(self, other):
return not self.__eq__(other)
def shape_equal(self, other):
if self.is_leaf() and other.is_leaf():
return True
if len(self.children) != len(other.children):
return False
return all(c1.shape_equal(c2) for c1, c2 in zip(self.children, other.children))
def is_canonical(self):
if self.is_leaf():
return True
children = self.children
for c1, c2 in zip(children, children[1:]):
if RankTree.canonical_order(c1) > RankTree.canonical_order(c2):
return False
return all(c.is_canonical() for c in children)
def is_symmetrical(self):
if self.is_leaf():
return True
even_split_leaves = len(set(self.leaf_partition())) == 1
all_same_rank = len({c.shape_rank() for c in self.children}) == 1
return even_split_leaves and all_same_rank
# TODO This is called repeatedly in ranking and unranking and has a perfect
# subtructure for DP. It's only every called on n in [0, num_leaves]
# so we should compute a vector of those results up front instead of using
# repeated calls to this function.
# Put an lru_cache on for now as a quick replacement (cuts test time down by 80%)
@functools.lru_cache
def num_shapes(n):
"""
The cardinality of the set of unlabelled trees with n leaves,
up to isomorphism.
"""
if n <= 1:
return n
return sum(num_tree_pairings(part) for part in partitions(n))
def num_tree_pairings(part):
"""
The number of unique tree shapes that could be assembled from
a given partition of leaves. If we group the elements of the partition
by number of leaves, each group can be independently enumerated and the
cardinalities of each group's pairings can be multiplied. Within a group,
subsequent trees must have equivalent or greater rank, so the number of
ways to select trees follows combinations with replacement from the set
of all possible trees for that group.
"""
total = 1
for g in group_partition(part):
k = g[0]
total *= Combination.comb_with_replacement(num_shapes(k), len(g))
return total
def num_labellings(n, shape_rk):
return RankTree.shape_unrank(n, shape_rk).num_labellings()
def children_shape_ranks(rank, n):
"""
Return the partition of leaves associated
with the children of the tree of rank `rank`, and
the ranks of each child tree.
"""
part = []
for prev_part in partitions(n):
num_trees_with_part = num_tree_pairings(prev_part)
if rank < num_trees_with_part:
part = prev_part
break
rank -= num_trees_with_part
else:
if n != 1:
raise ValueError("Rank is out of bounds.")
grouped_part = group_partition(part)
child_ranks = []
next_child = 0
for g in grouped_part:
next_child += len(g)
k = g[0]
# TODO precompute vector up front
rest_children = part[next_child:]
rest_num_pairings = num_tree_pairings(rest_children)
shapes_comb_rank = rank // rest_num_pairings
g_shape_ranks = Combination.with_replacement_unrank(
shapes_comb_rank, num_shapes(k), len(g)
)
child_ranks += g_shape_ranks
rank %= rest_num_pairings
return part, child_ranks
def children_label_ranks(child_groups, rank, labels):
"""
Produces the subsets of labels assigned to each child
and the associated label rank of each child.
"""
child_labels = []
child_label_ranks = []
for i, g in enumerate(child_groups):
k = g[0].num_leaves
g_num_leaves = k * len(g)
num_g_labellings = num_group_labellings(g)
# TODO precompute vector of partial products outside of loop
rest_groups = child_groups[i + 1 :]
num_rest_labellings = num_list_of_group_labellings(rest_groups)
num_labellings_per_label_comb = num_g_labellings * num_rest_labellings
comb_rank = rank // num_labellings_per_label_comb
rank_given_label_comb = rank % num_labellings_per_label_comb
g_rank = rank_given_label_comb // num_rest_labellings
g_labels = Combination.unrank(comb_rank, labels, g_num_leaves)
g_child_labels, g_child_ranks = group_label_ranks(g_rank, g, g_labels)
child_labels += g_child_labels
child_label_ranks += g_child_ranks
labels = set_minus(labels, g_labels)
rank %= num_rest_labellings
return child_labels, child_label_ranks
def group_rank(g):
k = g[0].num_leaves
n = len(g) * k
# Num ways to label a single one of the trees
# We can do this once because all the trees in the group
# are of the same shape rank
y = g[0].num_labellings()
all_labels = list(heapq.merge(*(t.labels for t in g)))
rank = 0
for i, t in enumerate(g):
u_labels = t.labels
curr_trees = len(g) - i
# Kind of cheating here leaving the selection of min labels implicit
# because the rank of the comb without min labels is the same
comb_rank = Combination.rank(u_labels, all_labels)
# number of ways to distribute labels to rest leaves
num_rest_combs = 1
remaining_leaves = n - (i + 1) * k
for j in range(curr_trees - 1):
num_rest_combs *= Combination.comb(remaining_leaves - j * k - 1, k - 1)
preceding_combs = comb_rank * num_rest_combs * (y**curr_trees)
curr_comb = t.label_rank() * num_rest_combs * (y ** (curr_trees - 1))
rank += preceding_combs + curr_comb
all_labels = set_minus(all_labels, u_labels)
return rank
# TODO This is only used in a few cases and mostly in a n^2 way. Would
# be easy and useful to do this DP and produce a list of partial products
def num_list_of_group_labellings(groups):
"""
Given a set of labels and a list of groups, how many unique ways are there
to assign subsets of labels to each group in the list and subsequently
label all the trees in all the groups.
"""
remaining_leaves = sum(len(g) * g[0].num_leaves for g in groups)
total = 1
for g in groups:
k = g[0].num_leaves
x = len(g)
num_label_choices = Combination.comb(remaining_leaves, x * k)
total *= num_label_choices * num_group_labellings(g)
remaining_leaves -= x * k
return total
def num_group_labellings(g):
"""
Given a particular set of labels, how many unique ways are there
to assign subsets of labels to each tree in the group and subsequently
label those trees.
"""
# Shortcut because all the trees are identical and can therefore
# be labelled in the same ways
num_tree_labelings = g[0].num_labellings() ** len(g)
return num_assignments_in_group(g) * num_tree_labelings
def num_assignments_in_group(g):
"""
Given this group of identical trees, how many unique ways
are there to divide up a set of n labels?
"""
n = sum(t.num_leaves for t in g)
total = 1
for t in g:
k = t.num_leaves
# Choose k - 1 from n - 1 because the minimum label must be
# assigned to the first tree for a canonical labelling.
total *= Combination.comb(n - 1, k - 1)
n -= k
return total
def group_label_ranks(rank, child_group, labels):
"""
Given a group of trees of the same shape, a label rank and list of labels,
produce assignment of label subsets to each tree in the group and the
label rank of each tree.
"""
child_labels = []
child_label_ranks = []
for i, rank_tree in enumerate(child_group):
k = rank_tree.num_leaves
num_t_labellings = rank_tree.num_labellings()
rest_trees = child_group[i + 1 :]
num_rest_assignments = num_assignments_in_group(rest_trees)
num_rest_labellings = num_rest_assignments * (
num_t_labellings ** len(rest_trees)
)
num_labellings_per_label_comb = num_t_labellings * num_rest_labellings
comb_rank = rank // num_labellings_per_label_comb
rank_given_comb = rank % num_labellings_per_label_comb
t_rank = rank_given_comb // num_rest_labellings
rank %= num_rest_labellings
min_label = labels[0]
t_labels = [min_label] + Combination.unrank(comb_rank, labels[1:], k - 1)
labels = set_minus(labels, t_labels)
child_labels.append(t_labels)
child_label_ranks.append(t_rank)
return child_labels, child_label_ranks
class Combination:
@staticmethod
def comb(n, k):
"""
The number of times you can select k items from
n items without order and without replacement.
FIXME: This function will be available in `math` in Python 3.8
and should be replaced eventually.
"""
k = min(k, n - k)
res = 1
for i in range(1, k + 1):
res *= n - k + i
res //= i
return res
@staticmethod
def comb_with_replacement(n, k):
"""
Also called multichoose, the number of times you can select
k items from n items without order but *with* replacement.
"""
return Combination.comb(n + k - 1, k)
@staticmethod
def rank(combination, elements):
"""
Find the combination of k elements from the given set of elements
with the given rank in a lexicographic ordering.
"""
indices = [elements.index(x) for x in combination]
return Combination.from_range_rank(indices, len(elements))
@staticmethod
def from_range_rank(combination, n):
"""
Find the combination of k integers from [0, n)
with the given rank in a lexicographic ordering.
"""
k = len(combination)
if k == 0 or k == n:
return 0
j = combination[0]
combination = [x - 1 for x in combination]
if j == 0:
return Combination.from_range_rank(combination[1:], n - 1)
first_rank = Combination.comb(n - 1, k - 1)
rest_rank = Combination.from_range_rank(combination, n - 1)
return first_rank + rest_rank
@staticmethod
def unrank(rank, elements, k):
n = len(elements)
if k == 0:
return []
if len(elements) == 0:
raise ValueError("Rank is out of bounds.")
n_rest_combs = Combination.comb(n - 1, k - 1)
if rank < n_rest_combs:
return elements[:1] + Combination.unrank(rank, elements[1:], k - 1)
return Combination.unrank(rank - n_rest_combs, elements[1:], k)
@staticmethod
def with_replacement_rank(combination, n):
"""
Find the rank of ``combination`` in the lexicographic ordering of
combinations with replacement of integers from [0, n).
"""
k = len(combination)
if k == 0:
return 0
j = combination[0]
if k == 1:
return j
if j == 0:
return Combination.with_replacement_rank(combination[1:], n)
rest = [x - j for x in combination[1:]]
preceding = 0
for i in range(j):
preceding += Combination.comb_with_replacement(n - i, k - 1)
return preceding + Combination.with_replacement_rank(rest, n - j)
@staticmethod
def with_replacement_unrank(rank, n, k):
"""
Find the combination with replacement of k integers from [0, n)
with the given rank in a lexicographic ordering.
"""
if k == 0:
return []
i = 0
preceding = Combination.comb_with_replacement(n, k - 1)
while rank >= preceding:
rank -= preceding
i += 1
preceding = Combination.comb_with_replacement(n - i, k - 1)
rest = Combination.with_replacement_unrank(rank, n - i, k - 1)
return [i] + [x + i for x in rest]
def set_minus(arr, subset):
return [x for x in arr if x not in set(subset)]
# TODO I think we can use part-count form everywhere. Right now
# there's a janky work-around of grouping the partition when
# we needed in part-count form but it doesn't look like there's any
# place that can't just accept it from the start.
def partitions(n):
"""
Ascending integer partitions of n, excluding the partition [n].
Since trees with unary nodes are uncountable, the partition of
leaves must be at least size two.
"""
if n > 0:
# last partition is guaranteed to be length 1.
yield from itertools.takewhile(lambda a: len(a) > 1, rule_asc(n))
def rule_asc(n):
"""
Produce the integer partitions of n as ascending compositions.
See: http://jeromekelleher.net/generating-integer-partitions.html
"""
a = [0 for _ in range(n + 1)]
k = 1
a[1] = n
while k != 0:
x = a[k - 1] + 1
y = a[k] - 1
k -= 1
while x <= y:
a[k] = x
y -= x
k += 1
a[k] = x + y
yield a[: k + 1]
def group_by(values, equal):
groups = []
curr_group = []
for x in values:
if len(curr_group) == 0 or equal(x, curr_group[0]):
curr_group.append(x)
else:
groups.append(curr_group)
curr_group = [x]
if len(curr_group) != 0:
groups.append(curr_group)
return groups
def group_partition(part):
return group_by(part, lambda x, y: x == y)
def merge_tuple(tup1, tup2):
return tuple(heapq.merge(tup1, tup2))
def isdisjoint(iterable1, iterable2):
return set(iterable1).isdisjoint(iterable2)