diff --git a/.gitignore b/.gitignore index fa4c1b7..f7fac1f 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,8 @@ dmypy.json # Pyre type checker .pyre/ + +# Lightning and wandb outputs +checkpoints +logs +wandb diff --git a/configs/model/iam4vp.yaml b/configs/model/iam4vp.yaml new file mode 100644 index 0000000..46de027 --- /dev/null +++ b/configs/model/iam4vp.yaml @@ -0,0 +1,14 @@ +_target_: sat_pred.training_module.TrainingModule +model: + _target_: sat_pred.models.iam4vp_model.IAM4VP + num_channels: 11 + history_len: 12 + forecast_len: 12 + hid_S: 4 + hid_T: 4 + N_S: 6 + N_T: 6 +optimizer: + _target_: sat_pred.optimizers.AdamWReduceLROnPlateau + lr: 0.0005 +target_loss: MAE diff --git a/requirements.txt b/requirements.txt index ade5916..8ede35e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ -lightning -torch -numpy hydra-core +lightning matplotlib +numpy +ocf-iam4vp @ git+https://github.com/alan-turing-institute/ocf-iam4vp.git@0.4.2 pyaml_env +torch \ No newline at end of file diff --git a/sat_pred/models/iam4vp_model.py b/sat_pred/models/iam4vp_model.py new file mode 100644 index 0000000..ce48a58 --- /dev/null +++ b/sat_pred/models/iam4vp_model.py @@ -0,0 +1,26 @@ +from ocf_iam4vp import IAM4VP as IAM4VPBase +from torch import nn, stack, Tensor + + +class IAM4VP(nn.Module): + + def __init__( + self, num_channels, history_len, forecast_len, hid_S=16, hid_T=256, N_S=4, N_T=8 + ): + super().__init__() + self.model = IAM4VPBase( + num_channels, + num_history_steps=history_len, + num_forecast_steps=forecast_len, + hid_S=hid_S, + hid_T=hid_T, + N_S=N_S, + N_T=N_T, + ) + + def forward(self, X): + # Input batches have shape: (batch, channel, time, height, width) + y_hats: list[Tensor] = [] + for _ in range(self.model.num_forecast_steps): + y_hats.append(self.model(X, y_hats)) + return stack(y_hats, dim=2)