Skip to content

MRG Allow list of strings for type_filter in all_estimators. #3934

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 4, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
from sklearn.utils.testing import ignore_warnings

import sklearn
from sklearn.base import (ClassifierMixin, RegressorMixin,
TransformerMixin, ClusterMixin)
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import make_classification

Expand Down Expand Up @@ -86,9 +84,7 @@ def test_all_estimators():
def test_estimators_sparse_data():
# All estimators should either deal with sparse data or raise an
# exception with type TypeError and an intelligible error message
estimators = all_estimators()
estimators = [(name, Estimator) for name, Estimator in estimators
if issubclass(Estimator, (ClassifierMixin, RegressorMixin))]
estimators = all_estimators(type_filter=['classifier', 'regressor'])
for name, Estimator in estimators:
yield check_regressors_classifiers_sparse_data, name, Estimator

Expand All @@ -113,12 +109,8 @@ def test_transformers():

def test_estimators_nan_inf():
# Test that all estimators check their input for NaN's and infs
estimators = all_estimators()
estimators = [(name, E) for name, E in estimators
if (issubclass(E, ClassifierMixin) or
issubclass(E, RegressorMixin) or
issubclass(E, TransformerMixin) or
issubclass(E, ClusterMixin))]
estimators = all_estimators(type_filter=['classifier', 'regressor',
'transformer', 'cluster'])
for name, Estimator in estimators:
if name not in CROSS_DECOMPOSITION + ['Imputer']:
yield check_estimators_nan_inf, name, Estimator
Expand Down
56 changes: 30 additions & 26 deletions sklearn/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def assert_not_in(x, container):
except ImportError:
# for Py 2.6
def assert_raises_regex(expected_exception, expected_regexp,
callable_obj=None, *args, **kwargs):
callable_obj=None, *args, **kwargs):
"""Helper function to check for message patterns in exceptions"""

not_raised = False
Expand Down Expand Up @@ -157,7 +157,7 @@ def assert_warns(warning_class, func, *args, **kw):
if hasattr(np, 'VisibleDeprecationWarning'):
# Filter out numpy-specific warnings in numpy >= 1.9
w = [e for e in w
if not e.category is np.VisibleDeprecationWarning]
if e.category is not np.VisibleDeprecationWarning]

# Verify some things
if not len(w) > 0:
Expand Down Expand Up @@ -227,7 +227,7 @@ def assert_warns_message(warning_class, message, func, *args, **kw):
if not check_in_message(msg):
raise AssertionError("The message received ('%s') for <%s> is "
"not the one you expected ('%s')"
% (msg, func.__name__, message
% (msg, func.__name__, message
))
return result

Expand All @@ -246,7 +246,7 @@ def assert_no_warnings(func, *args, **kw):
if hasattr(np, 'VisibleDeprecationWarning'):
# Filter out numpy-specific warnings in numpy >= 1.9
w = [e for e in w
if not e.category is np.VisibleDeprecationWarning]
if e.category is not np.VisibleDeprecationWarning]

if len(w) > 0:
raise AssertionError("Got warnings when calling %s: %s"
Expand Down Expand Up @@ -510,11 +510,12 @@ def all_estimators(include_meta_estimators=False, include_other=False,
include_dont_test : boolean, default=False
Whether to include "special" label estimator or test processors.

type_filter : string or None, default=None
type_filter : string, list of string, or None, default=None
Which kind of estimators should be returned. If None, no filter is
applied and all estimators are returned. Possible values are
'classifier', 'regressor', 'cluster' and 'transformer' to get
estimators only of these specific types.
estimators only of these specific types, or a list of these to
get the estimators that fit at least one of the types.

Returns
-------
Expand Down Expand Up @@ -556,26 +557,29 @@ def is_abstract(c):
# possibly get rid of meta estimators
if not include_meta_estimators:
estimators = [c for c in estimators if not c[0] in META_ESTIMATORS]

if type_filter == 'classifier':
estimators = [est for est in estimators
if issubclass(est[1], ClassifierMixin)]
elif type_filter == 'regressor':
estimators = [est for est in estimators
if issubclass(est[1], RegressorMixin)]
elif type_filter == 'transformer':
estimators = [est for est in estimators
if issubclass(est[1], TransformerMixin)]
elif type_filter == 'cluster':
estimators = [est for est in estimators
if issubclass(est[1], ClusterMixin)]
elif type_filter is not None:
raise ValueError("Parameter type_filter must be 'classifier', "
"'regressor', 'transformer', 'cluster' or None, got"
" %s." % repr(type_filter))

# We sort in order to have reproducible test failures
return sorted(estimators)
if type_filter is not None:
if not isinstance(type_filter, list):
type_filter = [type_filter]
else:
type_filter = list(type_filter) # copy
filtered_estimators = []
filters = {'classifier': ClassifierMixin,
'regressor': RegressorMixin,
'transformer': TransformerMixin,
'cluster': ClusterMixin}
for name, mixin in filters.items():
if name in type_filter:
type_filter.remove(name)
filtered_estimators.extend([est for est in estimators
if issubclass(est[1], mixin)])
estimators = filtered_estimators
if type_filter:
raise ValueError("Parameter type_filter must be 'classifier', "
"'regressor', 'transformer', 'cluster' or None, got"
" %s." % repr(type_filter))

# drop duplicates, sort for reproducibility
return sorted(set(estimators))


def set_random_state(estimator, random_state=0):
Expand Down