Skip to content

Commit fa5a583

Browse files
authored
Integrate dask-expr and make CI happy (#980)
1 parent b5640cb commit fa5a583

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+329
-113
lines changed

.github/workflows/lint.yaml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ jobs:
55
lint:
66
runs-on: ubuntu-latest
77
steps:
8-
- uses: actions/checkout@v3
9-
- uses: actions/setup-python@v3
10-
- uses: pre-commit/[email protected]
8+
- uses: actions/[email protected]
9+
- uses: actions/setup-python@v5
10+
with:
11+
python-version: '3.9'
12+
- uses: pre-commit/[email protected]

.github/workflows/tests.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@ jobs:
1010
matrix:
1111
# os: ["windows-latest", "ubuntu-latest", "macos-latest"]
1212
os: ["ubuntu-latest"]
13-
python-version: ["3.8", "3.9", "3.10"]
13+
python-version: ["3.9", "3.10", "3.11"]
14+
query-planning: [true, false]
1415

1516
env:
1617
PYTHON_VERSION: ${{ matrix.python-version }}
1718
PARALLEL: "true"
1819
COVERAGE: "true"
20+
DASK_DATAFRAME__QUERY_PLANNING: ${{ matrix.query-planning }}
1921

2022
steps:
2123
- name: Checkout source

.pre-commit-config.yaml

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
repos:
2-
- repo: https://github.com/python/black
3-
rev: 22.3.0
2+
- repo: https://github.com/psf/black
3+
rev: 23.12.1
44
hooks:
55
- id: black
66
language_version: python3
7+
args:
8+
- --target-version=py39
79
- repo: https://github.com/pycqa/flake8
8-
rev: 3.7.9
10+
rev: 7.0.0
911
hooks:
1012
- id: flake8
1113
language_version: python3
12-
- repo: https://github.com/timothycrosley/isort
13-
rev: 4.3.21
14+
args: ["--ignore=E501,W503,E203,E741,E731"]
15+
- repo: https://github.com/pycqa/isort
16+
rev: 5.13.2
1417
hooks:
1518
- id: isort
1619
language_version: python3

ci/environment-3.10.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ channels:
33
- conda-forge
44
- defaults
55
dependencies:
6-
- dask
76
- dask-glm
87
- multipledispatch >=0.4.9
98
- mypy
@@ -21,3 +20,7 @@ dependencies:
2120
- scipy
2221
- sparse
2322
- toolz
23+
- pip
24+
- pip:
25+
- git+https://github.com/dask-contrib/dask-expr
26+
- git+https://github.com/dask/dask

ci/environment-3.8.yaml renamed to ci/environment-3.11.yaml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
name: dask-ml-3.8
1+
name: dask-ml-3.11
22
channels:
33
- conda-forge
44
- defaults
55
dependencies:
6-
- dask
76
- dask-glm
87
- multipledispatch >=0.4.9
98
- mypy
@@ -16,8 +15,12 @@ dependencies:
1615
- pytest
1716
- pytest-cov
1817
- pytest-mock
19-
- python=3.8.*
18+
- python=3.11.*
2019
- scikit-learn >=1.2.0
2120
- scipy
2221
- sparse
2322
- toolz
23+
- pip
24+
- pip:
25+
- git+https://github.com/dask-contrib/dask-expr
26+
- git+https://github.com/dask/dask

ci/environment-3.9.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ channels:
33
- conda-forge
44
- defaults
55
dependencies:
6-
- dask
76
- dask-glm
87
- multipledispatch >=0.4.9
98
- mypy
@@ -21,3 +20,7 @@ dependencies:
2120
- scipy
2221
- sparse
2322
- toolz
23+
- pip
24+
- pip:
25+
- git+https://github.com/dask-contrib/dask-expr
26+
- git+https://github.com/dask/dask

ci/environment-docs.yaml

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,28 +5,21 @@ channels:
55
dependencies:
66
- black
77
- coverage
8-
- graphviz
98
- heapdict
109
- ipykernel
1110
- ipython
1211
- multipledispatch
1312
- mypy
14-
- nbsphinx
1513
- nomkl
1614
- nose
1715
- numba
1816
- numpy
19-
- numpydoc
20-
- pandas
2117
- psutil
2218
- python=3.10
2319
- sortedcontainers
2420
- scikit-learn >=1.2.0
2521
- scipy
2622
- sparse
27-
- sphinx
28-
- sphinx_rtd_theme
29-
- sphinx-gallery
3023
- tornado
3124
- toolz
3225
- zict
@@ -35,5 +28,30 @@ dependencies:
3528
- dask-glm
3629
- dask-xgboost
3730
- pip:
38-
- dask-sphinx-theme >=3.0.0
3931
- graphviz
32+
- numpydoc
33+
- sphinx>=4.0.0,<5.0.0
34+
- dask-sphinx-theme>=3.0.0
35+
- sphinx-click
36+
- sphinx-copybutton
37+
- sphinx-remove-toctrees
38+
- sphinx_autosummary_accessors
39+
- sphinx-tabs
40+
- sphinx-design
41+
- jupyter_sphinx
42+
# FIXME: `sphinxcontrib-*` pins are a workaround until we have sphinx>=5.
43+
# See https://github.com/dask/dask-sphinx-theme/issues/68.
44+
- sphinxcontrib-applehelp>=1.0.0,<1.0.7
45+
- sphinxcontrib-devhelp>=1.0.0,<1.0.6
46+
- sphinxcontrib-htmlhelp>=2.0.0,<2.0.5
47+
- sphinxcontrib-serializinghtml>=1.1.0,<1.1.10
48+
- sphinxcontrib-qthelp>=1.0.0,<1.0.7
49+
- toolz
50+
- cloudpickle>=1.5.0
51+
- pandas>=1.4.0
52+
- dask-expr
53+
- fsspec
54+
- scipy
55+
- pytest
56+
- pytest-check-links
57+
- requests-cache

dask_ml/_partial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def fit(
2929
shuffle_blocks=True,
3030
random_state=None,
3131
assume_equal_chunks=False,
32-
**kwargs
32+
**kwargs,
3333
):
3434
"""Fit scikit learn model against dask arrays
3535

dask_ml/ensemble/_blockwise.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import dask
22
import dask.array as da
3-
import dask.dataframe as dd
43
import numpy as np
54
import sklearn.base
65
from sklearn.utils.validation import check_is_fitted
76

87
from ..base import ClassifierMixin, RegressorMixin
9-
from ..utils import check_array
8+
from ..utils import check_array, is_frame_base
109

1110

1211
class BlockwiseBase(sklearn.base.BaseEstimator):
@@ -62,7 +61,7 @@ def _predict(self, X):
6261
dtype=np.dtype(dtype),
6362
chunks=chunks,
6463
)
65-
elif isinstance(X, dd._Frame):
64+
elif is_frame_base(X):
6665
meta = np.empty((0, len(self.classes_)), dtype=dtype)
6766
combined = X.map_partitions(
6867
_predict_stack, estimators=self.estimators_, meta=meta
@@ -184,7 +183,7 @@ def _collect_probas(self, X):
184183
chunks=chunks,
185184
meta=meta,
186185
)
187-
elif isinstance(X, dd._Frame):
186+
elif is_frame_base(X):
188187
# TODO: replace with a _predict_proba_stack version.
189188
# This current raises; dask.dataframe doesn't like map_partitions that
190189
# return new axes.

dask_ml/linear_model/utils.py

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,60 @@
55
import numpy as np
66
from multipledispatch import dispatch
77

8+
if getattr(dd, "_dask_expr_enabled", lambda: False)():
9+
import dask_expr
810

9-
@dispatch(dd._Frame)
10-
def exp(A):
11-
return da.exp(A)
11+
@dispatch(dask_expr.FrameBase)
12+
def exp(A):
13+
return da.exp(A)
1214

15+
@dispatch(dask_expr.FrameBase)
16+
def absolute(A):
17+
return da.absolute(A)
1318

14-
@dispatch(dd._Frame)
15-
def absolute(A):
16-
return da.absolute(A)
19+
@dispatch(dask_expr.FrameBase)
20+
def sign(A):
21+
return da.sign(A)
1722

23+
@dispatch(dask_expr.FrameBase)
24+
def log1p(A):
25+
return da.log1p(A)
1826

19-
@dispatch(dd._Frame)
20-
def sign(A):
21-
return da.sign(A)
27+
@dispatch(dask_expr.FrameBase) # noqa: F811
28+
def add_intercept(X): # noqa: F811
29+
columns = X.columns
30+
if "intercept" in columns:
31+
raise ValueError("'intercept' column already in 'X'")
32+
return X.assign(intercept=1)[["intercept"] + list(columns)]
2233

34+
else:
2335

24-
@dispatch(dd._Frame)
25-
def log1p(A):
26-
return da.log1p(A)
36+
@dispatch(dd._Frame)
37+
def exp(A):
38+
return da.exp(A)
2739

40+
@dispatch(dd._Frame)
41+
def absolute(A):
42+
return da.absolute(A)
2843

29-
@dispatch(np.ndarray)
30-
def add_intercept(X):
44+
@dispatch(dd._Frame)
45+
def sign(A):
46+
return da.sign(A)
47+
48+
@dispatch(dd._Frame)
49+
def log1p(A):
50+
return da.log1p(A)
51+
52+
@dispatch(dd._Frame) # noqa: F811
53+
def add_intercept(X): # noqa: F811
54+
columns = X.columns
55+
if "intercept" in columns:
56+
raise ValueError("'intercept' column already in 'X'")
57+
return X.assign(intercept=1)[["intercept"] + list(columns)]
58+
59+
60+
@dispatch(np.ndarray) # noqa: F811
61+
def add_intercept(X): # noqa: F811
3162
return _add_intercept(X)
3263

3364

@@ -53,14 +84,6 @@ def add_intercept(X): # noqa: F811
5384
return X.map_blocks(_add_intercept, dtype=X.dtype, chunks=chunks)
5485

5586

56-
@dispatch(dd.DataFrame) # noqa: F811
57-
def add_intercept(X): # noqa: F811
58-
columns = X.columns
59-
if "intercept" in columns:
60-
raise ValueError("'intercept' column already in 'X'")
61-
return X.assign(intercept=1)[["intercept"] + list(columns)]
62-
63-
6487
@dispatch(np.ndarray) # noqa: F811
6588
def lr_prob_stack(prob): # noqa: F811
6689
return np.vstack([1 - prob, prob]).T

0 commit comments

Comments
 (0)