Skip to content

Commit d6eae58

Browse files
committed
misc
1 parent 8db2f16 commit d6eae58

File tree

2 files changed

+7
-8
lines changed

2 files changed

+7
-8
lines changed

torchtitan/experiments/simple_fsdp/tests/test_numerics.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@ def init_test(self):
2020
self.loss_fn = cross_entropy_loss
2121
data_parallel_shard_degree = -1
2222
if self.mode == "replicate":
23-
self.dp_mesh_dim_names = ("dp_replicate",)
23+
self.dp_mesh_dim_names = ["dp_replicate"]
2424
data_parallel_replicate_degree = self.world_size
2525
elif self.mode == "fully_shard":
26-
self.dp_mesh_dim_names = ("dp_shard_cp",)
26+
self.dp_mesh_dim_names = ["fsdp"]
2727
data_parallel_replicate_degree = 1
2828
elif self.mode == "hybrid_shard":
29-
self.dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
29+
self.dp_mesh_dim_names = ["dp_replicate", "fsdp"]
3030
data_parallel_replicate_degree = self.world_size // 2
3131
else:
3232
raise ValueError(f"Unsupported mode {self.mode}")
@@ -41,7 +41,6 @@ def init_test(self):
4141
etp=1,
4242
world_size=self.world_size,
4343
)
44-
self.device_mesh = self.parallel_dims.world_mesh
4544

4645
def get_input(self):
4746
inputs = torch.randn(8, 8).cuda()
@@ -50,7 +49,7 @@ def get_input(self):
5049
return model, inputs, labels
5150

5251
def run_fsdp2(self, model, inputs, labels, epoch=20):
53-
fully_shard(model, mesh=self.device_mesh[tuple(self.dp_mesh_dim_names)])
52+
fully_shard(model, mesh=self.parallel_dims.get_mesh(self.dp_mesh_dim_names))
5453
optim = self.optimizer(model.parameters(), lr=1e-4)
5554
losses = []
5655
for _ in range(epoch):
@@ -65,7 +64,7 @@ def run_fsdp2(self, model, inputs, labels, epoch=20):
6564
def run_simple_fsdp(self, model, inputs, labels, epoch=20):
6665
model = data_parallel(
6766
model,
68-
device_mesh=self.device_mesh[tuple(self.dp_mesh_dim_names)],
67+
device_mesh=self.parallel_dims.get_mesh(self.dp_mesh_dim_names),
6968
mode=self.mode,
7069
)
7170
optim = self.optimizer(model.parameters(), lr=1e-4)
@@ -82,7 +81,7 @@ def run_simple_fsdp(self, model, inputs, labels, epoch=20):
8281
def run_simple_fsdp_compiled_aot_eager(self, model, inputs, labels, epoch=20):
8382
model = data_parallel(
8483
model,
85-
device_mesh=self.device_mesh[tuple(self.dp_mesh_dim_names)],
84+
device_mesh=self.parallel_dims.get_mesh(self.dp_mesh_dim_names),
8685
mode=self.mode,
8786
)
8887
# TODO: Add "inductor" backend when it's numerical issues are fixed

torchtitan/models/flux/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(self, job_config: JobConfig):
3535
self.parallel_dims,
3636
self.device,
3737
job_config.debug,
38-
distinct_seed_mesh_dims=["dp_shard", "dp_replicate"],
38+
distinct_seed_mesh_dims=["fsdp", "dp_replicate"],
3939
)
4040

4141
# NOTE: self._dtype is the data type used for encoders (image encoder, T5 text encoder, CLIP text encoder).

0 commit comments

Comments
 (0)