Skip to content

Commit 743ab4a

Browse files
committed
misc
1 parent b126f6e commit 743ab4a

File tree

6 files changed

+92
-82
lines changed

6 files changed

+92
-82
lines changed

torchtitan/config/job_config.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -397,19 +397,7 @@ class Parallelism:
397397
"""
398398
Expert parallelism degree. 1 means disabled. No effect for non-MoE models.
399399
400-
Currently, it is supported with the following constraints:
401-
402-
- when etp = tp:
403-
404-
- cp <= ep <= dp_shard * cp
405-
- ep % cp == 0
406-
- dp_shard * cp % ep == 0
407-
408-
- when etp = 1:
409-
410-
- cp * tp <= ep <= dp_shard * cp * tp
411-
- ep % (cp * tp) == 0
412-
- dp_shard * cp * tp % ep == 0
400+
Currently, etp is either 1 or is the same as tp.
413401
414402
Note that this is still an experimental feature. Some constraints will be
415403
relaxed soon when we have more flexible DeviceMesh support.

torchtitan/distributed/parallel_dims.py

Lines changed: 23 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,6 @@ def _validate(self):
5757

5858
if ep > 1:
5959
assert etp == tp or etp == 1, "Currently we only support ETP=TP or ETP=1"
60-
if etp == tp:
61-
# EP would borrow all cp and some dp_shard degree
62-
assert ep % cp == 0 and (dp_shard * cp) % ep == 0
63-
elif etp == 1:
64-
# EP would borrow all cp and tp and some dp_shard degree
65-
assert ep % (cp * tp) == 0 and (dp_shard * cp * tp) % ep == 0
6660

6761
def build_mesh(self) -> DeviceMesh:
6862
"""
@@ -71,15 +65,14 @@ def build_mesh(self) -> DeviceMesh:
7165
The following mesh dimensions will be created:
7266
7367
pp: Pipeline Parallelism (PP).
74-
spmd: Used by SPMD DTensor RNG seed.
7568
batch: Used by data loading to determine the global batch size and which
7669
part of the data each rank should read. This dimension includes both
7770
``dp_replicate`` and ``dp_shard``. The backend is set to ``fake`` for
7871
this dimension to avoid unnecessary process group creation.
7972
loss: Used by all-reduce when computing the loss. Includes ``dp_replicate``,
8073
``dp_shard``, and ``cp`` degrees, as all are data parallelisms.
8174
dp_replicate: For DDP or HSDP replicate dimension.
82-
fsdp: For FSDP dimension. This includes ``cp``.
75+
fsdp: For FSDP dimension. This includes ``dp_shard`` and ``cp``.
8376
cp: Context Parallelism (CP).
8477
tp: Tensor Parallelism (TP).
8578
ep: Expert Parallelism (EP).
@@ -89,7 +82,6 @@ def build_mesh(self) -> DeviceMesh:
8982
Note: All the dimensions above are created by unflattening the world mesh.
9083
This API performs the following unflatten operations:
9184
92-
["pp", "spmd"]
9385
["pp", "batch", "cp", "tp"]
9486
["pp", "loss", "tp"]
9587
["pp", "dp_replicate", "fsdp", "tp"]
@@ -127,20 +119,16 @@ def unflatten_mesh(
127119
loss = self.dp_replicate * self.dp_shard * self.cp
128120
fsdp = self.dp_shard * self.cp
129121
efsdp = fsdp * self.tp // (self.etp * self.ep)
130-
spmd = self.world_size // self.pp
131122

132123
self._world_mesh = init_device_mesh(
133124
device_type, (self.world_size,), mesh_dim_names=("world",)
134125
)
135-
pp_spmd_mesh = unflatten_mesh(self._world_mesh, ("pp", "spmd"), (self.pp, spmd))
136-
data_mesh = unflatten_mesh(
126+
dataloading_mesh = unflatten_mesh(
137127
self._world_mesh,
138128
("pp", "batch", "cp", "tp"),
139129
(self.pp, batch, self.cp, self.tp),
140130
)
141-
loss_mesh = unflatten_mesh(
142-
self._world_mesh, ("pp", "loss", "tp"), (self.pp, loss, self.tp)
143-
)
131+
loss_mesh = dataloading_mesh["batch", "cp"].flatten("loss_mesh")
144132
dense_mesh = unflatten_mesh(
145133
self._world_mesh,
146134
("pp", "dp_replicate", "fsdp", "tp"),
@@ -153,14 +141,13 @@ def unflatten_mesh(
153141
)
154142

155143
self._meshes = {
156-
"pp": pp_spmd_mesh["pp"],
157-
"spmd": pp_spmd_mesh["spmd"],
158-
"batch": data_mesh["batch"],
144+
"pp": dataloading_mesh["pp"],
145+
"batch": dataloading_mesh["batch"],
159146
"loss": loss_mesh["loss"],
160147
"dp_replicate": dense_mesh["dp_replicate"],
161148
"fsdp": dense_mesh["fsdp"],
162-
"cp": data_mesh["cp"],
163-
"tp": data_mesh["tp"],
149+
"cp": dataloading_mesh["cp"],
150+
"tp": dataloading_mesh["tp"],
164151
"ep": sparse_mesh["ep"],
165152
"efsdp": sparse_mesh["efsdp"],
166153
"etp": sparse_mesh["etp"],
@@ -180,7 +167,6 @@ def _validate_meshes(self):
180167
"""Validate that created meshes have the expected sizes."""
181168
expected_sizes = {
182169
"pp": self.pp,
183-
"spmd": self.world_size // self.pp,
184170
"batch": self.dp_replicate * self.dp_shard,
185171
"loss": self.dp_replicate * self.dp_shard * self.cp,
186172
"dp_replicate": self.dp_replicate,
@@ -199,34 +185,38 @@ def _validate_meshes(self):
199185
f"expected {expected_size}, got {actual_size}"
200186
)
201187

202-
def get_mesh(self, dim: str) -> DeviceMesh | None:
203-
"""Get a device mesh by dimension name.
188+
def get_mesh(self, dims: str | list[str]) -> DeviceMesh | None:
189+
"""Get a device mesh by dimension names.
204190
205191
Args:
206-
dim: Name of the mesh dimension. Valid options include:
207-
'pp', 'spmd', 'batch', 'loss', 'dp_replicate', 'fsdp',
192+
dims: Names of the mesh dimension. Valid options include:
193+
'pp', 'batch', 'loss', 'dp_replicate', 'fsdp',
208194
'cp', 'tp', 'ep', 'etp', 'efsdp'
209195
210196
Returns:
211-
DeviceMesh for the requested dimension, or None if the dimension
212-
has size 1 (i.e., parallelism is disabled for that dimension).
197+
DeviceMesh for the requested dimension(s), or None if any of
198+
dimension(s) has size 1 (i.e., parallelism is disabled for that dimension).
213199
214200
Raises:
215-
ValueError: If the requested dimension name is not valid.
201+
ValueError: If the requested dimension name(s) is not valid.
216202
"""
217203
if not self._meshes:
218204
self.build_mesh()
219205

220-
if dim not in self._meshes:
206+
if isinstance(dims, str):
207+
dims = [dims]
208+
209+
if not all(dim in self._meshes for dim in dims):
221210
valid_dims = sorted(self._meshes.keys())
222211
raise ValueError(
223-
f"Invalid mesh dim: '{dim}'. Valid dimensions are: {valid_dims}"
212+
f"Invalid mesh dim: '{dims}'. Valid dimensions are: {valid_dims}"
224213
)
225214

226-
if self._meshes[dim].size() == 1:
215+
if any(self._meshes[dim].size() == 1 for dim in dims):
227216
return None
228217

229-
return self._meshes[dim]
218+
meshes = [self._meshes[dim] for dim in dims]
219+
return meshes[0] if len(meshes) == 1 else DeviceMesh._concatenate(meshes)
230220

231221
def get_all_meshes(self) -> dict[str, DeviceMesh]:
232222
if not self._meshes:
@@ -256,7 +246,7 @@ def cp_enabled(self):
256246
return self.cp > 1
257247

258248
@property
259-
def batch_enabled(self):
249+
def dp_cp_enabled(self):
260250
return self.dp_enabled or self.cp_enabled
261251

262252
@property

torchtitan/distributed/utils.py

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -86,18 +86,29 @@ def set_determinism(
8686
device: torch.device,
8787
seed: int | None = None,
8888
deterministic: bool = False,
89-
distinct_seed_mesh_dim: str = "pp",
89+
distinct_seed_mesh_dims: list[str] | None = None,
9090
) -> None:
9191
"""
9292
Set the same DTensor manual seed for all dimensions in world mesh, but only different seeds
93-
across dimension denoted by `distinct_seed_mesh_dim`. An example use case is pipeline parallelism,
93+
across dimensions denoted by `distinct_seed_mesh_dims`. An example use case is pipeline parallelism,
9494
where we want to have the same seed across SPMD groups, but different seeds across PP groups.
9595
9696
Currently, does not set seeds for the CUDA RNG since TorchTitan always uses DTensor for SPMD parallelisms,
9797
and DTensor manages its own RNG tracker, but we could extend to support both if needed.
9898
9999
Set Determinism flags for increased reproducibility with loss of performance.
100+
101+
Args:
102+
world_mesh: Device mesh for distributed training
103+
device: Device to use
104+
seed: Base seed value (if None, will be determined automatically)
105+
deterministic: Whether to enable deterministic algorithms
106+
distinct_seed_mesh_dims: List of mesh dimension names to have distinct seeds across.
107+
If None, defaults to ["pp"] for backward compatibility.
100108
"""
109+
if distinct_seed_mesh_dims is None:
110+
distinct_seed_mesh_dims = ["pp"]
111+
101112
if deterministic:
102113
logger.info("Deterministic algorithm enabled (expect perf degradation).")
103114
torch.use_deterministic_algorithms(True)
@@ -115,7 +126,7 @@ def set_determinism(
115126

116127
FlexAttentionWrapper._compiled_flex_attn = torch.compile(flex_attention)
117128

118-
if parallel_dims.world_size == 1:
129+
if not parallel_dims.world_size == 1:
119130
if seed is not None:
120131
torch.manual_seed(seed)
121132
os.environ["PYTHONHASHSEED"] = str(seed % 2**32)
@@ -131,19 +142,46 @@ def set_determinism(
131142
torch.distributed.broadcast(seed_tensor, src=0)
132143
seed = seed_tensor.to("cpu").view(torch.uint64).item()
133144

134-
# Set distinct seed for each rank in mesh dimensions, with dimension name provided by `distinct_seed_mesh_dim`
145+
# Set distinct seed for each rank in mesh dimensions, with dimension names provided by `distinct_seed_mesh_dims`
135146
# For PP + SPMD cases, we want to separate the world into the SPMD mesh and the PP mesh,
136147
# and choose a unique seed for each rank on the PP mesh.
137-
# TODO(jianiw): We could further extend this to support multiple distinct dimensions instead of just one.
138-
duplicate_seed_mesh = parallel_dims.get_mesh("spmd")
139-
logger.debug(f"Global Rank {c10d.get_rank()} using seed: {seed}")
140-
all_meshes = parallel_dims.get_all_meshes()
141-
if distinct_seed_mesh_dim in all_meshes.keys():
142-
distinct_mesh = all_meshes[distinct_seed_mesh_dim]
143-
seed += distinct_mesh.get_local_rank()
148+
# We support multiple distinct dimensions by adding each distinct dimension's local rank to the seed.
149+
distinct_seed_meshes = [
150+
parallel_dims.get_mesh(dim) for dim in distinct_seed_mesh_dims
151+
]
152+
distinct_seed_meshes = [mesh for mesh in distinct_seed_meshes if mesh is not None]
153+
154+
if distinct_seed_meshes:
155+
# Use mixed-radix positional system to ensure unique seed per coordinate
156+
# Each dimension contributes: local_rank * (product of all previous dimension sizes)
157+
# This guarantees uniqueness like multi-dimensional array indexing
158+
seed_offset = 0
159+
cumulative_size = 1
160+
161+
for distinct_mesh in distinct_seed_meshes:
162+
local_rank = distinct_mesh.get_local_rank()
163+
# Add contribution from this dimension
164+
seed_offset += local_rank * cumulative_size
165+
166+
# Update cumulative size for next dimension
167+
cumulative_size *= distinct_mesh.size()
168+
169+
seed += seed_offset
144170
seed %= 2**64
145171

146-
logger.debug(f"{distinct_seed_mesh_dim} rank {distinct_mesh.get_local_rank()}")
172+
logger.debug(
173+
f"Distinct dims {distinct_dims_in_mesh}, Global rank {c10d.get_rank()} using seed: {seed}"
174+
)
175+
176+
# Filter out all distinct dimensions to get duplicate_seed_mesh
177+
duplicate_seed_meshes = [
178+
v
179+
for k, v in parallel_dims.get_all_meshes()
180+
if k not in distinct_dims_in_mesh
181+
]
182+
else:
183+
duplicate_seed_meshes = [parallel_dims.world_mesh]
184+
logger.debug(f"Global Rank {c10d.get_rank()} using seed: {seed}")
147185

148186
# The native RNGs and python RNG may not be important, except for the 1-D PP case, but we seed them for consistency.
149187
torch.manual_seed(seed)
@@ -152,8 +190,8 @@ def set_determinism(
152190

153191
# As long as we are not in the 1-D (PP-only) case, we will have a seed to use for all ranks of the SPMD mesh.
154192
# IF PP is also used, this seed is unique per PP rank.
155-
if duplicate_seed_mesh:
156-
torch.distributed.tensor._random.manual_seed(seed, duplicate_seed_mesh)
193+
if duplicate_seed_meshes:
194+
torch.distributed.tensor._random.manual_seed(seed, duplicate_seed_meshes[0])
157195

158196

159197
def create_context_parallel_ctx(

torchtitan/models/llama3/infra/parallelize.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,10 @@ def parallelize_llama(
111111

112112
if parallel_dims.fsdp_enabled:
113113
# dp_mesh is the mesh for FSDP/HSDP
114-
if parallel_dims.dp_replicate_enabled:
115-
dp_mesh = DeviceMesh._concatenate(
116-
[parallel_dims.get_mesh("dp_replicate"), parallel_dims.get_mesh("fsdp")]
117-
)
118-
else:
119-
dp_mesh = parallel_dims.get_mesh("fsdp")
114+
names = (
115+
["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"]
116+
)
117+
dp_mesh = parallel_dims.get_mesh(names)
120118
apply_fsdp(
121119
model,
122120
dp_mesh,

torchtitan/models/llama4/infra/parallelize.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def parallelize_llama(
101101
tp_mesh=tp_mesh,
102102
ep_mesh=parallel_dims.get_mesh("ep"),
103103
etp_mesh=parallel_dims.get_mesh("etp"),
104-
etp_enabled=parallel_dims.etp_enabled,
104+
ep_etp_mesh=parallel_dims.get_mesh(["ep", "etp"]),
105105
)
106106

107107
model_compile_enabled = (
@@ -123,23 +123,16 @@ def parallelize_llama(
123123

124124
if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled:
125125
# dp_mesh is the mesh for FSDP/HSDP
126-
if parallel_dims.dp_replicate_enabled:
127-
dp_mesh = DeviceMesh._concatenate(
128-
[parallel_dims.get_mesh("dp_replicate"), parallel_dims.get_mesh("fsdp")]
129-
)
130-
else:
131-
dp_mesh = parallel_dims.get_mesh("fsdp")
126+
names = (
127+
["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"]
128+
)
129+
dp_mesh = parallel_dims.get_mesh(names)
132130

133131
# the mesh dim names of which the MoE params are sharded on via FSDP/HSDP
134132
dp_mod_ep_mesh = None
135133
if parallel_dims.ep_enabled:
136134
if parallel_dims.dp_replicate_enabled:
137-
dp_mod_ep_mesh = DeviceMesh._concatenate(
138-
[
139-
parallel_dims.get_mesh("dp_replicate"),
140-
parallel_dims.get_mesh("efsdp"),
141-
]
142-
)
135+
dp_mod_ep_mesh = parallel_dims.get_mesh(["dp_replicate", "efsdp"])
143136
else:
144137
dp_mod_ep_mesh = parallel_dims.get_mesh("efsdp")
145138

@@ -434,6 +427,7 @@ def apply_moe_ep_tp(
434427
tp_mesh: DeviceMesh | None,
435428
ep_mesh: DeviceMesh | None,
436429
etp_mesh: DeviceMesh | None,
430+
ep_etp_mesh: DeviceMesh | None,
437431
):
438432
assert ep_mesh is not None or tp_mesh is not None
439433

@@ -477,17 +471,19 @@ def apply_moe_ep_tp(
477471
parallelize_plan=moe_layer_plan,
478472
)
479473

480-
experts_mesh, experts_plan = None, None
474+
expert_mesh, experts_plan = None, None
481475
if ep_mesh is None:
476+
assert ep_etp_mesh is None
482477
experts_mesh = tp_mesh
483478
# input Replicate, output Partial
484479
experts_plan = TensorParallel()
485480
elif tp_mesh is None or etp_mesh is None:
481+
assert ep_etp_mesh is None
486482
experts_mesh = ep_mesh
487483
# input / output sharding on the batch / tokens dim
488484
experts_plan = ExpertParallel()
489485
else:
490-
experts_mesh = DeviceMesh._concatenate([ep_mesh, etp_mesh])
486+
experts_mesh = ep_etp_mesh
491487
experts_plan = ExpertTensorParallel()
492488

493489
parallelize_module(

torchtitan/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ def train_step(
526526
if not self.metrics_processor.should_log(self.step):
527527
return
528528

529-
if parallel_dims.batch_enabled:
529+
if parallel_dims.dp_cp_enabled:
530530
loss = loss.detach()
531531
ft_pg = self.ft_manager.loss_sync_pg
532532
batch_mesh = parallel_dims.get_mesh("batch")

0 commit comments

Comments
 (0)