Skip to content

TST be more specific in test_estimator_checks #29834

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 107 additions & 44 deletions sklearn/utils/tests/test_estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,19 @@
check_classifiers_multilabel_output_format_decision_function,
check_classifiers_multilabel_output_format_predict,
check_classifiers_multilabel_output_format_predict_proba,
check_classifiers_one_label_sample_weights,
check_dataframe_column_names_consistency,
check_decision_proba_consistency,
check_dict_unchanged,
check_dont_overwrite_parameters,
check_estimator,
check_estimator_cloneable,
check_estimator_repr,
check_estimator_sparse_array,
check_estimator_sparse_matrix,
check_estimator_tags_renamed,
check_estimators_nan_inf,
check_estimators_overwrite_params,
check_estimators_unfitted,
check_fit_check_is_fitted,
check_fit_score_takes_y,
Expand All @@ -62,8 +69,10 @@
check_no_attributes_set_in_init,
check_outlier_contamination,
check_outlier_corruption,
check_parameters_default_constructible,
check_regressor_data_not_an_array,
check_requires_y_none,
check_sample_weights_pandas_series,
check_set_params,
set_random_state,
)
Expand Down Expand Up @@ -573,40 +582,58 @@ def fit(self, X, y):
check_fit_score_takes_y("test", TestEstimatorWithDeprecatedFitMethod())


def test_check_estimator():
# tests that the estimator actually fails on "bad" estimators.
# not a complete test of all checks, which are very extensive.

# check that we have a set_params and can clone
def test_check_estimator_with_class_removed():
"""Test that passing a class instead of an instance fails."""
msg = "Passing a class was deprecated"
with raises(TypeError, match=msg):
check_estimator(object)
check_estimator(LogisticRegression)


def test_mutable_default_params():
"""Test that constructor cannot have mutable default parameters."""
msg = (
"Parameter 'p' of estimator 'HasMutableParameters' is of type "
"object which is not allowed"
)
# check that the "default_constructible" test checks for mutable parameters
check_estimator(HasImmutableParameters()) # should pass
check_parameters_default_constructible(
"Immutable", HasImmutableParameters()
) # should pass
with raises(AssertionError, match=msg):
check_estimator(HasMutableParameters())
check_parameters_default_constructible("Mutable", HasMutableParameters())


def test_check_set_params():
"""Check set_params doesn't fail and sets the right values."""
# check that values returned by get_params match set_params
msg = "get_params result does not match what was passed to set_params"
with raises(AssertionError, match=msg):
check_set_params("test", ModifiesValueInsteadOfRaisingError())

with warnings.catch_warnings(record=True) as records:
check_set_params("test", RaisesErrorInSetParams())
assert UserWarning in [rec.category for rec in records]

with raises(AssertionError, match=msg):
check_estimator(ModifiesAnotherValue())
# check that we have a fit method
msg = "object has no attribute 'fit'"
with raises(AttributeError, match=msg):
check_estimator(BaseEstimator())
# check that fit does input validation
msg = "Did not raise"
check_set_params("test", ModifiesAnotherValue())


def test_check_estimators_nan_inf():
# check that predict does input validation (doesn't accept dicts in input)
msg = "Estimator NoCheckinPredict doesn't check for NaN and inf in predict"
with raises(AssertionError, match=msg):
check_estimator(BaseBadClassifier())
check_estimators_nan_inf("NoCheckinPredict", NoCheckinPredict())


def test_check_dict_unchanged():
# check that estimator state does not change
# at transform/predict/predict_proba time
msg = "Estimator changes __dict__ during predict"
with raises(AssertionError, match=msg):
check_dict_unchanged("test", ChangesDict())


def test_check_sample_weights_pandas_series():
# check that sample_weights in fit accepts pandas.Series type
try:
from pandas import Series # noqa
Expand All @@ -616,27 +643,28 @@ def test_check_estimator():
"'sample_weight' parameter is of type pandas.Series"
)
with raises(ValueError, match=msg):
check_estimator(NoSampleWeightPandasSeriesType())
check_sample_weights_pandas_series(
"NoSampleWeightPandasSeriesType", NoSampleWeightPandasSeriesType()
)
except ImportError:
pass
# check that predict does input validation (doesn't accept dicts in input)
msg = "Estimator NoCheckinPredict doesn't check for NaN and inf in predict"
with raises(AssertionError, match=msg):
check_estimator(NoCheckinPredict())
# check that estimator state does not change
# at transform/predict/predict_proba time
msg = "Estimator changes __dict__ during predict"
with raises(AssertionError, match=msg):
check_estimator(ChangesDict())


def test_check_estimators_overwrite_params():
# check that `fit` only changes attributes that
# are private (start with an _ or end with a _).
msg = (
"Estimator ChangesWrongAttribute should not change or mutate "
"the parameter wrong_attribute from 0 to 1 during fit."
)
with raises(AssertionError, match=msg):
check_estimator(ChangesWrongAttribute())
check_estimator(ChangesUnderscoreAttribute())
check_estimators_overwrite_params(
"ChangesWrongAttribute", ChangesWrongAttribute()
)
check_estimators_overwrite_params("test", ChangesUnderscoreAttribute())


def test_check_dont_overwrite_parameters():
# check that `fit` doesn't add any public attribute
msg = (
r"Estimator adds public attribute\(s\) during the fit method."
Expand All @@ -645,7 +673,10 @@ def test_check_estimator():
" with _ but wrong_attribute added"
)
with raises(AssertionError, match=msg):
check_estimator(SetsWrongAttribute())
check_dont_overwrite_parameters("test", SetsWrongAttribute())


def test_check_methods_sample_order_invariance():
# check for sample order invariance
name = NotInvariantSampleOrder.__name__
method = "predict"
Expand All @@ -654,25 +685,53 @@ def test_check_estimator():
"with different sample order."
).format(method=method, name=name)
with raises(AssertionError, match=msg):
check_estimator(NotInvariantSampleOrder())
check_methods_sample_order_invariance(
"NotInvariantSampleOrder", NotInvariantSampleOrder()
)


def test_check_methods_subset_invariance():
# check for invariant method
name = NotInvariantPredict.__name__
method = "predict"
msg = ("{method} of {name} is not invariant when applied to a subset.").format(
method=method, name=name
)
with raises(AssertionError, match=msg):
check_estimator(NotInvariantPredict())
check_methods_subset_invariance("NotInvariantPredict", NotInvariantPredict())


def test_check_estimator_sparse_data():
# check for sparse data input handling
name = NoSparseClassifier.__name__
msg = "Estimator %s doesn't seem to fail gracefully on sparse data" % name
with raises(AssertionError, match=msg):
check_estimator(NoSparseClassifier("sparse_matrix"))
check_estimator_sparse_matrix(name, NoSparseClassifier("sparse_matrix"))

if SPARRAY_PRESENT:
with raises(AssertionError, match=msg):
check_estimator(NoSparseClassifier("sparse_array"))
check_estimator_sparse_array(name, NoSparseClassifier("sparse_array"))

# Large indices test on bad estimator
msg = (
"Estimator LargeSparseNotSupportedClassifier doesn't seem to "
r"support \S{3}_64 matrix, and is not failing gracefully.*"
)
with raises(AssertionError, match=msg):
check_estimator_sparse_matrix(
"LargeSparseNotSupportedClassifier",
LargeSparseNotSupportedClassifier("sparse_matrix"),
)

if SPARRAY_PRESENT:
with raises(AssertionError, match=msg):
check_estimator_sparse_array(
"LargeSparseNotSupportedClassifier",
LargeSparseNotSupportedClassifier("sparse_array"),
)


def test_check_classifiers_one_label_sample_weights():
# check for classifiers reducing to less than two classes via sample weights
name = OneClassSampleErrorClassifier.__name__
msg = (
Expand All @@ -681,19 +740,23 @@ def test_check_estimator():
"'class'."
)
with raises(AssertionError, match=msg):
check_estimator(OneClassSampleErrorClassifier())
check_classifiers_one_label_sample_weights(
"OneClassSampleErrorClassifier", OneClassSampleErrorClassifier()
)

# Large indices test on bad estimator
msg = (
"Estimator LargeSparseNotSupportedClassifier doesn't seem to "
r"support \S{3}_64 matrix, and is not failing gracefully.*"
)
with raises(AssertionError, match=msg):
check_estimator(LargeSparseNotSupportedClassifier("sparse_matrix"))

if SPARRAY_PRESENT:
with raises(AssertionError, match=msg):
check_estimator(LargeSparseNotSupportedClassifier("sparse_array"))
def test_check_estimator():
# tests that the estimator actually fails on "bad" estimators.
# not a complete test of all checks, which are very extensive.

# check that we have a fit method
msg = "object has no attribute 'fit'"
with raises(AttributeError, match=msg):
check_estimator(BaseEstimator())
# check that fit does input validation
msg = "Did not raise"
with raises(AssertionError, match=msg):
check_estimator(BaseBadClassifier())

# does error on binary_only untagged estimator
msg = "Only 2 classes are supported"
Expand Down