@@ -243,7 +243,7 @@ def reset(self):
243243 # value in worker_workset()
244244 self .rng = np .random .default_rng (self .data_loader_rng_seed )
245245
246- fsm = (
246+ fsm : int = (
247247 self .forecast_steps [min (self .epoch , len (self .forecast_steps ) - 1 )]
248248 if self .forecast_policy != "random"
249249 else self .forecast_steps .max ()
@@ -255,7 +255,7 @@ def reset(self):
255255 index_range = self .time_window_handler .get_index_range ()
256256 idx_end = index_range .end
257257 # native length of datasets, independent of epoch length that has potentially been specified
258- forecast_len = ( self .len_hrs * (fsm + 1 )) // self . step_hrs
258+ forecast_len = self .time_window_handler . get_n_steps (fsm + 1 )
259259 idx_end -= forecast_len + self .forecast_offset
260260 assert idx_end > 0 , "dataset size too small for forecast range"
261261 self .perms = np .arange (index_range .start , idx_end )
@@ -373,12 +373,10 @@ def __iter__(self):
373373 for fstep in range (
374374 self .forecast_offset , self .forecast_offset + forecast_dt + 1
375375 ):
376- step_forecast_dt = (
377- idx + (self .forecast_delta_hrs * fstep ) // self .step_hrs
378- )
379- time_win2 = self .time_window_handler .window (step_forecast_dt )
376+ forecast_idx = idx + self .time_window_handler .get_n_steps (fstep )
377+ time_win2 = self .time_window_handler .window (forecast_idx )
380378
381- rdata = ds .get_target (step_forecast_dt )
379+ rdata = ds .get_target (forecast_idx )
382380
383381 sample_is_empty = rdata .is_empty ()
384382 if sample_is_empty :
0 commit comments