# MIT License
#
# Copyright (c) 2020 University of Oxford
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
Utility functions for tsdate. Many of these can be removed when tskit is updated to
a more recent version which has the functionality built-in
"""
import json
import logging
import numba
import numpy as np
import tskit
from numba.types import UniTuple as _unituple # NOQA: N813
import tsdate
from . import provenance
from .approx import _b, _b1r, _f, _f1r, _f1w, _i, _i1r, _i1w
logger = logging.getLogger(__name__)
def reduce_to_contemporaneous(ts):
"""
Simplify the ts to only the contemporaneous samples, and return the new ts + node map
"""
samples = ts.samples()
contmpr_samples = samples[ts.nodes_time[samples] == 0]
return ts.simplify(
contmpr_samples,
map_nodes=True,
keep_unary=True,
filter_populations=False,
filter_sites=False,
record_provenance=False,
filter_individuals=False,
)
[docs]
def preprocess_ts(
tree_sequence,
*,
minimum_gap=None,
remove_telomeres=None,
delete_intervals=None,
split_disjoint=None,
filter_populations=False,
filter_individuals=False,
filter_sites=False,
record_provenance=None,
**kwargs,
):
"""
Function to prepare tree sequences for dating by modifying the tree sequence
to increase the accuracy of dating. This can involve removing data-poor regions,
removing locally-unary segments of nodes via simplification, and splitting
discontinuous nodes.
:param tskit.TreeSequence tree_sequence: The input tree sequence
to be preprocessed.
:param float minimum_gap: The minimum gap between sites to remove from the tree
sequence. Default: ``None`` treated as ``1000000``. Removed regions are recorded
in the provenance of the resulting tree sequence.
:param bool remove_telomeres: Should all material before the first site and after the
last site be removed, regardless of the length. Default: ``None`` treated as
``True``
:param array_like delete_intervals: A list (start, end) pairs describing the
genomic intervals (gaps) to delete. This is usually left as ``None``
(the default) in which case ``minimum_gap`` and ``remove_telomeres`` are used
to determine the gaps to remove, and the calculated intervals are recorded in
the provenance of the resulting tree sequence.
:param bool split_disjoint: Run the {func}`split_disjoint_nodes` function
on the returned tree sequence, breaking any disjoint node into nodes that can
be dated separately (Default: ``None`` treated as ``True``).
:param bool filter_populations: parameter passed to the
{meth}`tskit.TreeSequence.simplify` command. Unlike calling that command
directly, this defaults to ``False``, such that all populations in the tree
sequence are kept.
:param bool filter_individuals: parameter passed to the
{meth}`tskit.TreeSequence.simplify` command. Unlike calling that command
directly, this defaults to ``False``, such
that all individuals in the tree sequence are kept.
:param bool filter_sites: parameter passed to the
{meth}`tskit.TreeSequence.simplify` command. Unlike calling that command
directly, this defaults to ``False``, such
that all sites in the tree sequence are kept.
:param bool record_provenance: If ``True``, record details of this call to
simplify in the returned tree sequence's provenance information
(Default: ``None`` treated as ``True``).
:param \\**kwargs: All further keyword arguments are passed to the
{meth}`tskit.TreeSequence.simplify` command.
:return: A tree sequence with gaps removed.
:rtype: tskit.TreeSequence
"""
logger.info("Beginning preprocessing")
logger.info(f"Minimum_gap: {minimum_gap} and remove_telomeres: {remove_telomeres}")
if split_disjoint is None:
split_disjoint = True
if record_provenance is None:
record_provenance = True
if delete_intervals is not None and (
minimum_gap is not None or remove_telomeres is not None
):
raise ValueError(
"Cannot specify both delete_intervals and minimum_gap/remove_telomeres"
)
tables = tree_sequence.dump_tables()
sites = tables.sites.position[:]
if delete_intervals is None:
if minimum_gap is None:
minimum_gap = 1000000
if remove_telomeres is None:
remove_telomeres = True
if tree_sequence.num_sites < 1:
raise ValueError("Invalid tree sequence: no sites present")
delete_intervals = []
if remove_telomeres:
first_site = sites[0] - 1
if first_site > 0:
delete_intervals.append([0, first_site])
logger.info(
"REMOVING TELOMERE: Snip topology "
f"from 0 to first site at {first_site}."
)
last_site = sites[-1] + 1
sequence_length = tables.sequence_length
if last_site < sequence_length:
delete_intervals.append([last_site, sequence_length])
logger.info(
"REMOVING TELOMERE: Snip topology "
f"from {last_site} to end of sequence at {sequence_length}."
)
gaps = sites[1:] - sites[:-1]
threshold_gaps = np.where(gaps >= minimum_gap)[0]
for gap in threshold_gaps:
gap_start = sites[gap] + 1
gap_end = sites[gap + 1] - 1
if gap_end > gap_start:
logger.info(
f"Gap Size is {gap_end - gap_start}. Snip topology "
f"from {gap_start} to {gap_end}."
)
delete_intervals.append([gap_start, gap_end])
delete_intervals = sorted(delete_intervals, key=lambda x: x[0])
if len(delete_intervals) > 0:
tables.delete_intervals(delete_intervals, simplify=False, record_provenance=False)
tables.simplify(
filter_populations=filter_populations,
filter_individuals=filter_individuals,
filter_sites=filter_sites,
record_provenance=False,
**kwargs,
)
else:
logger.info("No gaps to remove")
tables.simplify(
filter_populations=filter_populations,
filter_individuals=filter_individuals,
filter_sites=filter_sites,
record_provenance=False,
**kwargs,
)
if record_provenance:
provenance.record_provenance(
tables,
"preprocess_ts",
minimum_gap=minimum_gap,
remove_telomeres=remove_telomeres,
split_disjoint=split_disjoint,
filter_populations=filter_populations,
filter_individuals=filter_individuals,
filter_sites=filter_sites,
delete_intervals=delete_intervals,
)
ts = tables.tree_sequence()
if split_disjoint:
ts = split_disjoint_nodes(ts, record_provenance=False)
return ts
def nodes_time_unconstrained(tree_sequence):
"""
Return the unconstrained node times for every node in a tree sequence that has
been dated using ``tsdate`` with the inside-outside algorithm (these times are
stored in the node metadata). Will produce an error if the tree sequence does
not contain this information.
"""
nodes_time = tree_sequence.nodes_time.copy()
metadata = tree_sequence.tables.nodes.metadata
metadata_offset = tree_sequence.tables.nodes.metadata_offset
for index, met in enumerate(tskit.unpack_bytes(metadata, metadata_offset)):
if index not in tree_sequence.samples():
try:
nodes_time[index] = json.loads(met.decode())["mn"]
except (KeyError, json.decoder.JSONDecodeError) as err:
raise ValueError(
"Tree Sequence must be tsdated with the Inside-Outside Method."
) from err
return nodes_time
[docs]
def sites_time_from_ts(
tree_sequence, *, unconstrained=True, node_selection="child", min_time=1
):
"""
Returns an estimated "time" for each site. This is the estimated age of the oldest
MRCA which possesses a derived variant at that site, and is useful for performing
(re)inference of a tree sequence. It is calculated from the ages of nodes, with the
appropriate nodes identified by the position of mutations in the trees.
If node times in the tree sequence have been estimated by ``tsdate`` using the
inside-outside algorithm, then as well as a time in the tree sequence, nodes will
store additional time estimates that have not been explictly constrained by the
tree topology. By default, this function tries to use these "unconstrained" times,
although this is likely to fail (with a warning) on tree sequences that have not
been processed by ``tsdate``: in this case the standard node times can be used by
setting ``unconstrained=False``.
The concept of a site time is meaningless for non-variable sites, and
so the returned time for these sites is ``np.nan`` (note that this is not exactly
the same as tskit.UNKNOWN_TIME, which marks sites that could have a meaningful time
but whose time estimate is unknown).
:param tskit.TreeSequence tree_sequence: The input tree sequence.
:param bool unconstrained: Use estimated node times which have not been constrained
by tree topology. If ``True`` (default), this requires a tree sequence which has
been dated using the ``tsdate`` inside-outside algorithm. If this is not the
case, specify ``False`` to use the standard tree sequence node times.
:param str node_selection: Defines how site times are calculated from the age of
the upper and lower nodes that bound each mutation at the site. Options are
"child", "parent", "arithmetic" or "geometric", with the following meanings
* ``'child'`` (default): the site time is the age of the oldest node
*below* each mutation at the site
* ``'parent'``: the site time is the age of the oldest node *above* each
mutation at the site
* ``'arithmetic'``: the arithmetic mean of the ages of the node above and the
node below each mutation is calculated; the site time is the oldest
of these means.
* ``'geometric'``: the geometric mean of the ages of the node above and the
node below each mutation is calculated; the site time is the oldest
of these means
:param float min_time: A site time of zero implies that no MRCA in the past
possessed the derived variant, so the variant cannot be used for inferring
relationships between the samples. To allow all variants to be potentially
available for inference, if a site time would otherwise be calculated as zero
(for example, where the ``mutation_age`` parameter is "child" or "geometric"
and all mutations at a site are associated with leaf nodes), a minimum site
greater than 0 is recommended. By default this is set to 1, which is generally
reasonable for times measured in generations or years, although it is also
fine to set this to a small epsilon value.
:return: Array of length tree_sequence.num_sites with estimated time of each site
:rtype: numpy.ndarray(dtype=np.float64)
"""
if tree_sequence.num_sites < 1:
raise ValueError("Invalid tree sequence: no sites present")
if node_selection not in ["arithmetic", "geometric", "child", "parent"]:
raise ValueError(
"The node_selection parameter must be "
"'child', 'parent', 'arithmetic', or 'geometric'"
)
if unconstrained:
try:
nodes_time = nodes_time_unconstrained(tree_sequence)
except ValueError as e:
e.args += "Try calling sites_time_from_ts() with unconstrained=False."
raise
else:
nodes_time = tree_sequence.nodes_time
sites_time = np.full(tree_sequence.num_sites, np.nan)
for tree in tree_sequence.trees():
for site in tree.sites():
for mutation in site.mutations:
parent_node = tree.parent(mutation.node)
if node_selection == "child" or parent_node == tskit.NULL:
age = nodes_time[mutation.node]
else:
parent_age = nodes_time[parent_node]
if node_selection == "parent":
age = parent_age
elif node_selection == "arithmetic":
age = (nodes_time[mutation.node] + parent_age) / 2
elif node_selection == "geometric":
age = np.sqrt(nodes_time[mutation.node] * parent_age)
if np.isnan(sites_time[site.id]) or sites_time[site.id] < age:
sites_time[site.id] = age
if sites_time[site.id] < min_time:
sites_time[site.id] = min_time
return sites_time
[docs]
def add_sampledata_times(samples, sites_time):
"""
Return a tsinfer.SampleData file with estimated times associated with sites.
Ensures that each site's time is at least as old as the oldest historic sample
carrying a derived allele at that site.
:param tsinfer.formats.SampleData samples: A tsinfer SampleData object to
add site times to. Any historic individuals in this SampleData file are used to
constrain site times.
:return: A tsinfer.SampleData file
:rtype: tsinfer.SampleData
"""
if samples.num_sites != len(sites_time):
raise ValueError(
"sites_time should contain the same number of sites as the SampleData file"
)
# Get constraints from ancients
sites_bound = samples.min_site_times(individuals_only=True)
# Use maximum of constraints and estimated site times
sites_time = np.maximum(sites_time, sites_bound)
copy = samples.copy()
copy.sites_time[:] = sites_time
copy.finalise()
return copy
def mutation_span_array(tree_sequence):
"""Extract mutation counts and spans per edge into a two-column array"""
mutation_spans = np.zeros((tree_sequence.num_edges, 2))
mutation_edges = np.zeros(tree_sequence.num_mutations, dtype=np.int32)
for mut in tree_sequence.mutations():
mutation_edges[mut.id] = mut.edge
if mut.edge != tskit.NULL:
mutation_spans[mut.edge, 0] += 1
for edge in tree_sequence.edges():
mutation_spans[edge.id, 1] = edge.span
return mutation_spans, mutation_edges
# Some functions for changing tskit metadata
# See https://github.com/tskit-dev/tskit/discussions/2954
# TODO - potentially possible to speed up using numba?
def _reorder_nodes(node_table, order, extra_md_dict):
# extra_md_dict ({rowid: new_byte_metadata}) can be used to pass metadata to replace
# the existing metadata in a row. This works by creating new rows for the metadata,
# based on the algorithm in https://github.com/tskit-dev/tskit/discussions/2954
data = [node_table.metadata]
# add a list of new byte arrays, then concat
md_dtype, md_off_dtype = node_table.metadata.dtype, node_table.metadata_offset.dtype
data += [np.array(bytearray(v), dtype=md_dtype) for v in extra_md_dict.values()]
md = np.concatenate(data)
if len(md) == 0: # Common edge case: no metadata
md_off = np.zeros(len(order) + 1, dtype=md_off_dtype)
else:
extra_offsets = np.cumsum([len(d) for d in data], dtype=md_off_dtype)[1:]
md_off = np.concatenate((node_table.metadata_offset, extra_offsets))
arr = tskit.unpack_arrays(md, md_off)
if len(extra_md_dict) > 0:
# map the keys in extra_md_dict to the new row ids
d = {k: i + node_table.num_rows for i, k in enumerate(extra_md_dict.keys())}
md, md_off = tskit.pack_arrays([arr[d.get(i, i)] for i in order], md_dtype)
else:
md, md_off = tskit.pack_arrays([arr[i] for i in order], md_dtype)
node_table.set_columns(
flags=node_table.flags[order],
time=node_table.time[order],
population=node_table.population[order],
individual=node_table.individual[order],
metadata=md,
metadata_offset=md_off,
)
@numba.njit(_unituple(_i1w, 4)(_i1r, _i1r, _f1r, _f1r, _b1r))
def _split_disjoint_nodes(
edges_parent, edges_child, edges_left, edges_right, node_excluded
):
"""
Split disconnected regions of nodes into separate nodes.
Returns updated edges_parent, edges_child, mutations_node, and indices indicating
from which original node the new nodes are derived.
"""
assert edges_parent.size == edges_child.size == edges_left.size == edges_right.size
num_edges = edges_parent.size
num_nodes = node_excluded.size
# For each edge, check whether parent/child is separated by a gap from the
# previous edge involving either parent/child. Label disconnected segments
# per node by integers starting at zero.
edges_order = np.argsort(edges_left)
# TODO: is a sort really needed here?
edges_segments = np.full((2, num_edges), -1, dtype=np.int32)
nodes_segments = np.full(num_nodes, -1, dtype=np.int32)
nodes_right = np.full(node_excluded.size, -np.inf, dtype=np.float64)
for e in edges_order:
nodes = edges_parent[e], edges_child[e]
for i, n in enumerate(nodes):
if node_excluded[n]:
continue
nodes_segments[n] += edges_left[e] > nodes_right[n]
edges_segments[i, e] = nodes_segments[n]
nodes_right[n] = max(nodes_right[n], edges_right[e])
# Create "nodes_segments[i]" supplementary nodes by copying node "i".
# Store the id of the first supplement for each node in "nodes_map".
split_nodes = [] # the nodes in the original that were split
nodes_map = np.full(num_nodes, -1, dtype=np.int32)
for i, s in enumerate(nodes_segments):
for j in range(s):
if j == 0:
nodes_map[i] = num_nodes
split_nodes.append(i)
num_nodes += 1
split_nodes = np.array(split_nodes, dtype=np.int32)
nodes_order = np.arange(num_nodes, dtype=np.int32)
if len(split_nodes) > 0:
nodes_order[-len(split_nodes) :] = split_nodes
# Relabel the nodes on each edge given "nodes_map"
for e in edges_order:
nodes = edges_parent[e], edges_child[e]
for i, n in enumerate(nodes):
if edges_segments[i, e] > 0:
edges_segments[i, e] += nodes_map[n] - 1
else:
edges_segments[i, e] = n
edges_parent, edges_child = edges_segments[0, ...], edges_segments[1, ...]
return edges_parent, edges_child, nodes_order, split_nodes
@numba.njit(_i1w(_i1r, _f1r, _i1r, _i1r, _i1r, _f1r, _f1r, _i1r, _i1r))
def _relabel_mutations_node(
mutations_node,
mutations_position,
nodes_order,
edges_parent,
edges_child,
edges_left,
edges_right,
insert_index,
remove_index,
):
"""
Traverse trees, maintaining a mapping between old and new node IDs in the
current tree. Update `mutations_node` to reflect new IDs.
"""
assert edges_parent.size == edges_child.size == edges_left.size == edges_right.size
assert edges_parent.size == insert_index.size == remove_index.size
assert mutations_position.size == mutations_node.size
num_nodes = nodes_order.size
num_edges = edges_parent.size
num_mutations = mutations_position.size
insert_position = edges_left[insert_index]
remove_position = edges_right[remove_index]
sequence_length = remove_position[-1]
output = np.full(num_mutations, tskit.NULL, dtype=np.int32)
nodes_map = np.full(num_nodes, tskit.NULL, dtype=np.int32)
a, b, m = 0, 0, 0
left = 0.0
while left < sequence_length:
while b < num_edges and remove_position[b] == left: # edges out
b += 1
while a < num_edges and insert_position[a] == left: # edges in
e = insert_index[a]
c, p = edges_child[e], edges_parent[e]
nodes_map[nodes_order[c]] = c
nodes_map[nodes_order[p]] = p
a += 1
right = sequence_length
if b < num_edges:
right = min(right, remove_position[b])
if a < num_edges:
right = min(right, insert_position[a])
left = right
while m < num_mutations and mutations_position[m] < right:
assert nodes_map[mutations_node[m]] != tskit.NULL
output[m] = nodes_map[mutations_node[m]]
m += 1
return output
[docs]
def split_disjoint_nodes(ts, *, record_provenance=None):
"""
For each non-sample node, split regions separated by gaps into distinct
nodes, returning a tree sequence with potentially duplicated nodes.
Where there are multiple disconnected regions, the leftmost one is assigned
the ID of the original node, and the remainder are assigned new node IDs.
Population, flags, individual, time, and metadata are all copied into the
new nodes. Nodes that have been split will be flagged with
``tsdate.NODE_SPLIT_BY_PREPROCESS``. The metadata of these nodes will also be
updated with an `unsplit_node_id` field giving the node ID in the input tree
sequence to which they correspond. If this metadata cannot be set, a warning
is emitted.
:param bool record_provenance: If ``True``, record details of this call in the
returned tree sequence's provenance information (Default: ``None`` treated
as ``True``).
"""
metadata_key = "unsplit_node_id"
if record_provenance is None:
record_provenance = True
node_is_sample = np.bitwise_and(ts.nodes_flags, tskit.NODE_IS_SAMPLE).astype(bool)
edges_parent, edges_child, nodes_order, split_nodes = _split_disjoint_nodes(
ts.edges_parent,
ts.edges_child,
ts.edges_left,
ts.edges_right,
node_is_sample,
)
mutations_node = _relabel_mutations_node(
ts.mutations_node,
ts.sites_position[ts.mutations_site],
nodes_order,
edges_parent,
edges_child,
ts.edges_left,
ts.edges_right,
ts.indexes_edge_insertion_order,
ts.indexes_edge_removal_order,
)
tables = ts.dump_tables()
# Update the nodes table (complex because we have made new nodes)
flags = tables.nodes.flags
flags[split_nodes] |= tsdate.NODE_SPLIT_BY_PREPROCESS
tables.nodes.flags = flags
extra_md = {}
try:
for u in split_nodes:
md = ts.node(u).metadata
md[metadata_key] = int(u)
extra_md[u] = tables.nodes.metadata_schema.validate_and_encode_row(md)
except (TypeError, tskit.MetadataValidationError):
logger.warning(f"Could not set '{metadata_key}' on node metadata")
_reorder_nodes(tables.nodes, nodes_order, extra_md)
# Update the edges table
tables.edges.parent = edges_parent
tables.edges.child = edges_child
# Update the mutations table
tables.mutations.node = mutations_node
tables.sort()
assert np.array_equal(
tables.nodes.time[tables.mutations.node], ts.nodes_time[ts.mutations_node]
)
if record_provenance:
provenance.record_provenance(
tables,
"split_disjoint_nodes",
)
return tables.tree_sequence()
@numba.njit(_f1w(_f1r, _b1r, _i1r, _i1r, _f, _i))
def _constrain_ages(
nodes_time, nodes_fixed, edges_parent, edges_child, epsilon, max_iterations
):
"""
Approximate least squares solution to the positive branch length
constraint, using the method of alternating projections. Loosely based on
Dykstra's algorithm, see:
Dykstra RL, "An algorithm for restricted least squares regression", JASA
1983
"""
assert nodes_time.size == nodes_fixed.size
assert edges_parent.size == edges_child.size
num_edges = edges_parent.size
nodes_time = nodes_time.copy()
edges_cavity = np.zeros((num_edges, 2))
for _ in range(max_iterations): # method of alternating projections
if np.all(nodes_time[edges_parent] - nodes_time[edges_child] > 0):
return nodes_time
for e in range(num_edges):
p, c = edges_parent[e], edges_child[e]
nodes_time[c] -= edges_cavity[e, 0]
nodes_time[p] -= edges_cavity[e, 1]
adjustment = nodes_time[c] - nodes_time[p] # + epsilon
edges_cavity[e, :] = 0.0
if adjustment > 0:
assert not nodes_fixed[p] # TODO: no reason not to support this
edges_cavity[e, 0] = 0 if nodes_fixed[c] else -adjustment / 2
edges_cavity[e, 1] = adjustment if nodes_fixed[c] else adjustment / 2
nodes_time[c] += edges_cavity[e, 0]
nodes_time[p] += edges_cavity[e, 1]
# print(
# "min length:", np.min(nodes_time[edges_parent] - nodes_time[edges_child])
# )
for e in range(num_edges): # force constraint
p, c = edges_parent[e], edges_child[e]
if nodes_time[c] >= nodes_time[p]:
nodes_time[p] = nodes_time[c] + epsilon
return nodes_time
def constrain_ages(ts, nodes_time, epsilon=1e-6, max_iterations=0):
"""
Use a hybrid approach to adjust node times such that branch lengths are
positive. The first pass iteratively solves a constrained least squares
problem that seeks to find constrained ages as close as possible to
unconstrained ages. Progress is initially fast but typically becomes quite
slow, so after a fixed number of iterations the iterative algorithm
terminates and the constraint is forced.
:param tskit.TreeSequence ts: The input tree sequence, with arbitrary node
times.
:param np.ndarray nodes_time: Unconstrained node ages to inject into the
tree sequence.
:param float epsilon: The minimum allowed branch length when forcing
positive branch lengths.
:param int max_iterations: The number of iterations of alternating
projections before forcing positive branch lengths.
:return np.ndarray: Constrained node ages
"""
assert nodes_time.size == ts.num_nodes
assert epsilon >= 0
assert max_iterations >= 0
node_is_sample = np.bitwise_and(ts.nodes_flags, tskit.NODE_IS_SAMPLE).astype(bool)
constrained_nodes_time = _constrain_ages(
nodes_time,
node_is_sample,
ts.edges_parent,
ts.edges_child,
epsilon,
max_iterations,
)
modified = np.sum(~np.isclose(nodes_time, constrained_nodes_time))
if modified:
logging.info(f"Modified ages of {modified} nodes to satisfy constraints")
return constrained_nodes_time
def constrain_mutations(ts, nodes_time, mutations_edge):
"""
If the mutation is above a root, its age set to the age of the root. If
the mutation is between two internal nodes, the edge midpoint is used.
:param tskit.TreeSequence ts: The input tree sequence, with arbitrary node
times.
:param np.ndarray nodes_time: Constrained node ages.
:param np.ndarray mutations_edge: The edge that each mutation falls on.
:return np.ndarray: Constrained mutation ages
"""
parent = ts.edges_parent[mutations_edge]
child = ts.edges_child[mutations_edge]
parent_time = nodes_time[parent]
child_time = nodes_time[child]
assert np.all(parent_time > child_time), "Negative branch lengths"
mutations_time = (child_time + parent_time) / 2
internal = mutations_edge != tskit.NULL
constrained_time = np.full(mutations_time.size, tskit.UNKNOWN_TIME)
constrained_time[internal] = mutations_time[internal]
constrained_time[~internal] = nodes_time[ts.mutations_node[~internal]]
external = np.sum(~internal)
if external:
logging.info(f"Set ages of {external} nonsegregating mutations to root times.")
return constrained_time
@numba.njit(_b(_i1r, _f1r, _f1r, _i1r, _i1r, _f, _i))
def _contains_unary_nodes(
edges_parent,
edges_left,
edges_right,
indexes_insert,
indexes_remove,
sequence_length,
num_nodes,
):
assert edges_parent.size == edges_left.size == edges_right.size
assert indexes_insert.size == indexes_remove.size == edges_parent.size
num_edges = edges_parent.size
nodes_children = np.zeros(num_nodes, dtype=np.int32)
position_insert = edges_left[indexes_insert]
position_remove = edges_right[indexes_remove]
left = 0.0
a, b = 0, 0
while a < num_edges or b < num_edges:
check = set()
while b < num_edges and position_remove[b] == left: # edges out
e = indexes_remove[b]
p = edges_parent[e]
nodes_children[p] -= 1
check.add(p)
b += 1
while a < num_edges and position_insert[a] == left: # edges in
e = indexes_insert[a]
p = edges_parent[e]
nodes_children[p] += 1
check.add(p)
a += 1
for p in check:
if nodes_children[p] == 1:
return True
right = sequence_length
if b < num_edges:
right = min(right, position_remove[b])
if a < num_edges:
right = min(right, position_insert[a])
left = right
return False
def contains_unary_nodes(ts):
"""
Check if any node in the tree sequence is unary over some portion of its span
"""
return _contains_unary_nodes(
ts.edges_parent,
ts.edges_left,
ts.edges_right,
ts.indexes_edge_insertion_order,
ts.indexes_edge_removal_order,
ts.sequence_length,
ts.num_nodes,
)