Skip to content

Commit 3fa8c3f

Browse files
committed
removed general stat stuff (moved to tskit-dev#3271)
1 parent d0892aa commit 3fa8c3f

File tree

1 file changed

+36
-112
lines changed

1 file changed

+36
-112
lines changed

python/tests/test_tree_stats.py

Lines changed: 36 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -169,95 +169,67 @@ def windowed_tree_stat(ts, stat, windows, span_normalise=True):
169169

170170

171171
def naive_branch_general_stat(
172-
ts, w, f, windows=None, time_windows=None, polarised=False, span_normalise=True
172+
ts, w, f, windows=None, polarised=False, span_normalise=True
173173
):
174174
# NOTE: does not behave correctly for unpolarised stats
175175
# with non-ancestral material.
176176
if windows is None:
177177
windows = [0.0, ts.sequence_length]
178-
drop_time_windows = time_windows is None
179-
if time_windows is None:
180-
time_windows = [0.0, np.inf]
181-
else:
182-
if time_windows[0] != 0:
183-
time_windows = [0] + time_windows
184178
n, k = w.shape
185-
tw = len(time_windows) - 1
186179
# hack to determine m
187180
m = len(f(w[0]))
188181
total = np.sum(w, axis=0)
189182

190-
sigma = np.zeros((ts.num_trees, tw, m))
191-
for j, upper_time in enumerate(time_windows[1:]):
192-
if np.isfinite(upper_time):
193-
decap_ts = ts.decapitate(upper_time)
183+
sigma = np.zeros((ts.num_trees, m))
184+
for tree in ts.trees():
185+
x = np.zeros((ts.num_nodes, k))
186+
x[ts.samples()] = w
187+
for u in tree.nodes(order="postorder"):
188+
for v in tree.children(u):
189+
x[u] += x[v]
190+
if polarised:
191+
s = sum(tree.branch_length(u) * f(x[u]) for u in tree.nodes())
194192
else:
195-
decap_ts = ts
196-
assert np.all(list(ts.samples()) == list(decap_ts.samples()))
197-
for tree in decap_ts.trees():
198-
x = np.zeros((decap_ts.num_nodes, k))
199-
x[decap_ts.samples()] = w
200-
for u in tree.nodes(order="postorder"):
201-
for v in tree.children(u):
202-
x[u] += x[v]
203-
if polarised:
204-
s = sum(tree.branch_length(u) * f(x[u]) for u in tree.nodes())
205-
else:
206-
s = sum(
207-
tree.branch_length(u) * (f(x[u]) + f(total - x[u]))
208-
for u in tree.nodes()
209-
)
210-
sigma[tree.index, j, :] = s * tree.span
211-
for j in range(1, tw):
212-
sigma[:, j, :] = sigma[:, j, :] - sigma[:, j - 1, :]
193+
s = sum(
194+
tree.branch_length(u) * (f(x[u]) + f(total - x[u]))
195+
for u in tree.nodes()
196+
)
197+
sigma[tree.index] = s * tree.span
213198
if isinstance(windows, str) and windows == "trees":
214199
# need to average across the windows
215200
if span_normalise:
216201
for j, tree in enumerate(ts.trees()):
217202
sigma[j] /= tree.span
218-
out = sigma
203+
return sigma
219204
else:
220-
out = windowed_tree_stat(ts, sigma, windows, span_normalise=span_normalise)
221-
if drop_time_windows:
222-
assert out.ndim == 3
223-
out = out[:, 0]
224-
return out
205+
return windowed_tree_stat(ts, sigma, windows, span_normalise=span_normalise)
225206

226207

227208
def branch_general_stat(
228-
ts,
229-
sample_weights,
230-
summary_func,
231-
windows=None,
232-
time_windows=None,
233-
polarised=False,
234-
span_normalise=True,
209+
ts, sample_weights, summary_func, windows=None, polarised=False, span_normalise=True
235210
):
236211
"""
237212
Efficient implementation of the algorithm used as the basis for the
238213
underlying C version.
239214
"""
240215
n, state_dim = sample_weights.shape
241216
windows = ts.parse_windows(windows)
242-
drop_time_windows = time_windows is None
243-
time_windows = ts.parse_time_windows(time_windows)
244217
num_windows = windows.shape[0] - 1
245-
num_time_windows = time_windows.shape[0] - 1
246218

247219
# Determine result_dim
248220
result_dim = len(summary_func(sample_weights[0]))
249-
result = np.zeros((num_windows, num_time_windows, result_dim))
221+
result = np.zeros((num_windows, result_dim))
250222
state = np.zeros((ts.num_nodes, state_dim))
251223
state[ts.samples()] = sample_weights
252224
total_weight = np.sum(sample_weights, axis=0)
253225

254226
time = ts.tables.nodes.time
255227
parent = np.zeros(ts.num_nodes, dtype=np.int32) - 1
256-
branch_length = np.zeros((num_time_windows, ts.num_nodes))
228+
branch_length = np.zeros(ts.num_nodes)
257229
# The value of summary_func(u) for every node.
258230
summary = np.zeros((ts.num_nodes, result_dim))
259231
# The result for the current tree *not* weighted by span.
260-
running_sum = np.zeros((num_time_windows, result_dim))
232+
running_sum = np.zeros(result_dim)
261233

262234
def polarised_summary(u):
263235
s = summary_func(state[u])
@@ -269,48 +241,31 @@ def polarised_summary(u):
269241
summary[u] = polarised_summary(u)
270242

271243
window_index = 0
272-
273-
def update_sum(u, sign):
274-
time_window_index = 0
275-
if parent[u] != -1:
276-
while (
277-
time_window_index < num_time_windows
278-
and time_windows[time_window_index] < time[parent[u]]
279-
):
280-
running_sum[time_window_index] += sign * (
281-
branch_length[time_window_index, u] * summary[u]
282-
)
283-
time_window_index += 1
284-
285244
for (t_left, t_right), edges_out, edges_in in ts.edge_diffs():
286245
for edge in edges_out:
287246
u = edge.child
288-
update_sum(u, sign=-1)
247+
running_sum -= branch_length[u] * summary[u]
289248
u = edge.parent
290249
while u != -1:
291-
update_sum(u, sign=-1)
250+
running_sum -= branch_length[u] * summary[u]
292251
state[u] -= state[edge.child]
293252
summary[u] = polarised_summary(u)
294-
update_sum(u, sign=+1)
253+
running_sum += branch_length[u] * summary[u]
295254
u = parent[u]
296255
parent[edge.child] = -1
297-
for tw in range(num_time_windows):
298-
branch_length[tw, edge.child] = 0
256+
branch_length[edge.child] = 0
299257

300258
for edge in edges_in:
301259
parent[edge.child] = edge.parent
302-
for tw in range(num_time_windows):
303-
branch_length[tw, edge.child] = min(
304-
time[edge.parent], time_windows[tw + 1]
305-
) - max(time[edge.child], time_windows[tw])
260+
branch_length[edge.child] = time[edge.parent] - time[edge.child]
306261
u = edge.child
307-
update_sum(u, sign=+1)
262+
running_sum += branch_length[u] * summary[u]
308263
u = edge.parent
309264
while u != -1:
310-
update_sum(u, sign=-1)
265+
running_sum -= branch_length[u] * summary[u]
311266
state[u] += state[edge.child]
312267
summary[u] = polarised_summary(u)
313-
update_sum(u, sign=+1)
268+
running_sum += branch_length[u] * summary[u]
314269
u = parent[u]
315270

316271
# Update the windows
@@ -322,22 +277,16 @@ def update_sum(u, sign):
322277
right = min(t_right, w_right)
323278
span = right - left
324279
assert span > 0
325-
time_window_index = 0
326-
while time_window_index < num_time_windows:
327-
result[window_index, time_window_index] += (
328-
running_sum[time_window_index] * span
329-
)
330-
time_window_index += 1
280+
result[window_index] += running_sum * span
331281
if w_right <= t_right:
332282
window_index += 1
333283
else:
334284
# This interval crosses a tree boundary, so we update it again in the
335285
# for the next tree
336286
break
287+
288+
# print("window_index:", window_index, windows.shape)
337289
assert window_index == windows.shape[0] - 1
338-
if drop_time_windows:
339-
assert result.ndim == 3
340-
result = result[:, 0]
341290
if span_normalise:
342291
for j in range(num_windows):
343292
result[j] /= windows[j + 1] - windows[j]
@@ -397,13 +346,7 @@ def naive_site_general_stat(
397346

398347

399348
def site_general_stat(
400-
ts,
401-
sample_weights,
402-
summary_func,
403-
windows=None,
404-
time_windows=None,
405-
polarised=False,
406-
span_normalise=True,
349+
ts, sample_weights, summary_func, windows=None, polarised=False, span_normalise=True
407350
):
408351
"""
409352
Problem: 'sites' is different that the other windowing options
@@ -506,19 +449,12 @@ def naive_node_general_stat(
506449

507450

508451
def node_general_stat(
509-
ts,
510-
sample_weights,
511-
summary_func,
512-
windows=None,
513-
time_windows=None,
514-
polarised=False,
515-
span_normalise=True,
452+
ts, sample_weights, summary_func, windows=None, polarised=False, span_normalise=True
516453
):
517454
"""
518455
Efficient implementation of the algorithm used as the basis for the
519456
underlying C version.
520457
"""
521-
assert time_windows is None
522458
n, state_dim = sample_weights.shape
523459
windows = ts.parse_windows(windows)
524460
num_windows = windows.shape[0] - 1
@@ -588,7 +524,6 @@ def general_stat(
588524
sample_weights,
589525
summary_func,
590526
windows=None,
591-
time_windows=None,
592527
polarised=False,
593528
mode="site",
594529
span_normalise=True,
@@ -607,7 +542,6 @@ def general_stat(
607542
sample_weights,
608543
summary_func,
609544
windows=windows,
610-
time_windows=time_windows,
611545
polarised=polarised,
612546
span_normalise=span_normalise,
613547
)
@@ -3735,9 +3669,7 @@ class TestSitef3(Testf3, MutatedTopologyExamplesMixin):
37353669
############################################
37363670

37373671

3738-
def branch_f4(
3739-
ts, sample_sets, indexes, windows=None, time_windows=None, span_normalise=True
3740-
):
3672+
def branch_f4(ts, sample_sets, indexes, windows=None, span_normalise=True):
37413673
windows = ts.parse_windows(windows)
37423674
out = np.zeros((len(windows) - 1, len(indexes)))
37433675
for j in range(len(windows) - 1):
@@ -3877,15 +3809,7 @@ def node_f4(ts, sample_sets, indexes, windows=None, span_normalise=True):
38773809
return out
38783810

38793811

3880-
def f4(
3881-
ts,
3882-
sample_sets,
3883-
indexes=None,
3884-
windows=None,
3885-
time_windows=None,
3886-
mode="site",
3887-
span_normalise=True,
3888-
):
3812+
def f4(ts, sample_sets, indexes=None, windows=None, mode="site", span_normalise=True):
38893813
"""
38903814
Patterson's f4 statistic definitions.
38913815
"""

0 commit comments

Comments
 (0)