@@ -439,7 +439,7 @@ def map( # type: ignore[override]
439
439
440
440
441
441
class DataTree (
442
- NamedNode [ "DataTree" ] ,
442
+ NamedNode ,
443
443
DataTreeAggregations ,
444
444
DataTreeOpsMixin ,
445
445
TreeAttrAccessMixin ,
@@ -559,9 +559,12 @@ def _node_coord_variables_with_index(self) -> Mapping[Hashable, Variable]:
559
559
560
560
@property
561
561
def _coord_variables (self ) -> ChainMap [Hashable , Variable ]:
562
+ # ChainMap is incorrected typed in typeshed (only the first argument
563
+ # needs to be mutable)
564
+ # https://github.com/python/typeshed/issues/8430
562
565
return ChainMap (
563
566
self ._node_coord_variables ,
564
- * (p ._node_coord_variables_with_index for p in self .parents ),
567
+ * (p ._node_coord_variables_with_index for p in self .parents ), # type: ignore[arg-type]
565
568
)
566
569
567
570
@property
@@ -1340,7 +1343,7 @@ def equals(self, other: DataTree) -> bool:
1340
1343
)
1341
1344
1342
1345
def _inherited_coords_set (self ) -> set [str ]:
1343
- return set (self .parent .coords if self .parent else [])
1346
+ return set (self .parent .coords if self .parent else []) # type: ignore[arg-type]
1344
1347
1345
1348
def identical (self , other : DataTree ) -> bool :
1346
1349
"""
@@ -1563,9 +1566,33 @@ def match(self, pattern: str) -> DataTree:
1563
1566
}
1564
1567
return DataTree .from_dict (matching_nodes , name = self .name )
1565
1568
1569
+ @overload
1566
1570
def map_over_datasets (
1567
1571
self ,
1568
- func : Callable ,
1572
+ func : Callable [..., Dataset | None ],
1573
+ * args : Any ,
1574
+ kwargs : Mapping [str , Any ] | None = None ,
1575
+ ) -> DataTree : ...
1576
+
1577
+ @overload
1578
+ def map_over_datasets (
1579
+ self ,
1580
+ func : Callable [..., tuple [Dataset | None , Dataset | None ]],
1581
+ * args : Any ,
1582
+ kwargs : Mapping [str , Any ] | None = None ,
1583
+ ) -> tuple [DataTree , DataTree ]: ...
1584
+
1585
+ @overload
1586
+ def map_over_datasets (
1587
+ self ,
1588
+ func : Callable [..., tuple [Dataset | None , ...]],
1589
+ * args : Any ,
1590
+ kwargs : Mapping [str , Any ] | None = None ,
1591
+ ) -> tuple [DataTree , ...]: ...
1592
+
1593
+ def map_over_datasets (
1594
+ self ,
1595
+ func : Callable [..., Dataset | None | tuple [Dataset | None , ...]],
1569
1596
* args : Any ,
1570
1597
kwargs : Mapping [str , Any ] | None = None ,
1571
1598
) -> DataTree | tuple [DataTree , ...]:
@@ -1600,8 +1627,7 @@ def map_over_datasets(
1600
1627
map_over_datasets
1601
1628
"""
1602
1629
# TODO this signature means that func has no way to know which node it is being called upon - change?
1603
- # TODO fix this typing error
1604
- return map_over_datasets (func , self , * args , kwargs = kwargs )
1630
+ return map_over_datasets (func , self , * args , kwargs = kwargs ) # type: ignore[arg-type]
1605
1631
1606
1632
@overload
1607
1633
def pipe (
@@ -1695,7 +1721,7 @@ def groups(self):
1695
1721
1696
1722
def _unary_op (self , f , * args , ** kwargs ) -> DataTree :
1697
1723
# TODO do we need to any additional work to avoid duplication etc.? (Similar to aggregations)
1698
- return self .map_over_datasets (functools .partial (f , ** kwargs ), * args ) # type: ignore[return-value]
1724
+ return self .map_over_datasets (functools .partial (f , ** kwargs ), * args )
1699
1725
1700
1726
def _binary_op (self , other , f , reflexive = False , join = None ) -> DataTree :
1701
1727
from xarray .core .groupby import GroupBy
@@ -1911,7 +1937,7 @@ def to_zarr(
1911
1937
)
1912
1938
1913
1939
def _get_all_dims (self ) -> set :
1914
- all_dims = set ()
1940
+ all_dims : set [ Any ] = set ()
1915
1941
for node in self .subtree :
1916
1942
all_dims .update (node ._node_dims )
1917
1943
return all_dims
0 commit comments