13
13
14
14
# standard library
15
15
from collections .abc import Collection as Collection_ , Hashable
16
+ from dataclasses import Field
16
17
from enum import auto
17
- from typing import Annotated , Callable , Protocol , TypeVar , Union
18
+ from typing import (
19
+ Annotated ,
20
+ Any ,
21
+ Callable ,
22
+ ClassVar ,
23
+ ParamSpec ,
24
+ Protocol ,
25
+ TypeVar ,
26
+ Union ,
27
+ )
18
28
19
29
20
30
# dependencies
23
33
24
34
25
35
# type hints
36
+ PAny = ParamSpec ("PAny" )
26
37
TAny = TypeVar ("TAny" )
38
+ TDataArray = TypeVar ("TDataArray" , bound = DataArray )
39
+ TDataset = TypeVar ("TDataset" , bound = Dataset )
27
40
TDims = TypeVar ("TDims" , covariant = True )
28
41
TDtype = TypeVar ("TDtype" , covariant = True )
29
42
THashable = TypeVar ("THashable" , bound = Hashable )
30
- TXarray = TypeVar ("TXarray" , bound = "Xarray" )
43
+ TXarray = TypeVar ("TXarray" , covariant = True , bound = "Xarray" )
31
44
Xarray = Union [DataArray , Dataset ]
32
45
33
46
@@ -37,6 +50,23 @@ class Collection(Collection_[TDtype], Protocol[TDims, TDtype]):
37
50
pass
38
51
39
52
53
+ class DataClass (Protocol [PAny ]):
54
+ """Protocol for a dataclass object."""
55
+
56
+ __dataclass_fields__ : ClassVar [dict [str , Field [Any ]]]
57
+
58
+ def __init__ (self , * args : PAny .args , ** kwargs : PAny .kwargs ) -> None : ...
59
+
60
+
61
+ class DataClassOf (Protocol [PAny , TXarray ]):
62
+ """Protocol for a dataclass object with an xarray factory."""
63
+
64
+ _xarray_factory : Callable [..., TXarray ]
65
+ __dataclass_fields__ : ClassVar [dict [str , Field [Any ]]]
66
+
67
+ def __init__ (self , * args : PAny .args , ** kwargs : PAny .kwargs ) -> None : ...
68
+
69
+
40
70
# constants
41
71
class Tag (TagBase ):
42
72
"""Collection of xarray-related tags for annotating type hints."""
0 commit comments