|
2 | 2 | from pathlib import Path |
3 | 3 | from typing import Any |
4 | 4 | from typing import Callable |
| 5 | +from typing import Iterable |
5 | 6 | from typing import Iterator |
6 | 7 | from typing import Literal |
7 | 8 | from typing import Optional |
@@ -78,6 +79,12 @@ class BigWigDataset: |
78 | 79 | GPU. More threads means that more IO can take place while the GPU is busy doing |
79 | 80 | calculations (decompressing or neural network training for example). More threads |
80 | 81 | also means a higher GPU memory usage. Default: 4 |
| 82 | + custom_position_sampler: if set, this sampler will be used instead of the default |
| 83 | + position sampler (which samples randomly and uniform from regions of interest) |
| 84 | + This should be an iterable of tuples (chromosome, center). |
| 85 | + custom_track_sampler: if specified, this sampler will be used to sample tracks. When not |
| 86 | + specified, each batch simply contains all tracks, or a randomly sellected subset of |
| 87 | + tracks in case sub_sample_tracks is set. Should be Iterable batches of track indices. |
81 | 88 | return_batch_objects: if True, the batches will be returned as instances of |
82 | 89 | bigwig_loader.batch.Batch |
83 | 90 | """ |
@@ -107,6 +114,8 @@ def __init__( |
107 | 114 | repeat_same_positions: bool = False, |
108 | 115 | sub_sample_tracks: Optional[int] = None, |
109 | 116 | n_threads: int = 4, |
| 117 | + custom_position_sampler: Optional[Iterable[tuple[str, int]]] = None, |
| 118 | + custom_track_sampler: Optional[Iterable[list[int]]] = None, |
110 | 119 | return_batch_objects: bool = False, |
111 | 120 | ): |
112 | 121 | super().__init__() |
@@ -152,32 +161,34 @@ def __init__( |
152 | 161 | self._sub_sample_tracks = sub_sample_tracks |
153 | 162 | self._n_threads = n_threads |
154 | 163 | self._return_batch_objects = return_batch_objects |
155 | | - |
156 | | - def _create_dataloader(self) -> StreamedDataloader: |
157 | | - position_sampler = RandomPositionSampler( |
| 164 | + self._position_sampler = custom_position_sampler or RandomPositionSampler( |
158 | 165 | regions_of_interest=self.regions_of_interest, |
159 | 166 | buffer_size=self._position_sampler_buffer_size, |
160 | 167 | repeat_same=self._repeat_same_positions, |
161 | 168 | ) |
| 169 | + if custom_track_sampler is not None: |
| 170 | + self._track_sampler: Optional[Iterable[list[int]]] = custom_track_sampler |
| 171 | + elif sub_sample_tracks is not None: |
| 172 | + self._track_sampler = TrackSampler( |
| 173 | + total_number_of_tracks=len(self.bigwig_collection), |
| 174 | + sample_size=sub_sample_tracks, |
| 175 | + ) |
| 176 | + else: |
| 177 | + self._track_sampler = None |
162 | 178 |
|
| 179 | + def _create_dataloader(self) -> StreamedDataloader: |
163 | 180 | sequence_sampler = GenomicSequenceSampler( |
164 | 181 | reference_genome_path=self.reference_genome_path, |
165 | 182 | sequence_length=self.sequence_length, |
166 | | - position_sampler=position_sampler, |
| 183 | + position_sampler=self._position_sampler, |
167 | 184 | maximum_unknown_bases_fraction=self.maximum_unknown_bases_fraction, |
168 | 185 | ) |
169 | | - track_sampler = None |
170 | | - if self._sub_sample_tracks is not None: |
171 | | - track_sampler = TrackSampler( |
172 | | - total_number_of_tracks=len(self.bigwig_collection), |
173 | | - sample_size=self._sub_sample_tracks, |
174 | | - ) |
175 | 186 |
|
176 | 187 | query_batch_generator = QueryBatchGenerator( |
177 | 188 | genomic_location_sampler=sequence_sampler, |
178 | 189 | center_bin_to_predict=self.center_bin_to_predict, |
179 | 190 | batch_size=self.super_batch_size, |
180 | | - track_sampler=track_sampler, |
| 191 | + track_sampler=self._track_sampler, |
181 | 192 | ) |
182 | 193 |
|
183 | 194 | return StreamedDataloader( |
|
0 commit comments