Skip to content

Commit 580c3e6

Browse files
committed
Add support for terminal site to generate_ancestors with PY engine
1 parent 4bf2521 commit 580c3e6

File tree

3 files changed

+64
-23
lines changed

3 files changed

+64
-23
lines changed

tsinfer/algorithm.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class Site:
5858
id = attr.ib()
5959
time = attr.ib()
6060
derived_count = attr.ib()
61+
terminal = attr.ib()
6162

6263

6364
class AncestorBuilder:
@@ -137,21 +138,23 @@ def store_site_genotypes(self, site_id, genotypes):
137138
stop = start + self.encoded_genotypes_size
138139
self.genotype_store[start:stop] = genotypes
139140

140-
def add_site(self, time, genotypes):
141+
def add_site(self, time, genotypes, terminal):
141142
"""
142143
Adds a new site at the specified ID to the builder.
143144
"""
144145
site_id = len(self.sites)
145146
derived_count = np.sum(genotypes == 1)
146-
self.store_site_genotypes(site_id, genotypes)
147-
self.sites.append(Site(site_id, time, derived_count))
148-
sites_at_fixed_timepoint = self.time_map[time]
149-
# Sites with an identical variant distribution (i.e. with the same
150-
# genotypes.tobytes() value) and at the same time, are put into the same ancestor
151-
# to which we allocate a unique ID (just use the genotypes value)
152-
ancestor_uid = tuple(genotypes)
153-
# Add each site to the list for this ancestor_uid at this timepoint
154-
sites_at_fixed_timepoint[ancestor_uid].append(site_id)
147+
self.sites.append(Site(site_id, time, derived_count, terminal))
148+
if not terminal:
149+
self.store_site_genotypes(site_id, genotypes)
150+
sites_at_fixed_timepoint = self.time_map[time]
151+
# Sites with an identical variant distribution (i.e. with the same
152+
# genotypes.tobytes() value) and at the same time, are put into the
153+
# same ancestor to which we allocate a unique ID (just use the genotypes
154+
# value)
155+
ancestor_uid = tuple(genotypes)
156+
# Add each site to the list for this ancestor_uid at thigs timepoint
157+
sites_at_fixed_timepoint[ancestor_uid].append(site_id)
155158

156159
def print_state(self):
157160
print("Ancestor builder")
@@ -221,6 +224,8 @@ def compute_ancestral_states(self, a, focal_site, sites):
221224
disagree = np.zeros(self.num_samples, dtype=bool)
222225

223226
for site_index in sites:
227+
if self.sites[site_index].terminal:
228+
break
224229
a[site_index] = 0
225230
last_site = site_index
226231
g_l = self.get_site_genotypes(site_index)

tsinfer/formats.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3093,7 +3093,14 @@ class AncestorData(DataContainer):
30933093
FORMAT_NAME = "tsinfer-ancestor-data"
30943094
FORMAT_VERSION = (3, 0)
30953095

3096-
def __init__(self, position, sequence_length, chunk_size_sites=None, **kwargs):
3096+
def __init__(
3097+
self,
3098+
inference_position,
3099+
terminal_position,
3100+
sequence_length,
3101+
chunk_size_sites=None,
3102+
**kwargs,
3103+
):
30973104
super().__init__(**kwargs)
30983105
self._last_time = 0
30993106
self.inference_sites_set = False
@@ -3111,15 +3118,22 @@ def __init__(self, position, sequence_length, chunk_size_sites=None, **kwargs):
31113118
self.create_dataset("sample_end", dtype=np.int32)
31123119
self.create_dataset("sample_time", dtype=np.float64)
31133120
self.create_dataset("sample_focal_sites", dtype="array:i4")
3114-
3121+
variant_position = np.concatenate([inference_position, terminal_position])
31153122
self.create_dataset(
31163123
"variant_position",
3117-
data=position,
3118-
shape=position.shape,
3124+
data=variant_position,
3125+
shape=variant_position.shape,
31193126
chunks=self._chunk_size_sites,
31203127
dtype=np.float64,
31213128
dimensions=["variants"],
31223129
)
3130+
self.create_dataset(
3131+
"terminal_position",
3132+
data=terminal_position,
3133+
shape=terminal_position.shape,
3134+
dtype=np.float64,
3135+
dimensions=["terminal_sites"],
3136+
)
31233137

31243138
# We have to include a ploidy dimension sgkit compatibility
31253139
a = self.create_dataset(
@@ -3277,10 +3291,17 @@ def num_sites(self):
32773291
@property
32783292
def sites_position(self):
32793293
"""
3280-
The positions of the inference sites used to generate the ancestors
3294+
The positions of the inference and terminal sites used to generate the ancestors
32813295
"""
32823296
return self.data["variant_position"]
32833297

3298+
@property
3299+
def terminal_position(self):
3300+
"""
3301+
The positions of the terminal sites used to generate the ancestors
3302+
"""
3303+
return self.data["terminal_position"]
3304+
32843305
@property
32853306
def ancestors_start(self):
32863307
return self.data["sample_start"]
@@ -3314,10 +3335,10 @@ def ancestors_length(self):
33143335
"""
33153336
# Ancestor start and end are half-closed. The last site is assumed
33163337
# to cover the region up to sequence length.
3317-
pos = np.hstack([self.sites_position[:], [self.sequence_length]])
3338+
33183339
start = self.ancestors_start[:]
33193340
end = self.ancestors_end[:]
3320-
return pos[end] - pos[start]
3341+
return self.sites_position[end] - self.sites_position[start]
33213342

33223343
def insert_proxy_samples(
33233344
self,
@@ -3683,6 +3704,7 @@ def add_ancestor(self, start, end, time, focal_sites, haplotype):
36833704
if start < 0:
36843705
raise ValueError("Start must be >= 0")
36853706
if end > self.num_sites:
3707+
print(f"[INFO] {end}, {self.num_sites}")
36863708
raise ValueError("end must be <= num_sites")
36873709
if start >= end:
36883710
raise ValueError("start must be < end")

tsinfer/inference.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1807,6 +1807,8 @@ def __init__(
18071807
self.num_samples = variant_data.num_samples
18081808
self.num_threads = num_threads
18091809
self.mmap_temp_file = None
1810+
self.sites_position = None
1811+
self.terminal_position = None
18101812
mmap_fd = -1
18111813

18121814
genotype_matrix_size = self.max_sites * self.num_samples
@@ -1865,6 +1867,8 @@ def add_sites(self, exclude_positions=None):
18651867
logger.info(f"Starting addition of {self.max_sites} sites")
18661868
progress = self.progress_monitor.get("ga_add_sites", self.max_sites)
18671869
inference_site_id = []
1870+
last_position = 0
1871+
18681872
for variant in self.variant_data.variants(recode_ancestral=True):
18691873
# If there's missing data the last allele is None
18701874
num_alleles = len(variant.alleles) - int(variant.alleles[-1] is None)
@@ -1879,6 +1883,7 @@ def add_sites(self, exclude_positions=None):
18791883
and site.ancestral_state is not None
18801884
):
18811885
use_site = True
1886+
last_position = site.position
18821887
time = site.time
18831888
if tskit.is_unknown_time(time):
18841889
# Non-variable sites have no obvious freq-as-time values
@@ -1888,12 +1893,18 @@ def add_sites(self, exclude_positions=None):
18881893
if np.isnan(time):
18891894
use_site = False # Site with meaningless time value: skip inference
18901895
if use_site:
1891-
self.ancestor_builder.add_site(time, variant.genotypes)
1896+
self.ancestor_builder.add_site(time, variant.genotypes, terminal=False)
18921897
inference_site_id.append(site.id)
18931898
self.num_sites += 1
18941899
progress.update()
18951900
progress.close()
18961901
self.inference_site_ids = inference_site_id
1902+
# Add terminal site at end of sequence
1903+
zeros = np.zeros(self.num_samples, dtype=np.int8)
1904+
self.ancestor_builder.add_site(tskit.UNKNOWN_TIME, zeros, terminal=True)
1905+
self.num_sites += 1
1906+
self.terminal_position = np.array([last_position + 1], dtype=np.float64)
1907+
18971908
logger.info("Finished adding sites")
18981909

18991910
def _run_synchronous(self, progress):
@@ -2000,15 +2011,18 @@ def run(self):
20002011
if t not in self.timepoint_to_epoch:
20012012
self.timepoint_to_epoch[t] = len(self.timepoint_to_epoch) + 1
20022013
self.ancestor_data = formats.AncestorData(
2003-
self.variant_data.sites_position[:][self.inference_site_ids],
2004-
self.variant_data.sequence_length,
2014+
inference_position=self.variant_data.sites_position[:][
2015+
self.inference_site_ids
2016+
],
2017+
terminal_position=self.terminal_position,
2018+
sequence_length=self.variant_data.sequence_length,
20052019
path=self.ancestor_data_path,
20062020
**self.ancestor_data_kwargs,
20072021
)
20082022
if self.num_ancestors > 0:
20092023
logger.info(f"Starting build for {self.num_ancestors} ancestors")
20102024
progress = self.progress_monitor.get("ga_generate", self.num_ancestors)
2011-
a = np.zeros(self.num_sites, dtype=np.int8)
2025+
a = np.zeros(self.num_sites - 1, dtype=np.int8)
20122026
root_time = max(self.timepoint_to_epoch.keys())
20132027
av_timestep = root_time / len(self.timepoint_to_epoch)
20142028
root_time += av_timestep # Add a root a bit older than the oldest ancestor
@@ -2017,15 +2031,15 @@ def run(self):
20172031
# line up. It's normally removed when processing the final tree sequence.
20182032
self.ancestor_data.add_ancestor(
20192033
start=0,
2020-
end=self.num_sites,
2034+
end=self.num_sites - 1,
20212035
time=root_time + av_timestep,
20222036
focal_sites=np.array([], dtype=np.int32),
20232037
haplotype=a,
20242038
)
20252039
# This is the the "ultimate ancestor" of all zeros
20262040
self.ancestor_data.add_ancestor(
20272041
start=0,
2028-
end=self.num_sites,
2042+
end=self.num_sites - 1,
20292043
time=root_time,
20302044
focal_sites=np.array([], dtype=np.int32),
20312045
haplotype=a,

0 commit comments

Comments
 (0)