-
-
Notifications
You must be signed in to change notification settings - Fork 26k
FIX pipeline now checks if it's fitted #29868
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
sklearn/pipeline.py
Outdated
@@ -37,6 +38,18 @@ | |||
__all__ = ["Pipeline", "FeatureUnion", "make_pipeline", "make_union"] | |||
|
|||
|
|||
def _check_is_fitted(pipeline): | |||
try: | |||
check_is_fitted(pipeline) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if we should not also modify check_is_fitted
to be lenient with stateless estimator. Right now, one could expect to implement __sklearn_is_fitted__
but I don't think this is part of the API per se. So I'm wondering if check_is_fitted
should look at the tag requires_fit
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done via: #29880
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A couple of comments. But it looks good.
sklearn/pipeline.py
Outdated
@@ -575,18 +603,21 @@ def predict(self, X, **params): | |||
y_pred : ndarray | |||
Result of calling `predict` on the final estimator. | |||
""" | |||
Xt = X | |||
with _handle_warnings(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be worth to also add some TODO
next to each context manager to have more occurences.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe having a decorator instead of a context manager would avoid the extra indentation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would also like it to be a decorator, since this is concerning whole methods.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these methods are complicated, they already have decorators, and adding another decorator might complicate things. So I rather do the dirty-ish thing here and keep it as is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, @adrinjalali, I have worked through this and added some comments, which I hope will be helpful.
sklearn/pipeline.py
Outdated
@@ -575,18 +603,21 @@ def predict(self, X, **params): | |||
y_pred : ndarray | |||
Result of calling `predict` on the final estimator. | |||
""" | |||
Xt = X | |||
with _handle_warnings(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would also like it to be a decorator, since this is concerning whole methods.
def __sklearn_is_fitted__(self): | ||
return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the purpose of distributing self.fitted_ = True
and def __sklearn_is_fitted__(self)
across some of the mocking classes in this test file and in model_selection/test/test_validation
?
I believe this is not needed and I think it's blurring the boundaries between the test cases and makes them difficult to read without knowing this PR or searching for it in the future.
I find it neater if test classes are very cleanly only serving their own purpose.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is added since w/o it the test would fail.
Strictly speaking, adding a self.fitted_ = True
is used when you want check_is_fitted
to be okay after calling fit
, but you don't have anything else to set in fit
. __sklearn_check_is_fitted__
is added when you don't need the user to call fit
and the estimator is always considered fitted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now this PR is merged, but still: I had run all the concerned tests files without the additions before making this comment. They all passed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No they fail:
FAILED sklearn/tests/test_pipeline.py::test_metadata_routing_for_pipeline[decision_function] - FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.
FAILED sklearn/tests/test_pipeline.py::test_metadata_routing_for_pipeline[inverse_transform] - FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.
FAILED sklearn/tests/test_pipeline.py::test_metadata_routing_for_pipeline[predict] - FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.
FAILED sklearn/tests/test_pipeline.py::test_metadata_routing_for_pipeline[predict_log_proba] - FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.
FAILED sklearn/tests/test_pipeline.py::test_metadata_routing_for_pipeline[predict_proba] - FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.
FAILED sklearn/tests/test_pipeline.py::test_metadata_routing_for_pipeline[score] - FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.
FAILED sklearn/tests/test_pipeline.py::test_metadata_routing_for_pipeline[transform] - FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Running the tests with the -Werror::FutureWarning
flag or the SKLEARN_WARNINGS_AS_ERRORS=1
environmental variable shows the errors.
Thanks for the hint, @adrinjalali.
# TODO(1.8): remove this test | ||
def test_pipeline_warns_not_fitted(): | ||
class StatelessEstimator(BaseEstimator): | ||
def fit(self, X, y): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here mentioning explicitly what is lacking:
def fit(self, X, y): | |
def fit(self, X, y): | |
"""Doesn't create learned attributes.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the docstring now does enough explanation I think.
for _, estimator in reversed(self.steps): | ||
if estimator != "passthrough": | ||
last_step = estimator | ||
break |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought using break
is considered bad practice, isn't it? Not sure about its actual downsides though. Alternatively a while loop with "last step that is not 'passthrough'" as a stopping criterion, but that would look very complicated compared to the break
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there's nothing wrong with using break
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One alternative might be something like
last_step = next(
(
estimator for _, estimator in reversed(self.steps)
if estimator != "passthrough"
),
None,
)
but not everyone thinks this is cleaner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah that's hard to read 😁 but nice!
sklearn/pipeline.py
Outdated
"""A context manager to make sure a NotFittedError is raised, if a subestimator | ||
raises the error. | ||
|
||
Otherwise, we raise a warning if the pipeline is not fitted, with the deprecation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This context manager raises a warning instead of a NotFittedError, which differs from what is written here.
Maybe like this:
"""A context manager to make sure a NotFittedError is raised, if a subestimator | |
raises the error. | |
Otherwise, we raise a warning if the pipeline is not fitted, with the deprecation. | |
"""A context manager to raise a FutureWarning during the deprecation period, | |
if the last step of a pipeline raises a NotFittedError when it is not fitted. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now I see I got this wrong.
Would still be good to explain better what is supposed to happen.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think for a helper method which is only here for two versions during the deprecation cycle, it doesn't really matter.
@@ -253,6 +253,7 @@ def fit( | |||
P.shape[0], | |||
P.shape[1], | |||
) | |||
self.fitted_ = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need it here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cause check_is_fitted
checks for an attribute with a trailing underscore. This makes check_is_fitted(self)
to pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@Charlie-XIAO or @adam2392 might wanna have a look? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This overall LGTM! Just a small suggestion:
sklearn/pipeline.py
Outdated
@@ -37,6 +39,33 @@ | |||
__all__ = ["Pipeline", "FeatureUnion", "make_pipeline", "make_union"] | |||
|
|||
|
|||
@contextmanager | |||
def _handle_warnings(estimator): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was confused by this name before reading its docstring. Would something like _raise_or_warn_if_not_fitted
, or _ensure_fitted_or_warn
, or _handle_fit_status
be better?
Fixes #27014
This PR makes
Pipeline
to check if it's fitted in methods other thanfit*
, with a deprecation.cc @glemaitre @betatim @StefanieSenger