Skip to content

Commit bc20f04

Browse files
authored
Merge pull request #129 from lgray/map_reduce_agg_hist_adds
feat!: when filling dask_histgram.boost.Histograms delay creation of task graph, use multi-fill
2 parents 8caccd5 + 941e0b5 commit bc20f04

File tree

5 files changed

+245
-135
lines changed

5 files changed

+245
-135
lines changed

src/dask_histogram/boost.py

Lines changed: 54 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import boost_histogram.storage as storage
1111
import dask
1212
import dask.array as da
13-
from dask.bag.core import empty_safe_aggregate, partition_all
1413
from dask.base import DaskMethodsMixin, dont_optimize, is_dask_collection, tokenize
1514
from dask.context import globalmethod
1615
from dask.delayed import Delayed, delayed
@@ -20,7 +19,12 @@
2019
from tlz import first
2120

2221
from dask_histogram.bins import normalize_bins_range
23-
from dask_histogram.core import AggHistogram, _get_optimization_function, factory
22+
from dask_histogram.core import (
23+
AggHistogram,
24+
_get_optimization_function,
25+
_partitioned_histogram_multifill,
26+
_reduction,
27+
)
2428

2529
if TYPE_CHECKING:
2630
from dask_histogram.typing import (
@@ -36,55 +40,6 @@
3640
__all__ = ("Histogram", "histogram", "histogram2d", "histogramdd")
3741

3842

39-
def _build_staged_tree_reduce(
40-
stages: list[AggHistogram], split_every: int | bool
41-
) -> HighLevelGraph:
42-
if not split_every:
43-
split_every = len(stages)
44-
45-
reducer = sum
46-
47-
token = tokenize(stages, reducer, split_every)
48-
49-
k = len(stages)
50-
b = ""
51-
fmt = f"staged-fill-aggregate-{token}"
52-
depth = 0
53-
54-
dsk = {}
55-
56-
if k > 1:
57-
while k > split_every:
58-
c = fmt + str(depth)
59-
for i, inds in enumerate(partition_all(split_every, range(k))):
60-
dsk[(c, i)] = (
61-
empty_safe_aggregate,
62-
reducer,
63-
[
64-
(stages[j].name if depth == 0 else b, 0 if depth == 0 else j)
65-
for j in inds
66-
],
67-
False,
68-
)
69-
70-
k = i + 1
71-
b = c
72-
depth += 1
73-
74-
dsk[(fmt, 0)] = (
75-
empty_safe_aggregate,
76-
reducer,
77-
[
78-
(stages[j].name if depth == 0 else b, 0 if depth == 0 else j)
79-
for j in range(k)
80-
],
81-
True,
82-
)
83-
return fmt, HighLevelGraph.from_collections(fmt, dsk, dependencies=stages)
84-
85-
return stages[0].name, stages[0].dask
86-
87-
8843
class Histogram(bh.Histogram, DaskMethodsMixin, family=dask_histogram):
8944
"""Histogram object capable of lazy computation.
9045
@@ -97,9 +52,6 @@ class Histogram(bh.Histogram, DaskMethodsMixin, family=dask_histogram):
9752
type is :py:class:`boost_histogram.storage.Double`.
9853
metadata : Any
9954
Data that is passed along if a new histogram is created.
100-
split_every : int | bool | None, default None
101-
Width of aggregation layers for staged fills.
102-
If False, all staged fills are added in one layer (memory intensive!).
10355
10456
See Also
10557
--------
@@ -139,7 +91,7 @@ def __init__(
13991
) -> None:
14092
"""Construct a Histogram object."""
14193
super().__init__(*axes, storage=storage, metadata=metadata)
142-
self._staged: list[AggHistogram] | None = None
94+
self._staged: AggHistogram | None = None
14395
self._dask_name: str | None = (
14496
f"empty-histogram-{tokenize(*axes, storage, metadata)}"
14597
)
@@ -148,29 +100,19 @@ def __init__(
148100
{},
149101
)
150102
self._split_every = split_every
151-
if self._split_every is None:
152-
self._split_every = dask.config.get("histogram.aggregation.split_every", 8)
153103

154104
@property
155105
def _histref(self):
156106
return (
157107
tuple(self.axes),
158-
self.storage_type,
108+
self.storage_type(),
159109
self.metadata,
160110
)
161111

162112
def __iadd__(self, other):
163-
if self.staged_fills() and other.staged_fills():
164-
self._staged += other._staged
165-
elif not self.staged_fills() and other.staged_fills():
166-
self._staged = other._staged
167-
if self.staged_fills():
168-
new_name, new_graph = _build_staged_tree_reduce(
169-
self._staged, self._split_every
170-
)
171-
self._dask = new_graph
172-
self._dask_name = new_name
173-
return self
113+
raise NotImplementedError(
114+
"dask-boost-histograms are not addable, please sum them after computation!"
115+
)
174116

175117
def __add__(self, other):
176118
return self.__iadd__(other)
@@ -234,6 +176,8 @@ def _in_memory_type(self) -> type[bh.Histogram]:
234176

235177
@property
236178
def dask_name(self) -> str:
179+
if self._dask_name == "__not_yet_calculated__" and self._dask is None:
180+
self._build_taskgraph()
237181
if self._dask_name is None:
238182
raise RuntimeError(
239183
"The dask name should never be None when it's requested."
@@ -242,12 +186,45 @@ def dask_name(self) -> str:
242186

243187
@property
244188
def dask(self) -> HighLevelGraph:
189+
if self._dask_name == "__not_yet_calculated__" and self._dask is None:
190+
self._build_taskgraph()
245191
if self._dask is None:
246192
raise RuntimeError(
247193
"The dask graph should never be None when it's requested."
248194
)
249195
return self._dask
250196

197+
def _build_taskgraph(self):
198+
data_list = []
199+
weights = []
200+
samples = []
201+
202+
for afill in self._staged:
203+
data_list.append(afill["args"])
204+
weights.append(afill["kwargs"]["weight"])
205+
samples.append(afill["kwargs"]["sample"])
206+
207+
if all(weight is None for weight in weights):
208+
weights = None
209+
210+
if not all(sample is None for sample in samples):
211+
samples = None
212+
213+
split_every = self._split_every or dask.config.get(
214+
"histogram.aggregation.split-every", 8
215+
)
216+
217+
fills = _partitioned_histogram_multifill(
218+
data_list, self._histref, weights, samples
219+
)
220+
221+
output_hist = _reduction(fills, split_every)
222+
223+
self._staged = None
224+
self._staged_result = output_hist
225+
self._dask = output_hist.dask
226+
self._dask_name = output_hist.name
227+
251228
def fill( # type: ignore
252229
self,
253230
*args: DaskCollection,
@@ -318,14 +295,13 @@ def fill( # type: ignore
318295
else:
319296
raise ValueError(f"Cannot interpret input data: {args}")
320297

321-
new_fill = factory(*args, histref=self._histref, weights=weight, sample=sample)
298+
new_fill = {"args": args, "kwargs": {"weight": weight, "sample": sample}}
322299
if self._staged is None:
323300
self._staged = [new_fill]
324301
else:
325-
self._staged += [new_fill]
326-
new_name, new_graph = _build_staged_tree_reduce(self._staged, self._split_every)
327-
self._dask = new_graph
328-
self._dask_name = new_name
302+
self._staged.append(new_fill)
303+
self._dask = None
304+
self._dask_name = "__not_yet_calculated__"
329305

330306
return self
331307

@@ -383,7 +359,8 @@ def to_delayed(self) -> Delayed:
383359
384360
"""
385361
if self._staged is not None:
386-
return sum(self._staged[1:], start=self._staged[0]).to_delayed()
362+
self._build_taskgraph()
363+
return self._staged_result.to_delayed()
387364
return delayed(bh.Histogram(self))
388365

389366
def __repr__(self) -> str:
@@ -449,7 +426,8 @@ def to_dask_array(self, flow: bool = False, dd: bool = True) -> Any:
449426
450427
"""
451428
if self._staged is not None:
452-
return sum(self._staged).to_dask_array(flow=flow, dd=dd)
429+
self._build_taskgraph()
430+
return self._staged_result.to_dask_array(flow=flow, dd=dd)
453431
else:
454432
counts, edges = self.to_numpy(flow=flow, dd=True, view=False)
455433
counts = da.from_array(counts)

0 commit comments

Comments
 (0)