diff --git a/config/default_config.yml b/config/default_config.yml index efb6e95b3..8f0f2d459 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -23,7 +23,7 @@ ae_adapter_with_qk_lnorm: True ae_adapter_with_residual: True ae_adapter_dropout_rate: 0.1 -ae_global_dim_embed: 256 +ae_global_dim_embed: 2048 ae_global_num_blocks: 4 ae_global_num_heads: 16 ae_global_dropout_rate: 0.1 diff --git a/config/streams/era5_1deg/era5.yml b/config/streams/era5_1deg/era5.yml index e9cc9a6b8..912075c4b 100644 --- a/config/streams/era5_1deg/era5.yml +++ b/config/streams/era5_1deg/era5.yml @@ -9,8 +9,8 @@ ERA5 : type : anemoi - #filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] - filenames : ['aifs-ea-an-oper-0001-mars-o48-1979-2024-6h-v1.zarr'] + filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] + # filenames : ['aifs-ea-an-oper-0001-mars-o48-1979-2024-6h-v1.zarr'] source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp'] loss_weight : 1. diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 78d11a4a6..fbb930ad3 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -10,8 +10,8 @@ import torch import torch.nn as nn from torch.utils.checkpoint import checkpoint - from weathergen.common.config import Config + from weathergen.model.attention import ( MultiCrossAttentionHeadVarlen, MultiCrossAttentionHeadVarlenSlicedQ, @@ -24,7 +24,7 @@ StreamEmbedLinear, StreamEmbedTransformer, ) -from weathergen.model.layers import MLP +from weathergen.model.layers import FEMLP, MLP from weathergen.model.utils import ActivationFactory from weathergen.utils.utils import get_dtype @@ -317,18 +317,31 @@ def create(self) -> torch.nn.ModuleList: attention_dtype=get_dtype(self.cf.attention_dtype), ) ) - # Add MLP block - self.fe_blocks.append( - MLP( - self.cf.ae_global_dim_embed, - self.cf.ae_global_dim_embed, - with_residual=True, - dropout_rate=self.cf.fe_dropout_rate, - norm_type=self.cf.norm_type, - dim_aux=1, - norm_eps=self.cf.mlp_norm_eps, + + if i + 1 == self.cf.ae_global_num_blocks: + self.fe_blocks.append( + FEMLP( + self.cf.ae_global_dim_embed, + self.cf.ae_global_dim_embed, + with_residual=True, + dropout_rate=self.cf.fe_dropout_rate, + norm_type=self.cf.norm_type, + dim_aux=1, + norm_eps=self.cf.mlp_norm_eps, + ) + ) + else: + self.fe_blocks.append( + MLP( + self.cf.ae_global_dim_embed, + self.cf.ae_global_dim_embed, + with_residual=True, + dropout_rate=self.cf.fe_dropout_rate, + norm_type=self.cf.norm_type, + dim_aux=1, + norm_eps=self.cf.mlp_norm_eps, + ) ) - ) def init_weights_final(m): if isinstance(m, torch.nn.Linear): diff --git a/src/weathergen/model/layers.py b/src/weathergen/model/layers.py index 1f7b8df5d..17cca11e8 100644 --- a/src/weathergen/model/layers.py +++ b/src/weathergen/model/layers.py @@ -93,3 +93,80 @@ def forward(self, *args): x = x + x_in.repeat([*[1 for _ in x.shape[:-1]], x.shape[-1] // x_in.shape[-1]]) return x + + +class FEMLP(torch.nn.Module): + def __init__( + self, + dim_in, + dim_out, + num_layers=2, + hidden_factor=2, + pre_layer_norm=True, + dropout_rate=0.0, + nonlin=torch.nn.GELU, + with_residual=False, + norm_type="LayerNorm", + dim_aux=None, + norm_eps=1e-5, + name: str | None = None, + ): + """Constructor""" + + super(FEMLP, self).__init__() + + if name is not None: + self.name = name + + assert num_layers >= 2 + + self.with_residual = with_residual + self.with_aux = dim_aux is not None + dim_hidden = int(dim_in * hidden_factor) + + self.layers = torch.nn.ModuleList() + + norm = torch.nn.LayerNorm if norm_type == "LayerNorm" else RMSNorm + + if pre_layer_norm: + self.layers.append( + norm(dim_in, eps=norm_eps) + if dim_aux is None + else AdaLayerNorm(dim_in, dim_aux, norm_eps=norm_eps) + ) + + self.layers.append(torch.nn.Linear(dim_in, dim_hidden)) + self.layers.append(nonlin()) + self.layers.append(torch.nn.Dropout(p=dropout_rate)) + + for _ in range(num_layers - 2): + self.layers.append(torch.nn.Linear(dim_hidden, dim_hidden)) + self.layers.append(nonlin()) + self.layers.append(torch.nn.Dropout(p=dropout_rate)) + + self.layers.append(torch.nn.Linear(dim_hidden, dim_out)) + + # Add LayerNorm after skip connection if residuals are used + if self.with_residual: + # self.residual_norm = AdaLayerNorm( + # dim_out, dim_aux, norm_eps=norm_eps + # ) # norm(dim_out, eps=norm_eps) + self.residual_norm = torch.nn.LayerNorm(dim_out, eps=norm_eps, elementwise_affine=False) + + def forward(self, *args): + x, x_in, aux = args[0], args[0], args[-1] + + for i, layer in enumerate(self.layers): + x = layer(x, aux) if (i == 0 and self.with_aux) else layer(x) + + if self.with_residual: + if x.shape[-1] == x_in.shape[-1]: + x = x_in + x + else: + assert x.shape[-1] % x_in.shape[-1] == 0 + x = x + x_in.repeat([*[1 for _ in x.shape[:-1]], x.shape[-1] // x_in.shape[-1]]) + + # Apply LayerNorm to the residual connection + x = self.residual_norm(x) + + return x