Skip to content

[MRG+1] make more explicit which checks are run #7317

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 5, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
check_class_weight_balanced_linear_classifier,
check_transformer_n_iter,
check_non_transformer_estimators_n_iter,
check_get_params_invariance)
check_get_params_invariance,
_set_test_name)


def test_all_estimator_no_base_class():
Expand All @@ -55,7 +56,8 @@ def test_all_estimators():

for name, Estimator in estimators:
# some can just not be sensibly default constructed
yield check_parameters_default_constructible, name, Estimator
yield (_set_test_name(check_parameters_default_constructible, name),
name, Estimator)


def test_non_meta_estimators():
Expand All @@ -70,9 +72,9 @@ def test_non_meta_estimators():
if issubclass(Estimator, ProjectedGradientNMF):
# The ProjectedGradientNMF class is deprecated
with ignore_warnings():
yield check, name, Estimator
yield _set_test_name(check, name), name, Estimator
else:
yield check, name, Estimator
yield _set_test_name(check, name), name, Estimator


def test_configure():
Expand Down Expand Up @@ -114,7 +116,8 @@ def test_class_weight_balanced_linear_classifiers():
issubclass(clazz, LinearClassifierMixin))]

for name, Classifier in linear_classifiers:
yield check_class_weight_balanced_linear_classifier, name, Classifier
yield _set_test_name(check_class_weight_balanced_linear_classifier,
name), name, Classifier


@ignore_warnings
Expand Down Expand Up @@ -196,8 +199,9 @@ def test_non_transformer_estimators_n_iter():
else:
# Multitask models related to ENet cannot handle
# if y is mono-output.
yield (check_non_transformer_estimators_n_iter,
name, estimator, 'Multi' in name)
yield (_set_test_name(
check_non_transformer_estimators_n_iter, name),
name, estimator, 'Multi' in name)


def test_transformer_n_iter():
Expand All @@ -218,9 +222,12 @@ def test_transformer_n_iter():
if isinstance(estimator, ProjectedGradientNMF):
# The ProjectedGradientNMF class is deprecated
with ignore_warnings():
yield check_transformer_n_iter, name, estimator
yield _set_test_name(
check_transformer_n_iter, name), name, estimator
else:
yield check_transformer_n_iter, name, estimator
yield _set_test_name(
check_transformer_n_iter, name), name, estimator


def test_get_params_invariance():
# Test for estimators that support get_params, that
Expand All @@ -234,6 +241,8 @@ def test_get_params_invariance():
# If class is deprecated, ignore deprecated warnings
if hasattr(Estimator.__init__, "deprecated_original"):
with ignore_warnings():
yield check_get_params_invariance, name, Estimator
yield _set_test_name(
check_get_params_invariance, name), name, Estimator
else:
yield check_get_params_invariance, name, Estimator
yield _set_test_name(
check_get_params_invariance, name), name, Estimator
6 changes: 6 additions & 0 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@
"GradientBoostingClassifier", "GradientBoostingRegressor"]


def _set_test_name(function, name):
function.description = ("sklearn.tests.test_common.{0}({1})".format(
function.__name__, name))
return function


def _yield_non_meta_checks(name, Estimator):
yield check_estimators_dtypes
yield check_fit_score_takes_y
Expand Down