Skip to content

Commit d49a7db

Browse files
committed
add option flags for shuffling feature order
1 parent 058d66a commit d49a7db

File tree

3 files changed

+16
-4
lines changed

3 files changed

+16
-4
lines changed

docs/benchmarks/ebm-benchmark.ipynb

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@
3434
"source": [
3535
"# install powerlift if not already installed\n",
3636
"\n",
37-
"# !! IMPORTANT !! : until the next release, install locally with \"pip install -e .[datasets,postgres]\" from powerlift directory\n",
38-
"\n",
3937
"try:\n",
4038
" import powerlift\n",
4139
"except ModuleNotFoundError:\n",

python/interpret-core/interpret/develop.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,15 @@
66
_current_module = sys.modules[__name__]
77
_current_module.is_debug_mode = False
88

9+
# Global options
910
_purify_boosting = False
1011
_purify_result = False
12+
_randomize_initial_feature_order = True
13+
# TODO: investigate if _randomize_feature_order actually decreases accuracy
14+
# https://github.com/interpretml/interpret/issues/563#issuecomment-2240820952
15+
# this seems to decrease accuracy slightly, but helps with collinearity
16+
_randomize_greedy_feature_order = True # randomize feature order only if greedy enabled
17+
_randomize_feature_order = False # randomize feature order always
1118

1219

1320
def print_debug_info(file=None):

python/interpret-core/interpret/glassbox/_ebm/_boost.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import heapq
99
from ...utils._native import Native
10+
from ... import develop
1011

1112
import logging
1213

@@ -74,7 +75,7 @@ def boost(
7475
_log.info("Start boosting")
7576
native = Native.get_native_singleton()
7677
nominals = native.extract_nominals(dataset)
77-
random_cyclic_ordering = np.empty(len(term_features), np.int64)
78+
random_cyclic_ordering = np.arange(len(term_features), dtype=np.int64)
7879

7980
while step_idx < max_steps:
8081
term_boost_flags_local = term_boost_flags
@@ -85,7 +86,13 @@ def boost(
8586
bestkey = None
8687
heap = []
8788
# if pure cyclical then only randomize at start
88-
if 0 < greedy_steps or step_idx == 0:
89+
if (
90+
step_idx == 0
91+
and develop._randomize_initial_feature_order
92+
or develop._randomize_greedy_feature_order
93+
and 0 < greedy_steps
94+
or develop._randomize_feature_order
95+
):
8996
# TODO: test if shuffling during pure cyclic is better
9097
native.shuffle(rng, random_cyclic_ordering)
9198

0 commit comments

Comments
 (0)