Skip to content

TST add a few more tests to API checks #29832

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 36 additions & 14 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,17 @@


def _yield_api_checks(estimator):
tags = get_tags(estimator)
yield check_estimator_cloneable
yield check_estimator_repr
yield check_no_attributes_set_in_init
yield check_fit_score_takes_y
yield check_estimators_overwrite_params
yield check_dont_overwrite_parameters
yield check_estimators_fit_returns_self
yield check_readonly_memmap_input
if tags.requires_fit:
yield check_estimators_unfitted
yield check_do_not_raise_errors_in_init_or_set_params


Expand All @@ -104,8 +111,6 @@ def _yield_checks(estimator):
yield check_sample_weights_not_overwritten
yield partial(check_sample_weights_invariance, kind="ones")
yield partial(check_sample_weights_invariance, kind="zeros")
yield check_estimators_fit_returns_self
yield partial(check_estimators_fit_returns_self, readonly_memmap=True)

# Check that all estimator yield informative messages when
# trained on empty datasets
Expand Down Expand Up @@ -173,8 +178,6 @@ def _yield_classifier_checks(classifier):
yield check_supervised_y_no_nan
if tags.target_tags.single_output:
yield check_supervised_y_2d
if tags.requires_fit:
yield check_estimators_unfitted
if "class_weight" in classifier.get_params().keys():
yield check_class_weight_classifiers

Expand Down Expand Up @@ -247,8 +250,6 @@ def _yield_regressor_checks(regressor):
if name != "CCA":
# check that the regressor handles int input
yield check_regressors_int
if tags.requires_fit:
yield check_estimators_unfitted
yield check_non_transformer_estimators_n_iter


Expand Down Expand Up @@ -311,9 +312,6 @@ def _yield_outliers_checks(estimator):
yield partial(check_outliers_train, readonly_memmap=True)
# test outlier detectors can handle non-array data
yield check_classifier_data_not_an_array
# test if NotFittedError is raised
if get_tags(estimator).requires_fit:
yield check_estimators_unfitted
yield check_non_transformer_estimators_n_iter


Expand Down Expand Up @@ -381,7 +379,6 @@ def _yield_all_checks(estimator, legacy: bool):
yield check_get_params_invariance
yield check_set_params
yield check_dict_unchanged
yield check_dont_overwrite_parameters
yield check_fit_idempotent
yield check_fit_check_is_fitted
if not tags.no_validation:
Expand Down Expand Up @@ -2724,18 +2721,34 @@ def check_get_feature_names_out_error(name, estimator_orig):


@ignore_warnings(category=FutureWarning)
def check_estimators_fit_returns_self(name, estimator_orig, readonly_memmap=False):
def check_estimators_fit_returns_self(name, estimator_orig):
"""Check if self is returned when calling fit."""
X, y = make_blobs(random_state=0, n_samples=21)
X = _enforce_estimator_tags_X(estimator_orig, X)

estimator = clone(estimator_orig)
y = _enforce_estimator_tags_y(estimator, y)

if readonly_memmap:
X, y = create_memmap_backed_data([X, y])
set_random_state(estimator)
assert estimator.fit(X, y) is estimator


@ignore_warnings(category=FutureWarning)
def check_readonly_memmap_input(name, estimator_orig):
"""Check that the estimator can handle readonly memmap backed data.

This is particularly needed to support joblib parallelisation.
"""
X, y = make_blobs(random_state=0, n_samples=21)
X = _enforce_estimator_tags_X(estimator_orig, X)

estimator = clone(estimator_orig)
y = _enforce_estimator_tags_y(estimator, y)

X, y = create_memmap_backed_data([X, y])

set_random_state(estimator)
# This should not raise an error and should return self
assert estimator.fit(X, y) is estimator


Expand All @@ -2745,6 +2758,15 @@ def check_estimators_unfitted(name, estimator_orig):

Unfitted estimators should raise a NotFittedError.
"""
err_msg = (
"Estimator should raise a NotFittedError when calling `{method}` before fit. "
"Either call `check_is_fitted(self)` at the beginning of `{method}` or "
"set `tags.requires_fit=False` on estimator tags to disable this check.\n"
"- `check_is_fitted`: https://scikit-learn.org/dev/modules/generated/sklearn."
"utils.validation.check_is_fitted.html\n"
"- Estimator Tags: https://scikit-learn.org/dev/developers/develop."
"html#estimator-tags"
)
# Common test for Regressors, Classifiers and Outlier detection estimators
X, y = _regression_dataset()

Expand All @@ -2756,7 +2778,7 @@ def check_estimators_unfitted(name, estimator_orig):
"predict_log_proba",
):
if hasattr(estimator, method):
with raises(NotFittedError):
with raises(NotFittedError, err_msg=err_msg.format(method=method)):
getattr(estimator, method)(X)


Expand Down
6 changes: 1 addition & 5 deletions sklearn/utils/tests/test_estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,10 +753,6 @@ def test_check_estimator():
msg = "object has no attribute 'fit'"
with raises(AttributeError, match=msg):
check_estimator(BaseEstimator())
# check that fit does input validation
msg = "Did not raise"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now that we assert the right message, it became clear that this was already tested in another test and was redundant.

with raises(AssertionError, match=msg):
check_estimator(BaseBadClassifier())

# does error on binary_only untagged estimator
msg = "Only 2 classes are supported"
Expand Down Expand Up @@ -836,7 +832,7 @@ def test_check_estimator_clones():
def test_check_estimators_unfitted():
# check that a ValueError/AttributeError is raised when calling predict
# on an unfitted estimator
msg = "Did not raise"
msg = "Estimator should raise a NotFittedError when calling"
with raises(AssertionError, match=msg):
check_estimators_unfitted("estimator", NoSparseClassifier())

Expand Down