Skip to content

Commit e85d3b7

Browse files
Merge pull request #90 from benjeffery/caching
Add disk caching
2 parents ad5d136 + d220b41 commit e85d3b7

File tree

4 files changed

+82
-0
lines changed

4 files changed

+82
-0
lines changed

cache.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import functools
2+
import pathlib
3+
4+
import appdirs
5+
import daiquiri
6+
import diskcache
7+
8+
logger = daiquiri.getLogger("cache")
9+
10+
11+
def get_cache_dir():
12+
cache_dir = pathlib.Path(appdirs.user_cache_dir("tsqc", "tsqc"))
13+
cache_dir.mkdir(exist_ok=True, parents=True)
14+
return cache_dir
15+
16+
17+
cache = diskcache.Cache(get_cache_dir())
18+
19+
20+
def disk_cache(version):
21+
def decorator(func):
22+
@functools.wraps(func)
23+
def wrapper(self, *args, **kwargs):
24+
uuid = self.file_uuid
25+
if uuid is None:
26+
logger.info(f"No uuid, not caching {func.__name__}")
27+
return func(self, *args, **kwargs)
28+
key = f"{self.file_uuid}-{func.__name__}-{version}"
29+
if key in cache:
30+
logger.info(f"Fetching {key} from cache")
31+
return cache[key]
32+
logger.info(f"Calculating {key} and caching")
33+
result = func(self, *args, **kwargs)
34+
cache[key] = result
35+
return result
36+
37+
return wrapper
38+
39+
return decorator

model.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
import pandas as pd
88
import tskit
99

10+
from cache import disk_cache
11+
12+
1013
logger = daiquiri.getLogger("model")
1114

1215
spec = [
@@ -271,7 +274,12 @@ def __init__(self, ts, name=None):
271274
self.ts.mutations_node, minlength=self.ts.num_nodes
272275
)
273276

277+
@property
278+
def file_uuid(self):
279+
return self.ts.file_uuid
280+
274281
@cached_property
282+
@disk_cache("v1")
275283
def summary_df(self):
276284
nodes_with_zero_muts = np.sum(self.nodes_num_mutations == 0)
277285
sites_with_zero_muts = np.sum(self.sites_num_mutations == 0)
@@ -312,6 +320,7 @@ def child_bounds(num_nodes, edges_left, edges_right, edges_child):
312320
return child_left, child_right
313321

314322
@cached_property
323+
@disk_cache("v1")
315324
def mutations_df(self):
316325
# FIXME use tskit's impute mutations time
317326
ts = self.ts
@@ -386,6 +395,7 @@ def mutations_df(self):
386395
)
387396

388397
@cached_property
398+
@disk_cache("v1")
389399
def edges_df(self):
390400
ts = self.ts
391401
left = ts.edges_left
@@ -426,6 +436,7 @@ def edges_df(self):
426436
)
427437

428438
@cached_property
439+
@disk_cache("v1")
429440
def nodes_df(self):
430441
ts = self.ts
431442
child_left, child_right = self.child_bounds(
@@ -452,6 +463,7 @@ def nodes_df(self):
452463
)
453464

454465
@cached_property
466+
@disk_cache("v1")
455467
def trees_df(self):
456468
ts = self.ts
457469
num_trees = ts.num_trees

requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
click
22
daiquiri
33
panel
4+
diskcache
45
hvplot
56
xarray
67
datashader
78
tskit
89
seaborn
910
pre-commit
1011
pytest
12+
tszip
13+
appdirs

tests/test_data_model.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import logging
2+
13
import numpy as np
24
import numpy.testing as nt
35
import tskit
@@ -212,3 +214,29 @@ def test_multi_tree_with_polytomies_example(self):
212214
nt.assert_array_equal(df.total_branch_length, [11.0, 12.0])
213215
# nt.assert_array_equal(df.mean_internal_arity, [2.25, 2.25])
214216
nt.assert_array_equal(df.max_internal_arity, [3.0, 3.0])
217+
218+
219+
def test_cache(caplog, tmpdir):
220+
caplog.set_level(logging.INFO)
221+
ts = multiple_trees_example_ts()
222+
tsm = model.TSModel(ts)
223+
# Use the logging out put to determine cache usage
224+
t1 = tsm.trees_df
225+
t2 = tsm.trees_df
226+
assert t1.equals(t2)
227+
assert "No uuid, not caching trees_df" in caplog.text
228+
229+
ts.dump(tmpdir / "cache.trees")
230+
ts = tskit.load(tmpdir / "cache.trees")
231+
tsm = model.TSModel(ts)
232+
# Use the logging out put to determine cache usage
233+
caplog.clear()
234+
t1 = tsm.trees_df
235+
assert "Calculating" in caplog.text
236+
caplog.clear()
237+
238+
ts2 = tskit.load(tmpdir / "cache.trees")
239+
tsm2 = model.TSModel(ts2)
240+
t2 = tsm2.trees_df
241+
assert "Fetching" in caplog.text
242+
assert t1.equals(t2)

0 commit comments

Comments
 (0)