Skip to content

Commit 76a2e2e

Browse files
committed
play around with batch loop in MultiStreamDataSampler
1 parent a6373aa commit 76a2e2e

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

src/weathergen/datasets/multi_stream_data_sampler.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -309,16 +309,19 @@ def __iter__(self):
309309
# bidx is used to count the #batches that have been emitted
310310
# idx_raw is used to index into the dataset; the decoupling is needed
311311
# since there are empty batches
312-
idx_raw = iter_start
313-
for i, _bidx in enumerate(range(iter_start, iter_end, self.batch_size)):
314-
# forecast_dt needs to be constant per batch (amortized through data parallel training)
315-
forecast_dt = self.perms_forecast_dt[i]
316-
312+
idx_raw = iter_start # start step index
313+
assert (iter_end - iter_start) // self.batch_size == len(self.perms_forecast_dt)
314+
# forecast_dt needs to be constant per batch (amortized through data parallel training)
315+
for forecast_dt in self.perms_forecast_dt: # bidx loop
317316
# use while loop due to the scattered nature of the data in time and to
318317
# ensure batches are not empty
319318
batch = []
320319
while len(batch) < self.batch_size:
321-
idx: TIndex = self.perms[idx_raw % self.perms.shape[0]]
320+
# TODO: identity? len(self.perms) should be most likely longer then
321+
# idx_raw since it contains the all dataset steps (- small adjustment)
322+
# whereas iter_end-iter_start should be smaller since it is only a subset
323+
perm_idx = idx_raw % len(self.perms)
324+
idx: TIndex = self.perms[perm_idx]
322325
idx_raw += 1
323326

324327
time_win1 = self.time_window_handler.window(idx)

0 commit comments

Comments
 (0)