Skip to content

Commit 22a1a9a

Browse files
committed
fix setting ft state dicts when ft checkpointing is disabled
Summary: - when ft dataloader checkpointing is disabled, we also don't set the ft state - make it so that when ft checkpointing is disabled, we still set the state dict so that model, optimizer etc. can be recovered from a different replica
1 parent 2a7a148 commit 22a1a9a

File tree

1 file changed

+34
-25
lines changed

1 file changed

+34
-25
lines changed

torchtitan/components/checkpoint.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -189,15 +189,25 @@ def __init__(
189189
self.enable = checkpoint_config.enable
190190
self.load_only = checkpoint_config.load_only
191191

192+
self.states = states
193+
self.states.update(
194+
{
195+
MODEL: ModelWrapper(model_parts),
196+
OPTIMIZER: optimizers,
197+
DATALOADER: dataloader,
198+
LR_SCHEDULER: lr_schedulers,
199+
}
200+
)
201+
192202
self.ft_manager = (
193-
ft_manager.manager
194-
if ft_manager
195-
and ft_manager.enabled
196-
and checkpoint_config.enable_ft_dataloader_checkpoints
197-
else None
203+
ft_manager.manager if ft_manager and ft_manager.enabled else None
198204
)
199205

200-
if ft_manager and ft_manager.enabled and not self.ft_manager:
206+
self.enable_ft_dataloader_checkpoints = (
207+
self.ft_manager and checkpoint_config.enable_ft_dataloader_checkpoints
208+
)
209+
210+
if self.ft_manager and not self.enable_ft_dataloader_checkpoints:
201211
logger.warn(
202212
"Fault tolerance is enabled but enable_ft_dataloader_checkpoints is False. "
203213
"This means replicas can retrain over the same data multiple times, which can result in overfitting."
@@ -229,20 +239,11 @@ def load_state_dict(state_dict):
229239
async_mode = checkpoint_config.async_mode.lower()
230240
self.enable_staging = (
231241
self.enable and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
232-
) or self.ft_manager
242+
) or self.enable_ft_dataloader_checkpoints
233243

234-
if not self.enable and self.ft_manager is None:
244+
if not self.enable and not self.enable_ft_dataloader_checkpoints:
235245
return
236246

237-
self.states = states
238-
self.states.update(
239-
{
240-
MODEL: ModelWrapper(model_parts),
241-
OPTIMIZER: optimizers,
242-
DATALOADER: dataloader,
243-
LR_SCHEDULER: lr_schedulers,
244-
}
245-
)
246247
self.ft_states = {DATALOADER: dataloader}
247248

248249
self.staging = False
@@ -279,7 +280,7 @@ def load_state_dict(state_dict):
279280
if (
280281
async_mode == AsyncMode.ASYNC
281282
or async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
282-
or self.ft_manager
283+
or self.enable_ft_dataloader_checkpoints
283284
):
284285
self.pg = dist.new_group(backend="gloo")
285286

@@ -480,14 +481,16 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
480481
None
481482
"""
482483

483-
if self.ft_manager:
484+
if self.enable_ft_dataloader_checkpoints:
484485
self._ft_save(curr_step)
485486

486487
if not self._should_save(curr_step, last_step):
487488
return
488489

489490
begin = time.monotonic()
490-
if not self.ft_manager or self.ft_manager.participating_rank() == 0:
491+
if not self.enable_ft_dataloader_checkpoints or (
492+
self.ft_manager and self.ft_manager.participating_rank() == 0
493+
):
491494
logger.info("Saving the checkpoint (or staging if async is enabled).")
492495
checkpoint_id = self._create_checkpoint_id(curr_step)
493496
self._async_wait()
@@ -530,7 +533,8 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
530533
"Finished saving the checkpoint (or staging if async is enabled)"
531534
f"in {time.monotonic() - begin:.2f} seconds."
532535
)
533-
elif self.ft_manager:
536+
elif self.enable_ft_dataloader_checkpoints:
537+
assert self.ft_manager is not None
534538
logger.info(
535539
"Replica %d doesn't save checkpoint.",
536540
self.ft_manager.participating_rank(),
@@ -551,7 +555,7 @@ def load(self, step: int = -1) -> bool:
551555
bool: Whether the checkpoint was loaded successfully.
552556
"""
553557

554-
if self.ft_manager:
558+
if self.enable_ft_dataloader_checkpoints:
555559
self._ft_load()
556560

557561
if not self.enable:
@@ -749,7 +753,7 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]:
749753

750754
states_to_load = self._flattened_model_states_sd(states_to_load)
751755

752-
if self.ft_manager:
756+
if self.enable_ft_dataloader_checkpoints:
753757
states_to_load.pop(DATALOADER)
754758

755759
return states_to_load
@@ -805,7 +809,9 @@ def _async_wait(self) -> None:
805809
if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
806810
if self.save_future is not None:
807811
self.save_future.result()
808-
elif self.async_mode == AsyncMode.ASYNC or self.ft_manager is not None:
812+
elif (
813+
self.async_mode == AsyncMode.ASYNC or self.enable_ft_dataloader_checkpoints
814+
):
809815
if self.save_future is not None:
810816
self.save_future.result()
811817
self.save_future = None
@@ -820,7 +826,10 @@ def _purge_stale_checkpoints(self):
820826
self.keep_latest_k > 0
821827
and dist.get_rank() == 0
822828
and os.path.isdir(self.folder)
823-
and (not self.ft_manager or self.ft_manager.participating_rank() == 0)
829+
and (
830+
not self.enable_ft_dataloader_checkpoints
831+
or (self.ft_manager and self.ft_manager.participating_rank() == 0)
832+
)
824833
):
825834
discovered_checkpoints = []
826835
for filename in os.listdir(self.folder):

0 commit comments

Comments
 (0)