Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 39 additions & 6 deletions src/weathergen/datasets/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,9 @@ def mask_source(
coords: torch.Tensor,
geoinfos: torch.Tensor,
source: torch.Tensor,
) -> list[torch.Tensor]:
idxs_inv: torch.Tensor,
idxs_inv_lens: list[list[int]],
) -> (list[torch.Tensor], torch.Tensor):
"""
Receives tokenized data, generates a mask, and returns the source data (unmasked)
and the permutation selection mask (perm_sel) to be used for the target.
Expand All @@ -160,6 +162,14 @@ def mask_source(
token_lens = [len(t) for t in tokenized_data]
num_tokens = sum(token_lens)

# TODO: wrap in function
lens_per_cell = [np.array(i).sum() for i in idxs_inv_lens]
idxs_inv_cell = torch.split(idxs_inv, lens_per_cell)
idxs_inv_out = [
list(torch.split(idxs, ll))
for idxs, ll in zip(idxs_inv_cell, idxs_inv_lens, strict=False)
]

# If there are no tokens, return empty lists.
if num_tokens == 0:
return tokenized_data
Expand All @@ -181,7 +191,7 @@ def mask_source(
token_lens = [len(t) for t in tokenized_data]
self.perm_sel = [np.ones(tl, dtype=bool) for tl in token_lens]
source_data = [data[~p] for data, p in zip(tokenized_data, self.perm_sel, strict=True)]
return source_data
return (source_data,)

# Implementation of different masking strategies.
# Generate a flat boolean mask for random, block, or healpix masking at cell level.
Expand Down Expand Up @@ -239,16 +249,26 @@ def mask_source(

# Apply the mask to get the source data (where mask is False)
source_data = [data[~p] for data, p in zip(tokenized_data, self.perm_sel, strict=True)]
# TODO: when np.where(p) is empty then we should insert empty tensor so that cat below
# works
idxs_inv_out = [
[idxs[i] for i in np.where(p)[0]]
for idxs, p in zip(idxs_inv_out, self.perm_sel, strict=True)
]

idxs_inv = torch.cat([torch.cat(a) for a in idxs_inv_out if len(a) > 0])

return source_data
return source_data, idxs_inv

def mask_target(
self,
target_tokenized_data: list[list[torch.Tensor]],
coords: torch.Tensor,
geoinfos: torch.Tensor,
source: torch.Tensor,
) -> list[torch.Tensor]:
idxs_inv: torch.Tensor,
idxs_inv_lens: list[list[int]],
) -> (list[torch.Tensor], torch.Tensor):
"""
Applies the permutation selection mask to
the tokenized data to create the target data.
Expand All @@ -265,6 +285,14 @@ def mask_target(
list[torch.Tensor]: The target data with masked tokens, one tensor per cell.
"""

# TODO: wrap in function
lens_per_cell = [np.array(i).sum() for i in idxs_inv_lens]
idxs_inv_cell = torch.split(idxs_inv, lens_per_cell)
idxs_inv_out = [
list(torch.split(idxs, ll))
for idxs, ll in zip(idxs_inv_cell, idxs_inv_lens, strict=False)
]

# check that self.perm_sel is set, and not None with an assert statement
assert self.perm_sel is not None, "Masker.perm_sel must be set before calling mask_target."

Expand All @@ -274,9 +302,10 @@ def mask_target(
feature_dim = self.dim_time_enc + coords.shape[-1] + geoinfos.shape[-1] + source.shape[-1]

processed_target_tokens = []
processed_idxs_inv = []

# process all tokens used for embedding
for cc, pp in zip(target_tokenized_data, self.perm_sel, strict=True):
for cc, idxs, pp in zip(target_tokenized_data, idxs_inv_out, self.perm_sel, strict=True):
if self.current_strategy == "channel":
# If masking strategy is channel, handle target tokens differently.
# We don't have Booleans per cell, instead per channel per cell,
Expand All @@ -298,16 +327,20 @@ def mask_target(
else:
# For other masking strategies, we simply select the tensors where the mask is True.
selected_tensors = [c for c, p in zip(cc, pp, strict=True) if p]
selected_idxs_inv = [idxs[i] for i in np.where(pp)[0]]

# Append the selected tensors to the processed_target_tokens list.
if selected_tensors:
processed_target_tokens.append(torch.cat(selected_tensors))
processed_idxs_inv.append(selected_idxs_inv)
else:
processed_target_tokens.append(
torch.empty(0, feature_dim, dtype=coords.dtype, device=coords.device)
)

return processed_target_tokens
idxs_inv = torch.cat([torch.cat(a) for a in processed_idxs_inv if len(a) > 0])

return processed_target_tokens, idxs_inv

def _get_sampling_rate(self):
"""
Expand Down
25 changes: 14 additions & 11 deletions src/weathergen/datasets/multi_stream_data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
compute_source_cell_lens,
)
from weathergen.utils.distributed import is_root
from weathergen.utils.train_logger import Stage
from weathergen.utils.train_logger import TRAIN, Stage

type AnyDataReader = DataReaderBase | DataReaderAnemoi | DataReaderObs

Expand All @@ -45,11 +45,12 @@ def readerdata_to_torch(rdata: IOReaderData) -> IOReaderData:
"""
Convert data, coords, and geoinfos to torch tensor
"""
rdata.coords = torch.tensor(rdata.coords)
rdata.geoinfos = torch.tensor(rdata.geoinfos)
rdata.data = torch.tensor(rdata.data)

return rdata
return IOReaderData(
coords=torch.tensor(rdata.coords),
geoinfos=torch.tensor(rdata.geoinfos),
data=torch.tensor(rdata.data),
datetimes=rdata.datetimes,
)


def collect_datasources(stream_datasets: list, idx: int, type: str) -> IOReaderData:
Expand Down Expand Up @@ -225,10 +226,10 @@ def __init__(
self.num_healpix_cells: int = 12 * 4**self.healpix_level

if cf.training_mode == "forecast":
self.tokenizer = TokenizerForecast(cf.healpix_level)
self.tokenizer = TokenizerForecast(cf)
elif cf.training_mode == "masking":
masker = Masker(cf)
self.tokenizer = TokenizerMasking(cf.healpix_level, masker)
self.tokenizer = TokenizerMasking(cf, masker)
assert self.forecast_offset == 0, "masked token modeling requires auto-encoder training"
msg = "masked token modeling does not support self.input_window_steps > 1; "
msg += "increase window length"
Expand Down Expand Up @@ -388,7 +389,9 @@ def __iter__(self):
stream_ds[0].normalize_coords,
)

# TODO: rdata only be collected in validation mode
# rdata does not need to be retained in training mode, only used for output
if self._stage == TRAIN:
rdata = None
stream_data.add_source(rdata, ss_lens, ss_cells, ss_centroids)

# target
Expand Down Expand Up @@ -417,14 +420,14 @@ def __iter__(self):
stream_data.target_is_spoof = True

# preprocess data for model input
(tt_cells, tc, tt_c, tt_t) = self.tokenizer.batchify_target(
(tt_cells, tc, tt_c, tt_t, idxs_inv) = self.tokenizer.batchify_target(
stream_info,
self.sampling_rate_target,
readerdata_to_torch(rdata),
(time_win_target.start, time_win_target.end),
)

stream_data.add_target(fstep, tt_cells, tc, tt_c, tt_t)
stream_data.add_target(fstep, tt_cells, tc, tt_c, tt_t, idxs_inv)

# merge inputs for sources and targets for current stream
streams_data += [stream_data]
Expand Down
8 changes: 7 additions & 1 deletion src/weathergen/datasets/stream_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def __init__(self, idx: int, forecast_steps: int, healpix_cells: int) -> None:
self.target_tokens_lens = [
torch.tensor([0 for _ in range(self.healpix_cells)]) for _ in range(forecast_steps + 1)
]
# index to recover original ordering of data
self.target_idxs_inv = None

# source tokens per cell
self.source_tokens_cells = []
Expand Down Expand Up @@ -153,6 +155,7 @@ def add_source(
[ torch.tensor( tokens per cell, token size, number of channels) ]
ss_centroids : list(number of healpix cells )
[ torch.tensor( for source , 5) ]
idxs_inv : index to recover original ordering of datapoints

Returns
-------
Expand All @@ -174,6 +177,7 @@ def add_target(
target_coords: torch.tensor,
target_coords_raw: torch.tensor,
times_raw: torch.tensor,
idxs_inv: torch.tensor,
) -> None:
"""
Add data for target for one input.
Expand All @@ -193,6 +197,7 @@ def add_target(
target_times : list( number of healpix cells)
[ torch.tensor( points per cell) ]
absolute target times
idxs_inv : index to recover original ordering of datapoints

Returns
-------
Expand All @@ -202,7 +207,7 @@ def add_target(
self.target_tokens[fstep] = torch.cat(targets)
self.target_coords[fstep] = torch.cat(target_coords)
self.target_times_raw[fstep] = np.concatenate(times_raw)
self.target_coords_raw[fstep] = target_coords_raw
self.target_coords_raw[fstep] = torch.cat(target_coords_raw)

tc = target_coords
self.target_coords_lens[fstep] = torch.tensor(
Expand All @@ -213,6 +218,7 @@ def add_target(
[len(f) for f in targets] if len(targets) > 1 else self.target_tokens_lens[fstep],
dtype=torch.int,
)
self.target_idxs_inv = idxs_inv

def target_empty(self) -> bool:
"""
Expand Down
8 changes: 5 additions & 3 deletions src/weathergen/datasets/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@ class Tokenizer:
Base class for tokenizers.
"""

def __init__(self, healpix_level: int):
def __init__(self, cf: dict):
self.permute_target_points = cf.get(cf.permute_target_points, True)

ref = torch.tensor([1.0, 0.0, 0.0])

self.hl_source = healpix_level
self.hl_target = healpix_level
self.hl_source = cf.healpix_level
self.hl_target = cf.healpix_level

self.num_healpix_cells_source = 12 * 4**self.hl_source
self.num_healpix_cells_target = 12 * 4**self.hl_target
Expand Down
27 changes: 14 additions & 13 deletions src/weathergen/datasets/tokenizer_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,7 @@ def batchify_target(
rdata: IOReaderData,
time_win: tuple,
):
target_tokens = torch.zeros([self.num_healpix_cells_target], dtype=torch.int32)
target_coords = torch.zeros([self.num_healpix_cells_target], dtype=torch.int32)
target_tokens_lens = torch.zeros([self.num_healpix_cells_target], dtype=torch.int32)

# limit the number of target points
sampling_rate_target = stream_info.get("sampling_rate_target", sampling_rate_target)
if sampling_rate_target < 1.0:
mask = self.rng.uniform(0.0, 1.0, rdata.data.shape[0]) < sampling_rate_target
Expand All @@ -101,20 +98,24 @@ def batchify_target(

# TODO: currently treated as empty to avoid special case handling
if len(rdata.data) < 2:
target_tokens = torch.zeros([self.num_healpix_cells_target], dtype=torch.int32)
target_coords = torch.zeros([self.num_healpix_cells_target], dtype=torch.int32)
return (target_tokens, target_coords, torch.tensor([]), torch.tensor([]))

# compute indices for each cell
# compute indices for each cell (no tokenization is used/needed here since the decoding
# uses the cell structure but no tokens)
hpy_idxs_ord_split, _, _, _ = hpy_cell_splits(rdata.coords, self.hl_target)

# TODO: expose parameter
with_perm_target = True
if with_perm_target:
hpy_idxs_ord_split = [
idx[self.rng.permutation(len(idx))[: int(len(idx))]] for idx in hpy_idxs_ord_split
]
# permute the indices for each cell to improve generalization
if self.permute_target_points:
perms = [self.rng.permutation(len(idx))[: int(len(idx))] for idx in hpy_idxs_ord_split]
hpy_idxs_ord_split = [idx[p] for idx, p in zip(hpy_idxs_ord_split, perms, strict=False)]

# helper variables to split according to cells
# reordering is done in contiguous memory for efficiency
idxs_ord = np.concatenate(hpy_idxs_ord_split)
# inverse map for reordering to output data points in same order as input
idxs_ord_inv = np.argsort(idxs_ord)
# helper variables to split according to cells
ll = np.cumsum(np.array([len(a) for a in hpy_idxs_ord_split]))[:-1]

# compute encoding of time
Expand Down Expand Up @@ -146,4 +147,4 @@ def batchify_target(
target_coords.requires_grad = False
target_coords = list(target_coords.split(target_tokens_lens.tolist()))

return (target_tokens, target_coords, target_coords_raw, target_times_raw)
return (target_tokens, target_coords, target_coords_raw, target_times_raw, idxs_ord_inv)
4 changes: 3 additions & 1 deletion src/weathergen/datasets/tokenizer_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,9 @@ def id(arg):
target_coords.requires_grad = False
target_coords = list(target_coords.split(tt_lens))

return (target_tokens, target_coords, target_coords_raw, target_times_raw)
idxs_inv = None

return (target_tokens, target_coords, target_coords_raw, target_times_raw, idxs_inv)

def sample_tensors_uniform_vectorized(
self, tensor_list: list, lengths: list, max_total_points: int
Expand Down
9 changes: 6 additions & 3 deletions src/weathergen/datasets/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,10 @@ def hpy_cell_splits(coords: torch.tensor, hl: int):
phis : phis in rad
posr3 : (thetas,phis) as position in R3
"""

thetas = ((90.0 - coords[:, 0]) / 180.0) * np.pi
phis = ((coords[:, 1] + 180.0) / 360.0) * 2.0 * np.pi
# healpix cells for all points
# healpix cell index for all points
hpy_idxs = ang2pix(2**hl, thetas, phis, nest=True)
posr3 = s2tor3(thetas, phis)

Expand All @@ -115,7 +116,7 @@ def hpy_cell_splits(coords: torch.tensor, hl: int):
# extract per cell data
hpy_idxs_ord_temp = np.split(hpy_idxs_ord, splits + 1)
hpy_idxs_ord_split = [np.array([], dtype=np.int64) for _ in range(12 * 4**hl)]
# TODO: split smarter (with a augmented splits list?) so that this loop is not needed
# split according to cells
for b, x in zip(np.unique(np.unique(hpy_idxs[hpy_idxs_ord])), hpy_idxs_ord_temp, strict=True):
hpy_idxs_ord_split[b] = x

Expand Down Expand Up @@ -249,7 +250,8 @@ def tokenize_window_spacetime(
pad_tokens=True,
local_coords=True,
):
"""Tokenize respecting an intrinsic time step in the data, i.e. each time step is tokenized
"""
Tokenize respecting an intrinsic time step in the data, i.e. each time step is tokenized
separately
"""

Expand All @@ -275,6 +277,7 @@ def tokenize_window_spacetime(
local_coords,
)

# merge tokens originating from different time steps
tokens_cells = [t + tc for t, tc in zip(tokens_cells, tokens_cells_cur, strict=True)]

return tokens_cells
Expand Down
2 changes: 1 addition & 1 deletion src/weathergen/train/loss_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def compute_loss(

# normalize by all targets and forecast steps that were non-empty
# (with each having an expected loss of 1 for an uninitalized neural net)
loss = loss / ctr_streams
loss = loss / ctr_streams if ctr_streams > 0 else loss

# Return all computed loss components encapsulated in a ModelLoss dataclass
return LossValues(loss=loss, losses_all=losses_all, stddev_all=stddev_all)
15 changes: 10 additions & 5 deletions src/weathergen/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def _prepare_logging(
targets_all = [[[] for _ in self.cf.streams] for _ in range(fsteps)]
targets_lens = [[[] for _ in self.cf.streams] for _ in range(fsteps)]

# TODO: iterate over batches here in future, and change loop order to batch, stream, fstep
# TODO: iterate over batches, and change loop order to batch, stream, fstep
for fstep in range(len(targets_rt)):
for i_strm, target in enumerate(targets_rt[fstep]):
pred = preds[fstep][i_strm]
Expand All @@ -524,13 +524,18 @@ def _prepare_logging(
pred = pred.reshape([pred.shape[0], *target.shape])
assert pred.shape[1] > 0

mask_nan = ~torch.isnan(target)
if pred[:, mask_nan].shape[1] == 0:
continue

targets_lens[fstep][i_strm] += [target.shape[0]]
dn_data = self.dataset_val.denormalize_target_channels

# revert reordering of points for cells and potentially randomization
# this is only implemented for forecasting
idxs_inv = streams_data[fstep][i_strm].target_idxs_inv
if idxs_inv is not None:
targets_coords_raw[fstep][i_strm] = targets_coords_raw[fstep][i_strm][idxs_inv]
targets_times_raw[fstep][i_strm] = targets_times_raw[fstep][i_strm][idxs_inv]
pred = pred[:, idxs_inv]
target = target[idxs_inv]

f32 = torch.float32
preds_all[fstep][i_strm] += [dn_data(i_strm, pred.to(f32)).detach().cpu()]
targets_all[fstep][i_strm] += [dn_data(i_strm, target.to(f32)).detach().cpu()]
Expand Down