Skip to content

Commit e3bb189

Browse files
committed
set pg names
1 parent e5606c9 commit e3bb189

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

torchtitan/distributed/utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -305,12 +305,15 @@ def set_pg_timeouts(timeout, world_mesh):
305305
torch.distributed.barrier(device_ids=[device_module.current_device()])
306306
device_module.synchronize()
307307

308-
groups = [world_mesh.get_group(mesh_dim) for mesh_dim in range(world_mesh.ndim)]
308+
groups = [(mesh_dim, world_mesh.get_group(mesh_dim)) for mesh_dim in range(world_mesh.ndim)]
309309

310310
# None represents the 'default' PG, not part of the mesh
311-
groups.append(None)
312-
for group in groups:
311+
groups.append((0, None))
312+
for mesh_dim, group in groups:
313313
torch.distributed.distributed_c10d._set_pg_timeout(timeout, group)
314+
if group is None:
315+
continue
316+
group.setGroupName(f"mesh_dim_{mesh_dim}")
314317

315318

316319
@torch.no_grad()

0 commit comments

Comments
 (0)