File tree Expand file tree Collapse file tree 1 file changed +6
-3
lines changed Expand file tree Collapse file tree 1 file changed +6
-3
lines changed Original file line number Diff line number Diff line change @@ -305,12 +305,15 @@ def set_pg_timeouts(timeout, world_mesh):
305305 torch .distributed .barrier (device_ids = [device_module .current_device ()])
306306 device_module .synchronize ()
307307
308- groups = [world_mesh .get_group (mesh_dim ) for mesh_dim in range (world_mesh .ndim )]
308+ groups = [( mesh_dim , world_mesh .get_group (mesh_dim ) ) for mesh_dim in range (world_mesh .ndim )]
309309
310310 # None represents the 'default' PG, not part of the mesh
311- groups .append (None )
312- for group in groups :
311+ groups .append (( 0 , None ) )
312+ for mesh_dim , group in groups :
313313 torch .distributed .distributed_c10d ._set_pg_timeout (timeout , group )
314+ if group is None :
315+ continue
316+ group .setGroupName (f"mesh_dim_{ mesh_dim } " )
314317
315318
316319@torch .no_grad ()
You can’t perform that action at this time.
0 commit comments