@@ -67,10 +67,6 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
6767class ExpertParallel (ParallelStyle ):
6868 def __init__ (self ):
6969 super ().__init__ ()
70- self .input_splits = None
71- self .output_splits = None
72- self .input_shape = None
73- self .permuted_indices = None
7470
7571 # performing all-to-all dispatch on the input
7672 def _token_dispatch (self , mod , inputs , device_mesh ):
@@ -103,14 +99,14 @@ def _token_dispatch(self, mod, inputs, device_mesh):
10399 .sum (dim = 1 )
104100 .to (torch .device ("cpu" ), non_blocking = False )
105101 )
106- self . input_splits = input_splits .tolist ()
107- self . output_splits = output_splits .tolist ()
102+ input_splits = input_splits .tolist ()
103+ output_splits = output_splits .tolist ()
108104
109105 # perform all-to-all
110106 routed_input = all_to_all_single_autograd (
111107 routed_input ,
112- self . output_splits ,
113- self . input_splits ,
108+ output_splits ,
109+ input_splits ,
114110 device_mesh .get_group (),
115111 )
116112
@@ -127,15 +123,22 @@ def _token_dispatch(self, mod, inputs, device_mesh):
127123 # of GroupedExperts, as it does not need padding.
128124
129125 (
130- self . input_shape ,
126+ input_shape ,
131127 routed_input ,
132- self . permuted_indices ,
128+ permuted_indices ,
133129 num_tokens_per_expert_group ,
134130 ) = _permute (
135131 routed_input , num_tokens_per_expert_group , ep_degree , num_local_experts
136132 )
137133
138- return routed_input , num_tokens_per_expert_group
134+ return (
135+ routed_input ,
136+ num_tokens_per_expert_group ,
137+ input_shape ,
138+ permuted_indices ,
139+ input_splits ,
140+ output_splits ,
141+ )
139142
140143 @staticmethod
141144 def _partition_fn (name , mod , device_mesh ):
@@ -145,15 +148,20 @@ def _partition_fn(name, mod, device_mesh):
145148 mod .register_parameter (name , dist_param )
146149
147150 # performing all-to-all combine on the output
148- def _token_combine (self , mod , routed_output , device_mesh ):
149- routed_output = _unpermute (
150- routed_output , self .input_shape , self .permuted_indices
151- )
151+ def _token_combine (self , mod , mod_outputs , device_mesh ):
152+ (
153+ routed_output ,
154+ input_shape ,
155+ permuted_indices ,
156+ input_splits ,
157+ output_splits ,
158+ ) = mod_outputs
159+ routed_output = _unpermute (routed_output , input_shape , permuted_indices )
152160
153161 routed_output = all_to_all_single_autograd (
154162 routed_output ,
155- self . input_splits ,
156- self . output_splits ,
163+ input_splits ,
164+ output_splits ,
157165 device_mesh .get_group (),
158166 )
159167 return routed_output
@@ -204,9 +212,9 @@ def _partition_fn_2d(self, name, mod, ep_tp_mesh):
204212 nn .Parameter (distribute_tensor (mod .w3 , ep_tp_mesh , [Shard (0 ), Shard (1 )])),
205213 ) # Column-wise sharding
206214
207- def _token_combine (self , mod , routed_output , device_mesh ):
215+ def _token_combine (self , mod , mod_outputs , device_mesh ):
208216 # token combine happens on the EP mesh, whereas device_mesh is [ep, tp] mesh
209- return super ()._token_combine (mod , routed_output , device_mesh ["ep" ])
217+ return super ()._token_combine (mod , mod_outputs , device_mesh ["ep" ])
210218
211219 def _apply (self , module : nn .Module , device_mesh : DeviceMesh ) -> nn .Module :
212220 return distribute_module (
0 commit comments