diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 4cc3f169f63c0..795a8a7708cbe 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -67,17 +67,6 @@ CROSS_DECOMPOSITION = ['PLSCanonical', 'PLSRegression', 'CCA', 'PLSSVD'] -def _assert_raises(exp_type, strict_mode, func, *args, msg=None): - """Assert exp_type is raised when calling func(*args) - - The error message is validated if strict mode is True. - """ - if strict_mode: - assert_raises_regex(exp_type, msg, func, *args) - else: - assert_raises(exp_type, func, *args) - - def _yield_checks(estimator): name = estimator.__class__.__name__ tags = estimator._get_tags() @@ -168,9 +157,18 @@ def check_supervised_y_no_nan(name, estimator_orig, strict_mode=True): y = np.full(10, np.inf) y = _enforce_estimator_tags_y(estimator, y) - errmsg = ("Input contains NaN, infinity or a value too large for " - r"dtype\('float64'\).") - _assert_raises(ValueError, strict_mode, estimator.fit, X, y, msg=errmsg) + errmsg = "Input contains NaN, infinity or a value too large for " \ + "dtype('float64')." + try: + estimator.fit(X, y) + except ValueError as e: + if str(e) != errmsg: + raise ValueError("Estimator {0} raised error as expected, but " + "does not match expected error message" + .format(name)) + else: + raise ValueError("Estimator {0} should have raised error on fitting " + "array y with NaN value.".format(name)) def _yield_regressor_checks(regressor): @@ -977,17 +975,16 @@ def check_dtype_object(name, estimator_orig, strict_mode=True): if hasattr(estimator, "transform"): estimator.transform(X) - msg = "Unknown label type" try: estimator.fit(X, y.astype(object)) except Exception as e: - if strict_mode and "Unknown label type" not in str(e): + if "Unknown label type" not in str(e): raise if 'string' not in tags['X_types']: X[0, 0] = {'foo': 'bar'} msg = "argument must be a string.* number" - _assert_raises(TypeError, strict_mode, estimator.fit, X, y, msg=msg) + assert_raises_regex(TypeError, msg, estimator.fit, X, y) else: # Estimators supporting string will not call np.asarray to convert the # data to numeric and therefore, the error will not be raised. @@ -1002,8 +999,8 @@ def check_complex_data(name, estimator_orig, strict_mode=True): X = X.reshape(-1, 1) y = np.random.sample(10) + 1j * np.random.sample(10) estimator = clone(estimator_orig) - msg = "Complex data not supported" - _assert_raises(ValueError, strict_mode, estimator.fit, X, y, msg=msg) + assert_raises_regex(ValueError, "Complex data not supported", + estimator.fit, X, y) @ignore_warnings @@ -1209,14 +1206,12 @@ def check_fit2d_1sample(name, estimator_orig, strict_mode=True): msgs = ["1 sample", "n_samples = 1", "n_samples=1", "one sample", "1 class", "one class"] + try: estimator.fit(X, y) except ValueError as e: - if strict_mode and all(msg not in repr(e) for msg in msgs): - raise AssertionError( - "The error message should contain one of the following " - f"patterns: {', '.join(msgs)}." - ) from e + if all(msg not in repr(e) for msg in msgs): + raise e @ignore_warnings @@ -1245,14 +1240,12 @@ def check_fit2d_1feature(name, estimator_orig, strict_mode=True): set_random_state(estimator, 1) msgs = ["1 feature(s)", "n_features = 1", "n_features=1"] + try: estimator.fit(X, y) except ValueError as e: - if strict_mode and all(msg not in repr(e) for msg in msgs): - raise AssertionError( - "The error message should contain one of the following " - f"patterns: {', '.join(msgs)}." - ) from e + if all(msg not in repr(e) for msg in msgs): + raise e @ignore_warnings @@ -1533,7 +1526,7 @@ def check_estimators_empty_data_messages(name, estimator_orig, y = _enforce_estimator_tags_y(e, np.array([1, 0, 1])) msg = (r"0 feature\(s\) \(shape=\(3, 0\)\) while a minimum of \d* " "is required.") - _assert_raises(ValueError, strict_mode, e.fit, X_zero_features, y, msg=msg) + assert_raises_regex(ValueError, msg, e.fit, X_zero_features, y) @ignore_warnings(category=FutureWarning) @@ -2928,9 +2921,9 @@ def check_classifiers_regression_target(name, estimator_orig, X = X + 1 + abs(X.min(axis=0)) # be sure that X is non-negative e = clone(estimator_orig) + msg = 'Unknown label type: ' if not e._get_tags()["no_validation"]: - msg = 'Unknown label type: ' - _assert_raises(ValueError, strict_mode, e.fit, X, y, msg=msg) + assert_raises_regex(ValueError, msg, e.fit, X, y) @ignore_warnings(category=FutureWarning) @@ -3009,8 +3002,11 @@ def check_fit_non_negative(name, estimator_orig, strict_mode=True): X = np.array([[-1., 1], [-1., 1]]) y = np.array([1, 2]) estimator = clone(estimator_orig) - msg = "Negative values in data passed to" - _assert_raises(ValueError, strict_mode, estimator.fit, X, y, msg=msg) + if strict_mode: + assert_raises_regex(ValueError, "Negative values in data passed to", + estimator.fit, X, y) + else: # Don't check error message if strict mode is off + assert_raises(ValueError, estimator.fit, X, y) def check_fit_idempotent(name, estimator_orig, strict_mode=True): diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index 149e1e9288535..fc42329c94933 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -23,8 +23,6 @@ from sklearn.utils.estimator_checks import check_no_attributes_set_in_init from sklearn.utils.estimator_checks import check_classifier_data_not_an_array from sklearn.utils.estimator_checks import check_regressor_data_not_an_array -from sklearn.utils.estimator_checks import check_fit2d_1sample -from sklearn.utils.estimator_checks import check_fit2d_1feature from sklearn.utils.validation import check_is_fitted from sklearn.utils.estimator_checks import check_outlier_corruption from sklearn.utils.fixes import np_version, parse_version @@ -665,37 +663,3 @@ def test_xfail_ignored_in_check_estimator(): # Make sure checks marked as xfail are just ignored and not run by # check_estimator(), but still raise a warning. assert_warns(SkipTestWarning, check_estimator, NuSVC()) - - -def test_check_fit2d_1sample(): - - class MyEst(SVC): - # raises a bad error message when only 1 sample is passed - def fit(self, X, y): - if X.shape[0] == 1: - raise ValueError("non informative error message") - - assert_raises_regex( - AssertionError, - "The error message should contain one of the following", - check_fit2d_1sample, - 'estimator_name', - MyEst() - ) - - -def test_check_fit2d_1feature(): - - class MyEst(SVC): - # raises a bad error message when only 1 feature is passed - def fit(self, X, y): - if X.shape[1] == 1: - raise ValueError("non informative error message") - - assert_raises_regex( - AssertionError, - "The error message should contain one of the following", - check_fit2d_1feature, - 'estimator_name', - MyEst() - )