@@ -93,3 +93,212 @@ def forward(self, *args):
9393 x = x + x_in .repeat ([* [1 for _ in x .shape [:- 1 ]], x .shape [- 1 ] // x_in .shape [- 1 ]])
9494
9595 return x
96+
97+ class _DenseBlock (nn .Module ):
98+ """A tiny FFN that mirrors the structure of the current MLP stack."""
99+ def __init__ (self , dim_in , dim_hidden , dim_out , num_layers = 2 ,
100+ nonlin = nn .GELU , dropout_rate = 0.0 ):
101+ super ().__init__ ()
102+ layers = [nn .Linear (dim_in , dim_hidden ), nonlin (), nn .Dropout (dropout_rate )]
103+ for _ in range (num_layers - 2 ):
104+ layers += [nn .Linear (dim_hidden , dim_hidden ), nonlin (), nn .Dropout (dropout_rate )]
105+ layers += [nn .Linear (dim_hidden , dim_out )]
106+ self .net = nn .Sequential (* layers )
107+
108+ def forward (self , x ):
109+ return self .net (x )
110+
111+ class MoEMLP (nn .Module ):
112+ """
113+ Drop-in MoE MLP (memory-friendly):
114+ - Same call pattern as the current MLP: forward(*args) where args=(x, ...) and optional aux at the end
115+ - Supports residual add exactly like MLP
116+ - Optional AdaLayerNorm when dim_aux is provided
117+ - Simple top-k router; mixes experts with streaming accumulation (no big [E, ..., D] stack)
118+ """
119+ def __init__ (
120+ self ,
121+ dim_in ,
122+ dim_out ,
123+ num_layers = 2 ,
124+ hidden_factor = 2 ,
125+ pre_layer_norm = True ,
126+ dropout_rate = 0.0 ,
127+ nonlin = nn .GELU ,
128+ with_residual = False ,
129+ norm_type = "LayerNorm" ,
130+ dim_aux = None ,
131+ norm_eps = 1e-5 ,
132+ name : str | None = None ,
133+ # MoE bits
134+ num_experts : int = 8 ,
135+ top_k : int = 4 ,
136+ router_noisy_std : float = 0.0 , # set >0 to add noise to router logits
137+ # Memory bits
138+ use_checkpoint : bool = False , # checkpoint expert forward to save memory
139+ ):
140+ super ().__init__ ()
141+ if name is not None :
142+ self .name = name
143+
144+ assert num_layers >= 2
145+ assert 1 <= top_k <= num_experts
146+
147+ self .with_residual = with_residual
148+ self .with_aux = dim_aux is not None
149+ self .pre_layer_norm = pre_layer_norm
150+ self .top_k = top_k
151+ self .num_experts = num_experts
152+ self .use_checkpoint = use_checkpoint
153+
154+ dim_hidden = int (dim_in * hidden_factor )
155+
156+ # Norm (match MLP behavior)
157+ Norm = nn .LayerNorm if norm_type == "LayerNorm" else RMSNorm
158+ if pre_layer_norm :
159+ self .norm = (
160+ Norm (dim_in , eps = norm_eps )
161+ if dim_aux is None
162+ else AdaLayerNorm (dim_in , dim_aux , norm_eps = norm_eps )
163+ )
164+ else :
165+ self .norm = None # no pre-norm
166+
167+ # Router
168+ self .router = nn .Linear (dim_in , num_experts )
169+ self .router_noisy_std = router_noisy_std
170+
171+ # Experts (identical shape)
172+ self .experts = nn .ModuleList (
173+ [
174+ _DenseBlock (
175+ dim_in = dim_in ,
176+ dim_hidden = dim_hidden ,
177+ dim_out = dim_out ,
178+ num_layers = num_layers ,
179+ nonlin = nonlin ,
180+ dropout_rate = dropout_rate ,
181+ )
182+ for _ in range (num_experts )
183+ ]
184+ )
185+
186+ # For optional aux loss (load-balancing); not used unless you read it
187+ self .register_buffer ("last_aux_loss" , torch .zeros ((), dtype = torch .float32 ))
188+
189+ def _gate (self , x_norm ):
190+ # x_norm: [*, D]. Router works on the last dim.
191+ logits = self .router (x_norm )
192+ if self .router_noisy_std > 0 :
193+ logits = logits + torch .randn_like (logits ) * self .router_noisy_std
194+
195+ if self .top_k == self .num_experts :
196+ # softmax over all experts
197+ weights = torch .softmax (logits , dim = - 1 ) # [..., E]
198+ top_idx = None # not needed
199+ else :
200+ # top-k softmax
201+ top_vals , top_idx = torch .topk (logits , k = self .top_k , dim = - 1 ) # [*, k]
202+ weights = torch .softmax (top_vals , dim = - 1 ) # [*, k]
203+ return weights , top_idx
204+
205+ @torch .no_grad ()
206+ def _compute_load_balance_aux (self , weights , top_idx , num_experts ):
207+ """
208+ Simple load-balancing penalty from Switch/MoE papers:
209+ Encourage uniform expert probability and uniform usage.
210+ Works with both full-softmax (top_idx None) and top-k.
211+ """
212+ if top_idx is None :
213+ # weights over E
214+ probs = weights .mean (dim = tuple (range (weights .dim () - 1 ))) # [E]
215+ else :
216+ # Build usage over experts from top-k selection
217+ # *prefix, K = weights.shape
218+ # flat_w = weights.reshape(-1, K) # [N, K]
219+ # flat_i = top_idx.reshape(-1, K) # [N, K]
220+ if weights .shape != top_idx .shape :
221+ raise ValueError (
222+ "Top-k weights and indices must share the same shape"
223+ )
224+
225+ K = weights .shape [- 1 ]
226+ flat_w = weights .reshape (- 1 , K ) # [N, K]
227+ flat_i = top_idx .reshape (- 1 , K ) # [N, K]
228+ E = num_experts
229+ usage = torch .zeros (E , device = weights .device , dtype = weights .dtype )
230+ usage .scatter_add_ (0 , flat_i .reshape (- 1 ), flat_w .reshape (- 1 ))
231+ usage = usage / usage .sum ().clamp_min (1e-6 ) # normalize
232+ probs = usage # proxy
233+ # Target is uniform 1/E
234+ E = num_experts
235+ target = torch .full_like (probs , 1.0 / E )
236+ aux = (probs * (probs .add (1e-6 ).log () - target .add (1e-6 ).log ())).sum ()
237+ return aux
238+
239+ def forward (self , * args ):
240+ # Match your MLP(*args) calling convention
241+ x = args [0 ]
242+ x_in = x
243+ aux = args [- 1 ] if self .with_aux else None
244+
245+ # Optional pre-norm (possibly adaptive)
246+ if self .norm is not None :
247+ if self .with_aux :
248+ x = self .norm (x , aux )
249+ else :
250+ x = self .norm (x )
251+
252+ # Router
253+ weights , top_idx = self ._gate (x ) # weights: [..., E] or [..., K]
254+
255+ # Build a full weight tensor [..., E] if we are in top-k mode,
256+ # so we can stream over experts without stacking their outputs.
257+ if top_idx is None :
258+ w_full = weights # [..., E]
259+ else :
260+ # scatter top-k weights into a zero tensor of size E
261+ E = self .num_experts
262+ w_full = torch .zeros (* weights .shape [:- 1 ], E , device = weights .device , dtype = weights .dtype ) # [..., E]
263+ w_full .scatter_ (- 1 , top_idx , weights )
264+
265+ # Output accumulator (no expert stacking)
266+ out_dim = self .experts [0 ].net [- 1 ].out_features # last Linear of _DenseBlock
267+ y = x .new_zeros (* x .shape [:- 1 ], out_dim )
268+
269+ # Optional gradient checkpoint
270+ if self .use_checkpoint :
271+ from torch .utils .checkpoint import checkpoint
272+
273+ # Stream over experts: y += expert(x) * w_full[..., e]
274+ for e , expert in enumerate (self .experts ):
275+ # skip compute if weight mass is (nearly) zero for this expert
276+ w_e = w_full [..., e ] # [...]
277+ if torch .allclose (w_e , torch .zeros ((), device = w_e .device , dtype = w_e .dtype )):
278+ continue
279+
280+ if self .use_checkpoint and self .training :
281+ y_e = checkpoint (expert , x )
282+ else :
283+ y_e = expert (x )
284+ y = y + y_e * w_e .unsqueeze (- 1 )
285+
286+ # Residual (same logic as your MLP)
287+ if self .with_residual :
288+ if y .shape [- 1 ] == x_in .shape [- 1 ]:
289+ y = x_in + y
290+ else :
291+ assert y .shape [- 1 ] % x_in .shape [- 1 ] == 0
292+ y = y + x_in .repeat ([* [1 for _ in y .shape [:- 1 ]], y .shape [- 1 ] // x_in .shape [- 1 ]])
293+
294+ # Optional: update aux loss (not returned; read if you want)
295+ with torch .no_grad ():
296+ self .last_aux_loss = self ._compute_load_balance_aux (
297+ # w_full if top_idx is not None else weights, # use full probs if we built them
298+ # None if top_idx is None else top_idx,
299+ weights ,
300+ top_idx ,
301+ self .num_experts ,
302+ )
303+
304+ return y
0 commit comments