Skip to content

[MRG] Prototype 4 for strict check_estimator mode #17361

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
8775c1e
treat strict checks as xfail checks
NicolasHug May 16, 2020
8c07ca6
Merge branch 'master' of github.com:scikit-learn/scikit-learn into st…
NicolasHug May 19, 2020
664552f
different names
NicolasHug May 19, 2020
ecff04c
Comments
NicolasHug May 19, 2020
8e66d47
Merge branch 'master' of github.com:scikit-learn/scikit-learn into st…
NicolasHug May 20, 2020
c7c5c8d
some clearning
NicolasHug May 20, 2020
e7d5f7c
Merge branch 'master' of github.com:scikit-learn/scikit-learn into st…
NicolasHug May 25, 2020
ff09e3a
Merge branch 'master' of github.com:scikit-learn/scikit-learn into st…
NicolasHug May 26, 2020
a2b5bf5
This is hard
NicolasHug May 26, 2020
2e2bebd
put back reasons
NicolasHug May 27, 2020
0a61d69
comments and cleaning
NicolasHug May 27, 2020
e1f1761
Merge branch 'master' of github.com:scikit-learn/scikit-learn into st…
NicolasHug May 27, 2020
9c9757b
typo
NicolasHug May 27, 2020
92b45e1
check name in xfail message
NicolasHug Jun 3, 2020
6c8af6a
Merge branch 'master' into strict_mode_xfails_partially_strict_checks
rth Jul 10, 2020
a92320d
Lint
rth Jul 10, 2020
82584dc
Lint
rth Jul 11, 2020
c88329b
Merge branch 'master' of github.com:scikit-learn/scikit-learn into st…
NicolasHug Jul 20, 2020
a0195f7
removed use of parenthesis in xfail_checks tag
NicolasHug Jul 20, 2020
3b69426
Addressed comments from Joel
NicolasHug Aug 2, 2020
735791c
Merge branch 'master' of github.com:scikit-learn/scikit-learn into st…
NicolasHug Aug 2, 2020
993dd94
probably fixed test
NicolasHug Aug 2, 2020
b55ee3c
use generator function instead of comprehension, hopefully clearer
NicolasHug Aug 2, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sklearn/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ class that has the highest probability, and can thus be different
def _more_tags(self):
return {
'_xfail_checks': {
'check_sample_weights_invariance(kind=zeros)':
'check_sample_weights_invariance':
'zero sample_weight is not equivalent to removing samples',
}
}
Expand Down
4 changes: 2 additions & 2 deletions sklearn/cluster/_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,7 +1163,7 @@ def score(self, X, y=None, sample_weight=None):
def _more_tags(self):
return {
'_xfail_checks': {
'check_sample_weights_invariance(kind=zeros)':
'check_sample_weights_invariance':
'zero sample_weight is not equivalent to removing samples',
}
}
Expand Down Expand Up @@ -1889,7 +1889,7 @@ def predict(self, X, sample_weight=None):
def _more_tags(self):
return {
'_xfail_checks': {
'check_sample_weights_invariance(kind=zeros)':
'check_sample_weights_invariance':
'zero sample_weight is not equivalent to removing samples',
}
}
2 changes: 1 addition & 1 deletion sklearn/ensemble/_iforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def _compute_score_samples(self, X, subsample_features):
def _more_tags(self):
return {
'_xfail_checks': {
'check_sample_weights_invariance(kind=zeros)':
'check_sample_weights_invariance':
'zero sample_weight is not equivalent to removing samples',
}
}
Expand Down
2 changes: 1 addition & 1 deletion sklearn/linear_model/_logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2090,7 +2090,7 @@ def score(self, X, y, sample_weight=None):
def _more_tags(self):
return {
'_xfail_checks': {
'check_sample_weights_invariance(kind=zeros)':
'check_sample_weights_invariance':
'zero sample_weight is not equivalent to removing samples',
}
}
2 changes: 1 addition & 1 deletion sklearn/linear_model/_ransac.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ def score(self, X, y):
def _more_tags(self):
return {
'_xfail_checks': {
'check_sample_weights_invariance(kind=zeros)':
'check_sample_weights_invariance':
'zero sample_weight is not equivalent to removing samples',
}
}
2 changes: 1 addition & 1 deletion sklearn/linear_model/_ridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1913,7 +1913,7 @@ def classes_(self):
def _more_tags(self):
return {
'_xfail_checks': {
'check_sample_weights_invariance(kind=zeros)':
'check_sample_weights_invariance':
'zero sample_weight is not equivalent to removing samples',
}
}
4 changes: 2 additions & 2 deletions sklearn/linear_model/_stochastic_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,7 +1098,7 @@ def _predict_log_proba(self, X):
def _more_tags(self):
return {
'_xfail_checks': {
'check_sample_weights_invariance(kind=zeros)':
'check_sample_weights_invariance':
'zero sample_weight is not equivalent to removing samples',
}
}
Expand Down Expand Up @@ -1588,7 +1588,7 @@ def __init__(self, loss="squared_loss", *, penalty="l2", alpha=0.0001,
def _more_tags(self):
return {
'_xfail_checks': {
'check_sample_weights_invariance(kind=zeros)':
'check_sample_weights_invariance':
'zero sample_weight is not equivalent to removing samples',
}
}
2 changes: 1 addition & 1 deletion sklearn/neighbors/_kde.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def sample(self, n_samples=1, random_state=None):
def _more_tags(self):
return {
'_xfail_checks': {
'check_sample_weights_invariance(kind=zeros)':
'check_sample_weights_invariance':
'sample_weight must have positive values',
}
}
14 changes: 7 additions & 7 deletions sklearn/svm/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def fit(self, X, y, sample_weight=None):
def _more_tags(self):
return {
'_xfail_checks': {
'check_sample_weights_invariance(kind=zeros)':
'check_sample_weights_invariance':
'zero sample_weight is not equivalent to removing samples',
}
}
Expand Down Expand Up @@ -436,7 +436,7 @@ def fit(self, X, y, sample_weight=None):
def _more_tags(self):
return {
'_xfail_checks': {
'check_sample_weights_invariance(kind=zeros)':
'check_sample_weights_invariance':
'zero sample_weight is not equivalent to removing samples',
}
}
Expand Down Expand Up @@ -670,7 +670,7 @@ def __init__(self, *, C=1.0, kernel='rbf', degree=3, gamma='scale',
def _more_tags(self):
return {
'_xfail_checks': {
'check_sample_weights_invariance(kind=zeros)':
'check_sample_weights_invariance':
'zero sample_weight is not equivalent to removing samples',
}
}
Expand Down Expand Up @@ -895,7 +895,7 @@ def _more_tags(self):
'check_methods_subset_invariance':
'fails for the decision_function method',
'check_class_weight_classifiers': 'class_weight is ignored.',
'check_sample_weights_invariance(kind=zeros)':
'check_sample_weights_invariance':
'zero sample_weight is not equivalent to removing samples',
}
}
Expand Down Expand Up @@ -1072,7 +1072,7 @@ def probB_(self):
def _more_tags(self):
return {
'_xfail_checks': {
'check_sample_weights_invariance(kind=zeros)':
'check_sample_weights_invariance':
'zero sample_weight is not equivalent to removing samples',
}
}
Expand Down Expand Up @@ -1226,7 +1226,7 @@ def __init__(self, *, nu=0.5, C=1.0, kernel='rbf', degree=3,
def _more_tags(self):
return {
'_xfail_checks': {
'check_sample_weights_invariance(kind=zeros)':
'check_sample_weights_invariance':
'zero sample_weight is not equivalent to removing samples',
}
}
Expand Down Expand Up @@ -1459,7 +1459,7 @@ def probB_(self):
def _more_tags(self):
return {
'_xfail_checks': {
'check_sample_weights_invariance(kind=zeros)':
'check_sample_weights_invariance':
'zero sample_weight is not equivalent to removing samples',
}
}
74 changes: 69 additions & 5 deletions sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,27 @@
from functools import partial

import pytest

import numpy as np

from sklearn.utils import all_estimators
from sklearn.utils._testing import ignore_warnings
from sklearn.exceptions import ConvergenceWarning
from sklearn.exceptions import ConvergenceWarning, SkipTestWarning
from sklearn.utils.estimator_checks import check_estimator

import sklearn
from sklearn.base import BiclusterMixin

from sklearn.decomposition import NMF
from sklearn.utils.validation import check_non_negative, check_array
from sklearn.linear_model._base import LinearClassifierMixin
from sklearn.linear_model import LogisticRegression
from sklearn.svm import NuSVC
from sklearn.utils import IS_PYPY
from sklearn.utils._testing import SkipTest
from sklearn.utils.estimator_checks import (
_construct_instance,
_set_checking_parameters,
_set_check_estimator_ids,
_get_check_estimator_ids,
check_class_weight_balanced_linear_classifier,
parametrize_with_checks)

Expand All @@ -59,8 +62,8 @@ def _sample_func(x, y=1):
"LogisticRegression(class_weight='balanced',random_state=1,"
"solver='newton-cg',warm_start=True)")
])
def test_set_check_estimator_ids(val, expected):
assert _set_check_estimator_ids(val) == expected
def test_get_check_estimator_ids(val, expected):
assert _get_check_estimator_ids(val) == expected


def _tested_estimators():
Expand Down Expand Up @@ -204,3 +207,64 @@ def test_class_support_removed():

with pytest.raises(TypeError, match=msg):
parametrize_with_checks([LogisticRegression])


class MyNMFWithBadErrorMessage(NMF):
# Same as NMF but raises an uninformative error message if X has negative
# value. This estimator would fail the check suite in strict mode,
# specifically it would fail check_fit_non_negative
def fit(self, X, y=None, **params):
X = check_array(X, accept_sparse=('csr', 'csc'),
dtype=[np.float64, np.float32])
try:
check_non_negative(X, whom='')
except ValueError:
raise ValueError("Some non-informative error msg")

return super().fit(X, y, **params)


def test_strict_mode_check_estimator():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious how long does this test takes? Running check_estimator on an estimator takes a while and we should likely avoid it if possible, though here I guess we have no choice?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

about 6 seconds on my machine. I agree it's a bit long but I think it's worth it?

# Tests various conditions for the strict mode of check_estimator()
# Details are in the comments

# LogisticRegression has no _xfail_checks, so when strict_mode is on, there
# should be no skipped tests.
with pytest.warns(None) as catched_warnings:
check_estimator(LogisticRegression(), strict_mode=True)
assert not any(isinstance(w, SkipTestWarning) for w in catched_warnings)
# When strict mode is off, check_n_features should be skipped because it's
# a fully strict check
msg_check_n_features_in = 'check_n_features_in is fully strict '
with pytest.warns(SkipTestWarning, match=msg_check_n_features_in):
check_estimator(LogisticRegression(), strict_mode=False)

# NuSVC has some _xfail_checks. They should be skipped regardless of
# strict_mode
with pytest.warns(SkipTestWarning,
match='fails for the decision_function method'):
check_estimator(NuSVC(), strict_mode=True)
# When strict mode is off, check_n_features_in is skipped along with the
# rest of the xfail_checks
with pytest.warns(SkipTestWarning, match=msg_check_n_features_in):
check_estimator(NuSVC(), strict_mode=False)

# MyNMF will fail check_fit_non_negative() in strict mode because it yields
# a bad error message
with pytest.raises(AssertionError, match='does not match'):
check_estimator(MyNMFWithBadErrorMessage(), strict_mode=True)
# However, it should pass the test suite in non-strict mode because when
# strict mode is off, check_fit_non_negative() will not check the exact
# error messsage. (We still assert that the warning from
# check_n_features_in is raised)
with pytest.warns(SkipTestWarning, match=msg_check_n_features_in):
check_estimator(MyNMFWithBadErrorMessage(), strict_mode=False)


@parametrize_with_checks([LogisticRegression(),
NuSVC(),
MyNMFWithBadErrorMessage()],
strict_mode=False)
def test_strict_mode_parametrize_with_checks(estimator, check):
# Ideally we should assert that the strict checks are Xfailed...
check(estimator)
Loading