Source code for tsinfer.inference

#
# Copyright (C) 2018-2023 University of Oxford
#
# This file is part of tsinfer.
#
# tsinfer is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# tsinfer is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with tsinfer.  If not, see <http://www.gnu.org/licenses/>.
#
"""
Central module for high-level inference. The actual implementation of
of the core tasks like ancestor generation and matching are delegated
to other modules.
"""
import collections
import copy
import dataclasses
import heapq
import json
import logging
import math
import operator
import os
import pathlib
import pickle
import queue
import tempfile
import threading
import time as time_

import humanize
import numpy as np
import tskit

import _tsinfer
import tsinfer.algorithm as algorithm
import tsinfer.ancestors as ancestors
import tsinfer.constants as constants
import tsinfer.formats as formats
import tsinfer.progress as progress
import tsinfer.provenance as provenance
import tsinfer.threads as threads

logger = logging.getLogger(__name__)


variant_data_time_metadata_definition = {
    "description": "Time of an individual from the SampleData file.",
    "type": "number",
    # Defaults aren't currently used, see
    # https://github.com/tskit-dev/tskit/issues/1073
    "default": 0,
}

inference_type_metadata_definition = {
    "description": (
        "The type of inference used at this site. This can be one of the following: "
        f"'{constants.INFERENCE_FULL}' for sites which used the standard tsinfer "
        f"algorithm; '{constants.INFERENCE_NONE}' for sites containing only missing "
        f"data or the ancestral state; '{constants.INFERENCE_PARSIMONY}' for sites "
        "that used a parsimony algorithm to place mutations based on trees inferred "
        "from the remaining data."
    ),
    "type": "string",
    "enum": [
        constants.INFERENCE_NONE,
        constants.INFERENCE_FULL,
        constants.INFERENCE_PARSIMONY,
    ],
}

node_ancestor_data_id_metadata_definition = {
    "description": (
        "The ID of the tsinfer ancestor data node from which this node is derived."
    ),
    "type": "number",
}

node_variant_data_id_metadata_definition = {
    "description": (
        "The ID of the tsinfer sample data node from which this node is derived. "
        "Only present for nodes in which historical samples are treated as ancestors."
    ),
    "type": "number",
}


def add_to_schema(schema, name, definition=None, required=False):
    """
    Adds the specified metadata name to the schema, with the specified definition.
    If the metadata name is already in the schema then either will warn about
    potential overwriting (if the definition is the same and there is a description),
    or will raise an error otherwise (to avoid conflicting metadata definitions).
    """
    schema = copy.deepcopy(schema)
    if definition is None:
        definition = {}
    try:
        if name in schema["properties"]:
            try:
                if (
                    schema["properties"][name] == definition
                    and definition["description"] != ""
                ):
                    logger.warning(
                        f"Metadata {name} with identical description already in schema."
                        " Schema left unchanged: existing metadata may be overwritten."
                    )
                    return schema
            except KeyError:
                pass
            raise ValueError(f"The metadata {name} is reserved for use by tsinfer")
    except KeyError:
        schema["properties"] = {}
    schema["properties"][name] = definition
    if required:
        if "required" not in schema:
            schema["required"] = []
        schema["required"].append(name)
    return schema


def is_pc_ancestor(flags):
    """
    Returns True if the path compression ancestor flag is set on the specified
    flags value.
    """
    return (flags & constants.NODE_IS_PC_ANCESTOR) != 0


def is_srb_ancestor(flags):
    """
    Returns True if the shared recombination breakpoint flag is set on the
    specified flags value.
    """
    return (flags & constants.NODE_IS_SRB_ANCESTOR) != 0


def count_pc_ancestors(flags):
    """
    Returns the number of values in the specified array which have the
    NODE_IS_PC_ANCESTOR set.
    """
    flags = np.asarray(flags, dtype=np.uint32)
    return np.sum(is_pc_ancestor(flags))


def count_srb_ancestors(flags):
    """
    Returns the number of values in the specified array which have the
    NODE_IS_SRB_ANCESTOR set.
    """
    flags = np.asarray(flags, dtype=np.uint32)
    return np.sum(np.bitwise_and(flags, constants.NODE_IS_SRB_ANCESTOR) != 0)


AlleleCounts = collections.namedtuple("AlleleCounts", "known ancestral derived")


def allele_counts(genotypes):
    """
    Return summary counts of the number of different allele types for a genotypes array
    """
    n_known = np.sum(genotypes != tskit.MISSING_DATA)
    n_ancestral = np.sum(genotypes == 0)
    return AlleleCounts(
        known=n_known, ancestral=n_ancestral, derived=n_known - n_ancestral
    )


def _get_progress_monitor(progress_monitor, **kwargs):
    """
    Check if this really is a ProgressMonitor, if not, return something usable as one
    """
    if isinstance(progress_monitor, progress.ProgressMonitor):
        return progress_monitor
    if progress_monitor:
        return progress.ProgressMonitor(**kwargs)
    return progress.DummyProgressMonitor()


def _encode_raw_metadata(obj):
    return json.dumps(obj).encode()


def _update_site_metadata(current_metadata, inference_type):
    return {"inference_type": inference_type, **current_metadata}


def verify(variant_data, tree_sequence, progress_monitor=None):
    """
    verify(samples, tree_sequence)

    Verifies that the specified sample data and tree sequence files encode the
    same data.

    :param SampleData samples: The input `SampleData` instance
        representing the observed data that we wish to compare to.
    :param TreeSequence tree_sequence: The input :class:`tskit.TreeSequence`
        instance an encoding of the specified samples that we wish to verify.
    """
    progress_monitor = _get_progress_monitor(progress_monitor, verify=True)
    if variant_data.num_sites != tree_sequence.num_sites:
        raise ValueError("numbers of sites not equal")
    if variant_data.num_samples != tree_sequence.num_samples:
        raise ValueError("numbers of samples not equal")
    if variant_data.sequence_length != tree_sequence.sequence_length:
        raise ValueError("Sequence lengths not equal")
    progress = progress_monitor.get("verify", tree_sequence.num_sites)
    for var1, var2 in zip(
        variant_data.variants(recode_ancestral=True), tree_sequence.variants()
    ):
        if var1.site.position != var2.site.position:
            raise ValueError(
                "site positions not equal: {} != {}".format(
                    var1.site.position, var2.site.position
                )
            )
        # First (ancestral) allele should always be the same
        if var1.alleles[0] != var2.alleles[0]:
            raise ValueError(f"Ancestral allele not equal at site {var1.site.id}")
        if var1.alleles != var2.alleles:
            # Alleles may be in a different order, or even present/absent if not in the
            # genotype matrix so we need to explicitly compare the decoded values (slow)
            for i, (g1, g2) in enumerate(zip(var1.genotypes, var2.genotypes)):
                # We don't expect missingness in a tsinfer generated tree sequence
                assert g2 != tskit.NULL
                if g1 != tskit.NULL and var1.alleles[g1] != var2.alleles[g2]:
                    raise ValueError(
                        f"Alleles for sample {i} not equal at site {var1.site.id}"
                    )
        else:
            g1 = var1.genotypes
            g2 = np.copy(var2.genotypes)
            missing_mask = g1 == tskit.NULL
            g2[missing_mask] = tskit.NULL
            if not np.array_equal(g1, g2):
                raise ValueError(f"Genotypes not equal at site {var1.site.id}")
        progress.update()
    progress.close()


def check_sample_indexes(variant_data, indexes):
    """
    Checks that the specified sample indexes are valid for the specified
    sample data file.
    """
    if indexes is None:
        return np.arange(variant_data.num_samples, dtype=np.int32)
    indexes = np.array(indexes)
    if len(indexes) == 0:
        raise ValueError("Must supply at least one sample to match")
    if np.any(indexes < 0) or np.any(indexes >= variant_data.num_samples):
        raise ValueError("Sample index out of bounds")
    if np.any(indexes[:-1] >= indexes[1:]):
        raise ValueError("Sample indexes must be in increasing order")
    return indexes


[docs] def infer( variant_data, *, recombination_rate=None, mismatch_ratio=None, path_compression=True, exclude_positions=None, post_process=None, num_threads=0, # Deliberately undocumented parameters below precision=None, engine=constants.C_ENGINE, progress_monitor=None, time_units=None, simplify=None, # Deprecated record_provenance=True, ): """ infer(variant_data, *, recombination_rate=None, mismatch_ratio=None,\ path_compression=True, exclude_positions=None, post_process=None,\ num_threads=0) Runs the full :ref:`inference pipeline <sec_inference>` on the specified :class:`VariantData` instance and returns the inferred :class:`tskit.TreeSequence`. See :ref:`matching ancestors & samples<sec_inference_match_ancestors_and_samples>` in the documentation for details of ``recombination_rate``, ``mismatch_ratio`` and ``path_compression``. .. note:: For finer grained control over inference, for example to set mismatch ratios when matching ancestors as well as when matching samples, run :func:`tsinfer.generate_ancestors`, :func:`tsinfer.match_ancestors` and :func:`tsinfer.match_samples` separately. :param VariantData variant_data: The input :class:`VariantData` instance representing the observed data that we wish to make inferences from. :param recombination_rate: Either a floating point value giving a constant rate :math:`\\rho` per unit length of genome, or an :class:`msprime.RateMap` object. This is used to calculate the probability of recombination between adjacent sites in the match_samples stage. If ``None``, all matching conflicts are resolved by recombination and all inference sites will have a single mutation (equivalent to mismatch_ratio near zero). :type recombination_rate: float, msprime.RateMap :param float mismatch_ratio: The probability of a mismatch relative to the median probability of recombination between adjacent sites: can only be used if a recombination rate has been set (default: ``None`` treated as 1 if ``recombination_rate`` is set). This is only applied in the match_samples stage. :param bool path_compression: Whether to merge edges that share identical paths (essentially taking advantage of shared recombination breakpoints). :param bool post_process: Whether to run the :func:`post_process` method on the the tree sequence which, among other things, removes ancestral material that does not end up in the current samples (if not specified, defaults to ``True``) :param array_like exclude_positions: A list of site positions to exclude for full inference. Sites with these positions will not be used to generate ancestors, and not used during the copying process. Any such sites that exist in the sample data file will be included in the trees after the main inference process using parsimony. The list does not need to be in to be in any particular order, and can include site positions that are not present in the sample data file. :param int num_threads: The number of worker threads to use in parallelised sections of the algorithm. If <= 0, do not spawn any threads and use simpler sequential algorithms (default). :param bool simplify: When post_processing, only simplify the tree sequence. deprecated but retained for backwards compatibility (default: ``None``). :return: The :class:`tskit.TreeSequence` object inferred from the input sample data. :rtype: tskit.TreeSequence """ with provenance.TimingAndMemory() as timing: progress_monitor = _get_progress_monitor( progress_monitor, generate_ancestors=True, match_ancestors=True, match_samples=True, ) ancestor_data = generate_ancestors( variant_data, num_threads=num_threads, exclude_positions=exclude_positions, engine=engine, progress_monitor=progress_monitor, record_provenance=False, ) # NB: do not pass or encourage use of the mismatch ratio / recombination rate in # the ancestor matching phase. See # https://github.com/tskit-dev/tsinfer/issues/980 ancestors_ts = match_ancestors( variant_data, ancestor_data, engine=engine, num_threads=num_threads, precision=precision, path_compression=path_compression, progress_monitor=progress_monitor, time_units=time_units, record_provenance=False, ) inferred_ts = match_samples( variant_data, ancestors_ts, engine=engine, num_threads=num_threads, recombination_rate=recombination_rate, mismatch_ratio=mismatch_ratio, precision=precision, post_process=post_process, path_compression=path_compression, progress_monitor=progress_monitor, simplify=simplify, record_provenance=False, ) if record_provenance: tables = inferred_ts.dump_tables() record = provenance.get_provenance_dict( command="infer", resources=timing.metrics.asdict(), mismatch_ratio=mismatch_ratio, path_compression=path_compression, precision=precision, simplify=simplify, post_process=post_process, # TODO: maybe record recombination rate (which could be a RateMap) ) tables.provenances.add_row(record=json.dumps(record)) inferred_ts = tables.tree_sequence() return inferred_ts
[docs] def generate_ancestors( variant_data, *, path=None, exclude_positions=None, num_threads=0, genotype_encoding=None, mmap_temp_dir=None, # Deliberately undocumented parameters below engine=constants.C_ENGINE, progress_monitor=None, record_provenance=True, **kwargs, ): """ generate_ancestors(variant_data, *, path=None, exclude_positions=None,\ num_threads=0, genotype_encoding=None, mmap_temp_dir=None, **kwargs) Runs the ancestor generation :ref:`algorithm <sec_inference_generate_ancestors>` on the specified :class:`VariantData` instance and returns the resulting :class:`AncestorData` instance. If you wish to store generated ancestors persistently on file you must pass the ``path`` keyword argument to this function. For example, .. code-block:: python ancestor_data = tsinfer.generate_ancestors(variant_data, path="mydata.ancestors") Other keyword arguments are passed to the :class:`AncestorData` constructor, which may be used to control the storage properties of the generated file. Ancestor generation involves loading the entire genotype matrix into memory, by default using one byte per haploid genotype, which can be prohibitively large when working with sample sizes of 100,000 or more. There are two options to help mitigate memory usage. The ``genotype_encoding`` parameter allows the user to specify a more compact encoding scheme, which reduces storage space for datasets with small numbers of alleles. Currently, the :attr:`.GenotypeEncoding.ONE_BIT` encoding is supported, which provides 8-fold compression of biallelic, non-missing data. An error is raised if an encoding that does not support the range of values present in a given dataset is provided. The second option for reducing the RAM footprint of this function is to use the ``mmap_temp_dir`` parameter. This allows the genotype data to be cached on file, transparently using the operating system's virtual memory subsystem to swap in and out the data. This can work well if the encoded genotype matrix *almost* fits in RAM and fast local storage is available. However, if the size of the encoded genotype matrix is more than, say, twice the available RAM it is unlikely that this function will complete in a reasonable time-frame. A temporary file is created in the specified ``mmap_temp_dir``, which is automatically deleted when the function completes. .. warning:: The ``mmap_temp_dir`` option is a silent no-op on Windows! :param VariantData variant_data: The :class:`VariantData` instance that we are generating putative ancestors from. :param str path: The path of the file to store the ancestor data. If None, the information is stored in memory and not persistent. :param array_like exclude_positions: A list of site positions to exclude for full inference. Sites with these positions will not be used to generate ancestors, and not used during the copying process. The list does not need be in any particular order. :param int num_threads: The number of worker threads to use. If < 1, use a simpler synchronous algorithm. :param int genotype_encoding: The encoding to use for genotype data internally when generating ancestors. See the :class:`.GenotypeEncoding` class for the available options. Defaults to one-byte per genotype. :param str mmap_temp_dir: The directory within which to create the temporary backing file when using mmaped memory for bulk genotype storage. If None (the default) allocate memory directly using the standard mechanism. This is an advanced option, usually only relevant when working with very large datasets (see above for more information). :return: The inferred ancestors stored in an :class:`AncestorData` instance. :rtype: AncestorData """ with provenance.TimingAndMemory() as timing: variant_data._check_finalised() if np.any(np.isfinite(variant_data.sites_time[:])) and np.any( tskit.is_unknown_time(variant_data.sites_time[:]) ): raise ValueError( "Cannot generate ancestors from a variant_data instance that mixes" " user-specified times with times-as-frequencies. To explicitly" " set an undefined time for a site, permanently excluding it" " from inference, set it to np.nan." ) if genotype_encoding is None: # TODO should we provide some functionality to automatically figure # out what the minimum encoding is? genotype_encoding = constants.GenotypeEncoding.EIGHT_BIT generator = AncestorsGenerator( variant_data, ancestor_data_path=path, ancestor_data_kwargs=kwargs, num_threads=num_threads, engine=engine, genotype_encoding=genotype_encoding, mmap_temp_dir=mmap_temp_dir, progress_monitor=progress_monitor, ) generator.add_sites(exclude_positions) ancestor_data = generator.run() for timestamp, record in variant_data.provenances(): ancestor_data.add_provenance(timestamp, record) if record_provenance: ancestor_data.record_provenance("generate_ancestors", timing.metrics.asdict()) ancestor_data.finalise() return ancestor_data
[docs] def match_ancestors( variant_data, ancestor_data, *, recombination_rate=None, mismatch_ratio=None, path_compression=True, num_threads=0, # Deliberately undocumented parameters below recombination=None, # See :class:`Matcher` mismatch=None, # See :class:`Matcher` precision=None, engine=constants.C_ENGINE, progress_monitor=None, extended_checks=False, time_units=None, record_provenance=True, ): """ match_ancestors(variant_data, ancestor_data, *, recombination_rate=None,\ mismatch_ratio=None, path_compression=True, num_threads=0) Run the ancestor matching :ref:`algorithm <sec_inference_match_ancestors>` on the specified :class:`VariantData` and :class:`AncestorData` instances, returning the resulting :class:`tskit.TreeSequence` representing the complete ancestry of the putative ancestors. See :ref:`matching ancestors & samples<sec_inference_match_ancestors_and_samples>` in the documentation for details of ``recombination_rate``, ``mismatch_ratio`` and ``path_compression``. :param VariantData variant_data: The :class:`VariantData` instance representing the input data. :param AncestorData ancestor_data: The :class:`AncestorData` instance representing the set of ancestral haplotypes for which we are finding a history. :param recombination_rate: Either a floating point value giving a constant rate :math:`\\rho` per unit length of genome, or an :class:`msprime.RateMap` object. This is used to calculate the probability of recombination between adjacent sites. If ``None``, all matching conflicts are resolved by recombination and all inference sites will have a single mutation (equivalent to mismatch_ratio near zero) :type recombination_rate: float, msprime.RateMap :param float mismatch_ratio: The probability of a mismatch relative to the median probability of recombination between adjacent sites: can only be used if a recombination rate has been set (default: ``None`` treated as 1 if ``recombination_rate`` is set). :param bool path_compression: Whether to merge edges that share identical paths (essentially taking advantage of shared recombination breakpoints). :param int num_threads: The number of match worker threads to use. If this is <= 0 then a simpler sequential algorithm is used (default). :return: The ancestors tree sequence representing the inferred history of the set of ancestors. :rtype: tskit.TreeSequence """ with provenance.TimingAndMemory() as timing: progress_monitor = _get_progress_monitor(progress_monitor, match_ancestors=True) variant_data._check_finalised() ancestor_data._check_finalised() matcher = AncestorMatcher( variant_data, ancestor_data, time_units=time_units, recombination_rate=recombination_rate, recombination=recombination, mismatch_ratio=mismatch_ratio, mismatch=mismatch, path_compression=path_compression, num_threads=num_threads, precision=precision, extended_checks=extended_checks, engine=engine, progress_monitor=progress_monitor, ) ancestor_grouping = matcher.group_by_linesweep() ts = matcher.match_ancestors(ancestor_grouping) tables = ts.dump_tables() for timestamp, record in ancestor_data.provenances(): tables.provenances.add_row(timestamp=timestamp, record=json.dumps(record)) if record_provenance: record = provenance.get_provenance_dict( command="match_ancestors", resources=timing.metrics.asdict(), mismatch_ratio=mismatch_ratio, path_compression=path_compression, precision=precision, # TODO: maybe record recombination rate (which could be a RateMap) ) tables.provenances.add_row(record=json.dumps(record)) ts = tables.tree_sequence() return ts
[docs] def match_ancestors_batch_init( work_dir, variant_data_path, ancestral_state, ancestor_data_path, min_work_per_job, *, max_num_partitions=None, sample_mask=None, site_mask=None, recombination_rate=None, mismatch_ratio=None, path_compression=True, # Deliberately undocumented parameters below recombination=None, # See :class:`Matcher` mismatch=None, # See :class:`Matcher` precision=None, engine=constants.C_ENGINE, extended_checks=False, time_units=None, record_provenance=True, ): """ match_ancestors_batch_init(work_dir, variant_data_path, ancestral_state, ancestor_data_path, min_work_per_job, \\*, max_num_partitions=None, sample_mask=None, site_mask=None, recombination_rate=None, mismatch_ratio=None, path_compression=True) Initialise a batched ancestor matching job. This function is used to prepare a working directory for running a batched ancestor matching job. The job is split into groups of ancestors, with each group further split into partitions of ancestors if necessary. `work_dir` is created and details are written to `metadata.json` in `work_dir`. The job can then be run using :meth:`match_ancestors_batch_groups` and :meth:`match_ancestors_batch_group_partition` then finally :meth:`match_ancestors_batch_group_finalise`. See :ref:`large scale inference<sec_large_scale>` for more details about how these methods work together. See :meth:`match_ancestors` for details on ancestor matching. :param str work_dir: The directory in which to store the working files. :param str variant_data_path: The input dataset in `VCF Zarr <https://github.com/sgkit-dev/vcf-zarr-spec>`_ format. Path to the Zarr dataset saved on disk. See :class:`VariantData`. :param Union(array, str) ancestral_state: A numpy array of strings specifying the ancestral states (alleles) used in inference. This must be the same length as the number of unmasked sites in the dataset. Alternatively, a single string can be provided, giving the name of an array in the input dataset which contains the ancestral states. Unknown ancestral states can be specified using "N". Any ancestral states which do not match any of the known alleles at that site, will be tallied, and a warning issued summarizing the unknown ancestral states. :param str ancestor_data_path: The path to the file containing the ancestors generated by :meth:`generate_ancestors`. :param int min_work_per_job: The minimum amount of work (as a count of genotypes) to allocate to a single parallel job. If the amount of work in a group of ancestors exceeds this level it will be broken up into parallel partitions, subject to the constraint of `max_num_partitions`. :param int max_num_partitions: The maximum number of partitions to split a group of ancestors into. Useful for limiting the number of jobs in a workflow to avoid job overhead. Defaults to 1000. :param Union(array, str) sample_mask: A numpy array of booleans specifying which samples to mask out (exclude) from the dataset. Alternatively, a string can be provided, giving the name of an array in the input dataset which contains the sample mask. If ``None`` (default), all samples are included. :param Union(array, str) site_mask: A numpy array of booleans specifying which sites to mask out (exclude) from the dataset. Alternatively, a string can be provided, giving the name of an array in the input dataset which contains the site mask. If ``None`` (default), all sites are included. :param recombination_rate: Either a floating point value giving a constant rate :math:`\\rho` per unit length of genome, or an :class:`msprime.RateMap` object. This is used to calculate the probability of recombination between adjacent sites. If ``None``, all matching conflicts are resolved by recombination and all inference sites will have a single mutation (equivalent to mismatch_ratio near zero) :type recombination_rate: float, msprime.RateMap :param float mismatch_ratio: The probability of a mismatch relative to the median probability of recombination between adjacent sites: can only be used if a recombination rate has been set (default: ``None`` treated as 1 if ``recombination_rate`` is set). :param bool path_compression: Whether to merge edges that share identical paths (essentially taking advantage of shared recombination breakpoints). :return: A dictionary of the job metadata, as written to `metadata.json` in `work_dir`. `ancestor_grouping` in this dict contains the grouping of ancestors into groups and should be used to guide calling :meth:`match_ancestors_batch_groups` and :meth:`match_ancestors_batch_group_partition`. :rtype: dict """ if max_num_partitions is None: max_num_partitions = 1000 start_time = time_.perf_counter() work_dir = pathlib.Path(work_dir) work_dir.mkdir(parents=True, exist_ok=True) ancestors = formats.AncestorData.load(ancestor_data_path) variant_data = formats.VariantData( variant_data_path, ancestral_state=ancestral_state, sample_mask=sample_mask, site_mask=site_mask, ) ancestors._check_finalised() variant_data._check_finalised() matcher = AncestorMatcher( variant_data, ancestors, ) ancestor_grouping = [] ancestor_lengths = ancestors.ancestors_length for group_index, group_ancestors in matcher.group_by_linesweep().items(): # Make ancestor_ids JSON serialisable group_ancestors = list(map(int, group_ancestors)) # The first group is trivial so never partition if group_index == 0: partitions = [group_ancestors] else: total_work = sum(ancestor_lengths[ancestor] for ancestor in group_ancestors) partition_count = math.ceil(total_work / min_work_per_job) if partition_count > max_num_partitions: partition_count = max_num_partitions # Partition into roughly equal sized bins (by work) sorted_ancestors = sorted( group_ancestors, key=lambda x: ancestor_lengths[x], reverse=True ) # Use greedy bin packing - place each ancestor in the bin with # lowest total length heap = [(0, []) for _ in range(partition_count)] for ancestor in sorted_ancestors: sum_len, partition = heapq.heappop(heap) partition.append(ancestor) sum_len += ancestor_lengths[ancestor] heapq.heappush(heap, (sum_len, partition)) partitions = [ sorted(partition) for sum_len, partition in heap if sum_len > 0 ] if len(partitions) > 1: group_dir = work_dir / f"group_{group_index}" group_dir.mkdir() group = { "ancestors": group_ancestors, "partitions": partitions if len(partitions) > 1 else None, } ancestor_grouping.append(group) metadata = { "variant_data_path": str(variant_data_path), "ancestral_state": ancestral_state, "ancestor_data_path": str(ancestor_data_path), "sample_mask": sample_mask, "site_mask": site_mask, "recombination_rate": recombination_rate, "mismatch_ratio": mismatch_ratio, "path_compression": path_compression, "recombination": recombination, "mismatch": mismatch, "precision": precision, "engine": engine, "extended_checks": extended_checks, "time_units": time_units, "record_provenance": record_provenance, "ancestor_grouping": ancestor_grouping, "start_time": start_time, } metadata_path = work_dir / "metadata.json" metadata_path.write_text(json.dumps(metadata)) return metadata
def initialize_ancestor_matcher(metadata, ancestors_ts=None, **kwargs): variant_data = formats.VariantData( metadata["variant_data_path"], ancestral_state=metadata["ancestral_state"], sample_mask=metadata["sample_mask"], site_mask=metadata["site_mask"], ) ancestors = formats.AncestorData.load(metadata["ancestor_data_path"]) variant_data._check_finalised() ancestors._check_finalised() return AncestorMatcher( variant_data, ancestors, ancestors_ts=ancestors_ts, time_units=metadata["time_units"], recombination_rate=metadata["recombination_rate"], recombination=metadata["recombination"], mismatch_ratio=metadata["mismatch_ratio"], mismatch=metadata["mismatch"], path_compression=metadata["path_compression"], precision=metadata["precision"], extended_checks=metadata["extended_checks"], engine=metadata["engine"], **kwargs, )
[docs] def match_ancestors_batch_groups( work_dir, group_index_start, group_index_end, num_threads=0 ): """ match_ancestors_batch_groups(work_dir, group_index_start, group_index_end, num_threads=0) Match a set of ancestor groups from `group_index_start`(inclusive) to `group_index_end`(exclusive) in a batched ancestor matching job. See :ref:`large scale inference<sec_large_scale>` for more details. A tree sequence file for `group_index_start - 1` must exist in `work_dir`, unless `group_index_start` is 0. After matching the tree sequence for `group_index_end - 1` is written to `work_dir`. :param str work_dir: The working directory for the batch job, as written by :meth:`match_ancestors_batch_init`. :param int group_index_start: The first group index to match. :param int group_index_end: The group index to stop matching at. :param int num_threads: The number of worker threads to use. If this is <= 1 then match sequentially. :return: The tree sequence representing the inferred ancestors for the last group matched :rtype: tskit.TreeSequence """ metadata_path = os.path.join(work_dir, "metadata.json") with open(metadata_path) as f: metadata = json.load(f) if group_index_start >= len(metadata["ancestor_grouping"]) or group_index_start < 0: raise ValueError(f"Group {group_index_start} is out of range") if group_index_end > len(metadata["ancestor_grouping"]) or group_index_end < 1: raise ValueError(f"Group {group_index_end} is out of range") if group_index_end <= group_index_start: raise ValueError("Group index end must be greater than start") with provenance.TimingAndMemory() as timing: if group_index_start == 0: ancestors_ts = None else: ancestors_ts = tskit.load( os.path.join(work_dir, f"ancestors_{group_index_start-1}.trees") ) matcher = initialize_ancestor_matcher( metadata, ancestors_ts, num_threads=num_threads ) ts = matcher.match_ancestors( { group_index: metadata["ancestor_grouping"][group_index]["ancestors"] for group_index in range(group_index_start, group_index_end) } ) path = os.path.join(work_dir, f"ancestors_{group_index_end-1}.trees") logger.info(f"Dumping to {path}") ts.dump(path) with open(path + ".resources", "w") as f: f.write(json.dumps(timing.metrics.asdict())) return ts
[docs] def match_ancestors_batch_group_partition(work_dir, group_index, partition_index): """ match_ancestors_batch_group_partition(work_dir, group_index, partition_index) Match a single partition of ancestors from a group in a batched ancestor matching job. See :ref:`large scale inference<sec_large_scale>` for more details. The tree sequence for the group before must exist in `work_dir`. After matching the results for the partition are written to `work_dir`. Once all partitions for a group have been matched, the group can be finalised using :meth:`match_ancestors_batch_group_finalise`. The number of partitions in a group is recorded in `metadata.json` in the work dir under the `ancestor_grouping` key. This method uses a single thread. :param str work_dir: The working directory for the batch job, as written by :meth:`match_ancestors_batch_init`. :param int group_index: The group index that contains the partition to match. :param int partition_index: The partition index to match. Must be less than the number of partitions in the batch job metadata for this group. """ start_time = time_.perf_counter() metadata_path = os.path.join(work_dir, "metadata.json") with open(metadata_path) as f: metadata = json.load(f) group = metadata["ancestor_grouping"][group_index] if group["partitions"] is None: raise ValueError(f"Group {group_index} has no partitions") if partition_index >= len(group["partitions"]) or partition_index < 0: raise ValueError(f"Partition {partition_index} is out of range") with provenance.TimingAndMemory() as timing: ancestors_ts = tskit.load( os.path.join(work_dir, f"ancestors_{group_index-1}.trees") ) matcher = initialize_ancestor_matcher(metadata, ancestors_ts) ancestors_to_match = group["partitions"][partition_index] results = matcher.match_partition( ancestors_to_match, group_index, partition_index ) partition_path = os.path.join( work_dir, f"group_{group_index}", f"partition_{partition_index}.pkl" ) logger.info(f"Dumping to {partition_path}") with open(partition_path, "wb") as f: pickle.dump((start_time, timing.metrics, ancestors_to_match, results), f)
[docs] def match_ancestors_batch_group_finalise(work_dir, group_index): """ match_ancestors_batch_group_finalise(work_dir, group_index) Finalise a group of partitioned ancestors in a batched ancestor matching job. See :ref:`large scale inference<sec_large_scale>` for more details. The tree sequence for the group before must exist in `work_dir`, along with the results for all partitions in this group. Writes the tree sequence for the group to `work_dir`. :param str work_dir: The working directory for the batch job, as written by :meth:`match_ancestors_batch_init`. :param int group_index: The group index to finalise. :return: The tree sequence representing the inferred ancestors for the group :rtype: tskit.TreeSequence """ with provenance.TimingAndMemory() as timing: metadata_path = os.path.join(work_dir, "metadata.json") with open(metadata_path) as f: metadata = json.load(f) group = metadata["ancestor_grouping"][group_index] ancestors_ts = tskit.load( os.path.join(work_dir, f"ancestors_{group_index-1}.trees") ) matcher = initialize_ancestor_matcher(metadata, ancestors_ts) logger.info( f"Finalising group {group_index}, loading " f"{len(group['partitions'])} partitions" ) start_times = [] timings = [] results = {} for partition_index in range(len(group["partitions"])): partition_path = os.path.join( work_dir, f"group_{group_index}", f"partition_{partition_index}.pkl" ) with open(partition_path, "rb") as f: start_time, part_timing, ancestors, result = pickle.load(f) start_times.append(start_time) for ancestor, r in zip(ancestors, result): results[ancestor] = r timings.append(part_timing) results = list(map(operator.itemgetter(1), sorted(results.items()))) ts = matcher.finalise_group(group, results, group_index) path = os.path.join(work_dir, f"ancestors_{group_index}.trees") ts.dump(path) combined_metrics = provenance.ResourceMetrics.combine(timings + [timing.metrics]) combined_metrics.elapsed_time = time_.perf_counter() - min(start_times) with open(path + ".resources", "w") as f: f.write(json.dumps(combined_metrics.asdict())) return ts
[docs] def match_ancestors_batch_finalise(work_dir): """ match_ancestors_batch_finalise(work_dir) Finalise a batched ancestor matching job. This method should be called after all groups have been matched, either by :meth:`match_ancestors_batch_groups` or :meth:`match_ancestors_batch_group_finalise`. Returns the final ancestors tree sequence for the batch job. `work_dir` is retained and not deleted. :param str work_dir: The working directory for the batch job, as written by :meth:`match_ancestors_batch_init`. :return: The tree sequence representing the inferred ancestors for the batch job :rtype: tskit.TreeSequence """ with provenance.TimingAndMemory() as timing: metadata_path = os.path.join(work_dir, "metadata.json") with open(metadata_path) as f: metadata = json.load(f) ancestor_data = formats.AncestorData.load(metadata["ancestor_data_path"]) final_group = len(metadata["ancestor_grouping"]) - 1 ts = tskit.load(os.path.join(work_dir, f"ancestors_{final_group}.trees")) tables = ts.dump_tables() for timestamp, record in ancestor_data.provenances(): tables.provenances.add_row(timestamp=timestamp, record=json.dumps(record)) if metadata["record_provenance"]: # Find all the .resources files and combine them resources = [] for root, _, files in os.walk(work_dir): for file in files: if file.endswith(".resources"): with open(os.path.join(root, file)) as f: resource = provenance.ResourceMetrics(**json.load(f)) resources.append(resource) combined_resources = provenance.ResourceMetrics.combine( resources + [timing.metrics] ) combined_resources.elapsed_time = time_.perf_counter() - metadata["start_time"] record = provenance.get_provenance_dict( command="match_ancestors", resources=combined_resources.asdict(), mismatch_ratio=metadata["mismatch_ratio"], path_compression=metadata["path_compression"], precision=metadata["precision"], ) tables.provenances.add_row(record=json.dumps(record)) ts = tables.tree_sequence() return ts
[docs] def augment_ancestors( variant_data, ancestors_ts, indexes, *, recombination_rate=None, mismatch_ratio=None, path_compression=True, num_threads=0, # Deliberately undocumented parameters below recombination=None, # See :class:`Matcher` mismatch=None, # See :class:`Matcher` precision=None, extended_checks=False, engine=constants.C_ENGINE, progress_monitor=None, record_provenance=True, ): """ augment_ancestors(variant_data, ancestors_ts, indexes, *, recombination_rate=None,\ mismatch_ratio=None, path_compression=True, num_threads=0) Runs the sample matching :ref:`algorithm <sec_inference_match_samples>` on the specified :class:`VariantData` instance and ancestors tree sequence, for the specified subset of sample indexes, returning the :class:`tskit.TreeSequence` instance including these samples. This tree sequence can then be used as an ancestors tree sequence for subsequent matching against all samples. See :ref:`matching ancestors & samples<sec_inference_match_ancestors_and_samples>` in the documentation for details of ``recombination_rate``, ``mismatch_ratio`` and ``path_compression``. :param VariantData variant_data: The :class:`VariantData` instance representing the input data. :param tskit.TreeSequence ancestors_ts: The :class:`tskit.TreeSequence` instance representing the inferred history among ancestral ancestral haplotypes. :param array indexes: The sample indexes to insert into the ancestors tree sequence, in increasing order. :param recombination_rate: Either a floating point value giving a constant rate :math:`\\rho` per unit length of genome, or an :class:`msprime.RateMap` object. This is used to calculate the probability of recombination between adjacent sites. If ``None``, all matching conflicts are resolved by recombination and all inference sites will have a single mutation (equivalent to mismatch_ratio near zero) :type recombination_rate: float, msprime.RateMap :param float mismatch_ratio: The probability of a mismatch relative to the median probability of recombination between adjacent sites: can only be used if a recombination rate has been set (default: ``None`` treated as 1 if ``recombination_rate`` is set). :param bool path_compression: Whether to merge edges that share identical paths (essentially taking advantage of shared recombination breakpoints). :param int num_threads: The number of match worker threads to use. If this is <= 0 then a simpler sequential algorithm is used (default). :return: The specified ancestors tree sequence augmented with copying paths for the specified sample. :rtype: tskit.TreeSequence """ variant_data._check_finalised() progress_monitor = _get_progress_monitor(progress_monitor, augment_ancestors=True) manager = SampleMatcher( variant_data, ancestors_ts, recombination_rate=recombination_rate, mismatch_ratio=mismatch_ratio, recombination=recombination, mismatch=mismatch, path_compression=path_compression, num_threads=num_threads, precision=precision, extended_checks=extended_checks, engine=engine, progress_monitor=progress_monitor, ) sample_indexes = check_sample_indexes(variant_data, indexes) sample_times = np.zeros( len(sample_indexes), dtype=variant_data.individuals_time.dtype ) manager.match_samples(sample_indexes, sample_times) ts = manager.get_augmented_ancestors_tree_sequence(sample_indexes) if record_provenance: tables = ts.dump_tables() record = provenance.get_provenance_dict( command="augment_ancestors", mismatch_ratio=mismatch_ratio, path_compression=path_compression, precision=precision, ) tables.provenances.add_row(record=json.dumps(record)) ts = tables.tree_sequence() return ts
@dataclasses.dataclass class SampleBatchWorkDescriptor: variant_data_path: str ancestral_state: str sample_mask: np.ndarray site_mask: np.ndarray ancestor_ts_path: str recombination_rate: float mismatch_ratio: float path_compression: bool indexes: np.ndarray recombination: float mismatch: float precision: int engine: str extended_checks: bool post_process: bool force_sample_times: bool overlay_non_inference_sites: bool record_provenance: bool sample_indexes: list sample_times: list num_samples_per_partition: int num_partitions: int start_time: float def common_params(self) -> dict: return { "recombination_rate": self.recombination_rate, "mismatch_ratio": self.mismatch_ratio, "path_compression": self.path_compression, "recombination": self.recombination, "mismatch": self.mismatch, "precision": self.precision, "engine": self.engine, "extended_checks": self.extended_checks, } def save(self, path): def numpy_encoder(obj): if isinstance(obj, np.ndarray): return { "__numpy__": True, "data": obj.tolist(), "dtype": str(obj.dtype), } return obj with open(path, "w") as f: json.dump(dataclasses.asdict(self), f, indent=2, default=numpy_encoder) @classmethod def load(cls, path): def numpy_decoder(dct): if "__numpy__" in dct: return np.array(dct["data"], dtype=dct["dtype"]) return dct with open(path) as f: wd_dict = json.load(f, object_hook=numpy_decoder) return cls(**wd_dict) def load_variant_data_and_ancestors_ts(wd: SampleBatchWorkDescriptor): variant_data = formats.VariantData( wd.variant_data_path, wd.ancestral_state, sample_mask=wd.sample_mask, site_mask=wd.site_mask, ) variant_data._check_finalised() ancestor_ts = tskit.load(wd.ancestor_ts_path) matcher = SampleMatcher( variant_data, ancestor_ts, **wd.common_params(), ) return variant_data, ancestor_ts, matcher
[docs] def match_samples_batch_init( work_dir, variant_data_path, ancestral_state, ancestor_ts_path, min_work_per_job, *, sample_mask=None, site_mask=None, recombination_rate=None, mismatch_ratio=None, path_compression=True, indexes=None, post_process=None, force_sample_times=False, overlay_non_inference_sites=None, # Deliberately undocumented parameters below recombination=None, # See :class:`Matcher` mismatch=None, # See :class:`Matcher` precision=None, extended_checks=False, engine=constants.C_ENGINE, record_provenance=True, ): """ match_samples_batch_init(work_dir, variant_data_path, ancestral_state, ancestor_ts_path, min_work_per_job, \\*, sample_mask=None, site_mask=None, recombination_rate=None, mismatch_ratio=None, path_compression=True, indexes=None, post_process=None, force_sample_times=False) Initialise a batched sample matching job. Creates `work_dir` and writes job details to `metadata.json`. The job can then be run using parallel calls to :meth:`match_samples_batch_partition` and once those are complete finally :meth:`match_samples_batch_finalise`. The `num_partitions` key in the metadata dict contains the number of partitions that need to be processed. :param str work_dir: The directory in which to store the working files. :param str variant_data_path: The input dataset in `VCF Zarr <https://github.com/sgkit-dev/vcf-zarr-spec>`_ format. Path to the Zarr dataset saved on disk. See :class:`VariantData`. :param Union(array, str) ancestral_state: A numpy array of strings specifying the ancestral states (alleles) used in inference. This must be the same length as the number of unmasked sites in the dataset. Alternatively, a single string can be provided, giving the name of an array in the input dataset which contains the ancestral states. Unknown ancestral states can be specified using "N". Any ancestral states which do not match any of the known alleles at that site, will be tallied, and a warning issued summarizing the unknown ancestral states. :param str ancestor_ts_path: The path to the tree sequence file containing the ancestors generated by :meth:`match_ancestors_batch_finalise`, or :meth:`match_ancestors`. :param int min_work_per_job: The minimum amount of work (as a count of genotypes) to allocate to a single parallel job. If the amount of work in a group of samples exceeds this level it will be broken up into parallel partitions, subject to the constraint of `max_num_partitions`. :param Union(array, str) sample_mask: A numpy array of booleans specifying which samples to mask out (exclude) from the dataset. Alternatively, a string can be provided, giving the name of an array in the input dataset which contains the sample mask. If ``None`` (default), all samples are included. :param Union(array, str) site_mask: A numpy array of booleans specifying which sites to mask out (exclude) from the dataset. Alternatively, a string can be provided, giving the name of an array in the input dataset which contains the site mask. If ``None`` (default), all sites are included. :param recombination_rate: Either a floating point value giving a constant rate :math:`\\rho` per unit length of genome, or an :class:`msprime.RateMap` object. This is used to calculate the probability of recombination between adjacent sites. If ``None``, all matching conflicts are resolved by recombination and all inference sites will have a single mutation (equivalent to mismatch_ratio near zero) :type recombination_rate: float, msprime.RateMap :param float mismatch_ratio: The probability of a mismatch relative to the median probability of recombination between adjacent sites: can only be used if a recombination rate has been set (default: ``None`` treated as 1 if ``recombination_rate`` is set). :param bool path_compression: Whether to merge edges that share identical paths (essentially taking advantage of shared recombination breakpoints). :param indexes: The sample indexes to match. If ``None`` (default), all samples are matched. :type indexes: arraylike :param bool post_process: Whether to run the :func:`post_process` method on the the tree sequence which, among other things, removes ancestral material that does not end up in the current samples (if not specified, defaults to ``True``) :param bool force_sample_times: After matching, should an attempt be made to adjust the time of "historical samples" (those associated with an individual having a non-zero time) such that the sample nodes in the tree sequence appear at the time of the individual with which they are associated. :param bool overlay_non_inference_sites: If True, sites that were included in the selected sites, but were not used for inference, will be added to the tree sequence by mapping their mutations over the inferred topology. Defaults to True. :return: A dictionary of the job metadata, as written to `metadata.json` in `work_dir`. """ # Convert working_dir to pathlib.Path work_dir = pathlib.Path(work_dir) # Create work dir work_dir.mkdir(parents=True, exist_ok=True) wd = SampleBatchWorkDescriptor( variant_data_path=str(variant_data_path), ancestral_state=ancestral_state, sample_mask=sample_mask, site_mask=site_mask, ancestor_ts_path=str(ancestor_ts_path), recombination_rate=recombination_rate, mismatch_ratio=mismatch_ratio, path_compression=path_compression, indexes=indexes, recombination=recombination, mismatch=mismatch, precision=precision, engine=engine, extended_checks=extended_checks, post_process=post_process, force_sample_times=force_sample_times, overlay_non_inference_sites=overlay_non_inference_sites, record_provenance=record_provenance, sample_indexes=[], sample_times=[], num_samples_per_partition=0, num_partitions=0, start_time=time_.perf_counter(), ) variant_data, ancestor_ts, matcher = load_variant_data_and_ancestors_ts(wd) sample_indexes = check_sample_indexes(variant_data, indexes).tolist() sample_times = np.zeros( len(sample_indexes), dtype=variant_data.individuals_time.dtype ) if force_sample_times: individuals = variant_data.samples_individual[:][sample_indexes] # By construction all samples in an sd file have an individual: but check anyway assert np.all(individuals >= 0) sample_times = variant_data.individuals_time[:][individuals] # Here we might want to re-order sample_indexes and sample_times # so that any historical ones come first, any we bomb out early if they conflict # but that would mean re-ordering the sample nodes in the final ts, and # we sometimes assume they are in the same order as in the file sample_times = sample_times.tolist() wd.sample_indexes = sample_indexes wd.sample_times = sample_times num_samples_per_partition = max( 1, math.ceil(min_work_per_job // variant_data.num_sites) ) wd.num_samples_per_partition = num_samples_per_partition wd.num_partitions = math.ceil(len(sample_indexes) / num_samples_per_partition) wd_path = work_dir / "metadata.json" wd.save(wd_path) return wd
[docs] def match_samples_batch_partition(work_dir, partition_index): """ match_samples_batch_partition(work_dir, partition_index) Match a single partition of samples in a batched sample matching job. See :ref:`large scale inference<sec_large_scale>` for more details. Match data for the partition is written to `work_dir`. Uses a single thread to perform matching. :param str work_dir: The working directory for the batch job, as written by :meth:`match_samples_batch_init`. :param int partition_index: The partition index to match. Must be less than the number of partitions in the batch job metadata key `num_partitions`. """ wd_path = pathlib.Path(work_dir) / "metadata.json" wd = SampleBatchWorkDescriptor.load(wd_path) if partition_index >= wd.num_partitions or partition_index < 0: raise ValueError(f"Partition {partition_index} is out of range") sample_indexes = wd.sample_indexes sample_times = wd.sample_times partition_slice = slice( partition_index * wd.num_samples_per_partition, min( (partition_index + 1) * wd.num_samples_per_partition, len(sample_indexes), ), ) logger.info( f"Matching partition {partition_index} with {partition_slice.start} to" f" {partition_slice.stop} of {len(sample_indexes)} samples" ) with provenance.TimingAndMemory() as timing: variant_data, ancestor_ts, matcher = load_variant_data_and_ancestors_ts(wd) results = matcher.match_samples( sample_indexes, sample_times, slice_=partition_slice ) path = os.path.join(work_dir, f"partition_{partition_index}.pkl") logger.info(f"Dumping to {path}") with open(path, "wb") as f: pickle.dump((timing.metrics, results), f)
[docs] def match_samples_batch_finalise(work_dir): """ match_samples_batch_finalise(work_dir) Finalise a batched sample matching job. This method should be called after all partitions have been matched by :meth:`match_samples_batch_partition`. Returns the final tree sequence for the batch job. `work_dir` is retained and not deleted. :param str work_dir: The working directory for the batch job, as written by :meth:`match_samples_batch_init`. :return: The tree sequence representing the inferred history of the samples. :rtype: tskit.TreeSequence """ with provenance.TimingAndMemory() as timing: wd_path = os.path.join(work_dir, "metadata.json") wd = SampleBatchWorkDescriptor.load(wd_path) variant_data, ancestor_ts, matcher = load_variant_data_and_ancestors_ts(wd) results = [] timings = [] for partition_index in range(wd.num_partitions): path = os.path.join(work_dir, f"partition_{partition_index}.pkl") with open(path, "rb") as f: part_timing, part_results = pickle.load(f) results.extend(part_results) timings.append(part_timing) ts = match_samples( variant_data, ancestor_ts, indexes=wd.indexes, post_process=wd.post_process, force_sample_times=wd.force_sample_times, record_provenance=wd.record_provenance, overlay_non_inference_sites=wd.overlay_non_inference_sites, results=results, **wd.common_params(), ) # Rewrite the last provenance with the correct info start_time = wd.start_time combined_metrics = provenance.ResourceMetrics.combine(timings + [timing.metrics]) combined_metrics.elapsed_time = time_.perf_counter() - start_time tables = ts.dump_tables() prov = tables.provenances[-1] record = json.loads(prov.record) record["resources"] = combined_metrics.asdict() tables.provenances[-1] = prov.replace(record=json.dumps(record)) return tables.tree_sequence()
[docs] def match_samples( variant_data, ancestors_ts, *, recombination_rate=None, mismatch_ratio=None, path_compression=True, indexes=None, post_process=None, force_sample_times=False, num_threads=0, overlay_non_inference_sites=None, # Deliberately undocumented parameters below recombination=None, # See :class:`Matcher` mismatch=None, # See :class:`Matcher` precision=None, extended_checks=False, engine=constants.C_ENGINE, progress_monitor=None, simplify=None, # deprecated record_provenance=True, results=None, ): """ match_samples(variant_data, ancestors_ts, *, recombination_rate=None,\ mismatch_ratio=None, path_compression=True, post_process=None,\ indexes=None, force_sample_times=False, num_threads=0) Runs the sample matching :ref:`algorithm <sec_inference_match_samples>` on the specified :class:`VariantData` instance and ancestors tree sequence, returning the final :class:`tskit.TreeSequence` instance containing the full inferred history for all samples and sites. See :ref:`matching ancestors & samples<sec_inference_match_ancestors_and_samples>` in the documentation for details of ``recombination_rate``, ``mismatch_ratio`` and ``path_compression``. :param VariantData variant_data: The :class:`VariantData` instance representing the input data. :param tskit.TreeSequence ancestors_ts: The :class:`tskit.TreeSequence` instance representing the inferred history among ancestral ancestral haplotypes. :param recombination_rate: Either a floating point value giving a constant rate :math:`\\rho` per unit length of genome, or an :class:`msprime.RateMap` object. This is used to calculate the probability of recombination between adjacent sites. If ``None``, all matching conflicts are resolved by recombination and all inference sites will have a single mutation (equivalent to mismatch_ratio near zero) :type recombination_rate: float, msprime.RateMap :param float mismatch_ratio: The probability of a mismatch relative to the median probability of recombination between adjacent sites: can only be used if a recombination rate has been set (default: ``None`` treated as 1 if ``recombination_rate`` is set). :param bool path_compression: Whether to merge edges that share identical paths (essentially taking advantage of shared recombination breakpoints). :param array_like indexes: An array of indexes into the variant_data file of the samples to match (in increasing order) or None for all samples. :param bool post_process: Whether to run the :func:`post_process` method on the the tree sequence which, among other things, removes ancestral material that does not end up in the current samples (if not specified, defaults to ``True``) :param bool force_sample_times: After matching, should an attempt be made to adjust the time of "historical samples" (those associated with an individual having a non-zero time) such that the sample nodes in the tree sequence appear at the time of the individual with which they are associated. :param int num_threads: The number of match worker threads to use. If this is <= 0 then a simpler sequential algorithm is used (default). :param bool simplify: Treated as an alias for ``post_process``, deprecated but currently retained for backwards compatibility if set to ``False``. :param bool overlay_non_inference_sites: If True, sites that were included in the selected sites, but were not used for inference, will be added to the tree sequence by mapping their mutations over the inferred topology. Defaults to True. :return: The tree sequence representing the inferred history of the sample. :rtype: tskit.TreeSequence """ simplify_only = False # if true, carry out "old" (deprecated) simplify behaviour if simplify is None: if post_process is None: post_process = True else: if post_process is not None: raise ValueError("Can't specify both `simplify` and `post_process`") else: if simplify: logger.warning( "The `simplify` parameter is deprecated in favour of `post_process`" ) simplify_only = True post_process = True else: post_process = False if overlay_non_inference_sites is None: overlay_non_inference_sites = True else: overlay_non_inference_sites = overlay_non_inference_sites with provenance.TimingAndMemory() as timing: variant_data._check_finalised() progress_monitor = _get_progress_monitor(progress_monitor, match_samples=True) manager = SampleMatcher( variant_data, ancestors_ts, recombination_rate=recombination_rate, mismatch_ratio=mismatch_ratio, recombination=recombination, mismatch=mismatch, path_compression=path_compression, num_threads=num_threads, precision=precision, extended_checks=extended_checks, engine=engine, progress_monitor=progress_monitor, ) sample_indexes = check_sample_indexes(variant_data, indexes) sample_times = np.zeros( len(sample_indexes), dtype=variant_data.individuals_time.dtype ) if force_sample_times: individuals = variant_data.samples_individual[:][sample_indexes] # By construction all samples in an sd file have an # individual: but check anyway assert np.all(individuals >= 0) sample_times = variant_data.individuals_time[:][individuals] # Here we might want to re-order sample_indexes and sample_times # so that any historical ones come first, any we bomb out early # if they conflict but that would mean re-ordering the sample # nodes in the final ts, and we sometimes assume they are in # the same order as in the file manager.match_samples(sample_indexes, sample_times, results) ts = manager.finalise(overlay_non_inference_sites) if post_process: ts = _post_process( ts, warn_if_unexpected_format=True, simplify_only=simplify_only ) if record_provenance: tables = ts.dump_tables() # We don't have a source here because tree sequence files don't have a UUID yet. record = provenance.get_provenance_dict( command="match_samples", resources=timing.metrics.asdict(), mismatch_ratio=mismatch_ratio, path_compression=path_compression, precision=precision, simplify=simplify, post_process=post_process, # TODO: maybe record recombination rate (which could be a RateMap) ) tables.provenances.add_row(record=json.dumps(record)) ts = tables.tree_sequence() return ts
def insert_missing_sites( variant_data, tree_sequence, *, sample_id_map=None, progress_monitor=None ): """ Return a new tree sequence containing extra sites that are present in a :class:`VariantData` instance but are missing from a corresponding tree sequence. At each newly inserted site, mutations are overlaid parsimoneously, using :meth:`tskit.Tree.map_mutations`, such that the realised variation at that site corresponds to the allelic distribution seen in the variant_data file. Sites that have mutations overlaid in this way can be identified in the output tree sequence as their :ref:`metadata<tskit.sec_metadata_definition>` will contain a key named ``inference`` set to ``tsinfer.INFERENCE_PARSIMONY``. Newly inserted sites that do not require mutations will have this set to `tsinfer.INFERENCE_NONE` instead. Sites in ``variant_data`` that already exist in the tree sequence are left untouched. By default, sample 0 in ``variant_data`` is assumed to correspond to the first sample node in the input tree sequence (i.e. ``tree_sequence.samples()[0]``), sample 1 to the second sample node, and so on. If this is not the case, a map can be provided, which specifies the sample ids in ``variant_data`` that correspond to the sample nodes in the tree sequence. This also allows the use of :class:`VariantData` instances that contain samples in addition to those in the original tree sequence. .. note:: Sample states observed as missing in the input ``variant_data`` need not correspond to samples whose nodes are actually "missing" (i.e. :ref:`isolated<tskit.sec_data_model_missing_data>`) in the input tree sequence. In this case, the allelic state of the sample in the returned tree sequence will be imputed to the most parsimonious state. .. note:: If the ancestral state at a site is unknown (i.e. ``tskit.MISSING_DATA``), it will be inferred by parsimony. If it is unknown and all sample data is missing at that site, the site will be created with an ancestral state set to the empty string. :param VariantData variant_data: The :class:`VariantData` instance containing some sites that are not in the input tree sequence. :param tskit.TreeSequence tree_sequence: The input :class:`tskit.TreeSequence` whose sample nodes correspond to a set of samples in the variant_data. :param sample_id_map array: An array of length `tree_sequence.num_samples` specifying the indexes of samples in the variant_data file that correspond to sample nodes ``0..(num_samples-1)`` in the tree sequence. If None, assume that all the samples in variant_data correspond to the sample nodes in the tree sequence, and are in the same order. :return: The input tree sequence with additional sites and mutations. :rtype: tskit.TreeSequence """ if variant_data.sequence_length != tree_sequence.sequence_length: raise ValueError( "variant_data and tree_sequence must have the same sequence length" ) if sample_id_map is None: sample_id_map = np.arange(variant_data.num_samples) if len(sample_id_map) != tree_sequence.num_samples: raise ValueError( "You must specify the same number of samples in variant_data " "as in the tree_sequence" ) progress_monitor = _get_progress_monitor(progress_monitor) tables = tree_sequence.dump_tables() trees = tree_sequence.trees() tree = next(trees) positions = variant_data.sites_position[:] new_sd_sites = np.where(np.isin(positions, tables.sites.position) == 0)[0] schema = tables.sites.metadata_schema.schema # Create new sites and add the mutations progress = progress_monitor.get("ms_extra_sites", len(new_sd_sites)) for variant in variant_data.variants(sites=new_sd_sites, recode_ancestral=True): site = variant.site pos = site.position anc_state = site.ancestral_state anc_value = 0 # variant(recode_ancestral=True) always has 0 as the anc index G = variant.genotypes[sample_id_map] # We can't perform parsimony inference if all sites are missing, and there's no # point if all non-missing sites are the ancestral state, so skip these cases if np.all(np.logical_or(G == tskit.MISSING_DATA, G == anc_value)): metadata = _update_site_metadata( site.metadata, inference_type=constants.INFERENCE_NONE ) if schema is None: metadata = _encode_raw_metadata(metadata) tables.sites.add_row( position=pos, ancestral_state="" if anc_state is None else anc_state, metadata=metadata, ) else: while tree.interval[1] <= pos: tree = next(trees) anc_state, mapped_mutations = tree.map_mutations( G, variant.alleles, ancestral_state=anc_state ) metadata = _update_site_metadata( site.metadata, inference_type=constants.INFERENCE_PARSIMONY ) if schema is None: metadata = _encode_raw_metadata(metadata) new_site_id = tables.sites.add_row( position=pos, ancestral_state=anc_state, metadata=metadata, ) mut_map = {tskit.NULL: tskit.NULL} for i, mutation in enumerate(mapped_mutations): mut_map[i] = tables.mutations.add_row( site=new_site_id, node=mutation.node, derived_state=mutation.derived_state, parent=mut_map[mutation.parent], ) progress.update() progress.close() tables.sort() return tables.tree_sequence() class AncestorsGenerator: """ Manages the process of building ancestors. """ def __init__( self, variant_data, ancestor_data_path, ancestor_data_kwargs, num_threads=0, engine=constants.C_ENGINE, genotype_encoding=constants.GenotypeEncoding.EIGHT_BIT, mmap_temp_dir=None, progress_monitor=None, ): self.variant_data = variant_data self.ancestor_data_path = ancestor_data_path self.ancestor_data_kwargs = ancestor_data_kwargs self.progress_monitor = _get_progress_monitor( progress_monitor, generate_ancestors=True ) self.max_sites = variant_data.num_sites self.num_sites = 0 self.inference_site_ids = [] self.num_samples = variant_data.num_samples self.num_threads = num_threads self.mmap_temp_file = None mmap_fd = -1 genotype_matrix_size = self.max_sites * self.num_samples if genotype_encoding == constants.GenotypeEncoding.ONE_BIT: genotype_matrix_size /= 8 genotype_mem = humanize.naturalsize(genotype_matrix_size, binary=True) logging.info(f"Max encoded genotype matrix size={genotype_mem}") if mmap_temp_dir is not None: self.mmap_temp_file = tempfile.NamedTemporaryFile( dir=mmap_temp_dir, prefix="tsinfer-mmap-genotypes-" ) logging.info(f"Using mmapped {self.mmap_temp_file.name} for genotypes") mmap_fd = self.mmap_temp_file.fileno() if engine == constants.C_ENGINE: logger.debug("Using C AncestorBuilder implementation") self.ancestor_builder = _tsinfer.AncestorBuilder( self.num_samples, self.max_sites, genotype_encoding=genotype_encoding, mmap_fd=mmap_fd, ) elif engine == constants.PY_ENGINE: logger.debug("Using Python AncestorBuilder implementation") self.ancestor_builder = algorithm.AncestorBuilder( self.num_samples, self.max_sites, genotype_encoding=genotype_encoding, ) else: raise ValueError(f"Unknown engine:{engine}") def add_sites(self, exclude_positions=None): """ Add all sites that are suitable for inference into the ancestor builder (and subsequent inference), unless they are held in the specified list of excluded site positions. Suitable sites have only 2 listed alleles, one of which is defined as the ancestral_state, and where at least two samples carry the derived allele and at least one sample carries the ancestral allele. Suitable sites will be added at the time given by site.time, unless site.time is ``np.nan`` or ``tskit.UNKNOWN_TIME``. In the first case, the site will simply excluded as if it were in the list of ``excluded_positions``. In the second case, then the time associated with the site will be the frequency of the derived allele (i.e. the number of samples with the derived allele divided by the total number of samples with non-missing alleles). """ if exclude_positions is None: exclude_positions = set() else: exclude_positions = np.array(exclude_positions, dtype=np.float64) if len(exclude_positions.shape) != 1: raise ValueError("exclude_positions must be a 1D array of numbers") exclude_positions = set(exclude_positions) logger.info(f"Starting addition of {self.max_sites} sites") progress = self.progress_monitor.get("ga_add_sites", self.max_sites) inference_site_id = [] for variant in self.variant_data.variants(recode_ancestral=True): # If there's missing data the last allele is None num_alleles = len(variant.alleles) - int(variant.alleles[-1] is None) counts = allele_counts(variant.genotypes) use_site = False site = variant.site if ( site.position not in exclude_positions and num_alleles == 2 # This will ensure that the derived state is "1" and 1 < counts.derived < counts.known and site.ancestral_state is not None ): use_site = True time = site.time if tskit.is_unknown_time(time): # Non-variable sites have no obvious freq-as-time values assert counts.known != counts.derived assert counts.known != counts.ancestral time = counts.derived / counts.known if np.isnan(time): use_site = False # Site with meaningless time value: skip inference if use_site: self.ancestor_builder.add_site(time, variant.genotypes) inference_site_id.append(site.id) self.num_sites += 1 progress.update() progress.close() self.inference_site_ids = inference_site_id logger.info("Finished adding sites") def _run_synchronous(self, progress): a = np.zeros(self.num_sites, dtype=np.int8) for t, focal_sites in self.descriptors: before = time_.perf_counter() start, end = self.ancestor_builder.make_ancestor(focal_sites, a) duration = time_.perf_counter() - before logger.debug( "Made ancestor in {:.2f}s at timepoint {} " "from {} to {} (len={}) with {} focal sites ({})".format( duration, t, start, end, end - start, len(focal_sites), focal_sites, ) ) self.ancestor_data.add_ancestor( start=start, end=end, time=t, focal_sites=focal_sites, haplotype=a[start:end], ) progress.update() def _run_threaded(self, progress): # This works by pushing the ancestor descriptors onto the build_queue, # which the worker threads pop off and process. We need to add ancestors # in the the ancestor_data object in the correct order, so we maintain # a priority queue (add_queue) which allows us to track the next smallest # index of the generated ancestor. We add build ancestors to this queue # as they are built, and drain it when we can. queue_depth = 8 * self.num_threads # Seems like a reasonable limit build_queue = queue.Queue(queue_depth) add_lock = threading.Lock() next_add_index = 0 add_queue = [] def drain_add_queue(): nonlocal next_add_index num_drained = 0 while len(add_queue) > 0 and add_queue[0][0] == next_add_index: _, t, focal_sites, s, e, haplotype = heapq.heappop(add_queue) self.ancestor_data.add_ancestor( start=s, end=e, time=t, focal_sites=focal_sites, haplotype=haplotype ) progress.update() next_add_index += 1 num_drained += 1 logger.debug(f"Drained {num_drained} ancestors from add queue") def build_worker(thread_index): a = np.zeros(self.num_sites, dtype=np.int8) while True: work = build_queue.get() if work is None: break index, t, focal_sites = work start, end = self.ancestor_builder.make_ancestor(focal_sites, a) with add_lock: haplotype = a[start:end].copy() heapq.heappush( add_queue, (index, t, focal_sites, start, end, haplotype) ) drain_add_queue() build_queue.task_done() build_queue.task_done() build_threads = [ threads.queue_consumer_thread( build_worker, build_queue, name=f"build-worker-{j}", index=j ) for j in range(self.num_threads) ] logger.debug(f"Started {self.num_threads} build worker threads") for index, (t, focal_sites) in enumerate(self.descriptors): build_queue.put((index, t, focal_sites)) # Stop the the worker threads. for _ in range(self.num_threads): build_queue.put(None) for j in range(self.num_threads): build_threads[j].join() drain_add_queue() def run(self): descriptors = self.ancestor_builder.ancestor_descriptors() peak_ram = humanize.naturalsize(self.ancestor_builder.mem_size, binary=True) logger.info(f"Ancestor builder peak RAM: {peak_ram}") # Sort the descriptors so that we deterministically create ancestors # in the same order across implementations d = [(t, tuple(focal_sites)) for t, focal_sites in descriptors] self.descriptors = sorted(d, reverse=True) self.num_ancestors = len(self.descriptors) # Maps epoch numbers to their corresponding ancestor times. self.timepoint_to_epoch = {} for t, _ in reversed(self.descriptors): if t not in self.timepoint_to_epoch: self.timepoint_to_epoch[t] = len(self.timepoint_to_epoch) + 1 self.ancestor_data = formats.AncestorData( self.variant_data.sites_position[:][self.inference_site_ids], self.variant_data.sequence_length, path=self.ancestor_data_path, **self.ancestor_data_kwargs, ) if self.num_ancestors > 0: logger.info(f"Starting build for {self.num_ancestors} ancestors") progress = self.progress_monitor.get("ga_generate", self.num_ancestors) a = np.zeros(self.num_sites, dtype=np.int8) root_time = max(self.timepoint_to_epoch.keys()) av_timestep = root_time / len(self.timepoint_to_epoch) root_time += av_timestep # Add a root a bit older than the oldest ancestor # Add an extra ancestor to act as a type of "virtual root" for the matching # algorithm: rather an awkward hack, but also allows the ancestor IDs to # line up. It's normally removed when processing the final tree sequence. self.ancestor_data.add_ancestor( start=0, end=self.num_sites, time=root_time + av_timestep, focal_sites=np.array([], dtype=np.int32), haplotype=a, ) # This is the the "ultimate ancestor" of all zeros self.ancestor_data.add_ancestor( start=0, end=self.num_sites, time=root_time, focal_sites=np.array([], dtype=np.int32), haplotype=a, ) if self.num_threads <= 0: self._run_synchronous(progress) else: self._run_threaded(progress) progress.close() logger.info("Finished building ancestors") if self.mmap_temp_file is not None: try: self.mmap_temp_file.close() except: # noqa pass return self.ancestor_data class Matcher: """ A matching instance, used in both ``tsinfer.match_ancestors`` and ``tsinfer.match_samples``. For details of the ``path_compression``, `recombination_rate`` and ``mismatch_ratio`` parameters, see :ref:`matching ancestors & samples<sec_inference_match_ancestors_and_samples>`. Note that the ``recombination`` and ``mismatch`` parameters can be used in ``match_ancestors`` and ``match_samples`` and are passed directly to this function, but are deliberately not publicly documented in those methods. They are expected to be numpy arrays of length ``num_inference_sites - 1`` and ``num_inference_sites`` respectively, containing values between 0 and 1, and allow recombination and mismatch probabilities to be set directly. The ``recombination`` probabilities measure the probability of a recombination event between adjacent inference sites, used to calculate the HMM transition probabilities in the L&S-like matching algorithm. The ``mismatch`` probabilities are used to calculate the emission probabilities in the HMM. Note that values > 0.5 in the recombination and (particularly) the mutation arrays are likely to lead to pathological behaviour - for example, a mismatch probability of 1 means that a mismatch is *required* at every site. For this reason, the probabilities created for recombination and mismatch when using the the public-facing ``recombination_rate`` and ``mismatch_ratio`` parameters are never > 0.5. TODO: include deliberately non-public details of precision here. """ def __init__( self, variant_data, inference_site_position, num_threads=1, path_compression=True, recombination_rate=None, mismatch_ratio=None, recombination=None, mismatch=None, precision=None, extended_checks=False, engine=constants.C_ENGINE, progress_monitor=None, allow_multiallele=False, ): self.variant_data = variant_data self.num_threads = num_threads self.path_compression = path_compression self.num_samples = self.variant_data.num_samples self.num_sites = len(inference_site_position) if self.num_sites == 0: logger.warning("No sites used for inference") num_intervals = max(self.num_sites - 1, 0) self.progress_monitor = _get_progress_monitor(progress_monitor) self.match_progress = None # Allocated by subclass self.extended_checks = extended_checks all_sites = self.variant_data.sites_position[:] index = np.searchsorted(all_sites, inference_site_position) num_alleles = variant_data.num_alleles()[index] self.num_alleles = num_alleles if not np.all(all_sites[index] == inference_site_position): raise ValueError( "Site positions for inference must be a subset of those in " "the sample data file." ) self.inference_site_id = index # Map of site index to tree sequence position. Bracketing # values of 0 and L are used for simplicity. self.position_map = np.hstack( [inference_site_position, [variant_data.sequence_length]] ) self.position_map[0] = 0 self.recombination = np.zeros(self.num_sites) # TODO: reduce len by 1 self.mismatch = np.zeros(self.num_sites) if recombination is not None or mismatch is not None: if recombination is None or mismatch is None: raise ValueError( "Directly setting probabilities requires specifying " "both 'recombination' and 'mismatch'" ) if recombination_rate is not None or mismatch_ratio is not None: raise ValueError( "Cannot simultaneously specify recombination & recombination_rate, " "or mismatch and mismatch_ratio" ) logger.info("Recombination and mismatch probabilities given by user") else: # Must set recombination and mismatch arrays if recombination_rate is None and mismatch_ratio is not None: raise ValueError("Cannot use mismatch without setting recombination") if ( recombination_rate is None and mismatch_ratio is None ) or num_intervals == 0: # Special case: revert to tsinfer 0.1 behaviour with no mismatch allowed default_recombination_prob = 1e-2 default_mismatch_prob = 1e-20 # Substantially < the value above recombination = np.full(num_intervals, default_recombination_prob) mismatch = np.full(self.num_sites, default_mismatch_prob) logger.info( "Mismatch prevented by setting constant high recombination and " + "low mismatch probabilities" ) else: genetic_dists = self.recombination_rate_to_dist( recombination_rate, inference_site_position ) recombination = self.recombination_dist_to_prob(genetic_dists) if mismatch_ratio is None: mismatch_ratio = 1.0 mismatch = np.full( self.num_sites, self.mismatch_ratio_to_prob( mismatch_ratio, np.median(genetic_dists), num_alleles ), ) logger.info( "Recombination and mismatch probabilities calculated from " + f"specified recomb rates with mismatch ratio = {mismatch_ratio}" ) if len(recombination) != num_intervals: raise ValueError("Bad length for recombination array") if len(mismatch) != self.num_sites: raise ValueError("Bad length for mismatch array") if not (np.all(recombination >= 0) and np.all(recombination <= 1)): raise ValueError("Underlying recombination probabilities not between 0 & 1") if not (np.all(mismatch >= 0) and np.all(mismatch <= 1)): raise ValueError("Underlying mismatch probabilities not between 0 & 1") if precision is None: precision = 13 self.recombination[1:] = recombination self.mismatch[:] = mismatch self.precision = precision if len(recombination) == 0: logger.info("Fewer than two inference sites: no recombination possible") else: logger.info( "Summary of recombination probabilities between sites: " f"min={np.min(recombination):.5g}; " f"max={np.max(recombination):.5g}; " f"median={np.median(recombination):.5g}; " f"mean={np.mean(recombination):.5g}" ) if len(mismatch) == 0: logger.info("No inference sites: no mismatch possible") else: logger.info( "Summary of mismatch probabilities over sites: " f"min={np.min(mismatch):.5g}; " f"max={np.max(mismatch):.5g}; " f"median={np.median(mismatch):.5g}; " f"mean={np.mean(mismatch):.5g}" ) logger.info( f"Matching using {precision} digits of precision in likelihood calcs" ) self.engine = engine if engine == constants.C_ENGINE: logger.debug("Using C matcher implementation") self.tree_sequence_builder_class = _tsinfer.TreeSequenceBuilder self.ancestor_matcher_class = _tsinfer.AncestorMatcher elif engine == constants.PY_ENGINE: logger.debug("Using Python matcher implementation") self.tree_sequence_builder_class = algorithm.TreeSequenceBuilder self.ancestor_matcher_class = algorithm.AncestorMatcher else: raise ValueError(f"Unknown engine:{engine}") self.tree_sequence_builder = None # Allocate 64K nodes and edges initially. This will double as needed and will # quickly be big enough even for very large instances. self.max_edges = 64 * 1024 self.max_nodes = 64 * 1024 if np.any(num_alleles > 2) and not allow_multiallele: # Currently only used for unsupported extend operation. We can # remove in future versions. raise ValueError("Cannot currently match with > 2 alleles.") self.tree_sequence_builder = self.tree_sequence_builder_class( num_alleles=num_alleles, max_nodes=self.max_nodes, max_edges=self.max_edges ) logger.debug(f"Allocated tree sequence builder with max_nodes={self.max_nodes}") @staticmethod def find_path(matcher, child_id, haplotype, start, end): """ Finds the path of the specified haplotype and returns the MatchResult object """ missing = haplotype == tskit.MISSING_DATA match = np.full(len(haplotype), tskit.MISSING_DATA, dtype=np.int8) left, right, parent = matcher.find_path(haplotype, start, end, match) match[missing] = tskit.MISSING_DATA diffs = start + np.where(haplotype[start:end] != match[start:end])[0] derived_state = haplotype[diffs] result = MatchResult( node=child_id, path=Path(left=left, right=right, parent=parent), mutations_site=diffs.astype(np.int32), mutations_derived_state=derived_state, mean_traceback_size=matcher.mean_traceback_size, ) logger.debug( "Matched node {}; " "num_edges={} tb_size={:.2f} match_mem={}".format( child_id, left.shape[0], matcher.mean_traceback_size, humanize.naturalsize(matcher.total_memory, binary=True), ) ) return result @staticmethod def recombination_rate_to_dist(rho, positions): """ Return the mean number of recombinations between adjacent positions (i.e. the genetic distance in Morgans) given either a fixed rate or a RateMap """ try: return np.diff(rho.get_cumulative_mass(positions)) except AttributeError: return np.diff(positions) * rho @staticmethod def recombination_dist_to_prob(genetic_distances): """ Convert genetic distances (in Morgans) to a probability of recombination, (i.e. an odd number of events) assuming a Poisson distribution, see Haldane, 1919 J. Genetics 8: 299-309. This maxes out at 0.5 as dist -> inf """ return (1 - np.exp(-genetic_distances * 2)) / 2 @staticmethod def mismatch_ratio_to_prob(ratio, genetic_distances, num_alleles=2): """ Convert a mismatch ratio, relative to a genetic distance, to a probability of mismatch. A mismatch probability of 1 means that the emitted allele has a 100% probability of being different from the allele implied by the hidden state. For all allele types to be emitted with equal probability, regardless of the copying haplotype, the mismatch probability should be set to 1/num_alleles. For a small genetic_distance d, setting a ratio of X should give a probability of approximately X * r, where r is the recombination probability given by recombination_dist_to_prob(d) """ return (1 - np.exp(-genetic_distances * ratio * num_alleles)) / num_alleles def create_matcher_instance(self): return self.ancestor_matcher_class( self.tree_sequence_builder, recombination=self.recombination, mismatch=self.mismatch, precision=self.precision, extended_checks=self.extended_checks, ) def convert_inference_mutations(self, tables): """ Convert the mutations stored in the tree sequence builder into the output format. """ mut_site, node, derived_state, _ = self.tree_sequence_builder.dump_mutations() mutation_id = 0 num_mutations = len(mut_site) progress = self.progress_monitor.get( "ms_full_mutations", len(self.inference_site_id) ) schema = tables.sites.metadata_schema.schema for site in self.variant_data.sites(self.inference_site_id): metadata = _update_site_metadata(site.metadata, constants.INFERENCE_FULL) if schema is None: metadata = _encode_raw_metadata(metadata) site_id = tables.sites.add_row( site.position, ancestral_state=site.ancestral_state, metadata=metadata, ) while mutation_id < num_mutations and mut_site[mutation_id] == site_id: tables.mutations.add_row( site_id, node=node[mutation_id], derived_state=site.reorder_alleles()[derived_state[mutation_id]], ) mutation_id += 1 progress.update() progress.close() def restore_tree_sequence_builder(self): tables = self.ancestors_ts_tables if self.variant_data.sequence_length != tables.sequence_length: raise ValueError( "Ancestors tree sequence not compatible: sequence length is different to" " sample data file." ) if np.any(tables.nodes.time <= 0): raise ValueError("All nodes must have time > 0") edges = tables.edges # Get the indexes into the position array. left = np.searchsorted(self.position_map, edges.left) if np.any(self.position_map[left] != edges.left): raise ValueError("Invalid left coordinates") right = np.searchsorted(self.position_map, edges.right) if np.any(self.position_map[right] != edges.right): raise ValueError("Invalid right coordinates") # Need to sort by child ID here and left so that we can efficiently # insert the child paths. index = np.lexsort((left, edges.child)) nodes = tables.nodes self.tree_sequence_builder.restore_nodes(nodes.time, nodes.flags) self.tree_sequence_builder.restore_edges( left[index].astype(np.int32), right[index].astype(np.int32), edges.parent[index], edges.child[index], ) assert self.tree_sequence_builder.num_match_nodes == 1 + len( np.unique(edges.child) ) mutations = tables.mutations derived_state = np.zeros(len(mutations), dtype=np.int8) mutation_site = mutations.site site_id = 0 mutation_id = 0 for site in self.variant_data.sites(self.inference_site_id): while ( mutation_id < len(mutations) and mutation_site[mutation_id] == site_id ): allele = mutations[mutation_id].derived_state derived_state[mutation_id] = site.reorder_alleles().index(allele) mutation_id += 1 site_id += 1 self.tree_sequence_builder.restore_mutations( mutation_site, mutations.node, derived_state, mutations.parent ) logger.info( "Loaded {} samples {} nodes; {} edges; {} sites; {} mutations".format( self.num_samples, len(nodes), len(edges), self.num_sites, len(mutations), ) ) class AncestorMatcher(Matcher): def __init__( self, variant_data, ancestor_data, ancestors_ts=None, time_units=None, **kwargs ): super().__init__(variant_data, ancestor_data.sites_position[:], **kwargs) self.ancestor_data = ancestor_data if time_units is None: time_units = tskit.TIME_UNITS_UNCALIBRATED self.time_units = time_units self.num_ancestors = self.ancestor_data.num_ancestors if ancestors_ts is None: # Add nodes for all the ancestors so that the ancestor IDs are equal # to the node IDs. for t in self.ancestor_data.ancestors_time[:]: self.tree_sequence_builder.add_node(t) else: self.ancestors_ts_tables = ancestors_ts.tables self.restore_tree_sequence_builder() def group_by_linesweep(self): t = time_.time() start = self.ancestor_data.ancestors_start[:] end = self.ancestor_data.ancestors_end[:] time = self.ancestor_data.ancestors_time[:] # We only need to perform the grouping for the small epochs at earlier times. # Skipping the later epochs _really_ helps as later ancestors are dependent on # almost all the earlier ones, so the dependency graph becomes intractable. breaks = np.where(time[1:] != time[:-1])[0] epoch_start = np.hstack([[0], breaks + 1]) epoch_end = np.hstack([breaks + 1, [self.num_ancestors]]) time_slices = np.vstack([epoch_start, epoch_end]).T epoch_sizes = time_slices[:, 1] - time_slices[:, 0] median_size = np.median(epoch_sizes) cutoff = 500 * median_size # Zero out the first half so that an initial large epoch doesn't # get selected as the cutoff epoch_sizes[: len(epoch_sizes) // 2] = 0 # To choose a cutoff point find the first epoch that is 50 times larger than # the median epoch size. For a large set of human genomes the median epoch # size is around 10, so we'll stop grouping by linesweep at 5000. if np.max(epoch_sizes) <= cutoff: large_epoch = len(time_slices) large_epoch_first_ancestor = self.num_ancestors else: large_epoch = np.where(epoch_sizes > cutoff)[0][0] large_epoch_first_ancestor = time_slices[large_epoch, 0] logger.info(f"{len(time_slices)} epochs with {median_size} median size.") logger.info(f"First large (>{cutoff}) epoch is {large_epoch}") logger.info(f"Grouping {large_epoch_first_ancestor} ancestors by linesweep") ancestor_grouping = ancestors.group_ancestors_by_linesweep( start[:large_epoch_first_ancestor], end[:large_epoch_first_ancestor], time[:large_epoch_first_ancestor], ) # Add on the remaining epochs, grouped by time next_epoch = len(ancestor_grouping) + 1 for epoch in range(large_epoch, len(time_slices)): ancestor_grouping[next_epoch] = np.arange(*time_slices[epoch]) next_epoch += 1 # Remove the "virtual root" ancestor try: assert 0 in ancestor_grouping[0] ancestor_grouping[0].remove(0) except KeyError: pass logger.info( f"Finished grouping into {len(ancestor_grouping)} groups in " f"{time_.time() - t:.2f} seconds" ) return ancestor_grouping def __start_group(self, level, ancestor_ids): info = collections.OrderedDict( [("level", str(level)), ("nanc", str(len(ancestor_ids)))] ) self.progress_monitor.set_detail(info) self.tree_sequence_builder.freeze_indexes() def __complete_group(self, group, ancestor_ids, results): nodes_before = self.tree_sequence_builder.num_nodes match_nodes_before = self.tree_sequence_builder.num_match_nodes for child_id, result in zip(ancestor_ids, results): assert result.node == child_id self.tree_sequence_builder.add_path( int(child_id), result.path.left, result.path.right, result.path.parent, compress=self.path_compression, extended_checks=self.extended_checks, ) self.tree_sequence_builder.add_mutations( int(child_id), result.mutations_site, result.mutations_derived_state ) extra_nodes = self.tree_sequence_builder.num_nodes - nodes_before assert ( self.tree_sequence_builder.num_match_nodes == match_nodes_before + extra_nodes + len(ancestor_ids) ) logger.debug( "Finished group {} with {} ancestors; {} extra nodes inserted; " "mean_tb_size={:.2f} edges={};".format( group, len(ancestor_ids), extra_nodes, sum(result.mean_traceback_size for result in results) / len(results) if len(results) > 0 else float("nan"), self.tree_sequence_builder.num_edges, ) ) def match_locally(self, ancestor_ids): def thread_worker_function(ancestor): local_data = threading.local() if not hasattr(local_data, "matcher"): local_data.matcher = self.create_matcher_instance() result = self.find_path( matcher=local_data.matcher, child_id=ancestor.id, haplotype=ancestor.full_haplotype, start=ancestor.start, end=ancestor.end, ) self.match_progress.update() return result if self.num_threads > 0: results = list( threads.threaded_map( # noqa E731 thread_worker_function, self.ancestor_data.ancestors(indexes=ancestor_ids), self.num_threads, ) ) else: results = list( map( thread_worker_function, self.ancestor_data.ancestors(indexes=ancestor_ids), ) ) return results def match_ancestors(self, ancestor_grouping): logger.info(f"Starting ancestor matching for {len(ancestor_grouping)} groups") self.match_progress = self.progress_monitor.get( "ma_match", sum(len(ids) for ids in ancestor_grouping.values()) ) for group, ancestor_ids in ancestor_grouping.items(): t = time_.time() logger.info( f"Starting group {group} of {len(ancestor_grouping)} " f"with {len(ancestor_ids)} ancestors" ) self.__start_group(group, ancestor_ids) results = self.match_locally(ancestor_ids) self.__complete_group(group, ancestor_ids, results) logger.info( f"Finished group {group} of {len(ancestor_grouping)} in " f"{time_.time() - t:.2f} seconds" ) ts = self.store_output() self.match_progress.close() logger.info("Finished ancestor matching") return ts def match_partition(self, ancestors_to_match, group_index, partition_index): logger.info( f"Matching group {group_index} partition {partition_index} " f"with {len(ancestors_to_match)} ancestors" ) t = time_.time() self.__start_group(group_index, ancestors_to_match) self.match_progress = self.progress_monitor.get( "ma_match", len(ancestors_to_match) ) results = self.match_locally(ancestors_to_match) self.match_progress.close() logger.info(f"Matching took {time_.time() - t:.2f} seconds") return results def finalise_group(self, group, results, group_index): logger.info(f"Finalising group {group_index}") self.__start_group(group_index, group["ancestors"]) self.__complete_group(group_index, group["ancestors"], results) ts = self.store_output() logger.info(f"Finalised group {group_index}") return ts def fill_ancestors_tables(self, tables): """ Return the ancestors tree sequence tables. Only inference sites are included in this tree sequence. All nodes have the sample flag bit set, and if a node corresponds to an ancestor in the ancestors file, it is indicated via metadata. """ logger.debug("Building ancestors tree sequence") tsb = self.tree_sequence_builder flags, times = tsb.dump_nodes() pc_ancestors = is_pc_ancestor(flags) tables.nodes.set_columns(flags=flags, time=times) # Add metadata for any non-PC node, pointing to the original ancestor metadata = [] ancestor = 0 for is_pc in pc_ancestors: if is_pc: metadata.append(b"") else: metadata.append(_encode_raw_metadata({"ancestor_data_id": ancestor})) ancestor += 1 tables.nodes.packset_metadata(metadata) left, right, parent, child = tsb.dump_edges() tables.edges.set_columns( left=self.position_map[left], right=self.position_map[right], parent=parent, child=child, ) self.convert_inference_mutations(tables) logger.debug("Sorting ancestors tree sequence") tables.sort() # Note: it's probably possible to compute the mutation parents from the # tsb data structures but we're not doing it for now. tables.build_index() tables.compute_mutation_parents() logger.debug("Sorting ancestors tree sequence done") logger.info( "Built ancestors tree sequence: {} nodes ({} pc ancestors); {} edges; " "{} sites; {} mutations".format( len(tables.nodes), np.sum(pc_ancestors), len(tables.edges), len(tables.mutations), len(tables.sites), ) ) def store_output(self): tables = tskit.TableCollection( sequence_length=self.ancestor_data.sequence_length ) # We decided to use a permissive schema for the metadata, for flexibility dict_schema = tskit.MetadataSchema.permissive_json().schema dict_schema = add_to_schema( dict_schema, "ancestor_data_id", node_ancestor_data_id_metadata_definition ) tables.nodes.metadata_schema = tskit.MetadataSchema(dict_schema) if self.num_ancestors > 0: self.fill_ancestors_tables(tables) tables.time_units = self.time_units return tables.tree_sequence() class SampleMatcher(Matcher): def __init__(self, variant_data, ancestors_ts, **kwargs): self.ancestors_ts_tables = ancestors_ts.dump_tables() super().__init__( variant_data, self.ancestors_ts_tables.sites.position, **kwargs ) self.restore_tree_sequence_builder() # Map from input sample indexes (IDs in the SampleData file) to the # node ID in the tree sequence. self.sample_id_map = {} def match_locally(self, sample_indexes): def thread_worker_function(j_haplotype): j, haplotype = j_haplotype assert len(haplotype) == self.num_sites local_data = threading.local() if not hasattr(local_data, "matcher"): local_data.matcher = self.create_matcher_instance() logger.info( f"{time_.time()}Thread {threading.get_ident()} starting haplotype {j}" ) result = self.find_path( matcher=local_data.matcher, child_id=self.sample_id_map[j], haplotype=haplotype, start=0, end=self.num_sites, ) self.match_progress.update() logger.info( f"{time_.time()}Thread {threading.get_ident()} finished haplotype {j}" ) return result sample_haplotypes = self.variant_data.haplotypes( sample_indexes, sites=self.inference_site_id, recode_ancestral=True, ) if self.num_threads > 0: results = threads.threaded_map( thread_worker_function, sample_haplotypes, self.num_threads ) else: results = map(thread_worker_function, sample_haplotypes) return list(results) def _match_samples(self, sample_indexes, results=None): num_samples = len(sample_indexes) builder = self.tree_sequence_builder _, times = builder.dump_nodes() logger.info(f"Started matching for {num_samples} samples") if self.num_sites == 0: return if results is None: self.match_progress = self.progress_monitor.get("ms_match", num_samples) t = time_.time() results = self.match_locally(sample_indexes) logger.info( f"Finished matching for all samples in {time_.time() - t:.2f} seconds" ) self.match_progress.close() logger.info( "Inserting sample paths: {} edges in total".format( sum(len(r.path.left) for r in results) ) ) progress_monitor = self.progress_monitor.get("ms_paths", num_samples) for j, result in zip(sample_indexes, results): node_id = int(self.sample_id_map[j]) assert node_id == result.node if np.any(times[node_id] > times[result.path.parent]): p = result.path.parent[np.argmin(times[result.path.parent])] raise ValueError( f"Failed to put sample {j} (node {node_id}) at time " f"{times[node_id]} as it has a younger parent (node {p})." ) builder.add_path( result.node, result.path.left, result.path.right, result.path.parent, compress=self.path_compression, ) builder.add_mutations( result.node, result.mutations_site, result.mutations_derived_state, ) progress_monitor.update() progress_monitor.close() return results def match_samples( self, sample_indexes, sample_times=None, results=None, slice_=None ): if sample_times is None: sample_times = np.zeros(len(sample_indexes)) builder = self.tree_sequence_builder for j, t in zip(sample_indexes, sample_times): self.sample_id_map[j] = builder.add_node(t) if slice_ is None: slice_ = slice(0, len(sample_indexes)) return self._match_samples(sample_indexes[slice_], results) def finalise(self, overlay_non_inference_sites): logger.info("Finalising tree sequence") ts = self.get_samples_tree_sequence(overlay_non_inference_sites) # Check that there are the same number of samples as expected assert len(self.sample_id_map) == ts.num_samples return ts def get_samples_tree_sequence(self, overlay_non_inference_sites=True): """ Returns the current state of the build tree sequence. Sample nodes will have the sample node flag set and be in the same order as passed the order of sample_indexes passed to match_samples. For correct sample reconstruction, the non-inference sites also need to be placed into the resulting tree sequence. """ tsb = self.tree_sequence_builder tables = self.ancestors_ts_tables.copy() schema = self.variant_data.metadata_schema tables.metadata_schema = tskit.MetadataSchema(schema) tables.metadata = self.variant_data.metadata schema = self.variant_data.populations_metadata_schema if schema is not None: tables.populations.metadata_schema = tskit.MetadataSchema(schema) for metadata in self.variant_data.populations_metadata[:]: if schema is None: # Use the default json encoding to avoid breaking old code. tables.populations.add_row(_encode_raw_metadata(metadata)) else: tables.populations.add_row(metadata) schema = self.variant_data.individuals_metadata_schema if schema is not None: schema = add_to_schema( schema, "variant_data_time", definition=variant_data_time_metadata_definition, ) tables.individuals.metadata_schema = tskit.MetadataSchema(schema) num_ancestral_individuals = len(tables.individuals) for ind in self.variant_data.individuals(): metadata = ind.metadata if ind.time != 0: metadata["variant_data_time"] = ind.time if schema is None: metadata = _encode_raw_metadata(ind.metadata) tables.individuals.add_row( location=ind.location, metadata=metadata, flags=ind.flags, ) logger.debug("Adding tree sequence nodes") flags, times = tsb.dump_nodes() num_pc_ancestors = count_pc_ancestors(flags) # All true ancestors are samples in the ancestors tree sequence. We unset # the SAMPLE flag but keep other flags intact. new_flags = tables.nodes.flags new_flags = np.bitwise_and( new_flags, ~new_flags.dtype.type(tskit.NODE_IS_SAMPLE) ) tables.nodes.flags = new_flags.astype(np.uint32) sample_ids = list(self.sample_id_map.values()) assert len(tables.nodes) == sample_ids[0] individuals_population = self.variant_data.individuals_population[:] samples_individual = self.variant_data.samples_individual[:] individuals_time = self.variant_data.individuals_time[:] for index, sample_id in self.sample_id_map.items(): individual = samples_individual[index] if individuals_time[individual] != 0: flags[sample_id] = np.bitwise_or( flags[sample_id], constants.NODE_IS_HISTORICAL_SAMPLE ) population = individuals_population[individual] tables.nodes.add_row( flags=flags[sample_id], time=times[sample_id], population=population, individual=num_ancestral_individuals + individual, ) # Add in the remaining non-sample nodes. for u in range(sample_ids[-1] + 1, tsb.num_nodes): tables.nodes.add_row(flags=flags[u], time=times[u]) logger.debug("Adding tree sequence edges") tables.edges.clear() left, right, parent, child = tsb.dump_edges() if self.num_sites == 0: # We have no inference sites, so no edges have been estimated. To ensure # we have a rooted tree, we add in edges for each sample to an artificial # root. assert left.shape[0] == 0 max_node_time = tables.nodes.time.max() root = tables.nodes.add_row(flags=0, time=max_node_time + 1) ultimate = tables.nodes.add_row(flags=0, time=max_node_time + 2) tables.edges.add_row(0, tables.sequence_length, ultimate, root) for sample_id in sample_ids: tables.edges.add_row(0, tables.sequence_length, root, sample_id) else: tables.edges.set_columns( left=self.position_map[left], right=self.position_map[right], parent=parent, child=child, ) logger.debug("Sorting and building intermediate tree sequence.") tables.sites.clear() tables.mutations.clear() tables.sort() schema = self.variant_data.sites_metadata_schema if schema is not None: schema = add_to_schema( schema, "inference_type", definition=inference_type_metadata_definition, ) tables.sites.metadata_schema = tskit.MetadataSchema(schema) self.convert_inference_mutations(tables) # FIXME this is a shortcut. We should be computing the mutation parent above # during insertion (probably) tables.build_index() tables.compute_mutation_parents() logger.info( "Built samples tree sequence: {} nodes ({} pc); {} edges; " "{} sites; {} mutations".format( len(tables.nodes), num_pc_ancestors, len(tables.edges), len(tables.sites), len(tables.mutations), ) ) ts = tables.tree_sequence() num_additional_sites = self.variant_data.num_sites - self.num_sites if overlay_non_inference_sites and num_additional_sites > 0: logger.info("Mapping additional sites") assert np.array_equal(ts.samples(), list(self.sample_id_map.values())) ts = insert_missing_sites( self.variant_data, ts, sample_id_map=np.array(list(self.sample_id_map.keys())), progress_monitor=self.progress_monitor, ) else: logger.info("Skipping additional site mapping") return ts def get_augmented_ancestors_tree_sequence(self, sample_indexes): """ Return the ancestors tree sequence augmented with samples as extra ancestors. """ logger.debug("Building augmented ancestors tree sequence") tsb = self.tree_sequence_builder tables = self.ancestors_ts_tables.copy() dict_schema = tables.nodes.metadata_schema.schema assert dict_schema is not None dict_schema = add_to_schema( dict_schema, "variant_data_id", node_variant_data_id_metadata_definition ) tables.nodes.metadata_schema = tskit.MetadataSchema(dict_schema) flags, times = tsb.dump_nodes() s = 0 num_pc_ancestors = 0 for j in range(len(tables.nodes), len(flags)): if times[j] == 0.0: # This is an augmented ancestor node. tables.nodes.add_row( flags=constants.NODE_IS_SAMPLE_ANCESTOR, time=times[j], metadata={"variant_data_id": int(sample_indexes[s])}, ) s += 1 else: # This is a path compressed node tables.nodes.add_row(flags=flags[j], time=times[j]) assert is_pc_ancestor(flags[j]) num_pc_ancestors += 1 assert s == len(sample_indexes) assert len(tables.nodes) == len(flags) # Increment the time for all nodes so the augmented samples are no longer # at timepoint 0. tables.nodes.time = tables.nodes.time + 1 # TODO - check this works for augmented ancestors with missing data left, right, parent, child = tsb.dump_edges() tables.edges.set_columns( left=self.position_map[left], right=self.position_map[right], parent=parent, child=child, ) tables.sites.clear() tables.mutations.clear() self.convert_inference_mutations(tables) logger.debug("Sorting ancestors tree sequence") tables.sort() logger.debug("Sorting ancestors tree sequence done") logger.info( "Augmented ancestors tree sequence: {} nodes ({} extra pc ancestors); " "{} edges; {} sites; {} mutations".format( len(tables.nodes), num_pc_ancestors, len(tables.edges), len(tables.mutations), len(tables.sites), ) ) return tables.tree_sequence() @dataclasses.dataclass class Path: left: np.ndarray right: np.ndarray parent: np.ndarray @dataclasses.dataclass class MatchResult: node: int path: Path mutations_site: list mutations_derived_state: list mean_traceback_size: int def has_single_edge_over_grand_root(ts): # Internal function to check if this is a "raw" inferred tree sequence. if ts.num_edges < 2: # must have edge to grand root and above grand root return False last_edge = ts.edge(-1) if last_edge.left != 0 or last_edge.right != ts.sequence_length: return False # Not a single edge spanning the entire genome if ts.edge(-2).parent == last_edge.parent: return False # other edges point to the oldest node => not a virtual-like root return True def has_same_root_everywhere(ts): roots = set() for tree in ts.trees(): if not tree.has_single_root: return False roots.add(tree.root) if len(roots) > 1: return False return True
[docs] def post_process( ts, *, split_ultimate=None, erase_flanks=None, # Parameters below deliberately undocumented warn_if_unexpected_format=None, simplify_only=None, ): """ post_process(ts, *, split_ultimate=None, erase_flanks=None) Post-process a tsinferred tree sequence into a more conventional form. This is the function run by default on the final tree sequence output by :func:`match_samples`. It involves the following 4 steps: #. If the oldest node is connected to a single child via an edge that spans the entire tree sequence, this oldest node is removed, so that its child becomes the new root (this step is undertaken to remove the "virtual-root-like node" which is added to ancestor tree sequences to enable matching). #. If the oldest node is removed in the first step and the new root spans the entire genome, it is treated as the "ultimate ancestor" and (unless ``split_ultimate`` is ``False``) the node is split into multiple coexisiting nodes with the splits occurring whenever the children of the ultimate ancestor change. The rationale is that tsinfer creates a single ancestral haplotype with all inference sites in the ancestral state: this is, however, unlikely to represent a single ancestor in the past. If the tree sequence is then dated, the fact that ultimate ancestor is split into separate nodes allows these nodes to be dated to different times. #. Often, extensive regions of genome exist before the first defined site and after the last defined site. Since no data exists in these sections of the genome, post processing by default erases the inferred topology in these regions. However, if ``erase_flanks`` is False, the flanking regions at the start and end will be assigned the same topology as inferred at the first and last site respectively. #. The sample nodes are reordered such that they are the first nodes listed in the node table, removing tree nodes and edges that are not on a path between the root and any of the samples (by applying the :meth:`~tskit.TreeSequence.simplify` method with ``keep_unary`` set to True but ``filter_sites``, ``filter_populations`` and ``filter_individuals`` set to False). :param bool split_ultimate: If ``True`` (default) and the oldest node is the only parent to a single "ultimate ancestor" node, attempt to split this node into many separate nodes (see above). If ``False`` do not attempt to identify or split an ultimate ancestor node. :param bool erase_flanks: If ``True`` (default), keep only the inferred topology between the first and last sites. If ``False``, output regions of topology inferred before the first site and after the last site. :return: The post-processed tree sequence. :rtype: tskit.TreeSequence """ if split_ultimate is None: split_ultimate = True if erase_flanks is None: erase_flanks = True tables = ts.dump_tables() if not simplify_only: if has_single_edge_over_grand_root(ts): logger.info( "Removing the oldest edge to detach the virtual-root-like ancestor" ) last_edge = ts.edge(-1) # Edge with oldest parent is last in the edge table tables.edges.truncate(tables.edges.num_rows - 1) # move any mutations above the virtual-root-like ancestor to above the # ultimate ancestor instead (these will be mutations placed by parsimony) mutations_node = tables.mutations.node mutations_node[mutations_node == last_edge.parent] = last_edge.child tables.mutations.node = mutations_node if split_ultimate: split_ultimate_ancestor(tables, warn_if_unexpected_format) elif warn_if_unexpected_format: logger.warning( "Cannot find a virtual-root-like ancestor during preprocessing" ) if erase_flanks and ts.num_sites > 0: # So that the last site falls within a tree, we must add one to the # site position (or simply extend to the end of the ts) keep_max = min(ts.sites_position[-1] + 1, ts.sequence_length) tables.keep_intervals( [[ts.sites_position[0], keep_max]], simplify=False, record_provenance=False, ) erased = ts.sites_position[0] + ts.sequence_length - keep_max erased *= 100 / ts.sequence_length logger.info( f"Erased flanks covering {erased}% of the genome: " f"{ts.sites_position[0]} units at the start and " f"{ts.sequence_length - keep_max} units at the end" ) logger.info( "Simplifying with filter_sites=False, filter_populations=False, " "filter_individuals=False, and keep_unary=True on " f"{tables.nodes.num_rows} nodes and {tables.edges.num_rows} edges" ) # NB: if this is an inferred TS, match_samples is guaranteed to produce samples # in the same order as passed in to sample_indexes, and simplification will # simply stick all those at the start but keep their order the same tables.simplify( filter_sites=False, filter_populations=False, filter_individuals=False, keep_unary=True, record_provenance=False, ) logger.info( "Finished simplify; now have {} nodes and {} edges".format( tables.nodes.num_rows, tables.edges.num_rows ) ) return tables.tree_sequence()
def _post_process(*args, **kwargs): return post_process(*args, **kwargs) def split_ultimate_ancestor(tables, warn_if_unexpected_format=None): # Internal function: if a single oldest node is a root across the entire genome, # split it up into a set of contemporaneous nodes whenever the node children change ts = tables.tree_sequence() if not has_same_root_everywhere(ts): if warn_if_unexpected_format: logger.warning("Cannot find a single contiguous ultimate ancestor to split") return # Split into multiple contemporaneous nodes whenever the node children change genomewide_ultimate_ancestor_id = ts.edge(-1).parent genomewide_ultimate_ancestor = ts.node(genomewide_ultimate_ancestor_id) logger.info("Located the all zeros ultimate ancestor") root_breaks = set() edges = tables.edges j = len(edges) - 1 # the last edges are the ones connecting to the genomewide UA while j >= 0 and edges[j].parent == genomewide_ultimate_ancestor_id: root_breaks |= {edges[j].left, edges[j].right} j -= 1 root_breaks = sorted(root_breaks) assert root_breaks[0] == 0 if root_breaks[1] == tables.sequence_length: # Only a single edge: no splitting needed return logger.info(f"Splitting ultimate ancestor into {len(root_breaks) - 1} nodes") # detach the ultimate ancestor from all its children, so it can be simplified out tables.edges.truncate(j + 1) # Move the mutations above the ultimate ancestor to the new nodes mutation_ids = np.where(tables.mutations.node == genomewide_ultimate_ancestor_id)[0] mutation_positions = tables.sites.position[tables.mutations.site[mutation_ids]] mut_iter = zip(mutation_ids, mutation_positions) mutation_id, mutation_pos = next(mut_iter, (None, ts.sequence_length)) # Go through the trees, making a new root node whereever we hit a root_break # and recreating the edges to the children each time trees_iter = ts.trees() tree = next(trees_iter) left = root_breaks[0] for right in root_breaks[1:]: while tree.interval.right != right: tree = next(trees_iter) new_root = tables.nodes.append(genomewide_ultimate_ancestor) for c in tree.children(genomewide_ultimate_ancestor_id): tables.edges.add_row(parent=new_root, child=c, left=left, right=right) while mutation_pos < right: tables.mutations[mutation_id] = tables.mutations[mutation_id].replace( node=new_root ) mutation_id, mutation_pos = next(mut_iter, (None, ts.sequence_length)) left = right tables.sort() def minimise(ts): """ Returns a tree sequence with the minimal information required to represent the tree topologies at its sites. This is a convenience function used when we wish to use a subset of the sites in a tree sequence for ancestor matching. It is a thin-wrapper over the simplify method. """ return ts.simplify( reduce_to_site_topology=True, filter_sites=False, filter_individuals=False, filter_populations=False, )