Numba Integration#

The tskit.jit.numba module provides classes for working with tree sequences from Numba jit-compiled Python code. Such code can run up to hundreds of times faster than normal Python, yet avoids the difficulties of writing C or other low-level code.

Note

Numba is not a direct dependency of tskit, so will not be available unless installed:

pip install numba

or

conda install numba

Overview#

The numba integration provides a tskit.TreeSequence wrapper class NumbaTreeSequence. This class can be used directly in numba.njit compiled functions, and provides several efficient methods for tree traversal:

  • tree_index(): For efficient iteration through the trees in the sequence

  • parent_index(): For efficient access to parent edge information, to traverse upwards through the ARG.

  • child_index(): For efficient access to child edge information, to traverse downwards through the ARG.

These methods are optimised to work within Numba’s @njit decorated functions, allowing you to write high-performance tree sequence analysis code in a plain Python style.

Basic Usage#

The tskit.jit.numba module is not imported with normal tskit so must be imported explicitly:

import numpy as np
import tskit
import tskit.jit.numba as tskit_numba

Normal third-party classes such as tskit.TreeSequence can’t be used in numba.njit compiled functions so the tskit.TreeSequence must be wrapped in a NumbaTreeSequence by jitwrap(). This must be done outside njit code:

import msprime

ts = msprime.sim_ancestry(
    samples=50000,
    sequence_length=100000,
    recombination_rate=0.1,
    random_seed=42
)
numba_ts = tskit_numba.jitwrap(ts)
print(type(numba_ts))
<class 'numba.experimental.jitclass.boxing._NumbaTreeSequence'>

Tree Iteration#

Tree iteration can be performed in numba.njit compiled functions using the TreeIndex class. This class provides next() and prev() methods for forward and backward iteration through the trees in a tree sequence. Its in_range and out_range attributes provide the edges that must be added or removed to form the current tree from the previous tree, along with the current tree interval and its sites and mutations through site_range and mutation_range.

A TreeIndex instance can be obtained from a NumbaTreeSequence using the tree_index() method. The initial state of this is of a “null” tree outside the range of the tree sequence, the first call to next() or prev() will be to the first, or last tree sequence tree respectively. After that, the in_range and out_range attributes will provide the edges that must be added or removed to form the current tree from the previous tree. For example tree_index.in_range.order[in_range.start:in_range.stop] will give the edge ids that are new in the current tree, and tree_index.out_range.order[out_range.start:out_range.stop] will give the edge ids that are no longer present in the current tree. tree_index.site_range and tree_index.mutation_range give the indexes into the tree sequences site and mutation arrays.

As a simple example we can calculate the number of edges in each tree in a tree sequence:

import numba

@numba.njit
def edges_per_tree(numba_ts):
    tree_index = numba_ts.tree_index()
    current_num_edges = 0
    num_edges = []
    
    # Move forward through the trees
    while tree_index.next():
        # Access current tree information
        in_range = tree_index.in_range
        out_range = tree_index.out_range
        
        current_num_edges -= (out_range.stop - out_range.start)
        current_num_edges += (in_range.stop - in_range.start)
        num_edges.append(current_num_edges)
    return num_edges

Hide code cell content

# Warm up the JIT compiler
edges = edges_per_tree(numba_ts)
import time

t = time.time()
jit_num_edges = edges_per_tree(numba_ts)
print(f"JIT Time taken: {time.time() - t:.4f} seconds")
JIT Time taken: 0.0320 seconds

Doing the same thing with the normal tskit API would be much slower:

t = time.time()
python_num_edges = []
for tree in ts.trees():
    python_num_edges.append(tree.num_edges)
print(f"Normal Time taken: {time.time() - t:.4f} seconds")

assert jit_num_edges == python_num_edges, "JIT and normal results do not match!"
Normal Time taken: 0.3796 seconds

Example - diversity calculation#

As a more interesting example we can calculate genetic diversity (also known as pi). For this example we’ll be calculating based on the distance in the tree between samples. (mode="branch" in the tskit API.)

This example also shows the style of Python code that gives best performance under numba JIT compilation - using simple loops and fixed-size arrays with minimal object attribute access.

    @numba.njit
    def diversity(numba_ts):
        # Cache arrays to avoid repeated attribute access in
        # tight loops
        edge_child = numba_ts.edges_child
        edge_parent = numba_ts.edges_parent
        node_times = numba_ts.nodes_time
        node_flags = numba_ts.nodes_flags
        
        if numba_ts.num_samples <= 1:
            return 0.0

        parent = np.full(numba_ts.num_nodes, -1, dtype=np.int32)
        branch_length = np.zeros(numba_ts.num_nodes, dtype=np.float64)
        state = np.zeros(numba_ts.num_nodes, dtype=np.int32)
        summary = np.zeros(numba_ts.num_nodes, dtype=np.float64)

        n = float(numba_ts.num_samples)
        two_over_denom = 2.0 / (n * (n - 1.0))
        sample_summary = 2.0 / n

        # Retrieve this constant outside the loop
        # to avoid repeated attribute access
        NODE_IS_SAMPLE = tskit.NODE_IS_SAMPLE
        # Find the sample nodes and initialize their states
        for node in range(numba_ts.num_nodes):
            if node_flags[node] & NODE_IS_SAMPLE:
                state[node] = 1.0
                summary[node] = sample_summary

        result = 0.0
        running_sum = 0.0
        tree_index = numba_ts.tree_index()

        # Now iterate through the trees
        while tree_index.next():
            # Process the outgoing edges
            for j in range(tree_index.out_range.start, tree_index.out_range.stop):
                h = tree_index.out_range.order[j]
                child = edge_child[h]
                child_parent = edge_parent[h]

                running_sum -= branch_length[child] * summary[child]
                parent[child] = -1
                branch_length[child] = 0.0

                u = child_parent
                parent_u = parent[u]
                while u != -1:
                    running_sum -= branch_length[u] * summary[u]
                    state[u] -= state[child]
                    summary[u] = state[u] * (n - state[u]) * two_over_denom
                    running_sum += branch_length[u] * summary[u]
                    u = parent_u
                    if u != -1:
                        parent_u = parent[u]

            # Process the incoming edges
            for j in range(tree_index.in_range.start, tree_index.in_range.stop):
                h = tree_index.in_range.order[j]
                child = edge_child[h]
                child_parent = edge_parent[h]

                parent[child] = child_parent
                branch_length[child] = node_times[child_parent] - node_times[child]
                running_sum += branch_length[child] * summary[child]

                u = child_parent
                parent_u = parent[u]
                while u != -1:
                    running_sum -= branch_length[u] * summary[u]
                    state[u] += state[child]
                    summary[u] = state[u] * (n - state[u]) * two_over_denom
                    running_sum += branch_length[u] * summary[u]
                    u = parent_u
                    if u != -1:
                        parent_u = parent[u]

            result += running_sum * (
                tree_index.interval[1] - tree_index.interval[0]
            )

        return result / numba_ts.sequence_length

Hide code cell content

# Warm up the JIT
d = diversity(numba_ts)
t = time.time()
d = diversity(numba_ts)
print("Diversity:", d)
print("Time taken:", time.time() - t)
Diversity: 3.992633119145986
Time taken: 0.27684879302978516

As this code is written for this specific diversity calculation it is even faster than the tskit C implementation, called here from Python:

t = time.time()
d_tskit = ts.diversity(mode="branch")
print("Diversity (tskit):", d_tskit)
print("Time taken:", time.time() - t)
Diversity (tskit): 3.992633119146012
Time taken: 0.5163848400115967

ARG Traversal#

Beyond iterating through trees, you may need to traverse the ARG vertically. The child_index() and parent_index() methods provide efficient access to parent-child relationships in the edge table within numba.njit functions.

The child_index() method returns an array that allows you to efficiently find all edges where a given node is the parent. Since edges are already sorted by parent in the tskit data model, this is implemented using simple range indexing. For any node u, the returned array child_index[u] gives a tuple of the start and stop indices in the tskit edge table where node u is the parent. The index is calculated on each call to child_index() so should be called once.

The parent_index() method creates a ParentIndex that allows you to efficiently find all edges where a given node is the child. Since edges are not sorted by child in the edge table, the returned class contains a custom index that sorts edge IDs by child node (and then by left coordinate). For any node u, parent_index.index_range[u] gives a tuple of the start and stop indices in the parent_index.edge_index array, and parent_index.edge_index[start:stop] gives the actual tskit edge IDs.

Both can be obtained from a NumbaTreeSequence:

# Get the indexes
child_index = numba_ts.child_index()
parent_index = numba_ts.parent_index()

# Example: find all left coordinates of edges where node 5 is the parent
start, stop = child_index[5]
left_coords = numba_ts.edges_left[start:stop]
print(left_coords)

# Example: find all right coordinates of edges where node 3 is the child
start, stop = parent_index.index_range[3]
right_coords = numba_ts.edges_right[start:stop]
print(right_coords)
[]
[100000.]

These indexes enable efficient algorithms that need to traverse parent-child relationships in the ARG, such as computing descendant sets, ancestral paths, or subtree properties.

Example - descendant span calculation#

Here’s an example of using the ARG traversal indexes to calculate the total sequence length over which each node descends from a specified node:

@numba.njit
def descendant_span(numba_ts, u):
    """
    Calculate the total sequence length over which each node 
    descends from the specified node u.
    """
    child_index = numba_ts.child_index()
    edges_left = numba_ts.edges_left
    edges_right = numba_ts.edges_right
    edges_child = numba_ts.edges_child
    
    total_descending = np.zeros(numba_ts.num_nodes)
    stack = [(u, 0.0, numba_ts.sequence_length)]
    
    while len(stack) > 0:
        node, left, right = stack.pop()
        total_descending[node] += right - left
        
        # Find all child edges for this node
        for e in range(child_index[node, 0], child_index[node, 1]):
            e_left = edges_left[e]
            e_right = edges_right[e]
            
            # Check if edge overlaps with current interval
            if e_right > left and right > e_left:
                inter_left = max(e_left, left)
                inter_right = min(e_right, right)
                e_child = edges_child[e]
                stack.append((e_child, inter_left, inter_right))
    
    return total_descending

Hide code cell content

# Warm up the JIT
result = descendant_span(numba_ts, 0)
# Calculate descendant span for the root node (highest numbered node)
root_node = numba_ts.num_nodes - 1
result = descendant_span(numba_ts, root_node)

# Show nodes that have non-zero descendant span
non_zero = result > 0
print(f"Nodes descended from {root_node}:")
print(f"Node IDs: {np.where(non_zero)[0]}")
print(f"Span lengths: {result[non_zero]}")
Nodes descended from 495178:
Node IDs: [     0      1      2 ... 457696 469935 495178]
Span lengths: [1.e+00 1.e+00 1.e+00 ... 1.e+00 1.e+00 1.e+05]

Comparing performance with using the tskit Python API:

def descendant_span_tskit(ts, u):
    """Reference implementation using tskit trees"""
    total_descending = np.zeros(ts.num_nodes)
    for tree in ts.trees():
        descendants = tree.preorder(u)
        total_descending[descendants] += tree.span
    return total_descending

import time
t = time.time()
numba_result = descendant_span(numba_ts, root_node)
print(f"Numba time: {time.time() - t:.6f} seconds")

t = time.time()
tskit_result = descendant_span_tskit(ts, root_node)
print(f"tskit time: {time.time() - t:.6f} seconds")

np.testing.assert_array_almost_equal(numba_result, tskit_result, decimal=10)
print("Results match!")
Numba time: 0.014827 seconds
tskit time: 0.866002 seconds
Results match!

Example - ARG descendant and ancestral edges calculation#

As we have child_index and parent_index, we can efficiently find both descendant and ancestral sub-ARGs for a given node. This first example shows how to find all edges in the ARG that are descendants of a given node. It returns a boolean array indicating which edges are part of the sub-ARG rooted at the specified node:

@numba.njit
def descendant_edges(numba_ts, u):
    """
    Returns a boolean array which is only True for edges that are descendants of node u.
    """
    edge_select = np.zeros(numba_ts.num_edges, dtype=np.bool_)
    child_index = numba_ts.child_index()
    edges_left = numba_ts.edges_left
    edges_right = numba_ts.edges_right
    edges_child = numba_ts.edges_child
    
    # The stack stores (node_id, left_coord, right_coord)
    stack = [(u, 0.0, numba_ts.sequence_length)]
    
    while len(stack) > 0:
        node, left, right = stack.pop()
        
        # Find all edges where 'node' is the parent
        start, stop = child_index[node]
        for e in range(start, stop):
            e_left = edges_left[e]
            e_right = edges_right[e]
            
            # Check for genomic interval overlap
            if e_right > left and right > e_left:
                # This edge is part of the sub-ARG
                edge_select[e] = True
                
                # Calculate the intersection for the next traversal step
                inter_left = max(e_left, left)
                inter_right = min(e_right, right)
                e_child = edges_child[e]
                stack.append((e_child, inter_left, inter_right))
                
    return edge_select
# Find descendant edges for a high-numbered node (likely near root)
test_node = max(0, numba_ts.num_nodes - 5)
edge_select = descendant_edges(numba_ts, test_node)

# Show which edges are descendants
descendant_edge_ids = np.where(edge_select)[0]
print(f"Edges descended from node {test_node}: {descendant_edge_ids[:10]}...")
print(f"Total descendant edges: {np.sum(edge_select)}")
Edges descended from node 495174: [0 1 2 3 4 5 6 7 8 9]...
Total descendant edges: 199998

Hide code cell content

# Create a simple hard-coded example for consistent visualization
tables = tskit.TableCollection(sequence_length=10.0)

tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)  # node 0
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)  # node 1
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)  # node 2
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)  # node 3
tables.nodes.add_row(flags=0, time=1)  # node 4
tables.nodes.add_row(flags=0, time=2)  # node 5
tables.nodes.add_row(flags=0, time=3)  # node 6

tables.edges.add_row(left=0, right=5, parent=4, child=0)
tables.edges.add_row(left=0, right=10, parent=4, child=1)
tables.edges.add_row(left=5, right=10, parent=5, child=0)
tables.edges.add_row(left=0, right=10, parent=5, child=2)
tables.edges.add_row(left=0, right=7, parent=6, child=4)
tables.edges.add_row(left=0, right=10, parent=6, child=5)
tables.edges.add_row(left=7, right=10, parent=6, child=3)

tables.sort()
ts_simple = tables.tree_sequence()

A tree sequence is easily made from the descendant edges array:

numba_ts_simple = tskit_numba.jitwrap(ts_simple)
node = 5
E = descendant_edges(numba_ts_simple, node)
tables_sub = ts_simple.dump_tables()
tables_sub.edges.replace_with(tables_sub.edges[E])
ts_sub = tables_sub.tree_sequence()

As an example, lets visualise the selection of a sub-ARG. Here is the full ARG with a highlighted node:

css_style = f".node.n{node} > .sym {{ fill: #c41e3a; }}"
ts_simple.draw_svg(size=(400, 200), node_labels={}, y_axis=True, style=css_style)
_images/1d26658c9c3aa9abef58d4e88fac63d97728c3711f129d3550a9c19bfa0fbca0.svg

And the sub-ARG from that node:

ts_sub.draw_svg(size=(400, 200), node_labels={}, y_axis=True, style=css_style)
_images/61396d35c0edfe4375826663e7d658aa21b68be27f4fbad52e11981d9d2e25ca.svg

In the other direction, we can similarly find the sub-ARG that is ancestral to a given node:

@numba.njit
def ancestral_edges(numba_ts, u):
    """
    Returns a boolean array which is only True for edges that are ancestors of node u.
    """
    edge_select = np.zeros(numba_ts.num_edges, dtype=np.bool_)
    parent_index = numba_ts.parent_index()
    edges_left = numba_ts.edges_left
    edges_right = numba_ts.edges_right
    edges_parent = numba_ts.edges_parent
    
    # The stack stores (node_id, left_coord, right_coord)
    stack = [(u, 0.0, numba_ts.sequence_length)]
    
    while len(stack) > 0:
        node, left, right = stack.pop()
        
        # Find all edges where 'node' is the child
        start, stop = parent_index.index_range[node]
        for i in range(start, stop):
            e = parent_index.edge_index[i]
            e_left = edges_left[e]
            e_right = edges_right[e]
            
            # Check for genomic interval overlap
            if e_right > left and right > e_left:
                # This edge is part of the sub-ARG
                edge_select[e] = True
                
                # Calculate the intersection for the next traversal step
                inter_left = max(e_left, left)
                inter_right = min(e_right, right)
                e_parent = edges_parent[e]
                stack.append((e_parent, inter_left, inter_right))

    return edge_select
# Find ancestral edges for a sample node (low-numbered nodes are usually samples)
test_node = min(5, numba_ts.num_nodes - 1)
edge_select = ancestral_edges(numba_ts, test_node)

# Show which edges are ancestors
ancestral_edge_ids = np.where(edge_select)[0]
print(f"Edges ancestral to node {test_node}: {ancestral_edge_ids[:10]}...")
print(f"Total ancestral edges: {np.sum(edge_select)}")
Edges ancestral to node 5: [ 57753  59866  59868 112525 170304 206560 206561 206562 252004 280568]...
Total ancestral edges: 110058

Hide code cell content

# Warm up the JIT for both functions
_ = descendant_edges(numba_ts, 0)
_ = ancestral_edges(numba_ts, 0)

Comparing performance with using the tskit Python API shows significant speedup:

def descendant_edges_tskit(ts, start_node):
    D = np.zeros(ts.num_edges, dtype=bool)
    for tree in ts.trees():
        for v in tree.preorder(start_node):
            if v != start_node:
                D[tree.edge(v)] = True
    return D

def ancestral_edges_tskit(ts, start_node):
    A = np.zeros(ts.num_edges, dtype=bool)
    for tree in ts.trees():
        curr_node = start_node
        parent = tree.parent(curr_node)
        while parent != tskit.NULL:
            edge_id = tree.edge(curr_node)
            A[edge_id] = True
            curr_node = parent
            parent = tree.parent(curr_node)
    return A

import time

# Test with root node for descendant edges
root_node = numba_ts.num_nodes - 1
t = time.time()
numba_desc = descendant_edges(numba_ts, root_node)
print(f"Numba descendant edges time: {time.time() - t:.6f} seconds")

t = time.time()
tskit_desc = descendant_edges_tskit(ts, root_node)
print(f"tskit descendant edges time: {time.time() - t:.6f} seconds")

# Test with sample node for ancestral edges  
sample_node = 0
t = time.time()
numba_anc = ancestral_edges(numba_ts, sample_node)
print(f"Numba ancestral edges time: {time.time() - t:.6f} seconds")

t = time.time()
tskit_anc = ancestral_edges_tskit(ts, sample_node)
print(f"tskit ancestral edges time: {time.time() - t:.6f} seconds")

# Verify results match
np.testing.assert_array_equal(numba_desc, tskit_desc)
np.testing.assert_array_equal(numba_anc, tskit_anc)
print("Results match!")
Numba descendant edges time: 0.016229 seconds
tskit descendant edges time: 0.497277 seconds
Numba ancestral edges time: 0.025768 seconds
tskit ancestral edges time: 0.919687 seconds
Results match!

API Reference#

tskit.jit.numba.jitwrap(ts)[source]#

Convert a TreeSequence to a Numba-compatible format.

Creates a NumbaTreeSequence object that can be used within Numba-compiled functions.

Parameters:

ts (tskit.TreeSequence) – The tree sequence to convert.

Returns:

A Numba-compatible representation of the input tree sequence. Contains all necessary data arrays and metadata for tree traversal.

Return type:

NumbaTreeSequence

class tskit.jit.numba.NumbaTreeSequence(num_trees, num_nodes, num_samples, num_edges, num_sites, num_mutations, sequence_length, edges_left, edges_right, indexes_edge_insertion_order, indexes_edge_removal_order, individuals_flags, nodes_time, nodes_flags, nodes_population, nodes_individual, edges_parent, edges_child, sites_position, sites_ancestral_state, mutations_site, mutations_node, mutations_parent, mutations_time, mutations_derived_state, breakpoints, max_ancestral_length, max_derived_length)[source]#

A Numba-compatible representation of a tree sequence.

This class provides access a tree sequence class that can be used from within Numba “njit” compiled functions. jitwrap() should be used to JIT compile this class from a tskit.TreeSequence object, before it is passed to a Numba function.

Attributes#

num_treesint

Number of trees in the tree sequence.

num_nodesint

Number of nodes in the tree sequence.

num_samplesint

Number of samples in the tree sequence.

num_edgesint

Number of edges in the tree sequence.

num_sitesint

Number of sites in the tree sequence.

num_mutationsint

Number of mutations in the tree sequence.

sequence_lengthfloat

Total sequence length of the tree sequence.

edges_leftnumpy.ndarray

Array (dtype=np.float64) of left coordinates of edges.

edges_rightnumpy.ndarray

Array (dtype=np.float64) of right coordinates of edges.

edges_parentnumpy.ndarray

Array (dtype=np.int32) of parent node IDs for each edge.

edges_childnumpy.ndarray

Array (dtype=np.int32) of child node IDs for each edge.

nodes_timenumpy.ndarray

Array (dtype=np.float64) of time values for each node.

nodes_flagsnumpy.ndarray

Array (dtype=np.uint32) of flag values for each node.

nodes_populationnumpy.ndarray

Array (dtype=np.int32) of population IDs for each node.

nodes_individualnumpy.ndarray

Array (dtype=np.int32) of individual IDs for each node.

individuals_flagsnumpy.ndarray

Array (dtype=np.uint32) of flag values for each individual.

sites_positionnumpy.ndarray

Array (dtype=np.float64) of positions of sites along the sequence.

mutations_sitenumpy.ndarray

Array (dtype=np.int32) of site IDs for each mutation.

mutations_nodenumpy.ndarray

Array (dtype=np.int32) of node IDs for each mutation.

mutations_parentnumpy.ndarray

Array (dtype=np.int32) of parent mutation IDs.

mutations_timenumpy.ndarray

Array (dtype=np.float64) of time values for each mutation.

breakpointsnumpy.ndarray

Array (dtype=np.float64) of genomic positions where trees change.

indexes_edge_insertion_ordernumpy.ndarray

Array (dtype=np.int32) specifying the order in which edges are inserted during tree building.

indexes_edge_removal_ordernumpy.ndarray

Array (dtype=np.int32) specifying the order in which edges are removed during tree building.

tree_index()[source]#

Create a TreeIndex for traversing this tree sequence.

Returns:

A new tree index initialized to the null tree. Use next() or prev() to move to an actual tree.

Return type:

TreeIndex

child_index()[source]#

Create child index array for finding child edges of nodes. This operation requires a linear pass over the edge table and therefore has a time complexity of O(E).

Returns:

A numpy array (dtype=np.int32, shape=(num_nodes, 2)) where each row contains the [start, stop) range of edges where this node is the parent.

Return type:

numpy.ndarray

parent_index()[source]#

Create a ParentIndex for finding parent edges of nodes.

Edges within each child’s group are not guaranteed to be in any specific order. This operation uses a two-pass algorithm with O(N + E) time complexity and O(N) auxiliary space.

Returns:

A new parent index container that can be used to efficiently find all edges where a given node is the child.

Return type:

ParentIndex

class tskit.jit.numba.TreeIndex(ts)[source]#

Traverse trees in a numba compatible tree sequence.

This class provides efficient forward and backward iteration through the trees in a tree sequence. It provides the tree interval, edge changes to create the current tree, along with its sites and mutations. A full pass over the trees using repeated next or prev requires O(E + M + S) time complexity.

It should not be instantiated directly, but is returned by the tree_index method of NumbaTreeSequence.

Attributes#

tsNumbaTreeSequence

Reference to the tree sequence being traversed.

indexint

Current tree index. -1 indicates no current tree (null state).

directionint

Traversal direction: tskit.FORWARD or tskit.REVERSE. tskit.NULL if uninitialised.

intervaltuple

Genomic interval (left, right) covered by the current tree.

in_rangeEdgeRange

Edges being added to form this current tree, relative to the last state

out_rangeEdgeRange

Edges being removed to form this current tree, relative to the last state

site_rangetuple

Range of sites in the current tree (start, stop).

mutation_rangetuple

Range of mutations in the current tree (start, stop).

Example#

>>> tree_index = numba_ts.tree_index()
>>> num_edges = 0
>>> while tree_index.next():
        num_edges += (tree_index.in_range.stop - tree_index.in_range.start)
        num_edges -= (tree_index.out_range.stop - tree_index.out_range.start)
        print(f"Tree {tree_index.index}: {num_edges} edges")
set_null()[source]#

Reset the tree index to null state.

next()[source]#

Move to the next tree in forward direction.

Updates the tree index to the next tree in the sequence, computing the edges that need to be added and removed to transform from the previous tree to the current tree. On the first call, this initializes the iterator and moves to tree 0.

Returns:

True if successfully moved to next tree, False if the end of the tree sequence is reached. When False is returned, the iterator is in null state (index=-1).

Return type:

bool

prev()[source]#

Move to the previous tree in reverse direction.

Updates the tree index to the previous tree in the sequence, computing the edges that need to be added and removed to transform from the next tree to the current tree. On the first call, this initializes the iterator and moves to the most rightward tree.

Returns:

True if successfully moved to previous tree, False if the beginning of the tree sequence is reached. When False is returned, the iterator is in null state (index=-1).

Return type:

bool

class tskit.jit.numba.EdgeRange(*args, **kwargs)[source]#

Represents a range of edges during tree traversal.

This class encapsulates information about a contiguous range of edges that are either being removed or added to step from one tree to another. The start and stop indices, when applied to the order array, define the ids of edges to process.

Attributes#

startint

Starting index of the edge range (inclusive).

stopint

Stopping index of the edge range (exclusive).

ordernumpy.ndarray

Array (dtype=np.int32) containing edge IDs in the order they should be processed. The edge ids in this range are order[start:stop].

class tskit.jit.numba.ParentIndex(*args, **kwargs)[source]#

Simple data container for parent index information.

This class provides access to all edges where a given node is the child. Since edges are not sorted by child in the tskit edge table, a custom index (edge_index) is built that sorts edge IDs by child node. index_range then contains the [start, stop) range of edges for each child node in edge_index.

Attributes#

edge_indexnumpy.ndarray

Array (dtype=np.int32) of edge IDs sorted by child node and left coordinate.

index_rangenumpy.ndarray

Array (dtype=np.int32, shape=(num_nodes, 2)) where each row contains the [start, stop) range in edge_index where this node is the child.