55import numpy as np
66from multipledispatch import dispatch
77
8+ if getattr (dd , "_dask_expr_enabled" , lambda : False )():
9+ import dask_expr
810
9- @dispatch (dd . _Frame )
10- def exp (A ):
11- return da .exp (A )
11+ @dispatch (dask_expr . FrameBase )
12+ def exp (A ):
13+ return da .exp (A )
1214
15+ @dispatch (dask_expr .FrameBase )
16+ def absolute (A ):
17+ return da .absolute (A )
1318
14- @dispatch (dd . _Frame )
15- def absolute (A ):
16- return da .absolute (A )
19+ @dispatch (dask_expr . FrameBase )
20+ def sign (A ):
21+ return da .sign (A )
1722
23+ @dispatch (dask_expr .FrameBase )
24+ def log1p (A ):
25+ return da .log1p (A )
1826
19- @dispatch (dd ._Frame )
20- def sign (A ):
21- return da .sign (A )
27+ @dispatch (dask_expr .FrameBase ) # noqa: F811
28+ def add_intercept (X ): # noqa: F811
29+ columns = X .columns
30+ if "intercept" in columns :
31+ raise ValueError ("'intercept' column already in 'X'" )
32+ return X .assign (intercept = 1 )[["intercept" ] + list (columns )]
2233
34+ else :
2335
24- @dispatch (dd ._Frame )
25- def log1p (A ):
26- return da .log1p (A )
36+ @dispatch (dd ._Frame )
37+ def exp (A ):
38+ return da .exp (A )
2739
40+ @dispatch (dd ._Frame )
41+ def absolute (A ):
42+ return da .absolute (A )
2843
29- @dispatch (np .ndarray )
30- def add_intercept (X ):
44+ @dispatch (dd ._Frame )
45+ def sign (A ):
46+ return da .sign (A )
47+
48+ @dispatch (dd ._Frame )
49+ def log1p (A ):
50+ return da .log1p (A )
51+
52+ @dispatch (dd ._Frame ) # noqa: F811
53+ def add_intercept (X ): # noqa: F811
54+ columns = X .columns
55+ if "intercept" in columns :
56+ raise ValueError ("'intercept' column already in 'X'" )
57+ return X .assign (intercept = 1 )[["intercept" ] + list (columns )]
58+
59+
60+ @dispatch (np .ndarray ) # noqa: F811
61+ def add_intercept (X ): # noqa: F811
3162 return _add_intercept (X )
3263
3364
@@ -53,14 +84,6 @@ def add_intercept(X): # noqa: F811
5384 return X .map_blocks (_add_intercept , dtype = X .dtype , chunks = chunks )
5485
5586
56- @dispatch (dd .DataFrame ) # noqa: F811
57- def add_intercept (X ): # noqa: F811
58- columns = X .columns
59- if "intercept" in columns :
60- raise ValueError ("'intercept' column already in 'X'" )
61- return X .assign (intercept = 1 )[["intercept" ] + list (columns )]
62-
63-
6487@dispatch (np .ndarray ) # noqa: F811
6588def lr_prob_stack (prob ): # noqa: F811
6689 return np .vstack ([1 - prob , prob ]).T
0 commit comments