Skip to content

[MRG] simplify check_is_fitted to use any fitted attributes #14545

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Aug 13, 2019

Conversation

amueller
Copy link
Member

@amueller amueller commented Aug 1, 2019

This simplifies check_is_fitted to error if no fitted attribute is found.
This clearly is less strict than what we had before, but I did not need to change any tests, so according to our tests (i.e. the guaranteed functionality), this implementation is as good as the previous one.

The main motivation for this change is to allow us to reduce boiler-plate in the future. If we introduce a validation method as in #13603, we could now include the check_is_fitted there.

@@ -866,21 +866,18 @@ def check_symmetric(array, tol=1E-10, raise_warning=True,
return array


def check_is_fitted(estimator, attributes, msg=None, all_or_any=all):
def check_is_fitted(estimator, *, msg=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aren't you changing the signature here? did you mean

check_is_fitted(estimator, *args, msg=None): to preserve backward compatibility?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm changing the signature. We could have args here if we consider this public, which maybe we should?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok actually using args doesn't work unless I also add *kwargs. So if we want backward-compatibility we need to just do a usual deprecation cycle.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think we consider the validation utils public. But I'm happy to see them go private.

@@ -910,10 +904,10 @@ def check_is_fitted(estimator, attributes, msg=None, all_or_any=all):
if not hasattr(estimator, 'fit'):
raise TypeError("%s is not an estimator instance." % (estimator))

if not isinstance(attributes, (list, tuple)):
attributes = [attributes]
attrs = [v for v in vars(estimator) if v.endswith("_")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think NearedtNeighbors has stored only _fit_X

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has this:

if self.metric_params is None:
self.effective_metric_params_ = {}
else:
self.effective_metric_params_ = self.metric_params.copy()

and it was very recently documented.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Common tests pass so it must work ;)

@amueller
Copy link
Member Author

amueller commented Aug 2, 2019

hm vectorizer were not caught by common tests of course :-/ TfidfTransformer only has _idf_diag

@amueller
Copy link
Member Author

amueller commented Aug 2, 2019

And CountVectorizer is misbehaving....

@amueller
Copy link
Member Author

amueller commented Aug 2, 2019

See #14559, but should be passing now. This is not the cleanest work-around but that's mostly because CountVectorizer doesn't adhere to conventions. This is a minimum change (that also fixes a bug where you had to call transform before being able to call inverse_transform).

Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually handy. I will need it in #14028.
Could you add an entry in the what's new?

I wonder if there is an issue with some of the meta-estimator which are note tested with the common tests. I'll check that.

@amueller
Copy link
Member Author

amueller commented Aug 5, 2019

I'm ambivalent about adding a whatsnew but I can do it if you think it's worth it. Probably should add a versionchanged?

@glemaitre
Copy link
Member

glemaitre commented Aug 5, 2019

I'm ambivalent about adding a whatsnew but I can do it if you think it's worth it. Probably should add a versionchanged?

I would say that a note in the versionchanged could be nice. In fact only a version changed would be needed I think. This is handy for 3rd party library to know what is going on where their tests are failing :)

@@ -866,21 +865,18 @@ def check_symmetric(array, tol=1E-10, raise_warning=True,
return array


def check_is_fitted(estimator, attributes, msg=None, all_or_any=all):
def check_is_fitted(estimator, attributes='deprecated', msg=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The behavior is already not back-compatible. Since we mention in the documentation that these utils can change from a version to another, I would not bother with a deprecation warning for the attributes parameters knowing that one can have some side-effect with all_or_any.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is it not backward compatible? Oh I could deprecate all_or_any as well?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If somebody is using all_or_any now, nothing would happen and this is not an attribute of the function as well. But as I mentioned, we clearly state in the documentation that utils are not following the deprecation cycle and can change: https://scikit-learn.org/stable/developers/utilities.html

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, only deprecating one doesn't make sense. But also see the discussion at #6616. Basically, the docs say that but people ignore it and it might not be good if we enforce it and should make things private instead.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I see. I am sure I was one of these people that complain at least once (then @lesteve show me the red box :))
I really feel that having the utils private could help to move quickly sometimes and help third-party project (at the cost of potential breaking if they use them). So deprecation it is :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so should I add one for all_or_any then?

@glemaitre
Copy link
Member

What is the behaviour expected on Pipeline?
This will fail:

from sklearn.datasets import load_iris                                               
from sklearn.preprocessing import StandardScaler                                     
from sklearn.linear_model import LogisticRegression                                  
from sklearn.pipeline import make_pipeline                                           
from sklearn.utils.validation import check_is_fitted                                 
                                                                                     
X, y = load_iris(return_X_y=True)                                                    
pipe = make_pipeline(StandardScaler(), LogisticRegression())                         
pipe.fit(X, y)                                                                       
check_is_fitted(pipe)
---------------------------------------------------------------------------
NotFittedError                            Traceback (most recent call last)
/tmp/tmp.py in <module>
      8 pipe = make_pipeline(StandardScaler(), LogisticRegression())
      9 pipe.fit(X, y)
---> 10 check_is_fitted(pipe)

~/Documents/code/toolbox/scikit-learn/sklearn/utils/validation.py in check_is_fitted(estimator, msg)
    910 
    911     if not len(attrs):
--> 912         raise NotFittedError(msg % {'name': type(estimator).__name__})
    913 
    914 

NotFittedError: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using this method.

@glemaitre
Copy link
Member

we should probably make a recursive call on each element of the Pipeline instance then?

@glemaitre
Copy link
Member

Do you still want to make a deprecation?

@thomasjpfan
Copy link
Member

Since we clearly state that the utils are not guaranteed to be stable, I would prefer not go through a deprecation cycle.

@amueller
Copy link
Member Author

amueller commented Aug 9, 2019

@thomasjpfan I would say that remark's there mostly to limit liability ;) see my remarks above. I think I'll edit to also deprecate the all_or_any and then we can merge?

Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you could add at least the suggestion in the docstring. It could be easier to find it for removal.
Otherwise LGTM

assert check_is_fitted(ard) is None
assert check_is_fitted(svr) is None

assert_warns_message(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pytest.mark.parametrize("params", [{'attributes': ['coefs_']}, {all_or_any=any}]
def test_check_is_fitted_deprecation(params):
    # FIXME: to be removed in 0.23
    warn_msg = 'Passing {} to check_is_fitted is deprecated'.format(list(params.keys())[0])
    with pytest.warns(DeprecationWarning, match=warn_msg):
        check_is_fitted(ard, **params)

It could be handy to have a separated test function to be removed next version.
We might use pytest (but the test will be removed anyway).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it easier to remove a test than to remove the asserts? A comment might be nice but it will also just fail and so we won't forget ;)

amueller and others added 4 commits August 12, 2019 15:48
Co-Authored-By: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Co-Authored-By: Guillaume Lemaitre <g.lemaitre58@gmail.com>
@glemaitre glemaitre merged commit 92af3da into scikit-learn:master Aug 13, 2019
@glemaitre
Copy link
Member

Good to go

@amueller
Copy link
Member Author

yay thanks for the reviews :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants