Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ dependencies = [
"earthkit-data>=0.12.4",
"eccodes>=2.38.3",
"numpy",
"omegaconf>=2.2,<2.4",
"omegaconf>=2.2",
"packaging",
"pydantic",
"pyyaml",
Expand Down
2 changes: 1 addition & 1 deletion src/anemoi/inference/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
class Configuration(BaseModel):
"""Configuration class."""

model_config = ConfigDict(extra="forbid")
model_config = ConfigDict(extra="allow")

date: datetime | None = None
"""The starting date for the forecast. If not provided, the date will depend on the selected Input object. If a string, it is parsed by :func:`earthkit.data.utils.dates`."""
Expand Down
18 changes: 15 additions & 3 deletions src/anemoi/inference/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,22 @@ def get_available_device() -> "torch.device":
torch.device
The available device, either 'cuda', 'mps', or 'cpu'.
"""
import os

import torch

if torch.cuda.is_available():
return torch.device("cuda")
elif torch.backends.mps.is_available():
local_rank_env = os.environ.get("LOCAL_RANK")
slurm_local = os.environ.get("SLURM_LOCALID")
if local_rank_env is not None:
local_rank = int(local_rank_env)
elif slurm_local is not None:
local_rank = int(slurm_local)
else:
local_rank = 0
torch.cuda.set_device(local_rank) # important for NCCL
return torch.device(f"cuda:{local_rank}")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): # mac fallback
return torch.device("mps")
return torch.device("cpu")
else:
return torch.device("cpu")
6 changes: 4 additions & 2 deletions src/anemoi/inference/inputs/cutout.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,15 @@ def __repr__(self):
"""Return a string representation of the Cutout object."""
return f"Cutout({self.sources})"

def create_input_state(self, *, date: Date | None) -> State:
def create_input_state(self, *, date: Date | None, **kwargs) -> State:
"""Create the input state for the given date.

Parameters
----------
date : Optional[Date]
The date for which to create the input state.
**kwargs : dict
Additional keyword arguments for the source input state creation.

Returns
-------
Expand All @@ -173,7 +175,7 @@ def create_input_state(self, *, date: Date | None) -> State:
combined_state = {}

for source in self.sources.keys():
source_state = self.sources[source].create_input_state(date=date)
source_state = self.sources[source].create_input_state(date=date, **kwargs)
source_mask = self.masks[source]

# Create the mask front padded with zeros
Expand Down
27 changes: 23 additions & 4 deletions src/anemoi/inference/inputs/ekd.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ def _create_state(
longitudes: FloatArray | None = None,
dtype: DTypeLike = np.float32,
flatten: bool = True,
ref_date_index: int = -1,
title: str = "Create state",
) -> State:
"""Create a state from an ekd.FieldList.

Expand All @@ -215,6 +217,10 @@ def _create_state(
The data type.
flatten : bool
Whether to flatten the data.
ref_date_index : int
The index of the reference date in the dates list.
title : str
The title for logging.

Returns
-------
Expand All @@ -230,15 +236,15 @@ def _create_state(
dates = sorted([to_datetime(d) for d in dates])
date_to_index = {d.isoformat(): i for i, d in enumerate(dates)}

state = dict(date=dates[-1], latitudes=latitudes, longitudes=longitudes, fields=dict())
state = dict(date=dates[ref_date_index], latitudes=latitudes, longitudes=longitudes, fields=dict())

state_fields = state["fields"]

fields = self._filter_and_sort(fields, dates=dates, title="Create input state")

if latitudes is None and longitudes is None:
try:
state["latitudes"], state["longitudes"] = fields[0].grid_points()
latitudes, longitudes = fields[0].grid_points()
LOG.info(
"%s: using `latitudes` and `longitudes` from the first input field",
self.__class__.__name__,
Expand All @@ -251,8 +257,6 @@ def _create_state(
latitudes = self.checkpoint.latitudes
longitudes = self.checkpoint.longitudes
if latitudes is not None and longitudes is not None:
state["latitudes"] = latitudes
state["longitudes"] = longitudes
LOG.info(
"%s: using `latitudes` and `longitudes` found in the checkpoint.",
self.__class__.__name__,
Expand All @@ -264,6 +268,15 @@ def _create_state(
)
raise e

state = dict(date=dates[ref_date_index], latitudes=latitudes, longitudes=longitudes, fields=fields)
state = self.pre_process(state)
fields = state["fields"]
state_fields = {}

if len(fields) == 0:
raise ValueError("No input fields provided")

# dates and date_to_index are already sorted and created earlier
check = defaultdict(set)

n_points = fields[0].to_numpy(dtype=dtype, flatten=flatten).size
Expand All @@ -288,6 +301,7 @@ def _create_state(
LOG.error("number_of_grid_points %s", self.checkpoint.number_of_grid_points)
raise

state["fields"] = state_fields
if date_idx in check[name]:
LOG.error("Duplicate dates for %s: %s", name, date_idx)
LOG.error("Expected %s", list(date_to_index.keys()))
Expand Down Expand Up @@ -326,6 +340,7 @@ def _create_input_state(
longitudes: FloatArray | None = None,
dtype: DTypeLike = np.float32,
flatten: bool = True,
ref_date_index: int = -1,
) -> State:
"""Create the input state.

Expand All @@ -345,6 +360,8 @@ def _create_input_state(
The data type.
flatten : bool
Whether to flatten the data.
ref_date_index : int
The index of the reference date in the dates list.

Returns
-------
Expand All @@ -366,6 +383,8 @@ def _create_input_state(
longitudes=longitudes,
dtype=dtype,
flatten=flatten,
ref_date_index=ref_date_index,
title="Create input state",
)

def _load_forcings_state(self, fields: ekd.FieldList, *, dates: list[Date], current_state: State) -> State:
Expand Down
10 changes: 7 additions & 3 deletions src/anemoi/inference/inputs/gribfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,25 @@ def __init__(
super().__init__(context, **kwargs)
self.path = path

def create_input_state(self, *, date: Date | None) -> State:
def create_input_state(self, *, date: Date | None, **kwargs) -> State:
"""Create the input state for the given date.

Parameters
----------
date : Optional[Date]
The date for which to create the input state.
**kwargs : Any
Additional keyword arguments, including:
- ref_date_index: int, default -1
The reference date index to use.

Returns
-------
State
The created input state.
"""

return self._create_input_state(self._fieldlist, date=date)
ref_date_index = kwargs.get("ref_date_index", -1)
return self._create_input_state(self._fieldlist, date=date, ref_date_index=ref_date_index)

def load_forcings_state(self, *, dates: list[Date], current_state: State) -> State:
"""Load the forcings state for the given variables and dates.
Expand Down
128 changes: 128 additions & 0 deletions src/anemoi/inference/pre_processors/extract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# (C) Copyright 2025- Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


import logging
from pathlib import Path

import numpy as np
from anemoi.transform.fields import new_field_from_numpy
from anemoi.transform.fields import new_fieldlist_from_list

from anemoi.inference.context import Context
from anemoi.inference.decorators import main_argument
from anemoi.inference.types import BoolArray
from anemoi.inference.types import State

from ..processor import Processor
from . import pre_processor_registry

LOG = logging.getLogger(__name__)


class ExtractBase(Processor):
"""Base class for processors that extract data from the state."""

# this needs to be set in subclasses
indexer: BoolArray | slice

def process(self, state: State) -> State:
"""Process the state to extract a subset of points based on the indexer.

Parameters
----------
state : State
The state containing fields to be extracted.

Returns
-------
State
The updated state with extracted fields.
"""
state = state.copy()
result = []
for field in state["fields"]:
data = field.to_numpy().flatten()[self.indexer]
result.append(new_field_from_numpy(data, template=field))
state["fields"] = new_fieldlist_from_list(result)
state["latitudes"] = state["latitudes"][self.indexer]
state["longitudes"] = state["longitudes"][self.indexer]

return state


@pre_processor_registry.register("extract_mask")
@main_argument("mask")
class ExtractMask(ExtractBase):
"""Extract a subset of points from the state based on a boolean mask.

Parameters
----------
context : Any
The context in which the processor is running.
mask : str
Either a path to a `.npy` file containing the boolean mask or
the name of a supporting array found in the checkpoint.
"""

def __init__(self, context: Context, mask: str) -> None:
super().__init__(context)

self._maskname = mask

if Path(mask).is_file():
mask = np.load(mask)
else:
mask = context.checkpoint.load_supporting_array(mask)

if not isinstance(mask, np.ndarray) or mask.dtype != bool:
raise ValueError(
"Expected the mask to be a boolean numpy array. " f"Got {type(mask)} with dtype {mask.dtype}."
)

self.indexer = mask
self.npoints = np.sum(mask)

def __repr__(self) -> str:
"""Return a string representation of the ExtractMask object.

Returns
-------
str
String representation of the object.
"""
return f"ExtractMask({self._maskname}, points={self.npoints}/{self.indexer.size})"


@pre_processor_registry.register("extract_slice")
class ExtractSlice(ExtractBase):
"""Extract a subset of points from the state based on a slice.

Parameters
----------
context : Context
The context in which the processor is running.
slice_args : int
Arguments to create a slice object. This can be a single integer or
a tuple of integers representing the start, stop, and step of the slice.
"""

def __init__(self, context: Context, *slice_args: int) -> None:
super().__init__(context)
self.indexer = slice(*slice_args)

def __repr__(self) -> str:
"""Return a string representation of the ExtractSlice object.

Returns
-------
str
String representation of the object.
"""
return f"ExtractSlice({self.indexer})"
17 changes: 10 additions & 7 deletions src/anemoi/inference/pre_processors/forward_transform_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
import logging
from typing import Any

import earthkit.data as ekd
from anemoi.transform.filters import filter_registry

from anemoi.inference.decorators import main_argument
from anemoi.inference.types import State

from ..processor import Processor
from . import pre_processor_registry
Expand Down Expand Up @@ -51,20 +51,23 @@ def __init__(self, context: Any, filter: str, **kwargs: Any) -> None:
super().__init__(context)
self.filter = filter_registry.create(filter, **kwargs)

def process(self, fields: ekd.FieldList) -> ekd.FieldList:
def process(self, state: State) -> State:
"""Process the given fields using the forward filter.

Parameters
----------
fields : ekd.FieldList
The fields to be processed.
state : State
The state containing the fields to be processed.

Returns
-------
ekd.FieldList
The processed fields.
State
The processed state.
"""
return self.filter.forward(fields)
fields = state["fields"].copy()
result = self.filter.forward(fields)
state["fields"] = result
return state

def patch_data_request(self, data_request: Any) -> Any:
"""Patch the data request using the filter.
Expand Down
Loading
Loading