2424 StreamEmbedLinear ,
2525 StreamEmbedTransformer ,
2626)
27- from weathergen .model .layers import MLP
27+ from weathergen .model .layers import MLP , MoEMLP
2828from weathergen .model .utils import ActivationFactory
2929from weathergen .utils .utils import get_dtype
3030
31+ import logging
32+ logger = logging .getLogger (__name__ )
3133
3234class EmbeddingEngine :
3335 name : "EmbeddingEngine"
@@ -249,17 +251,50 @@ def create(self) -> torch.nn.ModuleList:
249251 )
250252 )
251253 # MLP block
252- self .ae_global_blocks .append (
253- MLP (
254- self .cf .ae_global_dim_embed ,
255- self .cf .ae_global_dim_embed ,
256- with_residual = True ,
257- dropout_rate = self .cf .ae_global_dropout_rate ,
258- hidden_factor = self .cf .ae_global_mlp_hidden_factor ,
259- norm_type = self .cf .norm_type ,
260- norm_eps = self .cf .mlp_norm_eps ,
261- )
254+ # Add MoE option
255+ use_moe = getattr (self .cf , "ae_global_mlp_type" , "dense" ) == "moe"
256+ mlp_common_kwargs = dict (
257+ dim_in = self .cf .ae_global_dim_embed ,
258+ dim_out = self .cf .ae_global_dim_embed ,
259+ with_residual = True ,
260+ dropout_rate = self .cf .ae_global_dropout_rate ,
261+ norm_type = self .cf .norm_type ,
262+ norm_eps = self .cf .mlp_norm_eps ,
262263 )
264+ if use_moe :
265+ self .ae_global_blocks .append (
266+ MoEMLP (
267+ ** mlp_common_kwargs ,
268+ num_experts = getattr (self .cf , "ae_global_moe_num_experts" , 2 ),
269+ top_k = getattr (self .cf , "ae_global_moe_top_k" , 1 ),
270+ router_noisy_std = getattr (self .cf , "ae_global_moe_router_noisy_std" , 0.0 ),
271+ hidden_factor = getattr (self .cf , "ae_global_moe_hidden_factor" , 2 ),
272+ )
273+ )
274+ else :
275+ self .ae_global_blocks .append (
276+ MLP (
277+ self .cf .ae_global_dim_embed ,
278+ self .cf .ae_global_dim_embed ,
279+ with_residual = True ,
280+ dropout_rate = self .cf .ae_global_dropout_rate ,
281+ hidden_factor = self .cf .ae_global_mlp_hidden_factor ,
282+ norm_type = self .cf .norm_type ,
283+ norm_eps = self .cf .mlp_norm_eps ,
284+ )
285+ )
286+ # Count MoE blocks
287+ num_moe = sum (1 for m in self .ae_global_blocks if isinstance (m , MoEMLP ))
288+ logger .info (
289+ "[MoE] GlobalAssimilationEngine: %d MoEMLP blocks "
290+ "(ae_global_mlp_type=%s, experts=%s, top_k=%s, hidden_factor=%s)" ,
291+ num_moe ,
292+ getattr (self .cf , "ae_global_mlp_type" , "dense" ),
293+ getattr (self .cf , "ae_global_moe_num_experts" , None ),
294+ getattr (self .cf , "ae_global_moe_top_k" , None ),
295+ getattr (self .cf , "ae_global_moe_hidden_factor" , None ),
296+ )
297+
263298 return self .ae_global_blocks
264299
265300
@@ -343,8 +378,8 @@ def create(self) -> torch.nn.ModuleList:
343378 self .fe_blocks .append (
344379 MoEMLP (
345380 ** mlp_common_kwargs ,
346- num_experts = getattr (self .cf , "fe_moe_num_experts" , 8 ),
347- top_k = getattr (self .cf , "fe_moe_top_k" , 4 ),
381+ num_experts = getattr (self .cf , "fe_moe_num_experts" , 2 ),
382+ top_k = getattr (self .cf , "fe_moe_top_k" , 2 ),
348383 router_noisy_std = getattr (self .cf , "fe_moe_router_noisy_std" , 0.0 ),
349384 hidden_factor = getattr (self .cf , "fe_moe_hidden_factor" , 2 ),
350385 )
@@ -362,15 +397,24 @@ def create(self) -> torch.nn.ModuleList:
362397 )
363398 )
364399 # ------------------------------------------------------------------
365- def init_weights_final (m ):
366- if isinstance (m , torch .nn .Linear ):
367- torch .nn .init .normal_ (m .weight , mean = 0 , std = 0.001 )
368- if m .bias is not None :
369- torch .nn .init .normal_ (m .bias , mean = 0 , std = 0.001 )
370-
371- for block in self .fe_blocks :
372- block .apply (init_weights_final )
373-
400+ # def init_weights_final(m):
401+ # if isinstance(m, torch.nn.Linear) and not getattr(m, "is_moe_router", False):
402+ # torch.nn.init.normal_(m.weight, mean=0, std=0.001)
403+ # if m.bias is not None:
404+ # torch.nn.init.normal_(m.bias, mean=0, std=0.001)
405+
406+ # for block in self.fe_blocks:
407+ # block.apply(init_weights_final)
408+ num_moe = sum (1 for m in self .fe_blocks if isinstance (m , MoEMLP ))
409+ logger .info (
410+ "[MoE] ForecastingEngine: %d MoEMLP blocks "
411+ "(fe_mlp_type=%s, experts=%s, top_k=%s, hidden_factor=%s)" ,
412+ num_moe ,
413+ getattr (self .cf , "fe_mlp_type" , "dense" ),
414+ getattr (self .cf , "fe_moe_num_experts" , None ),
415+ getattr (self .cf , "fe_moe_top_k" , None ),
416+ getattr (self .cf , "fe_moe_hidden_factor" , None ),
417+ )
374418 return self .fe_blocks
375419
376420
@@ -619,6 +663,14 @@ def __init__(
619663 with_adanorm = False ,
620664 with_mlp = False ,
621665 attention_kwargs = attention_kwargs ,
666+ ffn_mlp_type = getattr (self .cf , "decoder_ffn_mlp_type" , "dense" ),
667+ ffn_hidden_factor = getattr (self .cf , "decoder_ffn_hidden_factor" , 4 ),
668+ moe_kwargs = dict (
669+ num_experts = getattr (self .cf , "decoder_moe_num_experts" , 2 ),
670+ top_k = getattr (self .cf , "decoder_moe_top_k" , 2 ),
671+ router_noisy_std = getattr (self .cf , "decoder_moe_router_noisy_std" , 0.0 ),
672+ use_checkpoint = getattr (self .cf , "decoder_moe_use_checkpoint" , False ),
673+ )
622674 )
623675 )
624676 elif self .cf .decoder_type == "AdaLayerNormConditioning" :
@@ -674,6 +726,14 @@ def __init__(
674726 tr_mlp_hidden_factor = tr_mlp_hidden_factor ,
675727 tro_type = tro_type ,
676728 mlp_norm_eps = self .cf .mlp_norm_eps ,
729+ ffn_mlp_type = getattr (self .cf , "decoder_ffn_mlp_type" , "dense" ),
730+ ffn_hidden_factor = getattr (self .cf , "decoder_ffn_hidden_factor" , 4 ),
731+ moe_kwargs = dict (
732+ num_experts = getattr (self .cf , "decoder_moe_num_experts" , 2 ),
733+ top_k = getattr (self .cf , "decoder_moe_top_k" , 2 ),
734+ router_noisy_std = getattr (self .cf , "decoder_moe_router_noisy_std" , 0.0 ),
735+ use_checkpoint = getattr (self .cf , "decoder_moe_use_checkpoint" , False ),
736+ )
677737 )
678738 )
679739 else :
0 commit comments