@@ -57,12 +57,6 @@ def _validate(self):
5757
5858 if ep > 1 :
5959 assert etp == tp or etp == 1 , "Currently we only support ETP=TP or ETP=1"
60- if etp == tp :
61- # EP would borrow all cp and some dp_shard degree
62- assert ep % cp == 0 and (dp_shard * cp ) % ep == 0
63- elif etp == 1 :
64- # EP would borrow all cp and tp and some dp_shard degree
65- assert ep % (cp * tp ) == 0 and (dp_shard * cp * tp ) % ep == 0
6660
6761 def build_mesh (self ) -> DeviceMesh :
6862 """
@@ -71,15 +65,14 @@ def build_mesh(self) -> DeviceMesh:
7165 The following mesh dimensions will be created:
7266
7367 pp: Pipeline Parallelism (PP).
74- spmd: Used by SPMD DTensor RNG seed.
7568 batch: Used by data loading to determine the global batch size and which
7669 part of the data each rank should read. This dimension includes both
7770 ``dp_replicate`` and ``dp_shard``. The backend is set to ``fake`` for
7871 this dimension to avoid unnecessary process group creation.
7972 loss: Used by all-reduce when computing the loss. Includes ``dp_replicate``,
8073 ``dp_shard``, and ``cp`` degrees, as all are data parallelisms.
8174 dp_replicate: For DDP or HSDP replicate dimension.
82- fsdp: For FSDP dimension. This includes ``cp``.
75+ fsdp: For FSDP dimension. This includes ``dp_shard`` and `` cp``.
8376 cp: Context Parallelism (CP).
8477 tp: Tensor Parallelism (TP).
8578 ep: Expert Parallelism (EP).
@@ -89,7 +82,6 @@ def build_mesh(self) -> DeviceMesh:
8982 Note: All the dimensions above are created by unflattening the world mesh.
9083 This API performs the following unflatten operations:
9184
92- ["pp", "spmd"]
9385 ["pp", "batch", "cp", "tp"]
9486 ["pp", "loss", "tp"]
9587 ["pp", "dp_replicate", "fsdp", "tp"]
@@ -127,20 +119,16 @@ def unflatten_mesh(
127119 loss = self .dp_replicate * self .dp_shard * self .cp
128120 fsdp = self .dp_shard * self .cp
129121 efsdp = fsdp * self .tp // (self .etp * self .ep )
130- spmd = self .world_size // self .pp
131122
132123 self ._world_mesh = init_device_mesh (
133124 device_type , (self .world_size ,), mesh_dim_names = ("world" ,)
134125 )
135- pp_spmd_mesh = unflatten_mesh (self ._world_mesh , ("pp" , "spmd" ), (self .pp , spmd ))
136- data_mesh = unflatten_mesh (
126+ dataloading_mesh = unflatten_mesh (
137127 self ._world_mesh ,
138128 ("pp" , "batch" , "cp" , "tp" ),
139129 (self .pp , batch , self .cp , self .tp ),
140130 )
141- loss_mesh = unflatten_mesh (
142- self ._world_mesh , ("pp" , "loss" , "tp" ), (self .pp , loss , self .tp )
143- )
131+ loss_mesh = dataloading_mesh ["batch" , "cp" ].flatten ("loss_mesh" )
144132 dense_mesh = unflatten_mesh (
145133 self ._world_mesh ,
146134 ("pp" , "dp_replicate" , "fsdp" , "tp" ),
@@ -153,14 +141,13 @@ def unflatten_mesh(
153141 )
154142
155143 self ._meshes = {
156- "pp" : pp_spmd_mesh ["pp" ],
157- "spmd" : pp_spmd_mesh ["spmd" ],
158- "batch" : data_mesh ["batch" ],
144+ "pp" : dataloading_mesh ["pp" ],
145+ "batch" : dataloading_mesh ["batch" ],
159146 "loss" : loss_mesh ["loss" ],
160147 "dp_replicate" : dense_mesh ["dp_replicate" ],
161148 "fsdp" : dense_mesh ["fsdp" ],
162- "cp" : data_mesh ["cp" ],
163- "tp" : data_mesh ["tp" ],
149+ "cp" : dataloading_mesh ["cp" ],
150+ "tp" : dataloading_mesh ["tp" ],
164151 "ep" : sparse_mesh ["ep" ],
165152 "efsdp" : sparse_mesh ["efsdp" ],
166153 "etp" : sparse_mesh ["etp" ],
@@ -180,7 +167,6 @@ def _validate_meshes(self):
180167 """Validate that created meshes have the expected sizes."""
181168 expected_sizes = {
182169 "pp" : self .pp ,
183- "spmd" : self .world_size // self .pp ,
184170 "batch" : self .dp_replicate * self .dp_shard ,
185171 "loss" : self .dp_replicate * self .dp_shard * self .cp ,
186172 "dp_replicate" : self .dp_replicate ,
@@ -199,34 +185,38 @@ def _validate_meshes(self):
199185 f"expected { expected_size } , got { actual_size } "
200186 )
201187
202- def get_mesh (self , dim : str ) -> DeviceMesh | None :
203- """Get a device mesh by dimension name .
188+ def get_mesh (self , dims : str | list [ str ] ) -> DeviceMesh | None :
189+ """Get a device mesh by dimension names .
204190
205191 Args:
206- dim: Name of the mesh dimension. Valid options include:
207- 'pp', 'spmd', ' batch', 'loss', 'dp_replicate', 'fsdp',
192+ dims: Names of the mesh dimension. Valid options include:
193+ 'pp', 'batch', 'loss', 'dp_replicate', 'fsdp',
208194 'cp', 'tp', 'ep', 'etp', 'efsdp'
209195
210196 Returns:
211- DeviceMesh for the requested dimension, or None if the dimension
212- has size 1 (i.e., parallelism is disabled for that dimension).
197+ DeviceMesh for the requested dimension(s) , or None if any of
198+ dimension(s) has size 1 (i.e., parallelism is disabled for that dimension).
213199
214200 Raises:
215- ValueError: If the requested dimension name is not valid.
201+ ValueError: If the requested dimension name(s) is not valid.
216202 """
217203 if not self ._meshes :
218204 self .build_mesh ()
219205
220- if dim not in self ._meshes :
206+ if isinstance (dims , str ):
207+ dims = [dims ]
208+
209+ if not all (dim in self ._meshes for dim in dims ):
221210 valid_dims = sorted (self ._meshes .keys ())
222211 raise ValueError (
223- f"Invalid mesh dim: '{ dim } '. Valid dimensions are: { valid_dims } "
212+ f"Invalid mesh dim: '{ dims } '. Valid dimensions are: { valid_dims } "
224213 )
225214
226- if self ._meshes [dim ].size () == 1 :
215+ if any ( self ._meshes [dim ].size () == 1 for dim in dims ) :
227216 return None
228217
229- return self ._meshes [dim ]
218+ meshes = [self ._meshes [dim ] for dim in dims ]
219+ return meshes [0 ] if len (meshes ) == 1 else DeviceMesh ._concatenate (meshes )
230220
231221 def get_all_meshes (self ) -> dict [str , DeviceMesh ]:
232222 if not self ._meshes :
@@ -256,7 +246,7 @@ def cp_enabled(self):
256246 return self .cp > 1
257247
258248 @property
259- def batch_enabled (self ):
249+ def dp_cp_enabled (self ):
260250 return self .dp_enabled or self .cp_enabled
261251
262252 @property
0 commit comments