@@ -145,67 +145,95 @@ def windowed_tree_stat(ts, stat, windows, span_normalise=True):
145
145
146
146
147
147
def naive_branch_general_stat (
148
- ts , w , f , windows = None , polarised = False , span_normalise = True
148
+ ts , w , f , windows = None , time_windows = None , polarised = False , span_normalise = True
149
149
):
150
150
# NOTE: does not behave correctly for unpolarised stats
151
151
# with non-ancestral material.
152
152
if windows is None :
153
153
windows = [0.0 , ts .sequence_length ]
154
+ drop_time_windows = time_windows is None
155
+ if time_windows is None :
156
+ time_windows = [0.0 , np .inf ]
157
+ else :
158
+ if time_windows [0 ] != 0 :
159
+ time_windows = [0 ] + time_windows
154
160
n , k = w .shape
161
+ tw = len (time_windows ) - 1
155
162
# hack to determine m
156
163
m = len (f (w [0 ]))
157
164
total = np .sum (w , axis = 0 )
158
165
159
- sigma = np .zeros ((ts .num_trees , m ))
160
- for tree in ts .trees ():
161
- x = np .zeros ((ts .num_nodes , k ))
162
- x [ts .samples ()] = w
163
- for u in tree .nodes (order = "postorder" ):
164
- for v in tree .children (u ):
165
- x [u ] += x [v ]
166
- if polarised :
167
- s = sum (tree .branch_length (u ) * f (x [u ]) for u in tree .nodes ())
166
+ sigma = np .zeros ((ts .num_trees , tw , m ))
167
+ for j , upper_time in enumerate (time_windows [1 :]):
168
+ if np .isfinite (upper_time ):
169
+ decap_ts = ts .decapitate (upper_time )
168
170
else :
169
- s = sum (
170
- tree .branch_length (u ) * (f (x [u ]) + f (total - x [u ]))
171
- for u in tree .nodes ()
172
- )
173
- sigma [tree .index ] = s * tree .span
171
+ decap_ts = ts
172
+ assert np .all (list (ts .samples ()) == list (decap_ts .samples ()))
173
+ for tree in decap_ts .trees ():
174
+ x = np .zeros ((decap_ts .num_nodes , k ))
175
+ x [decap_ts .samples ()] = w
176
+ for u in tree .nodes (order = "postorder" ):
177
+ for v in tree .children (u ):
178
+ x [u ] += x [v ]
179
+ if polarised :
180
+ s = sum (tree .branch_length (u ) * f (x [u ]) for u in tree .nodes ())
181
+ else :
182
+ s = sum (
183
+ tree .branch_length (u ) * (f (x [u ]) + f (total - x [u ]))
184
+ for u in tree .nodes ()
185
+ )
186
+ sigma [tree .index , j , :] = s * tree .span
187
+ for j in range (1 , tw ):
188
+ sigma [:, j , :] = sigma [:, j , :] - sigma [:, j - 1 , :]
174
189
if isinstance (windows , str ) and windows == "trees" :
175
190
# need to average across the windows
176
191
if span_normalise :
177
192
for j , tree in enumerate (ts .trees ()):
178
193
sigma [j ] /= tree .span
179
- return sigma
194
+ out = sigma
180
195
else :
181
- return windowed_tree_stat (ts , sigma , windows , span_normalise = span_normalise )
196
+ out = windowed_tree_stat (ts , sigma , windows , span_normalise = span_normalise )
197
+ if drop_time_windows :
198
+ assert out .ndim == 3
199
+ out = out [:, 0 ]
200
+ return out
182
201
183
202
184
203
def branch_general_stat (
185
- ts , sample_weights , summary_func , windows = None , polarised = False , span_normalise = True
204
+ ts ,
205
+ sample_weights ,
206
+ summary_func ,
207
+ windows = None ,
208
+ time_windows = None ,
209
+ polarised = False ,
210
+ span_normalise = True ,
186
211
):
187
212
"""
188
213
Efficient implementation of the algorithm used as the basis for the
189
214
underlying C version.
190
215
"""
191
216
n , state_dim = sample_weights .shape
192
217
windows = ts .parse_windows (windows )
218
+ drop_time_windows = time_windows is None
219
+ time_windows = ts .parse_time_windows (time_windows )
193
220
num_windows = windows .shape [0 ] - 1
221
+ num_time_windows = time_windows .shape [0 ] - 1
194
222
195
223
# Determine result_dim
196
224
result_dim = len (summary_func (sample_weights [0 ]))
197
- result = np .zeros ((num_windows , result_dim ))
225
+ result = np .zeros ((num_windows , num_time_windows , result_dim ))
198
226
state = np .zeros ((ts .num_nodes , state_dim ))
199
227
state [ts .samples ()] = sample_weights
200
228
total_weight = np .sum (sample_weights , axis = 0 )
201
229
202
230
time = ts .tables .nodes .time
203
231
parent = np .zeros (ts .num_nodes , dtype = np .int32 ) - 1
204
- branch_length = np .zeros (ts .num_nodes )
232
+ branch_length = np .zeros (( num_time_windows , ts .num_nodes ) )
205
233
# The value of summary_func(u) for every node.
206
234
summary = np .zeros ((ts .num_nodes , result_dim ))
207
235
# The result for the current tree *not* weighted by span.
208
- running_sum = np .zeros (result_dim )
236
+ running_sum = np .zeros (( num_time_windows , result_dim ) )
209
237
210
238
def polarised_summary (u ):
211
239
s = summary_func (state [u ])
@@ -217,31 +245,48 @@ def polarised_summary(u):
217
245
summary [u ] = polarised_summary (u )
218
246
219
247
window_index = 0
248
+
249
+ def update_sum (u , sign ):
250
+ time_window_index = 0
251
+ if parent [u ] != - 1 :
252
+ while (
253
+ time_window_index < num_time_windows
254
+ and time_windows [time_window_index ] < time [parent [u ]]
255
+ ):
256
+ running_sum [time_window_index ] += sign * (
257
+ branch_length [time_window_index , u ] * summary [u ]
258
+ )
259
+ time_window_index += 1
260
+
220
261
for (t_left , t_right ), edges_out , edges_in in ts .edge_diffs ():
221
262
for edge in edges_out :
222
263
u = edge .child
223
- running_sum -= branch_length [ u ] * summary [ u ]
264
+ update_sum ( u , sign = - 1 )
224
265
u = edge .parent
225
266
while u != - 1 :
226
- running_sum -= branch_length [ u ] * summary [ u ]
267
+ update_sum ( u , sign = - 1 )
227
268
state [u ] -= state [edge .child ]
228
269
summary [u ] = polarised_summary (u )
229
- running_sum += branch_length [ u ] * summary [ u ]
270
+ update_sum ( u , sign = + 1 )
230
271
u = parent [u ]
231
272
parent [edge .child ] = - 1
232
- branch_length [edge .child ] = 0
273
+ for tw in range (num_time_windows ):
274
+ branch_length [tw , edge .child ] = 0
233
275
234
276
for edge in edges_in :
235
277
parent [edge .child ] = edge .parent
236
- branch_length [edge .child ] = time [edge .parent ] - time [edge .child ]
278
+ for tw in range (num_time_windows ):
279
+ branch_length [tw , edge .child ] = min (
280
+ time [edge .parent ], time_windows [tw + 1 ]
281
+ ) - max (time [edge .child ], time_windows [tw ])
237
282
u = edge .child
238
- running_sum += branch_length [ u ] * summary [ u ]
283
+ update_sum ( u , sign = + 1 )
239
284
u = edge .parent
240
285
while u != - 1 :
241
- running_sum -= branch_length [ u ] * summary [ u ]
286
+ update_sum ( u , sign = - 1 )
242
287
state [u ] += state [edge .child ]
243
288
summary [u ] = polarised_summary (u )
244
- running_sum += branch_length [ u ] * summary [ u ]
289
+ update_sum ( u , sign = + 1 )
245
290
u = parent [u ]
246
291
247
292
# Update the windows
@@ -253,7 +298,12 @@ def polarised_summary(u):
253
298
right = min (t_right , w_right )
254
299
span = right - left
255
300
assert span > 0
256
- result [window_index ] += running_sum * span
301
+ time_window_index = 0
302
+ while time_window_index < num_time_windows :
303
+ result [window_index , time_window_index ] += (
304
+ running_sum [time_window_index ] * span
305
+ )
306
+ time_window_index += 1
257
307
if w_right <= t_right :
258
308
window_index += 1
259
309
else :
@@ -263,6 +313,9 @@ def polarised_summary(u):
263
313
264
314
# print("window_index:", window_index, windows.shape)
265
315
assert window_index == windows .shape [0 ] - 1
316
+ if drop_time_windows :
317
+ assert result .ndim == 3
318
+ result = result [:, 0 ]
266
319
if span_normalise :
267
320
for j in range (num_windows ):
268
321
result [j ] /= windows [j + 1 ] - windows [j ]
@@ -322,13 +375,20 @@ def naive_site_general_stat(
322
375
323
376
324
377
def site_general_stat (
325
- ts , sample_weights , summary_func , windows = None , polarised = False , span_normalise = True
378
+ ts ,
379
+ sample_weights ,
380
+ summary_func ,
381
+ windows = None ,
382
+ time_windows = None ,
383
+ polarised = False ,
384
+ span_normalise = True ,
326
385
):
327
386
"""
328
387
Problem: 'sites' is different that the other windowing options
329
388
because if we output by site we don't want to normalize by length of the window.
330
389
Solution: we pass an argument "normalize", to the windowing function.
331
390
"""
391
+ assert time_windows is None
332
392
windows = ts .parse_windows (windows )
333
393
num_windows = windows .shape [0 ] - 1
334
394
n , state_dim = sample_weights .shape
@@ -425,12 +485,19 @@ def naive_node_general_stat(
425
485
426
486
427
487
def node_general_stat (
428
- ts , sample_weights , summary_func , windows = None , polarised = False , span_normalise = True
488
+ ts ,
489
+ sample_weights ,
490
+ summary_func ,
491
+ windows = None ,
492
+ time_windows = None ,
493
+ polarised = False ,
494
+ span_normalise = True ,
429
495
):
430
496
"""
431
497
Efficient implementation of the algorithm used as the basis for the
432
498
underlying C version.
433
499
"""
500
+ assert time_windows is None
434
501
n , state_dim = sample_weights .shape
435
502
windows = ts .parse_windows (windows )
436
503
num_windows = windows .shape [0 ] - 1
@@ -500,6 +567,7 @@ def general_stat(
500
567
sample_weights ,
501
568
summary_func ,
502
569
windows = None ,
570
+ time_windows = None ,
503
571
polarised = False ,
504
572
mode = "site" ,
505
573
span_normalise = True ,
@@ -518,6 +586,7 @@ def general_stat(
518
586
sample_weights ,
519
587
summary_func ,
520
588
windows = windows ,
589
+ time_windows = time_windows ,
521
590
polarised = polarised ,
522
591
span_normalise = span_normalise ,
523
592
)
@@ -3534,7 +3603,9 @@ class TestSitef3(Testf3, MutatedTopologyExamplesMixin):
3534
3603
############################################
3535
3604
3536
3605
3537
- def branch_f4 (ts , sample_sets , indexes , windows = None , span_normalise = True ):
3606
+ def branch_f4 (
3607
+ ts , sample_sets , indexes , windows = None , time_windows = None , span_normalise = True
3608
+ ):
3538
3609
windows = ts .parse_windows (windows )
3539
3610
out = np .zeros ((len (windows ) - 1 , len (indexes )))
3540
3611
for j in range (len (windows ) - 1 ):
@@ -3674,7 +3745,15 @@ def node_f4(ts, sample_sets, indexes, windows=None, span_normalise=True):
3674
3745
return out
3675
3746
3676
3747
3677
- def f4 (ts , sample_sets , indexes = None , windows = None , mode = "site" , span_normalise = True ):
3748
+ def f4 (
3749
+ ts ,
3750
+ sample_sets ,
3751
+ indexes = None ,
3752
+ windows = None ,
3753
+ time_windows = None ,
3754
+ mode = "site" ,
3755
+ span_normalise = True ,
3756
+ ):
3678
3757
"""
3679
3758
Patterson's f4 statistic definitions.
3680
3759
"""
0 commit comments