@@ -49,16 +49,18 @@ class CloudPathwaysArrayHandler(type_handlers.ArrayHandler):
4949
5050 def __init__ (
5151 self ,
52- read_timeout : datetime .timedelta | None = None ,
52+ timeout : datetime .timedelta | None = None ,
5353 use_ocdbt : bool = False ,
5454 ):
55- """Constructor .
55+ """Orbax array handler for Pathways on Cloud with Persistence API .
5656
5757 Args:
58- read_timeout : Duration indicating the timeout for reading arrays
58+ timeout : Duration indicating the timeout for reading and writing arrays
5959 use_ocdbt: allows using Tensorstore OCDBT driver.
6060 """
61- self ._read_timeout = read_timeout
61+ if timeout is None :
62+ timeout = datetime .timedelta (hours = 1 )
63+ self .timeout = timeout
6264
6365 if use_ocdbt :
6466 raise ValueError ("OCDBT not supported for Pathways." )
@@ -92,7 +94,7 @@ async def serialize(
9294
9395 self ._wait_for_directory_creation_signals ()
9496 locations , names = extract_parent_dir_and_name (infos )
95- f = functools .partial (helper .write_one_array , timeout = self ._read_timeout )
97+ f = functools .partial (helper .write_one_array , timeout = self .timeout )
9698 futures_results = list (map (f , locations , names , values ))
9799
98100 return [
@@ -181,7 +183,7 @@ async def deserialize(
181183 grouped_global_shapes ,
182184 grouped_shardings ,
183185 global_mesh .devices ,
184- timeout = self ._read_timeout ,
186+ timeout = self .timeout ,
185187 )
186188 # each persistence call is awaited serially.
187189 read_future .result ()
@@ -191,7 +193,7 @@ async def deserialize(
191193
192194
193195def register_pathways_handlers (
194- read_timeout : datetime .timedelta | None = None ,
196+ timeout : datetime .timedelta | None = None ,
195197):
196198 """Function that must be called before saving or restoring with Pathways."""
197199 logger .debug (
@@ -200,7 +202,7 @@ def register_pathways_handlers(
200202 type_handlers .register_type_handler (
201203 jax .Array ,
202204 CloudPathwaysArrayHandler (
203- read_timeout = read_timeout ,
205+ timeout = timeout ,
204206 ),
205207 override = True ,
206208 )
0 commit comments