Skip to content

Commit 86c5c36

Browse files
committed
force default_value to be float32
1 parent cb2a764 commit 86c5c36

File tree

4 files changed

+112
-30
lines changed

4 files changed

+112
-30
lines changed

bigwig_loader/download_example_data.py

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,55 @@ def download_example_data() -> None:
1919

2020
def get_reference_genome(reference_genome_path: Path = config.reference_genome) -> Path:
2121
compressed_file = reference_genome_path.with_suffix(".fasta.gz")
22-
if reference_genome_path.exists():
23-
return reference_genome_path
24-
elif compressed_file.exists():
25-
# subprocess.run(["bgzip", "-d", compressed_file])
26-
unzip_gz_file(compressed_file, reference_genome_path)
27-
else:
28-
LOGGER.info("Need reference genome for tests. Downloading it from ENCODE.")
29-
url = "https://www.encodeproject.org/files/GRCh38_no_alt_analysis_set_GCA_000001405.15/@@download/GRCh38_no_alt_analysis_set_GCA_000001405.15.fasta.gz"
30-
urllib.request.urlretrieve(url, compressed_file)
22+
if compressed_file.exists() and not reference_genome_path.exists():
3123
# subprocess.run(["bgzip", "-d", compressed_file])
3224
unzip_gz_file(compressed_file, reference_genome_path)
25+
26+
if (
27+
reference_genome_path.exists()
28+
and checksum_md5_for_path(reference_genome_path)
29+
!= config.reference_genome_checksum
30+
):
31+
LOGGER.info(
32+
f"Reference genome checksum mismatch, downloading again from {reference_genome_path}"
33+
)
34+
_download_genome(
35+
url=config.reference_genome_url,
36+
compressed_file_path=compressed_file,
37+
uncompressed_file_path=reference_genome_path,
38+
md5_checksum=config.reference_genome_checksum,
39+
)
40+
41+
if not reference_genome_path.exists():
42+
LOGGER.info(
43+
f"Reference genome not found, downloading from {config.reference_genome_url}"
44+
)
45+
_download_genome(
46+
url=config.reference_genome_url,
47+
compressed_file_path=compressed_file,
48+
uncompressed_file_path=reference_genome_path,
49+
md5_checksum=config.reference_genome_checksum,
50+
)
3351
return reference_genome_path
3452

3553

54+
def _download_genome(
55+
url: str,
56+
compressed_file_path: Path,
57+
uncompressed_file_path: Path,
58+
md5_checksum: str,
59+
) -> Path:
60+
urllib.request.urlretrieve(url, compressed_file_path)
61+
# subprocess.run(["bgzip", "-d", compressed_file])
62+
unzip_gz_file(compressed_file_path, uncompressed_file_path)
63+
this_checksum = checksum_md5_for_path(uncompressed_file_path)
64+
if this_checksum != md5_checksum:
65+
raise RuntimeError(
66+
f"{uncompressed_file_path} has incorrect checksum: {this_checksum} vs. {md5_checksum}"
67+
)
68+
return uncompressed_file_path
69+
70+
3671
def unzip_gz_file(compressed_file_path: Path, output_file_path: Path) -> Path:
3772
with gzip.open(compressed_file_path, "rb") as gz_file:
3873
with open(output_file_path, "wb") as output_file:
@@ -52,6 +87,13 @@ def unzip_gz_file(compressed_file_path: Path, output_file_path: Path) -> Path:
5287
}
5388

5489

90+
def checksum_md5_for_path(path: Path, chunk_size: int = 10 * 1024 * 1024) -> str:
91+
"""return the md5sum"""
92+
with path.open(mode="rb") as f:
93+
checksum = checksum_md5(f, chunk_size=chunk_size)
94+
return checksum
95+
96+
5597
def checksum_md5(f: BinaryIO, *, chunk_size: int = 10 * 1024 * 1024) -> str:
5698
"""return the md5sum"""
5799
m = hashlib.md5(b"", usedforsecurity=False)
@@ -68,7 +110,7 @@ def get_example_bigwigs_files(bigwig_dir: Path = config.bigwig_dir) -> Path:
68110
file = bigwig_dir / fn
69111
if not file.exists():
70112
urllib.request.urlretrieve(url, file)
71-
with file.open(mode="rb") as f:
72-
if checksum_md5(f) != md5:
73-
raise RuntimeError(f"{fn} has incorrect checksum!")
113+
checksum = checksum_md5_for_path(file)
114+
if checksum != md5:
115+
raise RuntimeError(f"{fn} has incorrect checksum: {checksum} vs. {md5}")
74116
return bigwig_dir

bigwig_loader/intervals_to_values.py

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,15 @@ def intervals_to_values(
8787
)
8888

8989
if out is None:
90+
logging.debug(f"Creating new out tensor with default value {default_value}")
91+
9092
out = cp.full(
9193
(found_starts.shape[0], len(query_starts), sequence_length // window_size),
9294
default_value,
9395
dtype=cp.float32,
9496
)
97+
logging.debug(out)
98+
9599
else:
96100
logging.debug(f"Setting default value in output tensor to {default_value}")
97101
out.fill(default_value)
@@ -139,7 +143,7 @@ def intervals_to_values(
139143
sequence_length,
140144
max_number_intervals,
141145
window_size,
142-
default_value,
146+
cp.float32(default_value),
143147
default_value_isnan,
144148
out,
145149
),
@@ -171,8 +175,10 @@ def kernel_in_python_with_window(
171175
int,
172176
int,
173177
int,
174-
cp.ndarray,
175178
int,
179+
float,
180+
bool,
181+
cp.ndarray,
176182
],
177183
) -> cp.ndarray:
178184
"""Equivalent in python to cuda_kernel_with_window. Just for debugging."""
@@ -190,6 +196,8 @@ def kernel_in_python_with_window(
190196
sequence_length,
191197
max_number_intervals,
192198
window_size,
199+
default_value,
200+
default_value_isnan,
193201
out,
194202
) = args
195203

@@ -218,7 +226,7 @@ def kernel_in_python_with_window(
218226
print("reduced_dim")
219227
print(reduced_dim)
220228

221-
out_vector = [0.0] * reduced_dim * batch_size * num_tracks
229+
out_vector = [default_value] * reduced_dim * batch_size * num_tracks
222230

223231
for thread in range(n_threads):
224232
batch_index = thread % batch_size
@@ -239,7 +247,8 @@ def kernel_in_python_with_window(
239247

240248
cursor = found_start_index
241249
window_index = 0
242-
summation = 0
250+
summation = 0.0
251+
valid_count = 0
243252

244253
# cursor moves through the rows of the bigwig file
245254
# window_index moves through the sequence
@@ -265,19 +274,31 @@ def kernel_in_python_with_window(
265274
print("start index", start_index)
266275

267276
if start_index >= window_end:
268-
print("CONTINUE")
269-
out_vector[i * reduced_dim + window_index] = summation / window_size
270-
summation = 0
277+
if default_value_isnan:
278+
if valid_count > 0:
279+
out_vector[i * reduced_dim + window_index] = (
280+
summation / valid_count
281+
)
282+
else:
283+
out_vector[i * reduced_dim + window_index] = default_value
284+
else:
285+
summation = summation + (window_size - valid_count) * default_value
286+
out_vector[i * reduced_dim + window_index] = summation / window_size
287+
summation = 0.0
288+
valid_count = 0
271289
window_index += 1
290+
print("CONTINUE")
272291
continue
273292

274293
number = min(window_end, end_index) - max(window_start, start_index)
275294

276-
print(
277-
f"Add {number} x {track_values[cursor]} = {number * track_values[cursor]} to summation"
278-
)
279-
summation += number * track_values[cursor]
280-
print(f"Summation = {summation}")
295+
if number > 0:
296+
print(
297+
f"Add {number} x {track_values[cursor]} = {number * track_values[cursor]} to summation"
298+
)
299+
summation += number * track_values[cursor]
300+
print(f"Summation = {summation}")
301+
valid_count += number
281302

282303
print("end_index", "window_end")
283304
print(end_index, window_end)
@@ -292,8 +313,19 @@ def kernel_in_python_with_window(
292313
print(
293314
"cursor + 1 >= found_end_index \t\t calculate average, reset summation and move to next window"
294315
)
295-
out_vector[i * reduced_dim + window_index] = summation / window_size
296-
summation = 0
316+
# out_vector[i * reduced_dim + window_index] = summation / window_size
317+
if default_value_isnan:
318+
if valid_count > 0:
319+
out_vector[i * reduced_dim + window_index] = (
320+
summation / valid_count
321+
)
322+
else:
323+
out_vector[i * reduced_dim + window_index] = default_value
324+
else:
325+
summation = summation + (window_size - valid_count) * default_value
326+
out_vector[i * reduced_dim + window_index] = summation / window_size
327+
summation = 0.0
328+
valid_count = 0
297329
window_index += 1
298330
# move cursor
299331
if end_index < window_end:

bigwig_loader/settings.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class Settings(BaseSettings):
3434
reference_genome: Path = (
3535
example_data_dir / "GRCh38_no_alt_analysis_set_GCA_000001405.15.fasta"
3636
)
37+
reference_genome_checksum: str = "a6da8681616c05eb542f1d91606a7b2f"
3738
bigwig_dir: Path = example_data_dir / "bigwig"
3839

3940
def __str__(self) -> str:

tests/test_intervals_to_values_window_function.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,32 +37,35 @@ def test_get_values_from_intervals_window(default_value) -> None:
3737
assert (values == expected).all()
3838

3939

40-
@pytest.mark.parametrize("default_value", [0.0, cp.nan])
40+
@pytest.mark.parametrize("default_value", [0.0, cp.nan, 5.6, 10.0, 7565])
4141
def test_get_values_from_intervals_edge_case_1(default_value) -> None:
4242
"""Query start is somewhere in a "gap"."""
4343
track_starts = cp.asarray([1, 10, 12, 16], dtype=cp.int32)
4444
track_ends = cp.asarray([3, 12, 16, 20], dtype=cp.int32)
4545
track_values = cp.asarray([20.0, 30.0, 40.0, 50.0], dtype=cp.dtype("f4"))
4646
query_starts = cp.asarray([6], dtype=cp.int32)
4747
query_ends = cp.asarray([18], dtype=cp.int32)
48-
reserved = cp.zeros((1, 4), dtype=cp.dtype("<f4"))
48+
# reserved = cp.zeros((1, 4), dtype=cp.dtype("<f4"))
4949
values = intervals_to_values(
5050
track_starts,
5151
track_ends,
5252
track_values,
5353
query_starts,
5454
query_ends,
5555
default_value=default_value,
56-
out=reserved,
5756
window_size=3,
5857
)
5958
x = default_value
6059
if isnan(default_value):
6160
expected = cp.asarray([[x, 30.0, 40.0, 46.666668]])
61+
elif default_value != 0:
62+
expected = cp.asarray([[x, (x + 30.0 + 30.0) / 3, 40.0, 46.666668]])
6263
else:
6364
expected = cp.asarray([[x, 20.0, 40.0, 46.666668]])
6465

66+
print("expected:")
6567
print(expected)
68+
print("actual:")
6669
print(values)
6770

6871
assert (
@@ -227,7 +230,9 @@ def test_get_values_from_intervals_batch_of_2(default_value) -> None:
227230
expected = cp.asarray(
228231
[[20.0, 26.666666, 26.666666, 0.0], [23.333334, 36.666668, 0.0, 33.333332]]
229232
)
233+
print("expected:")
230234
print(expected)
235+
print("actual:")
231236
print(values)
232237
assert cp.allclose(values, expected, equal_nan=True)
233238

@@ -621,9 +626,11 @@ def test_combinations_window_size_batch_size_n_tracks_on_random_data(
621626
query_ends,
622627
sizes=sizes,
623628
window_size=1,
624-
default_value=default_value,
629+
default_value=cp.nan,
625630
)
626631

632+
cp.nan_to_num(values_with_window_size_1, copy=False, nan=default_value)
633+
627634
reduced_dim = sequence_length // window_size
628635
full_matrix = values_with_window_size_1[:, :, : reduced_dim * window_size]
629636
full_matrix = full_matrix.reshape(

0 commit comments

Comments
 (0)