Skip to content

Commit 58fe392

Browse files
authored
Merge pull request #123 from lgray/improve_optimization_choice
fix: properly use dask-awkward optimizations in all scenarios
2 parents 5289f10 + 49516f1 commit 58fe392

File tree

3 files changed

+21
-16
lines changed

3 files changed

+21
-16
lines changed

src/dask_histogram/boost.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from tlz import first
2121

2222
from dask_histogram.bins import normalize_bins_range
23-
from dask_histogram.core import AggHistogram, factory, optimize
23+
from dask_histogram.core import AggHistogram, _get_optimization_function, factory
2424

2525
if TYPE_CHECKING:
2626
from dask_histogram.typing import (
@@ -201,7 +201,7 @@ def __dask_postpersist__(self) -> Any:
201201
return self._rebuild, ()
202202

203203
__dask_optimize__ = globalmethod(
204-
optimize, key="histogram_optimize", falsey=dont_optimize
204+
_get_optimization_function(), key="histogram_optimize", falsey=dont_optimize
205205
)
206206

207207
__dask_scheduler__ = staticmethod(tget)

src/dask_histogram/core.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -401,27 +401,28 @@ def optimize(
401401
keys: Hashable | list[Hashable] | set[Hashable],
402402
**kwargs: Any,
403403
) -> Mapping:
404-
if not isinstance(keys, (list, set)):
405-
keys = [keys]
406-
keys = list(flatten(keys))
404+
keys = tuple(flatten(keys))
407405

408406
if not isinstance(dsk, HighLevelGraph):
409-
dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=())
407+
dsk = HighLevelGraph.from_collections(str(id(dsk)), dsk, dependencies=())
410408

409+
dsk = optimize_blockwise(dsk, keys=keys)
410+
dsk = fuse_roots(dsk, keys=keys) # type: ignore
411+
dsk = dsk.cull(set(keys)) # type: ignore
412+
return dsk
413+
414+
415+
def _get_optimization_function():
411416
# Here we try to run optimizations from dask-awkward (if we detect
412417
# that dask-awkward has been imported). There is no cost to
413418
# running this optimization even in cases where it's unncessary
414-
# because if no AwkwardInputLayers from daks-awkward are not
419+
# because if no AwkwardInputLayers from dask-awkward are
415420
# detected then the original graph is returned unchanged.
416421
if dask.config.get("awkward", default=False):
417-
from dask_awkward.lib.optimize import optimize
422+
from dask_awkward.lib.optimize import all_optimizations
418423

419-
dsk = optimize(dsk, keys=keys) # type: ignore[arg-type]
420-
421-
dsk = optimize_blockwise(dsk, keys=keys)
422-
dsk = fuse_roots(dsk, keys=keys) # type: ignore
423-
dsk = dsk.cull(set(keys)) # type: ignore
424-
return dsk
424+
return all_optimizations
425+
return optimize
425426

426427

427428
class AggHistogram(DaskMethodsMixin):
@@ -479,7 +480,7 @@ def __dask_postpersist__(self) -> Any:
479480
return self._rebuild, ()
480481

481482
__dask_optimize__ = globalmethod(
482-
optimize, key="histogram_optimize", falsey=dont_optimize
483+
_get_optimization_function(), key="histogram_optimize", falsey=dont_optimize
483484
)
484485

485486
__dask_scheduler__ = staticmethod(tget)
@@ -706,7 +707,7 @@ def _rebuild(self, dsk: Any, *, rename: Any = None) -> Any:
706707
return type(self)(dsk, name, self.npartitions, self.histref)
707708

708709
__dask_optimize__ = globalmethod(
709-
optimize, key="histogram_optimize", falsey=dont_optimize
710+
_get_optimization_function(), key="histogram_optimize", falsey=dont_optimize
710711
)
711712

712713
__dask_scheduler__ = staticmethod(tget)

tests/test_boost.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,10 @@ def test_obj_5D_strcat_intcat_rectangular_dak(use_weights):
176176
dhb.axis.Regular(9, -3.2, 3.2),
177177
storage=storage,
178178
)
179+
180+
# check that we are using the correct optimizer
181+
assert h.__dask_optimize__ == dak.lib.optimize.all_optimizations
182+
179183
for i in range(25):
180184
h.fill(f"testcat{i+1}", i + 1, x, y, z, weight=weights)
181185
h = h.compute()

0 commit comments

Comments
 (0)