Numba Integration#
The tskit.jit.numba
module provides classes for working with tree sequences
from Numba jit-compiled Python code. Such code can run
up to hundreds of times faster than normal Python, yet avoids the difficulties of writing
C or other low-level code.
Note
Numba is not a direct dependency of tskit, so will not be available unless installed:
pip install numba
or
conda install numba
Overview#
The numba integration provides a tskit.TreeSequence
wrapper class NumbaTreeSequence
.
This class can be used directly in numba.njit
compiled functions, and provides several efficient
methods for tree traversal:
tree_index()
: For efficient iteration through the trees in the sequenceparent_index()
: For efficient access to parent edge information, to traverse upwards through the ARG.child_index()
: For efficient access to child edge information, to traverse downwards through the ARG.
These methods are optimised to work within Numba’s @njit
decorated functions,
allowing you to write high-performance tree sequence analysis code in a plain Python style.
Basic Usage#
The tskit.jit.numba
module is not imported with normal tskit
so must be imported explicitly:
import numpy as np
import tskit
import tskit.jit.numba as tskit_numba
Normal third-party classes such as tskit.TreeSequence
can’t be used in numba.njit
compiled
functions so the tskit.TreeSequence
must be wrapped in a NumbaTreeSequence
by
jitwrap()
. This must be done outside njit
code:
import msprime
ts = msprime.sim_ancestry(
samples=50000,
sequence_length=100000,
recombination_rate=0.1,
random_seed=42
)
numba_ts = tskit_numba.jitwrap(ts)
print(type(numba_ts))
<class 'numba.experimental.jitclass.boxing._NumbaTreeSequence'>
Tree Iteration#
Tree iteration can be performed in numba.njit
compiled functions using the TreeIndex
class.
This class provides next()
and prev()
methods for forward and backward iteration through the trees in a tree sequence. Its in_range
and out_range
attributes provide the edges that must be added or removed to form the current
tree from the previous tree, along with the current tree interval
and its sites and mutations through site_range
and mutation_range
.
A TreeIndex
instance can be obtained from a NumbaTreeSequence
using the tree_index()
method. The initial state of this is of a “null” tree outside the range of the tree sequence, the first call to next()
or prev()
will be to the first, or last tree sequence tree respectively. After that, the in_range
and out_range
attributes will provide the edges that must be added or removed to form the current tree from the previous tree. For example
tree_index.in_range.order[in_range.start:in_range.stop]
will give the edge ids that are new in the current tree, and tree_index.out_range.order[out_range.start:out_range.stop]
will give the edge ids that are no longer present in the current tree. tree_index.site_range
and
tree_index.mutation_range
give the indexes into the tree sequences site and mutation arrays.
As a simple example we can calculate the number of edges in each tree in a tree sequence:
import numba
@numba.njit
def edges_per_tree(numba_ts):
tree_index = numba_ts.tree_index()
current_num_edges = 0
num_edges = []
# Move forward through the trees
while tree_index.next():
# Access current tree information
in_range = tree_index.in_range
out_range = tree_index.out_range
current_num_edges -= (out_range.stop - out_range.start)
current_num_edges += (in_range.stop - in_range.start)
num_edges.append(current_num_edges)
return num_edges
import time
t = time.time()
jit_num_edges = edges_per_tree(numba_ts)
print(f"JIT Time taken: {time.time() - t:.4f} seconds")
JIT Time taken: 0.0320 seconds
Doing the same thing with the normal tskit
API would be much slower:
t = time.time()
python_num_edges = []
for tree in ts.trees():
python_num_edges.append(tree.num_edges)
print(f"Normal Time taken: {time.time() - t:.4f} seconds")
assert jit_num_edges == python_num_edges, "JIT and normal results do not match!"
Normal Time taken: 0.3796 seconds
Example - diversity calculation#
As a more interesting example we can calculate genetic diversity (also known as pi).
For this example we’ll be calculating based on the distance in the tree between samples.
(mode="branch"
in the tskit API.)
This example also shows the style of Python code that gives best performance under numba
JIT compilation - using simple loops and fixed-size arrays with minimal object attribute access.
@numba.njit
def diversity(numba_ts):
# Cache arrays to avoid repeated attribute access in
# tight loops
edge_child = numba_ts.edges_child
edge_parent = numba_ts.edges_parent
node_times = numba_ts.nodes_time
node_flags = numba_ts.nodes_flags
if numba_ts.num_samples <= 1:
return 0.0
parent = np.full(numba_ts.num_nodes, -1, dtype=np.int32)
branch_length = np.zeros(numba_ts.num_nodes, dtype=np.float64)
state = np.zeros(numba_ts.num_nodes, dtype=np.int32)
summary = np.zeros(numba_ts.num_nodes, dtype=np.float64)
n = float(numba_ts.num_samples)
two_over_denom = 2.0 / (n * (n - 1.0))
sample_summary = 2.0 / n
# Retrieve this constant outside the loop
# to avoid repeated attribute access
NODE_IS_SAMPLE = tskit.NODE_IS_SAMPLE
# Find the sample nodes and initialize their states
for node in range(numba_ts.num_nodes):
if node_flags[node] & NODE_IS_SAMPLE:
state[node] = 1.0
summary[node] = sample_summary
result = 0.0
running_sum = 0.0
tree_index = numba_ts.tree_index()
# Now iterate through the trees
while tree_index.next():
# Process the outgoing edges
for j in range(tree_index.out_range.start, tree_index.out_range.stop):
h = tree_index.out_range.order[j]
child = edge_child[h]
child_parent = edge_parent[h]
running_sum -= branch_length[child] * summary[child]
parent[child] = -1
branch_length[child] = 0.0
u = child_parent
parent_u = parent[u]
while u != -1:
running_sum -= branch_length[u] * summary[u]
state[u] -= state[child]
summary[u] = state[u] * (n - state[u]) * two_over_denom
running_sum += branch_length[u] * summary[u]
u = parent_u
if u != -1:
parent_u = parent[u]
# Process the incoming edges
for j in range(tree_index.in_range.start, tree_index.in_range.stop):
h = tree_index.in_range.order[j]
child = edge_child[h]
child_parent = edge_parent[h]
parent[child] = child_parent
branch_length[child] = node_times[child_parent] - node_times[child]
running_sum += branch_length[child] * summary[child]
u = child_parent
parent_u = parent[u]
while u != -1:
running_sum -= branch_length[u] * summary[u]
state[u] += state[child]
summary[u] = state[u] * (n - state[u]) * two_over_denom
running_sum += branch_length[u] * summary[u]
u = parent_u
if u != -1:
parent_u = parent[u]
result += running_sum * (
tree_index.interval[1] - tree_index.interval[0]
)
return result / numba_ts.sequence_length
t = time.time()
d = diversity(numba_ts)
print("Diversity:", d)
print("Time taken:", time.time() - t)
Diversity: 3.992633119145986
Time taken: 0.27684879302978516
As this code is written for this specific diversity calculation it is even faster than the tskit C implementation, called here from Python:
t = time.time()
d_tskit = ts.diversity(mode="branch")
print("Diversity (tskit):", d_tskit)
print("Time taken:", time.time() - t)
Diversity (tskit): 3.992633119146012
Time taken: 0.5163848400115967
ARG Traversal#
Beyond iterating through trees, you may need to traverse the ARG vertically. The child_index()
and parent_index()
methods provide efficient access to parent-child relationships in the edge table within numba.njit
functions.
The child_index()
method returns an array that allows you to efficiently find all edges where a given node is the parent. Since edges are already sorted by parent in the tskit data model, this is implemented using simple range indexing. For any node u
, the returned array child_index[u]
gives a tuple of the start and stop indices in the tskit edge table where node u
is the parent. The index is calculated on each call to child_index()
so should be called once.
The parent_index()
method creates a ParentIndex
that allows you to efficiently find all edges where a given node is the child. Since edges are not sorted by child in the edge table, the returned class contains a custom index that sorts edge IDs by child node (and then by left coordinate). For any node u
, parent_index.index_range[u]
gives a tuple of the start and stop indices in the parent_index.edge_index
array, and parent_index.edge_index[start:stop]
gives the actual tskit edge IDs.
Both can be obtained from a NumbaTreeSequence
:
# Get the indexes
child_index = numba_ts.child_index()
parent_index = numba_ts.parent_index()
# Example: find all left coordinates of edges where node 5 is the parent
start, stop = child_index[5]
left_coords = numba_ts.edges_left[start:stop]
print(left_coords)
# Example: find all right coordinates of edges where node 3 is the child
start, stop = parent_index.index_range[3]
right_coords = numba_ts.edges_right[start:stop]
print(right_coords)
[]
[100000.]
These indexes enable efficient algorithms that need to traverse parent-child relationships in the ARG, such as computing descendant sets, ancestral paths, or subtree properties.
Example - descendant span calculation#
Here’s an example of using the ARG traversal indexes to calculate the total sequence length over which each node descends from a specified node:
@numba.njit
def descendant_span(numba_ts, u):
"""
Calculate the total sequence length over which each node
descends from the specified node u.
"""
child_index = numba_ts.child_index()
edges_left = numba_ts.edges_left
edges_right = numba_ts.edges_right
edges_child = numba_ts.edges_child
total_descending = np.zeros(numba_ts.num_nodes)
stack = [(u, 0.0, numba_ts.sequence_length)]
while len(stack) > 0:
node, left, right = stack.pop()
total_descending[node] += right - left
# Find all child edges for this node
for e in range(child_index[node, 0], child_index[node, 1]):
e_left = edges_left[e]
e_right = edges_right[e]
# Check if edge overlaps with current interval
if e_right > left and right > e_left:
inter_left = max(e_left, left)
inter_right = min(e_right, right)
e_child = edges_child[e]
stack.append((e_child, inter_left, inter_right))
return total_descending
# Calculate descendant span for the root node (highest numbered node)
root_node = numba_ts.num_nodes - 1
result = descendant_span(numba_ts, root_node)
# Show nodes that have non-zero descendant span
non_zero = result > 0
print(f"Nodes descended from {root_node}:")
print(f"Node IDs: {np.where(non_zero)[0]}")
print(f"Span lengths: {result[non_zero]}")
Nodes descended from 495178:
Node IDs: [ 0 1 2 ... 457696 469935 495178]
Span lengths: [1.e+00 1.e+00 1.e+00 ... 1.e+00 1.e+00 1.e+05]
Comparing performance with using the tskit Python API:
def descendant_span_tskit(ts, u):
"""Reference implementation using tskit trees"""
total_descending = np.zeros(ts.num_nodes)
for tree in ts.trees():
descendants = tree.preorder(u)
total_descending[descendants] += tree.span
return total_descending
import time
t = time.time()
numba_result = descendant_span(numba_ts, root_node)
print(f"Numba time: {time.time() - t:.6f} seconds")
t = time.time()
tskit_result = descendant_span_tskit(ts, root_node)
print(f"tskit time: {time.time() - t:.6f} seconds")
np.testing.assert_array_almost_equal(numba_result, tskit_result, decimal=10)
print("Results match!")
Numba time: 0.014827 seconds
tskit time: 0.866002 seconds
Results match!
Example - ARG descendant and ancestral edges calculation#
As we have child_index
and parent_index
, we can efficiently find both descendant and ancestral sub-ARGs
for a given node. This first example shows how to find all edges in the ARG that are descendants of a given node. It returns a boolean array indicating which edges are part of the sub-ARG rooted at the specified node:
@numba.njit
def descendant_edges(numba_ts, u):
"""
Returns a boolean array which is only True for edges that are descendants of node u.
"""
edge_select = np.zeros(numba_ts.num_edges, dtype=np.bool_)
child_index = numba_ts.child_index()
edges_left = numba_ts.edges_left
edges_right = numba_ts.edges_right
edges_child = numba_ts.edges_child
# The stack stores (node_id, left_coord, right_coord)
stack = [(u, 0.0, numba_ts.sequence_length)]
while len(stack) > 0:
node, left, right = stack.pop()
# Find all edges where 'node' is the parent
start, stop = child_index[node]
for e in range(start, stop):
e_left = edges_left[e]
e_right = edges_right[e]
# Check for genomic interval overlap
if e_right > left and right > e_left:
# This edge is part of the sub-ARG
edge_select[e] = True
# Calculate the intersection for the next traversal step
inter_left = max(e_left, left)
inter_right = min(e_right, right)
e_child = edges_child[e]
stack.append((e_child, inter_left, inter_right))
return edge_select
# Find descendant edges for a high-numbered node (likely near root)
test_node = max(0, numba_ts.num_nodes - 5)
edge_select = descendant_edges(numba_ts, test_node)
# Show which edges are descendants
descendant_edge_ids = np.where(edge_select)[0]
print(f"Edges descended from node {test_node}: {descendant_edge_ids[:10]}...")
print(f"Total descendant edges: {np.sum(edge_select)}")
Edges descended from node 495174: [0 1 2 3 4 5 6 7 8 9]...
Total descendant edges: 199998
A tree sequence is easily made from the descendant edges array:
numba_ts_simple = tskit_numba.jitwrap(ts_simple)
node = 5
E = descendant_edges(numba_ts_simple, node)
tables_sub = ts_simple.dump_tables()
tables_sub.edges.replace_with(tables_sub.edges[E])
ts_sub = tables_sub.tree_sequence()
As an example, lets visualise the selection of a sub-ARG. Here is the full ARG with a highlighted node:
css_style = f".node.n{node} > .sym {{ fill: #c41e3a; }}"
ts_simple.draw_svg(size=(400, 200), node_labels={}, y_axis=True, style=css_style)
And the sub-ARG from that node:
ts_sub.draw_svg(size=(400, 200), node_labels={}, y_axis=True, style=css_style)
In the other direction, we can similarly find the sub-ARG that is ancestral to a given node:
@numba.njit
def ancestral_edges(numba_ts, u):
"""
Returns a boolean array which is only True for edges that are ancestors of node u.
"""
edge_select = np.zeros(numba_ts.num_edges, dtype=np.bool_)
parent_index = numba_ts.parent_index()
edges_left = numba_ts.edges_left
edges_right = numba_ts.edges_right
edges_parent = numba_ts.edges_parent
# The stack stores (node_id, left_coord, right_coord)
stack = [(u, 0.0, numba_ts.sequence_length)]
while len(stack) > 0:
node, left, right = stack.pop()
# Find all edges where 'node' is the child
start, stop = parent_index.index_range[node]
for i in range(start, stop):
e = parent_index.edge_index[i]
e_left = edges_left[e]
e_right = edges_right[e]
# Check for genomic interval overlap
if e_right > left and right > e_left:
# This edge is part of the sub-ARG
edge_select[e] = True
# Calculate the intersection for the next traversal step
inter_left = max(e_left, left)
inter_right = min(e_right, right)
e_parent = edges_parent[e]
stack.append((e_parent, inter_left, inter_right))
return edge_select
# Find ancestral edges for a sample node (low-numbered nodes are usually samples)
test_node = min(5, numba_ts.num_nodes - 1)
edge_select = ancestral_edges(numba_ts, test_node)
# Show which edges are ancestors
ancestral_edge_ids = np.where(edge_select)[0]
print(f"Edges ancestral to node {test_node}: {ancestral_edge_ids[:10]}...")
print(f"Total ancestral edges: {np.sum(edge_select)}")
Edges ancestral to node 5: [ 57753 59866 59868 112525 170304 206560 206561 206562 252004 280568]...
Total ancestral edges: 110058
Comparing performance with using the tskit Python API shows significant speedup:
def descendant_edges_tskit(ts, start_node):
D = np.zeros(ts.num_edges, dtype=bool)
for tree in ts.trees():
for v in tree.preorder(start_node):
if v != start_node:
D[tree.edge(v)] = True
return D
def ancestral_edges_tskit(ts, start_node):
A = np.zeros(ts.num_edges, dtype=bool)
for tree in ts.trees():
curr_node = start_node
parent = tree.parent(curr_node)
while parent != tskit.NULL:
edge_id = tree.edge(curr_node)
A[edge_id] = True
curr_node = parent
parent = tree.parent(curr_node)
return A
import time
# Test with root node for descendant edges
root_node = numba_ts.num_nodes - 1
t = time.time()
numba_desc = descendant_edges(numba_ts, root_node)
print(f"Numba descendant edges time: {time.time() - t:.6f} seconds")
t = time.time()
tskit_desc = descendant_edges_tskit(ts, root_node)
print(f"tskit descendant edges time: {time.time() - t:.6f} seconds")
# Test with sample node for ancestral edges
sample_node = 0
t = time.time()
numba_anc = ancestral_edges(numba_ts, sample_node)
print(f"Numba ancestral edges time: {time.time() - t:.6f} seconds")
t = time.time()
tskit_anc = ancestral_edges_tskit(ts, sample_node)
print(f"tskit ancestral edges time: {time.time() - t:.6f} seconds")
# Verify results match
np.testing.assert_array_equal(numba_desc, tskit_desc)
np.testing.assert_array_equal(numba_anc, tskit_anc)
print("Results match!")
Numba descendant edges time: 0.016229 seconds
tskit descendant edges time: 0.497277 seconds
Numba ancestral edges time: 0.025768 seconds
tskit ancestral edges time: 0.919687 seconds
Results match!
API Reference#
- tskit.jit.numba.jitwrap(ts)[source]#
Convert a TreeSequence to a Numba-compatible format.
Creates a NumbaTreeSequence object that can be used within Numba-compiled functions.
- Parameters:
ts (tskit.TreeSequence) – The tree sequence to convert.
- Returns:
A Numba-compatible representation of the input tree sequence. Contains all necessary data arrays and metadata for tree traversal.
- Return type:
- class tskit.jit.numba.NumbaTreeSequence(num_trees, num_nodes, num_samples, num_edges, num_sites, num_mutations, sequence_length, edges_left, edges_right, indexes_edge_insertion_order, indexes_edge_removal_order, individuals_flags, nodes_time, nodes_flags, nodes_population, nodes_individual, edges_parent, edges_child, sites_position, sites_ancestral_state, mutations_site, mutations_node, mutations_parent, mutations_time, mutations_derived_state, breakpoints, max_ancestral_length, max_derived_length)[source]#
A Numba-compatible representation of a tree sequence.
This class provides access a tree sequence class that can be used from within Numba “njit” compiled functions.
jitwrap()
should be used to JIT compile this class from atskit.TreeSequence
object, before it is passed to a Numba function.Attributes#
- num_treesint
Number of trees in the tree sequence.
- num_nodesint
Number of nodes in the tree sequence.
- num_samplesint
Number of samples in the tree sequence.
- num_edgesint
Number of edges in the tree sequence.
- num_sitesint
Number of sites in the tree sequence.
- num_mutationsint
Number of mutations in the tree sequence.
- sequence_lengthfloat
Total sequence length of the tree sequence.
- edges_leftnumpy.ndarray
Array (dtype=np.float64) of left coordinates of edges.
- edges_rightnumpy.ndarray
Array (dtype=np.float64) of right coordinates of edges.
- edges_parentnumpy.ndarray
Array (dtype=np.int32) of parent node IDs for each edge.
- edges_childnumpy.ndarray
Array (dtype=np.int32) of child node IDs for each edge.
- nodes_timenumpy.ndarray
Array (dtype=np.float64) of time values for each node.
- nodes_flagsnumpy.ndarray
Array (dtype=np.uint32) of flag values for each node.
- nodes_populationnumpy.ndarray
Array (dtype=np.int32) of population IDs for each node.
- nodes_individualnumpy.ndarray
Array (dtype=np.int32) of individual IDs for each node.
- individuals_flagsnumpy.ndarray
Array (dtype=np.uint32) of flag values for each individual.
- sites_positionnumpy.ndarray
Array (dtype=np.float64) of positions of sites along the sequence.
- mutations_sitenumpy.ndarray
Array (dtype=np.int32) of site IDs for each mutation.
- mutations_nodenumpy.ndarray
Array (dtype=np.int32) of node IDs for each mutation.
- mutations_parentnumpy.ndarray
Array (dtype=np.int32) of parent mutation IDs.
- mutations_timenumpy.ndarray
Array (dtype=np.float64) of time values for each mutation.
- breakpointsnumpy.ndarray
Array (dtype=np.float64) of genomic positions where trees change.
- indexes_edge_insertion_ordernumpy.ndarray
Array (dtype=np.int32) specifying the order in which edges are inserted during tree building.
- indexes_edge_removal_ordernumpy.ndarray
Array (dtype=np.int32) specifying the order in which edges are removed during tree building.
- tree_index()[source]#
Create a
TreeIndex
for traversing this tree sequence.- Returns:
A new tree index initialized to the null tree. Use next() or prev() to move to an actual tree.
- Return type:
- child_index()[source]#
Create child index array for finding child edges of nodes. This operation requires a linear pass over the edge table and therefore has a time complexity of O(E).
- Returns:
A numpy array (dtype=np.int32, shape=(num_nodes, 2)) where each row contains the [start, stop) range of edges where this node is the parent.
- Return type:
- parent_index()[source]#
Create a
ParentIndex
for finding parent edges of nodes.Edges within each child’s group are not guaranteed to be in any specific order. This operation uses a two-pass algorithm with O(N + E) time complexity and O(N) auxiliary space.
- Returns:
A new parent index container that can be used to efficiently find all edges where a given node is the child.
- Return type:
- class tskit.jit.numba.TreeIndex(ts)[source]#
Traverse trees in a numba compatible tree sequence.
This class provides efficient forward and backward iteration through the trees in a tree sequence. It provides the tree interval, edge changes to create the current tree, along with its sites and mutations. A full pass over the trees using repeated next or prev requires O(E + M + S) time complexity.
It should not be instantiated directly, but is returned by the tree_index method of NumbaTreeSequence.
Attributes#
- tsNumbaTreeSequence
Reference to the tree sequence being traversed.
- indexint
Current tree index. -1 indicates no current tree (null state).
- directionint
Traversal direction: tskit.FORWARD or tskit.REVERSE. tskit.NULL if uninitialised.
- intervaltuple
Genomic interval (left, right) covered by the current tree.
- in_rangeEdgeRange
Edges being added to form this current tree, relative to the last state
- out_rangeEdgeRange
Edges being removed to form this current tree, relative to the last state
- site_rangetuple
Range of sites in the current tree (start, stop).
- mutation_rangetuple
Range of mutations in the current tree (start, stop).
Example#
>>> tree_index = numba_ts.tree_index() >>> num_edges = 0 >>> while tree_index.next(): num_edges += (tree_index.in_range.stop - tree_index.in_range.start) num_edges -= (tree_index.out_range.stop - tree_index.out_range.start) print(f"Tree {tree_index.index}: {num_edges} edges")
- next()[source]#
Move to the next tree in forward direction.
Updates the tree index to the next tree in the sequence, computing the edges that need to be added and removed to transform from the previous tree to the current tree. On the first call, this initializes the iterator and moves to tree 0.
- Returns:
True if successfully moved to next tree, False if the end of the tree sequence is reached. When False is returned, the iterator is in null state (index=-1).
- Return type:
- prev()[source]#
Move to the previous tree in reverse direction.
Updates the tree index to the previous tree in the sequence, computing the edges that need to be added and removed to transform from the next tree to the current tree. On the first call, this initializes the iterator and moves to the most rightward tree.
- Returns:
True if successfully moved to previous tree, False if the beginning of the tree sequence is reached. When False is returned, the iterator is in null state (index=-1).
- Return type:
- class tskit.jit.numba.EdgeRange(*args, **kwargs)[source]#
Represents a range of edges during tree traversal.
This class encapsulates information about a contiguous range of edges that are either being removed or added to step from one tree to another. The
start
andstop
indices, when applied to the order array, define the ids of edges to process.Attributes#
- startint
Starting index of the edge range (inclusive).
- stopint
Stopping index of the edge range (exclusive).
- ordernumpy.ndarray
Array (dtype=np.int32) containing edge IDs in the order they should be processed. The edge ids in this range are order[start:stop].
- class tskit.jit.numba.ParentIndex(*args, **kwargs)[source]#
Simple data container for parent index information.
This class provides access to all edges where a given node is the child. Since edges are not sorted by child in the tskit edge table, a custom index (edge_index) is built that sorts edge IDs by child node. index_range then contains the [start, stop) range of edges for each child node in edge_index.
Attributes#
- edge_indexnumpy.ndarray
Array (dtype=np.int32) of edge IDs sorted by child node and left coordinate.
- index_rangenumpy.ndarray
Array (dtype=np.int32, shape=(num_nodes, 2)) where each row contains the [start, stop) range in edge_index where this node is the child.