Skip to content

Commit 3ba22e9

Browse files
committed
support multiargument dask-awkward
1 parent 0fa081d commit 3ba22e9

File tree

1 file changed

+73
-59
lines changed

1 file changed

+73
-59
lines changed

src/dask_histogram/core.py

Lines changed: 73 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,10 @@ def _blocked_dak(data: Any, *, histref: bh.Histogram | None = None) -> bh.Histog
199199
return clone(histref).fill(data)
200200

201201

202+
def _blocked_dak_ma(*data: Any, histref: bh.Histogram | None = None) -> bh.Histogram:
203+
return clone(histref).fill(*data)
204+
205+
202206
def optimize(
203207
dsk: Mapping,
204208
keys: Hashable | list[Hashable] | set[Hashable],
@@ -210,12 +214,10 @@ def optimize(
210214

211215
if not isinstance(dsk, HighLevelGraph):
212216
dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=())
213-
else:
214-
# Perform Blockwise optimizations for HLG input
215-
dsk = optimize_blockwise(dsk, keys=keys)
216-
dsk = fuse_roots(dsk, keys=keys) # type: ignore
217-
dsk = dsk.cull(set(keys)) # type: ignore
218217

218+
dsk = optimize_blockwise(dsk, keys=keys)
219+
dsk = fuse_roots(dsk, keys=keys) # type: ignore
220+
dsk = dsk.cull(set(keys)) # type: ignore
219221
return dsk
220222

221223

@@ -334,15 +336,8 @@ def __str__(self) -> str:
334336

335337
__repr__ = __str__
336338

337-
@property
338-
def _args(self) -> tuple[HighLevelGraph, str, bh.Histogram]:
339-
return (self.dask, self.name, self.histref)
340-
341-
def __getstate__(self) -> tuple[HighLevelGraph, str, bh.Histogram]:
342-
return self._args
343-
344-
def __setstate__(self, state: tuple[HighLevelGraph, str, bh.Histogram]) -> None:
345-
self._dask, self._name, self._histref = state
339+
def __reduce__(self):
340+
return (AggHistogram, (self._dask, self._name, self._histref))
346341

347342
def to_dask_array(
348343
self, flow: bool = False, dd: bool = False
@@ -375,9 +370,15 @@ def to_boost(self) -> bh.Histogram:
375370
"""
376371
return self.compute()
377372

378-
def to_delayed(self) -> Delayed:
379-
dsk = self.__dask_graph__()
380-
return Delayed(self.name, dsk, layer=self._layer)
373+
def to_delayed(self, optimize_graph: bool = True) -> Delayed:
374+
keys = self.__dask_keys__()
375+
graph = self.__dask_graph__()
376+
layer = self.__dask_layers__()[0]
377+
if optimize_graph:
378+
graph = self.__dask_optimize__(graph, keys)
379+
layer = f"delayed-{self.name}"
380+
graph = HighLevelGraph.from_collections(layer, graph, dependencies=())
381+
return Delayed(keys[0], graph, layer=layer)
381382

382383
def values(self, flow: bool = False) -> NDArray[Any]:
383384
return self.to_boost().values(flow=flow)
@@ -454,7 +455,7 @@ def __init__(
454455
self, dsk: HighLevelGraph, name: str, npartitions: int, histref: bh.Histogram
455456
) -> None:
456457
self._dask: HighLevelGraph = dsk
457-
self._name = name
458+
self._name: str = name
458459
self._npartitions: int = npartitions
459460
self._histref: bh.Histogram = histref
460461

@@ -505,27 +506,36 @@ def __str__(self) -> str:
505506

506507
__repr__ = __str__
507508

508-
@property
509-
def _args(self) -> tuple[HighLevelGraph, str, int, bh.Histogram]:
510-
return (self.dask, self.name, self.npartitions, self.histref)
511-
512-
def __getstate__(self) -> tuple[HighLevelGraph, str, int, bh.Histogram]:
513-
return self._args
514-
515-
def __setstate__(
516-
self, state: tuple[HighLevelGraph, str, int, bh.Histogram]
517-
) -> None:
518-
self._dask, self._name, self._npartitions, self._histref = state
509+
def __reduce__(self):
510+
return (
511+
PartitionedHistogram,
512+
(
513+
self._dask,
514+
self._name,
515+
self._npartitions,
516+
self._histref,
517+
),
518+
)
519519

520520
@property
521521
def histref(self) -> bh.Histogram:
522522
"""boost_histogram.Histogram: reference histogram."""
523523
return self._histref
524524

525-
def to_agg(self, split_every: int | None = None) -> AggHistogram:
525+
def collapse(self, split_every: int | None = None) -> AggHistogram:
526526
"""Translate into a reduced aggregated histogram."""
527527
return _reduction(self, split_every=split_every)
528528

529+
def to_delayed(self, optimize_graph: bool = True) -> list[Delayed]:
530+
keys = self.__dask_keys__()
531+
graph = self.__dask_graph__()
532+
layer = self.__dask_layers__()[0]
533+
if optimize_graph:
534+
graph = self.__dask_optimize__(graph, keys)
535+
layer = f"delayed-{self.name}"
536+
graph = HighLevelGraph.from_collections(layer, graph, dependencies=())
537+
return [Delayed(k, graph, layer=layer) for k in keys]
538+
529539

530540
def _reduction(
531541
ph: PartitionedHistogram,
@@ -607,11 +617,14 @@ def _partitioned_histogram(
607617
sample: DaskCollection | None = None,
608618
split_every: int | None = None,
609619
) -> PartitionedHistogram:
610-
name = f"hist-on-block-{tokenize(data, histref, weights, sample)}"
620+
name = f"hist-on-block-{tokenize(data, histref, weights, sample, split_every)}"
611621
data_is_df = is_dataframe_like(data[0])
622+
data_is_dak = is_awkward_like(data[0])
612623
_weight_sample_check(*data, weights=weights)
613-
if len(data) == 1 and hasattr(data[0], "_typetracer"):
614-
from dask_awkward.core import partitionwise_layer as pwlayer
624+
625+
# Single awkward array object.
626+
if len(data) == 1 and data_is_dak:
627+
from dask_awkward.core import partitionwise_layer as dak_pwl
615628

616629
x = data[0]
617630
if weights is not None and sample is not None:
@@ -621,7 +634,9 @@ def _partitioned_histogram(
621634
elif weights is None and sample is not None:
622635
raise NotImplementedError()
623636
else:
624-
g = pwlayer(_blocked_dak, name, x, histref=histref)
637+
g = dak_pwl(_blocked_dak, name, x, histref=histref)
638+
639+
# Single object, not a dataframe
625640
elif len(data) == 1 and not data_is_df:
626641
x = data[0]
627642
if weights is not None and sample is not None:
@@ -634,6 +649,8 @@ def _partitioned_histogram(
634649
g = partitionwise(_blocked_sa_s, name, x, sample, histref=histref)
635650
else:
636651
g = partitionwise(_blocked_sa, name, x, histref=histref)
652+
653+
# Single object, is a dataframe
637654
elif len(data) == 1 and data_is_df:
638655
x = data[0]
639656
if weights is not None and sample is not None:
@@ -646,8 +663,20 @@ def _partitioned_histogram(
646663
g = partitionwise(_blocked_df_s, name, x, sample, histref=histref)
647664
else:
648665
g = partitionwise(_blocked_df, name, x, histref=histref)
666+
667+
# Multiple objects
649668
else:
650-
if weights is not None and sample is not None:
669+
670+
# Awkward array collection detected as first argument
671+
if data_is_dak:
672+
from dask_awkward.core import partitionwise_layer as dak_pwl
673+
674+
if weights is None and sample is None:
675+
g = dak_pwl(_blocked_dak_ma, name, *data, histref=histref)
676+
else:
677+
raise NotImplementedError()
678+
# Not an awkward array collection
679+
elif weights is not None and sample is not None:
651680
g = partitionwise(
652681
_blocked_ma_w_s, name, *data, weights, sample, histref=histref
653682
)
@@ -663,22 +692,6 @@ def _partitioned_histogram(
663692
return PartitionedHistogram(hlg, name, data[0].npartitions, histref=histref)
664693

665694

666-
def _reduced_histogram(
667-
*data: DaskCollection,
668-
histref: bh.Histogram,
669-
weights: DaskCollection | None = None,
670-
sample: DaskCollection | None = None,
671-
split_every: int | None = None,
672-
) -> AggHistogram:
673-
ph = _partitioned_histogram(
674-
*data,
675-
histref=histref,
676-
weights=weights,
677-
sample=sample,
678-
)
679-
return ph.to_agg(split_every=split_every)
680-
681-
682695
def to_dask_array(
683696
agghist: AggHistogram,
684697
flow: bool = False,
@@ -888,11 +901,12 @@ def factory(
888901
if storage is None:
889902
storage = bh.storage.Double()
890903
histref = bh.Histogram(*axes, storage=storage) # type: ignore
891-
f = _partitioned_histogram if keep_partitioned else _reduced_histogram
892-
return f( # type: ignore
893-
*data,
894-
histref=histref,
895-
weights=weights,
896-
sample=sample,
897-
split_every=split_every,
898-
)
904+
905+
ph = _partitioned_histogram(*data, histref=histref, weights=weights, sample=sample)
906+
if keep_partitioned:
907+
return ph
908+
return ph.collapse(split_every=split_every)
909+
910+
911+
def is_awkward_like(x: Any) -> bool:
912+
return is_dask_collection(x) and hasattr(x, "_typetracer")

0 commit comments

Comments
 (0)