Numba Integration#

The tskit.jit.numba module provides classes for working with tree sequences from Numba jit-compiled Python code. Such code can run up to hundreds of times faster than normal Python, yet avoids the difficulties of writing C or other low-level code.

Note

Numba is not a direct dependency of tskit, so will not be available unless installed:

pip install numba

or

conda install numba

Overview#

The numba integration provides:

These classes are designed to work within Numba’s @njit decorated functions, allowing you to write high-performance tree sequence analysis code.

Basic Usage#

The tskit.jit.numba module is not imported with normal tskit so must be imported explicitly:

import numpy as np
import tskit
import tskit.jit.numba as tskit_numba

Normal third-party classes such as tskit.TreeSequence can’t be used in numba.njit compiled functions so the tskit.TreeSequence must be wrapped in a NumbaTreeSequence by jitwrap(). This must be done outside njit code:

import msprime

ts = msprime.sim_ancestry(
    samples=50000,
    sequence_length=100000,
    recombination_rate=0.1,
    random_seed=42
)
numba_ts = tskit_numba.jitwrap(ts)
print(type(numba_ts))
<class 'numba.experimental.jitclass.boxing._NumbaTreeSequence'>

Tree Iteration#

Tree iteration can be performed using the NumbaTreeIndex class. This class provides next() and prev() methods for forward and backward iteration through the trees in a tree sequence. Its in_range and out_range attributes provide the edges that must be added or removed to form the current tree from the previous tree, along with the current tree interval and its sites and mutations through site_range and mutation_range.

A NumbaTreeIndex instance can be obtained from a NumbaTreeSequence using the tree_index() method. The initial state of this is of a “null” tree outside the range of the tree sequence, the first call to next() or prev() will be to the first, or last tree sequence tree respectively. After that, the in_range and out_range attributes will provide the edges that must be added or removed to form the current tree from the previous tree. For example tree_index.in_range.order[in_range.start:in_range.stop] will give the edge ids that are new in the current tree, and tree_index.out_range.order[out_range.start:out_range.stop] will give the edge ids that are no longer present in the current tree. tree_index.site_range and tree_index.mutation_range give the indexes into the tree sequences site and mutation arrays.

As a simple example we can calculate the number of edges in each tree in a tree sequence:

import numba

@numba.njit
def edges_per_tree(numba_ts):
    tree_index = numba_ts.tree_index()
    current_num_edges = 0
    num_edges = []
    
    # Move forward through the trees
    while tree_index.next():
        # Access current tree information
        in_range = tree_index.in_range
        out_range = tree_index.out_range
        
        current_num_edges -= (out_range.stop - out_range.start)
        current_num_edges += (in_range.stop - in_range.start)
        num_edges.append(current_num_edges)
    return num_edges

Hide code cell content

# Warm up the JIT compiler
edges = edges_per_tree(numba_ts)
import time

t = time.time()
jit_num_edges = edges_per_tree(numba_ts)
print(f"JIT Time taken: {time.time() - t:.4f} seconds")
JIT Time taken: 0.0363 seconds

Doing the same thing with the normal tskit API would be much slower:

t = time.time()
python_num_edges = []
for tree in ts.trees():
    python_num_edges.append(tree.num_edges)
print(f"Normal Time taken: {time.time() - t:.4f} seconds")

assert jit_num_edges == python_num_edges, "JIT and normal results do not match!"
Normal Time taken: 0.4326 seconds

Example - diversity calculation#

As a more interesting example we can calculate genetic diversity (also known as pi). For this example we’ll be calculating based on the distance in the tree between samples. (mode="branch" in the tskit API.)

This example also shows the style of Python code that gives best performance under numba JIT compilation - using simple loops and fixed-size arrays with minimal object attribute access.

    @numba.njit
    def diversity(numba_ts):
        # Cache arrays to avoid repeated attribute access in
        # tight loops
        edge_child = numba_ts.edges_child
        edge_parent = numba_ts.edges_parent
        node_times = numba_ts.nodes_time
        node_flags = numba_ts.nodes_flags
        
        if numba_ts.num_samples <= 1:
            return 0.0

        parent = np.full(numba_ts.num_nodes, -1, dtype=np.int32)
        branch_length = np.zeros(numba_ts.num_nodes, dtype=np.float64)
        state = np.zeros(numba_ts.num_nodes, dtype=np.int32)
        summary = np.zeros(numba_ts.num_nodes, dtype=np.float64)

        n = float(numba_ts.num_samples)
        two_over_denom = 2.0 / (n * (n - 1.0))
        sample_summary = 2.0 / n

        # Retrieve this constant outside the loop
        # to avoid repeated attribute access
        NODE_IS_SAMPLE = tskit.NODE_IS_SAMPLE
        # Find the sample nodes and initialize their states
        for node in range(numba_ts.num_nodes):
            if node_flags[node] & NODE_IS_SAMPLE:
                state[node] = 1.0
                summary[node] = sample_summary

        result = 0.0
        running_sum = 0.0
        tree_index = numba_ts.tree_index()

        # Now iterate through the trees
        while tree_index.next():
            # Process the outgoing edges
            for j in range(tree_index.out_range.start, tree_index.out_range.stop):
                h = tree_index.out_range.order[j]
                child = edge_child[h]
                child_parent = edge_parent[h]

                running_sum -= branch_length[child] * summary[child]
                parent[child] = -1
                branch_length[child] = 0.0

                u = child_parent
                parent_u = parent[u]
                while u != -1:
                    running_sum -= branch_length[u] * summary[u]
                    state[u] -= state[child]
                    summary[u] = state[u] * (n - state[u]) * two_over_denom
                    running_sum += branch_length[u] * summary[u]
                    u = parent_u
                    if u != -1:
                        parent_u = parent[u]

            # Process the incoming edges
            for j in range(tree_index.in_range.start, tree_index.in_range.stop):
                h = tree_index.in_range.order[j]
                child = edge_child[h]
                child_parent = edge_parent[h]

                parent[child] = child_parent
                branch_length[child] = node_times[child_parent] - node_times[child]
                running_sum += branch_length[child] * summary[child]

                u = child_parent
                parent_u = parent[u]
                while u != -1:
                    running_sum -= branch_length[u] * summary[u]
                    state[u] += state[child]
                    summary[u] = state[u] * (n - state[u]) * two_over_denom
                    running_sum += branch_length[u] * summary[u]
                    u = parent_u
                    if u != -1:
                        parent_u = parent[u]

            result += running_sum * (
                tree_index.interval[1] - tree_index.interval[0]
            )

        return result / numba_ts.sequence_length

Hide code cell content

# Warm up the JIT
d = diversity(numba_ts)
t = time.time()
d = diversity(numba_ts)
print("Diversity:", d)
print("Time taken:", time.time() - t)
Diversity: 3.992633119145986
Time taken: 0.28753113746643066

As this code is written for this specific diversity calculation it is even faster than the tskit C implementation, called here from Python:

t = time.time()
d_tskit = ts.diversity(mode="branch")
print("Diversity (tskit):", d_tskit)
print("Time taken:", time.time() - t)
Diversity (tskit): 3.992633119146012
Time taken: 0.5368030071258545

API Reference#

tskit.jit.numba.jitwrap(ts)[source]#

Convert a TreeSequence to a Numba-compatible format.

Creates a NumbaTreeSequence object that can be used within Numba-compiled functions.

Parameters#

tstskit.TreeSequence

The tree sequence to convert.

Returns#

NumbaTreeSequence

A Numba-compatible representation of the input tree sequence. Contains all necessary data arrays and metadata for tree traversal.

class tskit.jit.numba.NumbaTreeSequence(num_trees, num_nodes, num_samples, num_edges, num_sites, num_mutations, sequence_length, edges_left, edges_right, indexes_edge_insertion_order, indexes_edge_removal_order, individuals_flags, nodes_time, nodes_flags, nodes_population, nodes_individual, edges_parent, edges_child, sites_position, sites_ancestral_state, mutations_site, mutations_node, mutations_parent, mutations_time, mutations_derived_state, breakpoints, max_ancestral_length, max_derived_length)[source]#

A Numba-compatible representation of a tree sequence.

This class provides access a tree sequence class that can be used from within Numba “njit” compiled functions. jitwrap() should be used to JIT compile this class from a tskit.TreeSequence object, before it is passed to a Numba function.

Attributes#

num_treesint32

Number of trees in the tree sequence.

num_nodesint32

Number of nodes in the tree sequence.

num_samplesint32

Number of samples in the tree sequence.

num_edgesint32

Number of edges in the tree sequence.

num_sitesint32

Number of sites in the tree sequence.

num_mutationsint32

Number of mutations in the tree sequence.

sequence_lengthfloat64

Total sequence length of the tree sequence.

edges_leftfloat64[]

Left coordinates of edges.

edges_rightfloat64[]

Right coordinates of edges.

edges_parentint32[]

Parent node IDs for each edge.

edges_childint32[]

Child node IDs for each edge.

nodes_timefloat64[]

Time values for each node.

nodes_flagsuint32[]

Flag values for each node.

nodes_populationint32[]

Population IDs for each node.

nodes_individualint32[]

Individual IDs for each node.

individuals_flagsuint32[]

Flag values for each individual.

sites_positionfloat64[]

Positions of sites along the sequence.

mutations_siteint32[]

Site IDs for each mutation.

mutations_nodeint32[]

Node IDs for each mutation.

mutations_parentint32[]

Parent mutation IDs.

mutations_timefloat64[]

Time values for each mutation.

breakpointsfloat64[]

Genomic positions where trees change.

indexes_edge_insertion_orderint32[]

Order in which edges are inserted during tree building.

indexes_edge_removal_orderint32[]

Order in which edges are removed during tree building.

tree_index()[source]#

Create a NumbaTreeIndex for traversing this tree sequence.

Returns#

NumbaTreeIndex

A new tree index initialized to the null tree. Use next() or prev() to move to an actual tree.

Examples#

>>> tree_index = numba_ts.tree_index()
>>> while tree_index.next():
...     # Process current tree at tree_index.index
...     print(f"Tree {tree_index.index}: {tree_index.interval}")
class tskit.jit.numba.NumbaTreeIndex(ts)[source]#

Traverse trees in a numba compatible tree sequence.

This class provides efficient forward and backward iteration through the trees in a tree sequence. It provides the tree interval, edge changes to create the current tree, along with its sites and mutations.

It should not be instantiated directly, but is returned by the tree_index method of NumbaTreeSequence.

Attributes#

tsNumbaTreeSequence

Reference to the tree sequence being traversed.

indexint32

Current tree index. -1 indicates no current tree (null state).

directionint32

Traversal direction: tskit.FORWARD or tskit.REVERSE. tskit.NULL if uninitialised.

intervaltuple of float64

Genomic interval (left, right) covered by the current tree.

in_rangeNumbaEdgeRange

Edges being added to form this current tree, relative to the last state

out_rangeNumbaEdgeRange

Edges being removed to form this current tree, relative to the last state

site_rangetuple of int32

Range of sites in the current tree (start, stop).

mutation_rangetuple of int32

Range of mutations in the current tree (start, stop).

Example#

>>> tree_index = numba_ts.tree_index()
>>> num_edges = 0
>>> while tree_index.next():
        num_edges += (tree_index.in_range.stop - tree_index.in_range.start)
        num_edges -= (tree_index.out_range.stop - tree_index.out_range.start)
        print(f"Tree {tree_index.index}: {num_edges} edges")
set_null()[source]#

Reset the tree index to null state.

next()[source]#

Move to the next tree in forward direction.

Updates the tree index to the next tree in the sequence, computing the edges that need to be added and removed to transform from the previous tree to the current tree, storing them in self.in_range and self.out_range.

Returns#

bool

True if successfully moved to next tree, False if the end of the tree sequence is reached. When False is returned, the iterator is in null state (index=-1).

Notes#

On the first call, this initializes the iterator and moves to tree 0. The in_range and out_range attributes are updated to reflect the edge changes needed for the current tree.

prev()[source]#

Move to the previous tree in reverse direction.

Updates the tree index to the previous tree in the sequence, computing the edges that need to be added and removed to transform from the next tree to the current tree, storing them in self.in_range and self.out_range

Returns#

bool

True if successfully moved to previous tree, False if the beginning of the tree sequence is reached. When False is returned, the iterator is in null state (index=-1).

Notes#

On the first call, this initializes the iterator and moves to the most rightward tree. The in_range and out_range attributes are updated to reflect the edge changes needed for the current tree when traversing backward.

class tskit.jit.numba.NumbaEdgeRange(*args, **kwargs)[source]#

Represents a range of edges during tree traversal.

This class encapsulates information about a contiguous range of edges that are either being removed or added to step from one tree to another The start and stop indices, when applied to the order array, define the ids of edges to process.

Attributes#

startint32

Starting index of the edge range (inclusive).

stopint32

Stopping index of the edge range (exclusive).

orderint32[]

Array containing edge IDs in the order they should be processed. The edge ids in this range are order[start:stop].