Skip to content

Commit 36fea3a

Browse files
committed
Adding MoEMLP layer to the layers file, integrate the MoE layer in Forecasting engine, and set up the config file to control the use of this layer
1 parent aa3668a commit 36fea3a

File tree

3 files changed

+260
-11
lines changed

3 files changed

+260
-11
lines changed

config/default_config.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,11 @@ run_id: ???
152152
train_log:
153153
# The period to log metrics (in number of batch steps)
154154
log_interval: 20
155+
156+
# Forecast MLP type: "dense" (default) or "moe"
157+
fe_mlp_type: "dense" # set to "moe" to enable MoE
158+
159+
# MoE-only params (ignored when fe_mlp_type != "moe")
160+
fe_moe_num_experts: 8
161+
fe_moe_top_k: 2
162+
fe_moe_hidden_factor: 0.5 # = HF_dense / 4

src/weathergen/model/engines.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -318,18 +318,50 @@ def create(self) -> torch.nn.ModuleList:
318318
)
319319
)
320320
# Add MLP block
321-
self.fe_blocks.append(
322-
MLP(
323-
self.cf.ae_global_dim_embed,
324-
self.cf.ae_global_dim_embed,
325-
with_residual=True,
326-
dropout_rate=self.cf.fe_dropout_rate,
327-
norm_type=self.cf.norm_type,
328-
dim_aux=1,
329-
norm_eps=self.cf.mlp_norm_eps,
330-
)
321+
use_moe = getattr(self.cf, "fe_mlp_type", "dense") == "moe"
322+
mlp_common_kwargs = dict(
323+
dim_in=self.cf.ae_global_dim_embed,
324+
dim_out=self.cf.ae_global_dim_embed,
325+
with_residual=True,
326+
dropout_rate=self.cf.fe_dropout_rate,
327+
norm_type=self.cf.norm_type,
328+
dim_aux=1,
329+
norm_eps=self.cf.mlp_norm_eps,
331330
)
332-
331+
# self.fe_blocks.append(
332+
# MLP(
333+
# self.cf.ae_global_dim_embed,
334+
# self.cf.ae_global_dim_embed,
335+
# with_residual=True,
336+
# dropout_rate=self.cf.fe_dropout_rate,
337+
# norm_type=self.cf.norm_type,
338+
# dim_aux=1,
339+
# norm_eps=self.cf.mlp_norm_eps,
340+
# )
341+
# )
342+
if use_moe:
343+
self.fe_blocks.append(
344+
MoEMLP(
345+
**mlp_common_kwargs,
346+
num_experts=getattr(self.cf, "fe_moe_num_experts", 8),
347+
top_k=getattr(self.cf, "fe_moe_top_k", 4),
348+
router_noisy_std=getattr(self.cf, "fe_moe_router_noisy_std", 0.0),
349+
hidden_factor=getattr(self.cf, "fe_moe_hidden_factor", 2),
350+
)
351+
)
352+
else:
353+
self.fe_blocks.append(
354+
MLP(
355+
self.cf.ae_global_dim_embed,
356+
self.cf.ae_global_dim_embed,
357+
with_residual=True,
358+
dropout_rate=self.cf.fe_dropout_rate,
359+
norm_type=self.cf.norm_type,
360+
dim_aux=1,
361+
norm_eps=self.cf.mlp_norm_eps,
362+
)
363+
)
364+
# ------------------------------------------------------------------
333365
def init_weights_final(m):
334366
if isinstance(m, torch.nn.Linear):
335367
torch.nn.init.normal_(m.weight, mean=0, std=0.001)

src/weathergen/model/layers.py

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,212 @@ def forward(self, *args):
9393
x = x + x_in.repeat([*[1 for _ in x.shape[:-1]], x.shape[-1] // x_in.shape[-1]])
9494

9595
return x
96+
97+
class _DenseBlock(nn.Module):
98+
"""A tiny FFN that mirrors the structure of the current MLP stack."""
99+
def __init__(self, dim_in, dim_hidden, dim_out, num_layers=2,
100+
nonlin=nn.GELU, dropout_rate=0.0):
101+
super().__init__()
102+
layers = [nn.Linear(dim_in, dim_hidden), nonlin(), nn.Dropout(dropout_rate)]
103+
for _ in range(num_layers - 2):
104+
layers += [nn.Linear(dim_hidden, dim_hidden), nonlin(), nn.Dropout(dropout_rate)]
105+
layers += [nn.Linear(dim_hidden, dim_out)]
106+
self.net = nn.Sequential(*layers)
107+
108+
def forward(self, x):
109+
return self.net(x)
110+
111+
class MoEMLP(nn.Module):
112+
"""
113+
Drop-in MoE MLP (memory-friendly):
114+
- Same call pattern as the current MLP: forward(*args) where args=(x, ...) and optional aux at the end
115+
- Supports residual add exactly like MLP
116+
- Optional AdaLayerNorm when dim_aux is provided
117+
- Simple top-k router; mixes experts with streaming accumulation (no big [E, ..., D] stack)
118+
"""
119+
def __init__(
120+
self,
121+
dim_in,
122+
dim_out,
123+
num_layers=2,
124+
hidden_factor=2,
125+
pre_layer_norm=True,
126+
dropout_rate=0.0,
127+
nonlin=nn.GELU,
128+
with_residual=False,
129+
norm_type="LayerNorm",
130+
dim_aux=None,
131+
norm_eps=1e-5,
132+
name: str | None = None,
133+
# MoE bits
134+
num_experts: int = 8,
135+
top_k: int = 4,
136+
router_noisy_std: float = 0.0, # set >0 to add noise to router logits
137+
# Memory bits
138+
use_checkpoint: bool = False, # checkpoint expert forward to save memory
139+
):
140+
super().__init__()
141+
if name is not None:
142+
self.name = name
143+
144+
assert num_layers >= 2
145+
assert 1 <= top_k <= num_experts
146+
147+
self.with_residual = with_residual
148+
self.with_aux = dim_aux is not None
149+
self.pre_layer_norm = pre_layer_norm
150+
self.top_k = top_k
151+
self.num_experts = num_experts
152+
self.use_checkpoint = use_checkpoint
153+
154+
dim_hidden = int(dim_in * hidden_factor)
155+
156+
# Norm (match MLP behavior)
157+
Norm = nn.LayerNorm if norm_type == "LayerNorm" else RMSNorm
158+
if pre_layer_norm:
159+
self.norm = (
160+
Norm(dim_in, eps=norm_eps)
161+
if dim_aux is None
162+
else AdaLayerNorm(dim_in, dim_aux, norm_eps=norm_eps)
163+
)
164+
else:
165+
self.norm = None # no pre-norm
166+
167+
# Router
168+
self.router = nn.Linear(dim_in, num_experts)
169+
self.router_noisy_std = router_noisy_std
170+
171+
# Experts (identical shape)
172+
self.experts = nn.ModuleList(
173+
[
174+
_DenseBlock(
175+
dim_in=dim_in,
176+
dim_hidden=dim_hidden,
177+
dim_out=dim_out,
178+
num_layers=num_layers,
179+
nonlin=nonlin,
180+
dropout_rate=dropout_rate,
181+
)
182+
for _ in range(num_experts)
183+
]
184+
)
185+
186+
# For optional aux loss (load-balancing); not used unless you read it
187+
self.register_buffer("last_aux_loss", torch.zeros((), dtype=torch.float32))
188+
189+
def _gate(self, x_norm):
190+
# x_norm: [*, D]. Router works on the last dim.
191+
logits = self.router(x_norm)
192+
if self.router_noisy_std > 0:
193+
logits = logits + torch.randn_like(logits) * self.router_noisy_std
194+
195+
if self.top_k == self.num_experts:
196+
# softmax over all experts
197+
weights = torch.softmax(logits, dim=-1) # [..., E]
198+
top_idx = None # not needed
199+
else:
200+
# top-k softmax
201+
top_vals, top_idx = torch.topk(logits, k=self.top_k, dim=-1) # [*, k]
202+
weights = torch.softmax(top_vals, dim=-1) # [*, k]
203+
return weights, top_idx
204+
205+
@torch.no_grad()
206+
def _compute_load_balance_aux(self, weights, top_idx, num_experts):
207+
"""
208+
Simple load-balancing penalty from Switch/MoE papers:
209+
Encourage uniform expert probability and uniform usage.
210+
Works with both full-softmax (top_idx None) and top-k.
211+
"""
212+
if top_idx is None:
213+
# weights over E
214+
probs = weights.mean(dim=tuple(range(weights.dim() - 1))) # [E]
215+
else:
216+
# Build usage over experts from top-k selection
217+
# *prefix, K = weights.shape
218+
# flat_w = weights.reshape(-1, K) # [N, K]
219+
# flat_i = top_idx.reshape(-1, K) # [N, K]
220+
if weights.shape != top_idx.shape:
221+
raise ValueError(
222+
"Top-k weights and indices must share the same shape"
223+
)
224+
225+
K = weights.shape[-1]
226+
flat_w = weights.reshape(-1, K) # [N, K]
227+
flat_i = top_idx.reshape(-1, K) # [N, K]
228+
E = num_experts
229+
usage = torch.zeros(E, device=weights.device, dtype=weights.dtype)
230+
usage.scatter_add_(0, flat_i.reshape(-1), flat_w.reshape(-1))
231+
usage = usage / usage.sum().clamp_min(1e-6) # normalize
232+
probs = usage # proxy
233+
# Target is uniform 1/E
234+
E = num_experts
235+
target = torch.full_like(probs, 1.0 / E)
236+
aux = (probs * (probs.add(1e-6).log() - target.add(1e-6).log())).sum()
237+
return aux
238+
239+
def forward(self, *args):
240+
# Match your MLP(*args) calling convention
241+
x = args[0]
242+
x_in = x
243+
aux = args[-1] if self.with_aux else None
244+
245+
# Optional pre-norm (possibly adaptive)
246+
if self.norm is not None:
247+
if self.with_aux:
248+
x = self.norm(x, aux)
249+
else:
250+
x = self.norm(x)
251+
252+
# Router
253+
weights, top_idx = self._gate(x) # weights: [..., E] or [..., K]
254+
255+
# Build a full weight tensor [..., E] if we are in top-k mode,
256+
# so we can stream over experts without stacking their outputs.
257+
if top_idx is None:
258+
w_full = weights # [..., E]
259+
else:
260+
# scatter top-k weights into a zero tensor of size E
261+
E = self.num_experts
262+
w_full = torch.zeros(*weights.shape[:-1], E, device=weights.device, dtype=weights.dtype) # [..., E]
263+
w_full.scatter_(-1, top_idx, weights)
264+
265+
# Output accumulator (no expert stacking)
266+
out_dim = self.experts[0].net[-1].out_features # last Linear of _DenseBlock
267+
y = x.new_zeros(*x.shape[:-1], out_dim)
268+
269+
# Optional gradient checkpoint
270+
if self.use_checkpoint:
271+
from torch.utils.checkpoint import checkpoint
272+
273+
# Stream over experts: y += expert(x) * w_full[..., e]
274+
for e, expert in enumerate(self.experts):
275+
# skip compute if weight mass is (nearly) zero for this expert
276+
w_e = w_full[..., e] # [...]
277+
if torch.allclose(w_e, torch.zeros((), device=w_e.device, dtype=w_e.dtype)):
278+
continue
279+
280+
if self.use_checkpoint and self.training:
281+
y_e = checkpoint(expert, x)
282+
else:
283+
y_e = expert(x)
284+
y = y + y_e * w_e.unsqueeze(-1)
285+
286+
# Residual (same logic as your MLP)
287+
if self.with_residual:
288+
if y.shape[-1] == x_in.shape[-1]:
289+
y = x_in + y
290+
else:
291+
assert y.shape[-1] % x_in.shape[-1] == 0
292+
y = y + x_in.repeat([*[1 for _ in y.shape[:-1]], y.shape[-1] // x_in.shape[-1]])
293+
294+
# Optional: update aux loss (not returned; read if you want)
295+
with torch.no_grad():
296+
self.last_aux_loss = self._compute_load_balance_aux(
297+
# w_full if top_idx is not None else weights, # use full probs if we built them
298+
# None if top_idx is None else top_idx,
299+
weights,
300+
top_idx,
301+
self.num_experts,
302+
)
303+
304+
return y

0 commit comments

Comments
 (0)