diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index 4792f50f2baef..301ba2ffd6776 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -12,8 +12,7 @@ from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.utils import deprecated from sklearn.utils._testing import ( - assert_raises, - assert_raises_regex, + raises, assert_warns, ignore_warnings, MinimalClassifier, @@ -413,7 +412,8 @@ def test_not_an_array_array_function(): raise SkipTest("array_function protocol not supported in numpy <1.17") not_array = _NotAnArray(np.ones(10)) msg = "Don't want to call array_function sum!" - assert_raises_regex(TypeError, msg, np.sum, not_array) + with raises(TypeError, match=msg): + np.sum(not_array) # always returns True assert np.may_share_memory(not_array, None) @@ -437,92 +437,93 @@ def test_check_estimator(): # check that we have a set_params and can clone msg = "Passing a class was deprecated" - assert_raises_regex(TypeError, msg, check_estimator, object) + with raises(TypeError, match=msg): + check_estimator(object) msg = ( "Parameter 'p' of estimator 'HasMutableParameters' is of type " "object which is not allowed" ) # check that the "default_constructible" test checks for mutable parameters check_estimator(HasImmutableParameters()) # should pass - assert_raises_regex( - AssertionError, msg, check_estimator, HasMutableParameters() - ) + with raises(AssertionError, match=msg): + check_estimator(HasMutableParameters()) # check that values returned by get_params match set_params msg = "get_params result does not match what was passed to set_params" - assert_raises_regex(AssertionError, msg, check_estimator, - ModifiesValueInsteadOfRaisingError()) + with raises(AssertionError, match=msg): + check_estimator(ModifiesValueInsteadOfRaisingError()) assert_warns(UserWarning, check_estimator, RaisesErrorInSetParams()) - assert_raises_regex(AssertionError, msg, check_estimator, - ModifiesAnotherValue()) + with raises(AssertionError, match=msg): + check_estimator(ModifiesAnotherValue()) # check that we have a fit method msg = "object has no attribute 'fit'" - assert_raises_regex(AttributeError, msg, check_estimator, BaseEstimator()) + with raises(AttributeError, match=msg): + check_estimator(BaseEstimator()) # check that fit does input validation msg = "Did not raise" - assert_raises_regex(AssertionError, msg, check_estimator, - BaseBadClassifier()) + with raises(AssertionError, match=msg): + check_estimator(BaseBadClassifier()) # check that sample_weights in fit accepts pandas.Series type try: from pandas import Series # noqa msg = ("Estimator NoSampleWeightPandasSeriesType raises error if " "'sample_weight' parameter is of type pandas.Series") - assert_raises_regex( - ValueError, msg, check_estimator, NoSampleWeightPandasSeriesType()) + with raises(ValueError, match=msg): + check_estimator(NoSampleWeightPandasSeriesType()) except ImportError: pass # check that predict does input validation (doesn't accept dicts in input) msg = "Estimator doesn't check for NaN and inf in predict" - assert_raises_regex(AssertionError, msg, check_estimator, - NoCheckinPredict()) + with raises(AssertionError, match=msg): + check_estimator(NoCheckinPredict()) # check that estimator state does not change # at transform/predict/predict_proba time msg = 'Estimator changes __dict__ during predict' - assert_raises_regex(AssertionError, msg, check_estimator, ChangesDict()) + with raises(AssertionError, match=msg): + check_estimator(ChangesDict()) # check that `fit` only changes attribures that # are private (start with an _ or end with a _). msg = ('Estimator ChangesWrongAttribute should not change or mutate ' 'the parameter wrong_attribute from 0 to 1 during fit.') - assert_raises_regex(AssertionError, msg, - check_estimator, ChangesWrongAttribute()) + with raises(AssertionError, match=msg): + check_estimator(ChangesWrongAttribute()) check_estimator(ChangesUnderscoreAttribute()) # check that `fit` doesn't add any public attribute msg = (r'Estimator adds public attribute\(s\) during the fit method.' ' Estimators are only allowed to add private attributes' ' either started with _ or ended' ' with _ but wrong_attribute added') - assert_raises_regex(AssertionError, msg, - check_estimator, SetsWrongAttribute()) + with raises(AssertionError, match=msg): + check_estimator(SetsWrongAttribute()) # check for sample order invariance name = NotInvariantSampleOrder.__name__ method = 'predict' msg = ("{method} of {name} is not invariant when applied to a dataset" "with different sample order.").format(method=method, name=name) - assert_raises_regex(AssertionError, msg, - check_estimator, NotInvariantSampleOrder()) + with raises(AssertionError, match=msg): + check_estimator(NotInvariantSampleOrder()) # check for invariant method name = NotInvariantPredict.__name__ method = 'predict' msg = ("{method} of {name} is not invariant when applied " "to a subset.").format(method=method, name=name) - assert_raises_regex(AssertionError, msg, - check_estimator, NotInvariantPredict()) + with raises(AssertionError, match=msg): + check_estimator(NotInvariantPredict()) # check for sparse matrix input handling name = NoSparseClassifier.__name__ msg = "Estimator %s doesn't seem to fail gracefully on sparse data" % name - assert_raises_regex( - AssertionError, msg, check_estimator, NoSparseClassifier() - ) + with raises(AssertionError, match=msg): + check_estimator(NoSparseClassifier()) # Large indices test on bad estimator msg = ('Estimator LargeSparseNotSupportedClassifier doesn\'t seem to ' r'support \S{3}_64 matrix, and is not failing gracefully.*') - assert_raises_regex(AssertionError, msg, check_estimator, - LargeSparseNotSupportedClassifier()) + with raises(AssertionError, match=msg): + check_estimator(LargeSparseNotSupportedClassifier()) # does error on binary_only untagged estimator msg = 'Only 2 classes are supported' - assert_raises_regex(ValueError, msg, check_estimator, - UntaggedBinaryClassifier()) + with raises(ValueError, match=msg): + check_estimator(UntaggedBinaryClassifier()) # non-regression test for estimators transforming to sparse data check_estimator(SparseTransformer()) @@ -537,8 +538,8 @@ def test_check_estimator(): # Check regressor with requires_positive_y estimator tag msg = 'negative y values not supported!' - assert_raises_regex(ValueError, msg, check_estimator, - RequiresPositiveYRegressor()) + with raises(ValueError, match=msg): + check_estimator(RequiresPositiveYRegressor()) # Does not raise error on classifier with poor_score tag check_estimator(PoorScoreLogisticRegression()) @@ -547,7 +548,8 @@ def test_check_estimator(): def test_check_outlier_corruption(): # should raise AssertionError decision = np.array([0., 1., 1.5, 2.]) - assert_raises(AssertionError, check_outlier_corruption, 1, 2, decision) + with raises(AssertionError): + check_outlier_corruption(1, 2, decision) # should pass decision = np.array([0., 1., 1., 2.]) check_outlier_corruption(1, 2, decision) @@ -555,8 +557,8 @@ def test_check_outlier_corruption(): def test_check_estimator_transformer_no_mixin(): # check that TransformerMixin is not required for transformer tests to run - assert_raises_regex(AttributeError, '.*fit_transform.*', - check_estimator, BadTransformerWithoutMixin()) + with raises(AttributeError, '.*fit_transform.*'): + check_estimator(BadTransformerWithoutMixin()) def test_check_estimator_clones(): @@ -593,8 +595,8 @@ def test_check_estimators_unfitted(): # check that a ValueError/AttributeError is raised when calling predict # on an unfitted estimator msg = "Did not raise" - assert_raises_regex(AssertionError, msg, check_estimators_unfitted, - "estimator", NoSparseClassifier()) + with raises(AssertionError, match=msg): + check_estimators_unfitted("estimator", NoSparseClassifier()) # check that CorrectNotFittedError inherit from either ValueError # or AttributeError @@ -610,19 +612,22 @@ class NonConformantEstimatorNoParamSet(BaseEstimator): def __init__(self, you_should_set_this_=None): pass - assert_raises_regex(AssertionError, - "Estimator estimator_name should not set any" - " attribute apart from parameters during init." - r" Found attributes \['you_should_not_set_this_'\].", - check_no_attributes_set_in_init, - 'estimator_name', - NonConformantEstimatorPrivateSet()) - assert_raises_regex(AttributeError, - "Estimator estimator_name should store all " - "parameters as an attribute during init.", - check_no_attributes_set_in_init, - 'estimator_name', - NonConformantEstimatorNoParamSet()) + msg = ( + "Estimator estimator_name should not set any" + " attribute apart from parameters during init." + r" Found attributes \['you_should_not_set_this_'\]." + ) + with raises(AssertionError, match=msg): + check_no_attributes_set_in_init('estimator_name', + NonConformantEstimatorPrivateSet()) + + msg = ( + "Estimator estimator_name should store all parameters as an attribute" + " during init" + ) + with raises(AttributeError, match=msg): + check_no_attributes_set_in_init('estimator_name', + NonConformantEstimatorNoParamSet()) def test_check_estimator_pairwise(): @@ -639,32 +644,24 @@ def test_check_estimator_pairwise(): def test_check_classifier_data_not_an_array(): - assert_raises_regex(AssertionError, - 'Not equal to tolerance', - check_classifier_data_not_an_array, - 'estimator_name', - EstimatorInconsistentForPandas()) + with raises(AssertionError, match='Not equal to tolerance'): + check_classifier_data_not_an_array('estimator_name', + EstimatorInconsistentForPandas()) def test_check_regressor_data_not_an_array(): - assert_raises_regex(AssertionError, - 'Not equal to tolerance', - check_regressor_data_not_an_array, - 'estimator_name', - EstimatorInconsistentForPandas()) + with raises(AssertionError, match='Not equal to tolerance'): + check_regressor_data_not_an_array('estimator_name', + EstimatorInconsistentForPandas()) def test_check_estimator_get_tags_default_keys(): estimator = EstimatorMissingDefaultTags() err_msg = (r"EstimatorMissingDefaultTags._get_tags\(\) is missing entries" r" for the following default tags: {'allow_nan'}") - assert_raises_regex( - AssertionError, - err_msg, - check_estimator_get_tags_default_keys, - estimator.__class__.__name__, - estimator, - ) + with raises(AssertionError, match=err_msg): + check_estimator_get_tags_default_keys(estimator.__class__.__name__, + estimator) # noop check when _get_tags is not available estimator = MinimalTransformer() @@ -688,12 +685,15 @@ def run_tests_without_pytest(): def test_check_class_weight_balanced_linear_classifier(): # check that ill-computed balanced weights raises an exception - assert_raises_regex(AssertionError, - "Classifier estimator_name is not computing" - " class_weight=balanced properly.", - check_class_weight_balanced_linear_classifier, - 'estimator_name', - BadBalancedWeightsClassifier) + msg = ( + "Classifier estimator_name is not computing class_weight=balanced " + "properly" + ) + with raises(AssertionError, match=msg): + check_class_weight_balanced_linear_classifier( + 'estimator_name', + BadBalancedWeightsClassifier + ) def test_all_estimators_all_public():