3
3
from __future__ import annotations
4
4
5
5
import operator
6
- from typing import TYPE_CHECKING , Any , Callable , Mapping , Sequence
6
+ from typing import TYPE_CHECKING , Any , Callable , Hashable , Mapping , Sequence
7
7
8
8
import boost_histogram as bh
9
- import dask .array as da
10
9
import numpy as np
10
+ from dask .array .core import Array as DaskArray
11
+ from dask .array .core import asarray
11
12
from dask .bag .core import empty_safe_aggregate
12
- from dask .base import DaskMethodsMixin , is_dask_collection , tokenize
13
+ from dask .base import DaskMethodsMixin , dont_optimize , is_dask_collection , tokenize
14
+ from dask .blockwise import fuse_roots , optimize_blockwise
15
+ from dask .context import globalmethod
16
+ from dask .core import flatten
13
17
from dask .dataframe .core import partitionwise_graph as partitionwise
14
18
from dask .delayed import Delayed
15
19
from dask .highlevelgraph import HighLevelGraph
@@ -195,6 +199,26 @@ def _blocked_dak(data: Any, *, histref: bh.Histogram | None = None) -> bh.Histog
195
199
return clone (histref ).fill (data )
196
200
197
201
202
+ def optimize (
203
+ dsk : Mapping ,
204
+ keys : Hashable | list [Hashable ] | set [Hashable ],
205
+ ** kwargs : Any ,
206
+ ) -> Mapping :
207
+ if not isinstance (keys , (list , set )):
208
+ keys = [keys ]
209
+ keys = list (flatten (keys ))
210
+
211
+ if not isinstance (dsk , HighLevelGraph ):
212
+ dsk = HighLevelGraph .from_collections (id (dsk ), dsk , dependencies = ())
213
+ else :
214
+ # Perform Blockwise optimizations for HLG input
215
+ dsk = optimize_blockwise (dsk , keys = keys )
216
+ dsk = fuse_roots (dsk , keys = keys ) # type: ignore
217
+ dsk = dsk .cull (set (keys )) # type: ignore
218
+
219
+ return dsk
220
+
221
+
198
222
class AggHistogram (DaskMethodsMixin ):
199
223
"""Aggregated Histogram collection.
200
224
@@ -249,6 +273,12 @@ def __dask_postcompute__(self) -> Any:
249
273
def __dask_postpersist__ (self ) -> Any :
250
274
return self ._rebuild , ()
251
275
276
+ __dask_optimize__ = globalmethod (
277
+ optimize , key = "histogram_optimize" , falsey = dont_optimize
278
+ )
279
+
280
+ __dask_scheduler__ = staticmethod (tget )
281
+
252
282
def _rebuild (
253
283
self ,
254
284
dsk : HighLevelGraph ,
@@ -303,7 +333,6 @@ def __str__(self) -> str:
303
333
)
304
334
305
335
__repr__ = __str__
306
- __dask_scheduler__ = staticmethod (tget )
307
336
308
337
@property
309
338
def _args (self ) -> tuple [HighLevelGraph , str , bh .Histogram ]:
@@ -317,7 +346,7 @@ def __setstate__(self, state: tuple[HighLevelGraph, str, bh.Histogram]) -> None:
317
346
318
347
def to_dask_array (
319
348
self , flow : bool = False , dd : bool = False
320
- ) -> tuple [da . Array , ...] | tuple [da . Array , list [da . Array ]]:
349
+ ) -> tuple [DaskArray , ...] | tuple [DaskArray , list [DaskArray ]]:
321
350
"""Convert histogram object to dask.array form.
322
351
323
352
Parameters
@@ -462,14 +491,19 @@ def _rebuild(self, dsk: Any, *, rename: Any = None) -> Any:
462
491
name = rename .get (name , name )
463
492
return type (self )(dsk , name , self .npartitions , self .histref )
464
493
494
+ __dask_optimize__ = globalmethod (
495
+ optimize , key = "histogram_optimize" , falsey = dont_optimize
496
+ )
497
+
498
+ __dask_scheduler__ = staticmethod (tget )
499
+
465
500
def __str__ (self ) -> str :
466
501
return "dask_histogram.PartitionedHistogram,<%s, npartitions=%d>" % (
467
502
key_split (self .name ),
468
503
self .npartitions ,
469
504
)
470
505
471
506
__repr__ = __str__
472
- __dask_scheduler__ = staticmethod (tget )
473
507
474
508
@property
475
509
def _args (self ) -> tuple [HighLevelGraph , str , int , bh .Histogram ]:
@@ -649,7 +683,7 @@ def to_dask_array(
649
683
agghist : AggHistogram ,
650
684
flow : bool = False ,
651
685
dd : bool = False ,
652
- ) -> tuple [da . Array , ...] | tuple [da . Array , list [da . Array ]]:
686
+ ) -> tuple [DaskArray , ...] | tuple [DaskArray , list [DaskArray ]]:
653
687
"""Convert `agghist` to a `dask.array` return style.
654
688
655
689
Parameters
@@ -687,15 +721,15 @@ def to_dask_array(
687
721
bh .storage .AtomicInt64 ,
688
722
)
689
723
dt = int if int_storage else float
690
- c = da . Array (graph , name = name , shape = shape , chunks = shape , dtype = dt )
724
+ c = DaskArray (graph , name = name , shape = shape , chunks = shape , dtype = dt )
691
725
axes = agghist .histref .axes
692
726
693
727
if flow :
694
728
edges = [
695
- da . asarray (np .concatenate ([[- np .inf ], ax .edges , [np .inf ]])) for ax in axes
729
+ asarray (np .concatenate ([[- np .inf ], ax .edges , [np .inf ]])) for ax in axes
696
730
]
697
731
else :
698
- edges = [da . asarray (ax .edges ) for ax in axes ]
732
+ edges = [asarray (ax .edges ) for ax in axes ]
699
733
if dd :
700
734
return c , edges
701
735
return (c , * tuple (edges ))
0 commit comments