@@ -140,6 +140,13 @@ def unflatten_mesh(
140140 (self .pp , self .dp_replicate , efsdp , self .ep , self .etp ),
141141 )
142142
143+ # We have created all the required 1D meshes. This part is to create the
144+ # all the 2D meshes. We pre-created 2D meshes and error out if the users
145+ # try to access a 2D mesh that is not pre-created.
146+ hsdp_mesh = dense_mesh ["dp_replicate" , "fsdp" ]
147+ ehsdp_mesh = sparse_mesh ["dp_replicate" , "efsdp" ]
148+ ep_etp_mesh = sparse_mesh ["ep" , "etp" ]
149+
143150 self ._meshes = {
144151 "pp" : dataloading_mesh ["pp" ],
145152 "batch" : dataloading_mesh ["batch" ],
@@ -151,6 +158,9 @@ def unflatten_mesh(
151158 "ep" : sparse_mesh ["ep" ],
152159 "efsdp" : sparse_mesh ["efsdp" ],
153160 "etp" : sparse_mesh ["etp" ],
161+ "dp_replicate_fsdp" : hsdp_mesh ,
162+ "dp_replicate_efsdp" : ehsdp_mesh ,
163+ "ep_etp" : ep_etp_mesh ,
154164 }
155165
156166 # Validate mesh sizes
@@ -176,6 +186,12 @@ def _validate_meshes(self):
176186 "ep" : self .ep ,
177187 "efsdp" : self .dp_shard * self .cp * self .tp // (self .etp * self .ep ),
178188 "etp" : self .etp ,
189+ "dp_replicate_fsdp" : (self .dp_replicate , self .dp_shard * self .cp ),
190+ "dp_replicate_efsdp" : (
191+ self .dp_replicate ,
192+ self .dp_shard * self .cp * self .tp // (self .etp * self .ep ),
193+ ),
194+ "ep_etp" : (self .ep , self .etp ),
179195 }
180196
181197 for mesh_name , expected_size in expected_sizes .items ():
@@ -206,17 +222,17 @@ def get_mesh(self, dims: str | list[str]) -> DeviceMesh | None:
206222 if isinstance (dims , str ):
207223 dims = [dims ]
208224
209- if not all ( dim in self . _meshes for dim in dims ):
210- valid_dims = sorted ( self ._meshes . keys ())
225+ mesh_name = "_" . join ( dims )
226+ if mesh_name not in self ._meshes :
211227 raise ValueError (
212- f"Invalid mesh dim: '{ dims } '. Valid dimensions are: { valid_dims } "
228+ f"Invalid mesh dim: '{ mesh_name } '. "
229+ f"Valid dimensions are: { list (self ._meshes .keys ())} "
213230 )
214231
215232 if any (self ._meshes [dim ].size () == 1 for dim in dims ):
216233 return None
217234
218- meshes = [self ._meshes [dim ] for dim in dims ]
219- return meshes [0 ] if len (meshes ) == 1 else DeviceMesh ._concatenate (meshes )
235+ return self ._meshes [mesh_name ]
220236
221237 def get_all_meshes (self ) -> dict [str , DeviceMesh ]:
222238 if not self ._meshes :
0 commit comments