|
| 1 | +""" |
| 2 | +The :mod:`skltemplate.utils.discovery` module includes utilities to discover |
| 3 | +objects (i.e. estimators, displays, functions) from the `skltemplate` package. |
| 4 | +""" |
| 5 | + |
| 6 | +# Adapted from scikit-learn |
| 7 | +# Authors: scikit-learn-contrib developers |
| 8 | +# License: BSD 3 clause |
| 9 | + |
| 10 | +import inspect |
| 11 | +import pkgutil |
| 12 | +from importlib import import_module |
| 13 | +from operator import itemgetter |
| 14 | +from pathlib import Path |
| 15 | + |
| 16 | +from sklearn.base import ( |
| 17 | + BaseEstimator, |
| 18 | + ClassifierMixin, |
| 19 | + ClusterMixin, |
| 20 | + RegressorMixin, |
| 21 | + TransformerMixin, |
| 22 | +) |
| 23 | +from sklearn.utils._testing import ignore_warnings |
| 24 | + |
| 25 | +_MODULE_TO_IGNORE = {"tests"} |
| 26 | + |
| 27 | + |
| 28 | +def all_estimators(type_filter=None): |
| 29 | + """Get a list of all estimators from `skltemplate`. |
| 30 | +
|
| 31 | + This function crawls the module and gets all classes that inherit |
| 32 | + from `BaseEstimator`. Classes that are defined in test-modules are not |
| 33 | + included. |
| 34 | +
|
| 35 | + Parameters |
| 36 | + ---------- |
| 37 | + type_filter : {"classifier", "regressor", "cluster", "transformer"} \ |
| 38 | + or list of such str, default=None |
| 39 | + Which kind of estimators should be returned. If None, no filter is |
| 40 | + applied and all estimators are returned. Possible values are |
| 41 | + 'classifier', 'regressor', 'cluster' and 'transformer' to get |
| 42 | + estimators only of these specific types, or a list of these to |
| 43 | + get the estimators that fit at least one of the types. |
| 44 | +
|
| 45 | + Returns |
| 46 | + ------- |
| 47 | + estimators : list of tuples |
| 48 | + List of (name, class), where ``name`` is the class name as string |
| 49 | + and ``class`` is the actual type of the class. |
| 50 | +
|
| 51 | + Examples |
| 52 | + -------- |
| 53 | + >>> from skltemplate.utils.discovery import all_estimators |
| 54 | + >>> estimators = all_estimators() |
| 55 | + >>> type(estimators) |
| 56 | + <class 'list'> |
| 57 | + """ |
| 58 | + |
| 59 | + def is_abstract(c): |
| 60 | + if not (hasattr(c, "__abstractmethods__")): |
| 61 | + return False |
| 62 | + if not len(c.__abstractmethods__): |
| 63 | + return False |
| 64 | + return True |
| 65 | + |
| 66 | + all_classes = [] |
| 67 | + root = str(Path(__file__).parent.parent) # skltemplate package |
| 68 | + # Ignore deprecation warnings triggered at import time and from walking |
| 69 | + # packages |
| 70 | + with ignore_warnings(category=FutureWarning): |
| 71 | + for _, module_name, _ in pkgutil.walk_packages( |
| 72 | + path=[root], prefix="skltemplate." |
| 73 | + ): |
| 74 | + module_parts = module_name.split(".") |
| 75 | + if any(part in _MODULE_TO_IGNORE for part in module_parts): |
| 76 | + continue |
| 77 | + module = import_module(module_name) |
| 78 | + classes = inspect.getmembers(module, inspect.isclass) |
| 79 | + classes = [ |
| 80 | + (name, est_cls) for name, est_cls in classes if not name.startswith("_") |
| 81 | + ] |
| 82 | + |
| 83 | + all_classes.extend(classes) |
| 84 | + |
| 85 | + all_classes = set(all_classes) |
| 86 | + |
| 87 | + estimators = [ |
| 88 | + c |
| 89 | + for c in all_classes |
| 90 | + if (issubclass(c[1], BaseEstimator) and c[0] != "BaseEstimator") |
| 91 | + ] |
| 92 | + # get rid of abstract base classes |
| 93 | + estimators = [c for c in estimators if not is_abstract(c[1])] |
| 94 | + |
| 95 | + if type_filter is not None: |
| 96 | + if not isinstance(type_filter, list): |
| 97 | + type_filter = [type_filter] |
| 98 | + else: |
| 99 | + type_filter = list(type_filter) # copy |
| 100 | + filtered_estimators = [] |
| 101 | + filters = { |
| 102 | + "classifier": ClassifierMixin, |
| 103 | + "regressor": RegressorMixin, |
| 104 | + "transformer": TransformerMixin, |
| 105 | + "cluster": ClusterMixin, |
| 106 | + } |
| 107 | + for name, mixin in filters.items(): |
| 108 | + if name in type_filter: |
| 109 | + type_filter.remove(name) |
| 110 | + filtered_estimators.extend( |
| 111 | + [est for est in estimators if issubclass(est[1], mixin)] |
| 112 | + ) |
| 113 | + estimators = filtered_estimators |
| 114 | + if type_filter: |
| 115 | + raise ValueError( |
| 116 | + "Parameter type_filter must be 'classifier', " |
| 117 | + "'regressor', 'transformer', 'cluster' or " |
| 118 | + "None, got" |
| 119 | + f" {repr(type_filter)}." |
| 120 | + ) |
| 121 | + |
| 122 | + # drop duplicates, sort for reproducibility |
| 123 | + # itemgetter is used to ensure the sort does not extend to the 2nd item of |
| 124 | + # the tuple |
| 125 | + return sorted(set(estimators), key=itemgetter(0)) |
| 126 | + |
| 127 | + |
| 128 | +def all_displays(): |
| 129 | + """Get a list of all displays from `skltemplate`. |
| 130 | +
|
| 131 | + Returns |
| 132 | + ------- |
| 133 | + displays : list of tuples |
| 134 | + List of (name, class), where ``name`` is the display class name as |
| 135 | + string and ``class`` is the actual type of the class. |
| 136 | +
|
| 137 | + Examples |
| 138 | + -------- |
| 139 | + >>> from skltemplate.utils.discovery import all_displays |
| 140 | + >>> displays = all_displays() |
| 141 | + """ |
| 142 | + all_classes = [] |
| 143 | + root = str(Path(__file__).parent.parent) # skltemplate package |
| 144 | + # Ignore deprecation warnings triggered at import time and from walking |
| 145 | + # packages |
| 146 | + with ignore_warnings(category=FutureWarning): |
| 147 | + for _, module_name, _ in pkgutil.walk_packages( |
| 148 | + path=[root], prefix="skltemplate." |
| 149 | + ): |
| 150 | + module_parts = module_name.split(".") |
| 151 | + if any(part in _MODULE_TO_IGNORE for part in module_parts): |
| 152 | + continue |
| 153 | + module = import_module(module_name) |
| 154 | + classes = inspect.getmembers(module, inspect.isclass) |
| 155 | + classes = [ |
| 156 | + (name, display_class) |
| 157 | + for name, display_class in classes |
| 158 | + if not name.startswith("_") and name.endswith("Display") |
| 159 | + ] |
| 160 | + all_classes.extend(classes) |
| 161 | + |
| 162 | + return sorted(set(all_classes), key=itemgetter(0)) |
| 163 | + |
| 164 | + |
| 165 | +def _is_checked_function(item): |
| 166 | + if not inspect.isfunction(item): |
| 167 | + return False |
| 168 | + |
| 169 | + if item.__name__.startswith("_"): |
| 170 | + return False |
| 171 | + |
| 172 | + mod = item.__module__ |
| 173 | + if not mod.startswith("skltemplate.") or mod.endswith("estimator_checks"): |
| 174 | + return False |
| 175 | + |
| 176 | + return True |
| 177 | + |
| 178 | + |
| 179 | +def all_functions(): |
| 180 | + """Get a list of all functions from `skltemplate`. |
| 181 | +
|
| 182 | + Returns |
| 183 | + ------- |
| 184 | + functions : list of tuples |
| 185 | + List of (name, function), where ``name`` is the function name as |
| 186 | + string and ``function`` is the actual function. |
| 187 | +
|
| 188 | + Examples |
| 189 | + -------- |
| 190 | + >>> from skltemplate.utils.discovery import all_functions |
| 191 | + >>> functions = all_functions() |
| 192 | + """ |
| 193 | + all_functions = [] |
| 194 | + root = str(Path(__file__).parent.parent) # skltemplate package |
| 195 | + # Ignore deprecation warnings triggered at import time and from walking |
| 196 | + # packages |
| 197 | + with ignore_warnings(category=FutureWarning): |
| 198 | + for _, module_name, _ in pkgutil.walk_packages( |
| 199 | + path=[root], prefix="skltemplate." |
| 200 | + ): |
| 201 | + module_parts = module_name.split(".") |
| 202 | + if any(part in _MODULE_TO_IGNORE for part in module_parts): |
| 203 | + continue |
| 204 | + |
| 205 | + module = import_module(module_name) |
| 206 | + functions = inspect.getmembers(module, _is_checked_function) |
| 207 | + functions = [ |
| 208 | + (func.__name__, func) |
| 209 | + for name, func in functions |
| 210 | + if not name.startswith("_") |
| 211 | + ] |
| 212 | + all_functions.extend(functions) |
| 213 | + |
| 214 | + # drop duplicates, sort for reproducibility |
| 215 | + # itemgetter is used to ensure the sort does not extend to the 2nd item of |
| 216 | + # the tuple |
| 217 | + return sorted(set(all_functions), key=itemgetter(0)) |
0 commit comments