Skip to content
Open
1 change: 1 addition & 0 deletions config/evaluate/eval_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ evaluation:
summary_plots : true
summary_dir: "./plots/"
plot_ensemble: "members" #supported: false, "std", "minmax", "members"
plot_score_maps: false #plot scores on a 2D maps. it slows down score computation
print_summary: false #print out score values on screen. it can be verbose
log_scale: false
add_grid: false
Expand Down
15 changes: 10 additions & 5 deletions packages/evaluate/src/weathergen/evaluate/clim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,13 @@ def align_clim_data(
for fstep, target_data in target_output.items():
samples = np.unique(target_data.sample.values)
for sample in tqdm(samples, f"Aligning climatology for forecast step {fstep}"):
sample_mask = target_data.sample.values == sample
timestamp = target_data.valid_time.values[sample_mask][0]
sel_key = "sample" if "sample" in target_data.dims else "ipoint"
sel_val = (
sample if "sample" in target_data.dims else (target_data.sample.values == sample)
)
sel_mask = {sel_key: sel_val}

timestamp = target_data.sel(sel_mask).valid_time.values[0]
# Prepare climatology data for each sample
matching_time_idx = match_climatology_time(timestamp, clim_data)

Expand All @@ -141,8 +146,8 @@ def align_clim_data(
)
.transpose("grid_points", "channels") # dimensions specific to anemoi
)
target_lats = target_data.loc[{"ipoint": sample_mask}].lat.values
target_lons = target_data.loc[{"ipoint": sample_mask}].lon.values
target_lats = target_data.loc[sel_mask].lat.values
target_lons = target_data.loc[sel_mask].lon.values
# check if target coords match cached target coords
# if they do, use cached clim_indices
if (
Expand Down Expand Up @@ -174,7 +179,7 @@ def align_clim_data(
clim_values = prepared_clim_data.isel(grid_points=clim_indices).values
try:
if len(samples) > 1:
aligned_clim_data[fstep].loc[{"ipoint": sample_mask}] = clim_values
aligned_clim_data[fstep].loc[sel_mask] = clim_values
else:
aligned_clim_data[fstep] = clim_values
except (ValueError, IndexError) as e:
Expand Down
180 changes: 136 additions & 44 deletions packages/evaluate/src/weathergen/evaluate/io_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict[str, str] |
self.run_id = run_id
self.private_paths = private_paths
self.streams = eval_cfg.streams.keys()
self.data = None
# TODO: propagate it to the other functions using global plotting opts
self.global_plotting_options = eval_cfg.get("global_plotting_options", {})

# If results_base_dir and model_base_dir are not provided, default paths are used
self.model_base_dir = self.eval_cfg.get("model_base_dir", None)
Expand Down Expand Up @@ -130,6 +131,13 @@ def get_ensemble(self, stream: str | None = None) -> list[str]:
"""Placeholder implementation ensemble member names getter. Override in subclass."""
return list()

def is_regular(self, stream: str) -> bool:
"""
Placeholder implementation to check if lat/lon are regularly spaced.
Override in subclass.
"""
return True

def load_scores(self, stream: str, region: str, metric: str) -> xr.DataArray:
"""Placeholder to load pre-computed scores for a given run, stream, metric"""
return None
Expand Down Expand Up @@ -496,9 +504,9 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non

if not self.fname_zarr.exists() or not self.fname_zarr.is_dir():
_logger.error(f"Zarr file {self.fname_zarr} does not exist.")
# raise FileNotFoundError(
# f"Zarr file {self.fname_zarr} does not exist or is not a directory."
# )
raise FileNotFoundError(
f"Zarr file {self.fname_zarr} does not exist or is not a directory."
)

def get_inference_config(self):
"""
Expand Down Expand Up @@ -610,8 +618,7 @@ def get_data(

for fstep in fsteps:
_logger.info(f"RUN {self.run_id} - {stream}: Processing fstep {fstep}...")
da_tars_fs, da_preds_fs = [], []
pps = []
da_tars_fs, da_preds_fs, pps = [], [], []

for sample in tqdm(samples, desc=f"Processing {self.run_id} - {stream} - {fstep}"):
out = zio.get_data(sample, stream, fstep)
Expand Down Expand Up @@ -642,60 +649,57 @@ def get_data(
_logger.debug(f"Selecting ensemble members {ensemble}.")
pred = pred.sel(ens=ensemble)

if ensemble == ["mean"]:
_logger.debug("Averaging over ensemble members.")
pred = pred.mean("ens", keepdims=True)
else:
_logger.debug(f"Selecting ensemble members {ensemble}.")
pred = pred.sel(ens=ensemble)

da_tars_fs.append(target.squeeze())
da_preds_fs.append(pred.squeeze())

if len(da_tars_fs) > 0:
fsteps_final.append(fstep)
if not da_tars_fs:
_logger.info(
f"[{self.run_id} - {stream}] No valid data found for fstep {fstep}."
)
continue

fsteps_final.append(fstep)

_logger.debug(
f"Concatenating targets and predictions for stream {stream}, "
f"forecast_step {fstep}..."
)

if da_tars_fs:
# faster processing
if self.is_regular(stream):
# Efficient concatenation for regular grid
da_preds_fs = _force_consistent_grids(da_preds_fs)
da_tars_fs = _force_consistent_grids(da_tars_fs)

else:
# Irregular (scatter) case. concatenate over ipoint
da_tars_fs = xr.concat(da_tars_fs, dim="ipoint")
da_preds_fs = xr.concat(da_preds_fs, dim="ipoint")
if len(samples) == 1:
# Ensure sample coordinate is repeated along ipoint even if only one sample
da_tars_fs = da_tars_fs.assign_coords(
sample=(
"ipoint",
np.repeat(da_tars_fs.sample.values, len(da_tars_fs.ipoint)),
)
)
da_preds_fs = da_preds_fs.assign_coords(
sample=(
"ipoint",
np.repeat(da_preds_fs.sample.values, len(da_preds_fs.ipoint)),
)
)

if set(channels) != set(all_channels):
_logger.debug(
f"Restricting targets and predictions to channels {channels} "
f"for stream {stream}..."
if len(samples) == 1:
_logger.debug("Repeating sample coordinate for single-sample case.")
for da in (da_tars_fs, da_preds_fs):
da.assign_coords(
sample=("ipoint", np.repeat(da.sample.values, da.sizes["ipoint"]))
)

da_tars_fs, da_preds_fs, channels = dc.get_derived_channels(
da_tars_fs, da_preds_fs
)
if set(channels) != set(all_channels):
_logger.debug(
f"Restricting targets and predictions to channels {channels} "
f"for stream {stream}..."
)

da_tars_fs = da_tars_fs.sel(channel=channels)
da_preds_fs = da_preds_fs.sel(channel=channels)
da_tars_fs, da_preds_fs, channels = dc.get_derived_channels(
da_tars_fs, da_preds_fs
)

da_tars.append(da_tars_fs)
da_preds.append(da_preds_fs)
da_tars_fs = da_tars_fs.sel(channel=channels)
da_preds_fs = da_preds_fs.sel(channel=channels)

if return_counts:
points_per_sample.loc[{"forecast_step": fstep}] = np.array(pps)
da_tars.append(da_tars_fs)
da_preds.append(da_preds_fs)
if return_counts:
points_per_sample.loc[{"forecast_step": fstep}] = np.array(pps)

# Safer than a list
da_tars = {fstep: da for fstep, da in zip(fsteps_final, da_tars, strict=True)}
Expand Down Expand Up @@ -796,14 +800,65 @@ def get_channels(self, stream: str) -> list[str]:
return all_channels

def get_ensemble(self, stream: str | None = None) -> list[str]:
"""Get the list of ensemble member names for a given stream from the config."""
"""Get the list of ensemble member names for a given stream from the config.
Parameters
----------
stream : str
The name of the stream to get channels for.

Returns
-------
list[str]
A list of ensemble members.
"""
_logger.debug(f"Getting ensembles for stream {stream}...")

# TODO: improve this to get ensemble from io class
with ZarrIO(self.fname_zarr) as zio:
dummy = zio.get_data(0, stream, zio.forecast_steps[0])
return list(dummy.prediction.as_xarray().coords["ens"].values)

# TODO: improve this
def is_regular(self, stream: str) -> bool:
"""Check if the latitude and longitude coordinates are regularly spaced for a given stream.
Parameters
----------
stream : str
The name of the stream to get channels for.

Returns
-------
bool
True if the stream is regularly spaced. False otherwise.
"""
_logger.debug(f"Checking regular spacing for stream {stream}...")

with ZarrIO(self.fname_zarr) as zio:
dummy = zio.get_data(0, stream, zio.forecast_steps[0])

sample_idx = zio.samples[1] if len(zio.samples) > 1 else zio.samples[0]
fstep_idx = (
zio.forecast_steps[1] if len(zio.forecast_steps) > 1 else zio.forecast_steps[0]
)
dummy1 = zio.get_data(sample_idx, stream, fstep_idx)

da = dummy.prediction.as_xarray()
da1 = dummy1.prediction.as_xarray()

if (
da["lat"].shape != da1["lat"].shape
or da["lon"].shape != da1["lon"].shape
or not (
np.allclose(sorted(da["lat"].values), sorted(da1["lat"].values))
and np.allclose(sorted(da["lon"].values), sorted(da1["lon"].values))
)
):
_logger.debug("Latitude and/or longitude coordinates are not regularly spaced.")
return False

_logger.debug("Latitude and longitude coordinates are regularly spaced.")
return True

def load_scores(self, stream: str, region: str, metric: str) -> xr.DataArray | None:
"""
Load the pre-computed scores for a given run, stream and metric and epoch.
Expand Down Expand Up @@ -859,3 +914,40 @@ def get_inference_stream_attr(self, stream_name: str, key: str, default=None):
if stream.get("name") == stream_name:
return stream.get(key, default)
return default


################### Helper functions ########################


def _force_consistent_grids(ref: list[xr.DataArray]) -> xr.DataArray:
"""
Force all samples to share the same ipoint order.

Parameters
----------
ref:
Input dataset
Returns
-------
xr.DataArray
Returns a Dataset where all samples have the same lat lon and ipoint ordering
"""

# Pick first sample as reference
ref_lat = ref[0].lat
ref_lon = ref[0].lon

sort_idx = np.lexsort((ref_lon.values, ref_lat.values))
npoints = sort_idx.size
aligned = []
for a in ref:
a_sorted = a.isel(ipoint=sort_idx)

a_sorted = a_sorted.assign_coords(
ipoint=np.arange(npoints),
lat=("ipoint", ref_lat.values[sort_idx]),
lon=("ipoint", ref_lon.values[sort_idx]),
)
aligned.append(a_sorted)

return xr.concat(aligned, dim="sample")
Loading