10
10
import boost_histogram .storage as storage
11
11
import dask
12
12
import dask .array as da
13
- from dask .bag .core import empty_safe_aggregate , partition_all
14
13
from dask .base import DaskMethodsMixin , dont_optimize , is_dask_collection , tokenize
15
14
from dask .context import globalmethod
16
15
from dask .delayed import Delayed , delayed
20
19
from tlz import first
21
20
22
21
from dask_histogram .bins import normalize_bins_range
23
- from dask_histogram .core import AggHistogram , _get_optimization_function , factory
22
+ from dask_histogram .core import (
23
+ AggHistogram ,
24
+ _get_optimization_function ,
25
+ _partitioned_histogram_multifill ,
26
+ _reduction ,
27
+ )
24
28
25
29
if TYPE_CHECKING :
26
30
from dask_histogram .typing import (
36
40
__all__ = ("Histogram" , "histogram" , "histogram2d" , "histogramdd" )
37
41
38
42
39
- def _build_staged_tree_reduce (
40
- stages : list [AggHistogram ], split_every : int | bool
41
- ) -> HighLevelGraph :
42
- if not split_every :
43
- split_every = len (stages )
44
-
45
- reducer = sum
46
-
47
- token = tokenize (stages , reducer , split_every )
48
-
49
- k = len (stages )
50
- b = ""
51
- fmt = f"staged-fill-aggregate-{ token } "
52
- depth = 0
53
-
54
- dsk = {}
55
-
56
- if k > 1 :
57
- while k > split_every :
58
- c = fmt + str (depth )
59
- for i , inds in enumerate (partition_all (split_every , range (k ))):
60
- dsk [(c , i )] = (
61
- empty_safe_aggregate ,
62
- reducer ,
63
- [
64
- (stages [j ].name if depth == 0 else b , 0 if depth == 0 else j )
65
- for j in inds
66
- ],
67
- False ,
68
- )
69
-
70
- k = i + 1
71
- b = c
72
- depth += 1
73
-
74
- dsk [(fmt , 0 )] = (
75
- empty_safe_aggregate ,
76
- reducer ,
77
- [
78
- (stages [j ].name if depth == 0 else b , 0 if depth == 0 else j )
79
- for j in range (k )
80
- ],
81
- True ,
82
- )
83
- return fmt , HighLevelGraph .from_collections (fmt , dsk , dependencies = stages )
84
-
85
- return stages [0 ].name , stages [0 ].dask
86
-
87
-
88
43
class Histogram (bh .Histogram , DaskMethodsMixin , family = dask_histogram ):
89
44
"""Histogram object capable of lazy computation.
90
45
@@ -97,9 +52,6 @@ class Histogram(bh.Histogram, DaskMethodsMixin, family=dask_histogram):
97
52
type is :py:class:`boost_histogram.storage.Double`.
98
53
metadata : Any
99
54
Data that is passed along if a new histogram is created.
100
- split_every : int | bool | None, default None
101
- Width of aggregation layers for staged fills.
102
- If False, all staged fills are added in one layer (memory intensive!).
103
55
104
56
See Also
105
57
--------
@@ -139,7 +91,7 @@ def __init__(
139
91
) -> None :
140
92
"""Construct a Histogram object."""
141
93
super ().__init__ (* axes , storage = storage , metadata = metadata )
142
- self ._staged : list [ AggHistogram ] | None = None
94
+ self ._staged : AggHistogram | None = None
143
95
self ._dask_name : str | None = (
144
96
f"empty-histogram-{ tokenize (* axes , storage , metadata )} "
145
97
)
@@ -148,29 +100,19 @@ def __init__(
148
100
{},
149
101
)
150
102
self ._split_every = split_every
151
- if self ._split_every is None :
152
- self ._split_every = dask .config .get ("histogram.aggregation.split_every" , 8 )
153
103
154
104
@property
155
105
def _histref (self ):
156
106
return (
157
107
tuple (self .axes ),
158
- self .storage_type ,
108
+ self .storage_type () ,
159
109
self .metadata ,
160
110
)
161
111
162
112
def __iadd__ (self , other ):
163
- if self .staged_fills () and other .staged_fills ():
164
- self ._staged += other ._staged
165
- elif not self .staged_fills () and other .staged_fills ():
166
- self ._staged = other ._staged
167
- if self .staged_fills ():
168
- new_name , new_graph = _build_staged_tree_reduce (
169
- self ._staged , self ._split_every
170
- )
171
- self ._dask = new_graph
172
- self ._dask_name = new_name
173
- return self
113
+ raise NotImplementedError (
114
+ "dask-boost-histograms are not addable, please sum them after computation!"
115
+ )
174
116
175
117
def __add__ (self , other ):
176
118
return self .__iadd__ (other )
@@ -234,6 +176,8 @@ def _in_memory_type(self) -> type[bh.Histogram]:
234
176
235
177
@property
236
178
def dask_name (self ) -> str :
179
+ if self ._dask_name == "__not_yet_calculated__" and self ._dask is None :
180
+ self ._build_taskgraph ()
237
181
if self ._dask_name is None :
238
182
raise RuntimeError (
239
183
"The dask name should never be None when it's requested."
@@ -242,12 +186,45 @@ def dask_name(self) -> str:
242
186
243
187
@property
244
188
def dask (self ) -> HighLevelGraph :
189
+ if self ._dask_name == "__not_yet_calculated__" and self ._dask is None :
190
+ self ._build_taskgraph ()
245
191
if self ._dask is None :
246
192
raise RuntimeError (
247
193
"The dask graph should never be None when it's requested."
248
194
)
249
195
return self ._dask
250
196
197
+ def _build_taskgraph (self ):
198
+ data_list = []
199
+ weights = []
200
+ samples = []
201
+
202
+ for afill in self ._staged :
203
+ data_list .append (afill ["args" ])
204
+ weights .append (afill ["kwargs" ]["weight" ])
205
+ samples .append (afill ["kwargs" ]["sample" ])
206
+
207
+ if all (weight is None for weight in weights ):
208
+ weights = None
209
+
210
+ if not all (sample is None for sample in samples ):
211
+ samples = None
212
+
213
+ split_every = self ._split_every or dask .config .get (
214
+ "histogram.aggregation.split-every" , 8
215
+ )
216
+
217
+ fills = _partitioned_histogram_multifill (
218
+ data_list , self ._histref , weights , samples
219
+ )
220
+
221
+ output_hist = _reduction (fills , split_every )
222
+
223
+ self ._staged = None
224
+ self ._staged_result = output_hist
225
+ self ._dask = output_hist .dask
226
+ self ._dask_name = output_hist .name
227
+
251
228
def fill ( # type: ignore
252
229
self ,
253
230
* args : DaskCollection ,
@@ -318,14 +295,13 @@ def fill( # type: ignore
318
295
else :
319
296
raise ValueError (f"Cannot interpret input data: { args } " )
320
297
321
- new_fill = factory ( * args , histref = self . _histref , weights = weight , sample = sample )
298
+ new_fill = { " args" : args , "kwargs" : { " weight" : weight , " sample" : sample }}
322
299
if self ._staged is None :
323
300
self ._staged = [new_fill ]
324
301
else :
325
- self ._staged += [new_fill ]
326
- new_name , new_graph = _build_staged_tree_reduce (self ._staged , self ._split_every )
327
- self ._dask = new_graph
328
- self ._dask_name = new_name
302
+ self ._staged .append (new_fill )
303
+ self ._dask = None
304
+ self ._dask_name = "__not_yet_calculated__"
329
305
330
306
return self
331
307
@@ -383,7 +359,8 @@ def to_delayed(self) -> Delayed:
383
359
384
360
"""
385
361
if self ._staged is not None :
386
- return sum (self ._staged [1 :], start = self ._staged [0 ]).to_delayed ()
362
+ self ._build_taskgraph ()
363
+ return self ._staged_result .to_delayed ()
387
364
return delayed (bh .Histogram (self ))
388
365
389
366
def __repr__ (self ) -> str :
@@ -449,7 +426,8 @@ def to_dask_array(self, flow: bool = False, dd: bool = True) -> Any:
449
426
450
427
"""
451
428
if self ._staged is not None :
452
- return sum (self ._staged ).to_dask_array (flow = flow , dd = dd )
429
+ self ._build_taskgraph ()
430
+ return self ._staged_result .to_dask_array (flow = flow , dd = dd )
453
431
else :
454
432
counts , edges = self .to_numpy (flow = flow , dd = True , view = False )
455
433
counts = da .from_array (counts )
0 commit comments