Skip to content

Commit 178bc11

Browse files
committed
misc
1 parent 743ab4a commit 178bc11

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

torchtitan/distributed/parallel_dims.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def unflatten_mesh(
128128
("pp", "batch", "cp", "tp"),
129129
(self.pp, batch, self.cp, self.tp),
130130
)
131-
loss_mesh = dataloading_mesh["batch", "cp"].flatten("loss_mesh")
131+
loss_mesh = dataloading_mesh["batch", "cp"]._flatten("loss_mesh")
132132
dense_mesh = unflatten_mesh(
133133
self._world_mesh,
134134
("pp", "dp_replicate", "fsdp", "tp"),
@@ -143,7 +143,7 @@ def unflatten_mesh(
143143
self._meshes = {
144144
"pp": dataloading_mesh["pp"],
145145
"batch": dataloading_mesh["batch"],
146-
"loss": loss_mesh["loss"],
146+
"loss": loss_mesh,
147147
"dp_replicate": dense_mesh["dp_replicate"],
148148
"fsdp": dense_mesh["fsdp"],
149149
"cp": dataloading_mesh["cp"],

torchtitan/distributed/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,8 @@ def set_determinism(
191191
# As long as we are not in the 1-D (PP-only) case, we will have a seed to use for all ranks of the SPMD mesh.
192192
# IF PP is also used, this seed is unique per PP rank.
193193
if duplicate_seed_meshes:
194+
# We only need to pass one submesh as DTensor manual_seed uses that to verify
195+
# if the current rank is actually in the submesh.
194196
torch.distributed.tensor._random.manual_seed(seed, duplicate_seed_meshes[0])
195197

196198

0 commit comments

Comments
 (0)