Skip to content

Commit a67e87a

Browse files
committed
misc
1 parent 9a71cff commit a67e87a

File tree

1 file changed

+21
-5
lines changed

1 file changed

+21
-5
lines changed

torchtitan/distributed/parallel_dims.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,13 @@ def unflatten_mesh(
140140
(self.pp, self.dp_replicate, efsdp, self.ep, self.etp),
141141
)
142142

143+
# We have created all the required 1D meshes. This part is to create the
144+
# all the 2D meshes. We pre-created 2D meshes and error out if the users
145+
# try to access a 2D mesh that is not pre-created.
146+
hsdp_mesh = dense_mesh["dp_replicate", "fsdp"]
147+
ehsdp_mesh = sparse_mesh["dp_replicate", "efsdp"]
148+
ep_etp_mesh = sparse_mesh["ep", "etp"]
149+
143150
self._meshes = {
144151
"pp": dataloading_mesh["pp"],
145152
"batch": dataloading_mesh["batch"],
@@ -151,6 +158,9 @@ def unflatten_mesh(
151158
"ep": sparse_mesh["ep"],
152159
"efsdp": sparse_mesh["efsdp"],
153160
"etp": sparse_mesh["etp"],
161+
"dp_replicate_fsdp": hsdp_mesh,
162+
"dp_replicate_efsdp": ehsdp_mesh,
163+
"ep_etp": ep_etp_mesh,
154164
}
155165

156166
# Validate mesh sizes
@@ -176,6 +186,12 @@ def _validate_meshes(self):
176186
"ep": self.ep,
177187
"efsdp": self.dp_shard * self.cp * self.tp // (self.etp * self.ep),
178188
"etp": self.etp,
189+
"dp_replicate_fsdp": (self.dp_replicate, self.dp_shard * self.cp),
190+
"dp_replicate_efsdp": (
191+
self.dp_replicate,
192+
self.dp_shard * self.cp * self.tp // (self.etp * self.ep),
193+
),
194+
"ep_etp": (self.ep, self.etp),
179195
}
180196

181197
for mesh_name, expected_size in expected_sizes.items():
@@ -206,17 +222,17 @@ def get_mesh(self, dims: str | list[str]) -> DeviceMesh | None:
206222
if isinstance(dims, str):
207223
dims = [dims]
208224

209-
if not all(dim in self._meshes for dim in dims):
210-
valid_dims = sorted(self._meshes.keys())
225+
mesh_name = "_".join(dims)
226+
if mesh_name not in self._meshes:
211227
raise ValueError(
212-
f"Invalid mesh dim: '{dims}'. Valid dimensions are: {valid_dims}"
228+
f"Invalid mesh dim: '{mesh_name}'. "
229+
f"Valid dimensions are: {list(self._meshes.keys())}"
213230
)
214231

215232
if any(self._meshes[dim].size() == 1 for dim in dims):
216233
return None
217234

218-
meshes = [self._meshes[dim] for dim in dims]
219-
return meshes[0] if len(meshes) == 1 else DeviceMesh._concatenate(meshes)
235+
return self._meshes[mesh_name]
220236

221237
def get_all_meshes(self) -> dict[str, DeviceMesh]:
222238
if not self._meshes:

0 commit comments

Comments
 (0)