From 28029479092fcbde34efd7e201c4dbce3f026117 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Wed, 11 Sep 2024 10:26:17 +0200 Subject: [PATCH 1/2] TST add a few more tests to API checks --- sklearn/utils/estimator_checks.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 6da7c8eb1c7ff..24fe11f7371ec 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -83,11 +83,17 @@ def _yield_api_checks(estimator): + tags = get_tags(estimator) yield check_estimator_cloneable yield check_estimator_repr yield check_no_attributes_set_in_init yield check_fit_score_takes_y yield check_estimators_overwrite_params + yield check_dont_overwrite_parameters + yield check_estimators_fit_returns_self + yield partial(check_estimators_fit_returns_self, readonly_memmap=True) + if tags.requires_fit: + yield check_estimators_unfitted def _yield_checks(estimator): @@ -105,8 +111,6 @@ def _yield_checks(estimator): yield check_sample_weights_not_overwritten yield partial(check_sample_weights_invariance, kind="ones") yield partial(check_sample_weights_invariance, kind="zeros") - yield check_estimators_fit_returns_self - yield partial(check_estimators_fit_returns_self, readonly_memmap=True) # Check that all estimator yield informative messages when # trained on empty datasets @@ -172,8 +176,6 @@ def _yield_classifier_checks(classifier): yield check_supervised_y_no_nan if tags.target_tags.single_output: yield check_supervised_y_2d - if tags.requires_fit: - yield check_estimators_unfitted if "class_weight" in classifier.get_params().keys(): yield check_class_weight_classifiers @@ -246,8 +248,6 @@ def _yield_regressor_checks(regressor): if name != "CCA": # check that the regressor handles int input yield check_regressors_int - if tags.requires_fit: - yield check_estimators_unfitted yield check_non_transformer_estimators_n_iter @@ -310,9 +310,6 @@ def _yield_outliers_checks(estimator): yield partial(check_outliers_train, readonly_memmap=True) # test outlier detectors can handle non-array data yield check_classifier_data_not_an_array - # test if NotFittedError is raised - if get_tags(estimator).requires_fit: - yield check_estimators_unfitted yield check_non_transformer_estimators_n_iter @@ -380,7 +377,6 @@ def _yield_all_checks(estimator, legacy: bool): yield check_get_params_invariance yield check_set_params yield check_dict_unchanged - yield check_dont_overwrite_parameters yield check_fit_idempotent yield check_fit_check_is_fitted if not tags.no_validation: From 391d0d0330ee420468fec4fc4ecc8cee680ebc14 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Fri, 13 Sep 2024 10:07:35 +0200 Subject: [PATCH 2/2] Address comments --- sklearn/utils/estimator_checks.py | 35 +++++++++++++++++--- sklearn/utils/tests/test_estimator_checks.py | 6 +--- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 5d06b06651b9d..992aa0a43edb9 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -90,7 +90,7 @@ def _yield_api_checks(estimator): yield check_estimators_overwrite_params yield check_dont_overwrite_parameters yield check_estimators_fit_returns_self - yield partial(check_estimators_fit_returns_self, readonly_memmap=True) + yield check_readonly_memmap_input if tags.requires_fit: yield check_estimators_unfitted yield check_do_not_raise_errors_in_init_or_set_params @@ -2721,7 +2721,7 @@ def check_get_feature_names_out_error(name, estimator_orig): @ignore_warnings(category=FutureWarning) -def check_estimators_fit_returns_self(name, estimator_orig, readonly_memmap=False): +def check_estimators_fit_returns_self(name, estimator_orig): """Check if self is returned when calling fit.""" X, y = make_blobs(random_state=0, n_samples=21) X = _enforce_estimator_tags_X(estimator_orig, X) @@ -2729,10 +2729,26 @@ def check_estimators_fit_returns_self(name, estimator_orig, readonly_memmap=Fals estimator = clone(estimator_orig) y = _enforce_estimator_tags_y(estimator, y) - if readonly_memmap: - X, y = create_memmap_backed_data([X, y]) + set_random_state(estimator) + assert estimator.fit(X, y) is estimator + + +@ignore_warnings(category=FutureWarning) +def check_readonly_memmap_input(name, estimator_orig): + """Check that the estimator can handle readonly memmap backed data. + + This is particularly needed to support joblib parallelisation. + """ + X, y = make_blobs(random_state=0, n_samples=21) + X = _enforce_estimator_tags_X(estimator_orig, X) + + estimator = clone(estimator_orig) + y = _enforce_estimator_tags_y(estimator, y) + + X, y = create_memmap_backed_data([X, y]) set_random_state(estimator) + # This should not raise an error and should return self assert estimator.fit(X, y) is estimator @@ -2742,6 +2758,15 @@ def check_estimators_unfitted(name, estimator_orig): Unfitted estimators should raise a NotFittedError. """ + err_msg = ( + "Estimator should raise a NotFittedError when calling `{method}` before fit. " + "Either call `check_is_fitted(self)` at the beginning of `{method}` or " + "set `tags.requires_fit=False` on estimator tags to disable this check.\n" + "- `check_is_fitted`: https://scikit-learn.org/dev/modules/generated/sklearn." + "utils.validation.check_is_fitted.html\n" + "- Estimator Tags: https://scikit-learn.org/dev/developers/develop." + "html#estimator-tags" + ) # Common test for Regressors, Classifiers and Outlier detection estimators X, y = _regression_dataset() @@ -2753,7 +2778,7 @@ def check_estimators_unfitted(name, estimator_orig): "predict_log_proba", ): if hasattr(estimator, method): - with raises(NotFittedError): + with raises(NotFittedError, err_msg=err_msg.format(method=method)): getattr(estimator, method)(X) diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index 133aa713a6208..fbff767160bf5 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -753,10 +753,6 @@ def test_check_estimator(): msg = "object has no attribute 'fit'" with raises(AttributeError, match=msg): check_estimator(BaseEstimator()) - # check that fit does input validation - msg = "Did not raise" - with raises(AssertionError, match=msg): - check_estimator(BaseBadClassifier()) # does error on binary_only untagged estimator msg = "Only 2 classes are supported" @@ -836,7 +832,7 @@ def test_check_estimator_clones(): def test_check_estimators_unfitted(): # check that a ValueError/AttributeError is raised when calling predict # on an unfitted estimator - msg = "Did not raise" + msg = "Estimator should raise a NotFittedError when calling" with raises(AssertionError, match=msg): check_estimators_unfitted("estimator", NoSparseClassifier())