@@ -169,95 +169,67 @@ def windowed_tree_stat(ts, stat, windows, span_normalise=True):
169
169
170
170
171
171
def naive_branch_general_stat (
172
- ts , w , f , windows = None , time_windows = None , polarised = False , span_normalise = True
172
+ ts , w , f , windows = None , polarised = False , span_normalise = True
173
173
):
174
174
# NOTE: does not behave correctly for unpolarised stats
175
175
# with non-ancestral material.
176
176
if windows is None :
177
177
windows = [0.0 , ts .sequence_length ]
178
- drop_time_windows = time_windows is None
179
- if time_windows is None :
180
- time_windows = [0.0 , np .inf ]
181
- else :
182
- if time_windows [0 ] != 0 :
183
- time_windows = [0 ] + time_windows
184
178
n , k = w .shape
185
- tw = len (time_windows ) - 1
186
179
# hack to determine m
187
180
m = len (f (w [0 ]))
188
181
total = np .sum (w , axis = 0 )
189
182
190
- sigma = np .zeros ((ts .num_trees , tw , m ))
191
- for j , upper_time in enumerate (time_windows [1 :]):
192
- if np .isfinite (upper_time ):
193
- decap_ts = ts .decapitate (upper_time )
183
+ sigma = np .zeros ((ts .num_trees , m ))
184
+ for tree in ts .trees ():
185
+ x = np .zeros ((ts .num_nodes , k ))
186
+ x [ts .samples ()] = w
187
+ for u in tree .nodes (order = "postorder" ):
188
+ for v in tree .children (u ):
189
+ x [u ] += x [v ]
190
+ if polarised :
191
+ s = sum (tree .branch_length (u ) * f (x [u ]) for u in tree .nodes ())
194
192
else :
195
- decap_ts = ts
196
- assert np .all (list (ts .samples ()) == list (decap_ts .samples ()))
197
- for tree in decap_ts .trees ():
198
- x = np .zeros ((decap_ts .num_nodes , k ))
199
- x [decap_ts .samples ()] = w
200
- for u in tree .nodes (order = "postorder" ):
201
- for v in tree .children (u ):
202
- x [u ] += x [v ]
203
- if polarised :
204
- s = sum (tree .branch_length (u ) * f (x [u ]) for u in tree .nodes ())
205
- else :
206
- s = sum (
207
- tree .branch_length (u ) * (f (x [u ]) + f (total - x [u ]))
208
- for u in tree .nodes ()
209
- )
210
- sigma [tree .index , j , :] = s * tree .span
211
- for j in range (1 , tw ):
212
- sigma [:, j , :] = sigma [:, j , :] - sigma [:, j - 1 , :]
193
+ s = sum (
194
+ tree .branch_length (u ) * (f (x [u ]) + f (total - x [u ]))
195
+ for u in tree .nodes ()
196
+ )
197
+ sigma [tree .index ] = s * tree .span
213
198
if isinstance (windows , str ) and windows == "trees" :
214
199
# need to average across the windows
215
200
if span_normalise :
216
201
for j , tree in enumerate (ts .trees ()):
217
202
sigma [j ] /= tree .span
218
- out = sigma
203
+ return sigma
219
204
else :
220
- out = windowed_tree_stat (ts , sigma , windows , span_normalise = span_normalise )
221
- if drop_time_windows :
222
- assert out .ndim == 3
223
- out = out [:, 0 ]
224
- return out
205
+ return windowed_tree_stat (ts , sigma , windows , span_normalise = span_normalise )
225
206
226
207
227
208
def branch_general_stat (
228
- ts ,
229
- sample_weights ,
230
- summary_func ,
231
- windows = None ,
232
- time_windows = None ,
233
- polarised = False ,
234
- span_normalise = True ,
209
+ ts , sample_weights , summary_func , windows = None , polarised = False , span_normalise = True
235
210
):
236
211
"""
237
212
Efficient implementation of the algorithm used as the basis for the
238
213
underlying C version.
239
214
"""
240
215
n , state_dim = sample_weights .shape
241
216
windows = ts .parse_windows (windows )
242
- drop_time_windows = time_windows is None
243
- time_windows = ts .parse_time_windows (time_windows )
244
217
num_windows = windows .shape [0 ] - 1
245
- num_time_windows = time_windows .shape [0 ] - 1
246
218
247
219
# Determine result_dim
248
220
result_dim = len (summary_func (sample_weights [0 ]))
249
- result = np .zeros ((num_windows , num_time_windows , result_dim ))
221
+ result = np .zeros ((num_windows , result_dim ))
250
222
state = np .zeros ((ts .num_nodes , state_dim ))
251
223
state [ts .samples ()] = sample_weights
252
224
total_weight = np .sum (sample_weights , axis = 0 )
253
225
254
226
time = ts .tables .nodes .time
255
227
parent = np .zeros (ts .num_nodes , dtype = np .int32 ) - 1
256
- branch_length = np .zeros (( num_time_windows , ts .num_nodes ) )
228
+ branch_length = np .zeros (ts .num_nodes )
257
229
# The value of summary_func(u) for every node.
258
230
summary = np .zeros ((ts .num_nodes , result_dim ))
259
231
# The result for the current tree *not* weighted by span.
260
- running_sum = np .zeros (( num_time_windows , result_dim ) )
232
+ running_sum = np .zeros (result_dim )
261
233
262
234
def polarised_summary (u ):
263
235
s = summary_func (state [u ])
@@ -269,48 +241,31 @@ def polarised_summary(u):
269
241
summary [u ] = polarised_summary (u )
270
242
271
243
window_index = 0
272
-
273
- def update_sum (u , sign ):
274
- time_window_index = 0
275
- if parent [u ] != - 1 :
276
- while (
277
- time_window_index < num_time_windows
278
- and time_windows [time_window_index ] < time [parent [u ]]
279
- ):
280
- running_sum [time_window_index ] += sign * (
281
- branch_length [time_window_index , u ] * summary [u ]
282
- )
283
- time_window_index += 1
284
-
285
244
for (t_left , t_right ), edges_out , edges_in in ts .edge_diffs ():
286
245
for edge in edges_out :
287
246
u = edge .child
288
- update_sum ( u , sign = - 1 )
247
+ running_sum -= branch_length [ u ] * summary [ u ]
289
248
u = edge .parent
290
249
while u != - 1 :
291
- update_sum ( u , sign = - 1 )
250
+ running_sum -= branch_length [ u ] * summary [ u ]
292
251
state [u ] -= state [edge .child ]
293
252
summary [u ] = polarised_summary (u )
294
- update_sum ( u , sign = + 1 )
253
+ running_sum += branch_length [ u ] * summary [ u ]
295
254
u = parent [u ]
296
255
parent [edge .child ] = - 1
297
- for tw in range (num_time_windows ):
298
- branch_length [tw , edge .child ] = 0
256
+ branch_length [edge .child ] = 0
299
257
300
258
for edge in edges_in :
301
259
parent [edge .child ] = edge .parent
302
- for tw in range (num_time_windows ):
303
- branch_length [tw , edge .child ] = min (
304
- time [edge .parent ], time_windows [tw + 1 ]
305
- ) - max (time [edge .child ], time_windows [tw ])
260
+ branch_length [edge .child ] = time [edge .parent ] - time [edge .child ]
306
261
u = edge .child
307
- update_sum ( u , sign = + 1 )
262
+ running_sum += branch_length [ u ] * summary [ u ]
308
263
u = edge .parent
309
264
while u != - 1 :
310
- update_sum ( u , sign = - 1 )
265
+ running_sum -= branch_length [ u ] * summary [ u ]
311
266
state [u ] += state [edge .child ]
312
267
summary [u ] = polarised_summary (u )
313
- update_sum ( u , sign = + 1 )
268
+ running_sum += branch_length [ u ] * summary [ u ]
314
269
u = parent [u ]
315
270
316
271
# Update the windows
@@ -322,22 +277,16 @@ def update_sum(u, sign):
322
277
right = min (t_right , w_right )
323
278
span = right - left
324
279
assert span > 0
325
- time_window_index = 0
326
- while time_window_index < num_time_windows :
327
- result [window_index , time_window_index ] += (
328
- running_sum [time_window_index ] * span
329
- )
330
- time_window_index += 1
280
+ result [window_index ] += running_sum * span
331
281
if w_right <= t_right :
332
282
window_index += 1
333
283
else :
334
284
# This interval crosses a tree boundary, so we update it again in the
335
285
# for the next tree
336
286
break
287
+
288
+ # print("window_index:", window_index, windows.shape)
337
289
assert window_index == windows .shape [0 ] - 1
338
- if drop_time_windows :
339
- assert result .ndim == 3
340
- result = result [:, 0 ]
341
290
if span_normalise :
342
291
for j in range (num_windows ):
343
292
result [j ] /= windows [j + 1 ] - windows [j ]
@@ -397,13 +346,7 @@ def naive_site_general_stat(
397
346
398
347
399
348
def site_general_stat (
400
- ts ,
401
- sample_weights ,
402
- summary_func ,
403
- windows = None ,
404
- time_windows = None ,
405
- polarised = False ,
406
- span_normalise = True ,
349
+ ts , sample_weights , summary_func , windows = None , polarised = False , span_normalise = True
407
350
):
408
351
"""
409
352
Problem: 'sites' is different that the other windowing options
@@ -506,19 +449,12 @@ def naive_node_general_stat(
506
449
507
450
508
451
def node_general_stat (
509
- ts ,
510
- sample_weights ,
511
- summary_func ,
512
- windows = None ,
513
- time_windows = None ,
514
- polarised = False ,
515
- span_normalise = True ,
452
+ ts , sample_weights , summary_func , windows = None , polarised = False , span_normalise = True
516
453
):
517
454
"""
518
455
Efficient implementation of the algorithm used as the basis for the
519
456
underlying C version.
520
457
"""
521
- assert time_windows is None
522
458
n , state_dim = sample_weights .shape
523
459
windows = ts .parse_windows (windows )
524
460
num_windows = windows .shape [0 ] - 1
@@ -588,7 +524,6 @@ def general_stat(
588
524
sample_weights ,
589
525
summary_func ,
590
526
windows = None ,
591
- time_windows = None ,
592
527
polarised = False ,
593
528
mode = "site" ,
594
529
span_normalise = True ,
@@ -607,7 +542,6 @@ def general_stat(
607
542
sample_weights ,
608
543
summary_func ,
609
544
windows = windows ,
610
- time_windows = time_windows ,
611
545
polarised = polarised ,
612
546
span_normalise = span_normalise ,
613
547
)
@@ -3735,9 +3669,7 @@ class TestSitef3(Testf3, MutatedTopologyExamplesMixin):
3735
3669
############################################
3736
3670
3737
3671
3738
- def branch_f4 (
3739
- ts , sample_sets , indexes , windows = None , time_windows = None , span_normalise = True
3740
- ):
3672
+ def branch_f4 (ts , sample_sets , indexes , windows = None , span_normalise = True ):
3741
3673
windows = ts .parse_windows (windows )
3742
3674
out = np .zeros ((len (windows ) - 1 , len (indexes )))
3743
3675
for j in range (len (windows ) - 1 ):
@@ -3877,15 +3809,7 @@ def node_f4(ts, sample_sets, indexes, windows=None, span_normalise=True):
3877
3809
return out
3878
3810
3879
3811
3880
- def f4 (
3881
- ts ,
3882
- sample_sets ,
3883
- indexes = None ,
3884
- windows = None ,
3885
- time_windows = None ,
3886
- mode = "site" ,
3887
- span_normalise = True ,
3888
- ):
3812
+ def f4 (ts , sample_sets , indexes = None , windows = None , mode = "site" , span_normalise = True ):
3889
3813
"""
3890
3814
Patterson's f4 statistic definitions.
3891
3815
"""
0 commit comments