Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def test_get_check_estimator_ids(val, expected):


@parametrize_with_checks(
list(_tested_estimators()), expected_failed_checks=_get_expected_failed_checks
list(_tested_estimators()),
expected_failed_checks=_get_expected_failed_checks,
xfail_strict=True,
)
def test_estimators(estimator, check, request):
# Common tests for estimator instances
Expand Down
84 changes: 24 additions & 60 deletions sklearn/utils/_test_common/instance_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,9 @@
RegressorChain,
)
from sklearn.neighbors import (
KernelDensity,
KNeighborsClassifier,
KNeighborsRegressor,
KNeighborsTransformer,
NeighborhoodComponentsAnalysis,
RadiusNeighborsTransformer,
)
from sklearn.neural_network import BernoulliRBM, MLPClassifier, MLPRegressor
from sklearn.pipeline import FeatureUnion, Pipeline
Expand Down Expand Up @@ -845,24 +842,6 @@ def _yield_instances_for_check(check, estimator_orig):


PER_ESTIMATOR_XFAIL_CHECKS = {
AdaBoostClassifier: {
# TODO: replace by a statistical test, see meta-issue #16298
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an instance where the check doesn't fail but I am not sure if this is because AdaBoostClassifier has been fixed or because the check is not "good enough" to detect that sample weight handling is broken?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's possible that the check is a bit too weak. Still we can remove the XFAIL markers for this and maybe readd it later when needed if we ever make the check stronger.

"check_sample_weight_equivalence_on_dense_data": (
"sample_weight is not equivalent to removing/repeating samples."
),
"check_sample_weight_equivalence_on_sparse_data": (
"sample_weight is not equivalent to removing/repeating samples."
),
},
AdaBoostRegressor: {
# TODO: replace by a statistical test, see meta-issue #16298
"check_sample_weight_equivalence_on_dense_data": (
"sample_weight is not equivalent to removing/repeating samples."
),
"check_sample_weight_equivalence_on_sparse_data": (
"sample_weight is not equivalent to removing/repeating samples."
),
},
BaggingClassifier: {
# TODO: replace by a statistical test, see meta-issue #16298
"check_sample_weight_equivalence_on_dense_data": (
Expand Down Expand Up @@ -914,7 +893,6 @@ def _yield_instances_for_check(check, estimator_orig):
"check_dont_overwrite_parameters": "FIXME",
},
FixedThresholdClassifier: {
"check_classifiers_train": "Threshold at probability 0.5 does not hold",
"check_sample_weight_equivalence_on_dense_data": (
"Due to the cross-validation and sample ordering, removing a sample"
" is not strictly equal to putting is weight to zero. Specific unit"
Expand Down Expand Up @@ -950,21 +928,15 @@ def _yield_instances_for_check(check, estimator_orig):
"check_fit2d_1sample": (
"Fail during parameter check since min/max resources requires more samples"
),
"check_estimators_nan_inf": "FIXME",
"check_classifiers_one_label_sample_weights": "FIXME",
"check_fit2d_1feature": "FIXME",
"check_supervised_y_2d": "DataConversionWarning not caught",
"check_requires_y_none": "Doesn't fail gracefully",
"check_supervised_y_2d": "DataConversionWarning not caught",
},
HalvingRandomSearchCV: {
"check_fit2d_1sample": (
"Fail during parameter check since min/max resources requires more samples"
),
"check_estimators_nan_inf": "FIXME",
"check_classifiers_one_label_sample_weights": "FIXME",
"check_fit2d_1feature": "FIXME",
"check_supervised_y_2d": "DataConversionWarning not caught",
"check_requires_y_none": "Doesn't fail gracefully",
"check_supervised_y_2d": "DataConversionWarning not caught",
},
HistGradientBoostingClassifier: {
# TODO: replace by a statistical test, see meta-issue #16298
Expand Down Expand Up @@ -993,11 +965,6 @@ def _yield_instances_for_check(check, estimator_orig):
"sample_weight is not equivalent to removing/repeating samples."
),
},
KernelDensity: {
"check_sample_weight_equivalence_on_dense_data": (
"sample_weight must have positive values"
),
},
KMeans: {
# TODO: replace by a statistical test, see meta-issue #16298
"check_sample_weight_equivalence_on_dense_data": (
Expand All @@ -1007,9 +974,6 @@ def _yield_instances_for_check(check, estimator_orig):
"sample_weight is not equivalent to removing/repeating samples."
),
},
KNeighborsTransformer: {
"check_methods_sample_order_invariance": "check is not applicable."
},
LinearSVC: {
# TODO: replace by a statistical test when _dual=True, see meta-issue #16298
"check_sample_weight_equivalence_on_dense_data": (
Expand All @@ -1018,9 +982,6 @@ def _yield_instances_for_check(check, estimator_orig):
"check_sample_weight_equivalence_on_sparse_data": (
"sample_weight is not equivalent to removing/repeating samples."
),
"check_non_transformer_estimators_n_iter": (
"n_iter_ cannot be easily accessed."
),
},
LinearSVR: {
# TODO: replace by a statistical test, see meta-issue #16298
Expand All @@ -1031,15 +992,6 @@ def _yield_instances_for_check(check, estimator_orig):
"sample_weight is not equivalent to removing/repeating samples."
),
},
LogisticRegression: {
# TODO: fix sample_weight handling of this estimator, see meta-issue #16298
"check_sample_weight_equivalence_on_dense_data": (
"sample_weight is not equivalent to removing/repeating samples."
),
"check_sample_weight_equivalence_on_sparse_data": (
"sample_weight is not equivalent to removing/repeating samples."
),
},
MiniBatchKMeans: {
# TODO: replace by a statistical test, see meta-issue #16298
"check_sample_weight_equivalence_on_dense_data": (
Expand Down Expand Up @@ -1106,9 +1058,6 @@ def _yield_instances_for_check(check, estimator_orig):
"Therefore this test is x-fail until we fix this."
),
},
RadiusNeighborsTransformer: {
"check_methods_sample_order_invariance": "check is not applicable."
},
RandomForestClassifier: {
# TODO: replace by a statistical test, see meta-issue #16298
"check_sample_weight_equivalence_on_dense_data": (
Expand All @@ -1128,8 +1077,8 @@ def _yield_instances_for_check(check, estimator_orig):
),
},
RandomizedSearchCV: {
"check_supervised_y_2d": "DataConversionWarning not caught",
"check_requires_y_none": "Doesn't fail gracefully",
"check_supervised_y_2d": "DataConversionWarning not caught",
},
RandomTreesEmbedding: {
# TODO: replace by a statistical test, see meta-issue #16298
Expand Down Expand Up @@ -1206,11 +1155,6 @@ def _yield_instances_for_check(check, estimator_orig):
"check_estimators_dtypes": "raises nan error",
"check_fit2d_1sample": "_scale_normalize fails",
"check_fit2d_1feature": "raises apply_along_axis error",
"check_estimator_sparse_matrix": "does not fail gracefully",
"check_estimator_sparse_array": "does not fail gracefully",
"check_methods_subset_invariance": "empty array passed inside",
"check_dont_overwrite_parameters": "empty array passed inside",
"check_fit2d_predict1d": "empty array passed inside",
},
SVC: {
# TODO: fix sample_weight handling of this estimator when probability=False
Expand Down Expand Up @@ -1254,7 +1198,9 @@ def _yield_instances_for_check(check, estimator_orig):

def _get_expected_failed_checks(estimator):
"""Get the expected failed checks for all estimators in scikit-learn."""
failed_checks = PER_ESTIMATOR_XFAIL_CHECKS.get(type(estimator), {})
# Make a copy so that our modifications are not permanent. Important for
# estimators that are tested with multiple hyper-parameters.
failed_checks = dict(PER_ESTIMATOR_XFAIL_CHECKS.get(type(estimator), {}))

tags = get_tags(estimator)

Expand Down Expand Up @@ -1287,4 +1233,22 @@ def _get_expected_failed_checks(estimator):
}
)

if type(estimator) == DummyClassifier and estimator.strategy == "most_frequent":
failed_checks.pop("check_methods_sample_order_invariance")
failed_checks.pop("check_methods_subset_invariance")

if type(estimator) in (
GridSearchCV,
HalvingGridSearchCV,
HalvingRandomSearchCV,
RandomizedSearchCV,
) and (
type(estimator.estimator) == LogisticRegression
or (
type(estimator.estimator) == Pipeline
and estimator.estimator.steps[1][0] == "logisticregression"
)
):
failed_checks.pop("check_supervised_y_2d")

return failed_checks
Loading