Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion config/default_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ ae_adapter_with_qk_lnorm: True
ae_adapter_with_residual: True
ae_adapter_dropout_rate: 0.1

ae_global_dim_embed: 256
ae_global_dim_embed: 2048
ae_global_num_blocks: 4
ae_global_num_heads: 16
ae_global_dropout_rate: 0.1
Expand Down
4 changes: 2 additions & 2 deletions config/streams/era5_1deg/era5.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

ERA5 :
type : anemoi
#filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr']
filenames : ['aifs-ea-an-oper-0001-mars-o48-1979-2024-6h-v1.zarr']
filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr']
# filenames : ['aifs-ea-an-oper-0001-mars-o48-1979-2024-6h-v1.zarr']
source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp']
target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp']
loss_weight : 1.
Expand Down
39 changes: 26 additions & 13 deletions src/weathergen/model/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

from weathergen.common.config import Config

from weathergen.model.attention import (
MultiCrossAttentionHeadVarlen,
MultiCrossAttentionHeadVarlenSlicedQ,
Expand All @@ -24,7 +24,7 @@
StreamEmbedLinear,
StreamEmbedTransformer,
)
from weathergen.model.layers import MLP
from weathergen.model.layers import FEMLP, MLP
from weathergen.model.utils import ActivationFactory
from weathergen.utils.utils import get_dtype

Expand Down Expand Up @@ -317,18 +317,31 @@ def create(self) -> torch.nn.ModuleList:
attention_dtype=get_dtype(self.cf.attention_dtype),
)
)
# Add MLP block
self.fe_blocks.append(
MLP(
self.cf.ae_global_dim_embed,
self.cf.ae_global_dim_embed,
with_residual=True,
dropout_rate=self.cf.fe_dropout_rate,
norm_type=self.cf.norm_type,
dim_aux=1,
norm_eps=self.cf.mlp_norm_eps,

if i + 1 == self.cf.ae_global_num_blocks:
self.fe_blocks.append(
FEMLP(
self.cf.ae_global_dim_embed,
self.cf.ae_global_dim_embed,
with_residual=True,
dropout_rate=self.cf.fe_dropout_rate,
norm_type=self.cf.norm_type,
dim_aux=1,
norm_eps=self.cf.mlp_norm_eps,
)
)
else:
self.fe_blocks.append(
MLP(
self.cf.ae_global_dim_embed,
self.cf.ae_global_dim_embed,
with_residual=True,
dropout_rate=self.cf.fe_dropout_rate,
norm_type=self.cf.norm_type,
dim_aux=1,
norm_eps=self.cf.mlp_norm_eps,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the sake of having less redundant code, can we modify the MLP to receive an additional argument with_residual_layer_norm=False instead of introducing FEMLP? When calling MLP here, we can set with_residual_layer_norm=(i + 1) == self.cf.fe_num_blocks to add the residual layer norm in the last MLP layer.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would we need to modify the MLP? It can--and in my opinion definitely should--be implemented in the forecast engine. Where we can just have the LayerNorm as the last block.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check now? Is that better?

)
)

def init_weights_final(m):
if isinstance(m, torch.nn.Linear):
Expand Down
77 changes: 77 additions & 0 deletions src/weathergen/model/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,80 @@ def forward(self, *args):
x = x + x_in.repeat([*[1 for _ in x.shape[:-1]], x.shape[-1] // x_in.shape[-1]])

return x


class FEMLP(torch.nn.Module):
def __init__(
self,
dim_in,
dim_out,
num_layers=2,
hidden_factor=2,
pre_layer_norm=True,
dropout_rate=0.0,
nonlin=torch.nn.GELU,
with_residual=False,
norm_type="LayerNorm",
dim_aux=None,
norm_eps=1e-5,
name: str | None = None,
):
"""Constructor"""

super(FEMLP, self).__init__()

if name is not None:
self.name = name

assert num_layers >= 2

self.with_residual = with_residual
self.with_aux = dim_aux is not None
dim_hidden = int(dim_in * hidden_factor)

self.layers = torch.nn.ModuleList()

norm = torch.nn.LayerNorm if norm_type == "LayerNorm" else RMSNorm

if pre_layer_norm:
self.layers.append(
norm(dim_in, eps=norm_eps)
if dim_aux is None
else AdaLayerNorm(dim_in, dim_aux, norm_eps=norm_eps)
)

self.layers.append(torch.nn.Linear(dim_in, dim_hidden))
self.layers.append(nonlin())
self.layers.append(torch.nn.Dropout(p=dropout_rate))

for _ in range(num_layers - 2):
self.layers.append(torch.nn.Linear(dim_hidden, dim_hidden))
self.layers.append(nonlin())
self.layers.append(torch.nn.Dropout(p=dropout_rate))

self.layers.append(torch.nn.Linear(dim_hidden, dim_out))

# Add LayerNorm after skip connection if residuals are used
if self.with_residual:
# self.residual_norm = AdaLayerNorm(
# dim_out, dim_aux, norm_eps=norm_eps
# ) # norm(dim_out, eps=norm_eps)
self.residual_norm = torch.nn.LayerNorm(dim_out, eps=norm_eps, elementwise_affine=False)

def forward(self, *args):
x, x_in, aux = args[0], args[0], args[-1]

for i, layer in enumerate(self.layers):
x = layer(x, aux) if (i == 0 and self.with_aux) else layer(x)

if self.with_residual:
if x.shape[-1] == x_in.shape[-1]:
x = x_in + x
else:
assert x.shape[-1] % x_in.shape[-1] == 0
x = x + x_in.repeat([*[1 for _ in x.shape[:-1]], x.shape[-1] // x_in.shape[-1]])

# Apply LayerNorm to the residual connection
x = self.residual_norm(x)

return x