Skip to content

Commit b1ce76e

Browse files
authored
Improve DataTree typing (#10644)
* Improve DataTree typing * a few more fixes
1 parent 6ca51d3 commit b1ce76e

File tree

8 files changed

+109
-91
lines changed

8 files changed

+109
-91
lines changed

xarray/backends/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def _protect_dataset_variables_inplace(dataset: Dataset, cache: bool) -> None:
304304

305305
def _protect_datatree_variables_inplace(tree: DataTree, cache: bool) -> None:
306306
for node in tree.subtree:
307-
_protect_dataset_variables_inplace(node, cache)
307+
_protect_dataset_variables_inplace(node.dataset, cache)
308308

309309

310310
def _finalize_store(write, store):

xarray/core/datatree.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ def map( # type: ignore[override]
439439

440440

441441
class DataTree(
442-
NamedNode["DataTree"],
442+
NamedNode,
443443
DataTreeAggregations,
444444
DataTreeOpsMixin,
445445
TreeAttrAccessMixin,
@@ -559,9 +559,12 @@ def _node_coord_variables_with_index(self) -> Mapping[Hashable, Variable]:
559559

560560
@property
561561
def _coord_variables(self) -> ChainMap[Hashable, Variable]:
562+
# ChainMap is incorrected typed in typeshed (only the first argument
563+
# needs to be mutable)
564+
# https://github.com/python/typeshed/issues/8430
562565
return ChainMap(
563566
self._node_coord_variables,
564-
*(p._node_coord_variables_with_index for p in self.parents),
567+
*(p._node_coord_variables_with_index for p in self.parents), # type: ignore[arg-type]
565568
)
566569

567570
@property
@@ -1340,7 +1343,7 @@ def equals(self, other: DataTree) -> bool:
13401343
)
13411344

13421345
def _inherited_coords_set(self) -> set[str]:
1343-
return set(self.parent.coords if self.parent else [])
1346+
return set(self.parent.coords if self.parent else []) # type: ignore[arg-type]
13441347

13451348
def identical(self, other: DataTree) -> bool:
13461349
"""
@@ -1563,9 +1566,33 @@ def match(self, pattern: str) -> DataTree:
15631566
}
15641567
return DataTree.from_dict(matching_nodes, name=self.name)
15651568

1569+
@overload
15661570
def map_over_datasets(
15671571
self,
1568-
func: Callable,
1572+
func: Callable[..., Dataset | None],
1573+
*args: Any,
1574+
kwargs: Mapping[str, Any] | None = None,
1575+
) -> DataTree: ...
1576+
1577+
@overload
1578+
def map_over_datasets(
1579+
self,
1580+
func: Callable[..., tuple[Dataset | None, Dataset | None]],
1581+
*args: Any,
1582+
kwargs: Mapping[str, Any] | None = None,
1583+
) -> tuple[DataTree, DataTree]: ...
1584+
1585+
@overload
1586+
def map_over_datasets(
1587+
self,
1588+
func: Callable[..., tuple[Dataset | None, ...]],
1589+
*args: Any,
1590+
kwargs: Mapping[str, Any] | None = None,
1591+
) -> tuple[DataTree, ...]: ...
1592+
1593+
def map_over_datasets(
1594+
self,
1595+
func: Callable[..., Dataset | None | tuple[Dataset | None, ...]],
15691596
*args: Any,
15701597
kwargs: Mapping[str, Any] | None = None,
15711598
) -> DataTree | tuple[DataTree, ...]:
@@ -1600,8 +1627,7 @@ def map_over_datasets(
16001627
map_over_datasets
16011628
"""
16021629
# TODO this signature means that func has no way to know which node it is being called upon - change?
1603-
# TODO fix this typing error
1604-
return map_over_datasets(func, self, *args, kwargs=kwargs)
1630+
return map_over_datasets(func, self, *args, kwargs=kwargs) # type: ignore[arg-type]
16051631

16061632
@overload
16071633
def pipe(
@@ -1695,7 +1721,7 @@ def groups(self):
16951721

16961722
def _unary_op(self, f, *args, **kwargs) -> DataTree:
16971723
# TODO do we need to any additional work to avoid duplication etc.? (Similar to aggregations)
1698-
return self.map_over_datasets(functools.partial(f, **kwargs), *args) # type: ignore[return-value]
1724+
return self.map_over_datasets(functools.partial(f, **kwargs), *args)
16991725

17001726
def _binary_op(self, other, f, reflexive=False, join=None) -> DataTree:
17011727
from xarray.core.groupby import GroupBy
@@ -1911,7 +1937,7 @@ def to_zarr(
19111937
)
19121938

19131939
def _get_all_dims(self) -> set:
1914-
all_dims = set()
1940+
all_dims: set[Any] = set()
19151941
for node in self.subtree:
19161942
all_dims.update(node._node_dims)
19171943
return all_dims

xarray/core/datatree_io.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,13 @@ def _datatree_to_netcdf(
7474
at_root = node is dt
7575
ds = node.to_dataset(inherit=write_inherited_coords or at_root)
7676
group_path = None if at_root else "/" + node.relative_to(dt)
77-
ds.to_netcdf(
77+
ds.to_netcdf( # type: ignore[misc] # Not all union combinations were tried because there are too many unions
7878
target,
7979
group=group_path,
8080
mode=mode,
8181
encoding=encoding.get(node.path),
8282
unlimited_dims=unlimited_dims.get(node.path),
83-
engine=engine,
83+
engine=engine, # type: ignore[arg-type]
8484
format=format,
8585
compute=compute,
8686
**kwargs,
@@ -134,7 +134,7 @@ def _datatree_to_zarr(
134134
at_root = node is dt
135135
ds = node.to_dataset(inherit=write_inherited_coords or at_root)
136136
group_path = None if at_root else "/" + node.relative_to(dt)
137-
ds.to_zarr(
137+
ds.to_zarr( # type: ignore[call-overload]
138138
store,
139139
group=group_path,
140140
mode=mode,

xarray/core/datatree_mapping.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,14 @@
1313

1414
@overload
1515
def map_over_datasets(
16-
func: Callable[
17-
...,
18-
Dataset | None,
19-
],
16+
func: Callable[..., Dataset | None],
2017
*args: Any,
2118
kwargs: Mapping[str, Any] | None = None,
2219
) -> DataTree: ...
2320

2421

22+
# add an explicit overload for the most common case of two return values
23+
# (python typing does not have a way to match tuple lengths in general)
2524
@overload
2625
def map_over_datasets(
2726
func: Callable[..., tuple[Dataset | None, Dataset | None]],
@@ -30,8 +29,6 @@ def map_over_datasets(
3029
) -> tuple[DataTree, DataTree]: ...
3130

3231

33-
# add an expect overload for the most common case of two return values
34-
# (python typing does not have a way to match tuple lengths in general)
3532
@overload
3633
def map_over_datasets(
3734
func: Callable[..., tuple[Dataset | None, ...]],
@@ -41,7 +38,7 @@ def map_over_datasets(
4138

4239

4340
def map_over_datasets(
44-
func: Callable[..., Dataset | tuple[Dataset | None, ...] | None],
41+
func: Callable[..., Dataset | None | tuple[Dataset | None, ...]],
4542
*args: Any,
4643
kwargs: Mapping[str, Any] | None = None,
4744
) -> DataTree | tuple[DataTree, ...]:

0 commit comments

Comments
 (0)