Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 45 additions & 10 deletions src/anemoi/inference/runners/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,23 @@ def _bootstrap_processes(self) -> None:
LOG.warning(
f"world size ({self.config.world_size}) set in the config is ignored because we are launching via srun, using 'SLURM_NTASKS' instead"
)
elif "RANK" in os.environ and "WORLD_SIZE" in os.environ:
# New branch for Azure ML / general distributed env (e.g., env:// mode)
self.global_rank = int(os.environ["RANK"])
self.local_rank = int(
os.environ.get("LOCAL_RANK", self.global_rank)
) # Fallback to global if LOCAL_RANK unset
self.world_size = int(os.environ["WORLD_SIZE"])
self.master_addr = os.environ.get("MASTER_ADDR")
self.master_port = os.environ.get("MASTER_PORT")
if self.master_addr is None or self.master_port is None:
raise ValueError(
"MASTER_ADDR and MASTER_PORT must be set for distributed initialization (e.g., in Azure ML)"
)
if self.config.world_size != 1 and self.config.world_size != self.world_size:
LOG.warning(
f"Config world_size ({self.config.world_size}) ignored; using WORLD_SIZE from environment ({self.world_size})"
)
else:
# If srun is not available, spawn procs manually on a node

Expand Down Expand Up @@ -377,19 +394,29 @@ def _init_parallel(self) -> Optional["torch.distributed.ProcessGroup"]:
else:
backend = "gloo"

dist.init_process_group(
backend=backend,
init_method=f"tcp://{self.master_addr}:{self.master_port}",
timeout=datetime.timedelta(minutes=3),
world_size=self.world_size,
rank=self.global_rank,
)
if backend == "mpi":
# MPI backend: No init_method or explicit sizes needed
dist.init_process_group(backend="mpi")
model_comm_group_ranks = np.arange(self.world_size, dtype=int)
model_comm_group = dist.new_group(model_comm_group_ranks)
else:
if self._using_distributed_env():
init_method = "env://" # Azure ML recommended
else:
init_method = f"tcp://{self.master_addr}:{self.master_port}"
dist.init_process_group(
backend=backend,
init_method=init_method,
timeout=datetime.timedelta(minutes=3),
world_size=self.world_size,
rank=self.global_rank,
)
model_comm_group_ranks = np.arange(self.world_size, dtype=int)
model_comm_group = dist.new_group(model_comm_group_ranks)
LOG.info(f"Creating a model communication group with {self.world_size} devices with the {backend} backend")

model_comm_group_ranks = np.arange(self.world_size, dtype=int)
model_comm_group = dist.new_group(model_comm_group_ranks)
else:
model_comm_group = None
LOG.warning("ParallelRunner selected but world size of 1 detected")

return model_comm_group

Expand All @@ -406,3 +433,11 @@ def _get_parallel_info_from_slurm(self) -> tuple[int, int, int]:
world_size = int(os.environ.get("SLURM_NTASKS", 1)) # Total number of processes

return global_rank, local_rank, world_size

def _using_distributed_env(self) -> bool:
"""Checks for distributed env vars like those in Azure ML."""
return "RANK" in os.environ and "WORLD_SIZE" in os.environ

def _is_mpi_env(self) -> bool:
"""Detects common MPI implementations (optional, for generality)."""
return "OMPI_COMM_WORLD_SIZE" in os.environ or "PMI_SIZE" in os.environ