Skip to content

Commit 4d05aa8

Browse files
gcariaGiacomo Caria
andauthored
Fix bug indexing with boolean scalars (#10635)
* add check for bool type * clean up * clean up * add test * add test * add test * add test --------- Co-authored-by: Giacomo Caria <[email protected]>
1 parent b1ce76e commit 4d05aa8

File tree

5 files changed

+25
-21
lines changed

5 files changed

+25
-21
lines changed

xarray/core/dataset.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,7 @@
3535
from xarray.computation import ops
3636
from xarray.computation.arithmetic import DatasetArithmetic
3737
from xarray.core import dtypes as xrdtypes
38-
from xarray.core import (
39-
duck_array_ops,
40-
formatting,
41-
formatting_html,
42-
utils,
43-
)
38+
from xarray.core import duck_array_ops, formatting, formatting_html, utils
4439
from xarray.core._aggregations import DatasetAggregations
4540
from xarray.core.common import (
4641
DataWithCoords,
@@ -2636,7 +2631,7 @@ def _validate_indexers(
26362631

26372632
# all indexers should be int, slice, np.ndarrays, or Variable
26382633
for k, v in indexers.items():
2639-
if isinstance(v, int | slice | Variable):
2634+
if isinstance(v, int | slice | Variable) and not isinstance(v, bool):
26402635
yield k, v
26412636
elif isinstance(v, DataArray):
26422637
yield k, v.variable

xarray/core/indexing.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,7 @@ def __init__(
496496

497497
new_key = []
498498
for k in key:
499-
if isinstance(k, integer_types):
499+
if isinstance(k, integer_types) and not isinstance(k, bool):
500500
k = int(k)
501501
elif isinstance(k, slice):
502502
k = as_integer_slice(k)
@@ -513,7 +513,7 @@ def __init__(
513513
k = duck_array_ops.astype(k, np.int64, copy=False)
514514
else:
515515
raise TypeError(
516-
f"unexpected indexer type for {type(self).__name__}: {k!r}"
516+
f"unexpected indexer type for {type(self).__name__}: {k!r}, {type(k)}"
517517
)
518518
new_key.append(k)
519519

@@ -1639,7 +1639,7 @@ def is_fancy_indexer(indexer: Any) -> bool:
16391639
"""Return False if indexer is a int, slice, a 1-dimensional list, or a 0 or
16401640
1-dimensional ndarray; in all other cases return True
16411641
"""
1642-
if isinstance(indexer, int | slice):
1642+
if isinstance(indexer, int | slice) and not isinstance(indexer, bool):
16431643
return False
16441644
if isinstance(indexer, np.ndarray):
16451645
return indexer.ndim > 1
@@ -1771,11 +1771,7 @@ def transpose(self, order):
17711771

17721772

17731773
def _apply_vectorized_indexer_dask_wrapper(indices, coord):
1774-
from xarray.core.indexing import (
1775-
VectorizedIndexer,
1776-
apply_indexer,
1777-
as_indexable,
1778-
)
1774+
from xarray.core.indexing import VectorizedIndexer, apply_indexer, as_indexable
17791775

17801776
return apply_indexer(
17811777
as_indexable(coord), VectorizedIndexer((indices.squeeze(axis=-1),))

xarray/core/variable.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,10 @@ def _broadcast_indexes(self, key):
647647
k.item() if isinstance(k, np.ndarray) and k.ndim == 0 else k for k in key
648648
)
649649

650-
if all(isinstance(k, BASIC_INDEXING_TYPES) for k in key):
650+
if all(
651+
(isinstance(k, BASIC_INDEXING_TYPES) and not isinstance(k, bool))
652+
for k in key
653+
):
651654
return self._broadcast_indexes_basic(key)
652655

653656
self._validate_indexers(key)

xarray/tests/test_dataarray.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,7 @@
3434
from xarray.core import dtypes
3535
from xarray.core.common import full_like
3636
from xarray.core.coordinates import Coordinates, CoordinateValidationError
37-
from xarray.core.indexes import (
38-
Index,
39-
PandasIndex,
40-
filter_indexes_from_coords,
41-
)
37+
from xarray.core.indexes import Index, PandasIndex, filter_indexes_from_coords
4238
from xarray.core.types import QueryEngineOptions, QueryParserOptions
4339
from xarray.core.utils import is_scalar
4440
from xarray.testing import _assert_internal_invariants
@@ -749,6 +745,16 @@ def test_getitem_empty_index(self) -> None:
749745
)
750746
assert_identical(da[[]], DataArray(np.zeros((0, 4)), dims=["x", "y"]))
751747

748+
def test_getitem_typeerror(self) -> None:
749+
with pytest.raises(TypeError, match=r"unexpected indexer type"):
750+
self.dv[True]
751+
with pytest.raises(TypeError, match=r"unexpected indexer type"):
752+
self.dv[np.array(True)]
753+
with pytest.raises(TypeError, match=r"invalid indexer array"):
754+
self.dv[3.0]
755+
with pytest.raises(TypeError, match=r"invalid indexer array"):
756+
self.dv[None]
757+
752758
def test_setitem(self) -> None:
753759
# basic indexing should work as numpy's indexing
754760
tuples: list[tuple[int | list[int] | slice, int | list[int] | slice]] = [

xarray/tests/test_indexing.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,10 @@ def test_invalid_for_all(indexer_cls) -> None:
568568
indexer_cls((slice("foo"),))
569569
with pytest.raises(TypeError):
570570
indexer_cls((np.array(["foo"]),))
571+
with pytest.raises(TypeError):
572+
indexer_cls(True)
573+
with pytest.raises(TypeError):
574+
indexer_cls(np.array(True))
571575

572576

573577
def check_integer(indexer_cls):

0 commit comments

Comments
 (0)