@@ -70,18 +70,10 @@ def parallelize_deepseekv3(
7070 if parallel_dims .tp_enabled or parallel_dims .ep_enabled :
7171 apply_moe_ep_tp (
7272 model ,
73- tp_mesh = parallel_dims .get_mesh ("tp" ) if parallel_dims .tp_enabled else None ,
74- ep_mesh = parallel_dims .get_mesh ("ep" ) if parallel_dims .ep_enabled else None ,
75- etp_mesh = parallel_dims .get_mesh ("etp" )
76- if parallel_dims .etp_enabled
77- else None ,
78- ep_etp_mesh = (
79- parallel_dims .get_mesh ("ep_etp" )
80- if parallel_dims .tp_enabled
81- and parallel_dims .ep_enabled
82- and parallel_dims .etp_enabled
83- else None
84- ),
73+ tp_mesh = parallel_dims .get_mesh ("tp" ),
74+ ep_mesh = parallel_dims .get_mesh ("ep" ),
75+ etp_mesh = parallel_dims .get_mesh ("etp" ),
76+ ep_etp_mesh = parallel_dims .get_mesh (["ep" , "etp" ]),
8577 )
8678
8779 if job_config .activation_checkpoint .mode != "none" :
@@ -139,13 +131,13 @@ def parallelize_deepseekv3(
139131 assert edp_mesh is not None
140132 assert hasattr (transformer_block , "moe" )
141133 if (
142- dp_mod_ep_mesh .size () * parallel_dims .ep
134+ edp_mesh .size () * parallel_dims .ep
143135 > transformer_block .moe .experts .num_experts
144136 ):
145137 experts_shard_dim = 1
146138
147139 # when EP is enable, the routed experts' gradient reduction is done over
148- # dp_mod_ep_mesh instead of whole dp_mesh.
140+ # edp_mesh instead of whole dp_mesh.
149141 # we add a `fsdp_gradient_divide_factor` to scale gradient over dp_mesh
150142 # to be consistent with data.
151143 # TODO (ruisizhang123): update the logic following the link below instead
0 commit comments