Skip to content

Commit 3affcae

Browse files
authored
Experimental support dask_awkward collections, formatting and organization (#12)
1 parent 6085b37 commit 3affcae

File tree

6 files changed

+72
-33
lines changed

6 files changed

+72
-33
lines changed

.pre-commit-config.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,7 @@ repos:
2222
- id: pyupgrade
2323
args:
2424
- --py37-plus
25+
- repo: https://github.com/MarcoGorelli/absolufy-imports
26+
rev: v0.3.1
27+
hooks:
28+
- id: absolufy-imports

src/dask_histogram/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import boost_histogram.axis as _axis
44
import boost_histogram.storage as _storage
55

6-
from ._version import version as __version__
7-
from .core import AggHistogram, PartitionedHistogram, factory
8-
from .routines import histogram, histogram2d, histogramdd
6+
from dask_histogram._version import version as __version__
7+
from dask_histogram.core import AggHistogram, PartitionedHistogram, factory
8+
from dask_histogram.routines import histogram, histogram2d, histogramdd
99

1010
version_info = tuple(__version__.split("."))
1111

src/dask_histogram/bins.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
if TYPE_CHECKING:
1111
from typing import Sequence
1212

13-
from .typing import BinArg, BinType, RangeArg, RangeType
13+
from dask_histogram.typing import BinArg, BinType, RangeArg, RangeType
1414

1515

1616
class BinsStyle(Enum):

src/dask_histogram/boost.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,17 @@
1414
from dask.delayed import Delayed, delayed
1515
from dask.utils import is_arraylike, is_dataframe_like
1616

17-
from .bins import normalize_bins_range
18-
from .core import AggHistogram, factory
17+
from dask_histogram.bins import normalize_bins_range
18+
from dask_histogram.core import AggHistogram, factory
1919

2020
if TYPE_CHECKING:
21-
from .typing import BinArg, BinType, DaskCollection, RangeArg, RangeType
21+
from dask_histogram.typing import (
22+
BinArg,
23+
BinType,
24+
DaskCollection,
25+
RangeArg,
26+
RangeType,
27+
)
2228

2329
import dask_histogram
2430

src/dask_histogram/core.py

Lines changed: 46 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,19 @@
88
import boost_histogram as bh
99
import dask.array as da
1010
import numpy as np
11-
from dask.bag.core import empty_safe_aggregate, partition_all
11+
from dask.bag.core import empty_safe_aggregate
1212
from dask.base import DaskMethodsMixin, is_dask_collection, tokenize
1313
from dask.dataframe.core import partitionwise_graph as partitionwise
1414
from dask.delayed import Delayed
1515
from dask.highlevelgraph import HighLevelGraph
1616
from dask.threaded import get as tget
1717
from dask.utils import is_dataframe_like, key_split
18+
from tlz import partition_all
1819

1920
if TYPE_CHECKING:
2021
from numpy.typing import NDArray
2122

22-
from .typing import DaskCollection
23+
from dask_histogram.typing import DaskCollection
2324

2425
__all__ = (
2526
"AggHistogram",
@@ -29,7 +30,7 @@
2930
)
3031

3132

32-
def clone(histref: bh.Histogram = None) -> bh.Histogram:
33+
def clone(histref: bh.Histogram | None = None) -> bh.Histogram:
3334
"""Create a Histogram object based on another.
3435
3536
The axes and storage of the `histref` will be used to create a new
@@ -54,7 +55,7 @@ def clone(histref: bh.Histogram = None) -> bh.Histogram:
5455
def _blocked_sa(
5556
data: Any,
5657
*,
57-
histref: bh.Histogram = None,
58+
histref: bh.Histogram | None = None,
5859
) -> bh.Histogram:
5960
"""Blocked calculation; single argument; unweighted; no sample."""
6061
if data.ndim == 1:
@@ -69,7 +70,7 @@ def _blocked_sa_s(
6970
data: Any,
7071
sample: Any,
7172
*,
72-
histref: bh.Histogram = None,
73+
histref: bh.Histogram | None = None,
7374
) -> bh.Histogram:
7475
"""Blocked calculation; single argument; unweighted; with sample."""
7576
if data.ndim == 1:
@@ -84,7 +85,7 @@ def _blocked_sa_w(
8485
data: Any,
8586
weights: Any,
8687
*,
87-
histref: bh.Histogram = None,
88+
histref: bh.Histogram | None = None,
8889
) -> bh.Histogram:
8990
"""Blocked calculation; single argument; weighted; no sample."""
9091
if data.ndim == 1:
@@ -100,7 +101,7 @@ def _blocked_sa_w_s(
100101
weights: Any,
101102
sample: Any,
102103
*,
103-
histref: bh.Histogram = None,
104+
histref: bh.Histogram | None = None,
104105
) -> bh.Histogram:
105106
"""Blocked calculation; single argument; weighted; with sample."""
106107
if data.ndim == 1:
@@ -113,15 +114,15 @@ def _blocked_sa_w_s(
113114

114115
def _blocked_ma(
115116
*data: Any,
116-
histref: bh.Histogram = None,
117+
histref: bh.Histogram | None = None,
117118
) -> bh.Histogram:
118119
"""Blocked calculation; multiargument; unweighted; no sample."""
119120
return clone(histref).fill(*data)
120121

121122

122123
def _blocked_ma_s(
123124
*data: Any,
124-
histref: bh.Histogram = None,
125+
histref: bh.Histogram | None = None,
125126
) -> bh.Histogram:
126127
"""Blocked calculation; multiargument; unweighted; with sample."""
127128
sample = data[-1]
@@ -131,7 +132,7 @@ def _blocked_ma_s(
131132

132133
def _blocked_ma_w(
133134
*data: Any,
134-
histref: bh.Histogram = None,
135+
histref: bh.Histogram | None = None,
135136
) -> bh.Histogram:
136137
"""Blocked calculation; multiargument; weighted; no sample."""
137138
weights = data[-1]
@@ -141,7 +142,7 @@ def _blocked_ma_w(
141142

142143
def _blocked_ma_w_s(
143144
*data: Any,
144-
histref: bh.Histogram = None,
145+
histref: bh.Histogram | None = None,
145146
) -> bh.Histogram:
146147
"""Blocked calculation; multiargument; weighted; with sample."""
147148
weights = data[-2]
@@ -153,7 +154,7 @@ def _blocked_ma_w_s(
153154
def _blocked_df(
154155
data: Any,
155156
*,
156-
histref: bh.Histogram = None,
157+
histref: bh.Histogram | None = None,
157158
) -> bh.Histogram:
158159
return clone(histref).fill(*(data[c] for c in data.columns), weight=None)
159160

@@ -162,7 +163,7 @@ def _blocked_df_s(
162163
data: Any,
163164
sample: Any,
164165
*,
165-
histref: bh.Histogram = None,
166+
histref: bh.Histogram | None = None,
166167
) -> bh.Histogram:
167168
return clone(histref).fill(*(data[c] for c in data.columns), sample=sample)
168169

@@ -171,7 +172,7 @@ def _blocked_df_w(
171172
data: Any,
172173
weights: Any,
173174
*,
174-
histref: bh.Histogram = None,
175+
histref: bh.Histogram | None = None,
175176
) -> bh.Histogram:
176177
"""Blocked calculation; single argument; weighted; no sample."""
177178
return clone(histref).fill(*(data[c] for c in data.columns), weight=weights)
@@ -182,14 +183,18 @@ def _blocked_df_w_s(
182183
weights: Any,
183184
sample: Any,
184185
*,
185-
histref: bh.Histogram = None,
186+
histref: bh.Histogram | None = None,
186187
) -> bh.Histogram:
187188
"""Blocked calculation; single argument; weighted; with sample."""
188189
return clone(histref).fill(
189190
*(data[c] for c in data.columns), weight=weights, sample=sample
190191
)
191192

192193

194+
def _blocked_dak(data: Any, *, histref: bh.Histogram | None = None) -> bh.Histogram:
195+
return clone(histref).fill(data)
196+
197+
193198
class AggHistogram(DaskMethodsMixin):
194199
"""Aggregated Histogram collection.
195200
@@ -346,13 +351,13 @@ def to_delayed(self) -> Delayed:
346351
return Delayed(self.name, dsk, layer=self._layer)
347352

348353
def values(self, flow: bool = False) -> NDArray[Any]:
349-
return self.to_boost().values()
354+
return self.to_boost().values(flow=flow)
350355

351356
def variances(self, flow: bool = False) -> NDArray[Any] | None:
352-
return self.to_boost().variances()
357+
return self.to_boost().variances(flow=flow)
353358

354359
def counts(self, flow: bool = False) -> NDArray[Any]:
355-
return self.to_boost().counts()
360+
return self.to_boost().counts(flow=flow)
356361

357362
def __array__(self) -> NDArray[Any]:
358363
return self.compute().__array__()
@@ -483,12 +488,15 @@ def histref(self) -> bh.Histogram:
483488
"""boost_histogram.Histogram: reference histogram."""
484489
return self._histref
485490

486-
def to_agg(self, split_every: int = None) -> AggHistogram:
491+
def to_agg(self, split_every: int | None = None) -> AggHistogram:
487492
"""Translate into a reduced aggregated histogram."""
488493
return _reduction(self, split_every=split_every)
489494

490495

491-
def _reduction(ph: PartitionedHistogram, split_every: int = None) -> AggHistogram:
496+
def _reduction(
497+
ph: PartitionedHistogram,
498+
split_every: int | None = None,
499+
) -> AggHistogram:
492500
if split_every is None:
493501
split_every = 4
494502
if split_every is False:
@@ -568,7 +576,19 @@ def _partitioned_histogram(
568576
name = f"hist-on-block-{tokenize(data, histref, weights, sample)}"
569577
data_is_df = is_dataframe_like(data[0])
570578
_weight_sample_check(*data, weights=weights)
571-
if len(data) == 1 and not data_is_df:
579+
if len(data) == 1 and hasattr(data[0], "_typetracer"):
580+
from dask_awkward.core import partitionwise_layer as pwlayer
581+
582+
x = data[0]
583+
if weights is not None and sample is not None:
584+
raise NotImplementedError()
585+
elif weights is not None and sample is None:
586+
raise NotImplementedError()
587+
elif weights is None and sample is not None:
588+
raise NotImplementedError()
589+
else:
590+
g = pwlayer(_blocked_dak, name, x, histref=histref)
591+
elif len(data) == 1 and not data_is_df:
572592
x = data[0]
573593
if weights is not None and sample is not None:
574594
g = partitionwise(
@@ -621,7 +641,6 @@ def _reduced_histogram(
621641
histref=histref,
622642
weights=weights,
623643
sample=sample,
624-
split_every=split_every,
625644
)
626645
return ph.to_agg(split_every=split_every)
627646

@@ -683,7 +702,11 @@ def to_dask_array(
683702

684703

685704
class BinaryOpAgg:
686-
def __init__(self, func: Callable[[Any, Any], Any], name: str = None) -> None:
705+
def __init__(
706+
self,
707+
func: Callable[[Any, Any], Any],
708+
name: str | None = None,
709+
) -> None:
687710
self.func = func
688711
self.__name__ = func.__name__ if name is None else name
689712

src/dask_histogram/routines.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,17 @@
1010
from dask.base import is_dask_collection
1111
from dask.utils import is_arraylike, is_dataframe_like
1212

13-
from .bins import normalize_bins_range
14-
from .core import AggHistogram, factory
13+
from dask_histogram.bins import normalize_bins_range
14+
from dask_histogram.core import AggHistogram, factory
1515

1616
if TYPE_CHECKING:
17-
from .typing import BinArg, BinType, DaskCollection, RangeArg, RangeType
17+
from dask_histogram.typing import (
18+
BinArg,
19+
BinType,
20+
DaskCollection,
21+
RangeArg,
22+
RangeType,
23+
)
1824
else:
1925
DaskCollection = object
2026

0 commit comments

Comments
 (0)