@@ -199,6 +199,10 @@ def _blocked_dak(data: Any, *, histref: bh.Histogram | None = None) -> bh.Histog
199
199
return clone (histref ).fill (data )
200
200
201
201
202
+ def _blocked_dak_ma (* data : Any , histref : bh .Histogram | None = None ) -> bh .Histogram :
203
+ return clone (histref ).fill (* data )
204
+
205
+
202
206
def optimize (
203
207
dsk : Mapping ,
204
208
keys : Hashable | list [Hashable ] | set [Hashable ],
@@ -210,12 +214,10 @@ def optimize(
210
214
211
215
if not isinstance (dsk , HighLevelGraph ):
212
216
dsk = HighLevelGraph .from_collections (id (dsk ), dsk , dependencies = ())
213
- else :
214
- # Perform Blockwise optimizations for HLG input
215
- dsk = optimize_blockwise (dsk , keys = keys )
216
- dsk = fuse_roots (dsk , keys = keys ) # type: ignore
217
- dsk = dsk .cull (set (keys )) # type: ignore
218
217
218
+ dsk = optimize_blockwise (dsk , keys = keys )
219
+ dsk = fuse_roots (dsk , keys = keys ) # type: ignore
220
+ dsk = dsk .cull (set (keys )) # type: ignore
219
221
return dsk
220
222
221
223
@@ -334,15 +336,8 @@ def __str__(self) -> str:
334
336
335
337
__repr__ = __str__
336
338
337
- @property
338
- def _args (self ) -> tuple [HighLevelGraph , str , bh .Histogram ]:
339
- return (self .dask , self .name , self .histref )
340
-
341
- def __getstate__ (self ) -> tuple [HighLevelGraph , str , bh .Histogram ]:
342
- return self ._args
343
-
344
- def __setstate__ (self , state : tuple [HighLevelGraph , str , bh .Histogram ]) -> None :
345
- self ._dask , self ._name , self ._histref = state
339
+ def __reduce__ (self ):
340
+ return (AggHistogram , (self ._dask , self ._name , self ._histref ))
346
341
347
342
def to_dask_array (
348
343
self , flow : bool = False , dd : bool = False
@@ -375,9 +370,15 @@ def to_boost(self) -> bh.Histogram:
375
370
"""
376
371
return self .compute ()
377
372
378
- def to_delayed (self ) -> Delayed :
379
- dsk = self .__dask_graph__ ()
380
- return Delayed (self .name , dsk , layer = self ._layer )
373
+ def to_delayed (self , optimize_graph : bool = True ) -> Delayed :
374
+ keys = self .__dask_keys__ ()
375
+ graph = self .__dask_graph__ ()
376
+ layer = self .__dask_layers__ ()[0 ]
377
+ if optimize_graph :
378
+ graph = self .__dask_optimize__ (graph , keys )
379
+ layer = f"delayed-{ self .name } "
380
+ graph = HighLevelGraph .from_collections (layer , graph , dependencies = ())
381
+ return Delayed (keys [0 ], graph , layer = layer )
381
382
382
383
def values (self , flow : bool = False ) -> NDArray [Any ]:
383
384
return self .to_boost ().values (flow = flow )
@@ -454,7 +455,7 @@ def __init__(
454
455
self , dsk : HighLevelGraph , name : str , npartitions : int , histref : bh .Histogram
455
456
) -> None :
456
457
self ._dask : HighLevelGraph = dsk
457
- self ._name = name
458
+ self ._name : str = name
458
459
self ._npartitions : int = npartitions
459
460
self ._histref : bh .Histogram = histref
460
461
@@ -505,27 +506,36 @@ def __str__(self) -> str:
505
506
506
507
__repr__ = __str__
507
508
508
- @property
509
- def _args (self ) -> tuple [HighLevelGraph , str , int , bh .Histogram ]:
510
- return (self .dask , self .name , self .npartitions , self .histref )
511
-
512
- def __getstate__ (self ) -> tuple [HighLevelGraph , str , int , bh .Histogram ]:
513
- return self ._args
514
-
515
- def __setstate__ (
516
- self , state : tuple [HighLevelGraph , str , int , bh .Histogram ]
517
- ) -> None :
518
- self ._dask , self ._name , self ._npartitions , self ._histref = state
509
+ def __reduce__ (self ):
510
+ return (
511
+ PartitionedHistogram ,
512
+ (
513
+ self ._dask ,
514
+ self ._name ,
515
+ self ._npartitions ,
516
+ self ._histref ,
517
+ ),
518
+ )
519
519
520
520
@property
521
521
def histref (self ) -> bh .Histogram :
522
522
"""boost_histogram.Histogram: reference histogram."""
523
523
return self ._histref
524
524
525
- def to_agg (self , split_every : int | None = None ) -> AggHistogram :
525
+ def collapse (self , split_every : int | None = None ) -> AggHistogram :
526
526
"""Translate into a reduced aggregated histogram."""
527
527
return _reduction (self , split_every = split_every )
528
528
529
+ def to_delayed (self , optimize_graph : bool = True ) -> list [Delayed ]:
530
+ keys = self .__dask_keys__ ()
531
+ graph = self .__dask_graph__ ()
532
+ layer = self .__dask_layers__ ()[0 ]
533
+ if optimize_graph :
534
+ graph = self .__dask_optimize__ (graph , keys )
535
+ layer = f"delayed-{ self .name } "
536
+ graph = HighLevelGraph .from_collections (layer , graph , dependencies = ())
537
+ return [Delayed (k , graph , layer = layer ) for k in keys ]
538
+
529
539
530
540
def _reduction (
531
541
ph : PartitionedHistogram ,
@@ -607,11 +617,14 @@ def _partitioned_histogram(
607
617
sample : DaskCollection | None = None ,
608
618
split_every : int | None = None ,
609
619
) -> PartitionedHistogram :
610
- name = f"hist-on-block-{ tokenize (data , histref , weights , sample )} "
620
+ name = f"hist-on-block-{ tokenize (data , histref , weights , sample , split_every )} "
611
621
data_is_df = is_dataframe_like (data [0 ])
622
+ data_is_dak = is_awkward_like (data [0 ])
612
623
_weight_sample_check (* data , weights = weights )
613
- if len (data ) == 1 and hasattr (data [0 ], "_typetracer" ):
614
- from dask_awkward .core import partitionwise_layer as pwlayer
624
+
625
+ # Single awkward array object.
626
+ if len (data ) == 1 and data_is_dak :
627
+ from dask_awkward .core import partitionwise_layer as dak_pwl
615
628
616
629
x = data [0 ]
617
630
if weights is not None and sample is not None :
@@ -621,7 +634,9 @@ def _partitioned_histogram(
621
634
elif weights is None and sample is not None :
622
635
raise NotImplementedError ()
623
636
else :
624
- g = pwlayer (_blocked_dak , name , x , histref = histref )
637
+ g = dak_pwl (_blocked_dak , name , x , histref = histref )
638
+
639
+ # Single object, not a dataframe
625
640
elif len (data ) == 1 and not data_is_df :
626
641
x = data [0 ]
627
642
if weights is not None and sample is not None :
@@ -634,6 +649,8 @@ def _partitioned_histogram(
634
649
g = partitionwise (_blocked_sa_s , name , x , sample , histref = histref )
635
650
else :
636
651
g = partitionwise (_blocked_sa , name , x , histref = histref )
652
+
653
+ # Single object, is a dataframe
637
654
elif len (data ) == 1 and data_is_df :
638
655
x = data [0 ]
639
656
if weights is not None and sample is not None :
@@ -646,8 +663,20 @@ def _partitioned_histogram(
646
663
g = partitionwise (_blocked_df_s , name , x , sample , histref = histref )
647
664
else :
648
665
g = partitionwise (_blocked_df , name , x , histref = histref )
666
+
667
+ # Multiple objects
649
668
else :
650
- if weights is not None and sample is not None :
669
+
670
+ # Awkward array collection detected as first argument
671
+ if data_is_dak :
672
+ from dask_awkward .core import partitionwise_layer as dak_pwl
673
+
674
+ if weights is None and sample is None :
675
+ g = dak_pwl (_blocked_dak_ma , name , * data , histref = histref )
676
+ else :
677
+ raise NotImplementedError ()
678
+ # Not an awkward array collection
679
+ elif weights is not None and sample is not None :
651
680
g = partitionwise (
652
681
_blocked_ma_w_s , name , * data , weights , sample , histref = histref
653
682
)
@@ -663,22 +692,6 @@ def _partitioned_histogram(
663
692
return PartitionedHistogram (hlg , name , data [0 ].npartitions , histref = histref )
664
693
665
694
666
- def _reduced_histogram (
667
- * data : DaskCollection ,
668
- histref : bh .Histogram ,
669
- weights : DaskCollection | None = None ,
670
- sample : DaskCollection | None = None ,
671
- split_every : int | None = None ,
672
- ) -> AggHistogram :
673
- ph = _partitioned_histogram (
674
- * data ,
675
- histref = histref ,
676
- weights = weights ,
677
- sample = sample ,
678
- )
679
- return ph .to_agg (split_every = split_every )
680
-
681
-
682
695
def to_dask_array (
683
696
agghist : AggHistogram ,
684
697
flow : bool = False ,
@@ -888,11 +901,12 @@ def factory(
888
901
if storage is None :
889
902
storage = bh .storage .Double ()
890
903
histref = bh .Histogram (* axes , storage = storage ) # type: ignore
891
- f = _partitioned_histogram if keep_partitioned else _reduced_histogram
892
- return f ( # type: ignore
893
- * data ,
894
- histref = histref ,
895
- weights = weights ,
896
- sample = sample ,
897
- split_every = split_every ,
898
- )
904
+
905
+ ph = _partitioned_histogram (* data , histref = histref , weights = weights , sample = sample )
906
+ if keep_partitioned :
907
+ return ph
908
+ return ph .collapse (split_every = split_every )
909
+
910
+
911
+ def is_awkward_like (x : Any ) -> bool :
912
+ return is_dask_collection (x ) and hasattr (x , "_typetracer" )
0 commit comments