Skip to content

Commit feff206

Browse files
committed
explicit __dask_optimize__; typing
1 parent 3affcae commit feff206

File tree

1 file changed

+44
-10
lines changed

1 file changed

+44
-10
lines changed

src/dask_histogram/core.py

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,17 @@
33
from __future__ import annotations
44

55
import operator
6-
from typing import TYPE_CHECKING, Any, Callable, Mapping, Sequence
6+
from typing import TYPE_CHECKING, Any, Callable, Hashable, Mapping, Sequence
77

88
import boost_histogram as bh
9-
import dask.array as da
109
import numpy as np
10+
from dask.array.core import Array as DaskArray
11+
from dask.array.core import asarray
1112
from dask.bag.core import empty_safe_aggregate
12-
from dask.base import DaskMethodsMixin, is_dask_collection, tokenize
13+
from dask.base import DaskMethodsMixin, dont_optimize, is_dask_collection, tokenize
14+
from dask.blockwise import fuse_roots, optimize_blockwise
15+
from dask.context import globalmethod
16+
from dask.core import flatten
1317
from dask.dataframe.core import partitionwise_graph as partitionwise
1418
from dask.delayed import Delayed
1519
from dask.highlevelgraph import HighLevelGraph
@@ -195,6 +199,26 @@ def _blocked_dak(data: Any, *, histref: bh.Histogram | None = None) -> bh.Histog
195199
return clone(histref).fill(data)
196200

197201

202+
def optimize(
203+
dsk: Mapping,
204+
keys: Hashable | list[Hashable] | set[Hashable],
205+
**kwargs: Any,
206+
) -> Mapping:
207+
if not isinstance(keys, (list, set)):
208+
keys = [keys]
209+
keys = list(flatten(keys))
210+
211+
if not isinstance(dsk, HighLevelGraph):
212+
dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=())
213+
else:
214+
# Perform Blockwise optimizations for HLG input
215+
dsk = optimize_blockwise(dsk, keys=keys)
216+
dsk = fuse_roots(dsk, keys=keys) # type: ignore
217+
dsk = dsk.cull(set(keys)) # type: ignore
218+
219+
return dsk
220+
221+
198222
class AggHistogram(DaskMethodsMixin):
199223
"""Aggregated Histogram collection.
200224
@@ -249,6 +273,12 @@ def __dask_postcompute__(self) -> Any:
249273
def __dask_postpersist__(self) -> Any:
250274
return self._rebuild, ()
251275

276+
__dask_optimize__ = globalmethod(
277+
optimize, key="histogram_optimize", falsey=dont_optimize
278+
)
279+
280+
__dask_scheduler__ = staticmethod(tget)
281+
252282
def _rebuild(
253283
self,
254284
dsk: HighLevelGraph,
@@ -303,7 +333,6 @@ def __str__(self) -> str:
303333
)
304334

305335
__repr__ = __str__
306-
__dask_scheduler__ = staticmethod(tget)
307336

308337
@property
309338
def _args(self) -> tuple[HighLevelGraph, str, bh.Histogram]:
@@ -317,7 +346,7 @@ def __setstate__(self, state: tuple[HighLevelGraph, str, bh.Histogram]) -> None:
317346

318347
def to_dask_array(
319348
self, flow: bool = False, dd: bool = False
320-
) -> tuple[da.Array, ...] | tuple[da.Array, list[da.Array]]:
349+
) -> tuple[DaskArray, ...] | tuple[DaskArray, list[DaskArray]]:
321350
"""Convert histogram object to dask.array form.
322351
323352
Parameters
@@ -462,14 +491,19 @@ def _rebuild(self, dsk: Any, *, rename: Any = None) -> Any:
462491
name = rename.get(name, name)
463492
return type(self)(dsk, name, self.npartitions, self.histref)
464493

494+
__dask_optimize__ = globalmethod(
495+
optimize, key="histogram_optimize", falsey=dont_optimize
496+
)
497+
498+
__dask_scheduler__ = staticmethod(tget)
499+
465500
def __str__(self) -> str:
466501
return "dask_histogram.PartitionedHistogram,<%s, npartitions=%d>" % (
467502
key_split(self.name),
468503
self.npartitions,
469504
)
470505

471506
__repr__ = __str__
472-
__dask_scheduler__ = staticmethod(tget)
473507

474508
@property
475509
def _args(self) -> tuple[HighLevelGraph, str, int, bh.Histogram]:
@@ -649,7 +683,7 @@ def to_dask_array(
649683
agghist: AggHistogram,
650684
flow: bool = False,
651685
dd: bool = False,
652-
) -> tuple[da.Array, ...] | tuple[da.Array, list[da.Array]]:
686+
) -> tuple[DaskArray, ...] | tuple[DaskArray, list[DaskArray]]:
653687
"""Convert `agghist` to a `dask.array` return style.
654688
655689
Parameters
@@ -687,15 +721,15 @@ def to_dask_array(
687721
bh.storage.AtomicInt64,
688722
)
689723
dt = int if int_storage else float
690-
c = da.Array(graph, name=name, shape=shape, chunks=shape, dtype=dt)
724+
c = DaskArray(graph, name=name, shape=shape, chunks=shape, dtype=dt)
691725
axes = agghist.histref.axes
692726

693727
if flow:
694728
edges = [
695-
da.asarray(np.concatenate([[-np.inf], ax.edges, [np.inf]])) for ax in axes
729+
asarray(np.concatenate([[-np.inf], ax.edges, [np.inf]])) for ax in axes
696730
]
697731
else:
698-
edges = [da.asarray(ax.edges) for ax in axes]
732+
edges = [asarray(ax.edges) for ax in axes]
699733
if dd:
700734
return c, edges
701735
return (c, *tuple(edges))

0 commit comments

Comments
 (0)