Skip to content

Commit 7e6afe5

Browse files
committed
set pg names
1 parent 22a1a9a commit 7e6afe5

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

torchtitan/distributed/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def maybe_enable_amp(
232232

233233

234234
def init_distributed(
235-
comm_config: CommConfig, enable_cpu_backend: bool = False, base_folder: str = ""
235+
comm_config: CommConfig, enable_cpu_backend: bool = False, base_folder: str = "", ranks: list[int] = []
236236
):
237237
def _warn_overwrite_env(env, val):
238238
if env in os.environ:
@@ -276,6 +276,7 @@ def _get_distributed_backend(enable_cpu_backend):
276276
torch.distributed.init_process_group(
277277
backend=_get_distributed_backend(enable_cpu_backend),
278278
timeout=timedelta(seconds=comm_config.init_timeout_seconds),
279+
_ranks=ranks,
279280
)
280281

281282

torchtitan/train.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,21 @@ def __init__(self, job_config: JobConfig):
8383
# Device has to be set before creating TorchFT manager.
8484
device_module.set_device(self.device)
8585

86+
ranks = []
87+
ft_config = job_config.fault_tolerance
88+
if ft_config.enable:
89+
group_size = ft_config.group_size
90+
replica_id = ft_config.replica_id
91+
first_rank = replica_id * group_size
92+
last_rank = first_rank + group_size - 1
93+
ranks = list(range(first_rank, last_rank + 1))
94+
8695
# init distributed and build meshes
8796
dist_utils.init_distributed(
8897
job_config.comm,
8998
enable_cpu_backend=job_config.training.enable_cpu_offload,
9099
base_folder=job_config.job.dump_folder,
100+
ranks=ranks,
91101
)
92102

93103
job_config.maybe_log()

0 commit comments

Comments
 (0)