Skip to content

Commit 51364ad

Browse files
committed
python implementation in testing code
1 parent 1d8f453 commit 51364ad

File tree

1 file changed

+113
-34
lines changed

1 file changed

+113
-34
lines changed

python/tests/test_tree_stats.py

Lines changed: 113 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -145,67 +145,95 @@ def windowed_tree_stat(ts, stat, windows, span_normalise=True):
145145

146146

147147
def naive_branch_general_stat(
148-
ts, w, f, windows=None, polarised=False, span_normalise=True
148+
ts, w, f, windows=None, time_windows=None, polarised=False, span_normalise=True
149149
):
150150
# NOTE: does not behave correctly for unpolarised stats
151151
# with non-ancestral material.
152152
if windows is None:
153153
windows = [0.0, ts.sequence_length]
154+
drop_time_windows = time_windows is None
155+
if time_windows is None:
156+
time_windows = [0.0, np.inf]
157+
else:
158+
if time_windows[0] != 0:
159+
time_windows = [0] + time_windows
154160
n, k = w.shape
161+
tw = len(time_windows) - 1
155162
# hack to determine m
156163
m = len(f(w[0]))
157164
total = np.sum(w, axis=0)
158165

159-
sigma = np.zeros((ts.num_trees, m))
160-
for tree in ts.trees():
161-
x = np.zeros((ts.num_nodes, k))
162-
x[ts.samples()] = w
163-
for u in tree.nodes(order="postorder"):
164-
for v in tree.children(u):
165-
x[u] += x[v]
166-
if polarised:
167-
s = sum(tree.branch_length(u) * f(x[u]) for u in tree.nodes())
166+
sigma = np.zeros((ts.num_trees, tw, m))
167+
for j, upper_time in enumerate(time_windows[1:]):
168+
if np.isfinite(upper_time):
169+
decap_ts = ts.decapitate(upper_time)
168170
else:
169-
s = sum(
170-
tree.branch_length(u) * (f(x[u]) + f(total - x[u]))
171-
for u in tree.nodes()
172-
)
173-
sigma[tree.index] = s * tree.span
171+
decap_ts = ts
172+
assert np.all(list(ts.samples()) == list(decap_ts.samples()))
173+
for tree in decap_ts.trees():
174+
x = np.zeros((decap_ts.num_nodes, k))
175+
x[decap_ts.samples()] = w
176+
for u in tree.nodes(order="postorder"):
177+
for v in tree.children(u):
178+
x[u] += x[v]
179+
if polarised:
180+
s = sum(tree.branch_length(u) * f(x[u]) for u in tree.nodes())
181+
else:
182+
s = sum(
183+
tree.branch_length(u) * (f(x[u]) + f(total - x[u]))
184+
for u in tree.nodes()
185+
)
186+
sigma[tree.index, j, :] = s * tree.span
187+
for j in range(1, tw):
188+
sigma[:, j, :] = sigma[:, j, :] - sigma[:, j - 1, :]
174189
if isinstance(windows, str) and windows == "trees":
175190
# need to average across the windows
176191
if span_normalise:
177192
for j, tree in enumerate(ts.trees()):
178193
sigma[j] /= tree.span
179-
return sigma
194+
out = sigma
180195
else:
181-
return windowed_tree_stat(ts, sigma, windows, span_normalise=span_normalise)
196+
out = windowed_tree_stat(ts, sigma, windows, span_normalise=span_normalise)
197+
if drop_time_windows:
198+
assert out.ndim == 3
199+
out = out[:, 0]
200+
return out
182201

183202

184203
def branch_general_stat(
185-
ts, sample_weights, summary_func, windows=None, polarised=False, span_normalise=True
204+
ts,
205+
sample_weights,
206+
summary_func,
207+
windows=None,
208+
time_windows=None,
209+
polarised=False,
210+
span_normalise=True,
186211
):
187212
"""
188213
Efficient implementation of the algorithm used as the basis for the
189214
underlying C version.
190215
"""
191216
n, state_dim = sample_weights.shape
192217
windows = ts.parse_windows(windows)
218+
drop_time_windows = time_windows is None
219+
time_windows = ts.parse_time_windows(time_windows)
193220
num_windows = windows.shape[0] - 1
221+
num_time_windows = time_windows.shape[0] - 1
194222

195223
# Determine result_dim
196224
result_dim = len(summary_func(sample_weights[0]))
197-
result = np.zeros((num_windows, result_dim))
225+
result = np.zeros((num_windows, num_time_windows, result_dim))
198226
state = np.zeros((ts.num_nodes, state_dim))
199227
state[ts.samples()] = sample_weights
200228
total_weight = np.sum(sample_weights, axis=0)
201229

202230
time = ts.tables.nodes.time
203231
parent = np.zeros(ts.num_nodes, dtype=np.int32) - 1
204-
branch_length = np.zeros(ts.num_nodes)
232+
branch_length = np.zeros((num_time_windows, ts.num_nodes))
205233
# The value of summary_func(u) for every node.
206234
summary = np.zeros((ts.num_nodes, result_dim))
207235
# The result for the current tree *not* weighted by span.
208-
running_sum = np.zeros(result_dim)
236+
running_sum = np.zeros((num_time_windows, result_dim))
209237

210238
def polarised_summary(u):
211239
s = summary_func(state[u])
@@ -217,31 +245,48 @@ def polarised_summary(u):
217245
summary[u] = polarised_summary(u)
218246

219247
window_index = 0
248+
249+
def update_sum(u, sign):
250+
time_window_index = 0
251+
if parent[u] != -1:
252+
while (
253+
time_window_index < num_time_windows
254+
and time_windows[time_window_index] < time[parent[u]]
255+
):
256+
running_sum[time_window_index] += sign * (
257+
branch_length[time_window_index, u] * summary[u]
258+
)
259+
time_window_index += 1
260+
220261
for (t_left, t_right), edges_out, edges_in in ts.edge_diffs():
221262
for edge in edges_out:
222263
u = edge.child
223-
running_sum -= branch_length[u] * summary[u]
264+
update_sum(u, sign=-1)
224265
u = edge.parent
225266
while u != -1:
226-
running_sum -= branch_length[u] * summary[u]
267+
update_sum(u, sign=-1)
227268
state[u] -= state[edge.child]
228269
summary[u] = polarised_summary(u)
229-
running_sum += branch_length[u] * summary[u]
270+
update_sum(u, sign=+1)
230271
u = parent[u]
231272
parent[edge.child] = -1
232-
branch_length[edge.child] = 0
273+
for tw in range(num_time_windows):
274+
branch_length[tw, edge.child] = 0
233275

234276
for edge in edges_in:
235277
parent[edge.child] = edge.parent
236-
branch_length[edge.child] = time[edge.parent] - time[edge.child]
278+
for tw in range(num_time_windows):
279+
branch_length[tw, edge.child] = min(
280+
time[edge.parent], time_windows[tw + 1]
281+
) - max(time[edge.child], time_windows[tw])
237282
u = edge.child
238-
running_sum += branch_length[u] * summary[u]
283+
update_sum(u, sign=+1)
239284
u = edge.parent
240285
while u != -1:
241-
running_sum -= branch_length[u] * summary[u]
286+
update_sum(u, sign=-1)
242287
state[u] += state[edge.child]
243288
summary[u] = polarised_summary(u)
244-
running_sum += branch_length[u] * summary[u]
289+
update_sum(u, sign=+1)
245290
u = parent[u]
246291

247292
# Update the windows
@@ -253,7 +298,12 @@ def polarised_summary(u):
253298
right = min(t_right, w_right)
254299
span = right - left
255300
assert span > 0
256-
result[window_index] += running_sum * span
301+
time_window_index = 0
302+
while time_window_index < num_time_windows:
303+
result[window_index, time_window_index] += (
304+
running_sum[time_window_index] * span
305+
)
306+
time_window_index += 1
257307
if w_right <= t_right:
258308
window_index += 1
259309
else:
@@ -263,6 +313,9 @@ def polarised_summary(u):
263313

264314
# print("window_index:", window_index, windows.shape)
265315
assert window_index == windows.shape[0] - 1
316+
if drop_time_windows:
317+
assert result.ndim == 3
318+
result = result[:, 0]
266319
if span_normalise:
267320
for j in range(num_windows):
268321
result[j] /= windows[j + 1] - windows[j]
@@ -322,13 +375,20 @@ def naive_site_general_stat(
322375

323376

324377
def site_general_stat(
325-
ts, sample_weights, summary_func, windows=None, polarised=False, span_normalise=True
378+
ts,
379+
sample_weights,
380+
summary_func,
381+
windows=None,
382+
time_windows=None,
383+
polarised=False,
384+
span_normalise=True,
326385
):
327386
"""
328387
Problem: 'sites' is different that the other windowing options
329388
because if we output by site we don't want to normalize by length of the window.
330389
Solution: we pass an argument "normalize", to the windowing function.
331390
"""
391+
assert time_windows is None
332392
windows = ts.parse_windows(windows)
333393
num_windows = windows.shape[0] - 1
334394
n, state_dim = sample_weights.shape
@@ -425,12 +485,19 @@ def naive_node_general_stat(
425485

426486

427487
def node_general_stat(
428-
ts, sample_weights, summary_func, windows=None, polarised=False, span_normalise=True
488+
ts,
489+
sample_weights,
490+
summary_func,
491+
windows=None,
492+
time_windows=None,
493+
polarised=False,
494+
span_normalise=True,
429495
):
430496
"""
431497
Efficient implementation of the algorithm used as the basis for the
432498
underlying C version.
433499
"""
500+
assert time_windows is None
434501
n, state_dim = sample_weights.shape
435502
windows = ts.parse_windows(windows)
436503
num_windows = windows.shape[0] - 1
@@ -500,6 +567,7 @@ def general_stat(
500567
sample_weights,
501568
summary_func,
502569
windows=None,
570+
time_windows=None,
503571
polarised=False,
504572
mode="site",
505573
span_normalise=True,
@@ -518,6 +586,7 @@ def general_stat(
518586
sample_weights,
519587
summary_func,
520588
windows=windows,
589+
time_windows=time_windows,
521590
polarised=polarised,
522591
span_normalise=span_normalise,
523592
)
@@ -3534,7 +3603,9 @@ class TestSitef3(Testf3, MutatedTopologyExamplesMixin):
35343603
############################################
35353604

35363605

3537-
def branch_f4(ts, sample_sets, indexes, windows=None, span_normalise=True):
3606+
def branch_f4(
3607+
ts, sample_sets, indexes, windows=None, time_windows=None, span_normalise=True
3608+
):
35383609
windows = ts.parse_windows(windows)
35393610
out = np.zeros((len(windows) - 1, len(indexes)))
35403611
for j in range(len(windows) - 1):
@@ -3674,7 +3745,15 @@ def node_f4(ts, sample_sets, indexes, windows=None, span_normalise=True):
36743745
return out
36753746

36763747

3677-
def f4(ts, sample_sets, indexes=None, windows=None, mode="site", span_normalise=True):
3748+
def f4(
3749+
ts,
3750+
sample_sets,
3751+
indexes=None,
3752+
windows=None,
3753+
time_windows=None,
3754+
mode="site",
3755+
span_normalise=True,
3756+
):
36783757
"""
36793758
Patterson's f4 statistic definitions.
36803759
"""

0 commit comments

Comments
 (0)