Skip to content

TST Changes assert_raises to raises in sklearn/utils/test_estimator_checks.py #20138

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
May 27, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 77 additions & 77 deletions sklearn/utils/tests/test_estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils import deprecated
from sklearn.utils._testing import (
assert_raises,
assert_raises_regex,
raises,
assert_warns,
ignore_warnings,
MinimalClassifier,
Expand Down Expand Up @@ -413,7 +412,8 @@ def test_not_an_array_array_function():
raise SkipTest("array_function protocol not supported in numpy <1.17")
not_array = _NotAnArray(np.ones(10))
msg = "Don't want to call array_function sum!"
assert_raises_regex(TypeError, msg, np.sum, not_array)
with raises(TypeError, match=msg):
np.sum(not_array)
# always returns True
assert np.may_share_memory(not_array, None)

Expand All @@ -437,92 +437,93 @@ def test_check_estimator():

# check that we have a set_params and can clone
msg = "Passing a class was deprecated"
assert_raises_regex(TypeError, msg, check_estimator, object)
with raises(TypeError, match=msg):
check_estimator(object)
msg = (
"Parameter 'p' of estimator 'HasMutableParameters' is of type "
"object which is not allowed"
)
# check that the "default_constructible" test checks for mutable parameters
check_estimator(HasImmutableParameters()) # should pass
assert_raises_regex(
AssertionError, msg, check_estimator, HasMutableParameters()
)
with raises(AssertionError, match=msg):
check_estimator(HasMutableParameters())
# check that values returned by get_params match set_params
msg = "get_params result does not match what was passed to set_params"
assert_raises_regex(AssertionError, msg, check_estimator,
ModifiesValueInsteadOfRaisingError())
with raises(AssertionError, match=msg):
check_estimator(ModifiesValueInsteadOfRaisingError())
assert_warns(UserWarning, check_estimator, RaisesErrorInSetParams())
assert_raises_regex(AssertionError, msg, check_estimator,
ModifiesAnotherValue())
with raises(AssertionError, match=msg):
check_estimator(ModifiesAnotherValue())
# check that we have a fit method
msg = "object has no attribute 'fit'"
assert_raises_regex(AttributeError, msg, check_estimator, BaseEstimator())
with raises(AttributeError, match=msg):
check_estimator(BaseEstimator())
# check that fit does input validation
msg = "Did not raise"
assert_raises_regex(AssertionError, msg, check_estimator,
BaseBadClassifier())
with raises(AssertionError, match=msg):
check_estimator(BaseBadClassifier())
# check that sample_weights in fit accepts pandas.Series type
try:
from pandas import Series # noqa
msg = ("Estimator NoSampleWeightPandasSeriesType raises error if "
"'sample_weight' parameter is of type pandas.Series")
assert_raises_regex(
ValueError, msg, check_estimator, NoSampleWeightPandasSeriesType())
with raises(ValueError, match=msg):
check_estimator(NoSampleWeightPandasSeriesType())
except ImportError:
pass
# check that predict does input validation (doesn't accept dicts in input)
msg = "Estimator doesn't check for NaN and inf in predict"
assert_raises_regex(AssertionError, msg, check_estimator,
NoCheckinPredict())
with raises(AssertionError, match=msg):
check_estimator(NoCheckinPredict())
# check that estimator state does not change
# at transform/predict/predict_proba time
msg = 'Estimator changes __dict__ during predict'
assert_raises_regex(AssertionError, msg, check_estimator, ChangesDict())
with raises(AssertionError, match=msg):
check_estimator(ChangesDict())
# check that `fit` only changes attribures that
# are private (start with an _ or end with a _).
msg = ('Estimator ChangesWrongAttribute should not change or mutate '
'the parameter wrong_attribute from 0 to 1 during fit.')
assert_raises_regex(AssertionError, msg,
check_estimator, ChangesWrongAttribute())
with raises(AssertionError, match=msg):
check_estimator(ChangesWrongAttribute())
check_estimator(ChangesUnderscoreAttribute())
# check that `fit` doesn't add any public attribute
msg = (r'Estimator adds public attribute\(s\) during the fit method.'
' Estimators are only allowed to add private attributes'
' either started with _ or ended'
' with _ but wrong_attribute added')
assert_raises_regex(AssertionError, msg,
check_estimator, SetsWrongAttribute())
with raises(AssertionError, match=msg):
check_estimator(SetsWrongAttribute())
# check for sample order invariance
name = NotInvariantSampleOrder.__name__
method = 'predict'
msg = ("{method} of {name} is not invariant when applied to a dataset"
"with different sample order.").format(method=method, name=name)
assert_raises_regex(AssertionError, msg,
check_estimator, NotInvariantSampleOrder())
with raises(AssertionError, match=msg):
check_estimator(NotInvariantSampleOrder())
# check for invariant method
name = NotInvariantPredict.__name__
method = 'predict'
msg = ("{method} of {name} is not invariant when applied "
"to a subset.").format(method=method, name=name)
assert_raises_regex(AssertionError, msg,
check_estimator, NotInvariantPredict())
with raises(AssertionError, match=msg):
check_estimator(NotInvariantPredict())
# check for sparse matrix input handling
name = NoSparseClassifier.__name__
msg = "Estimator %s doesn't seem to fail gracefully on sparse data" % name
assert_raises_regex(
AssertionError, msg, check_estimator, NoSparseClassifier()
)
with raises(AssertionError, match=msg):
check_estimator(NoSparseClassifier())

# Large indices test on bad estimator
msg = ('Estimator LargeSparseNotSupportedClassifier doesn\'t seem to '
r'support \S{3}_64 matrix, and is not failing gracefully.*')
assert_raises_regex(AssertionError, msg, check_estimator,
LargeSparseNotSupportedClassifier())
with raises(AssertionError, match=msg):
check_estimator(LargeSparseNotSupportedClassifier())

# does error on binary_only untagged estimator
msg = 'Only 2 classes are supported'
assert_raises_regex(ValueError, msg, check_estimator,
UntaggedBinaryClassifier())
with raises(ValueError, match=msg):
check_estimator(UntaggedBinaryClassifier())

# non-regression test for estimators transforming to sparse data
check_estimator(SparseTransformer())
Expand All @@ -537,8 +538,8 @@ def test_check_estimator():

# Check regressor with requires_positive_y estimator tag
msg = 'negative y values not supported!'
assert_raises_regex(ValueError, msg, check_estimator,
RequiresPositiveYRegressor())
with raises(ValueError, match=msg):
check_estimator(RequiresPositiveYRegressor())

# Does not raise error on classifier with poor_score tag
check_estimator(PoorScoreLogisticRegression())
Expand All @@ -547,16 +548,17 @@ def test_check_estimator():
def test_check_outlier_corruption():
# should raise AssertionError
decision = np.array([0., 1., 1.5, 2.])
assert_raises(AssertionError, check_outlier_corruption, 1, 2, decision)
with raises(AssertionError):
check_outlier_corruption(1, 2, decision)
# should pass
decision = np.array([0., 1., 1., 2.])
check_outlier_corruption(1, 2, decision)


def test_check_estimator_transformer_no_mixin():
# check that TransformerMixin is not required for transformer tests to run
assert_raises_regex(AttributeError, '.*fit_transform.*',
check_estimator, BadTransformerWithoutMixin())
with raises(AttributeError, '.*fit_transform.*'):
check_estimator(BadTransformerWithoutMixin())


def test_check_estimator_clones():
Expand Down Expand Up @@ -593,8 +595,8 @@ def test_check_estimators_unfitted():
# check that a ValueError/AttributeError is raised when calling predict
# on an unfitted estimator
msg = "Did not raise"
assert_raises_regex(AssertionError, msg, check_estimators_unfitted,
"estimator", NoSparseClassifier())
with raises(AssertionError, match=msg):
check_estimators_unfitted("estimator", NoSparseClassifier())

# check that CorrectNotFittedError inherit from either ValueError
# or AttributeError
Expand All @@ -610,19 +612,22 @@ class NonConformantEstimatorNoParamSet(BaseEstimator):
def __init__(self, you_should_set_this_=None):
pass

assert_raises_regex(AssertionError,
"Estimator estimator_name should not set any"
" attribute apart from parameters during init."
r" Found attributes \['you_should_not_set_this_'\].",
check_no_attributes_set_in_init,
'estimator_name',
NonConformantEstimatorPrivateSet())
assert_raises_regex(AttributeError,
"Estimator estimator_name should store all "
"parameters as an attribute during init.",
check_no_attributes_set_in_init,
'estimator_name',
NonConformantEstimatorNoParamSet())
msg = (
"Estimator estimator_name should not set any"
" attribute apart from parameters during init."
r" Found attributes \['you_should_not_set_this_'\]."
)
with raises(AssertionError, match=msg):
check_no_attributes_set_in_init('estimator_name',
NonConformantEstimatorPrivateSet())

msg = (
"Estimator estimator_name should store all parameters as an attribute"
" during init"
)
with raises(AttributeError, match=msg):
check_no_attributes_set_in_init('estimator_name',
NonConformantEstimatorNoParamSet())


def test_check_estimator_pairwise():
Expand All @@ -639,32 +644,24 @@ def test_check_estimator_pairwise():


def test_check_classifier_data_not_an_array():
assert_raises_regex(AssertionError,
'Not equal to tolerance',
check_classifier_data_not_an_array,
'estimator_name',
EstimatorInconsistentForPandas())
with raises(AssertionError, match='Not equal to tolerance'):
check_classifier_data_not_an_array('estimator_name',
EstimatorInconsistentForPandas())


def test_check_regressor_data_not_an_array():
assert_raises_regex(AssertionError,
'Not equal to tolerance',
check_regressor_data_not_an_array,
'estimator_name',
EstimatorInconsistentForPandas())
with raises(AssertionError, match='Not equal to tolerance'):
check_regressor_data_not_an_array('estimator_name',
EstimatorInconsistentForPandas())


def test_check_estimator_get_tags_default_keys():
estimator = EstimatorMissingDefaultTags()
err_msg = (r"EstimatorMissingDefaultTags._get_tags\(\) is missing entries"
r" for the following default tags: {'allow_nan'}")
assert_raises_regex(
AssertionError,
err_msg,
check_estimator_get_tags_default_keys,
estimator.__class__.__name__,
estimator,
)
with raises(AssertionError, match=err_msg):
check_estimator_get_tags_default_keys(estimator.__class__.__name__,
estimator)

# noop check when _get_tags is not available
estimator = MinimalTransformer()
Expand All @@ -688,12 +685,15 @@ def run_tests_without_pytest():

def test_check_class_weight_balanced_linear_classifier():
# check that ill-computed balanced weights raises an exception
assert_raises_regex(AssertionError,
"Classifier estimator_name is not computing"
" class_weight=balanced properly.",
check_class_weight_balanced_linear_classifier,
'estimator_name',
BadBalancedWeightsClassifier)
msg = (
"Classifier estimator_name is not computing class_weight=balanced "
"properly"
)
with raises(AssertionError, match=msg):
check_class_weight_balanced_linear_classifier(
'estimator_name',
BadBalancedWeightsClassifier
)


def test_all_estimators_all_public():
Expand Down