11import os
22import time
3- from typing import Any , Dict , List , Optional
4-
5- from tml .ml_logging .torch_logging import logging
6- from tml .common .filesystem import infer_fs , is_gcs_fs
3+ from typing import (
4+ Any ,
5+ Dict ,
6+ Generator ,
7+ List ,
8+ Optional ,
9+ )
710
811import torchsnapshot
9-
12+ from tml .common .filesystem import (
13+ infer_fs ,
14+ is_gcs_fs ,
15+ )
16+ from tml .ml_logging .torch_logging import (
17+ logging ,
18+ )
19+ from torch import (
20+ FloatTensor ,
21+ )
1022
1123DONE_EVAL_SUBDIR = "evaled_by"
1224GCS_PREFIX = "gs://"
@@ -25,22 +37,22 @@ def __init__(self, save_dir: str, state: Dict[str, Any]) -> None:
2537 self .state ["extra_state" ] = torchsnapshot .StateDict (step = 0 , walltime = 0.0 )
2638
2739 @property
28- def step (self ):
40+ def step (self ) -> int :
2941 return self .state ["extra_state" ]["step" ]
3042
3143 @step .setter
3244 def step (self , step : int ) -> None :
3345 self .state ["extra_state" ]["step" ] = step
3446
3547 @property
36- def walltime (self ):
48+ def walltime (self ) -> float :
3749 return self .state ["extra_state" ]["walltime" ]
3850
3951 @walltime .setter
4052 def walltime (self , walltime : float ) -> None :
4153 self .state ["extra_state" ]["walltime" ] = walltime
4254
43- def save (self , global_step : int ) -> "PendingSnapshot" :
55+ def save (self , global_step : int ) -> "PendingSnapshot" : # type: ignore
4456 """Saves checkpoint with given global_step."""
4557 path = os .path .join (self .save_dir , str (global_step ))
4658 logging .info (f"Saving snapshot global_step { global_step } to { path } ." )
@@ -98,7 +110,7 @@ def load_snapshot_to_weight(
98110 cls ,
99111 embedding_snapshot : torchsnapshot .Snapshot ,
100112 snapshot_emb_name : str ,
101- weight_tensor ,
113+ weight_tensor : FloatTensor ,
102114 ) -> None :
103115 """Loads pretrained embedding from the snapshot to the model.
104116 Utilise partial lodaing meachanism from torchsnapshot.
@@ -128,19 +140,21 @@ def _eval_done_path(checkpoint_path: str, eval_partition: str) -> str:
128140 return os .path .join (_eval_subdir (checkpoint_path ), f"{ eval_partition } _DONE" )
129141
130142
131- def is_done_eval (checkpoint_path : str , eval_partition : str ):
132- return get_checkpoint (checkpoint_path ).exists (_eval_done_path (checkpoint_path , eval_partition ))
143+ def is_done_eval (checkpoint_path : str , eval_partition : str ) -> bool :
144+ return get_checkpoint (checkpoint_path ).exists (_eval_done_path (checkpoint_path , eval_partition )) # type: ignore[attr-defined]
133145
134146
135- def mark_done_eval (checkpoint_path : str , eval_partition : str ):
147+ def mark_done_eval (checkpoint_path : str , eval_partition : str ) -> Any :
136148 infer_fs (checkpoint_path ).touch (_eval_done_path (checkpoint_path , eval_partition ))
137149
138150
139151def step_from_checkpoint (checkpoint : str ) -> int :
140152 return int (os .path .basename (checkpoint ))
141153
142154
143- def checkpoints_iterator (save_dir : str , seconds_to_sleep : int = 30 , timeout : int = 1800 ):
155+ def checkpoints_iterator (
156+ save_dir : str , seconds_to_sleep : int = 30 , timeout : int = 1800
157+ ) -> Generator [str , None , None ]:
144158 """Simplified equivalent of tf.train.checkpoints_iterator.
145159
146160 Args:
@@ -149,7 +163,7 @@ def checkpoints_iterator(save_dir: str, seconds_to_sleep: int = 30, timeout: int
149163
150164 """
151165
152- def _poll (last_checkpoint : Optional [str ] = None ):
166+ def _poll (last_checkpoint : Optional [str ] = None ) -> Optional [ str ] :
153167 stop_time = time .time () + timeout
154168 while True :
155169 _checkpoint_path = get_checkpoint (save_dir , missing_ok = True )
0 commit comments