Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
0eca942
make pip installable, move to proper package folder
romeokienzler Oct 23, 2024
2d8bd59
fix dependencies
romeokienzler Oct 23, 2024
0031225
externalize config
romeokienzler Oct 23, 2024
79ac327
externalize config
romeokienzler Oct 23, 2024
cba5c36
fix config.py
romeokienzler Oct 23, 2024
480f76d
add wxc dep
romeokienzler Oct 24, 2024
bad4308
move from setup.py to pyproject.toml
romeokienzler Oct 24, 2024
49cd211
fix pyproject
romeokienzler Oct 24, 2024
eb96f73
remove submodules (add pip install), update readme
romeokienzler Oct 24, 2024
fecb7e6
Create init.py
romeokienzler Oct 25, 2024
55ec3a4
Update pyproject.toml
romeokienzler Oct 25, 2024
fff5d47
Update pyproject.toml
romeokienzler Oct 25, 2024
11b2361
Update pyproject.toml
romeokienzler Oct 25, 2024
b1c3d07
fix install
romeokienzler Oct 25, 2024
245ba13
fix import
romeokienzler Oct 25, 2024
c1e59ad
fix path after refactoring
romeokienzler Oct 25, 2024
fed0248
fix import
romeokienzler Oct 25, 2024
bacc93c
fix for robustness
romeokienzler Oct 25, 2024
f370566
fix path to config file
romeokienzler Oct 25, 2024
a9c8b6f
add cpu support for inference
romeokienzler Oct 28, 2024
a4ca275
add proper inheritance of ERA5DataModule
romeokienzler Oct 31, 2024
1dffc45
add lightning
romeokienzler Oct 31, 2024
22503cf
relax python version
romeokienzler Oct 31, 2024
f864e76
Delete prithviwxc/init.py
romeokienzler Nov 1, 2024
199628b
Create __init__.py
romeokienzler Nov 1, 2024
fc18b64
Delete prithviwxc/gravitywave/init.py
romeokienzler Nov 1, 2024
2e96eb9
Create __init__.py
romeokienzler Nov 1, 2024
4f74146
Update setup.py
romeokienzler Nov 1, 2024
f01a097
fix subpackage inclusion
romeokienzler Nov 4, 2024
45c3ed2
fix pyproject.toml
romeokienzler Nov 4, 2024
fed4238
fix pyproject.toml
romeokienzler Nov 4, 2024
a7580a5
send data to gpu where appropriate
romeokienzler Nov 26, 2024
7b85bee
send data to gpu where appropriate
romeokienzler Nov 26, 2024
f62f5d3
Merge branch '201' of github.com:romeokienzler/terratorch into 201
romeokienzler Jan 7, 2025
e8bfab4
remove version constraints
romeokienzler Jan 7, 2025
7732662
fix pyproject.toml
romeokienzler Jan 7, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .gitmodules

This file was deleted.

11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ This repository contains code and resources for training and inferring gravity w

## Setup

1. Clone the repository with submodules:
1. Clone the repository and download config:

git clone --recurse-submodules [email protected]:NASA-IMPACT/gravity-wave-finetuning.git gravity_wave_finetuning
cd gravity_wave_finetuning
git clone [email protected]:NASA-IMPACT/gravity-wave-finetuning.git
cd gravity-wave-finetuning
wget https://huggingface.co/Prithvi-WxC/Gravity_wave_Parameterization/resolve/main/config.yaml

2. Create and activate a Conda environment for the project:

Expand Down Expand Up @@ -55,7 +56,7 @@ To run the training on a single node and a single GPU, execute the following com
--nproc_per_node=1 \
--nnodes=1 \
--rdzv_backend=c10d \
finetune_gravity_wave.py
prithviwxc/gravitywave/finetune_gravity_wave.py
--split uvtp122

### Multi-node Training
Expand All @@ -71,7 +72,7 @@ After training, you can run inferences using the following command. Make sure to
--nnodes=1 \
--nproc_per_node=1 \
--rdzv_backend=c10d \
inference.py \
prithviwxc/gravitywave/inference.py \
--split=uvtp122 \
--ckpt_path=/path/to/checkpoint \
--data_path=/path/to/data \
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,4 @@ dependencies:
- wandb==0.17.7
- xarray==2024.7.0
- yacs==0.1.8
- git+https://github.com/NASA-IMPACT/Prithvi-WxC.git
1 change: 1 addition & 0 deletions prithviwxc/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

1 change: 1 addition & 0 deletions prithviwxc/gravitywave/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

19 changes: 7 additions & 12 deletions config.py → prithviwxc/gravitywave/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,12 @@

_CN = CN()

# Declare all the configuration keys
_CN.wandb_mode = "disabled"
_CN.vartype = "uvtp122"
_CN.train_data_path = (
"gravity_wave_flux/uvtp122"
)
_CN.valid_data_path = (
"gravity_wave_flux/uvtp122/test"
)
_CN.singular_sharded_checkpoint = (
"prithvi_wxc/v0.8.50.rollout_step3.1.pth"
)
_CN.train_data_path = "gravity_wave_flux/uvtp122"
_CN.valid_data_path = "gravity_wave_flux/uvtp122/test"
_CN.singular_sharded_checkpoint = "prithvi_wxc/v0.8.50.rollout_step3.1.pth"
_CN.file_glob_pattern = "wxc_input_u_v_t_p_output_theta_uw_vw_era5_*.nc"

_CN.lr = 0.0001
Expand All @@ -23,14 +18,14 @@
_CN.mask_unit_size_px = [8, 16]
_CN.patch_size_px = [1, 1]


### Training Params

# Training parameters
_CN.max_epochs = 100
_CN.batch_size = 12
_CN.num_data_workers = 8

_CN.merge_from_file("config.yaml")

# Function to clone the config
def get_cfg():
return _CN.clone()

20 changes: 11 additions & 9 deletions datamodule.py → prithviwxc/gravitywave/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import xarray as xr
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler

import lightning as pl

def get_era5_uvtp122(ds: xr.Dataset, index: int = 0) -> dict[str, torch.Tensor]:
"""Retrieve climate data variables at 122 pressure levels.
Expand Down Expand Up @@ -134,7 +134,7 @@ def __getitem__(self, index: int = 0) -> dict[str, torch.Tensor]:
return batch


class ERA5DataModule:
class ERA5DataModule(pl.LightningDataModule):
"""
This module handles data loading, batching, and train/validation splits.

Expand All @@ -143,7 +143,7 @@ class ERA5DataModule:
valid_data_path: Path to validation data.
file_glob_pattern: Pattern to match NetCDF files.
batch_size: Size of each mini-batch.
num_workers: Number of subprocesses for data loading.
num_workers: Number of subprocesses for data loading.
"""

def __init__(
Expand Down Expand Up @@ -171,24 +171,26 @@ def __init__(
self.batch_size: int = batch_size
self.num_workers: int = num_data_workers

def setup(self, stage: str | None = None) -> tuple[Dataset, Dataset]:
"""Sets up the datasets for different stages
(train, validation, predict).
def prepare_data(self):
pass

Args:
stage: Stage for which the setup is performed ("fit", "predict").
"""
def setup(self, stage: str | None = None) -> tuple[Dataset, Dataset]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if stage == "fit":
self.dataset_train = ERA5Dataset(
data_path=self.train_data_path, file_glob_pattern=self.file_glob_pattern
)
self.dataset_train = self.dataset_train.to(device)
self.dataset_val = ERA5Dataset(
data_path=self.valid_data_path, file_glob_pattern=self.file_glob_pattern
)
self.dataset_val = self.dataset_val.to(device)
elif stage == "predict":
self.dataset_predict = ERA5Dataset(
data_path=self.valid_data_path, file_glob_pattern=self.file_glob_pattern
)
self.dataset_predict = self.dataset_predict.to(device)


def train_dataloader(self) -> DataLoader:
"""Returns a DataLoader for the training data."""
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import torch.nn as nn

import importlib
model = importlib.import_module('Prithvi-WxC.PrithviWxC.model')
model = importlib.import_module('PrithviWxC.model')

from distributed import print0
from prithviwxc.gravitywave.distributed import print0

torch.set_float32_matmul_precision("high")

Expand Down
16 changes: 10 additions & 6 deletions inference.py → prithviwxc/gravitywave/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
import tqdm
import xarray as xr

from datamodule import ERA5DataModule
from gravity_wave_model import UNetWithTransformer
from prithviwxc.gravitywave.datamodule import ERA5DataModule
from prithviwxc.gravitywave.gravity_wave_model import UNetWithTransformer
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist

local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
device = f"cuda:{local_rank}"
local_rank = int(os.getenv("LOCAL_RANK",'0'))
rank = int(os.getenv("RANK",'0'))
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
dtype = torch.float32

def setup():
Expand Down Expand Up @@ -60,7 +60,11 @@ def get_model(cfg, vartype,ckpt_singular: str) -> torch.nn.Module:
patch_size_px=cfg.patch_size_px,
device=device,
)
model = DDP(model.to(local_rank, dtype=dtype), device_ids=[local_rank])

if device==torch.device('cpu'):
model = DDP(model.to(device))
else:
model = DDP(model.to(local_rank, dtype=dtype), device_ids=[local_rank])
model = load_checkpoint(model,ckpt_singular)

return model
Expand Down
99 changes: 99 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"

[project]
name = "prithviwxc-gravitywave"
version = "0.1.0"
description = "Gravity Wave Parameterization"
readme = { file = "README.md", content-type = "text/markdown" }
requires-python = ">=3.10"
license = { text = "MIT" }
authors = [
{ name = "Sujit Roy", email = "[email protected]" }
]
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent"
]
dependencies = [
"cachetools",
"cartopy",
"click",
"cloudpickle",
"contourpy",
"cycler",
"dask",
"docker-pycreds",
"fonttools",
"fsspec",
"gitdb",
"gitpython",
"h5netcdf",
"h5py",
"kiwisolver",
"locket",
"matplotlib",
"nvidia-ml-py",
"nvitop",
"pandas",
"partd",
"platformdirs",
"protobuf",
"pyhelpme",
"pyparsing",
"pyproj",
"pyshp",
"python-dateutil",
"pytz",
"sentry-sdk",
"setproctitle",
"shapely",
"smmap",
"tabulate",
"termcolor",
"toolz",
"tqdm",
"tzdata",
"wandb",
"xarray",
"yacs",
"asttokens",
"brotli",
"certifi",
"charset-normalizer",
"comm",
"debugpy",
"decorator",
"exceptiongroup",
"executing",
"filelock",
"idna",
"importlib-metadata",
"intel-openmp",
"ipykernel",
"ipython",
"jedi",
"jinja2",
"jupyter_client",
"jupyter_core",
"lightning",
"markupsafe",
"matplotlib-inline",
"mkl",
"mpmath",
"numpy",
"scipy",
"torch",
"torchaudio",
"torchvision",
"yacs"
]

[project.urls]
homepage = "https://github.com/NASA-IMPACT/gravity-wave-finetuning"

[tool.setuptools]
packages=["prithviwxc", "prithviwxc.gravitywave"]
include-package-data = true
41 changes: 41 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
cachetools==5.5.0
cartopy==0.23.0
click==8.1.7
cloudpickle==3.0.0
contourpy==1.3.0
cycler==0.12.1
dask==2024.8.1
docker-pycreds==0.4.0
fonttools==4.53.1
fsspec==2024.6.1
gitdb==4.0.11
gitpython==3.1.43
h5netcdf==1.3.0
h5py==3.11.0
kiwisolver==1.4.7
locket==1.0.0
matplotlib==3.9.2
nvidia-ml-py==12.535.161
nvitop==1.3.2
pandas==2.2.2
partd==1.4.2
platformdirs==4.2.2
protobuf==5.27.3
pyhelpme==0.1
pyparsing==3.1.4
pyproj==3.6.1
pyshp==2.3.1
python-dateutil==2.9.0.post0
pytz==2024.1
sentry-sdk==2.13.0
setproctitle==1.3.3
shapely==2.0.6
smmap==5.0.1
tabulate==0.9.0
termcolor==2.4.0
toolz==0.12.1
tqdm==4.66.5
tzdata==2024.1
wandb==0.17.7
xarray==2024.7.0
yacs==0.1.8
2 changes: 1 addition & 1 deletion scripts/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,6 @@ torchrun \
--rdzv_id=$RDZV_ID \
--rdzv_endpoint "$RDZV_ADDR:$RDZV_PORT" \
--rdzv_backend=c10d \
<SCRIPT_NAME>.py --split $data_args
./prithviwxc/gravitywave/<SCRIPT_NAME>.py --split $data_args

conda deactivate
9 changes: 9 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from setuptools import setup, find_packages

setup(
name="prithviwxc-gravitywave",
version="0.1.0",
description="Gravity Wave Parameterization",
packages=["prithviwxc", "prithviwxc.gravitywave"],
include_package_data=True,
)