@@ -401,27 +401,28 @@ def optimize(
401
401
keys : Hashable | list [Hashable ] | set [Hashable ],
402
402
** kwargs : Any ,
403
403
) -> Mapping :
404
- if not isinstance (keys , (list , set )):
405
- keys = [keys ]
406
- keys = list (flatten (keys ))
404
+ keys = tuple (flatten (keys ))
407
405
408
406
if not isinstance (dsk , HighLevelGraph ):
409
- dsk = HighLevelGraph .from_collections (id (dsk ), dsk , dependencies = ())
407
+ dsk = HighLevelGraph .from_collections (str ( id (dsk ) ), dsk , dependencies = ())
410
408
409
+ dsk = optimize_blockwise (dsk , keys = keys )
410
+ dsk = fuse_roots (dsk , keys = keys ) # type: ignore
411
+ dsk = dsk .cull (set (keys )) # type: ignore
412
+ return dsk
413
+
414
+
415
+ def _get_optimization_function ():
411
416
# Here we try to run optimizations from dask-awkward (if we detect
412
417
# that dask-awkward has been imported). There is no cost to
413
418
# running this optimization even in cases where it's unncessary
414
- # because if no AwkwardInputLayers from daks -awkward are not
419
+ # because if no AwkwardInputLayers from dask -awkward are
415
420
# detected then the original graph is returned unchanged.
416
421
if dask .config .get ("awkward" , default = False ):
417
- from dask_awkward .lib .optimize import optimize
422
+ from dask_awkward .lib .optimize import all_optimizations
418
423
419
- dsk = optimize (dsk , keys = keys ) # type: ignore[arg-type]
420
-
421
- dsk = optimize_blockwise (dsk , keys = keys )
422
- dsk = fuse_roots (dsk , keys = keys ) # type: ignore
423
- dsk = dsk .cull (set (keys )) # type: ignore
424
- return dsk
424
+ return all_optimizations
425
+ return optimize
425
426
426
427
427
428
class AggHistogram (DaskMethodsMixin ):
@@ -479,7 +480,7 @@ def __dask_postpersist__(self) -> Any:
479
480
return self ._rebuild , ()
480
481
481
482
__dask_optimize__ = globalmethod (
482
- optimize , key = "histogram_optimize" , falsey = dont_optimize
483
+ _get_optimization_function () , key = "histogram_optimize" , falsey = dont_optimize
483
484
)
484
485
485
486
__dask_scheduler__ = staticmethod (tget )
@@ -706,7 +707,7 @@ def _rebuild(self, dsk: Any, *, rename: Any = None) -> Any:
706
707
return type (self )(dsk , name , self .npartitions , self .histref )
707
708
708
709
__dask_optimize__ = globalmethod (
709
- optimize , key = "histogram_optimize" , falsey = dont_optimize
710
+ _get_optimization_function () , key = "histogram_optimize" , falsey = dont_optimize
710
711
)
711
712
712
713
__dask_scheduler__ = staticmethod (tget )
0 commit comments