Skip to content

Commit c0c66dd

Browse files
committed
set pg names
1 parent e43621c commit c0c66dd

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

torchtitan/distributed/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -305,12 +305,13 @@ def set_pg_timeouts(timeout, world_mesh):
305305
torch.distributed.barrier(device_ids=[device_module.current_device()])
306306
device_module.synchronize()
307307

308-
groups = [world_mesh.get_group(mesh_dim) for mesh_dim in range(world_mesh.ndim)]
308+
groups = [(mesh_dim, world_mesh.get_group(mesh_dim)) for mesh_dim in range(world_mesh.ndim)]
309309

310310
# None represents the 'default' PG, not part of the mesh
311-
groups.append(None)
312-
for group in groups:
311+
groups.append((-1, None))
312+
for mesh_dim, group in groups:
313313
torch.distributed.distributed_c10d._set_pg_timeout(timeout, group)
314+
group.setGroupName(f"mesh_dim_{mesh_dim}")
314315

315316

316317
@torch.no_grad()

0 commit comments

Comments
 (0)