@@ -189,15 +189,25 @@ def __init__(
189189 self .enable = checkpoint_config .enable
190190 self .load_only = checkpoint_config .load_only
191191
192+ self .states = states
193+ self .states .update (
194+ {
195+ MODEL : ModelWrapper (model_parts ),
196+ OPTIMIZER : optimizers ,
197+ DATALOADER : dataloader ,
198+ LR_SCHEDULER : lr_schedulers ,
199+ }
200+ )
201+
192202 self .ft_manager = (
193- ft_manager .manager
194- if ft_manager
195- and ft_manager .enabled
196- and checkpoint_config .enable_ft_dataloader_checkpoints
197- else None
203+ ft_manager .manager if ft_manager and ft_manager .enabled else None
198204 )
199205
200- if ft_manager and ft_manager .enabled and not self .ft_manager :
206+ self .enable_ft_dataloader_checkpoints = (
207+ self .ft_manager and checkpoint_config .enable_ft_dataloader_checkpoints
208+ )
209+
210+ if self .ft_manager and not self .enable_ft_dataloader_checkpoints :
201211 logger .warn (
202212 "Fault tolerance is enabled but enable_ft_dataloader_checkpoints is False. "
203213 "This means replicas can retrain over the same data multiple times, which can result in overfitting."
@@ -229,20 +239,11 @@ def load_state_dict(state_dict):
229239 async_mode = checkpoint_config .async_mode .lower ()
230240 self .enable_staging = (
231241 self .enable and async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM
232- ) or self .ft_manager
242+ ) or self .enable_ft_dataloader_checkpoints
233243
234- if not self .enable and self .ft_manager is None :
244+ if not self .enable and not self .enable_ft_dataloader_checkpoints :
235245 return
236246
237- self .states = states
238- self .states .update (
239- {
240- MODEL : ModelWrapper (model_parts ),
241- OPTIMIZER : optimizers ,
242- DATALOADER : dataloader ,
243- LR_SCHEDULER : lr_schedulers ,
244- }
245- )
246247 self .ft_states = {DATALOADER : dataloader }
247248
248249 self .staging = False
@@ -279,7 +280,7 @@ def load_state_dict(state_dict):
279280 if (
280281 async_mode == AsyncMode .ASYNC
281282 or async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM
282- or self .ft_manager
283+ or self .enable_ft_dataloader_checkpoints
283284 ):
284285 self .pg = dist .new_group (backend = "gloo" )
285286
@@ -480,14 +481,16 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
480481 None
481482 """
482483
483- if self .ft_manager :
484+ if self .enable_ft_dataloader_checkpoints :
484485 self ._ft_save (curr_step )
485486
486487 if not self ._should_save (curr_step , last_step ):
487488 return
488489
489490 begin = time .monotonic ()
490- if not self .ft_manager or self .ft_manager .participating_rank () == 0 :
491+ if not self .enable_ft_dataloader_checkpoints or (
492+ self .ft_manager and self .ft_manager .participating_rank () == 0
493+ ):
491494 logger .info ("Saving the checkpoint (or staging if async is enabled)." )
492495 checkpoint_id = self ._create_checkpoint_id (curr_step )
493496 self ._async_wait ()
@@ -530,7 +533,8 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
530533 "Finished saving the checkpoint (or staging if async is enabled)"
531534 f"in { time .monotonic () - begin :.2f} seconds."
532535 )
533- elif self .ft_manager :
536+ elif self .enable_ft_dataloader_checkpoints :
537+ assert self .ft_manager is not None
534538 logger .info (
535539 "Replica %d doesn't save checkpoint." ,
536540 self .ft_manager .participating_rank (),
@@ -551,7 +555,7 @@ def load(self, step: int = -1) -> bool:
551555 bool: Whether the checkpoint was loaded successfully.
552556 """
553557
554- if self .ft_manager :
558+ if self .enable_ft_dataloader_checkpoints :
555559 self ._ft_load ()
556560
557561 if not self .enable :
@@ -749,7 +753,7 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]:
749753
750754 states_to_load = self ._flattened_model_states_sd (states_to_load )
751755
752- if self .ft_manager :
756+ if self .enable_ft_dataloader_checkpoints :
753757 states_to_load .pop (DATALOADER )
754758
755759 return states_to_load
@@ -805,7 +809,9 @@ def _async_wait(self) -> None:
805809 if self .async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM :
806810 if self .save_future is not None :
807811 self .save_future .result ()
808- elif self .async_mode == AsyncMode .ASYNC or self .ft_manager is not None :
812+ elif (
813+ self .async_mode == AsyncMode .ASYNC or self .enable_ft_dataloader_checkpoints
814+ ):
809815 if self .save_future is not None :
810816 self .save_future .result ()
811817 self .save_future = None
@@ -820,7 +826,10 @@ def _purge_stale_checkpoints(self):
820826 self .keep_latest_k > 0
821827 and dist .get_rank () == 0
822828 and os .path .isdir (self .folder )
823- and (not self .ft_manager or self .ft_manager .participating_rank () == 0 )
829+ and (
830+ not self .enable_ft_dataloader_checkpoints
831+ or (self .ft_manager and self .ft_manager .participating_rank () == 0 )
832+ )
824833 ):
825834 discovered_checkpoints = []
826835 for filename in os .listdir (self .folder ):
0 commit comments