Skip to content

Commit a6373aa

Browse files
committed
use TimeWindowHandler to convert forecast steps to indices
1 parent 17ba0a5 commit a6373aa

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

src/weathergen/datasets/data_reader_base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,9 @@ def window(self, idx: TIndex) -> DTRange:
188188

189189
return DTRange(t_start_win, t_end_win)
190190

191+
def get_n_steps(self, forecast_step: int) -> int:
192+
return (int(self.t_window_len) * forecast_step) // int(self.t_window_step)
193+
191194

192195
@dataclass
193196
class ReaderData:

src/weathergen/datasets/multi_stream_data_sampler.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def reset(self):
243243
# value in worker_workset()
244244
self.rng = np.random.default_rng(self.data_loader_rng_seed)
245245

246-
fsm = (
246+
fsm: int = (
247247
self.forecast_steps[min(self.epoch, len(self.forecast_steps) - 1)]
248248
if self.forecast_policy != "random"
249249
else self.forecast_steps.max()
@@ -255,7 +255,7 @@ def reset(self):
255255
index_range = self.time_window_handler.get_index_range()
256256
idx_end = index_range.end
257257
# native length of datasets, independent of epoch length that has potentially been specified
258-
forecast_len = (self.len_hrs * (fsm + 1)) // self.step_hrs
258+
forecast_len = self.time_window_handler.get_n_steps(fsm + 1)
259259
idx_end -= forecast_len + self.forecast_offset
260260
assert idx_end > 0, "dataset size too small for forecast range"
261261
self.perms = np.arange(index_range.start, idx_end)
@@ -373,12 +373,10 @@ def __iter__(self):
373373
for fstep in range(
374374
self.forecast_offset, self.forecast_offset + forecast_dt + 1
375375
):
376-
step_forecast_dt = (
377-
idx + (self.forecast_delta_hrs * fstep) // self.step_hrs
378-
)
379-
time_win2 = self.time_window_handler.window(step_forecast_dt)
376+
forecast_idx = idx + self.time_window_handler.get_n_steps(fstep)
377+
time_win2 = self.time_window_handler.window(forecast_idx)
380378

381-
rdata = ds.get_target(step_forecast_dt)
379+
rdata = ds.get_target(forecast_idx)
382380

383381
sample_is_empty = rdata.is_empty()
384382
if sample_is_empty:

0 commit comments

Comments
 (0)