Source code for tskit.jit.numba

import functools
import os

import numpy as np

import tskit

try:
    import numba

except ImportError:
    raise ImportError(
        "Numba is not installed. Please install it with `pip install numba` "
        "or `conda install numba` to use the tskit.jit.numba module."
    )


FORWARD = 1  #: Direction constant for forward tree traversal
REVERSE = -1  #: Direction constant for reverse tree traversal

# Retrieve these here to avoid lookups in tight loops
NODE_IS_SAMPLE = tskit.NODE_IS_SAMPLE
NULL = tskit.NULL

edge_range_spec = [
    ("start", numba.int32),
    ("stop", numba.int32),
    ("order", numba.int32[:]),
]

parent_index_spec = [
    ("edge_index", numba.int32[:]),
    ("index_range", numba.int32[:, :]),
]


[docs] @numba.experimental.jitclass(edge_range_spec) class EdgeRange: """ Represents a range of edges during tree traversal. This class encapsulates information about a contiguous range of edges that are either being removed or added to step from one tree to another. The ``start`` and ``stop`` indices, when applied to the order array, define the ids of edges to process. Attributes ---------- start : int Starting index of the edge range (inclusive). stop : int Stopping index of the edge range (exclusive). order : numpy.ndarray Array (dtype=np.int32) containing edge IDs in the order they should be processed. The edge ids in this range are order[start:stop]. """ def __init__(self, start, stop, order): self.start = start self.stop = stop self.order = order
[docs] @numba.experimental.jitclass(parent_index_spec) class ParentIndex: """ Simple data container for parent index information. This class provides access to all edges where a given node is the child. Since edges are not sorted by child in the tskit edge table, a custom index (edge_index) is built that sorts edge IDs by child node. `index_range` then contains the [start, stop) range of edges for each child node in `edge_index`. Attributes ---------- edge_index : numpy.ndarray Array (dtype=np.int32) of edge IDs sorted by child node and left coordinate. index_range : numpy.ndarray Array (dtype=np.int32, shape=(num_nodes, 2)) where each row contains the [start, stop) range in edge_index where this node is the child. """ def __init__(self, edge_index, index_range): self.edge_index = edge_index self.index_range = index_range
[docs] class TreeIndex: """ Traverse trees in a numba compatible tree sequence. This class provides efficient forward and backward iteration through the trees in a tree sequence. It provides the tree interval, edge changes to create the current tree, along with its sites and mutations. A full pass over the trees using repeated `next` or `prev` requires O(E + M + S) time complexity. It should not be instantiated directly, but is returned by the `tree_index` method of `NumbaTreeSequence`. Attributes ---------- ts : NumbaTreeSequence Reference to the tree sequence being traversed. index : int Current tree index. -1 indicates no current tree (null state). direction : int Traversal direction: tskit.FORWARD or tskit.REVERSE. tskit.NULL if uninitialised. interval : tuple Genomic interval (left, right) covered by the current tree. in_range : EdgeRange Edges being added to form this current tree, relative to the last state out_range : EdgeRange Edges being removed to form this current tree, relative to the last state site_range : tuple Range of sites in the current tree (start, stop). mutation_range : tuple Range of mutations in the current tree (start, stop). Example -------- >>> tree_index = numba_ts.tree_index() >>> num_edges = 0 >>> while tree_index.next(): num_edges += (tree_index.in_range.stop - tree_index.in_range.start) num_edges -= (tree_index.out_range.stop - tree_index.out_range.start) print(f"Tree {tree_index.index}: {num_edges} edges") """ def __init__(self, ts): self.ts = ts self.index = -1 self.direction = NULL self.interval = (0, 0) self.in_range = EdgeRange(0, 0, np.zeros(0, dtype=np.int32)) self.out_range = EdgeRange(0, 0, np.zeros(0, dtype=np.int32)) self.site_range = (0, 0) self.mutation_range = (0, 0)
[docs] def set_null(self): """ Reset the tree index to null state. """ self.index = -1 self.interval = (0, 0) self.site_range = (0, 0) self.mutation_range = (0, 0)
[docs] def next(self): # noqa: A003 """ Move to the next tree in forward direction. Updates the tree index to the next tree in the sequence, computing the edges that need to be added and removed to transform from the previous tree to the current tree. On the first call, this initializes the iterator and moves to tree 0. :return: True if successfully moved to next tree, False if the end of the tree sequence is reached. When False is returned, the iterator is in null state (index=-1). :rtype: bool """ M = self.ts.num_edges NS = self.ts.num_sites NM = self.ts.num_mutations breakpoints = self.ts.breakpoints left_coords = self.ts.edges_left left_order = self.ts.indexes_edge_insertion_order right_coords = self.ts.edges_right right_order = self.ts.indexes_edge_removal_order sites_position = self.ts.sites_position mutations_site = self.ts.mutations_site if self.index == -1: self.interval = (self.interval[0], 0) self.out_range.stop = 0 self.in_range.stop = 0 self.direction = FORWARD self.site_range = (0, 0) self.mutation_range = (0, 0) if self.direction == FORWARD: left_current_index = self.in_range.stop right_current_index = self.out_range.stop else: left_current_index = self.out_range.stop + 1 right_current_index = self.in_range.stop + 1 left = self.interval[1] j = right_current_index self.out_range.start = j while j < M and right_coords[right_order[j]] == left: j += 1 self.out_range.stop = j self.out_range.order = right_order j = left_current_index self.in_range.start = j while j < M and left_coords[left_order[j]] == left: j += 1 self.in_range.stop = j self.in_range.order = left_order self.direction = FORWARD self.index += 1 if self.index == self.ts.num_trees: self.set_null() else: right = breakpoints[self.index + 1] self.interval = (left, right) # Find sites in current tree interval [left, right) old_site_left, old_site_right = self.site_range j = old_site_right while j < NS and sites_position[j] < right: j += 1 self.site_range = (old_site_right, j) # Find mutations for sites in this interval old_mutation_left, old_mutation_right = self.mutation_range k = old_mutation_right while k < NM and mutations_site[k] < j: k += 1 self.mutation_range = (old_mutation_right, k) return self.index != -1
[docs] def prev(self): """ Move to the previous tree in reverse direction. Updates the tree index to the previous tree in the sequence, computing the edges that need to be added and removed to transform from the next tree to the current tree. On the first call, this initializes the iterator and moves to the most rightward tree. :return: True if successfully moved to previous tree, False if the beginning of the tree sequence is reached. When False is returned, the iterator is in null state (index=-1). :rtype: bool """ M = self.ts.num_edges NS = self.ts.num_sites NM = self.ts.num_mutations breakpoints = self.ts.breakpoints right_coords = self.ts.edges_right right_order = self.ts.indexes_edge_removal_order left_coords = self.ts.edges_left left_order = self.ts.indexes_edge_insertion_order sites_position = self.ts.sites_position mutations_site = self.ts.mutations_site if self.index == -1: self.index = self.ts.num_trees self.interval = (self.ts.sequence_length, self.interval[1]) self.in_range.stop = M - 1 self.out_range.stop = M - 1 self.direction = REVERSE self.site_range = (NS, NS) self.mutation_range = (NM, NM) if self.direction == REVERSE: left_current_index = self.out_range.stop right_current_index = self.in_range.stop else: left_current_index = self.in_range.stop - 1 right_current_index = self.out_range.stop - 1 right = self.interval[0] j = left_current_index self.out_range.start = j while j >= 0 and left_coords[left_order[j]] == right: j -= 1 self.out_range.stop = j self.out_range.order = left_order j = right_current_index self.in_range.start = j while j >= 0 and right_coords[right_order[j]] == right: j -= 1 self.in_range.stop = j self.in_range.order = right_order self.direction = REVERSE self.index -= 1 if self.index == -1: self.set_null() else: left = breakpoints[self.index] self.interval = (left, right) # Find sites in current tree interval [left, right) going backward old_site_left, old_site_right = self.site_range j = old_site_left - 1 while j >= 0 and sites_position[j] >= left: j -= 1 self.site_range = (j + 1, old_site_left) # Find mutations for sites in this interval going backward old_mutation_left, old_mutation_right = self.mutation_range k = old_mutation_left - 1 while k >= 0 and mutations_site[k] >= self.site_range[0]: k -= 1 self.mutation_range = (k + 1, old_mutation_left) return self.index != -1
[docs] class NumbaTreeSequence: """ A Numba-compatible representation of a tree sequence. This class provides access a tree sequence class that can be used from within Numba "njit" compiled functions. :meth:`jitwrap` should be used to JIT compile this class from a :class:`tskit.TreeSequence` object, before it is passed to a Numba function. Attributes ---------- num_trees : int Number of trees in the tree sequence. num_nodes : int Number of nodes in the tree sequence. num_samples : int Number of samples in the tree sequence. num_edges : int Number of edges in the tree sequence. num_sites : int Number of sites in the tree sequence. num_mutations : int Number of mutations in the tree sequence. sequence_length : float Total sequence length of the tree sequence. edges_left : numpy.ndarray Array (dtype=np.float64) of left coordinates of edges. edges_right : numpy.ndarray Array (dtype=np.float64) of right coordinates of edges. edges_parent : numpy.ndarray Array (dtype=np.int32) of parent node IDs for each edge. edges_child : numpy.ndarray Array (dtype=np.int32) of child node IDs for each edge. nodes_time : numpy.ndarray Array (dtype=np.float64) of time values for each node. nodes_flags : numpy.ndarray Array (dtype=np.uint32) of flag values for each node. nodes_population : numpy.ndarray Array (dtype=np.int32) of population IDs for each node. nodes_individual : numpy.ndarray Array (dtype=np.int32) of individual IDs for each node. individuals_flags : numpy.ndarray Array (dtype=np.uint32) of flag values for each individual. sites_position : numpy.ndarray Array (dtype=np.float64) of positions of sites along the sequence. mutations_site : numpy.ndarray Array (dtype=np.int32) of site IDs for each mutation. mutations_node : numpy.ndarray Array (dtype=np.int32) of node IDs for each mutation. mutations_parent : numpy.ndarray Array (dtype=np.int32) of parent mutation IDs. mutations_time : numpy.ndarray Array (dtype=np.float64) of time values for each mutation. breakpoints : numpy.ndarray Array (dtype=np.float64) of genomic positions where trees change. indexes_edge_insertion_order : numpy.ndarray Array (dtype=np.int32) specifying the order in which edges are inserted during tree building. indexes_edge_removal_order : numpy.ndarray Array (dtype=np.int32) specifying the order in which edges are removed during tree building. """ def __init__( self, num_trees, num_nodes, num_samples, num_edges, num_sites, num_mutations, sequence_length, edges_left, edges_right, indexes_edge_insertion_order, indexes_edge_removal_order, individuals_flags, nodes_time, nodes_flags, nodes_population, nodes_individual, edges_parent, edges_child, sites_position, sites_ancestral_state, mutations_site, mutations_node, mutations_parent, mutations_time, mutations_derived_state, mutations_inherited_state, breakpoints, max_ancestral_length, max_derived_length, max_inherited_length, ): self.num_trees = num_trees self.num_nodes = num_nodes self.num_samples = num_samples self.num_edges = num_edges self.num_sites = num_sites self.num_mutations = num_mutations self.sequence_length = sequence_length self.edges_left = edges_left self.edges_right = edges_right self.indexes_edge_insertion_order = indexes_edge_insertion_order self.indexes_edge_removal_order = indexes_edge_removal_order self.individuals_flags = individuals_flags self.nodes_time = nodes_time self.nodes_flags = nodes_flags self.nodes_population = nodes_population self.nodes_individual = nodes_individual self.edges_parent = edges_parent self.edges_child = edges_child self.sites_position = sites_position self.sites_ancestral_state = sites_ancestral_state self.mutations_site = mutations_site self.mutations_node = mutations_node self.mutations_parent = mutations_parent self.mutations_time = mutations_time self.mutations_derived_state = mutations_derived_state self.mutations_inherited_state = mutations_inherited_state self.breakpoints = breakpoints self.max_ancestral_length = max_ancestral_length self.max_derived_length = max_derived_length self.max_inherited_length = max_inherited_length
[docs] def tree_index(self): """ Create a :class:`TreeIndex` for traversing this tree sequence. :return: A new tree index initialized to the null tree. Use next() or prev() to move to an actual tree. :rtype: TreeIndex """ # This method will be overriden when the concrete JIT class TreeIndex # is defined in `jitwrap`. return TreeIndex(self) # pragma: no cover
[docs] def child_index(self): """ Create child index array for finding child edges of nodes. This operation requires a linear pass over the edge table and therefore has a time complexity of O(E). :return: A numpy array (dtype=np.int32, shape=(num_nodes, 2)) where each row contains the [start, stop) range of edges where this node is the parent. :rtype: numpy.ndarray """ child_range = np.full((self.num_nodes, 2), -1, dtype=np.int32) edges_parent = self.edges_parent if self.num_edges == 0: return child_range # Find ranges in tskit edge ordering last_parent = -1 for edge_id in range(self.num_edges): parent = edges_parent[edge_id] if parent != last_parent: child_range[parent, 0] = edge_id if last_parent != -1: child_range[last_parent, 1] = edge_id last_parent = parent if last_parent != -1: child_range[last_parent, 1] = self.num_edges return child_range
[docs] def parent_index(self): """ Create a :class:`ParentIndex` for finding parent edges of nodes. Edges within each child's group are not guaranteed to be in any specific order. This operation uses a two-pass algorithm with O(N + E) time complexity and O(N) auxiliary space. :return: A new parent index container that can be used to efficiently find all edges where a given node is the child. :rtype: ParentIndex """ num_nodes = self.num_nodes num_edges = self.num_edges edges_child = self.edges_child child_counts = np.zeros(num_nodes, dtype=np.int32) edge_index = np.zeros(num_edges, dtype=np.int32) index_range = np.zeros((num_nodes, 2), dtype=np.int32) if num_edges == 0: return ParentIndex(edge_index, index_range) # Count how many children each node has for child_node in edges_child: child_counts[child_node] += 1 # From the counts build the index ranges, we set both the start and the # end index to the start - this lets us use the end index as a tracker # for where we should insert the next edge for that node - when all # edges are done these values will be the correct end values! current_start = 0 for i in range(num_nodes): index_range[i, :] = current_start current_start += child_counts[i] # Now go over the edges, inserting them at the index pointed to # by the node's current end value, then increment. for edge_id in range(num_edges): child = edges_child[edge_id] pos = index_range[child, 1] edge_index[pos] = edge_id index_range[child, 1] += 1 return ParentIndex(edge_index, index_range)
# We cache these classes to avoid repeated JIT compilation @functools.lru_cache(None) def _jitwrap(max_ancestral_length, max_derived_length, max_inherited_length): # We have a circular dependency in JIT compilation between NumbaTreeSequence # and NumbaTreeIndex so we used a deferred type to break it tree_sequence_type = numba.deferred_type() # We run this code on CI with this env var set so we can get coverage # of the jitted functions. EdgeRange doesn't have a class_type # in this case, so we skip the spec entirely. if os.environ.get("NUMBA_DISABLE_JIT") == "1": tree_index_spec = [] else: tree_index_spec = [ ("ts", tree_sequence_type), ("index", numba.int32), ("direction", numba.int32), ("interval", numba.types.UniTuple(numba.float64, 2)), ("in_range", EdgeRange.class_type.instance_type), ("out_range", EdgeRange.class_type.instance_type), ("site_range", numba.types.UniTuple(numba.int32, 2)), ("mutation_range", numba.types.UniTuple(numba.int32, 2)), ] JittedTreeIndex = numba.experimental.jitclass(tree_index_spec)(TreeIndex) tree_sequence_spec = [ ("num_trees", numba.int32), ("num_nodes", numba.int32), ("num_samples", numba.int32), ("num_edges", numba.int32), ("num_sites", numba.int32), ("num_mutations", numba.int32), ("sequence_length", numba.float64), ("edges_left", numba.float64[:]), ("edges_right", numba.float64[:]), ("indexes_edge_insertion_order", numba.int32[:]), ("indexes_edge_removal_order", numba.int32[:]), ("individuals_flags", numba.uint32[:]), ("nodes_time", numba.float64[:]), ("nodes_flags", numba.uint32[:]), ("nodes_population", numba.int32[:]), ("nodes_individual", numba.int32[:]), ("edges_parent", numba.int32[:]), ("edges_child", numba.int32[:]), ("sites_position", numba.float64[:]), ("sites_ancestral_state", numba.types.UnicodeCharSeq(max_ancestral_length)[:]), ("mutations_site", numba.int32[:]), ("mutations_node", numba.int32[:]), ("mutations_parent", numba.int32[:]), ("mutations_time", numba.float64[:]), ("mutations_derived_state", numba.types.UnicodeCharSeq(max_derived_length)[:]), ( "mutations_inherited_state", numba.types.UnicodeCharSeq(max_inherited_length)[:], ), ("breakpoints", numba.float64[:]), ("max_ancestral_length", numba.int32), ("max_derived_length", numba.int32), ("max_inherited_length", numba.int32), ] # The `tree_index` method on NumbaTreeSequence uses NumbaTreeIndex # which is the uncompiled version of the class. The compiled version isn't # known till now, so replace the method with this definition. class _NumbaTreeSequence(NumbaTreeSequence): def tree_index(self): return JittedTreeIndex(self) JittedTreeSequence = numba.experimental.jitclass(tree_sequence_spec)( _NumbaTreeSequence ) # Now both classes are setup we can resolve the deferred type if os.environ.get("NUMBA_DISABLE_JIT") != "1": tree_sequence_type.define(JittedTreeSequence.class_type.instance_type) return JittedTreeSequence
[docs] def jitwrap(ts): """ Convert a TreeSequence to a Numba-compatible format. Creates a NumbaTreeSequence object that can be used within Numba-compiled functions. :param tskit.TreeSequence ts: The tree sequence to convert. :return: A Numba-compatible representation of the input tree sequence. Contains all necessary data arrays and metadata for tree traversal. :rtype: NumbaTreeSequence """ max_ancestral_length = max(1, max(map(len, ts.sites_ancestral_state), default=1)) max_derived_length = max(1, max(map(len, ts.mutations_derived_state), default=1)) max_inherited_length = max( 1, max(map(len, ts.mutations_inherited_state), default=1) ) JittedTreeSequence = _jitwrap( max_ancestral_length, max_derived_length, max_inherited_length ) # Create the tree sequence instance numba_ts = JittedTreeSequence( num_trees=ts.num_trees, num_nodes=ts.num_nodes, num_samples=ts.num_samples, num_edges=ts.num_edges, num_sites=ts.num_sites, num_mutations=ts.num_mutations, sequence_length=ts.sequence_length, edges_left=ts.edges_left, edges_right=ts.edges_right, indexes_edge_insertion_order=ts.indexes_edge_insertion_order, indexes_edge_removal_order=ts.indexes_edge_removal_order, individuals_flags=ts.individuals_flags, nodes_time=ts.nodes_time, nodes_flags=ts.nodes_flags, nodes_population=ts.nodes_population, nodes_individual=ts.nodes_individual, edges_parent=ts.edges_parent, edges_child=ts.edges_child, sites_position=ts.sites_position, sites_ancestral_state=ts.sites_ancestral_state.astype( f"U{max_ancestral_length}" ), mutations_site=ts.mutations_site, mutations_node=ts.mutations_node, mutations_parent=ts.mutations_parent, mutations_time=ts.mutations_time, mutations_derived_state=ts.mutations_derived_state.astype( f"U{max_derived_length}" ), mutations_inherited_state=ts.mutations_inherited_state.astype( f"U{max_inherited_length}" ), breakpoints=ts.breakpoints(as_array=True), max_ancestral_length=max_ancestral_length, max_derived_length=max_derived_length, max_inherited_length=max_inherited_length, ) return numba_ts