Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions tsinfer/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class Site:
id = attr.ib()
time = attr.ib()
derived_count = attr.ib()
terminal = attr.ib()


class AncestorBuilder:
Expand Down Expand Up @@ -137,21 +138,23 @@ def store_site_genotypes(self, site_id, genotypes):
stop = start + self.encoded_genotypes_size
self.genotype_store[start:stop] = genotypes

def add_site(self, time, genotypes):
def add_site(self, time, genotypes, terminal):
"""
Adds a new site at the specified ID to the builder.
"""
site_id = len(self.sites)
derived_count = np.sum(genotypes == 1)
self.store_site_genotypes(site_id, genotypes)
self.sites.append(Site(site_id, time, derived_count))
sites_at_fixed_timepoint = self.time_map[time]
# Sites with an identical variant distribution (i.e. with the same
# genotypes.tobytes() value) and at the same time, are put into the same ancestor
# to which we allocate a unique ID (just use the genotypes value)
ancestor_uid = tuple(genotypes)
# Add each site to the list for this ancestor_uid at this timepoint
sites_at_fixed_timepoint[ancestor_uid].append(site_id)
self.sites.append(Site(site_id, time, derived_count, terminal))
if not terminal:
self.store_site_genotypes(site_id, genotypes)
sites_at_fixed_timepoint = self.time_map[time]
# Sites with an identical variant distribution (i.e. with the same
# genotypes.tobytes() value) and at the same time, are put into the
# same ancestor to which we allocate a unique ID (just use the genotypes
# value)
ancestor_uid = tuple(genotypes)
# Add each site to the list for this ancestor_uid at thigs timepoint
sites_at_fixed_timepoint[ancestor_uid].append(site_id)

def print_state(self):
print("Ancestor builder")
Expand Down Expand Up @@ -221,6 +224,8 @@ def compute_ancestral_states(self, a, focal_site, sites):
disagree = np.zeros(self.num_samples, dtype=bool)

for site_index in sites:
if self.sites[site_index].terminal:
break
a[site_index] = 0
last_site = site_index
g_l = self.get_site_genotypes(site_index)
Expand Down
36 changes: 29 additions & 7 deletions tsinfer/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -3093,7 +3093,14 @@ class AncestorData(DataContainer):
FORMAT_NAME = "tsinfer-ancestor-data"
FORMAT_VERSION = (3, 0)

def __init__(self, position, sequence_length, chunk_size_sites=None, **kwargs):
def __init__(
self,
inference_position,
terminal_position,
sequence_length,
chunk_size_sites=None,
**kwargs,
):
super().__init__(**kwargs)
self._last_time = 0
self.inference_sites_set = False
Expand All @@ -3111,15 +3118,22 @@ def __init__(self, position, sequence_length, chunk_size_sites=None, **kwargs):
self.create_dataset("sample_end", dtype=np.int32)
self.create_dataset("sample_time", dtype=np.float64)
self.create_dataset("sample_focal_sites", dtype="array:i4")

variant_position = np.concatenate([inference_position, terminal_position])
self.create_dataset(
"variant_position",
data=position,
shape=position.shape,
data=variant_position,
shape=variant_position.shape,
chunks=self._chunk_size_sites,
dtype=np.float64,
dimensions=["variants"],
)
self.create_dataset(
"terminal_position",
data=terminal_position,
shape=terminal_position.shape,
dtype=np.float64,
dimensions=["terminal_sites"],
)

# We have to include a ploidy dimension sgkit compatibility
a = self.create_dataset(
Expand Down Expand Up @@ -3277,10 +3291,17 @@ def num_sites(self):
@property
def sites_position(self):
"""
The positions of the inference sites used to generate the ancestors
The positions of the inference and terminal sites used to generate the ancestors
"""
return self.data["variant_position"]

@property
def terminal_position(self):
"""
The positions of the terminal sites used to generate the ancestors
"""
return self.data["terminal_position"]

@property
def ancestors_start(self):
return self.data["sample_start"]
Expand Down Expand Up @@ -3314,10 +3335,10 @@ def ancestors_length(self):
"""
# Ancestor start and end are half-closed. The last site is assumed
# to cover the region up to sequence length.
pos = np.hstack([self.sites_position[:], [self.sequence_length]])

start = self.ancestors_start[:]
end = self.ancestors_end[:]
return pos[end] - pos[start]
return self.sites_position[end] - self.sites_position[start]

def insert_proxy_samples(
self,
Expand Down Expand Up @@ -3683,6 +3704,7 @@ def add_ancestor(self, start, end, time, focal_sites, haplotype):
if start < 0:
raise ValueError("Start must be >= 0")
if end > self.num_sites:
print(f"[INFO] {end}, {self.num_sites}")
raise ValueError("end must be <= num_sites")
if start >= end:
raise ValueError("start must be < end")
Expand Down
86 changes: 65 additions & 21 deletions tsinfer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1807,6 +1807,8 @@ def __init__(
self.num_samples = variant_data.num_samples
self.num_threads = num_threads
self.mmap_temp_file = None
self.sites_position = None
self.terminal_position = None
mmap_fd = -1

genotype_matrix_size = self.max_sites * self.num_samples
Expand Down Expand Up @@ -1865,6 +1867,8 @@ def add_sites(self, exclude_positions=None):
logger.info(f"Starting addition of {self.max_sites} sites")
progress = self.progress_monitor.get("ga_add_sites", self.max_sites)
inference_site_id = []
last_position = 0

for variant in self.variant_data.variants(recode_ancestral=True):
# If there's missing data the last allele is None
num_alleles = len(variant.alleles) - int(variant.alleles[-1] is None)
Expand All @@ -1879,6 +1883,7 @@ def add_sites(self, exclude_positions=None):
and site.ancestral_state is not None
):
use_site = True
last_position = site.position
time = site.time
if tskit.is_unknown_time(time):
# Non-variable sites have no obvious freq-as-time values
Expand All @@ -1888,12 +1893,22 @@ def add_sites(self, exclude_positions=None):
if np.isnan(time):
use_site = False # Site with meaningless time value: skip inference
if use_site:
self.ancestor_builder.add_site(time, variant.genotypes)
self.ancestor_builder.add_site(time, variant.genotypes, terminal=False)
inference_site_id.append(site.id)
self.num_sites += 1
progress.update()
progress.close()
self.inference_site_ids = inference_site_id
# Add terminal site at end of sequence
zeros = np.zeros(self.num_samples, dtype=np.int8)
self.ancestor_builder.add_site(tskit.UNKNOWN_TIME, zeros, terminal=True)
self.num_sites += 1

terminal_position = last_position + 1
if terminal_position == self.variant_data.sequence_length:
terminal_position -= 0.5
self.terminal_position = np.array([terminal_position], dtype=np.float64)

logger.info("Finished adding sites")

def _run_synchronous(self, progress):
Expand Down Expand Up @@ -2000,15 +2015,18 @@ def run(self):
if t not in self.timepoint_to_epoch:
self.timepoint_to_epoch[t] = len(self.timepoint_to_epoch) + 1
self.ancestor_data = formats.AncestorData(
self.variant_data.sites_position[:][self.inference_site_ids],
self.variant_data.sequence_length,
inference_position=self.variant_data.sites_position[:][
self.inference_site_ids
],
terminal_position=self.terminal_position,
sequence_length=self.variant_data.sequence_length,
path=self.ancestor_data_path,
**self.ancestor_data_kwargs,
)
if self.num_ancestors > 0:
logger.info(f"Starting build for {self.num_ancestors} ancestors")
progress = self.progress_monitor.get("ga_generate", self.num_ancestors)
a = np.zeros(self.num_sites, dtype=np.int8)
a = np.zeros(self.num_sites - 1, dtype=np.int8)
root_time = max(self.timepoint_to_epoch.keys())
av_timestep = root_time / len(self.timepoint_to_epoch)
root_time += av_timestep # Add a root a bit older than the oldest ancestor
Expand All @@ -2017,15 +2035,15 @@ def run(self):
# line up. It's normally removed when processing the final tree sequence.
self.ancestor_data.add_ancestor(
start=0,
end=self.num_sites,
end=self.num_sites - 1,
time=root_time + av_timestep,
focal_sites=np.array([], dtype=np.int32),
haplotype=a,
)
# This is the the "ultimate ancestor" of all zeros
self.ancestor_data.add_ancestor(
start=0,
end=self.num_sites,
end=self.num_sites - 1,
time=root_time,
focal_sites=np.array([], dtype=np.int32),
haplotype=a,
Expand Down Expand Up @@ -2072,7 +2090,8 @@ class Matcher:
def __init__(
self,
variant_data,
inference_site_position,
combined_position,
terminal_position,
num_threads=1,
path_compression=True,
recombination_rate=None,
Expand All @@ -2090,30 +2109,33 @@ def __init__(
self.num_threads = num_threads
self.path_compression = path_compression
self.num_samples = self.variant_data.num_samples
self.num_sites = len(inference_site_position)
if self.num_sites == 0:
logger.warning("No sites used for inference")
num_intervals = max(self.num_sites - 1, 0)
self.progress_monitor = _get_progress_monitor(progress_monitor)
self.match_progress = None # Allocated by subclass
self.extended_checks = extended_checks

assert np.isin(terminal_position, combined_position).all()
inference_position = np.setdiff1d(
combined_position, terminal_position, assume_unique=True
)
self.num_sites = len(inference_position)
if self.num_sites == 0:
logger.warning("No sites used for inference")
num_intervals = max(self.num_sites - 1, 0)

all_sites = self.variant_data.sites_position[:]
index = np.searchsorted(all_sites, inference_site_position)
index = np.searchsorted(all_sites, inference_position)
num_alleles = variant_data.num_alleles()[index]
self.num_alleles = num_alleles
if not np.all(all_sites[index] == inference_site_position):
if not np.all(all_sites[index] == inference_position):
raise ValueError(
"Site positions for inference must be a subset of those in "
"the sample data file."
)
self.inference_site_id = index

# Map of site index to tree sequence position. Bracketing
# values of 0 and L are used for simplicity.
self.position_map = np.hstack(
[inference_site_position, [variant_data.sequence_length]]
)
# Map of site index to tree sequence position. Terminal site position
# is included is no longer set to sequence_length.
self.position_map = combined_position.copy()
self.position_map[0] = 0
self.recombination = np.zeros(self.num_sites) # TODO: reduce len by 1
self.mismatch = np.zeros(self.num_sites)
Expand Down Expand Up @@ -2149,7 +2171,7 @@ def __init__(
)
else:
genetic_dists = self.recombination_rate_to_dist(
recombination_rate, inference_site_position
recombination_rate, inference_position
)
recombination = self.recombination_dist_to_prob(genetic_dists)
if mismatch_ratio is None:
Expand Down Expand Up @@ -2342,6 +2364,12 @@ def convert_inference_mutations(self, tables):
progress.update()
progress.close()

site_id = tables.sites.add_row(
self.terminal_position[0],
ancestral_state="N",
metadata=b"",
)

def restore_tree_sequence_builder(self):
tables = self.ancestors_ts_tables
if self.variant_data.sequence_length != tables.sequence_length:
Expand Down Expand Up @@ -2407,8 +2435,14 @@ class AncestorMatcher(Matcher):
def __init__(
self, variant_data, ancestor_data, ancestors_ts=None, time_units=None, **kwargs
):
super().__init__(variant_data, ancestor_data.sites_position[:], **kwargs)
super().__init__(
variant_data,
combined_position=ancestor_data.sites_position[:],
terminal_position=ancestor_data.terminal_position[:],
**kwargs,
)
self.ancestor_data = ancestor_data
self.terminal_position = ancestor_data.terminal_position
if time_units is None:
time_units = tskit.TIME_UNITS_UNCALIBRATED
self.time_units = time_units
Expand Down Expand Up @@ -2674,8 +2708,18 @@ def store_output(self):
class SampleMatcher(Matcher):
def __init__(self, variant_data, ancestors_ts, **kwargs):
self.ancestors_ts_tables = ancestors_ts.dump_tables()

ancestral_state_vals = ancestors_ts.tables.sites.ancestral_state
ancestral_state = np.char.decode(ancestral_state_vals.view("S1"), "ascii")
terminal_sites = np.where(ancestral_state == "N")[0]
terminal_position = ancestors_ts.sites_position[terminal_sites]
self.terminal_position = terminal_position

super().__init__(
variant_data, self.ancestors_ts_tables.sites.position, **kwargs
variant_data,
combined_position=self.ancestors_ts_tables.sites.position,
terminal_position=terminal_position,
**kwargs,
)
self.restore_tree_sequence_builder()
# Map from input sample indexes (IDs in the SampleData file) to the
Expand Down
Loading