Source code for tscompare.methods

# MIT License
#
# Copyright (c) 2025 Tskit Developers
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
Tools for comparing node times between tree sequences with different node sets
"""
import copy
import warnings
from collections import defaultdict
from dataclasses import dataclass
from itertools import product

import numpy as np
import scipy.sparse
import tskit


def compare(*args, **kwargs):
    warnings.warn(
        "compare() is deprecated and will be removed in the future; "
        "please use haplotype_arf() instead.",
        DeprecationWarning,
        stacklevel=1,
    )
    return haplotype_arf(*args, **kwargs)


[docs] def node_spans(ts, include_missing=False): """ Returns the array of "node spans", i.e., the `j`th entry gives the total span over which node `j` is in the tree sequence. Sample nodes that are isolated are "missing data"; inclusion of these spans are controlled by `include_missing`. (If `include_missing` is `True` then the span of each sample is always equal to the sequence length.) :param bool include_missing: Whether to include spans of nodes on which they have missing data. """ child_spans = np.bincount( ts.edges_child, weights=ts.edges_right - ts.edges_left, minlength=ts.num_nodes, ) for t in ts.trees(): span = t.span for r in t.roots: # do this check to exempt 'missing data' if include_missing or (t.num_children(r) > 0): child_spans[r] += span return child_spans
class CladeMap: """ An iterator across trees that maintains a mapping from a clade (a `frozenset` of sample IDs) to a `set` of nodes. When there are unary nodes, there may be multiple nodes associated with each clade. """ def __init__(self, ts): self._nil = frozenset() self._nodes = defaultdict(set) # nodes[clade] = {node ids} self._clades = defaultdict(frozenset) # clades[node] = {sample ids} self.tree_sequence = ts self.tree = ts.first(sample_lists=True) for node in self.tree.nodes(): clade = frozenset(self.tree.samples(node)) self._nodes[clade].add(node) self._clades[node] = clade self._prev = copy.deepcopy(self._clades) self._diff = ts.edge_diffs() next(self._diff) def _propagate(self, edge, downdate=False): """ Traverse path from `edge.parent` to root, either adding or removing the state (clade) associated with `edge.child` from the state of each visited node. Return a set with the node ids encountered during traversal. """ nodes = set() node = edge.parent clade = self._clades[edge.child] while node != tskit.NULL: last = self._clades[node] self._clades[node] = last - clade if downdate else last | clade if len(last): self._nodes[last].remove(node) if len(self._nodes[last]) == 0: del self._nodes[last] self._nodes[self._clades[node]].add(node) nodes.add(node) node = self.tree.parent(node) return nodes def next(self): # noqa: A003 """ Advance to the next tree, returning the difference between trees as a dictionary of the form `node : (last_clade, next_clade)` """ nodes = set() # nodes with potentially altered clades diff = {} # diff[node] = (prev_clade, curr_clade) if self.tree.index + 1 == self.tree_sequence.num_trees: return None # Subtract clades subtended by outgoing edges edge_diff = next(self._diff) for eo in edge_diff.edges_out: nodes |= self._propagate(eo, downdate=True) # Prune nodes that are no longer in tree for node in self._nodes[self._nil]: diff[node] = (self._prev[node], self._nil) del self._clades[node] nodes -= self._nodes[self._nil] self._nodes[self._nil].clear() # Add clades subtended by incoming edges self.tree.next() for ei in edge_diff.edges_in: nodes |= self._propagate(ei, downdate=False) # Find difference in clades between adjacent trees for node in nodes: diff[node] = (self._prev[node], self._clades[node]) if self._prev[node] == self._clades[node]: del diff[node] # Sync previous and current states for node, (_, curr) in diff.items(): if curr == self._nil: del self._prev[node] else: self._prev[node] = curr return diff @property def interval(self): """ Return interval spanned by tree """ return self.tree.interval def clades(self): """ Return set of clades in tree """ return self._nodes.keys() - self._nil def __getitem__(self, clade): """ Return set of nodes associated with a given clade. """ return frozenset(self._nodes[clade]) if frozenset(clade) in self else self._nil def __contains__(self, clade): """ Check if a clade is present in the tree """ return clade in self._nodes
[docs] def shared_node_spans(ts, other): """ Calculate the spans over which pairs of nodes in two tree sequences are ancestral to identical sets of samples. Returns a sparse matrix where rows correspond to nodes in `ts` and columns correspond to nodes in `other`, and whose value is the total amount of span over which the set of samples inheriting from the two nodes is identical. The shared span of a sample node with itself includes spans over which it has missing data. :return: A sparse matrix of class `scipy.sparse.csr_matrix`. """ if ts.sequence_length != other.sequence_length: raise ValueError("Tree sequences must be of equal sequence length.") if ts.num_samples != other.num_samples: raise ValueError("Tree sequences must have the same numbers of samples.") nil = frozenset() # Initialize clade iterators query = CladeMap(ts) target = CladeMap(other) # Initialize buffer[clade] = (query_nodes, target_nodes, left_coord) modified = query.clades() | target.clades() buffer = {} # Build sparse matrix of matches in triplet format query_node = [] target_node = [] shared_span = [] right = 0 while True: left = right right = min(query.interval[1], target.interval[1]) # Flush pairs of nodes that no longer have matching clades for clade in modified: # flush: if clade in buffer: n_i, n_j, start = buffer.pop(clade) span = left - start for i, j in product(n_i, n_j): query_node.append(i) target_node.append(j) shared_span.append(span) # Add new pairs of nodes with matching clades for clade in modified: assert clade not in buffer if clade in query and clade in target: n_i, n_j = query[clade], target[clade] buffer[clade] = (n_i, n_j, left) if right == ts.sequence_length: break # Find difference in clades with advance to next tree modified.clear() for clade_map in (query, target): if clade_map.interval[1] == right: clade_diff = clade_map.next() for prev, curr in clade_diff.values(): if prev != nil: modified.add(prev) if curr != nil: modified.add(curr) # Flush final tree for clade in buffer: n_i, n_j, start = buffer[clade] span = right - start for i, j in product(n_i, n_j): query_node.append(i) target_node.append(j) shared_span.append(span) numer = scipy.sparse.coo_matrix( (shared_span, (query_node, target_node)), shape=(ts.num_nodes, other.num_nodes), ).tocsr() return numer
[docs] def match_node_ages(ts, other): """ For each node in `ts`, return the age of a matched node from `other`. Node matching is accomplished as described in :func:`.haplotype_arf`. Returns a tuple of three vectors of length `ts.num_nodes`, in this order: the age of the best matching node in `other`; the proportion of the node span in `ts` that is covered by the best match; and the node id of the best match in `other`. :return: A tuple of arrays of length `ts.num_nodes` containing (time of matching node, proportion overlap, and node ID of match). """ shared_spans = shared_node_spans(ts, other) matched_span = shared_spans.max(axis=1).todense().A1 best_match = shared_spans.argmax(axis=1).A1 # NB: if there are multiple nodes with the largest span in a row, # argmax returns the node with the smallest integer id matched_time = other.nodes_time[best_match] best_match[matched_span == 0] = tskit.NULL matched_time[matched_span == 0] = np.nan return matched_time, matched_span, best_match
[docs] @dataclass class ARFResult: """ The result of a call to tscompare.haplotype_arf(ts, other), returning metrics associated with the ARG Robinson-Foulds measures of similarity and dissimilarity. This contains: `arf`: The ARG Robinson-Foulds relative dissimilarity: the proportion of the total span of `ts` that is *not* represented in `other`. This is: `1 - matched_span[0] / total_span[0]` `tpr`: The "true proportion represented": the proportion of the total span of `other` that is represented in `ts`. This is: `matched_span[1] / total_span[1]` `matched_span`: The total matched node spans between `ts` and `other`, in order (`match`, `inverse_match`), where `match` is the total span of `ts` that is represented in `other`, and `inverse_match` is the total span of `other` that is represented in `ts`. `total_span`: The total of all node spans of the two tree sequences, in order (`ts`, `other`). `rmse`: The root-mean-squared error between the transformed times of the nodes in `ts` and the transformed times of their best-matching nodes in `other`, with the average taken weighting by span in `ts`. `transform`: The transformation function used to transform times for computing `rmse`. """ arf: float tpr: float matched_span: tuple total_span: tuple rmse: float transform: callable def __str__(self): """ Return a plain text summary of the ARF result. """ out = "Tree sequence comparison:\n" out += f" ARF: {100 * self.arf:.2f}%\n" out += f" TPR: {100 * self.tpr:.2f}%\n" out += f" matched_span: {self.matched_span}\n" out += ( f" total span (ts, other): {self.total_span[0]}, {self.total_span[1]}\n" ) out += f" time RMSE: {self.rmse}\n" return out
[docs] def haplotype_arf(ts, other, transform=None): """ For two tree sequences `ts` and `other`, this method returns an object of type :class:`.ARFResult`. The values reported summarize the degree to which nodes in `ts` "match" corresponding nodes in `other`. To match nodes, for each node in `ts`, the best matching node(s) from `other` has the longest matching span using :func:`.shared_node_spans`. If there are multiple matches with the same longest shared span for a single node, the best match is the match that is closest in time. This requires that the samples are the same in both tree sequences: in other words, if node `i` is a sample node in `ts`, then node `i` is also a sample node in `other` (and vice-versa). For each node in `other` we compute the inverse matched span as the maximum shared span amongst all nodes in `ts` for which that node is their best match. Then, :class:`.ARFResult` contains: - (`arf`) The fraction of the total span of `ts` over which each nodes' descendant sample set does not match its' best match's descendant sample set (i.e., the total *un*-matched span divided by the total span of `ts`). - (`tpr`) The proportion of the span in `other` that is correctly represented in `ts` (i.e., the total inverse matching span divided by the total span of `other`). - (`matched_span`) The total "matching" and "inverse matching" spans between `ts` and `other`. The "matching span" is the total span of all nodes in `ts` over which each node is ancestral to the same set of samples as its best match in `other`. The "inverse matching span" is the total span of all nodes in `other` over which each node is ancestral to the same set of sample as its best match in `ts`. - (`total_span`) The total node spans of `ts` and `other`. - (`rmse`) The root mean squared difference between the transformed times of the nodes in `ts` and transformed times of their best matching nodes in `other`, with the average weighted by the nodes' spans in `ts`. The callable `transform` is used to transform times before computing root-mean-squared error (see :class:`.ARFResult`); the default is `log(1 + t)`. :param ts: The focal tree sequence. :param other: The tree sequence we compare to. :param transform: A callable that can take an array of times and return another array of numbers. :return: The three quantities above. :rtype: ARFResult """ samples = ts.samples() if ts.num_samples != other.num_samples or np.any(samples != other.samples()): raise ValueError("Samples in `ts` and `other` must agree.") def f(t): return np.log(1 + t) if transform is None: transform = f ts_node_spans = node_spans(ts, include_missing=True) shared_spans = shared_node_spans(ts, other) if min(ts.num_nodes, other.num_nodes) > 0: col_ind = shared_spans.indices row_ind = np.repeat( np.arange(shared_spans.shape[0]), repeats=np.diff(shared_spans.indptr) ) # We require that the samples are the same in both trees! # If we did not require this, we could identify swapped samples, # but this is out of scope (people could detect this using # the shared spans matrix directly). is_sample = np.full(max(ts.num_nodes, other.num_nodes), False) is_sample[samples] = True index_not_equal = ~np.equal(row_ind, col_ind) shared_spans.data[np.logical_and(is_sample[row_ind], index_not_equal)] = 0.0 # Find all potential matches for a node based on max shared span length max_span = shared_spans.max(axis=1).toarray().flatten() total_match_n1_span = np.sum(max_span) # <---- one thing to output # zero out everything that's not a row max shared_spans.data[shared_spans.data != max_span[row_ind]] = 0.0 # now re-sparsify the matrix: but, beware! don't do this again later. shared_spans.eliminate_zeros() col_ind = shared_spans.indices row_ind = np.repeat( np.arange(shared_spans.shape[0]), repeats=np.diff(shared_spans.indptr) ) # now, make a matrix with differences in transformed times # in the places where shared_spans retains nonzero elements time_diff = shared_spans.copy() ts_times = ts.nodes_time[row_ind] other_times = other.nodes_time[col_ind] time_diff.data[:] = np.absolute( np.asarray(transform(ts_times) - transform(other_times)) ) # "explicit=True" takes the min of only the entries explicitly represented dt = time_diff.min(axis=1, explicit=True).toarray().flatten() # next, zero out also those non-best-time-match elements shared_spans.data[time_diff.data != dt[row_ind]] = 0.0 # and, find sum of column maxima total_match_n2_span = shared_spans.max( axis=0 ).sum() # <--- the other thing we return else: max_span = 0 total_match_n1_span = 0 total_match_n2_span = 0 has_match = max_span != 0 if np.any(has_match): rmse = np.sqrt( np.sum(dt[has_match] ** 2 * ts_node_spans[has_match]) / np.sum(ts_node_spans[has_match]) ) # ^-- another thing to output else: rmse = np.nan total_span_ts = np.sum(ts_node_spans) total_span_other = np.sum(node_spans(other, include_missing=True)) return ARFResult( arf=1.0 - total_match_n1_span / total_span_ts, tpr=total_match_n2_span / total_span_other, matched_span=(total_match_n1_span, total_match_n2_span), total_span=(total_span_ts, total_span_other), rmse=rmse, transform=transform, )