Numba Integration#
The tskit.jit.numba
module provides classes for working with tree sequences
from Numba jit-compiled Python code. Such code can run
up to hundreds of times faster than normal Python, yet avoids the difficulties of writing
C or other low-level code.
Note
Numba is not a direct dependency of tskit, so will not be available unless installed:
pip install numba
or
conda install numba
Overview#
The numba integration provides:
NumbaTreeSequence
: A Numba-compatible representation of tree sequence dataNumbaTreeIndex
: A class for efficient tree iterationNumbaEdgeRange
: Container class for edge ranges during iteration
These classes are designed to work within Numba’s @njit
decorated functions,
allowing you to write high-performance tree sequence analysis code.
Basic Usage#
The tskit.jit.numba
module is not imported with normal tskit
so must be imported explicitly:
import numpy as np
import tskit
import tskit.jit.numba as tskit_numba
Normal third-party classes such as tskit.TreeSequence
can’t be used in numba.njit
compiled
functions so the tskit.TreeSequence
must be wrapped in a NumbaTreeSequence
by
jitwrap()
. This must be done outside njit
code:
import msprime
ts = msprime.sim_ancestry(
samples=50000,
sequence_length=100000,
recombination_rate=0.1,
random_seed=42
)
numba_ts = tskit_numba.jitwrap(ts)
print(type(numba_ts))
<class 'numba.experimental.jitclass.boxing._NumbaTreeSequence'>
Tree Iteration#
Tree iteration can be performed using the NumbaTreeIndex
class.
This class provides next()
and prev()
methods for forward and backward iteration through the trees in a tree sequence. Its in_range
and out_range
attributes provide the edges that must be added or removed to form the current
tree from the previous tree, along with the current tree interval
and its sites and mutations through site_range
and mutation_range
.
A NumbaTreeIndex
instance can be obtained from a NumbaTreeSequence
using the tree_index()
method. The initial state of this is of a “null” tree outside the range of the tree sequence, the first call to next()
or prev()
will be to the first, or last tree sequence tree respectively. After that, the in_range
and out_range
attributes will provide the edges that must be added or removed to form the current tree from the previous tree. For example
tree_index.in_range.order[in_range.start:in_range.stop]
will give the edge ids that are new in the current tree, and tree_index.out_range.order[out_range.start:out_range.stop]
will give the edge ids that are no longer present in the current tree. tree_index.site_range
and
tree_index.mutation_range
give the indexes into the tree sequences site and mutation arrays.
As a simple example we can calculate the number of edges in each tree in a tree sequence:
import numba
@numba.njit
def edges_per_tree(numba_ts):
tree_index = numba_ts.tree_index()
current_num_edges = 0
num_edges = []
# Move forward through the trees
while tree_index.next():
# Access current tree information
in_range = tree_index.in_range
out_range = tree_index.out_range
current_num_edges -= (out_range.stop - out_range.start)
current_num_edges += (in_range.stop - in_range.start)
num_edges.append(current_num_edges)
return num_edges
import time
t = time.time()
jit_num_edges = edges_per_tree(numba_ts)
print(f"JIT Time taken: {time.time() - t:.4f} seconds")
JIT Time taken: 0.0363 seconds
Doing the same thing with the normal tskit
API would be much slower:
t = time.time()
python_num_edges = []
for tree in ts.trees():
python_num_edges.append(tree.num_edges)
print(f"Normal Time taken: {time.time() - t:.4f} seconds")
assert jit_num_edges == python_num_edges, "JIT and normal results do not match!"
Normal Time taken: 0.4326 seconds
Example - diversity calculation#
As a more interesting example we can calculate genetic diversity (also known as pi).
For this example we’ll be calculating based on the distance in the tree between samples.
(mode="branch"
in the tskit API.)
This example also shows the style of Python code that gives best performance under numba
JIT compilation - using simple loops and fixed-size arrays with minimal object attribute access.
@numba.njit
def diversity(numba_ts):
# Cache arrays to avoid repeated attribute access in
# tight loops
edge_child = numba_ts.edges_child
edge_parent = numba_ts.edges_parent
node_times = numba_ts.nodes_time
node_flags = numba_ts.nodes_flags
if numba_ts.num_samples <= 1:
return 0.0
parent = np.full(numba_ts.num_nodes, -1, dtype=np.int32)
branch_length = np.zeros(numba_ts.num_nodes, dtype=np.float64)
state = np.zeros(numba_ts.num_nodes, dtype=np.int32)
summary = np.zeros(numba_ts.num_nodes, dtype=np.float64)
n = float(numba_ts.num_samples)
two_over_denom = 2.0 / (n * (n - 1.0))
sample_summary = 2.0 / n
# Retrieve this constant outside the loop
# to avoid repeated attribute access
NODE_IS_SAMPLE = tskit.NODE_IS_SAMPLE
# Find the sample nodes and initialize their states
for node in range(numba_ts.num_nodes):
if node_flags[node] & NODE_IS_SAMPLE:
state[node] = 1.0
summary[node] = sample_summary
result = 0.0
running_sum = 0.0
tree_index = numba_ts.tree_index()
# Now iterate through the trees
while tree_index.next():
# Process the outgoing edges
for j in range(tree_index.out_range.start, tree_index.out_range.stop):
h = tree_index.out_range.order[j]
child = edge_child[h]
child_parent = edge_parent[h]
running_sum -= branch_length[child] * summary[child]
parent[child] = -1
branch_length[child] = 0.0
u = child_parent
parent_u = parent[u]
while u != -1:
running_sum -= branch_length[u] * summary[u]
state[u] -= state[child]
summary[u] = state[u] * (n - state[u]) * two_over_denom
running_sum += branch_length[u] * summary[u]
u = parent_u
if u != -1:
parent_u = parent[u]
# Process the incoming edges
for j in range(tree_index.in_range.start, tree_index.in_range.stop):
h = tree_index.in_range.order[j]
child = edge_child[h]
child_parent = edge_parent[h]
parent[child] = child_parent
branch_length[child] = node_times[child_parent] - node_times[child]
running_sum += branch_length[child] * summary[child]
u = child_parent
parent_u = parent[u]
while u != -1:
running_sum -= branch_length[u] * summary[u]
state[u] += state[child]
summary[u] = state[u] * (n - state[u]) * two_over_denom
running_sum += branch_length[u] * summary[u]
u = parent_u
if u != -1:
parent_u = parent[u]
result += running_sum * (
tree_index.interval[1] - tree_index.interval[0]
)
return result / numba_ts.sequence_length
t = time.time()
d = diversity(numba_ts)
print("Diversity:", d)
print("Time taken:", time.time() - t)
Diversity: 3.992633119145986
Time taken: 0.28753113746643066
As this code is written for this specific diversity calculation it is even faster than the tskit C implementation, called here from Python:
t = time.time()
d_tskit = ts.diversity(mode="branch")
print("Diversity (tskit):", d_tskit)
print("Time taken:", time.time() - t)
Diversity (tskit): 3.992633119146012
Time taken: 0.5368030071258545
API Reference#
- tskit.jit.numba.jitwrap(ts)[source]#
Convert a TreeSequence to a Numba-compatible format.
Creates a NumbaTreeSequence object that can be used within Numba-compiled functions.
Parameters#
- tstskit.TreeSequence
The tree sequence to convert.
Returns#
- NumbaTreeSequence
A Numba-compatible representation of the input tree sequence. Contains all necessary data arrays and metadata for tree traversal.
- class tskit.jit.numba.NumbaTreeSequence(num_trees, num_nodes, num_samples, num_edges, num_sites, num_mutations, sequence_length, edges_left, edges_right, indexes_edge_insertion_order, indexes_edge_removal_order, individuals_flags, nodes_time, nodes_flags, nodes_population, nodes_individual, edges_parent, edges_child, sites_position, sites_ancestral_state, mutations_site, mutations_node, mutations_parent, mutations_time, mutations_derived_state, breakpoints, max_ancestral_length, max_derived_length)[source]#
A Numba-compatible representation of a tree sequence.
This class provides access a tree sequence class that can be used from within Numba “njit” compiled functions.
jitwrap()
should be used to JIT compile this class from atskit.TreeSequence
object, before it is passed to a Numba function.Attributes#
- num_treesint32
Number of trees in the tree sequence.
- num_nodesint32
Number of nodes in the tree sequence.
- num_samplesint32
Number of samples in the tree sequence.
- num_edgesint32
Number of edges in the tree sequence.
- num_sitesint32
Number of sites in the tree sequence.
- num_mutationsint32
Number of mutations in the tree sequence.
- sequence_lengthfloat64
Total sequence length of the tree sequence.
- edges_leftfloat64[]
Left coordinates of edges.
- edges_rightfloat64[]
Right coordinates of edges.
- edges_parentint32[]
Parent node IDs for each edge.
- edges_childint32[]
Child node IDs for each edge.
- nodes_timefloat64[]
Time values for each node.
- nodes_flagsuint32[]
Flag values for each node.
- nodes_populationint32[]
Population IDs for each node.
- nodes_individualint32[]
Individual IDs for each node.
- individuals_flagsuint32[]
Flag values for each individual.
- sites_positionfloat64[]
Positions of sites along the sequence.
- mutations_siteint32[]
Site IDs for each mutation.
- mutations_nodeint32[]
Node IDs for each mutation.
- mutations_parentint32[]
Parent mutation IDs.
- mutations_timefloat64[]
Time values for each mutation.
- breakpointsfloat64[]
Genomic positions where trees change.
- indexes_edge_insertion_orderint32[]
Order in which edges are inserted during tree building.
- indexes_edge_removal_orderint32[]
Order in which edges are removed during tree building.
- tree_index()[source]#
Create a
NumbaTreeIndex
for traversing this tree sequence.Returns#
- NumbaTreeIndex
A new tree index initialized to the null tree. Use next() or prev() to move to an actual tree.
Examples#
>>> tree_index = numba_ts.tree_index() >>> while tree_index.next(): ... # Process current tree at tree_index.index ... print(f"Tree {tree_index.index}: {tree_index.interval}")
- class tskit.jit.numba.NumbaTreeIndex(ts)[source]#
Traverse trees in a numba compatible tree sequence.
This class provides efficient forward and backward iteration through the trees in a tree sequence. It provides the tree interval, edge changes to create the current tree, along with its sites and mutations.
It should not be instantiated directly, but is returned by the tree_index method of NumbaTreeSequence.
Attributes#
- tsNumbaTreeSequence
Reference to the tree sequence being traversed.
- indexint32
Current tree index. -1 indicates no current tree (null state).
- directionint32
Traversal direction: tskit.FORWARD or tskit.REVERSE. tskit.NULL if uninitialised.
- intervaltuple of float64
Genomic interval (left, right) covered by the current tree.
- in_rangeNumbaEdgeRange
Edges being added to form this current tree, relative to the last state
- out_rangeNumbaEdgeRange
Edges being removed to form this current tree, relative to the last state
- site_rangetuple of int32
Range of sites in the current tree (start, stop).
- mutation_rangetuple of int32
Range of mutations in the current tree (start, stop).
Example#
>>> tree_index = numba_ts.tree_index() >>> num_edges = 0 >>> while tree_index.next(): num_edges += (tree_index.in_range.stop - tree_index.in_range.start) num_edges -= (tree_index.out_range.stop - tree_index.out_range.start) print(f"Tree {tree_index.index}: {num_edges} edges")
- next()[source]#
Move to the next tree in forward direction.
Updates the tree index to the next tree in the sequence, computing the edges that need to be added and removed to transform from the previous tree to the current tree, storing them in self.in_range and self.out_range.
Returns#
- bool
True if successfully moved to next tree, False if the end of the tree sequence is reached. When False is returned, the iterator is in null state (index=-1).
Notes#
On the first call, this initializes the iterator and moves to tree 0. The in_range and out_range attributes are updated to reflect the edge changes needed for the current tree.
- prev()[source]#
Move to the previous tree in reverse direction.
Updates the tree index to the previous tree in the sequence, computing the edges that need to be added and removed to transform from the next tree to the current tree, storing them in self.in_range and self.out_range
Returns#
- bool
True if successfully moved to previous tree, False if the beginning of the tree sequence is reached. When False is returned, the iterator is in null state (index=-1).
Notes#
On the first call, this initializes the iterator and moves to the most rightward tree. The in_range and out_range attributes are updated to reflect the edge changes needed for the current tree when traversing backward.
- class tskit.jit.numba.NumbaEdgeRange(*args, **kwargs)[source]#
Represents a range of edges during tree traversal.
This class encapsulates information about a contiguous range of edges that are either being removed or added to step from one tree to another The
start
andstop
indices, when applied to the order array, define the ids of edges to process.Attributes#
- startint32
Starting index of the edge range (inclusive).
- stopint32
Stopping index of the edge range (exclusive).
- orderint32[]
Array containing edge IDs in the order they should be processed. The edge ids in this range are order[start:stop].