Skip to content

Commit 5d94859

Browse files
authored
Include dask-awkward's optimize if awkward key exists in dask.config. (#49)
1 parent 9f5ef9d commit 5d94859

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

src/dask_histogram/core.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import TYPE_CHECKING, Any, Callable, Hashable, Mapping, Sequence
77

88
import boost_histogram as bh
9+
import dask.config
910
import numpy as np
1011
from dask.base import DaskMethodsMixin, dont_optimize, is_dask_collection, tokenize
1112
from dask.blockwise import fuse_roots, optimize_blockwise
@@ -226,6 +227,16 @@ def optimize(
226227
if not isinstance(dsk, HighLevelGraph):
227228
dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=())
228229

230+
# Here we try to run optimizations from dask-awkward (if we detect
231+
# that dask-awkward has been imported). There is no cost to
232+
# running this optimization even in cases where it's unncessary
233+
# because if no AwkwardInputLayers from daks-awkward are not
234+
# detected then the original graph is returned unchanged.
235+
if dask.config.get("awkward", default=False):
236+
from dask_awkward.lib.optimize import optimize
237+
238+
dsk = optimize(dsk, keys=keys)
239+
229240
dsk = optimize_blockwise(dsk, keys=keys)
230241
dsk = fuse_roots(dsk, keys=keys) # type: ignore
231242
dsk = dsk.cull(set(keys)) # type: ignore

0 commit comments

Comments
 (0)