@@ -20,13 +20,13 @@ def init_test(self):
2020 self .loss_fn = cross_entropy_loss
2121 data_parallel_shard_degree = - 1
2222 if self .mode == "replicate" :
23- self .dp_mesh_dim_names = ( "dp_replicate" ,)
23+ self .dp_mesh_dim_names = [ "dp_replicate" ]
2424 data_parallel_replicate_degree = self .world_size
2525 elif self .mode == "fully_shard" :
26- self .dp_mesh_dim_names = ( "dp_shard_cp" ,)
26+ self .dp_mesh_dim_names = [ "fsdp" ]
2727 data_parallel_replicate_degree = 1
2828 elif self .mode == "hybrid_shard" :
29- self .dp_mesh_dim_names = ( "dp_replicate" , "dp_shard_cp" )
29+ self .dp_mesh_dim_names = [ "dp_replicate" , "fsdp" ]
3030 data_parallel_replicate_degree = self .world_size // 2
3131 else :
3232 raise ValueError (f"Unsupported mode { self .mode } " )
@@ -41,7 +41,6 @@ def init_test(self):
4141 etp = 1 ,
4242 world_size = self .world_size ,
4343 )
44- self .device_mesh = self .parallel_dims .world_mesh
4544
4645 def get_input (self ):
4746 inputs = torch .randn (8 , 8 ).cuda ()
@@ -50,7 +49,7 @@ def get_input(self):
5049 return model , inputs , labels
5150
5251 def run_fsdp2 (self , model , inputs , labels , epoch = 20 ):
53- fully_shard (model , mesh = self .device_mesh [ tuple (self .dp_mesh_dim_names )] )
52+ fully_shard (model , mesh = self .parallel_dims . get_mesh (self .dp_mesh_dim_names ))
5453 optim = self .optimizer (model .parameters (), lr = 1e-4 )
5554 losses = []
5655 for _ in range (epoch ):
@@ -65,7 +64,7 @@ def run_fsdp2(self, model, inputs, labels, epoch=20):
6564 def run_simple_fsdp (self , model , inputs , labels , epoch = 20 ):
6665 model = data_parallel (
6766 model ,
68- device_mesh = self .device_mesh [ tuple (self .dp_mesh_dim_names )] ,
67+ device_mesh = self .parallel_dims . get_mesh (self .dp_mesh_dim_names ),
6968 mode = self .mode ,
7069 )
7170 optim = self .optimizer (model .parameters (), lr = 1e-4 )
@@ -82,7 +81,7 @@ def run_simple_fsdp(self, model, inputs, labels, epoch=20):
8281 def run_simple_fsdp_compiled_aot_eager (self , model , inputs , labels , epoch = 20 ):
8382 model = data_parallel (
8483 model ,
85- device_mesh = self .device_mesh [ tuple (self .dp_mesh_dim_names )] ,
84+ device_mesh = self .parallel_dims . get_mesh (self .dp_mesh_dim_names ),
8685 mode = self .mode ,
8786 )
8887 # TODO: Add "inductor" backend when it's numerical issues are fixed
0 commit comments