From ae6ca8faff3b0077661c747d7179f71656e23db4 Mon Sep 17 00:00:00 2001 From: Akio Taniguchi <13254278+astropenguin@users.noreply.github.com> Date: Sat, 4 Jan 2025 11:24:29 +0000 Subject: [PATCH 1/4] #237 Update project dependencies --- pyproject.toml | 2 +- uv.lock | 18 ++++-------------- 2 files changed, 5 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f590724..cf15f0b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ description = "xarray data creation by data classes" readme = "README.md" keywords = ["dataclasses", "specifications", "typing", "xarray"] requires-python = ">=3.9,<3.14" -dependencies = ["dataspecs>=2.0,<3.0", "xarray>=2022.3,<2026.0"] +dependencies = ["dataspecs>=3.0,<4.0", "xarray>=2022.3,<2026.0"] classifiers = [ "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3", diff --git a/uv.lock b/uv.lock index b686c10..1a45891 100644 --- a/uv.lock +++ b/uv.lock @@ -219,15 +219,14 @@ wheels = [ [[package]] name = "dataspecs" -version = "2.0.1" +version = "3.0.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pyhumps" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/26/9a/77afad0968f54f7f0388b5548bfe4ef18f967c5fbf7ac039ebb71ae4a8dc/dataspecs-2.0.1.tar.gz", hash = "sha256:701792d4d6196bb534d3bb9a6c99e2c222d13a9f2a698200c7787968288fa627", size = 49977 } +sdist = { url = "https://files.pythonhosted.org/packages/6f/a3/246cb603391ef6f3de8f256cc5a8beabc93896e01b9ed31b9c63b2fa15a4/dataspecs-3.0.1.tar.gz", hash = "sha256:568a484bf40f4087f4f4a729674a41cd89f4ce02eaa5e162d0d6fd3fc347cff4", size = 49758 } wheels = [ - { url = "https://files.pythonhosted.org/packages/3d/19/56c7345f0f4e8892dc828d22318ec3c7a82f887d7dae858dc2851cf2fcd3/dataspecs-2.0.1-py3-none-any.whl", hash = "sha256:55e24cb54ff1cf95ced1e16c715bbcbdad3fe8036e88bee22d075634f394e058", size = 12363 }, + { url = "https://files.pythonhosted.org/packages/02/18/6b9081fb19cfd562277b65518ff8e119ee06053cf41e7a703b47cc91b769/dataspecs-3.0.1-py3-none-any.whl", hash = "sha256:d9b2b9a12c87f629b334354400bc6caf5a34273bbe6e27fe65ac590afb972793", size = 12294 }, ] [[package]] @@ -847,15 +846,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f7/3f/01c8b82017c199075f8f788d0d906b9ffbbc5a47dc9918a945e13d5a2bda/pygments-2.18.0-py3-none-any.whl", hash = "sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a", size = 1205513 }, ] -[[package]] -name = "pyhumps" -version = "3.8.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/c4/83/fa6f8fb7accb21f39e8f2b6a18f76f6d90626bdb0a5e5448e5cc9b8ab014/pyhumps-3.8.0.tar.gz", hash = "sha256:498026258f7ee1a8e447c2e28526c0bea9407f9a59c03260aee4bd6c04d681a3", size = 9018 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/9e/11/a1938340ecb32d71e47ad4914843775011e6e9da59ba1229f181fef3119e/pyhumps-3.8.0-py3-none-any.whl", hash = "sha256:060e1954d9069f428232a1adda165db0b9d8dfdce1d265d36df7fbff540acfd6", size = 6095 }, -] - [[package]] name = "pyright" version = "1.1.391" @@ -1281,7 +1271,7 @@ dev = [ [package.metadata] requires-dist = [ - { name = "dataspecs", specifier = ">=2.0,<3.0" }, + { name = "dataspecs", specifier = ">=3.0,<4.0" }, { name = "xarray", specifier = ">=2022.3,<2026.0" }, ] From 5e597523e386fb6662065be5ca02b09152df5c68 Mon Sep 17 00:00:00 2001 From: Akio Taniguchi <13254278+astropenguin@users.noreply.github.com> Date: Sat, 4 Jan 2025 12:00:11 +0000 Subject: [PATCH 2/4] #237 Add typing module for tags and type aliases --- xarray_dataclasses/__init__.py | 22 ++++++++ xarray_dataclasses/typing.py | 95 ++++++++++++++++++++++++++++++++++ 2 files changed, 117 insertions(+) create mode 100644 xarray_dataclasses/typing.py diff --git a/xarray_dataclasses/__init__.py b/xarray_dataclasses/__init__.py index 8c0d5d5..56d4dec 100644 --- a/xarray_dataclasses/__init__.py +++ b/xarray_dataclasses/__init__.py @@ -1 +1,23 @@ +__all__ = [ + # submodules + "typing", + # aliases + "Attr", + "Attrs", + "Coord", + "Coords", + "Data", + "DataVars", + "Factory", + "Name", + "Tag", +] __version__ = "2.0.0" + + +# submodules +from . import typing + + +# aliases +from .typing import * diff --git a/xarray_dataclasses/typing.py b/xarray_dataclasses/typing.py new file mode 100644 index 0000000..afac858 --- /dev/null +++ b/xarray_dataclasses/typing.py @@ -0,0 +1,95 @@ +__all__ = [ + "Attr", + "Attrs", + "Coord", + "Coords", + "Data", + "DataVars", + "Factory", + "Name", + "Tag", +] + + +# standard library +from collections.abc import Collection as Collection_, Hashable +from enum import auto +from typing import Annotated, Callable, Protocol, TypeVar, Union + + +# dependencies +from dataspecs import TagBase +from xarray import DataArray, Dataset + + +# type hints +TAny = TypeVar("TAny") +TDims = TypeVar("TDims", covariant=True) +TDtype = TypeVar("TDtype", covariant=True) +THashable = TypeVar("THashable", bound=Hashable) +TXarray = TypeVar("TXarray", bound="Xarray") +Xarray = Union[DataArray, Dataset] + + +class Collection(Collection_[TDtype], Protocol[TDims, TDtype]): + """Same as Collection[T] but accepts additional type variable for dims.""" + + pass + + +# constants +class Tag(TagBase): + """Collection of xarray-related tags for annotating type hints.""" + + ATTR = auto() + """Tag for specifying an attribute of DataArray/set.""" + + COORD = auto() + """Tag for specifying a coordinate of DataArray/set.""" + + DATA = auto() + """Tag for specifying a data object of DataArray/set.""" + + DIMS = auto() + """Tag for specifying a dims object of DataArray/set.""" + + DTYPE = auto() + """Tag for specifying a dtype object of DataArray/set.""" + + FACTORY = auto() + """Tag for specifying a factory of DataArray/set.""" + + MULTIPLE = auto() + """Tag for specifying multiple items (attrs, coords, data vars).""" + + NAME = auto() + """Tag for specifying an item name (attr, coord, data). """ + + +# type aliases +Arrayable = Collection[Annotated[TDims, Tag.DIMS], Annotated[TDtype, Tag.DTYPE]] +"""Type alias for Collection[TDims, TDtype] annotated by tags.""" + +Attr = Annotated[TAny, Tag.ATTR] +"""Type alias for an attribute of DataArray/set.""" + +Attrs = Annotated[dict[str, TAny], Tag.ATTR, Tag.MULTIPLE] +"""Type alias for attributes of DataArray/set.""" + +Coord = Annotated[Arrayable[TDims, TDtype], Tag.COORD] +"""Type alias for a coordinate of DataArray/set.""" + +Coords = Annotated[dict[str, Arrayable[TDims, TDtype]], Tag.COORD, Tag.MULTIPLE] +"""Type alias for coordinates of DataArray/set.""" + +Data = Annotated[Arrayable[TDims, TDtype], Tag.DATA] +"""Type alias for a data object of DataArray/set.""" + +DataVars = Annotated[dict[str, Arrayable[TDims, TDtype]], Tag.DATA, Tag.MULTIPLE] +"""Type alias for data objects of DataArray/set.""" + +Factory = Annotated[Callable[..., TXarray], Tag.FACTORY] +"""Type alias for a factory of DataArray/set.""" + +Name = Annotated[THashable, Tag.NAME] +"""Type alias for an item name (attr, coord, data).""" From 30b82e2ab0c6ba98331b549fe94e29686721f3ae Mon Sep 17 00:00:00 2001 From: Akio Taniguchi <13254278+astropenguin@users.noreply.github.com> Date: Sat, 4 Jan 2025 12:37:51 +0000 Subject: [PATCH 3/4] #237 Add api module for dataclass converters --- xarray_dataclasses/__init__.py | 6 +++ xarray_dataclasses/api.py | 97 ++++++++++++++++++++++++++++++++++ xarray_dataclasses/typing.py | 34 +++++++++++- 3 files changed, 135 insertions(+), 2 deletions(-) create mode 100644 xarray_dataclasses/api.py diff --git a/xarray_dataclasses/__init__.py b/xarray_dataclasses/__init__.py index 56d4dec..c5bb51f 100644 --- a/xarray_dataclasses/__init__.py +++ b/xarray_dataclasses/__init__.py @@ -1,5 +1,6 @@ __all__ = [ # submodules + "api", "typing", # aliases "Attr", @@ -11,13 +12,18 @@ "Factory", "Name", "Tag", + "asdataarray", + "asdataset", + "asxarray", ] __version__ = "2.0.0" # submodules +from . import api from . import typing # aliases +from .api import * from .typing import * diff --git a/xarray_dataclasses/api.py b/xarray_dataclasses/api.py new file mode 100644 index 0000000..3aad157 --- /dev/null +++ b/xarray_dataclasses/api.py @@ -0,0 +1,97 @@ +__all__ = ["asdataarray", "asdataset", "asxarray"] + + +# standard library +from typing import Any, Callable, overload + + +# dependencies +from xarray import DataArray, Dataset +from .typing import DataClass, DataClassOf, PAny, TDataArray, TDataset, TXarray + + +@overload +def asdataarray( + obj: DataClassOf[PAny, TDataArray], + /, + *, + factory: None = None, +) -> TDataArray: ... + + +@overload +def asdataarray( + obj: DataClass[PAny], + /, + *, + factory: Callable[..., TDataArray], +) -> TDataArray: ... + + +@overload +def asdataarray( + obj: DataClass[PAny], + /, + *, + factory: None = None, +) -> DataArray: ... + + +def asdataarray(obj: Any, /, *, factory: Any = None) -> Any: + """Create a DataArray object from a dataclass object.""" + ... + + +@overload +def asdataset( + obj: DataClassOf[PAny, TDataset], + /, + *, + factory: None = None, +) -> TDataset: ... + + +@overload +def asdataset( + obj: DataClass[PAny], + /, + *, + factory: Callable[..., TDataset], +) -> TDataset: ... + + +@overload +def asdataset( + obj: DataClass[PAny], + /, + *, + factory: None = None, +) -> Dataset: ... + + +def asdataset(obj: Any, /, *, factory: Any = None) -> Any: + """Create a Dataset object from a dataclass object.""" + ... + + +@overload +def asxarray( + obj: DataClassOf[PAny, TXarray], + /, + *, + factory: None = None, +) -> TXarray: ... + + +@overload +def asxarray( + obj: DataClass[PAny], + /, + *, + factory: Callable[..., TXarray], +) -> TXarray: ... + + +def asxarray(obj: Any, /, *, factory: Any = None) -> Any: + """Create a DataArray/set object from a dataclass object.""" + ... diff --git a/xarray_dataclasses/typing.py b/xarray_dataclasses/typing.py index afac858..dbd857f 100644 --- a/xarray_dataclasses/typing.py +++ b/xarray_dataclasses/typing.py @@ -13,8 +13,18 @@ # standard library from collections.abc import Collection as Collection_, Hashable +from dataclasses import Field from enum import auto -from typing import Annotated, Callable, Protocol, TypeVar, Union +from typing import ( + Annotated, + Any, + Callable, + ClassVar, + ParamSpec, + Protocol, + TypeVar, + Union, +) # dependencies @@ -23,11 +33,14 @@ # type hints +PAny = ParamSpec("PAny") TAny = TypeVar("TAny") +TDataArray = TypeVar("TDataArray", bound=DataArray) +TDataset = TypeVar("TDataset", bound=Dataset) TDims = TypeVar("TDims", covariant=True) TDtype = TypeVar("TDtype", covariant=True) THashable = TypeVar("THashable", bound=Hashable) -TXarray = TypeVar("TXarray", bound="Xarray") +TXarray = TypeVar("TXarray", covariant=True, bound="Xarray") Xarray = Union[DataArray, Dataset] @@ -37,6 +50,23 @@ class Collection(Collection_[TDtype], Protocol[TDims, TDtype]): pass +class DataClass(Protocol[PAny]): + """Protocol for a dataclass object.""" + + __dataclass_fields__: ClassVar[dict[str, Field[Any]]] + + def __init__(self, *args: PAny.args, **kwargs: PAny.kwargs) -> None: ... + + +class DataClassOf(Protocol[PAny, TXarray]): + """Protocol for a dataclass object with an xarray factory.""" + + _xarray_factory: Callable[..., TXarray] + __dataclass_fields__: ClassVar[dict[str, Field[Any]]] + + def __init__(self, *args: PAny.args, **kwargs: PAny.kwargs) -> None: ... + + # constants class Tag(TagBase): """Collection of xarray-related tags for annotating type hints.""" From 8458a357f9608b211dae04c52d414a22e86611e3 Mon Sep 17 00:00:00 2001 From: Akio Taniguchi <13254278+astropenguin@users.noreply.github.com> Date: Sat, 4 Jan 2025 17:17:16 +0000 Subject: [PATCH 4/4] #237 Add helper dataclass converters --- xarray_dataclasses/api.py | 107 ++++++++++++++++++++++++++++++++--- xarray_dataclasses/typing.py | 7 ++- 2 files changed, 104 insertions(+), 10 deletions(-) diff --git a/xarray_dataclasses/api.py b/xarray_dataclasses/api.py index 3aad157..504907b 100644 --- a/xarray_dataclasses/api.py +++ b/xarray_dataclasses/api.py @@ -2,12 +2,31 @@ # standard library -from typing import Any, Callable, overload - +from dataclasses import replace +from typing import Any, ForwardRef, Literal, Optional, overload # dependencies -from xarray import DataArray, Dataset -from .typing import DataClass, DataClassOf, PAny, TDataArray, TDataset, TXarray +from dataspecs import ID, ROOT, Spec, Specs +from numpy import asarray, array +from typing_extensions import get_args, get_origin +from xarray import DataArray, Dataset, Variable +from .typing import ( + DataClass, + DataClassOf, + Factory, + HashDict, + PAny, + TAny, + TDataArray, + TDataset, + TXarray, + Tag, +) + + +# type hints +Attrs = HashDict[Any] +Vars = HashDict[Variable] @overload @@ -24,7 +43,7 @@ def asdataarray( obj: DataClass[PAny], /, *, - factory: Callable[..., TDataArray], + factory: Factory[TDataArray], ) -> TDataArray: ... @@ -56,7 +75,7 @@ def asdataset( obj: DataClass[PAny], /, *, - factory: Callable[..., TDataset], + factory: Factory[TDataset], ) -> TDataset: ... @@ -88,10 +107,84 @@ def asxarray( obj: DataClass[PAny], /, *, - factory: Callable[..., TXarray], + factory: Factory[TXarray], ) -> TXarray: ... def asxarray(obj: Any, /, *, factory: Any = None) -> Any: """Create a DataArray/set object from a dataclass object.""" ... + + +def get_attrs(specs: Specs[Spec[Any]], /, *, at: ID = ROOT) -> Attrs: + """Create attributes from data specs.""" + attrs: Attrs = {} + + for spec in specs[at.children][Tag.ATTR]: + options = specs[spec.id.children] + factory = maybe(options[Tag.FACTORY].unique).data or identity + name = maybe(options[Tag.NAME].unique).data or spec.id.name + + if Tag.MULTIPLE not in spec.tags: + spec = replace(spec, data={name: spec.data}) + + for name, data in spec[HashDict[Any]].data.items(): + attrs[name] = factory(data) + + return attrs + + +def get_vars(specs: Specs[Spec[Any]], of: Tag, /, *, at: ID = ROOT) -> Vars: + """Create variables of given tag from data specs.""" + vars: Vars = {} + + for spec in specs[at.children][of]: + options = specs[spec.id.children] + attrs = get_attrs(specs, at=spec.id) + factory = maybe(options[Tag.FACTORY].unique).data or Variable + name = maybe(options[Tag.NAME].unique).data or spec.id.name + + if (type_ := maybe(options[Tag.DIMS].unique).type) is None: + raise RuntimeError("Could not find any data spec for dims.") + elif get_origin(type_) is tuple: + dims = tuple(str(unwrap(arg)) for arg in get_args(type_)) + else: + dims = (str(unwrap(type_)),) + + if (type_ := maybe(options[Tag.DTYPE].unique).type) is None: + raise RuntimeError("Could not find any data spec for dims.") + elif type_ is type(None) or type_ is Any: + dtype = None + else: + dtype = unwrap(type_) + + if Tag.MULTIPLE not in spec.tags: + spec = replace(spec, data={name: spec.data}) + + for name, data in spec[HashDict[Any]].data.items(): + if not (data := asarray(data, dtype)).ndim: + data = array(data, ndmin=len(dims)) + + vars[name] = factory(attrs=attrs, data=data, dims=dims) + + return vars + + +def identity(obj: TAny, /) -> TAny: + """Identity function used for the default factory.""" + return obj + + +def maybe(obj: Optional[Spec[Any]], /) -> Spec[Any]: + """Return a dummy (``None``-filled) data spec if an object is not one.""" + return Spec(ROOT, (), None, None) if obj is None else obj + + +def unwrap(obj: Any, /) -> Any: + """Unwrap if an object is a literal or a forward reference.""" + if get_origin(obj) is Literal: + return args[0] if len(args := get_args(obj)) == 1 else obj + elif isinstance(obj, ForwardRef): + return obj.__forward_arg__ + else: + return obj diff --git a/xarray_dataclasses/typing.py b/xarray_dataclasses/typing.py index dbd857f..d9e122e 100644 --- a/xarray_dataclasses/typing.py +++ b/xarray_dataclasses/typing.py @@ -41,6 +41,7 @@ TDtype = TypeVar("TDtype", covariant=True) THashable = TypeVar("THashable", bound=Hashable) TXarray = TypeVar("TXarray", covariant=True, bound="Xarray") +HashDict = dict[Hashable, TAny] Xarray = Union[DataArray, Dataset] @@ -103,19 +104,19 @@ class Tag(TagBase): Attr = Annotated[TAny, Tag.ATTR] """Type alias for an attribute of DataArray/set.""" -Attrs = Annotated[dict[str, TAny], Tag.ATTR, Tag.MULTIPLE] +Attrs = Annotated[HashDict[TAny], Tag.ATTR, Tag.MULTIPLE] """Type alias for attributes of DataArray/set.""" Coord = Annotated[Arrayable[TDims, TDtype], Tag.COORD] """Type alias for a coordinate of DataArray/set.""" -Coords = Annotated[dict[str, Arrayable[TDims, TDtype]], Tag.COORD, Tag.MULTIPLE] +Coords = Annotated[HashDict[Arrayable[TDims, TDtype]], Tag.COORD, Tag.MULTIPLE] """Type alias for coordinates of DataArray/set.""" Data = Annotated[Arrayable[TDims, TDtype], Tag.DATA] """Type alias for a data object of DataArray/set.""" -DataVars = Annotated[dict[str, Arrayable[TDims, TDtype]], Tag.DATA, Tag.MULTIPLE] +DataVars = Annotated[HashDict[Arrayable[TDims, TDtype]], Tag.DATA, Tag.MULTIPLE] """Type alias for data objects of DataArray/set.""" Factory = Annotated[Callable[..., TXarray], Tag.FACTORY]