Skip to content

Commit 7d15d78

Browse files
committed
ENH make tools to list estimators within the project
1 parent 251b4f9 commit 7d15d78

File tree

11 files changed

+274
-4
lines changed

11 files changed

+274
-4
lines changed

doc/api.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,15 @@ Predictor
3232
:template: class.rst
3333

3434
TemplateClassifier
35+
36+
37+
Utilities
38+
=========
39+
40+
.. autosummary::
41+
:toctree: generated/
42+
:template: functions.rst
43+
44+
utils.discovery.all_estimators
45+
utils.discovery.all_displays
46+
utils.discovery.all_functions

skltemplate/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Authors: scikit-learn-contrib developers
2+
# License: BSD 3 clause
3+
14
from ._template import TemplateClassifier, TemplateEstimator, TemplateTransformer
25
from ._version import __version__
36

skltemplate/_template.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
"""
22
This is a module to be used as a reference for building other modules
33
"""
4+
5+
# Authors: scikit-learn-contrib developers
6+
# License: BSD 3 clause
7+
48
import numpy as np
59
from sklearn.base import BaseEstimator, ClassifierMixin, TransformerMixin, _fit_context
610
from sklearn.metrics import euclidean_distances
711
from sklearn.utils.multiclass import check_classification_targets
8-
from sklearn.utils.validation import check_is_fitted
12+
from sklearn.utils.validation import check_is_fitted
913

1014

1115
class TemplateEstimator(BaseEstimator):
@@ -301,4 +305,4 @@ def _more_tags(self):
301305
# https://scikit-learn.org/dev/developers/develop.html#estimator-tags
302306
# Here, our transformer does not do any operation in `fit` and only validate
303307
# the parameters. Thus, it is stateless.
304-
return {'stateless': True}
308+
return {"stateless": True}

skltemplate/_version.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
1+
# Authors: scikit-learn-contrib developers
2+
# License: BSD 3 clause
3+
14
__version__ = "0.0.4.dev0"

skltemplate/tests/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Authors: scikit-learn-contrib developers
2+
# License: BSD 3 clause

skltemplate/tests/test_common.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
"""This file shows how to write test based on the scikit-learn common tests."""
22

3+
# Authors: scikit-learn-contrib developers
4+
# License: BSD 3 clause
5+
36
from sklearn.utils.estimator_checks import parametrize_with_checks
47

5-
from skltemplate import TemplateClassifier, TemplateEstimator, TemplateTransformer
8+
from skltemplate.utils.discovery import all_estimators
69

710

811
# parametrize_with_checks allows to get a generator of check that is more fine-grained
912
# than check_estimator
10-
@parametrize_with_checks([TemplateEstimator(), TemplateTransformer(), TemplateClassifier()])
13+
@parametrize_with_checks([est() for _, est in all_estimators()])
1114
def test_estimators(estimator, check, request):
1215
"""Check the compatibility with scikit-learn API"""
1316
check(estimator)

skltemplate/tests/test_template.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
from skltemplate import TemplateClassifier, TemplateEstimator, TemplateTransformer
88

9+
# Authors: scikit-learn-contrib developers
10+
# License: BSD 3 clause
11+
912

1013
@pytest.fixture
1114
def data():

skltemplate/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Authors: scikit-learn-contrib developers
2+
# License: BSD 3 clause

skltemplate/utils/discovery.py

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
"""
2+
The :mod:`skltemplate.utils.discovery` module includes utilities to discover
3+
objects (i.e. estimators, displays, functions) from the `skltemplate` package.
4+
"""
5+
6+
# Adapted from scikit-learn
7+
# Authors: scikit-learn-contrib developers
8+
# License: BSD 3 clause
9+
10+
import inspect
11+
import pkgutil
12+
from importlib import import_module
13+
from operator import itemgetter
14+
from pathlib import Path
15+
16+
from sklearn.base import (
17+
BaseEstimator,
18+
ClassifierMixin,
19+
ClusterMixin,
20+
RegressorMixin,
21+
TransformerMixin,
22+
)
23+
from sklearn.utils._testing import ignore_warnings
24+
25+
_MODULE_TO_IGNORE = {"tests"}
26+
27+
28+
def all_estimators(type_filter=None):
29+
"""Get a list of all estimators from `skltemplate`.
30+
31+
This function crawls the module and gets all classes that inherit
32+
from `BaseEstimator`. Classes that are defined in test-modules are not
33+
included.
34+
35+
Parameters
36+
----------
37+
type_filter : {"classifier", "regressor", "cluster", "transformer"} \
38+
or list of such str, default=None
39+
Which kind of estimators should be returned. If None, no filter is
40+
applied and all estimators are returned. Possible values are
41+
'classifier', 'regressor', 'cluster' and 'transformer' to get
42+
estimators only of these specific types, or a list of these to
43+
get the estimators that fit at least one of the types.
44+
45+
Returns
46+
-------
47+
estimators : list of tuples
48+
List of (name, class), where ``name`` is the class name as string
49+
and ``class`` is the actual type of the class.
50+
51+
Examples
52+
--------
53+
>>> from skltemplate.utils.discovery import all_estimators
54+
>>> estimators = all_estimators()
55+
>>> type(estimators)
56+
<class 'list'>
57+
"""
58+
59+
def is_abstract(c):
60+
if not (hasattr(c, "__abstractmethods__")):
61+
return False
62+
if not len(c.__abstractmethods__):
63+
return False
64+
return True
65+
66+
all_classes = []
67+
root = str(Path(__file__).parent.parent) # skltemplate package
68+
# Ignore deprecation warnings triggered at import time and from walking
69+
# packages
70+
with ignore_warnings(category=FutureWarning):
71+
for _, module_name, _ in pkgutil.walk_packages(
72+
path=[root], prefix="skltemplate."
73+
):
74+
module_parts = module_name.split(".")
75+
if any(part in _MODULE_TO_IGNORE for part in module_parts):
76+
continue
77+
module = import_module(module_name)
78+
classes = inspect.getmembers(module, inspect.isclass)
79+
classes = [
80+
(name, est_cls) for name, est_cls in classes if not name.startswith("_")
81+
]
82+
83+
all_classes.extend(classes)
84+
85+
all_classes = set(all_classes)
86+
87+
estimators = [
88+
c
89+
for c in all_classes
90+
if (issubclass(c[1], BaseEstimator) and c[0] != "BaseEstimator")
91+
]
92+
# get rid of abstract base classes
93+
estimators = [c for c in estimators if not is_abstract(c[1])]
94+
95+
if type_filter is not None:
96+
if not isinstance(type_filter, list):
97+
type_filter = [type_filter]
98+
else:
99+
type_filter = list(type_filter) # copy
100+
filtered_estimators = []
101+
filters = {
102+
"classifier": ClassifierMixin,
103+
"regressor": RegressorMixin,
104+
"transformer": TransformerMixin,
105+
"cluster": ClusterMixin,
106+
}
107+
for name, mixin in filters.items():
108+
if name in type_filter:
109+
type_filter.remove(name)
110+
filtered_estimators.extend(
111+
[est for est in estimators if issubclass(est[1], mixin)]
112+
)
113+
estimators = filtered_estimators
114+
if type_filter:
115+
raise ValueError(
116+
"Parameter type_filter must be 'classifier', "
117+
"'regressor', 'transformer', 'cluster' or "
118+
"None, got"
119+
f" {repr(type_filter)}."
120+
)
121+
122+
# drop duplicates, sort for reproducibility
123+
# itemgetter is used to ensure the sort does not extend to the 2nd item of
124+
# the tuple
125+
return sorted(set(estimators), key=itemgetter(0))
126+
127+
128+
def all_displays():
129+
"""Get a list of all displays from `skltemplate`.
130+
131+
Returns
132+
-------
133+
displays : list of tuples
134+
List of (name, class), where ``name`` is the display class name as
135+
string and ``class`` is the actual type of the class.
136+
137+
Examples
138+
--------
139+
>>> from skltemplate.utils.discovery import all_displays
140+
>>> displays = all_displays()
141+
"""
142+
all_classes = []
143+
root = str(Path(__file__).parent.parent) # skltemplate package
144+
# Ignore deprecation warnings triggered at import time and from walking
145+
# packages
146+
with ignore_warnings(category=FutureWarning):
147+
for _, module_name, _ in pkgutil.walk_packages(
148+
path=[root], prefix="skltemplate."
149+
):
150+
module_parts = module_name.split(".")
151+
if any(part in _MODULE_TO_IGNORE for part in module_parts):
152+
continue
153+
module = import_module(module_name)
154+
classes = inspect.getmembers(module, inspect.isclass)
155+
classes = [
156+
(name, display_class)
157+
for name, display_class in classes
158+
if not name.startswith("_") and name.endswith("Display")
159+
]
160+
all_classes.extend(classes)
161+
162+
return sorted(set(all_classes), key=itemgetter(0))
163+
164+
165+
def _is_checked_function(item):
166+
if not inspect.isfunction(item):
167+
return False
168+
169+
if item.__name__.startswith("_"):
170+
return False
171+
172+
mod = item.__module__
173+
if not mod.startswith("skltemplate.") or mod.endswith("estimator_checks"):
174+
return False
175+
176+
return True
177+
178+
179+
def all_functions():
180+
"""Get a list of all functions from `skltemplate`.
181+
182+
Returns
183+
-------
184+
functions : list of tuples
185+
List of (name, function), where ``name`` is the function name as
186+
string and ``function`` is the actual function.
187+
188+
Examples
189+
--------
190+
>>> from skltemplate.utils.discovery import all_functions
191+
>>> functions = all_functions()
192+
"""
193+
all_functions = []
194+
root = str(Path(__file__).parent.parent) # skltemplate package
195+
# Ignore deprecation warnings triggered at import time and from walking
196+
# packages
197+
with ignore_warnings(category=FutureWarning):
198+
for _, module_name, _ in pkgutil.walk_packages(
199+
path=[root], prefix="skltemplate."
200+
):
201+
module_parts = module_name.split(".")
202+
if any(part in _MODULE_TO_IGNORE for part in module_parts):
203+
continue
204+
205+
module = import_module(module_name)
206+
functions = inspect.getmembers(module, _is_checked_function)
207+
functions = [
208+
(func.__name__, func)
209+
for name, func in functions
210+
if not name.startswith("_")
211+
]
212+
all_functions.extend(functions)
213+
214+
# drop duplicates, sort for reproducibility
215+
# itemgetter is used to ensure the sort does not extend to the 2nd item of
216+
# the tuple
217+
return sorted(set(all_functions), key=itemgetter(0))
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Authors: scikit-learn-contrib developers
2+
# License: BSD 3 clause

0 commit comments

Comments
 (0)