Source code for tsdate.core

# MIT License
#
# Copyright (c) 2021-24 Tskit Developers
# Copyright (c) 2020-21 University of Oxford
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
Infer the age of nodes conditional on a tree sequence topology.
"""

import functools
import itertools
import logging
import multiprocessing
import operator
import time  # DEBUG
from collections import defaultdict, namedtuple

import numba
import numpy as np
import scipy.stats
import tskit
from tqdm.auto import tqdm

from . import base, demography, prior, provenance, schemas, util, variational

logger = logging.getLogger(__name__)

FORMAT_NAME = "tsdate"
DEFAULT_RESCALING_INTERVALS = 1000
DEFAULT_RESCALING_ITERATIONS = 5
DEFAULT_MAX_ITERATIONS = 10
DEFAULT_EPSILON = 1e-6


class Likelihoods:
    """
    A class to store and process likelihoods. Likelihoods for edges are stored as a
    flattened lower triangular matrix of all the possible delta t's. This class also
    provides methods for accessing this lower triangular matrix, multiplying it, etc.

    If ``standardize`` is true, routines will operate to standardize the likelihoods
    such that their maximum is one (in linear space) or zero (in log space)
    """

    probability_space = base.LIN
    identity_constant = 1.0
    null_constant = 0.0

    def __init__(
        self,
        ts,
        timepoints,
        mutation_rate=None,
        recombination_rate=None,
        *,
        eps=0,
        fixed_node_set=None,
        standardize=False,
        progress=False,
    ):
        self.ts = ts
        self.timepoints = timepoints
        self.fixednodes = set(ts.samples()) if fixed_node_set is None else fixed_node_set
        self.mut_rate = mutation_rate
        self.rec_rate = recombination_rate
        self.standardize = standardize
        self.grid_size = len(timepoints)
        self.tri_size = self.grid_size * (self.grid_size + 1) / 2
        self.ll_mut = {}
        self.mut_edges = self.get_mut_edges(ts)
        self.progress = progress
        # Need to set eps properly in the 2 lines below, to account for values in the
        # same timeslice
        self.timediff_lower_tri = np.concatenate(
            [
                self.timepoints[time_index] - self.timepoints[0 : time_index + 1] + eps
                for time_index in np.arange(len(self.timepoints))
            ]
        )
        self.timediff = self.timepoints - self.timepoints[0] + eps

        # The mut_ll contains unpacked (1D) lower triangular matrices. We need to
        # index this by row and by column index.
        self.row_indices = []
        for t in range(self.grid_size):
            n = np.arange(self.grid_size)
            self.row_indices.append((((n * (n + 1)) // 2) + t)[t:])
        self.col_indices = []
        running_sum = 0  # use this to find the index of the last element of
        # each column in order to appropriately sum the vv by columns.
        for i in np.arange(self.grid_size):
            arr = np.arange(running_sum, running_sum + self.grid_size - i)
            index = arr[-1]
            running_sum = index + 1
            val = arr[0]
            self.col_indices.append(val)

        # These are used for transforming an array of grid_size into one of tri_size
        # By repeating elements over rows to form an upper or a lower triangular matrix
        self.to_lower_tri = np.concatenate(
            [np.arange(time_idx + 1) for time_idx in np.arange(self.grid_size)]
        )
        self.to_upper_tri = np.concatenate(
            [
                np.arange(time_idx, self.grid_size)
                for time_idx in np.arange(self.grid_size + 1)
            ]
        )

    @staticmethod
    def get_mut_edges(ts):
        """
        Get the number of mutations on each edge in the tree sequence.
        """
        mut_edges = np.zeros(ts.num_edges, dtype=np.int64)
        for m in ts.mutations():
            if m.edge != tskit.NULL:
                mut_edges[m.edge] += 1
        return mut_edges

    @staticmethod
    def _lik(muts, span, dt, mutation_rate, standardize=True):
        """
        The likelihood of an edge given a number of mutations, as set of time deltas (dt)
        and a span. This is a static function to allow parallelization
        """
        ll = scipy.stats.poisson.pmf(muts, dt * mutation_rate * span)
        if standardize:
            return ll / np.max(ll)
        else:
            return ll

    @staticmethod
    def _lik_wrapper(muts_span, dt, mutation_rate, standardize=True):
        """
        A wrapper to allow this _lik to be called by pool.imap_unordered, returning the
        mutation and span values
        """
        return muts_span, Likelihoods._lik(
            muts_span[0], muts_span[1], dt, mutation_rate, standardize=standardize
        )

    def precalculate_mutation_likelihoods(self, num_threads=None, unique_method=0):
        """
        We precalculate these because the pmf function is slow, but can be trivially
        parallelised. We store the likelihoods in a cache because they only depend on
        the number of mutations and the span, so can potentially be reused.

        However, we don't bother storing the likelihood for edges above a *fixed* node,
        because (a) these are only used once per node and (b) sample edges are often
        long, and hence their span will be unique. This also allows us to deal easily
        with fixed nodes at explicit times (rather than in time slices)
        """

        if self.mut_rate is None:
            raise RuntimeError(
                "Cannot calculate mutation likelihoods with no mutation_rate set"
            )
        if unique_method == 0:
            self.unfixed_likelihood_cache = {
                (muts, e.span): None
                for muts, e in zip(self.mut_edges, self.ts.edges())
                if e.child not in self.fixednodes
            }
        else:
            fixed_nodes = np.array(list(self.fixednodes))
            keys = np.unique(
                np.core.records.fromarrays(
                    (self.mut_edges, self.ts.edges_right - self.ts.edges_left),
                    names="muts,span",
                )[np.logical_not(np.isin(self.ts.edges_child, fixed_nodes))]
            )
            if unique_method == 1:
                self.unfixed_likelihood_cache = dict.fromkeys({tuple(t) for t in keys})
            else:
                self.unfixed_likelihood_cache = {tuple(t): None for t in keys}

        if num_threads:
            f = functools.partial(  # Set constant values for params for static _lik
                self._lik_wrapper,
                dt=self.timediff_lower_tri,
                mutation_rate=self.mut_rate,
                standardize=self.standardize,
            )
            if num_threads == 1:
                # Useful for testing
                for key in tqdm(
                    self.unfixed_likelihood_cache.keys(),
                    disable=not self.progress,
                    desc="Precalculating Likelihoods",
                ):
                    returned_key, likelihoods = f(key)
                    self.unfixed_likelihood_cache[returned_key] = likelihoods
            else:
                with tqdm(
                    total=len(self.unfixed_likelihood_cache.keys()),
                    disable=not self.progress,
                    desc="Precalculating Likelihoods",
                ) as prog_bar:
                    with multiprocessing.Pool(processes=num_threads) as pool:
                        for key, pmf in pool.imap_unordered(
                            f, self.unfixed_likelihood_cache.keys()
                        ):
                            self.unfixed_likelihood_cache[key] = pmf
                            prog_bar.update()
        else:
            for muts, span in tqdm(
                self.unfixed_likelihood_cache.keys(),
                disable=not self.progress,
                desc="Precalculating Likelihoods",
            ):
                self.unfixed_likelihood_cache[muts, span] = self._lik(
                    muts,
                    span,
                    dt=self.timediff_lower_tri,
                    mutation_rate=self.mut_rate,
                    standardize=self.standardize,
                )

    def get_mut_lik_fixed_node(self, edge):
        """
        Get the mutation likelihoods for an edge whose child is at a
        fixed time, but whose parent may take any of the time slices in the timepoints
        that are equal to or older than the child age. This is not cached, as it is
        likely to be unique for each edge
        """
        assert (
            edge.child in self.fixednodes
        ), "Wrongly called fixed node function on non-fixed node"
        assert (
            self.mut_rate is not None
        ), "Cannot calculate mutation likelihoods with no mutation_rate set"

        mutations_on_edge = self.mut_edges[edge.id]
        child_time = self.ts.node(edge.child).time
        assert child_time == 0
        # Temporary hack - we should really take a more precise likelihood
        return self._lik(
            mutations_on_edge,
            edge.span,
            self.timediff,
            self.mut_rate,
            standardize=self.standardize,
        )

    def get_mut_lik_lower_tri(self, edge):
        """
        Get the cached mutation likelihoods for an edge with non-fixed parent and child
        nodes, returning values for all the possible time differences between timepoints
        These values are returned as a flattened lower triangular matrix, the
        form required in the inside algorithm.

        """
        # Debugging asserts - should probably remove eventually
        assert (
            edge.child not in self.fixednodes
        ), "Wrongly called lower_tri function on fixed node"
        assert hasattr(
            self, "unfixed_likelihood_cache"
        ), "Must call `precalculate_mutation_likelihoods()` before getting likelihoods"

        mutations_on_edge = self.mut_edges[edge.id]
        return self.unfixed_likelihood_cache[mutations_on_edge, edge.span]

    def get_mut_lik_upper_tri(self, edge):
        """
        Same as :meth:`get_mut_lik_lower_tri`, but the returned array is ordered as
        flattened upper triangular matrix (suitable for the outside algorithm), rather
        than a lower triangular one
        """
        return self.get_mut_lik_lower_tri(edge)[np.concatenate(self.row_indices)]

    # The following functions don't access the likelihoods directly, but allow
    # other input arrays of length grid_size to be repeated in such a way that they can
    # be directly multiplied by the unpacked lower triangular matrix, or arrays of length
    # of the number of cells in the lower triangular matrix to be summed (e.g. by row)
    # to give a shorter array of length grid_size

    def make_lower_tri(self, input_array):
        """
        Repeat the input array row-wise to make a flattened lower triangular matrix
        """
        assert len(input_array) == self.grid_size
        return input_array[self.to_lower_tri]

    def rowsum_lower_tri(self, input_array):
        """
        Describe the reduceat trickery here. Presumably the opposite of make_lower_tri
        """
        assert len(input_array) == self.tri_size
        return np.add.reduceat(input_array, self.row_indices[0])

    def make_upper_tri(self, input_array):
        """
        Repeat the input array row-wise to make a flattened upper triangular matrix
        """
        assert len(input_array) == self.grid_size
        return input_array[self.to_upper_tri]

    def rowsum_upper_tri(self, input_array):
        """
        Describe the reduceat trickery here. Presumably the opposite of make_upper_tri
        """
        assert len(input_array) == self.tri_size
        return np.add.reduceat(input_array, self.col_indices)

    # Mutation & recombination algorithms on a tree sequence

    def n_breaks(self, edge):
        """
        Number of known breakpoints, only used in recombination likelihood calc
        """
        return (edge.left != 0) + (edge.right != self.ts.get_sequence_length())

    def combine(self, lik_1, lik_2):
        return lik_1 * lik_2

    def ratio(self, lik_1, lik_2, div_0_null=False):
        """
        Return the ratio of lik_1 to lik_2. In linear space, this divides lik_1 by lik_2
        If div_0_null==True, then 0/0 is set to the null_constant
        """
        with np.errstate(divide="ignore", invalid="ignore"):
            ret = lik_1 / lik_2
        if div_0_null:
            ret[np.isnan(ret)] = self.null_constant
        return ret

    def marginalize(self, lik):
        """
        Return the sum of likelihoods
        """
        return np.sum(lik)

    def _recombination_lik(self, edge, fixed=True):
        # Needs to return a lower tri *or* flattened array depending on `fixed`
        raise NotImplementedError(
            "Using the recombination clock is not currently supported"
            ". See https://github.com/awohns/tsdate/issues/5 for details"
        )
        # return (
        #     np.power(prev_state, self.n_breaks(edge)) *
        #     np.exp(-(prev_state * self.rec_rate * edge.span * 2)))

    def get_inside(self, arr, edge):
        liks = self.identity_constant
        if self.rec_rate is not None:
            liks = self._recombination_lik(edge)
        if self.mut_rate is not None:
            liks *= self.get_mut_lik_lower_tri(edge)
        return self.rowsum_lower_tri(arr * liks)

    def get_outside(self, arr, edge):
        liks = self.identity_constant
        if self.rec_rate is not None:
            liks = self._recombination_lik(edge)
        if self.mut_rate is not None:
            liks *= self.get_mut_lik_upper_tri(edge)
        return self.rowsum_upper_tri(arr * liks)

    def get_fixed(self, arr, edge):
        liks = self.identity_constant
        if self.rec_rate is not None:
            liks = self._recombination_lik(edge, fixed=True)
        if self.mut_rate is not None:
            liks *= self.get_mut_lik_fixed_node(edge)
        return arr * liks

    def scale_geometric(self, fraction, value):
        return value**fraction


class LogLikelihoods(Likelihoods):
    """
    Identical to the Likelihoods class but stores and returns log likelihoods
    """

    probability_space = base.LOG
    identity_constant = 0.0
    null_constant = -np.inf

    """
    Uses an alternative to logsumexp, useful for large grid sizes, see
    http://www.nowozin.net/sebastian/blog/streaming-log-sum-exp-computation.html
    """

    @staticmethod
    @numba.jit(nopython=True)
    def logsumexp(X):
        alpha = -np.inf
        r = 0.0
        for x in X:
            if x != -np.inf:
                if x <= alpha:
                    r += np.exp(x - alpha)
                else:
                    r *= np.exp(alpha - x)
                    r += 1.0
                    alpha = x
        return -np.inf if r == 0 else np.log(r) + alpha

    @staticmethod
    def _lik(muts, span, dt, mutation_rate, standardize=True):
        """
        The likelihood of an edge given a number of mutations, as set of time deltas (dt)
        and a span. This is a static function to allow parallelization
        """
        ll = scipy.stats.poisson.logpmf(muts, dt * mutation_rate * span)
        if standardize:
            return ll - np.max(ll)
        else:
            return ll

    @staticmethod
    def _lik_wrapper(muts_span, dt, mutation_rate, standardize=True):
        """
        Needs redefining to refer to the LogLikelihoods class
        """
        return muts_span, LogLikelihoods._lik(
            muts_span[0], muts_span[1], dt, mutation_rate, standardize=standardize
        )

    def rowsum_lower_tri(self, input_array):
        """
        The function below is equivalent to (but numba makes it faster than)
        np.logaddexp.reduceat(input_array, self.row_indices[0])
        """
        assert len(input_array) == self.tri_size
        res = list()
        i_start = self.row_indices[0][0]
        for i in self.row_indices[0][1:]:
            res.append(self.logsumexp(input_array[i_start:i]))
            i_start = i
        res.append(self.logsumexp(input_array[i:]))
        return np.array(res)

    def rowsum_upper_tri(self, input_array):
        """
        The function below is equivalent to (but numba makes it faster than)
        np.logaddexp.reduceat(input_array, self.col_indices)
        """
        assert len(input_array) == self.tri_size
        res = list()
        i_start = self.col_indices[0]
        for i in self.col_indices[1:]:
            res.append(self.logsumexp(input_array[i_start:i]))
            i_start = i
        res.append(self.logsumexp(input_array[i:]))
        return np.array(res)

    def _recombination_loglik(self, edge, fixed=True):
        # Needs to return a lower tri *or* flattened array depending on `fixed`
        raise NotImplementedError(
            "Using the recombination clock is not currently supported"
            ". See https://github.com/awohns/tsdate/issues/5 for details"
        )
        # return (
        #     np.power(prev_state, self.n_breaks(edge)) *
        #     np.exp(-(prev_state * self.rec_rate * edge.span * 2)))

    def combine(self, loglik_1, loglik_2):
        return loglik_1 + loglik_2

    def ratio(self, loglik_1, loglik_2, div_0_null=False):
        """
        In log space, likelihood ratio is loglik_1 - loglik_2
        If div_0_null==True, then if either is -inf it returns -inf (the null_constant)
        """
        with np.errstate(divide="ignore", invalid="ignore"):
            ret = loglik_1 - loglik_2
        if div_0_null:
            ret[np.isnan(ret)] = self.null_constant
        return ret

    def marginalize(self, loglik):
        """
        Return the logged sum of likelihoods
        """
        return self.logsumexp(loglik)

    def get_inside(self, arr, edge):
        log_liks = self.identity_constant
        if self.rec_rate is not None:
            log_liks = self._recombination_loglik(edge)
        if self.mut_rate is not None:
            log_liks += self.get_mut_lik_lower_tri(edge)
        return self.rowsum_lower_tri(arr + log_liks)

    def get_outside(self, arr, edge):
        log_liks = self.identity_constant
        if self.rec_rate is not None:
            log_liks = self._recombination_loglik(edge)
        if self.mut_rate is not None:
            log_liks += self.get_mut_lik_upper_tri(edge)
        return self.rowsum_upper_tri(arr + log_liks)

    def get_fixed(self, arr, edge):
        log_liks = self.identity_constant
        if self.rec_rate is not None:
            log_liks = self._recombination_loglik(edge, fixed=True)
        if self.mut_rate is not None:
            log_liks += self.get_mut_lik_fixed_node(edge)
        return arr + log_liks

    def scale_geometric(self, fraction, value):
        return fraction * value


class InOutAlgorithms:
    """
    Contains the inside and outside algorithms
    """

    def __init__(self, priors, lik, *, progress=False):
        if (
            lik.fixednodes.intersection(priors.nonfixed_nodes)
            or len(lik.fixednodes) + len(priors.nonfixed_nodes) != lik.ts.num_nodes
        ):
            raise ValueError(
                "The prior and likelihood objects disagree on which nodes are fixed"
            )
        if not np.allclose(lik.timepoints, priors.timepoints):
            raise ValueError(
                "The prior and likelihood objects disagree on the timepoints used"
            )

        self.priors = priors
        self.nonfixed_nodes = priors.nonfixed_nodes
        self.lik = lik
        self.ts = lik.ts

        self.fixednodes = lik.fixednodes
        self.progress = progress
        # If necessary, convert priors to log space
        self.priors.force_probability_space(lik.probability_space)

        self.spans = np.bincount(
            self.ts.edges_child,
            weights=self.ts.edges_right - self.ts.edges_left,
        )
        self.spans = np.pad(self.spans, (0, self.ts.num_nodes - len(self.spans)))

        self.root_spans = defaultdict(float)
        for tree in self.ts.trees(root_threshold=2):
            if tree.has_single_root:
                self.root_spans[tree.root] += tree.span
        # Add on the spans when this is a root
        for root, span_when_root in self.root_spans.items():
            self.spans[root] += span_when_root

    # === Grouped edge iterators ===

    def edges_by_parent_asc(self, grouped=True):
        """
        Return an itertools.groupby object of edges grouped by parent in ascending order
        of the time of the parent. Since tree sequence properties guarantee that edges
        are listed in nondecreasing order of parent time
        (https://tskit.readthedocs.io/en/latest/data-model.html#edge-requirements)
        we can simply use the standard edge order
        """
        if grouped:
            return itertools.groupby(self.ts.edges(), operator.attrgetter("parent"))
        else:
            return self.ts.edges()

    def edges_by_child_desc(self, grouped=True):
        """
        Return an itertools.groupby object of edges grouped by child in descending order
        of the time of the child.
        """
        it = (
            self.ts.edge(u)
            for u in np.lexsort(
                (self.ts.edges_child, -self.ts.nodes_time[self.ts.edges_child])
            )
        )
        if grouped:
            return itertools.groupby(it, operator.attrgetter("child"))
        else:
            return it

    def edges_by_child_then_parent_desc(self, grouped=True):
        """
        Return an itertools.groupby object of edges grouped by child in descending order
        of the time of the child, then by descending order of age of child
        """
        wtype = np.dtype(
            [
                ("child_age", self.ts.nodes_time.dtype),
                ("child_node", self.ts.edges_child.dtype),
                ("parent_age", self.ts.nodes_time.dtype),
            ]
        )
        w = np.empty(self.ts.num_edges, dtype=wtype)
        w["child_age"] = self.ts.nodes_time[self.ts.edges_child]
        w["child_node"] = self.ts.edges_child
        w["parent_age"] = -self.ts.nodes_time[self.ts.edges_parent]
        sorted_child_parent = (
            self.ts.edge(i)
            for i in reversed(
                np.argsort(w, order=("child_age", "child_node", "parent_age"))
            )
        )
        if grouped:
            return itertools.groupby(sorted_child_parent, operator.attrgetter("child"))
        else:
            return sorted_child_parent

    # === MAIN ALGORITHMS ===

    def inside_pass(self, *, standardize=True, cache_inside=False, progress=None):
        """
        Use dynamic programming to find approximate posterior to sample from
        """
        if progress is None:
            progress = self.progress
        inside = self.priors.clone_with_new_data(  # store inside matrix values
            grid_data=np.nan, fixed_data=self.lik.identity_constant
        )
        if cache_inside:
            g_i = np.full(
                (self.ts.num_edges, self.lik.grid_size), self.lik.identity_constant
            )
        denominator = np.full(self.ts.num_nodes, np.nan)
        assert (
            self.lik.standardize is False
        ), "Marginal likelihood requires unstandardized mutation likelihoods"
        marginal_lik = self.lik.identity_constant
        # Iterate through the nodes via groupby on parent node
        for parent, edges in tqdm(
            self.edges_by_parent_asc(),
            desc="Inside",
            total=inside.num_nonfixed,
            disable=not progress,
        ):
            """
            for each node, find the conditional prob of age at every time
            in time grid
            """
            if parent in self.fixednodes:
                continue  # there is no hidden state for this parent - it's fixed
            val = self.priors[parent].copy()
            for edge in edges:
                spanfrac = edge.span / self.spans[edge.child]
                # Calculate vals for each edge
                if edge.child in self.fixednodes:
                    # NB: geometric scaling works exactly when all nodes fixed in graph
                    # but is an approximation when times are unknown.
                    daughter_val = self.lik.scale_geometric(spanfrac, inside[edge.child])
                    edge_lik = self.lik.get_fixed(daughter_val, edge)
                else:
                    inside_values = inside[edge.child]
                    if np.ndim(inside_values) == 0 or np.all(np.isnan(inside_values)):
                        # Child appears fixed, or we have not visited it. Either our
                        # edge order is wrong (bug) or we have hit a dangling node
                        raise ValueError(
                            "The input tree sequence includes "
                            "dangling nodes: please simplify it"
                        )
                    daughter_val = self.lik.scale_geometric(
                        spanfrac, self.lik.make_lower_tri(inside[edge.child])
                    )
                    edge_lik = self.lik.get_inside(daughter_val, edge)
                val = self.lik.combine(val, edge_lik)
                if cache_inside:
                    g_i[edge.id] = edge_lik
            denominator[parent] = (
                np.max(val) if standardize else self.lik.identity_constant
            )
            inside[parent] = self.lik.ratio(val, denominator[parent])
            if standardize:
                marginal_lik = self.lik.combine(marginal_lik, denominator[parent])
        if cache_inside:
            self.g_i = self.lik.ratio(g_i, denominator[self.ts.edges_child, None])
        # Keep the results in this object
        self.inside = inside
        self.denominator = denominator
        # Calculate marginal likelihood
        for root, span_when_root in self.root_spans.items():
            spanfrac = span_when_root / self.spans[root]
            root_val = self.lik.scale_geometric(spanfrac, inside[root])
            marginal_lik = self.lik.combine(marginal_lik, self.lik.marginalize(root_val))
        return marginal_lik

    def outside_pass(
        self,
        *,
        standardize=False,
        ignore_oldest_root=False,
        progress=None,
    ):
        """
        Computes the full posterior distribution on nodes, returning the
        posterior values. These are *not* probabilities, as they do not sum to one:
        to convert to probabilities, call posterior.to_probabilities()

        Standardizing *during* the outside process may be necessary if there is
        overflow, but means that we cannot check the total functional value at each node

        Ignoring the oldest root may also be necessary when the oldest root node
        causes numerical stability issues.

        The rows in the posterior returned correspond to node IDs as given by
        self.nodes
        """
        if progress is None:
            progress = self.progress
        if not hasattr(self, "inside"):
            raise RuntimeError("You have not yet run the inside algorithm")

        outside = self.inside.clone_with_new_data(grid_data=0, probability_space=base.LIN)
        for root, span_when_root in self.root_spans.items():
            outside[root] = span_when_root / self.spans[root]
        outside.force_probability_space(self.inside.probability_space)

        for child, edges in tqdm(
            self.edges_by_child_desc(),
            desc="Outside",
            total=len(np.unique(self.ts.edges_child)),
            disable=not progress,
        ):
            if child in self.fixednodes:
                continue
            val = np.full(self.lik.grid_size, self.lik.identity_constant)
            for edge in edges:
                if ignore_oldest_root:
                    if edge.parent == self.ts.num_nodes - 1:
                        continue
                if edge.parent in self.fixednodes:
                    raise RuntimeError(
                        "Fixed nodes cannot currently be parents in the TS"
                    )
                # Geometric scaling works exactly for all nodes fixed in graph
                # but is an approximation when times are unknown.
                spanfrac = edge.span / self.spans[child]
                try:
                    inside_div_gi = self.lik.ratio(
                        self.inside[edge.parent], self.g_i[edge.id], div_0_null=True
                    )
                except AttributeError:  # we haven't cached g_i so we recalculate
                    daughter_val = self.lik.scale_geometric(
                        spanfrac, self.lik.make_lower_tri(self.inside[edge.child])
                    )
                    edge_lik = self.lik.get_inside(daughter_val, edge)
                    cur_g_i = self.lik.ratio(edge_lik, self.denominator[child])
                    inside_div_gi = self.lik.ratio(
                        self.inside[edge.parent], cur_g_i, div_0_null=True
                    )
                parent_val = self.lik.scale_geometric(
                    spanfrac,
                    self.lik.make_upper_tri(
                        self.lik.combine(outside[edge.parent], inside_div_gi)
                    ),
                )
                if standardize:
                    parent_val = self.lik.ratio(parent_val, np.max(parent_val))
                edge_lik = self.lik.get_outside(parent_val, edge)
                val = self.lik.combine(val, edge_lik)

            # vv[0] = 0  # Seems a hack: internal nodes should be allowed at time 0
            assert self.denominator[edge.child] > self.lik.null_constant
            outside[child] = self.lik.ratio(val, self.denominator[child])
            if standardize:
                outside[child] = self.lik.ratio(val, np.max(val))
        self.outside = outside
        posterior = outside.clone_with_new_data(
            grid_data=self.lik.combine(self.inside.grid_data, outside.grid_data),
            fixed_data=np.nan,
        )  # We should never use the posterior for a fixed node
        return posterior

    def outside_maximization(self, *, eps, progress=None):
        if progress is None:
            progress = self.progress
        if not hasattr(self, "inside"):
            raise RuntimeError("You have not yet run the inside algorithm")

        maximized_node_times = np.zeros(self.ts.num_nodes, dtype="int")

        if self.lik.probability_space == base.LOG:
            poisson = scipy.stats.poisson.logpmf
        elif self.lik.probability_space == base.LIN:
            poisson = scipy.stats.poisson.pmf

        mut_edges = self.lik.mut_edges
        mrcas = np.where(
            np.isin(np.arange(self.ts.num_nodes), self.ts.edges_child, invert=True)
        )[0]
        for i in mrcas:
            if i not in self.fixednodes:
                maximized_node_times[i] = np.argmax(self.inside[i])

        for child, edges in tqdm(
            self.edges_by_child_then_parent_desc(),
            desc="Maximization",
            total=len(np.unique(self.ts.edges_child)),
            disable=not progress,
        ):
            if child in self.fixednodes:
                continue
            for edge_index, edge in enumerate(edges):
                if edge_index == 0:
                    youngest_par_index = maximized_node_times[edge.parent]
                    parent_time = self.lik.timepoints[maximized_node_times[edge.parent]]
                    ll_mut = poisson(
                        mut_edges[edge.id],
                        (
                            parent_time
                            - self.lik.timepoints[: youngest_par_index + 1]
                            + eps
                        )
                        * self.lik.mut_rate
                        * edge.span,
                    )
                    result = self.lik.ratio(ll_mut, np.max(ll_mut))
                else:
                    cur_parent_index = maximized_node_times[edge.parent]
                    if cur_parent_index < youngest_par_index:
                        youngest_par_index = cur_parent_index
                    parent_time = self.lik.timepoints[maximized_node_times[edge.parent]]
                    ll_mut = poisson(
                        mut_edges[edge.id],
                        (
                            parent_time
                            - self.lik.timepoints[: youngest_par_index + 1]
                            + eps
                        )
                        * self.lik.mut_rate
                        * edge.span,
                    )
                    result[: youngest_par_index + 1] = self.lik.combine(
                        self.lik.ratio(
                            ll_mut[: youngest_par_index + 1],
                            np.max(ll_mut[: youngest_par_index + 1]),
                        ),
                        result[: youngest_par_index + 1],
                    )
            inside_val = self.inside[child][: (youngest_par_index + 1)]

            maximized_node_times[child] = np.argmax(
                self.lik.combine(result[: youngest_par_index + 1], inside_val)
            )

        return self.lik.timepoints[np.array(maximized_node_times).astype("int")]


# Classes for each method
Results = namedtuple(
    "Results",
    [
        "posterior_mean",
        "posterior_var",
        "posterior_obj",
        "mutation_mean",
        "mutation_var",
        "mutation_lik",
        "mutation_edge",
        "mutation_node",
    ],
)


class EstimationMethod:
    """
    Base class to hold the various estimation methods. Override prior_grid_func_name with
    something like "parameter_grid" or "prior_grid".
    """

    prior_grid_func_name = None

    def run():
        # Subclasses should override to return a return a Results object
        raise NotImplementedError(
            "Base class 'EstimationMethod' not intended for direct use"
        )

    def __init__(
        self,
        ts,
        *,
        mutation_rate=None,
        population_size=None,
        recombination_rate=None,
        time_units=None,
        priors=None,
        return_posteriors=None,
        return_likelihood=None,
        record_provenance=None,
        constr_iterations=None,
        progress=None,
    ):
        # Set up all the generic params describe in the tsdate.date function, and define
        # priors if not passed-in already
        self.ts = ts
        self.mutation_rate = mutation_rate
        self.recombination_rate = recombination_rate
        self.return_posteriors = return_posteriors
        self.return_likelihood = return_likelihood
        self.pbar = progress
        self.time_units = "generations" if time_units is None else time_units
        if record_provenance is None:
            record_provenance = True

        if recombination_rate is not None:
            raise NotImplementedError(
                "Using the recombination clock is not currently supported"
                ". See https://github.com/awohns/tsdate/issues/5 for details"
            )

        Ne = population_size  # shorthand
        if isinstance(Ne, dict):
            Ne = demography.PopulationSizeHistory(**Ne)

        self.provenance_params = None
        if record_provenance:
            self.provenance_params = dict(
                mutation_rate=mutation_rate,
                recombination_rate=recombination_rate,
                time_units=time_units,
                progress=progress,
                # demography.PopulationSizeHistory provides as_dict() for saving
                population_size=Ne.as_dict() if hasattr(Ne, "as_dict") else Ne,
            )

        if constr_iterations is None:
            self.constr_iterations = 0
        else:
            if not (isinstance(constr_iterations, int) and constr_iterations >= 0):
                raise ValueError(
                    "Number of constrained least squares iterations must be a "
                    "non-negative integer"
                )
            self.constr_iterations = constr_iterations

        if self.prior_grid_func_name is None:
            if priors is not None:
                raise ValueError(f"Priors are not used for method {self.name}")
            if Ne is not None:
                raise ValueError(f"Population size is not used for method {self.name}")
        else:
            if priors is None:
                if Ne is None:
                    raise ValueError(
                        "Must specify population size if priors are not already "
                        f"built using tsdate.build_{self.prior_grid_func_name}()"
                    )
                mk_prior = getattr(prior, self.prior_grid_func_name)
                # Default to not creating approximate priors unless ts has
                # greater than DEFAULT_APPROX_PRIOR_SIZE samples
                approx = (
                    True if ts.num_samples > base.DEFAULT_APPROX_PRIOR_SIZE else False
                )
                self.priors = mk_prior(
                    ts, Ne, approximate_priors=approx, progress=progress
                )
            else:
                logger.info("Using user-specified priors")
                if Ne is not None:
                    raise ValueError(
                        "Cannot specify population size if specifying priors "
                        f"from tsdate.build_{self.prior_grid_func_name}()"
                    )
                self.priors = priors

        # mutation to edge mapping
        # TODO: this isn't needed except for mutations_edge in constrain_mutations
        self.edges_mutations, self.mutations_edge = util.mutation_span_array(ts)

    def get_modified_ts(self, result, eps):
        # Return a new ts based on the existing one, but with the various
        # time-related information correctly set.
        ts = self.ts
        node_mean_t = result.posterior_mean
        node_var_t = result.posterior_var
        mut_mean_t = result.mutation_mean
        mut_var_t = result.mutation_var
        mut_edge = result.mutation_edge
        mut_node = result.mutation_node
        tables = ts.dump_tables()
        nodes = tables.nodes
        mutations = tables.mutations

        if self.provenance_params is not None:
            provenance.record_provenance(tables, self.name, **self.provenance_params)
        # Constrain node ages for positive branch lengths
        constr_timing = time.time()
        nodes.time = util.constrain_ages(ts, node_mean_t, eps, self.constr_iterations)
        mutations.time = util.constrain_mutations(ts, nodes.time, mut_edge)
        mutations.node = mut_node
        mutations.parent = np.full(mutations.num_rows, tskit.NULL, dtype=np.int32)
        tables.time_units = self.time_units
        constr_timing -= time.time()
        logger.info(f"Constrained node ages in {abs(constr_timing):.2f} seconds")
        # Add posterior mean and variance to node/mutation metadata
        meta_timing = time.time()
        self.set_time_metadata(
            nodes, node_mean_t, node_var_t, schemas.default_node_schema, overwrite=True
        )
        self.set_time_metadata(
            mutations, mut_mean_t, mut_var_t, schemas.default_mutation_schema
        )
        meta_timing -= time.time()
        logger.info(f"Inserted node and mutation metadata in {abs(meta_timing)} seconds")
        sort_timing = time.time()
        tables.sort()
        tables.build_index()
        tables.compute_mutation_parents()
        sort_timing -= time.time()
        logger.info(f"Sorted tree sequence in {abs(sort_timing):.2f} seconds")
        return tables.tree_sequence()

    def set_time_metadata(self, table, mean, var, default_schema, overwrite=False):
        if var is not None:
            table_name = type(table).__name__
            assert len(mean) == len(var) == table.num_rows
            if table.metadata_schema.schema is None or overwrite:
                if len(table.metadata) == 0 or overwrite:
                    table.metadata_schema = default_schema
                    md_iter = ({} for _ in range(table.num_rows))
                    # For speed, assume we don't need to validate
                    encoder = table.metadata_schema.encode_row
                    logger.info(f"Set metadata schema on {table_name}")
                else:
                    logger.warning(
                        f"Could not set metadata on {table_name}: "
                        "data already exists with no schema"
                    )
                    return
            else:
                md_iter = (
                    table.metadata_schema.decode_row(md)
                    for md in tskit.unpack_bytes(table.metadata, table.metadata_offset)
                )
                encoder = table.metadata_schema.validate_and_encode_row
                # TODO: could try to add to the existing schema if it's compatible
            metadata_array = []
            try:
                # wrap entire loop in try/except so metadata is either all set or not
                for metadata_dict, mn, vr in zip(md_iter, mean, var):
                    metadata_dict.update((("mn", mn), ("vr", vr)))
                    # validate and replace
                    metadata_array.append(encoder(metadata_dict))
                table.packset_metadata(metadata_array)
            except tskit.MetadataValidationError as e:
                logger.warning(f"Could not set time metadata in {table_name}: {e}")

    def parse_result(self, result, epsilon, extra_posterior_cols=None):
        # Construct the tree sequence to return and add other stuff we might want to
        # return. pst_cols is a dict to be appended to the output posterior dict
        ret = [self.get_modified_ts(result, epsilon)]
        if self.return_posteriors:
            pst_dict = None
            if result.posterior_obj is not None:
                pst_dict = result.posterior_obj.nonfixed_dict()
                pst_dict.update(extra_posterior_cols or {})
            ret.append(pst_dict)
        if self.return_likelihood:
            ret.append(result.mutation_lik)
        return tuple(ret) if len(ret) > 1 else ret.pop()

    def get_fixed_nodes_set(self):
        # TODO: modify to allow non-contemporary samples. If these have priors specified
        # they should work fine with these algorithms.
        for sample in self.ts.samples():
            if self.ts.node(sample).time != 0:
                raise NotImplementedError("Samples must all be at time 0")
        return set(self.ts.samples())


class DiscreteTimeMethod(EstimationMethod):
    prior_grid_func_name = "prior_grid"

    @staticmethod
    def mean_var(ts, posterior):
        """
        Mean and variance of node age given an atomic time discretization. Fixed
        nodes will be given a mean of their exact time in the tree sequence, and
        zero variance. This is a static method for ease of testing.
        """
        mn_post = np.full(ts.num_nodes, np.nan)  # Fill with NaNs so we detect when
        va_post = np.full(ts.num_nodes, np.nan)  # there's been an error

        is_fixed = np.ones(posterior.num_nodes, dtype=bool)
        is_fixed[posterior.nonfixed_nodes] = False
        mn_post[is_fixed] = ts.nodes_time[is_fixed]
        va_post[is_fixed] = 0

        for u in posterior.nonfixed_nodes:
            probs = posterior[u]
            times = posterior.timepoints
            mn_post[u] = np.sum(probs * times) / np.sum(probs)
            va_post[u] = np.sum(((mn_post[u] - (times)) ** 2) * (probs / np.sum(probs)))

        return mn_post, va_post

    def main_algorithm(self, probability_space, epsilon, num_threads):
        # Algorithm class is shared by inside-outside & outside-maximization methods
        if probability_space != base.LOG:
            liklhd = Likelihoods(
                self.ts,
                self.priors.timepoints,
                self.mutation_rate,
                self.recombination_rate,
                eps=epsilon,
                fixed_node_set=self.get_fixed_nodes_set(),
                progress=self.pbar,
            )
        else:
            liklhd = LogLikelihoods(
                self.ts,
                self.priors.timepoints,
                self.mutation_rate,
                self.recombination_rate,
                eps=epsilon,
                fixed_node_set=self.get_fixed_nodes_set(),
                progress=self.pbar,
            )
        if self.mutation_rate is not None:
            liklhd.precalculate_mutation_likelihoods(num_threads=num_threads)

        return InOutAlgorithms(self.priors, liklhd, progress=self.pbar)


class InsideOutsideMethod(DiscreteTimeMethod):
    name = "inside_outside"

    def run(
        self,
        eps,
        outside_standardize,
        ignore_oldest_root,
        probability_space,
        num_threads=None,
        cache_inside=None,
    ):
        if self.provenance_params is not None:
            self.provenance_params.update(
                {k: v for k, v in locals().items() if k != "self"}
            )
        dynamic_prog = self.main_algorithm(probability_space, eps, num_threads)
        marginal_likl = dynamic_prog.inside_pass(cache_inside=cache_inside)
        posterior_obj = dynamic_prog.outside_pass(
            standardize=outside_standardize, ignore_oldest_root=ignore_oldest_root
        )
        # Turn the posterior into probabilities
        posterior_obj.standardize()  # Just to ensure there are no floating point issues
        posterior_obj.force_probability_space(base.LIN)
        posterior_obj.to_probabilities()

        posterior_mean, posterior_var = self.mean_var(self.ts, posterior_obj)
        mut_edge = np.full(self.ts.num_mutations, tskit.NULL)
        mut_node = self.ts.mutations_node
        return Results(
            posterior_mean,
            posterior_var,
            posterior_obj,
            None,
            None,
            marginal_likl,
            mut_edge,
            mut_node,
        )


class MaximizationMethod(DiscreteTimeMethod):
    name = "maximization"

    def __init__(self, ts, **kwargs):
        super().__init__(ts, **kwargs)
        if self.return_posteriors:
            raise ValueError("Cannot return posterior with maximization method")

    def run(
        self,
        eps,
        probability_space=None,
        num_threads=None,
        cache_inside=None,
    ):
        if self.mutation_rate is None:
            raise ValueError("Outside maximization method requires mutation rate")
        if self.provenance_params is not None:
            self.provenance_params.update(
                {k: v for k, v in locals().items() if k != "self"}
            )
        dynamic_prog = self.main_algorithm(probability_space, eps, num_threads)
        marginal_likl = dynamic_prog.inside_pass(cache_inside=cache_inside)
        posterior_mean = dynamic_prog.outside_maximization(eps=eps)
        mut_edge = np.full(self.ts.num_mutations, tskit.NULL)
        mut_node = self.ts.mutations_node
        return Results(
            posterior_mean, None, None, None, None, marginal_likl, mut_edge, mut_node
        )


class VariationalGammaMethod(EstimationMethod):
    prior_grid_func_name = None
    name = "variational_gamma"

    def __init__(self, ts, **kwargs):
        super().__init__(ts, **kwargs)

    def run(
        self,
        eps,
        max_iterations,
        max_shape,
        rescaling_intervals,
        rescaling_iterations,
        match_segregating_sites,
        regularise_roots,
        singletons_phased,
    ):
        if self.provenance_params is not None:
            self.provenance_params.update(
                {k: v for k, v in locals().items() if k != "self"}
            )
        if not max_iterations > 0:
            raise ValueError("Maximum number of EP iterations must be greater than 0")
        if self.mutation_rate is None:
            raise ValueError("Variational gamma method requires mutation rate")

        posterior = variational.ExpectationPropagation(
            self.ts,
            mutation_rate=self.mutation_rate,
            singletons_phased=singletons_phased,
        )
        posterior.run(
            ep_iterations=max_iterations,
            max_shape=max_shape,
            rescale_intervals=rescaling_intervals,
            rescale_iterations=rescaling_iterations,
            regularise=regularise_roots,
            rescale_segsites=match_segregating_sites,
            progress=self.pbar,
        )

        node_mn, node_va = posterior.node_moments()
        mutation_mn, mutation_va = posterior.mutation_moments()
        mutation_edge, mutation_node = posterior.mutation_mapping()

        return Results(
            node_mn,
            node_va,
            None,
            mutation_mn,
            mutation_va,
            None,
            mutation_edge,
            mutation_node,
        )


[docs] def maximization( tree_sequence, *, mutation_rate, population_size=None, priors=None, eps=None, num_threads=None, probability_space=None, # below deliberately undocumented cache_inside=None, Ne=None, # Other params documented in `.date()` **kwargs, ): """ maximization(tree_sequence, *, mutation_rate, population_size=None, priors=None,\ eps=None, num_threads=None, probability_space=None, **kwargs) Infer dates for nodes in a genealogical graph using the "outside maximization" algorithm. This approximates the marginal posterior distribution of a node's age using an atomic discretization of time (e.g. point masses at particular timepoints). This estimation method comprises a single "inside" step followed by an "outside maximization" step. The inside step passes backwards in time from the samples to the roots of the graph,taking account of the distributions of times of each node's child (and if a ``mutation_rate`` is given, the the number of mutations on each edge). The outside maximization step passes forwards in time from the roots, updating each node's time on the basis of the most likely timepoint for each parent of that node. This provides a reasonable point estimate for node times, but does not generate a true posterior time distribution. For example: .. code-block:: python new_ts = tsdate.maximization(ts, mutation_rate=1e-8, population_size=1e4) .. note:: The prior parameters for each node-to-be-dated take the form of probabilities for each node at a set of discrete timepoints. If the ``priors`` parameter is used, it must specify an object constructed using :func:`build_prior_grid` (this can be used to define the number and position of the timepoints). If ``priors`` is not used, ``population_size`` must be provided, which is used to create a default prior derived from the conditional coalescent (tilted according to population size and weighted by the genomic span over which a node has a given number of descendant samples). This default prior assumes the nodes to be dated are all the non-sample nodes in the input tree sequence, and that they are contemporaneous. :param ~tskit.TreeSequence tree_sequence: The input tree sequence to be dated. :param float mutation_rate: The estimated mutation rate per unit of genome per unit time. If provided, the dating algorithm will use a mutation rate clock to help estimate node dates. Default: ``None`` :param float or ~demography.PopulationSizeHistory population_size: The estimated (diploid) effective population size used to construct the (default) conditional coalescent prior. For a population with constant size, this can be given as a single value (for example, as commonly estimated by the observed genetic diversity of the sample divided by four-times the expected mutation rate). Alternatively, for a population with time-varying size, this can be given directly as a :class:`~demography.PopulationSizeHistory` object or a parameter dictionary passed to initialise a :class:`~demography.PopulationSizeHistory` object. The ``population_size`` parameter is only used when ``priors`` is ``None``. Conversely, if ``priors`` is not ``None``, no ``population_size`` value should be specified. :param tsdate.base.NodeGridValues priors: NodeGridValues object containing the prior parameters for each node-to-be-dated. Note that different estimation methods may require different types of prior, as described in the documentation for each estimation method. :param float eps: The error factor in time difference calculations, and the minimum distance separating parent and child ages in the returned tree sequence. Default: None, treated as 1e-6. :param int num_threads: The number of threads to use when precalculating likelihoods. A simpler unthreaded algorithm is used unless this is >= 1. Default: None :param string probability_space: Should the internal algorithm save probabilities in "logarithmic" (slower, less liable to to overflow) or "linear" space (fast, may overflow). Default: None treated as"logarithmic" :param \\**kwargs: Other keyword arguments as described in the :func:`date` wrapper function, notably ``mutation_rate``, and ``population_size`` or ``priors``. Further arguments include ``time_units``, ``progress``, and ``record_provenance``. The additional ``return_likelihood`` argument can be used to return additional information (see below). Posteriors cannot be returned using this estimation method. :return: - **ts** (:class:`~tskit.TreeSequence`) -- a copy of the input tree sequence with updated node times based on the posterior mean, corrected where necessary to ensure that parents are strictly older than all their children by an amount given by the ``eps`` parameter. - **marginal_likelihood** (:py:class:`float`) -- (Only returned if ``return_likelihood`` is ``True``) The marginal likelihood of the mutation data given the inferred node times. """ if Ne is not None: if population_size is not None: raise ValueError("Only provide one of Ne (deprecated) or population_size") else: population_size = Ne if eps is None: eps = DEFAULT_EPSILON if probability_space is None: probability_space = base.LOG dating_method = MaximizationMethod( tree_sequence, mutation_rate=mutation_rate, population_size=population_size, priors=priors, **kwargs, ) result = dating_method.run( eps=eps, num_threads=num_threads, cache_inside=cache_inside, probability_space=probability_space, ) return dating_method.parse_result(result, eps)
[docs] def inside_outside( tree_sequence, *, mutation_rate, population_size=None, priors=None, eps=None, num_threads=None, outside_standardize=None, ignore_oldest_root=None, probability_space=None, # below deliberately undocumented cache_inside=False, # Deprecated params Ne=None, # Other params documented in `.date()` **kwargs, ): """ inside_outside(tree_sequence, *, mutation_rate, population_size=None, priors=None,\ eps=None, num_threads=None, outside_standardize=None, ignore_oldest_root=None,\ probability_space=None, **kwargs) Infer dates for nodes in a genealogical graph using the "inside outside" algorithm. This approximates the marginal posterior distribution of a node's age using an atomic discretization of time (e.g. point masses at particular timepoints). Currently, this estimation method comprises a single "inside" followed by a similar "outside" step. The inside step passes backwards in time from the samples to the roots of the graph,taking account of the distributions of times of each node's child (and if a ``mutation_rate`` is given, the the number of mutations on each edge). The outside step passes forwards in time from the roots, incorporating the time distributions for each node's parents. If there are (undirected) cycles in the underlying graph, this method does not provide a theoretically exact estimate of the marginal posterior distribution of node ages, but in practice it results in an accurate approximation. For example: .. code-block:: python new_ts = tsdate.inside_outside(ts, mutation_rate=1e-8, population_size=1e4) .. note:: The prior parameters for each node-to-be-dated take the form of probabilities for each node at a set of discrete timepoints. If the ``priors`` parameter is used, it must specify an object constructed using :func:`build_prior_grid` (this can be used to define the number and position of the timepoints). If ``priors`` is not used, ``population_size`` must be provided, which is used to create a default prior derived from the conditional coalescent (tilted according to population size and weighted by the genomic span over which a node has a given number of descendant samples). This default prior assumes the nodes to be dated are all the non-sample nodes in the input tree sequence, and that they are contemporaneous. :param ~tskit.TreeSequence tree_sequence: The input tree sequence to be dated. :param float mutation_rate: The estimated mutation rate per unit of genome per unit time. If provided, the dating algorithm will use a mutation rate clock to help estimate node dates. Default: ``None`` :param float or ~demography.PopulationSizeHistory population_size: The estimated (diploid) effective population size used to construct the (default) conditional coalescent prior. For a population with constant size, this can be given as a single value (for example, as commonly estimated by the observed genetic diversity of the sample divided by four-times the expected mutation rate). Alternatively, for a population with time-varying size, this can be given directly as a :class:`~demography.PopulationSizeHistory` object or a parameter dictionary passed to initialise a :class:`~demography.PopulationSizeHistory` object. The ``population_size`` parameter is only used when ``priors`` is ``None``. Conversely, if ``priors`` is not ``None``, no ``population_size`` value should be specified. :param tsdate.base.NodeGridValues priors: NodeGridValues object containing the prior parameters for each node-to-be-dated. Note that different estimation methods may require different types of prior, as described in the documentation for each estimation method. :param float eps: The error factor in time difference calculations, and the minimum distance separating parent and child ages in the returned tree sequence. Default: None, treated as 1e-6. :param int num_threads: The number of threads to use when precalculating likelihoods. A simpler unthreaded algorithm is used unless this is >= 1. Default: None :param bool outside_standardize: Should the likelihoods be standardized during the outside step? This can help to avoid numerical under/overflow. Using unstandardized values is mostly useful for testing (e.g. to obtain, in the outside step, the total functional value for each node). Default: None, treated as True. :param bool ignore_oldest_root: Should the oldest root in the tree sequence be ignored in the outside algorithm (if ``"inside_outside"`` is used as the method). Ignoring outside root can provide greater stability when dating tree sequences inferred from real data, in particular if all local trees are assumed to coalesce in a single "grand MRCA", as in older versions of ``tsinfer``. Default: None, treated as False. :param string probability_space: Should the internal algorithm save probabilities in "logarithmic" (slower, less liable to to overflow) or "linear" space (fast, may overflow). Default: "logarithmic" :param \\**kwargs: Other keyword arguments as described in the :func:`date` wrapper function, notably ``mutation_rate``, and ``population_size`` or ``priors``. Further arguments include ``time_units``, ``progress``, and ``record_provenance``. The additional arguments ``return_posteriors`` and ``return_likelihood`` can be used to return additional information (see below). :return: - **ts** (:class:`~tskit.TreeSequence`) -- a copy of the input tree sequence with updated node times based on the posterior mean, corrected where necessary to ensure that parents are strictly older than all their children by an amount given by the ``eps`` parameter. - **posteriors** (:py:class:`dict`) -- (Only returned if ``return_posteriors`` is ``True``) A dictionary of posterior probabilities. Each node whose time was inferred corresponds to an item in this dictionary whose key is the node ID and value is an array of probabilities of the node being at a list of timepoints. Timepoint values are provided in the returned dictionary under the key named "time". When read as a pandas ``DataFrame`` object using ``pd.DataFrame(posteriors)``, the rows correspond to labelled timepoints and columns are headed by their respective node ID. - **marginal_likelihood** (:py:class:`float`) -- (Only returned if ``return_likelihood`` is ``True``) The marginal likelihood of the mutation data given the inferred node times. """ if Ne is not None: if population_size is not None: raise ValueError("Only provide one of Ne (deprecated) or population_size") else: population_size = Ne if eps is None: eps = DEFAULT_EPSILON if probability_space is None: probability_space = base.LOG if outside_standardize is None: outside_standardize = True if ignore_oldest_root is None: ignore_oldest_root = False dating_method = InsideOutsideMethod( tree_sequence, mutation_rate=mutation_rate, population_size=population_size, priors=priors, **kwargs, ) result = dating_method.run( eps=eps, num_threads=num_threads, outside_standardize=outside_standardize, ignore_oldest_root=ignore_oldest_root, cache_inside=cache_inside, probability_space=probability_space, ) return dating_method.parse_result( result, eps, {"time": result.posterior_obj.timepoints} )
[docs] def variational_gamma( tree_sequence, *, mutation_rate, eps=None, max_iterations=None, rescaling_intervals=None, rescaling_iterations=None, match_segregating_sites=None, # deliberately undocumented parameters below. We may eventually document these max_shape=None, regularise_roots=None, singletons_phased=None, **kwargs, ): """ variational_gamma(tree_sequence, *, mutation_rate, eps=None, max_iterations=None,\ rescaling_intervals=None, **kwargs) Infer dates for nodes in a tree sequence using expectation propagation, which approximates the marginal posterior distribution of a given node's age with a gamma distribution. Convergence to the correct posterior moments is obtained by updating the distributions for node dates using several rounds of iteration. For example: .. code-block:: python new_ts = tsdate.variational_gamma(ts, mutation_rate=1e-8, max_iterations=10) A piecewise-constant uniform distribution is used as a prior for each node, that is updated via expectation maximization in each iteration. Node-specific priors are not currently supported. :param ~tskit.TreeSequence tree_sequence: The input tree sequence to be dated. :param float mutation_rate: The estimated mutation rate per unit of genome per unit time. :param float eps: The minimum distance separating parent and child ages in the returned tree sequence. Default: None, treated as 1e-6 :param int max_iterations: The number of iterations used in the expectation propagation algorithm. Default: None, treated as 10. :param float rescaling_intervals: For time rescaling, the number of time intervals within which to estimate a rescaling parameter. Setting this to zero means that rescaling is not performed. Default ``None``, treated as 1000. :param float rescaling_iterations: The number of iterations for time rescaling. Setting this to zero means that rescaling is not performed. Default ``None``, treated as 5. :param bool match_segregating_sites: If ``True``, then time is rescaled such that branch- and site-mode segregating sites are approximately equal. If ``False``, time is rescaled such that branch- and site-mode root-to-leaf length are approximately equal, which gives unbiased estimates when there are polytomies. Default ``False``. :param \\**kwargs: Other keyword arguments as described in the :func:`date` wrapper function, including ``time_units``, ``progress``, and ``record_provenance``. The arguments ``return_posteriors`` and ``return_likelihood`` can be used to return additional information (see below). :return: - **ts** (:class:`~tskit.TreeSequence`) -- a copy of the input tree sequence with updated node times based on the posterior mean, corrected where necessary to ensure that parents are strictly older than all their children by an amount given by the ``eps`` parameter. - **posteriors** (:py:class:`dict`) -- (Only returned if ``return_posteriors`` is ``True``) A dictionary of posterior probabilities. Each node whose time was inferred corresponds to an item in this dictionary whose key is the node ID and value is an array of the ``[shape, rate]`` parameters of the posterior gamma distribution for that node. When read as a pandas ``DataFrame`` object using ``pd.DataFrame(posteriors)``, the first row of the data frame is the shape and the second the rate parameter, each column being headed by the respective node ID. - **marginal_likelihood** (:py:class:`float`) -- (Only returned if ``return_likelihood`` is ``True``) The marginal likelihood of the mutation data given the inferred node times. Not currently implemented for this method (set to ``None``) """ if eps is None: eps = DEFAULT_EPSILON if max_iterations is None: max_iterations = DEFAULT_MAX_ITERATIONS if max_shape is None: # The maximum value for the shape parameter in the variational posteriors. # Equivalent to the maximum precision (inverse variance) on a logarithmic scale. max_shape = 1000 if rescaling_intervals is None: rescaling_intervals = DEFAULT_RESCALING_INTERVALS if rescaling_iterations is None: rescaling_iterations = DEFAULT_RESCALING_ITERATIONS if match_segregating_sites is None: match_segregating_sites = False if regularise_roots is None: regularise_roots = True if singletons_phased is None: singletons_phased = True if tree_sequence.num_mutations == 0: raise ValueError( "No mutations present: these are required for the variational_gamma method" ) dating_method = VariationalGammaMethod( tree_sequence, mutation_rate=mutation_rate, **kwargs ) result = dating_method.run( eps=eps, max_iterations=max_iterations, max_shape=max_shape, rescaling_intervals=rescaling_intervals, rescaling_iterations=rescaling_iterations, match_segregating_sites=match_segregating_sites, regularise_roots=regularise_roots, singletons_phased=singletons_phased, ) return dating_method.parse_result(result, eps, {"parameter": ["shape", "rate"]})
estimation_methods = { "variational_gamma": variational_gamma, "inside_outside": inside_outside, "maximization": maximization, } """ The names of available estimation methods, each mapped to a function to carry out the appropriate method. Names can be passed as strings to the :func:`~tsdate.date` function, or each named function can be called directly: * :func:`tsdate.variational_gamma`: variational approximation, empirically most accurate. * :func:`tsdate.inside_outside`: empirically better, theoretically problematic. * :func:`tsdate.maximization`: worse empirically, especially with gamma approximated priors, but theoretically robust """
[docs] def date( tree_sequence, *, mutation_rate, recombination_rate=None, time_units=None, method=None, constr_iterations=None, return_posteriors=None, return_likelihood=None, progress=None, record_provenance=True, # Other kwargs documented in the functions for each specific estimation-method **kwargs, ): """ Infer dates for nodes in a genealogical graph (or :ref:`ARG<tutorials:sec_args>`) stored in the :ref:`succinct tree sequence<tskit:sec_introduction>` format. New times are assigned to nodes using the estimation algorithm specified by ``method`` (see note below). If a ``mutation_rate`` is given, the mutation clock is used. The recombination clock is unsupported at this time. If neither a ``mutation_rate`` nor a ``recombination_rate`` is given, a topology-only clock is used. Times associated with mutations and times associated with non-fixed (non-sample) nodes are overwritten. For example: .. code-block:: python mu = 1e-8 Ne = ts.diversity()/4/mu # In the absence of external info, use ts for prior Ne new_ts = tsdate.date(ts, mutation_rate=mu, population_size=Ne) .. note:: This is a wrapper for the named functions that are listed in :data:`~tsdate.core.estimation_methods`. Details and specific parameters for each estimation method are given in the documentation for those functions. :param ~tskit.TreeSequence tree_sequence: The input tree sequence to be dated (for example one with :data:`uncalibrated<tskit.TIME_UNITS_UNCALIBRATED>` node times). :param float mutation_rate: The estimated mutation rate per unit of genome per unit time (see individual methods) :param float recombination_rate: The estimated recombination rate per unit of genome per unit time. If provided, the dating algorithm will use a recombination rate clock to help estimate node dates. Default: ``None`` (not currently implemented) :param str time_units: The time units used by the ``mutation_rate`` and ``recombination_rate`` values, and stored in the ``time_units`` attribute of the output tree sequence. If the conditional coalescent prior is used, then this is also applies to the value of ``population_size``, which in standard coalescent theory is measured in generations. Therefore if you wish to use mutation and recombination rates measured in (say) years, and are using the conditional coalescent prior, the ``population_size`` value which you provide must be scaled by multiplying by the number of years per generation. If ``None`` (default), assume ``"generations"``. :param string method: What estimation method to use. See :data:`~tsdate.core.estimation_methods` for possible values. If ``None`` (default) the "variational_gamma" method is currently chosen. :param bool return_posteriors: If ``True``, instead of returning just a dated tree sequence, return a tuple of ``(dated_ts, posteriors)``. Default: None, treated as False. :param int constr_iterations: The maximum number of constrained least squares iterations to use prior to forcing positive branch lengths. Default: None, treated as 0. :param bool return_likelihood: If ``True``, return the log marginal likelihood from the inside algorithm in addition to the dated tree sequence. If ``return_posteriors`` is also ``True``, then the marginal likelihood will be the last element of the tuple. Default: None, treated as False. :param bool progress: Show a progress bar. Default: None, treated as False. :param bool record_provenance: Should the tsdate command be appended to the provenence information in the returned tree sequence? Default: None, treated as True. :param \\**kwargs: Other keyword arguments specific to the :data:`estimation method<tsdate.core.estimation_methods>` used. These are documented in those specific functions. :return: A copy of the input tree sequence but with updated node times, or (if ``return_posteriors`` or ``return_likelihood`` is True) a tuple of that tree sequence plus a dictionary of posterior probabilities and/or the marginal likelihood given the mutations on the tree sequence. """ # Only the .date() wrapper needs to consider the deprecated "Ne" param if method is None: method = "variational_gamma" if method not in estimation_methods: raise ValueError(f"method must be one of {list(estimation_methods.keys())}") return estimation_methods[method]( tree_sequence, mutation_rate=mutation_rate, recombination_rate=recombination_rate, time_units=time_units, progress=progress, constr_iterations=constr_iterations, return_posteriors=return_posteriors, return_likelihood=return_likelihood, record_provenance=record_provenance, **kwargs, )