-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG] simplify check_is_fitted to use any fitted attributes #14545
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
sklearn/utils/validation.py
Outdated
@@ -866,21 +866,18 @@ def check_symmetric(array, tol=1E-10, raise_warning=True, | |||
return array | |||
|
|||
|
|||
def check_is_fitted(estimator, attributes, msg=None, all_or_any=all): | |||
def check_is_fitted(estimator, *, msg=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aren't you changing the signature here? did you mean
check_is_fitted(estimator, *args, msg=None):
to preserve backward compatibility?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm changing the signature. We could have args
here if we consider this public, which maybe we should?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok actually using args
doesn't work unless I also add *kwargs
. So if we want backward-compatibility we need to just do a usual deprecation cycle.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I think we consider the validation utils public. But I'm happy to see them go private.
sklearn/utils/validation.py
Outdated
@@ -910,10 +904,10 @@ def check_is_fitted(estimator, attributes, msg=None, all_or_any=all): | |||
if not hasattr(estimator, 'fit'): | |||
raise TypeError("%s is not an estimator instance." % (estimator)) | |||
|
|||
if not isinstance(attributes, (list, tuple)): | |||
attributes = [attributes] | |||
attrs = [v for v in vars(estimator) if v.endswith("_") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think NearedtNeighbors has stored only _fit_X
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It has this:
scikit-learn/sklearn/neighbors/base.py
Lines 166 to 169 in 7c60ead
if self.metric_params is None: | |
self.effective_metric_params_ = {} | |
else: | |
self.effective_metric_params_ = self.metric_params.copy() |
and it was very recently documented.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Common tests pass so it must work ;)
hm vectorizer were not caught by common tests of course :-/ |
And |
See #14559, but should be passing now. This is not the cleanest work-around but that's mostly because |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is actually handy. I will need it in #14028.
Could you add an entry in the what's new?
I wonder if there is an issue with some of the meta-estimator which are note tested with the common tests. I'll check that.
I'm ambivalent about adding a whatsnew but I can do it if you think it's worth it. Probably should add a |
I would say that a note in the |
sklearn/utils/validation.py
Outdated
@@ -866,21 +865,18 @@ def check_symmetric(array, tol=1E-10, raise_warning=True, | |||
return array | |||
|
|||
|
|||
def check_is_fitted(estimator, attributes, msg=None, all_or_any=all): | |||
def check_is_fitted(estimator, attributes='deprecated', msg=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The behavior is already not back-compatible. Since we mention in the documentation that these utils can change from a version to another, I would not bother with a deprecation warning for the attributes
parameters knowing that one can have some side-effect with all_or_any
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How is it not backward compatible? Oh I could deprecate all_or_any
as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If somebody is using all_or_any
now, nothing would happen and this is not an attribute of the function as well. But as I mentioned, we clearly state in the documentation that utils
are not following the deprecation cycle and can change: https://scikit-learn.org/stable/developers/utilities.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, only deprecating one doesn't make sense. But also see the discussion at #6616. Basically, the docs say that but people ignore it and it might not be good if we enforce it and should make things private instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK I see. I am sure I was one of these people that complain at least once (then @lesteve show me the red box :))
I really feel that having the utils private could help to move quickly sometimes and help third-party project (at the cost of potential breaking if they use them). So deprecation it is :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so should I add one for all_or_any
then?
What is the behaviour expected on from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline
from sklearn.utils.validation import check_is_fitted
X, y = load_iris(return_X_y=True)
pipe = make_pipeline(StandardScaler(), LogisticRegression())
pipe.fit(X, y)
check_is_fitted(pipe) ---------------------------------------------------------------------------
NotFittedError Traceback (most recent call last)
/tmp/tmp.py in <module>
8 pipe = make_pipeline(StandardScaler(), LogisticRegression())
9 pipe.fit(X, y)
---> 10 check_is_fitted(pipe)
~/Documents/code/toolbox/scikit-learn/sklearn/utils/validation.py in check_is_fitted(estimator, msg)
910
911 if not len(attrs):
--> 912 raise NotFittedError(msg % {'name': type(estimator).__name__})
913
914
NotFittedError: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using this method. |
we should probably make a recursive call on each element of the Pipeline instance then? |
Co-Authored-By: Guillaume Lemaitre <g.lemaitre58@gmail.com>
…to anything_fitted
Do you still want to make a deprecation? |
Since we clearly state that the utils are not guaranteed to be stable, I would prefer not go through a deprecation cycle. |
@thomasjpfan I would say that remark's there mostly to limit liability ;) see my remarks above. I think I'll edit to also deprecate the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you could add at least the suggestion in the docstring. It could be easier to find it for removal.
Otherwise LGTM
assert check_is_fitted(ard) is None | ||
assert check_is_fitted(svr) is None | ||
|
||
assert_warns_message( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pytest.mark.parametrize("params", [{'attributes': ['coefs_']}, {all_or_any=any}]
def test_check_is_fitted_deprecation(params):
# FIXME: to be removed in 0.23
warn_msg = 'Passing {} to check_is_fitted is deprecated'.format(list(params.keys())[0])
with pytest.warns(DeprecationWarning, match=warn_msg):
check_is_fitted(ard, **params)
It could be handy to have a separated test function to be removed next version.
We might use pytest (but the test will be removed anyway).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it easier to remove a test than to remove the asserts? A comment might be nice but it will also just fail and so we won't forget ;)
Co-Authored-By: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Co-Authored-By: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Good to go |
yay thanks for the reviews :) |
This simplifies
check_is_fitted
to error if no fitted attribute is found.This clearly is less strict than what we had before, but I did not need to change any tests, so according to our tests (i.e. the guaranteed functionality), this implementation is as good as the previous one.
The main motivation for this change is to allow us to reduce boiler-plate in the future. If we introduce a validation method as in #13603, we could now include the
check_is_fitted
there.