From 1f4fb212ec84274d5d855d07ebadd1c6029e738a Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Wed, 11 Sep 2024 14:44:52 +0200 Subject: [PATCH] TST be more specific in test_estimator_checks --- sklearn/utils/tests/test_estimator_checks.py | 156 +++++++++++++------ 1 file changed, 110 insertions(+), 46 deletions(-) diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index 34e549ba143a9..6401342014731 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -48,11 +48,18 @@ check_classifiers_multilabel_output_format_decision_function, check_classifiers_multilabel_output_format_predict, check_classifiers_multilabel_output_format_predict_proba, + check_classifiers_one_label_sample_weights, check_dataframe_column_names_consistency, check_decision_proba_consistency, + check_dict_unchanged, + check_dont_overwrite_parameters, check_estimator, check_estimator_cloneable, check_estimator_repr, + check_estimator_sparse_array, + check_estimator_sparse_matrix, + check_estimators_nan_inf, + check_estimators_overwrite_params, check_estimators_unfitted, check_fit_check_is_fitted, check_fit_score_takes_y, @@ -61,8 +68,11 @@ check_no_attributes_set_in_init, check_outlier_contamination, check_outlier_corruption, + check_parameters_default_constructible, check_regressor_data_not_an_array, check_requires_y_none, + check_sample_weights_pandas_series, + check_set_params, set_random_state, ) from sklearn.utils.fixes import CSR_CONTAINERS, SPARRAY_PRESENT @@ -571,40 +581,58 @@ def fit(self, X, y): check_fit_score_takes_y("test", TestEstimatorWithDeprecatedFitMethod()) -def test_check_estimator(): - # tests that the estimator actually fails on "bad" estimators. - # not a complete test of all checks, which are very extensive. - - # check that we have a set_params and can clone +def test_check_estimator_with_class_removed(): + """Test that passing a class instead of an instance fails.""" msg = "Passing a class was deprecated" with raises(TypeError, match=msg): - check_estimator(object) + check_estimator(LogisticRegression) + + +def test_mutable_default_params(): + """Test that constructor cannot have mutable default parameters.""" msg = ( "Parameter 'p' of estimator 'HasMutableParameters' is of type " "object which is not allowed" ) # check that the "default_constructible" test checks for mutable parameters - check_estimator(HasImmutableParameters()) # should pass + check_parameters_default_constructible( + "Immutable", HasImmutableParameters() + ) # should pass with raises(AssertionError, match=msg): - check_estimator(HasMutableParameters()) + check_parameters_default_constructible("Mutable", HasMutableParameters()) + + +def test_check_set_params(): + """Check set_params doesn't fail and sets the right values.""" # check that values returned by get_params match set_params msg = "get_params result does not match what was passed to set_params" with raises(AssertionError, match=msg): - check_estimator(ModifiesValueInsteadOfRaisingError()) + check_set_params("test", ModifiesValueInsteadOfRaisingError()) + with warnings.catch_warnings(record=True) as records: - check_estimator(RaisesErrorInSetParams()) + check_set_params("test", RaisesErrorInSetParams()) assert UserWarning in [rec.category for rec in records] with raises(AssertionError, match=msg): - check_estimator(ModifiesAnotherValue()) - # check that we have a fit method - msg = "object has no attribute 'fit'" - with raises(AttributeError, match=msg): - check_estimator(BaseEstimator()) - # check that fit does input validation - msg = "Did not raise" + check_set_params("test", ModifiesAnotherValue()) + + +def test_check_estimators_nan_inf(): + # check that predict does input validation (doesn't accept dicts in input) + msg = "Estimator NoCheckinPredict doesn't check for NaN and inf in predict" with raises(AssertionError, match=msg): - check_estimator(BaseBadClassifier()) + check_estimators_nan_inf("NoCheckinPredict", NoCheckinPredict()) + + +def test_check_dict_unchanged(): + # check that estimator state does not change + # at transform/predict/predict_proba time + msg = "Estimator changes __dict__ during predict" + with raises(AssertionError, match=msg): + check_dict_unchanged("test", ChangesDict()) + + +def test_check_sample_weights_pandas_series(): # check that sample_weights in fit accepts pandas.Series type try: from pandas import Series # noqa @@ -614,18 +642,14 @@ def test_check_estimator(): "'sample_weight' parameter is of type pandas.Series" ) with raises(ValueError, match=msg): - check_estimator(NoSampleWeightPandasSeriesType()) + check_sample_weights_pandas_series( + "NoSampleWeightPandasSeriesType", NoSampleWeightPandasSeriesType() + ) except ImportError: pass - # check that predict does input validation (doesn't accept dicts in input) - msg = "Estimator NoCheckinPredict doesn't check for NaN and inf in predict" - with raises(AssertionError, match=msg): - check_estimator(NoCheckinPredict()) - # check that estimator state does not change - # at transform/predict/predict_proba time - msg = "Estimator changes __dict__ during predict" - with raises(AssertionError, match=msg): - check_estimator(ChangesDict()) + + +def test_check_estimators_overwrite_params(): # check that `fit` only changes attributes that # are private (start with an _ or end with a _). msg = ( @@ -633,8 +657,13 @@ def test_check_estimator(): "the parameter wrong_attribute from 0 to 1 during fit." ) with raises(AssertionError, match=msg): - check_estimator(ChangesWrongAttribute()) - check_estimator(ChangesUnderscoreAttribute()) + check_estimators_overwrite_params( + "ChangesWrongAttribute", ChangesWrongAttribute() + ) + check_estimators_overwrite_params("test", ChangesUnderscoreAttribute()) + + +def test_check_dont_overwrite_parameters(): # check that `fit` doesn't add any public attribute msg = ( r"Estimator adds public attribute\(s\) during the fit method." @@ -643,7 +672,10 @@ def test_check_estimator(): " with _ but wrong_attribute added" ) with raises(AssertionError, match=msg): - check_estimator(SetsWrongAttribute()) + check_dont_overwrite_parameters("test", SetsWrongAttribute()) + + +def test_check_methods_sample_order_invariance(): # check for sample order invariance name = NotInvariantSampleOrder.__name__ method = "predict" @@ -652,7 +684,12 @@ def test_check_estimator(): "with different sample order." ).format(method=method, name=name) with raises(AssertionError, match=msg): - check_estimator(NotInvariantSampleOrder()) + check_methods_sample_order_invariance( + "NotInvariantSampleOrder", NotInvariantSampleOrder() + ) + + +def test_check_methods_subset_invariance(): # check for invariant method name = NotInvariantPredict.__name__ method = "predict" @@ -660,17 +697,40 @@ def test_check_estimator(): method=method, name=name ) with raises(AssertionError, match=msg): - check_estimator(NotInvariantPredict()) + check_methods_subset_invariance("NotInvariantPredict", NotInvariantPredict()) + + +def test_check_estimator_sparse_data(): # check for sparse data input handling name = NoSparseClassifier.__name__ msg = "Estimator %s doesn't seem to fail gracefully on sparse data" % name with raises(AssertionError, match=msg): - check_estimator(NoSparseClassifier("sparse_matrix")) + check_estimator_sparse_matrix(name, NoSparseClassifier("sparse_matrix")) if SPARRAY_PRESENT: with raises(AssertionError, match=msg): - check_estimator(NoSparseClassifier("sparse_array")) + check_estimator_sparse_array(name, NoSparseClassifier("sparse_array")) + # Large indices test on bad estimator + msg = ( + "Estimator LargeSparseNotSupportedClassifier doesn't seem to " + r"support \S{3}_64 matrix, and is not failing gracefully.*" + ) + with raises(AssertionError, match=msg): + check_estimator_sparse_matrix( + "LargeSparseNotSupportedClassifier", + LargeSparseNotSupportedClassifier("sparse_matrix"), + ) + + if SPARRAY_PRESENT: + with raises(AssertionError, match=msg): + check_estimator_sparse_array( + "LargeSparseNotSupportedClassifier", + LargeSparseNotSupportedClassifier("sparse_array"), + ) + + +def test_check_classifiers_one_label_sample_weights(): # check for classifiers reducing to less than two classes via sample weights name = OneClassSampleErrorClassifier.__name__ msg = ( @@ -679,19 +739,23 @@ def test_check_estimator(): "'class'." ) with raises(AssertionError, match=msg): - check_estimator(OneClassSampleErrorClassifier()) + check_classifiers_one_label_sample_weights( + "OneClassSampleErrorClassifier", OneClassSampleErrorClassifier() + ) - # Large indices test on bad estimator - msg = ( - "Estimator LargeSparseNotSupportedClassifier doesn't seem to " - r"support \S{3}_64 matrix, and is not failing gracefully.*" - ) - with raises(AssertionError, match=msg): - check_estimator(LargeSparseNotSupportedClassifier("sparse_matrix")) - if SPARRAY_PRESENT: - with raises(AssertionError, match=msg): - check_estimator(LargeSparseNotSupportedClassifier("sparse_array")) +def test_check_estimator(): + # tests that the estimator actually fails on "bad" estimators. + # not a complete test of all checks, which are very extensive. + + # check that we have a fit method + msg = "object has no attribute 'fit'" + with raises(AttributeError, match=msg): + check_estimator(BaseEstimator()) + # check that fit does input validation + msg = "Did not raise" + with raises(AssertionError, match=msg): + check_estimator(BaseBadClassifier()) # does error on binary_only untagged estimator msg = "Only 2 classes are supported"