Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ authors = [
]
requires-python = ">=3.9"
dependencies = [
"tsinfer==0.3.3", # https://github.com/jeromekelleher/sc2ts/issues/201
# "tsinfer==0.3.3", # https://github.com/jeromekelleher/sc2ts/issues/201
# FIXME
"tsinfer @ git+https://github.com/jeromekelleher/tsinfer.git@experimental-hmm",
"pyfaidx",
"tskit>=0.5.3",
"tszip",
Expand Down
53 changes: 21 additions & 32 deletions sc2ts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,46 +402,45 @@ def match_samples(
show_progress=False,
num_threads=None,
):
# First pass, compute the matches at precision=0.
run_batch = samples

# Values based on https://github.com/jeromekelleher/sc2ts/issues/242,
# but somewhat arbitrary.
for precision, cost_threshold in [(0, 1), (1, 2), (2, 3)]:
logger.info(f"Running batch of {len(run_batch)} at p={precision}")

mu = 0.125 ## FIXME
for k in range(int(num_mismatches)):
# To catch k mismatches we need a likelihood threshold of mu**k
likelihood_threshold = mu**k - 1e-15
logger.info(f"Running match={k} batch of {len(run_batch)} at threshold={likelihood_threshold}")
match_tsinfer(
samples=run_batch,
ts=base_ts,
num_mismatches=num_mismatches,
precision=precision,
likelihood_threshold=likelihood_threshold,
num_threads=num_threads,
show_progress=show_progress,
)

exceeding_threshold = []
for sample in run_batch:
cost = sample.get_hmm_cost(num_mismatches)
logger.debug(f"HMM@p={precision}: hmm_cost={cost} {sample.summary()}")
if cost > cost_threshold:
logger.debug(f"HMM@k={k}: hmm_cost={cost} {sample.summary()}")
if cost > k + 1:
sample.path.clear()
sample.mutations.clear()
exceeding_threshold.append(sample)

num_matches_found = len(run_batch) - len(exceeding_threshold)
logger.info(
f"{num_matches_found} final matches for found p={precision}; "
f"{num_matches_found} final matches found at k={k}; "
f"{len(exceeding_threshold)} remain"
)
run_batch = exceeding_threshold

precision = 6
logger.info(f"Running final batch of {len(run_batch)} at p={precision}")
logger.info(f"Running final batch of {len(run_batch)} at full precision")
match_tsinfer(
samples=run_batch,
ts=base_ts,
num_mismatches=num_mismatches,
precision=precision,
num_threads=num_threads,
likelihood_threshold=1e-200,
show_progress=show_progress,
)
for sample in run_batch:
Expand Down Expand Up @@ -798,36 +797,26 @@ def add_matching_results(
return ts # , excluded_samples, added_samples


def solve_num_mismatches(ts, k):
def solve_num_mismatches(k, num_sites, mu=0.125):
"""
Return the low-level LS parameters corresponding to accepting
k mismatches in favour of a single recombination.

NOTE! This is NOT taking into account the spatial distance along
the genome, and so is not a very good model in some ways.
"""
# We can match against any node in tsinfer
m = ts.num_sites
n = ts.num_nodes
# values of k <= 1 are not relevant for SC2 and lead to awkward corner cases
assert k > 1

# NOTE: the magnitude of mu matters because it puts a limit
# on how low we can push the HMM precision. We should be able to solve
# for the optimal value of this parameter such that the magnitude of the
# values within the HMM are as large as possible (so that we can truncate
# usefully).
# mu = 1e-2
mu = 0.125
denom = (1 - mu) ** k + (n - 1) * mu**k
r = n * mu**k / denom
denom = (1 - mu) ** k
r = mu**k / denom

# Add a little bit of extra mass for recombination so that we deterministically
# chose to recombine over k mutations
# NOTE: the magnitude of this value will depend also on mu, see above.
r += r * 0.01
ls_recomb = np.full(m - 1, r)
ls_mismatch = np.full(m, mu)
r += r * 0.125
ls_recomb = np.full(num_sites - 1, r)
ls_mismatch = np.full(num_sites, mu)
return ls_recomb, ls_mismatch


Expand Down Expand Up @@ -1268,7 +1257,7 @@ def match_tsinfer(
ts,
*,
num_mismatches,
precision=None,
likelihood_threshold=None,
num_threads=0,
show_progress=False,
mirror_coordinates=False,
Expand All @@ -1284,7 +1273,7 @@ def match_tsinfer(
sd = convert_tsinfer_sample_data(ts, genotypes)

L = int(ts.sequence_length)
ls_recomb, ls_mismatch = solve_num_mismatches(ts, num_mismatches)
ls_recomb, ls_mismatch = solve_num_mismatches(num_mismatches, ts.num_sites)
pm = tsinfer.inference._get_progress_monitor(
show_progress,
generate_ancestors=False,
Expand All @@ -1309,7 +1298,7 @@ def match_tsinfer(
mismatch=ls_mismatch,
progress_monitor=pm,
num_threads=num_threads,
precision=precision,
likelihood_threshold=likelihood_threshold
)
results = manager.run_match(np.arange(sd.num_samples))

Expand Down
56 changes: 33 additions & 23 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import numpy.testing as nt
import pytest
import tsinfer
import tskit
Expand All @@ -8,6 +9,18 @@
import util


class TestSolveNumMismatches:

@pytest.mark.parametrize(
["k", "expected_rho"],
[(2, 0.02295918), (3, 0.00327988), (4, 0.00046855), (1000, 0)],
)
def test_examples(self, k, expected_rho):
rho, mu = sc2ts.solve_num_mismatches(k, num_sites=2)
assert mu[0] == 0.125
nt.assert_almost_equal(rho[0], expected_rho)


class TestInitialTs:
def test_reference_sequence(self):
ts = sc2ts.initial_ts()
Expand Down Expand Up @@ -612,13 +625,13 @@ def test_node_mutation_counts(self, fx_ts_map, date):
"2020-02-03": {"nodes": 36, "mutations": 42},
"2020-02-04": {"nodes": 41, "mutations": 48},
"2020-02-05": {"nodes": 42, "mutations": 48},
"2020-02-06": {"nodes": 49, "mutations": 51},
"2020-02-07": {"nodes": 51, "mutations": 57},
"2020-02-08": {"nodes": 57, "mutations": 58},
"2020-02-09": {"nodes": 59, "mutations": 61},
"2020-02-10": {"nodes": 60, "mutations": 65},
"2020-02-11": {"nodes": 62, "mutations": 66},
"2020-02-13": {"nodes": 66, "mutations": 68},
"2020-02-06": {"nodes": 48, "mutations": 51},
"2020-02-07": {"nodes": 50, "mutations": 57},
"2020-02-08": {"nodes": 56, "mutations": 58},
"2020-02-09": {"nodes": 58, "mutations": 61},
"2020-02-10": {"nodes": 59, "mutations": 65},
"2020-02-11": {"nodes": 61, "mutations": 66},
"2020-02-13": {"nodes": 65, "mutations": 68},
}
assert ts.num_nodes == expected[date]["nodes"]
assert ts.num_mutations == expected[date]["mutations"]
Expand All @@ -631,9 +644,9 @@ def test_node_mutation_counts(self, fx_ts_map, date):
(13, "SRR11597132", 10),
(16, "SRR11597177", 10),
(41, "SRR11597156", 10),
(57, "SRR11597216", 1),
(60, "SRR11597207", 40),
(62, "ERR4205570", 58),
(56, "SRR11597216", 1),
(59, "SRR11597207", 40),
(61, "ERR4205570", 57),
],
)
def test_exact_matches(self, fx_ts_map, node, strain, parent):
Expand Down Expand Up @@ -693,10 +706,9 @@ class TestMatchingDetails:
# assert s.path[0].parent == 37

@pytest.mark.parametrize(
("strain", "parent"), [("SRR11597207", 40), ("ERR4205570", 58)]
("strain", "parent"), [("SRR11597207", 40), ("ERR4205570", 57)]
)
@pytest.mark.parametrize("num_mismatches", [2, 3, 4])
@pytest.mark.parametrize("precision", [0, 1, 2, 12])
def test_exact_matches(
self,
fx_ts_map,
Expand All @@ -705,17 +717,18 @@ def test_exact_matches(
strain,
parent,
num_mismatches,
precision,
):
ts = fx_ts_map["2020-02-10"]
samples = sc2ts.preprocess(
[fx_metadata_db[strain]], ts, "2020-02-20", fx_alignment_store
)
# FIXME
mu = 0.125
sc2ts.match_tsinfer(
samples=samples,
ts=ts,
num_mismatches=num_mismatches,
precision=precision,
likelihood_threshold = mu**num_mismatches - 1e-12,
num_threads=0,
)
s = samples[0]
Expand All @@ -725,10 +738,10 @@ def test_exact_matches(

@pytest.mark.parametrize(
("strain", "parent", "position", "derived_state"),
[("SRR11597218", 10, 289, "T"), ("ERR4206593", 58, 26994, "T")],
[("SRR11597218", 10, 289, "T"), ("ERR4206593", 57, 26994, "T")],
)
@pytest.mark.parametrize("num_mismatches", [2, 3, 4])
@pytest.mark.parametrize("precision", [0, 1, 2, 12])
# @pytest.mark.parametrize("precision", [0, 1, 2, 12])
def test_one_mismatch(
self,
fx_ts_map,
Expand All @@ -739,7 +752,6 @@ def test_one_mismatch(
position,
derived_state,
num_mismatches,
precision,
):
ts = fx_ts_map["2020-02-10"]
samples = sc2ts.preprocess(
Expand All @@ -749,7 +761,8 @@ def test_one_mismatch(
samples=samples,
ts=ts,
num_mismatches=num_mismatches,
precision=precision,
# FIXME
likelihood_threshold=0.12499999,
num_threads=0,
)
s = samples[0]
Expand All @@ -760,30 +773,27 @@ def test_one_mismatch(
assert s.path[0].parent == parent

@pytest.mark.parametrize("num_mismatches", [2, 3, 4])
@pytest.mark.parametrize("precision", [0, 1, 2, 12])
def test_two_mismatches(
self,
fx_ts_map,
fx_alignment_store,
fx_metadata_db,
num_mismatches,
precision,
):
strain = "ERR4204459"
ts = fx_ts_map["2020-02-10"]
samples = sc2ts.preprocess(
[fx_metadata_db[strain]], ts, "2020-02-20", fx_alignment_store
)
mu = 0.125
sc2ts.match_tsinfer(
samples=samples,
ts=ts,
num_mismatches=num_mismatches,
precision=precision,
likelihood_threshold=mu**2 - 1e-12,
num_threads=0,
)
s = samples[0]
assert len(s.path) == 1
assert s.path[0].parent == 5
assert len(s.mutations) == 2
# assert s.mutations[0].site_position == position
# assert s.mutations[0].derived_state == derived_state