diff --git a/src/anemoi/inference/input.py b/src/anemoi/inference/input.py index d1d49307..013d0c18 100644 --- a/src/anemoi/inference/input.py +++ b/src/anemoi/inference/input.py @@ -113,13 +113,15 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({self.purpose})" @abstractmethod - def create_input_state(self, *, date: Date | None) -> State: + def create_input_state(self, *, date: Date | None, **kwargs) -> State: """Create the input state dictionary. Parameters ---------- date : Optional[Date] The date for which to create the input state. + **kwargs : Any + Additional keyword arguments. Returns ------- diff --git a/src/anemoi/inference/inputs/cds.py b/src/anemoi/inference/inputs/cds.py index 0a3c1bfc..1404d44e 100644 --- a/src/anemoi/inference/inputs/cds.py +++ b/src/anemoi/inference/inputs/cds.py @@ -138,13 +138,15 @@ def __init__( self.dataset = dataset self.kwargs = kwargs - def create_input_state(self, *, date: Date | None) -> State: + def create_input_state(self, *, date: Date | None, **kwargs) -> State: """Create the input state for the given date. Parameters ---------- date : Optional[Date] The date for which to create the input state. + **kwargs : Any + Additional keyword arguments. Returns ------- @@ -162,6 +164,7 @@ def create_input_state(self, *, date: Date | None) -> State: ), variables=self.variables, date=date, + **kwargs, ) def retrieve(self, variables: list[str], dates: list[Date]) -> Any: diff --git a/src/anemoi/inference/inputs/cutout.py b/src/anemoi/inference/inputs/cutout.py index dc6e79df..ac37f6c9 100644 --- a/src/anemoi/inference/inputs/cutout.py +++ b/src/anemoi/inference/inputs/cutout.py @@ -11,20 +11,31 @@ import logging from collections import defaultdict from collections.abc import Iterable +from collections.abc import Mapping import numpy as np from anemoi.inference.input import Input from anemoi.inference.inputs import create_input +from anemoi.inference.inputs import input_registry from anemoi.inference.types import Date from anemoi.inference.types import ProcessorConfig from anemoi.inference.types import State -from . import input_registry - LOG = logging.getLogger(__name__) +def contains_key(obj, key: str) -> bool: + """Recursively check if `key` exists anywhere in a nested config (dict/DotDict/lists).""" + if isinstance(obj, Mapping): + if key in obj: + return True + return any(contains_key(v, key) for v in obj.values()) + if isinstance(obj, (list, tuple, set)): + return any(contains_key(v, key) for v in obj) + return False + + def _mask_and_combine_states( existing_state: State, new_state: State, @@ -138,9 +149,12 @@ def __init__( cfg = cfg.copy() mask = cfg.pop("mask", f"{src}/cutout_mask") - self.sources[src] = create_input( - context, cfg, variables=variables, pre_processors=pre_processors, purpose=purpose - ) + if contains_key(cfg, "pre_processors"): + self.sources[src] = create_input(context, cfg, variables=variables, purpose=purpose) + else: + self.sources[src] = create_input( + context, cfg, variables=variables, purpose=purpose, pre_processors=pre_processors + ) if isinstance(mask, str): self.masks[src] = self.sources[src].checkpoint.load_supporting_array(mask) @@ -151,13 +165,15 @@ def __repr__(self): """Return a string representation of the Cutout object.""" return f"Cutout({self.sources})" - def create_input_state(self, *, date: Date | None) -> State: + def create_input_state(self, *, date: Date | None, **kwargs) -> State: """Create the input state for the given date. Parameters ---------- date : Optional[Date] The date for which to create the input state. + **kwargs : dict + Additional keyword arguments for the source input state creation. Returns ------- @@ -173,7 +189,7 @@ def create_input_state(self, *, date: Date | None) -> State: combined_state = {} for source in self.sources.keys(): - source_state = self.sources[source].create_input_state(date=date) + source_state = self.sources[source].create_input_state(date=date, **kwargs) source_mask = self.masks[source] # Create the mask front padded with zeros diff --git a/src/anemoi/inference/inputs/dataset.py b/src/anemoi/inference/inputs/dataset.py index f9506b05..4f514f58 100644 --- a/src/anemoi/inference/inputs/dataset.py +++ b/src/anemoi/inference/inputs/dataset.py @@ -104,13 +104,15 @@ def __repr__(self) -> str: """Return a string representation of the DatasetInput.""" return f"DatasetInput({self.open_dataset_args}, {self.open_dataset_kwargs})" - def create_input_state(self, *, date: Date | None = None) -> State: + def create_input_state(self, *, date: Date | None = None, **kwargs) -> State: """Create the input state for the given date. Parameters ---------- date : Optional[Any] The date for which to create the input state. + **kwargs : Any + Additional keyword arguments. Returns ------- diff --git a/src/anemoi/inference/inputs/ekd.py b/src/anemoi/inference/inputs/ekd.py index 0ffb73bf..3219a09a 100644 --- a/src/anemoi/inference/inputs/ekd.py +++ b/src/anemoi/inference/inputs/ekd.py @@ -111,8 +111,6 @@ def __init__( ---------- context : Any The context in which the input is used. - pre_processors : Optional[List[ProcessorConfig]], default None - Pre-processors to apply to the input namer : Optional[Union[Callable[[Any, Dict[str, Any]], str], Dict[str, Any]]] Optional namer for the input. """ @@ -128,12 +126,12 @@ def __init__( assert callable(self._namer), type(self._namer) def _filter_and_sort(self, data: Any, *, dates: list[Any], title: str) -> Any: - """Filter and sort the data. + """Filter and sort the data (earthkit FieldList/FieldArray). Parameters ---------- data : Any - The data to filter and sort. + The data to filter and sort (FieldList or FieldArray). dates : List[Any] The list of dates to select. title : str @@ -142,7 +140,7 @@ def _filter_and_sort(self, data: Any, *, dates: list[Any], title: str) -> Any: Returns ------- Any - The filtered and sorted data. + The filtered and sorted data (FieldArray). """ def _name(field: Any, _: Any, original_metadata: dict[str, Any]) -> str: @@ -166,12 +164,12 @@ def _name(field: Any, _: Any, original_metadata: dict[str, Any]) -> str: return data def _find_variable(self, data: Any, name: str, **kwargs: Any) -> Any: - """Find a variable in the data. + """Find a variable in the data (earthkit FieldList/FieldArray selection). Parameters ---------- data : Any - The data to search. + The data to search (FieldList or FieldArray). name : str The name of the variable to find. **kwargs : Any @@ -180,7 +178,7 @@ def _find_variable(self, data: Any, name: str, **kwargs: Any) -> Any: Returns ------- Any - The selected variable. + The selected variable (FieldArray subset). """ def _name(field: Any, _: Any, original_metadata: dict[str, Any]) -> str: @@ -198,9 +196,17 @@ def _create_state( longitudes: FloatArray | None = None, dtype: DTypeLike = np.float32, flatten: bool = True, + ref_date_index: int = -1, ) -> State: """Create a state from an ekd.FieldList. + Notes + ----- + - The `fields` argument must be an earthkit FieldList (or FieldArray-compatible). + - This method intentionally converts state["fields"] from a FieldList to + a Dict[str, np.ndarray] with shape (len(dates), n_points). + - Pre-processors are run while state["fields"] is still a FieldList. + Parameters ---------- fields : ekd.FieldList @@ -215,30 +221,17 @@ def _create_state( The data type. flatten : bool Whether to flatten the data. + ref_date_index : int + The index of the reference date in the dates list. Returns ------- State - The created input state. + The created input state with state["fields"] as Dict[str, np.ndarray]. """ - fields = self.pre_process(fields) - - dates = sorted([to_datetime(d) for d in dates]) - date_to_index = {d.isoformat(): i for i, d in enumerate(dates)} - - state = dict(date=dates[-1], latitudes=latitudes, longitudes=longitudes, fields=dict()) - - if len(fields) == 0: - LOG.warning("No input fields found for dates %s (%s)", [d.isoformat() for d in dates], self) - return state - - state_fields = state["fields"] - - fields = self._filter_and_sort(fields, dates=dates, title="Create input state") - if latitudes is None and longitudes is None: try: - state["latitudes"], state["longitudes"] = fields[0].grid_points() + latitudes, longitudes = fields[0].grid_points() LOG.info( "%s: using `latitudes` and `longitudes` from the first input field", self.__class__.__name__, @@ -251,8 +244,6 @@ def _create_state( latitudes = self.checkpoint.latitudes longitudes = self.checkpoint.longitudes if latitudes is not None and longitudes is not None: - state["latitudes"] = latitudes - state["longitudes"] = longitudes LOG.info( "%s: using `latitudes` and `longitudes` found in the checkpoint.", self.__class__.__name__, @@ -264,6 +255,21 @@ def _create_state( ) raise e + state = dict(date=dates[ref_date_index], latitudes=latitudes, longitudes=longitudes, fields=fields) + + # allow hooks to operate on the FieldList before conversion to numpy + state = self.pre_process(state) + + fields = state["fields"] + state_fields = {} + + if len(fields) == 0: + raise ValueError("No input fields provided") + + dates = sorted([to_datetime(d) for d in dates]) + date_to_index = {d.isoformat(): i for i, d in enumerate(dates)} + fields = self._filter_and_sort(fields, dates=dates, title="Create input state") + check = defaultdict(set) n_points = fields[0].to_numpy(dtype=dtype, flatten=flatten).size @@ -295,7 +301,7 @@ def _create_state( raise ValueError(f"Duplicate dates for {name}") check[name].add(date_idx) - + state["fields"] = state_fields for name, idx in check.items(): if len(idx) != len(dates): LOG.error("Missing dates for %s: %s", name, idx) @@ -326,6 +332,7 @@ def _create_input_state( longitudes: FloatArray | None = None, dtype: DTypeLike = np.float32, flatten: bool = True, + ref_date_index: int = -1, ) -> State: """Create the input state. @@ -345,6 +352,8 @@ def _create_input_state( The data type. flatten : bool Whether to flatten the data. + ref_date_index : int + The index of the reference date in the dates list. Returns ------- @@ -366,6 +375,7 @@ def _create_input_state( longitudes=longitudes, dtype=dtype, flatten=flatten, + ref_date_index=ref_date_index, ) def _load_forcings_state(self, fields: ekd.FieldList, *, dates: list[Date], current_state: State) -> State: @@ -397,7 +407,7 @@ def _load_forcings_state(self, fields: ekd.FieldList, *, dates: list[Date], curr def set_private_attributes(self, state: State, fields: ekd.FieldList) -> None: # type: ignore """Set private attributes to the state. - Provides geography information if available retrieved from the fields. + Provides geography information if available retrieved from the fields (FieldList/FieldArray). """ geography_information = {} diff --git a/src/anemoi/inference/inputs/empty.py b/src/anemoi/inference/inputs/empty.py index 794d1e36..230592bc 100644 --- a/src/anemoi/inference/inputs/empty.py +++ b/src/anemoi/inference/inputs/empty.py @@ -46,13 +46,15 @@ def __init__(self, context: Context, **kwargs: Any) -> None: super().__init__(context, **kwargs) assert self.variables in (None, []), "EmptyInput should not have variables" - def create_input_state(self, *, date: Date | None) -> State: + def create_input_state(self, *, date: Date | None, **kwargs) -> State: """Create an empty input state. Parameters ---------- date : Date or None The date for the input state. + **kwargs : Any + Additional keyword arguments. Returns ------- diff --git a/src/anemoi/inference/inputs/fdb.py b/src/anemoi/inference/inputs/fdb.py index 6838068d..1614296a 100644 --- a/src/anemoi/inference/inputs/fdb.py +++ b/src/anemoi/inference/inputs/fdb.py @@ -55,11 +55,11 @@ def __init__( # NOTE: this is a temporary workaround for #191 thus not documented self.param_id_map = kwargs.pop("param_id_map", {}) - def create_input_state(self, *, date: Date | None) -> State: + def create_input_state(self, *, date: Date | None, **kwargs) -> State: date = np.datetime64(date).astype(datetime.datetime) dates = [date + h for h in self.checkpoint.lagged] ds = self.retrieve(variables=self.variables, dates=dates) - res = self._create_input_state(ds, variables=None, date=date) + res = self._create_input_state(ds, variables=None, date=date, **kwargs) return res def load_forcings_state(self, *, dates: list[Date], current_state: State) -> State: diff --git a/src/anemoi/inference/inputs/gribfile.py b/src/anemoi/inference/inputs/gribfile.py index aff5bfef..60d462c8 100644 --- a/src/anemoi/inference/inputs/gribfile.py +++ b/src/anemoi/inference/inputs/gribfile.py @@ -8,6 +8,7 @@ # nor does it submit to any jurisdiction. +import glob import logging import os from functools import cached_property @@ -47,9 +48,11 @@ def __init__( context : Any The context in which the input is used. path : str - The path to the GRIB file. - pre_processors : Optional[List[ProcessorConfig]], default None - Pre-processors to apply to the input + Path, directory or glob pattern to GRIB file(s). Examples: + - "/path/to/file.grib" + - "/path/to/*.grib" + - "/path/to/**/*.grib2" + - "/path/to/directory/" namer : Optional[Any] Optional namer for the input. **kwargs : Any @@ -58,21 +61,24 @@ def __init__( super().__init__(context, **kwargs) self.path = path - def create_input_state(self, *, date: Date | None) -> State: + def create_input_state(self, *, date: Date | None, ref_date_index: int = -1, **kwargs) -> State: """Create the input state for the given date. Parameters ---------- date : Optional[Date] The date for which to create the input state. + ref_date_index : int, default -1 + The reference date index to use. + **kwargs : Any + Additional keyword arguments. Returns ------- State The created input state. """ - - return self._create_input_state(self._fieldlist, date=date) + return self._create_input_state(self._fieldlist, date=date, ref_date_index=ref_date_index) def load_forcings_state(self, *, dates: list[Date], current_state: State) -> State: """Load the forcings state for the given variables and dates. @@ -98,8 +104,37 @@ def load_forcings_state(self, *, dates: list[Date], current_state: State) -> Sta @cached_property def _fieldlist(self) -> ekd.FieldList: - """Get the input fieldlist from the GRIB file.""" - if os.path.getsize(self.path) == 0: - LOG.warning("GRIB file %r is empty", self.path) + """Get the input fieldlist from the GRIB file or collection.""" + path = self.path + + # Case 1: explicit glob pattern + if glob.has_magic(path): + matches = glob.glob(path, recursive=True) + files = [p for p in matches if os.path.isfile(p)] + if not files: + LOG.warning("No GRIB files matched pattern %r", path) + return ekd.from_source("empty") + return ekd.from_source("file", sorted(files)) + + # Case 2: directory path -> search for GRIB files recursively + if os.path.isdir(path): + patterns = ("*.grib", "*.grib1", "*.grib2", "*.grb", "*.grb2") + files = [] + for pat in patterns: + files.extend(glob.glob(os.path.join(path, "**", pat), recursive=True)) + files = [f for f in sorted(set(files)) if os.path.isfile(f)] + if not files: + LOG.warning("GRIB directory %r contains no GRIB files", path) + return ekd.from_source("empty") + return ekd.from_source("file", files) + + # Case 3: single file path + try: + if os.path.getsize(path) == 0: + LOG.warning("GRIB file %r is empty", path) + return ekd.from_source("empty") + except FileNotFoundError: + LOG.warning("GRIB path %r not found", path) return ekd.from_source("empty") - return ekd.from_source("file", self.path) + + return ekd.from_source("file", path) diff --git a/src/anemoi/inference/inputs/icon.py b/src/anemoi/inference/inputs/icon.py index 8de05737..b0962454 100644 --- a/src/anemoi/inference/inputs/icon.py +++ b/src/anemoi/inference/inputs/icon.py @@ -63,13 +63,15 @@ def __init__( self.grid = grid self.refinement_level_c = refinement_level_c - def create_input_state(self, *, date: Date | None) -> State: + def create_input_state(self, *, date: Date | None, **kwargs) -> State: """Creates the input state for the given date. Parameters ---------- date : Optional[Date] The date for which to create the input state. + **kwargs : Any + Additional keyword arguments. Returns ------- @@ -86,6 +88,7 @@ def create_input_state(self, *, date: Date | None) -> State: date=date, latitudes=latitudes, longitudes=longitudes, + **kwargs, ) def load_forcings_state(self, *, dates: list[Date], current_state: State) -> State: diff --git a/src/anemoi/inference/inputs/mars.py b/src/anemoi/inference/inputs/mars.py index d4b66b23..9a8eddc6 100644 --- a/src/anemoi/inference/inputs/mars.py +++ b/src/anemoi/inference/inputs/mars.py @@ -239,13 +239,15 @@ def __init__( self.patches = patches or [] self.log = log - def create_input_state(self, *, date: Date | None) -> State: + def create_input_state(self, *, date: Date | None, **kwargs) -> State: """Create the input state for the given date. Parameters ---------- date : Optional[Date] The date for which to create the input state. + **kwargs : Any + Additional keyword arguments. Returns ------- @@ -263,6 +265,7 @@ def create_input_state(self, *, date: Date | None) -> State: ), variables=self.variables, date=date, + **kwargs, ) def retrieve(self, variables: list[str], dates: list[Date]) -> Any: diff --git a/src/anemoi/inference/inputs/repeated_dates.py b/src/anemoi/inference/inputs/repeated_dates.py index 8bc3df03..38a889e3 100644 --- a/src/anemoi/inference/inputs/repeated_dates.py +++ b/src/anemoi/inference/inputs/repeated_dates.py @@ -48,13 +48,15 @@ def __init__(self, context: Context, *, source: str, mode: str = "constant", **k super().__init__(context, **kwargs) self.source = create_input(context, source, variables=self.variables, purpose=self.purpose) - def create_input_state(self, *, date: Date | None) -> State: + def create_input_state(self, *, date: Date | None, **kwargs) -> State: """Create the input state for the repeated-dates input. Parameters ---------- date : Date or None The date for the input state. + **kwargs : Any + Additional keyword arguments. Returns ------- @@ -63,7 +65,7 @@ def create_input_state(self, *, date: Date | None) -> State: """ # TODO: Consider caching the result - state = self.source.create_input_state(date=self.date) + state = self.source.create_input_state(date=self.date, **kwargs) state["_input"] = self state["date"] = date return state diff --git a/src/anemoi/inference/inputs/split.py b/src/anemoi/inference/inputs/split.py index f60b35bb..6451540d 100644 --- a/src/anemoi/inference/inputs/split.py +++ b/src/anemoi/inference/inputs/split.py @@ -100,13 +100,15 @@ def __init__(self, context: Context, *splits, **kwargs: Any) -> None: super().__init__(context, **kwargs) - def create_input_state(self, *, date: Date | None) -> State: + def create_input_state(self, *, date: Date | None, **kwargs) -> State: """Create the input state for the repeated-dates input. Parameters ---------- date : Date or None The date for the input state. + **kwargs : Any + Additional keyword arguments. Returns ------- @@ -115,7 +117,7 @@ def create_input_state(self, *, date: Date | None) -> State: """ # TODO: Consider caching the result - states = [split.create_input_state(date=date) for split in self.splits] + states = [split.create_input_state(date=date, **kwargs) for split in self.splits] state = combine_states(*states) diff --git a/src/anemoi/inference/output.py b/src/anemoi/inference/output.py index 42714f7a..35493c4f 100644 --- a/src/anemoi/inference/output.py +++ b/src/anemoi/inference/output.py @@ -51,7 +51,7 @@ def __init__( """ self.context = context self.checkpoint = context.checkpoint - self.reference_date = None + self.reference_date = context.reference_date self._post_processor_confs = post_processors or [] diff --git a/src/anemoi/inference/plugin.py b/src/anemoi/inference/plugin.py index 74c64ab1..5bd12610 100644 --- a/src/anemoi/inference/plugin.py +++ b/src/anemoi/inference/plugin.py @@ -43,20 +43,22 @@ def __init__(self, context: Any, *, input_fields: Any, **kwargs) -> None: super().__init__(context, **kwargs) self.input_fields = input_fields - def create_input_state(self, *, date: Date | None) -> Any: + def create_input_state(self, *, date: Date | None, **kwargs) -> Any: """Create the input state for the given date. Parameters ---------- date : str The date for which to create the input state. + **kwargs : Any + Additional keyword arguments. Returns ------- Any The created input state. """ - return self._create_input_state(self.input_fields, date=date) + return self._create_input_state(self.input_fields, date=date, **kwargs) def load_forcings_state( self, diff --git a/src/anemoi/inference/post_processors/extract.py b/src/anemoi/inference/post_processors/extract.py index 983295a3..c56189b0 100644 --- a/src/anemoi/inference/post_processors/extract.py +++ b/src/anemoi/inference/post_processors/extract.py @@ -30,6 +30,24 @@ class ExtractBase(Processor): # this needs to be set in subclasses indexer: BoolArray | slice + def __init__(self, context: Context) -> None: + super().__init__(context) + self._indexer: BoolArray | slice | None = None + + @property + def indexer(self) -> BoolArray | slice: + if self._indexer is None: + raise RuntimeError(f"{type(self).__name__}.indexer is not set. Set it before process().") + return self._indexer + + @indexer.setter + def indexer(self, value: BoolArray | slice) -> None: + if isinstance(value, slice): + self._indexer = value + return + arr = np.asarray(value) + self._indexer = arr + def process(self, state: State) -> State: """Process the state to extract a subset of points based on the indexer. @@ -46,10 +64,11 @@ def process(self, state: State) -> State: state = state.copy() state["fields"] = state["fields"].copy() - state["latitudes"] = state["latitudes"][self.indexer] - state["longitudes"] = state["longitudes"][self.indexer] + idx = self.indexer # validate indexer is set + state["latitudes"] = state["latitudes"][idx] + state["longitudes"] = state["longitudes"][idx] for field in state["fields"]: - state["fields"][field] = state["fields"][field][self.indexer] + state["fields"][field] = state["fields"][field][idx] return state diff --git a/src/anemoi/inference/pre_processors/extract.py b/src/anemoi/inference/pre_processors/extract.py new file mode 100644 index 00000000..f3a5f040 --- /dev/null +++ b/src/anemoi/inference/pre_processors/extract.py @@ -0,0 +1,135 @@ +# (C) Copyright 2025- Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging +from pathlib import Path + +import numpy as np +from anemoi.transform.fields import new_field_from_numpy +from anemoi.transform.fields import new_fieldlist_from_list + +from anemoi.inference.context import Context +from anemoi.inference.decorators import main_argument +from anemoi.inference.types import BoolArray +from anemoi.inference.types import State + +from ..processor import Processor +from . import pre_processor_registry + +LOG = logging.getLogger(__name__) + + +class ExtractBase(Processor): + """Base class for pre-processors that extract data from the state.""" + + # this needs to be set in subclasses + indexer: BoolArray | slice + + def __init__(self, context: Context) -> None: + super().__init__(context) + self._indexer: BoolArray | slice | None = None + + @property + def indexer(self) -> BoolArray | slice: + if self._indexer is None: + raise RuntimeError(f"{type(self).__name__}.indexer is not set. Set it before process().") + return self._indexer + + @indexer.setter + def indexer(self, value: BoolArray | slice) -> None: + if isinstance(value, slice): + self._indexer = value + return + arr = np.asarray(value) + self._indexer = arr + + def process(self, state: State) -> State: + """Process the state to extract a subset of points based on the indexer.""" + state = state.copy() + idx = self.indexer # validate indexer is set and correct + result = [] + for field in state["fields"]: + data = field.to_numpy().flatten()[idx] + result.append(new_field_from_numpy(data, template=field)) + state["fields"] = new_fieldlist_from_list(result) + state["latitudes"] = state["latitudes"][idx] + state["longitudes"] = state["longitudes"][idx] + return state + + +@pre_processor_registry.register("extract_mask") +@main_argument("mask") +class ExtractMask(ExtractBase): + """Extract a subset of points from the state based on a boolean mask. + + Parameters + ---------- + context : Any + The context in which the processor is running. + mask : str + Either a path to a `.npy` file containing the boolean mask or + the name of a supporting array found in the checkpoint. + """ + + def __init__(self, context: Context, mask: str) -> None: + super().__init__(context) + + self._maskname = mask + + if Path(mask).is_file(): + mask = np.load(mask) + else: + mask = context.checkpoint.load_supporting_array(mask) + + if not isinstance(mask, np.ndarray) or mask.dtype != bool: + raise ValueError( + f"Expected the mask to be a boolean numpy array. Got {type(mask)} with dtype {mask.dtype}." + ) + + self.indexer = mask + self.npoints = np.sum(mask) + + def __repr__(self) -> str: + """Return a string representation of the ExtractMask object. + + Returns + ------- + str + String representation of the object. + """ + return f"ExtractMask({self._maskname}, points={self.npoints}/{self.indexer.size})" + + +@pre_processor_registry.register("extract_slice") +class ExtractSlice(ExtractBase): + """Extract a subset of points from the state based on a slice. + + Parameters + ---------- + context : Context + The context in which the processor is running. + slice_args : int + Arguments to create a slice object. This can be a single integer or + a tuple of integers representing the start, stop, and step of the slice. + """ + + def __init__(self, context: Context, *slice_args: int) -> None: + super().__init__(context) + self.indexer = slice(*slice_args) + + def __repr__(self) -> str: + """Return a string representation of the ExtractSlice object. + + Returns + ------- + str + String representation of the object. + """ + return f"ExtractSlice({self.indexer})" diff --git a/src/anemoi/inference/pre_processors/forward_transform_filter.py b/src/anemoi/inference/pre_processors/forward_transform_filter.py index 4cbffe57..ed73f44e 100644 --- a/src/anemoi/inference/pre_processors/forward_transform_filter.py +++ b/src/anemoi/inference/pre_processors/forward_transform_filter.py @@ -11,10 +11,10 @@ import logging from typing import Any -import earthkit.data as ekd from anemoi.transform.filters import filter_registry from anemoi.inference.decorators import main_argument +from anemoi.inference.types import State from ..processor import Processor from . import pre_processor_registry @@ -51,20 +51,21 @@ def __init__(self, context: Any, filter: str, **kwargs: Any) -> None: super().__init__(context) self.filter = filter_registry.create(filter, **kwargs) - def process(self, fields: ekd.FieldList) -> ekd.FieldList: + def process(self, state: State) -> State: """Process the given fields using the forward filter. Parameters ---------- - fields : ekd.FieldList - The fields to be processed. + state : State + The state containing the fields to be processed. Returns ------- - ekd.FieldList - The processed fields. + State + The processed state. """ - return self.filter.forward(fields) + state["fields"] = self.filter.forward(state["fields"]) + return state def patch_data_request(self, data_request: Any) -> Any: """Patch the data request using the filter. diff --git a/src/anemoi/inference/pre_processors/no_missing_values.py b/src/anemoi/inference/pre_processors/no_missing_values.py index e2fff804..4e2ea2a2 100644 --- a/src/anemoi/inference/pre_processors/no_missing_values.py +++ b/src/anemoi/inference/pre_processors/no_missing_values.py @@ -11,12 +11,12 @@ import logging from typing import Any -import earthkit.data as ekd import tqdm from anemoi.transform.fields import new_field_from_numpy from anemoi.transform.fields import new_fieldlist_from_list from anemoi.inference.context import Context +from anemoi.inference.types import State from ..processor import Processor from . import pre_processor_registry @@ -40,20 +40,21 @@ def __init__(self, context: Context, **kwargs: Any) -> None: """ super().__init__(context) - def process(self, fields: ekd.FieldList) -> ekd.FieldList: + def process(self, state: State) -> State: """Process the fields to replace NaNs with the mean value. Parameters ---------- - fields : list - List of fields to process. + state : State + The state to process. Returns ------- list - List of processed fields with NaNs replaced by the mean value. + List of processed state with NaNs replaced by the mean value. """ result = [] + fields = state["fields"] for field in tqdm.tqdm(fields): import numpy as np @@ -64,4 +65,5 @@ def process(self, fields: ekd.FieldList) -> ekd.FieldList: data = np.where(np.isnan(data), mean_value, data) result.append(new_field_from_numpy(data, template=field)) - return new_fieldlist_from_list(result) + state["fields"] = new_fieldlist_from_list(result) + return state diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index 6a764d23..ce4998e4 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -348,7 +348,7 @@ def add_initial_forcings_to_input_state(self, input_state: State) -> None: dates = [date + h for h in self.checkpoint.lagged] # For output object. Should be moved elsewhere - self.reference_date = dates[-1] + self.reference_date = self.reference_date or date self.initial_dates = dates # TODO: Check for user provided forcings @@ -689,13 +689,10 @@ def forecast( variable_to_input_tensor_index, self.checkpoint.timestep, ) + amp_ctx = torch.autocast(device_type=self.device.type, dtype=self.autocast) # Predict next state of atmosphere - with ( - torch.autocast(device_type=self.device.type, dtype=self.autocast), - ProfilingLabel("Predict step", self.use_profiler), - Timer(title), - ): + with torch.inference_mode(), amp_ctx, ProfilingLabel("Predict step", self.use_profiler), Timer(title): y_pred = self.predict_step(self.model, input_tensor_torch, fcstep=s, step=step, date=date) output = torch.squeeze(y_pred, dim=(0, 1)) # shape: (values, variables) @@ -777,21 +774,35 @@ def copy_prognostic_fields_to_input_tensor( """ # input_tensor_torch is shape: (batch, multi_step_input, values, variables) # batch is always 1 + pmask_in = torch.as_tensor( + self.checkpoint.prognostic_input_mask, + device=input_tensor_torch.device, + dtype=torch.long, + ) - prognostic_output_mask = self.checkpoint.prognostic_output_mask - prognostic_input_mask = self.checkpoint.prognostic_input_mask + pmask_out = torch.as_tensor( + self.checkpoint.prognostic_output_mask, + device=y_pred.device, + dtype=torch.long, + ) # index_select requires long dtype, can be bool (mask) + # or int (index) tensors - # Copy prognostic fields to input tensor - prognostic_fields = y_pred[..., prognostic_output_mask] # Get new predicted values - input_tensor_torch = input_tensor_torch.roll(-1, dims=1) # Roll the tensor in the multi_step_input dimension - input_tensor_torch[:, -1, :, self.checkpoint.prognostic_input_mask] = ( - prognostic_fields # Add new values to last 'multi_step_input' row - ) + prognostic_fields = torch.index_select(y_pred, dim=-1, index=pmask_out) + + input_tensor_torch = input_tensor_torch.roll(-1, dims=1) + input_tensor_torch[:, -1, :, pmask_in] = prognostic_fields - assert not check[prognostic_input_mask].any() # Make sure we are not overwriting some values - check[prognostic_input_mask] = True + pmask_in_np = pmask_in.detach().cpu().numpy() + if check[pmask_in_np].any(): + # Report which ones are conflicting + conflicting = [self._input_tensor_by_name[i] for i in pmask_in_np[check[pmask_in_np]]] + raise AssertionError( + f"Attempting to overwrite existing prognostic input slots for variables: {conflicting}" + ) + + check[pmask_in_np] = True - for n in prognostic_input_mask: + for n in pmask_in_np: self._input_kinds[self._input_tensor_by_name[n]] = Kind(prognostic=True) if self.trace: self.trace.from_rollout(self._input_tensor_by_name[n]) @@ -876,7 +887,6 @@ def add_boundary_forcings_to_input_tensor( """ # input_tensor_torch is shape: (batch, multi_step_input, values, variables) # batch is always 1 - sources = self.boundary_forcings_inputs for source in sources: forcings = source.load_forcings_array([date], state) # shape: (variables, dates, values) diff --git a/src/anemoi/inference/runners/default.py b/src/anemoi/inference/runners/default.py index cee04278..929c08db 100644 --- a/src/anemoi/inference/runners/default.py +++ b/src/anemoi/inference/runners/default.py @@ -68,6 +68,7 @@ def __init__(self, config: Configuration) -> None: config = DotDict(config.model_dump()) self.config = config + self.reference_date = self.config.date if hasattr(self.config, "date") else None super().__init__( config.checkpoint, diff --git a/src/anemoi/inference/runners/interpolator.py b/src/anemoi/inference/runners/interpolator.py index f2b82fd9..6374da3f 100644 --- a/src/anemoi/inference/runners/interpolator.py +++ b/src/anemoi/inference/runners/interpolator.py @@ -18,6 +18,7 @@ from anemoi.inference.config import Configuration from anemoi.inference.config.run import RunConfiguration +from anemoi.inference.device import get_available_device from anemoi.inference.lazy import torch from anemoi.inference.runner import Kind from anemoi.inference.types import State @@ -69,17 +70,13 @@ def __init__(self, config: Configuration) -> None: # if not isinstance(config, BaseModel): # config = RunConfiguration.load(config) - from anemoi.models.models import AnemoiModelEncProcDecInterpolator - super().__init__(config) - + self.from_analysis = any("use_original_paths" in keys for keys in config.input.values()) + self.device = get_available_device() self.patch_checkpoint_lagged_property() assert ( self.config.write_initial_state ), "Interpolator output should include temporal start state, end state and boundary conditions" - assert isinstance( - self.model.model, AnemoiModelEncProcDecInterpolator - ), "Model must be an interpolator model for this runner" self.target_forcings = self.target_computed_forcings( self.checkpoint._metadata._config_training.target_forcing.data @@ -123,9 +120,30 @@ def get_lagged(instance): # Replace the lagged property on this specific instance self.checkpoint.__class__.lagged = property(get_lagged) + def create_input_state(self, *, date: datetime.datetime, **kwargs) -> State: + prognostic_input = self.create_prognostics_input() + LOG.info("📥 Prognostic input: %s", prognostic_input) + prognostic_state = prognostic_input.create_input_state(date=date, **kwargs) + self._check_state(prognostic_state, "prognostics") + + constants_input = self.create_constant_coupled_forcings_input() + LOG.info("📥 Constant forcings input: %s", constants_input) + constants_state = constants_input.create_input_state(date=date, **kwargs) + self._check_state(constants_state, "constant_forcings") + + forcings_input = self.create_dynamic_forcings_input() + LOG.info("📥 Dynamic forcings input: %s", forcings_input) + forcings_state = forcings_input.create_input_state(date=date, **kwargs) + self._check_state(forcings_state, "dynamic_forcings") + input_state = self._combine_states( + prognostic_state, + constants_state, + forcings_state, + ) + return input_state + def execute(self) -> None: """Execute the interpolator runner with support for multiple interpolation periods.""" - if self.config.description is not None: LOG.info("%s", self.config.description) @@ -135,9 +153,7 @@ def execute(self) -> None: self.interpolation_window = get_interpolation_window( self.checkpoint.data_frequency, self.checkpoint.input_explicit_times ) - # Not really timestep but the size of the interpolation window, not sure if this is used self.time_step = self.interpolation_window - input = self.create_input() output = self.create_output() post_processors = self.post_processors @@ -148,18 +164,25 @@ def execute(self) -> None: num_windows = int(lead_time / self.interpolation_window) if lead_time % self.interpolation_window != to_timedelta(0): LOG.warning( - f"Lead time {lead_time} is not a multiple of interpolation window {self.interpolation_window}. " - f"Will interpolate for {num_windows * self.interpolation_window}" + "Lead time %s is not a multiple of interpolation window %s. Will interpolate for %s", + lead_time, + self.interpolation_window, + num_windows * self.interpolation_window, ) # Process each interpolation window for window_idx in range(num_windows): window_start_date = self.config.date + window_idx * self.interpolation_window - LOG.info(f"Processing interpolation window {window_idx + 1}/{num_windows} starting at {window_start_date}") + LOG.info( + "Processing interpolation window %d/%d starting at %s", window_idx + 1, num_windows, window_start_date + ) # Create input state for this window - input_state = input.create_input_state(date=window_start_date) + if self.from_analysis: + input_state = self.create_input_state(date=window_start_date) + else: + input_state = self.create_input_state(date=window_start_date, ref_date_index=0) self.input_state_hook(input_state) # Run interpolation for this window @@ -313,8 +336,6 @@ def forecast( Any The forecasted state. """ - import torch - # This does interpolation but called forecast so we can reuse run() self.model.eval() torch.set_grad_enabled(False) @@ -375,11 +396,13 @@ def forecast( result["interpolated"] = True if self.trace: - self.trace.write_input_tensor(date, s, input_tensor_torch.cpu().numpy(), variable_to_input_tensor_index) + self.trace.write_input_tensor( + date, s, input_tensor_torch.cpu().numpy(), variable_to_input_tensor_index, self.checkpoint.timestep + ) # Predict next state of atmosphere with ( - torch.autocast(device_type=self.device, dtype=self.autocast), + torch.autocast(device_type=self.device.type, dtype=self.autocast), ProfilingLabel("Predict step", self.use_profiler), Timer(title), ): @@ -391,7 +414,9 @@ def forecast( output = np.squeeze(y_pred.cpu().numpy()) # shape: (values, variables) if self.trace: - self.trace.write_output_tensor(date, s, output, self.checkpoint.output_tensor_index_to_variable) + self.trace.write_output_tensor( + date, s, output, self.checkpoint.output_tensor_index_to_variable, self.checkpoint.timestep + ) # Update state with ProfilingLabel("Updating state (CPU)", self.use_profiler):