1818import  abc 
1919from  collections .abc  import  Mapping 
2020import  dataclasses 
21+ import  functools 
2122import  gc 
2223import  os 
2324import  time 
@@ -96,7 +97,6 @@ def export_model(self, model: keras.Model, model_dir: str):
9697      model: The Keras model constructed by `create_model`. 
9798      model_dir: The model directory passed to the trainer. 
9899    """ 
99-     model .save (os .path .join (model_dir , core .KERAS_MODEL_SAVEFILE ))
100100
101101
102102class  KerasTrainer (core .Trainer [KerasTask ]):
@@ -118,6 +118,7 @@ def __init__(
118118      max_checkpoints_to_keep : int  =  5 ,
119119      checkpoint_save_interval_epochs : int  =  1 ,
120120      rng_seed : int  =  core .DEFAULT_RNG_SEED ,
121+       legacy_checkpoint_format : bool  =  True ,
121122  ):
122123    """Initializes the instance.""" 
123124
@@ -143,60 +144,77 @@ def __init__(
143144    self ._steps_per_eval  =  steps_per_eval 
144145    self ._continuous_eval_timeout  =  continuous_eval_timeout 
145146    self ._steps_per_loop  =  steps_per_loop 
146-     self ._checkpoint_manager  =  None 
147147    self ._marker_path  =  os .path .join (
148148        model_dir , core .TRAINING_COMPLETE_MARKER_FILE 
149149    )
150150    self ._checkpoint_dir  =  os .path .join (model_dir , core .CHECKPOINT_DIR )
151+     self ._max_checkpoints_to_keep  =  max_checkpoints_to_keep 
152+     self ._checkpoint_save_interval_epochs  =  checkpoint_save_interval_epochs 
153+     self ._legacy_checkpoint_format  =  legacy_checkpoint_format 
151154
155+   @functools .cached_property  
156+   def  train_callbacks (self ) ->  list [keras .callbacks .Callback ]:
157+     """Returns the training callbacks.""" 
152158    if  keras .backend .backend () ==  "jax" :
153-       self ._checkpoint_manager  =  keras_utils .KerasOrbaxCheckpointManager (
154-           checkpoint_dir = self ._checkpoint_dir ,
155-           max_to_keep = max_checkpoints_to_keep ,
156-           save_interval_epochs = checkpoint_save_interval_epochs ,
157-       )
158-       self ._train_callbacks  =  [
159+       if  self ._legacy_checkpoint_format :
160+         checkpoint_manager  =  keras_utils .KerasOrbaxCheckpointManager (
161+             checkpoint_dir = self ._checkpoint_dir ,
162+             max_to_keep = self ._max_checkpoints_to_keep ,
163+             save_interval_epochs = self ._checkpoint_save_interval_epochs ,
164+         )
165+       else :
166+         checkpoint_manager  =  keras_utils .KerasOrbaxCheckpointManagerV2 (
167+             checkpoint_dir = self ._checkpoint_dir ,
168+             max_to_keep = self ._max_checkpoints_to_keep ,
169+             save_interval_epochs = self ._checkpoint_save_interval_epochs ,
170+         )
171+       return  [
159172          keras_utils .EpochSummaryCallback (
160-               log_dir = os .path .join (model_dir , core .LOG_DIR ),
161-               steps_per_epoch = steps_per_loop ,
173+               log_dir = os .path .join (self . _model_dir , core .LOG_DIR ),
174+               steps_per_epoch = self . _steps_per_loop ,
162175              write_steps_per_second = True ,
163176          ),
164177          keras_utils .EpochOrbaxCheckpointAndRestoreCallback (
165-               checkpoint_manager = self . _checkpoint_manager ,
178+               checkpoint_manager = checkpoint_manager ,
166179              marker_path = self ._marker_path ,
167180          ),
168181      ]
169-       self ._eval_callbacks  =  [
182+     return  [
183+         keras .callbacks .TensorBoard (
184+             log_dir = os .path .join (self ._model_dir , core .LOG_DIR ),
185+             write_steps_per_second = True ,
186+         ),
187+         keras .callbacks .BackupAndRestore (
188+             backup_dir = os .path .join (self ._model_dir , core .BACKUP_DIR ),
189+         ),
190+         keras .callbacks .ModelCheckpoint (
191+             filepath = os .path .join (
192+                 self ._model_dir ,
193+                 core .CHECKPOINT_DIR ,
194+                 "ckpt-{epoch:d}.weights.h5" ,
195+             ),
196+             save_weights_only = True ,
197+             verbose = 1 ,
198+         ),
199+     ]
200+ 
201+   @functools .cached_property  
202+   def  eval_callbacks (self ) ->  list [keras .callbacks .Callback ]:
203+     """Returns the evaluation callbacks.""" 
204+     if  keras .backend .backend () ==  "jax" :
205+       return  [
170206          keras_utils .EpochSummaryCallback (
171-               log_dir = os .path .join (model_dir , core .LOG_DIR ),
172-               steps_per_epoch = steps_per_loop ,
207+               log_dir = os .path .join (self . _model_dir , core .LOG_DIR ),
208+               steps_per_epoch = self . _steps_per_loop ,
173209              write_steps_per_second = False ,
174210          ),
175211      ]
176-     else :
177-       self ._checkpoint_manager  =  None 
178-       self ._train_callbacks  =  [
179-           keras .callbacks .TensorBoard (
180-               log_dir = os .path .join (model_dir , core .LOG_DIR ),
181-               write_steps_per_second = True ,
182-           ),
183-           keras .callbacks .BackupAndRestore (
184-               backup_dir = os .path .join (model_dir , core .BACKUP_DIR ),
185-           ),
186-           keras .callbacks .ModelCheckpoint (
187-               filepath = os .path .join (
188-                   model_dir , core .CHECKPOINT_DIR , "ckpt-{epoch:d}.weights.h5" 
189-               ),
190-               save_weights_only = True ,
191-               verbose = 1 ,
192-           ),
193-       ]
194-       self ._eval_callbacks  =  [
195-           keras .callbacks .TensorBoard (
196-               log_dir = os .path .join (model_dir , core .LOG_DIR ),
197-               write_steps_per_second = True ,
198-           ),
199-       ]
212+     return  [
213+         keras .callbacks .TensorBoard (
214+             log_dir = os .path .join (self ._model_dir , core .LOG_DIR ),
215+             write_steps_per_second = True ,
216+         ),
217+     ]
200218
201219  def  _maybe_get_model_kws (
202220      self , task : KerasTask , dataset : tf .data .Dataset 
@@ -218,7 +236,9 @@ def train(self, task: KerasTask) -> core.Logs:
218236        dataset ,
219237        epochs = self ._train_epochs ,
220238        steps_per_epoch = self ._steps_per_loop ,
221-         callbacks = self ._train_callbacks ,
239+         callbacks = self .train_callbacks ,
240+         verbose = 0 ,  # this disables the progbar update at train end, which 
241+         # causes low TPU duty cycle. 
222242    )
223243    model .summary (print_fn = logging .info )
224244
@@ -237,14 +257,14 @@ def evaluate(self, task: KerasTask) -> core.Logs:
237257    if  keras .backend .backend () ==  "jax" :
238258      [tb_cbk ] =  [
239259          cbk 
240-           for  cbk  in  self ._eval_callbacks 
260+           for  cbk  in  self .eval_callbacks 
241261          if  isinstance (cbk , keras_utils .EpochSummaryCallback )
242262      ]
243263      epoch_start_time  =  time .time ()
244264      history  =  model .evaluate (
245265          dataset ,
246266          steps = self ._steps_per_eval ,
247-           callbacks = self ._eval_callbacks ,
267+           callbacks = self .eval_callbacks ,
248268          return_dict = True ,
249269      )
250270      epoch_dt  =  time .time () -  epoch_start_time 
@@ -257,7 +277,7 @@ def evaluate(self, task: KerasTask) -> core.Logs:
257277    return  model .evaluate (
258278        dataset ,
259279        steps = self ._steps_per_eval ,
260-         callbacks = self ._eval_callbacks ,
280+         callbacks = self .eval_callbacks ,
261281    )
262282
263283  def  train_and_evaluate (self , task : KerasTask ) ->  core .Logs :
@@ -277,7 +297,7 @@ def train_and_evaluate(self, task: KerasTask) -> core.Logs:
277297        steps_per_epoch = self ._steps_per_loop ,
278298        # Explicitly set to None for deterministic evaluation. 
279299        validation_steps = None ,
280-         callbacks = self ._train_callbacks ,
300+         callbacks = self .train_callbacks ,
281301    )
282302    model .summary (print_fn = logging .info )
283303
@@ -308,7 +328,10 @@ def timeout_fn() -> bool:
308328    else :
309329      steps_msg  =  "running complete evaluation..." 
310330
331+     use_legacy_checkpoint_format  =  self ._legacy_checkpoint_format 
332+ 
311333    class  _RestoreCallback (keras .callbacks .Callback ):
334+       """Callback for restoring the model from the latest checkpoint.""" 
312335
313336      def  __init__ (
314337          self ,
@@ -319,9 +342,14 @@ def __init__(
319342        self ._epoch  =  epoch 
320343
321344      def  on_test_begin (self , logs : Mapping [str , Any ] |  None  =  None ):
322-         keras_utils .restore_keras_model (
323-             model , self ._checkpoint_dir , step = self ._epoch 
324-         )
345+         if  use_legacy_checkpoint_format :
346+           keras_utils .restore_keras_model (
347+               model , self ._checkpoint_dir , step = self ._epoch 
348+           )
349+         else :
350+           keras_utils .restore_keras_checkpoint (
351+               self ._checkpoint_dir , model = model , epoch = self ._epoch 
352+           )
325353
326354    history  =  None 
327355    for  epoch  in  ocp .checkpoint_utils .checkpoints_iterator (
@@ -332,7 +360,7 @@ def on_test_begin(self, logs: Mapping[str, Any] | None = None):
332360      restore_callback  =  _RestoreCallback (self ._checkpoint_dir , epoch )
333361      [tb_cbk ] =  [
334362          cbk 
335-           for  cbk  in  self ._eval_callbacks 
363+           for  cbk  in  self .eval_callbacks 
336364          if  isinstance (cbk , keras_utils .EpochSummaryCallback )
337365      ]
338366      try :
@@ -346,7 +374,7 @@ def on_test_begin(self, logs: Mapping[str, Any] | None = None):
346374        history  =  model .evaluate (
347375            eval_dataset ,
348376            steps = self ._steps_per_eval ,
349-             callbacks = [restore_callback ] +  self ._eval_callbacks ,
377+             callbacks = [restore_callback ] +  self .eval_callbacks ,
350378            return_dict = True ,
351379        )
352380
0 commit comments