Skip to content

Commit 64158be

Browse files
committed
add public gain_scale parameter
1 parent 8d2330b commit 64158be

File tree

3 files changed

+21
-2
lines changed

3 files changed

+21
-2
lines changed

python/interpret-core/interpret/develop.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
"min_samples_leaf_nominal": None,
1919
"max_cat_threshold": 9223372036854775807,
2020
"cat_include": 1.0,
21-
"cat_scale": 1.0,
2221
"purify_boosting": False,
2322
"purify_result": False,
2423
"randomize_initial_feature_order": True,

python/interpret-core/interpret/glassbox/_ebm/_boost.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def boost(
2828
reg_alpha,
2929
reg_lambda,
3030
max_delta_step,
31+
gain_scale,
3132
min_cat_samples,
3233
cat_smooth,
3334
missing,
@@ -198,7 +199,7 @@ def boost(
198199

199200
if contains_nominals and len(term_features[term_idx]) == 1:
200201
# penalize nominals a bit because they benefit from sorting categories
201-
avg_gain *= develop.get_option("cat_scale")
202+
avg_gain *= gain_scale
202203

203204
gainkey = (-avg_gain, native.generate_seed(rng), term_idx)
204205
if not make_progress:

python/interpret-core/interpret/glassbox/_ebm/_ebm.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,7 @@ def __init__(
361361
reg_alpha,
362362
reg_lambda,
363363
max_delta_step,
364+
gain_scale,
364365
min_cat_samples,
365366
cat_smooth,
366367
missing,
@@ -411,6 +412,7 @@ def __init__(
411412
self.reg_alpha = reg_alpha
412413
self.reg_lambda = reg_lambda
413414
self.max_delta_step = max_delta_step
415+
self.gain_scale = gain_scale
414416
self.min_cat_samples = min_cat_samples
415417
self.cat_smooth = cat_smooth
416418
self.missing = missing
@@ -942,6 +944,7 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None):
942944
reg_alpha = 0.0
943945
reg_lambda = 0.0
944946
max_delta_step = 0.0
947+
gain_scale = 1.0
945948
min_cat_samples = 0
946949
cat_smooth = 0.0
947950
missing = "low"
@@ -965,6 +968,7 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None):
965968
reg_alpha = self.reg_alpha
966969
reg_lambda = self.reg_lambda
967970
max_delta_step = self.max_delta_step
971+
gain_scale = self.gain_scale
968972
min_cat_samples = self.min_cat_samples
969973
cat_smooth = self.cat_smooth
970974
missing = self.missing
@@ -1084,6 +1088,7 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None):
10841088
reg_alpha,
10851089
reg_lambda,
10861090
max_delta_step,
1091+
gain_scale,
10871092
min_cat_samples,
10881093
cat_smooth,
10891094
missing,
@@ -1359,6 +1364,7 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None):
13591364
reg_alpha,
13601365
reg_lambda,
13611366
max_delta_step,
1367+
gain_scale,
13621368
min_cat_samples,
13631369
cat_smooth,
13641370
missing,
@@ -1486,6 +1492,7 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None):
14861492
0.0,
14871493
0.0,
14881494
0.0,
1495+
1.0,
14891496
min_cat_samples,
14901497
cat_smooth,
14911498
missing,
@@ -2785,6 +2792,9 @@ class ExplainableBoostingClassifier(ClassifierMixin, EBMModel):
27852792
L2 regularization.
27862793
max_delta_step : float, default=0.0
27872794
Used to limit the max output of tree leaves. <=0.0 means no constraint.
2795+
gain_scale : float, default=1.0
2796+
Scale factor to apply to nominal categoricals. A scale factor above 1.0 will cause the
2797+
algorithm focus more on the nominal categoricals.
27882798
min_cat_samples : int, default=10
27892799
Minimum number of samples in order to treat a category separately. If lower than this threshold
27902800
the category is combined with other categories that have low numbers of samples.
@@ -2964,6 +2974,7 @@ def __init__(
29642974
reg_alpha: Optional[float] = 0.0,
29652975
reg_lambda: Optional[float] = 0.0,
29662976
max_delta_step: Optional[float] = 0.0,
2977+
gain_scale: Optional[float] = 1.0,
29672978
min_cat_samples: Optional[int] = 10,
29682979
cat_smooth: Optional[float] = 10.0,
29692980
missing: str = "separate",
@@ -2997,6 +3008,7 @@ def __init__(
29973008
reg_alpha=reg_alpha,
29983009
reg_lambda=reg_lambda,
29993010
max_delta_step=max_delta_step,
3011+
gain_scale=gain_scale,
30003012
min_cat_samples=min_cat_samples,
30013013
cat_smooth=cat_smooth,
30023014
missing=missing,
@@ -3167,6 +3179,9 @@ class ExplainableBoostingRegressor(RegressorMixin, EBMModel):
31673179
L2 regularization.
31683180
max_delta_step : float, default=0.0
31693181
Used to limit the max output of tree leaves. <=0.0 means no constraint.
3182+
gain_scale : float, default=1.0
3183+
Scale factor to apply to nominal categoricals. A scale factor above 1.0 will cause the
3184+
algorithm focus more on the nominal categoricals.
31703185
min_cat_samples : int, default=10
31713186
Minimum number of samples in order to treat a category separately. If lower than this threshold
31723187
the category is combined with other categories that have low numbers of samples.
@@ -3346,6 +3361,7 @@ def __init__(
33463361
reg_alpha: Optional[float] = 0.0,
33473362
reg_lambda: Optional[float] = 0.0,
33483363
max_delta_step: Optional[float] = 0.0,
3364+
gain_scale: Optional[float] = 1.0,
33493365
min_cat_samples: Optional[int] = 10,
33503366
cat_smooth: Optional[float] = 10.0,
33513367
missing: str = "separate",
@@ -3379,6 +3395,7 @@ def __init__(
33793395
reg_alpha=reg_alpha,
33803396
reg_lambda=reg_lambda,
33813397
max_delta_step=max_delta_step,
3398+
gain_scale=gain_scale,
33823399
min_cat_samples=min_cat_samples,
33833400
cat_smooth=cat_smooth,
33843401
missing=missing,
@@ -3615,6 +3632,7 @@ def __init__(
36153632
reg_alpha=0.0,
36163633
reg_lambda=0.0,
36173634
max_delta_step=0.0,
3635+
gain_scale=1.0,
36183636
min_cat_samples=0,
36193637
cat_smooth=0.0,
36203638
missing=None,
@@ -3896,6 +3914,7 @@ def __init__(
38963914
reg_alpha=0.0,
38973915
reg_lambda=0.0,
38983916
max_delta_step=0.0,
3917+
gain_scale=1.0,
38993918
min_cat_samples=0,
39003919
cat_smooth=0.0,
39013920
missing=None,

0 commit comments

Comments
 (0)