We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent e43621c commit c0c66ddCopy full SHA for c0c66dd
torchtitan/distributed/utils.py
@@ -305,12 +305,13 @@ def set_pg_timeouts(timeout, world_mesh):
305
torch.distributed.barrier(device_ids=[device_module.current_device()])
306
device_module.synchronize()
307
308
- groups = [world_mesh.get_group(mesh_dim) for mesh_dim in range(world_mesh.ndim)]
+ groups = [(mesh_dim, world_mesh.get_group(mesh_dim)) for mesh_dim in range(world_mesh.ndim)]
309
310
# None represents the 'default' PG, not part of the mesh
311
- groups.append(None)
312
- for group in groups:
+ groups.append((-1, None))
+ for mesh_dim, group in groups:
313
torch.distributed.distributed_c10d._set_pg_timeout(timeout, group)
314
+ group.setGroupName(f"mesh_dim_{mesh_dim}")
315
316
317
@torch.no_grad()
0 commit comments