Skip to content

FIX pipeline now checks if it's fitted #29868

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Oct 17, 2024

Conversation

adrinjalali
Copy link
Member

@adrinjalali adrinjalali commented Sep 17, 2024

Fixes #27014

This PR makes Pipeline to check if it's fitted in methods other than fit*, with a deprecation.

cc @glemaitre @betatim @StefanieSenger

Copy link

github-actions bot commented Sep 17, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 6094f5a. Link to the linter CI: here

@adrinjalali adrinjalali added this to the 1.6 milestone Sep 17, 2024
@@ -37,6 +38,18 @@
__all__ = ["Pipeline", "FeatureUnion", "make_pipeline", "make_union"]


def _check_is_fitted(pipeline):
try:
check_is_fitted(pipeline)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we should not also modify check_is_fitted to be lenient with stateless estimator. Right now, one could expect to implement __sklearn_is_fitted__ but I don't think this is part of the API per se. So I'm wondering if check_is_fitted should look at the tag requires_fit?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done via: #29880

@glemaitre glemaitre self-requested a review September 28, 2024 11:17
Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple of comments. But it looks good.

@@ -575,18 +603,21 @@ def predict(self, X, **params):
y_pred : ndarray
Result of calling `predict` on the final estimator.
"""
Xt = X
with _handle_warnings(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be worth to also add some TODO next to each context manager to have more occurences.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe having a decorator instead of a context manager would avoid the extra indentation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also like it to be a decorator, since this is concerning whole methods.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these methods are complicated, they already have decorators, and adding another decorator might complicate things. So I rather do the dirty-ish thing here and keep it as is.

Copy link
Contributor

@StefanieSenger StefanieSenger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @adrinjalali, I have worked through this and added some comments, which I hope will be helpful.

@@ -575,18 +603,21 @@ def predict(self, X, **params):
y_pred : ndarray
Result of calling `predict` on the final estimator.
"""
Xt = X
with _handle_warnings(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also like it to be a decorator, since this is concerning whole methods.

Comment on lines +1879 to +1880
def __sklearn_is_fitted__(self):
return True
Copy link
Contributor

@StefanieSenger StefanieSenger Oct 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of distributing self.fitted_ = True and def __sklearn_is_fitted__(self) across some of the mocking classes in this test file and in model_selection/test/test_validation?

I believe this is not needed and I think it's blurring the boundaries between the test cases and makes them difficult to read without knowing this PR or searching for it in the future.

I find it neater if test classes are very cleanly only serving their own purpose.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is added since w/o it the test would fail.

Strictly speaking, adding a self.fitted_ = True is used when you want check_is_fitted to be okay after calling fit, but you don't have anything else to set in fit. __sklearn_check_is_fitted__ is added when you don't need the user to call fit and the estimator is always considered fitted.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now this PR is merged, but still: I had run all the concerned tests files without the additions before making this comment. They all passed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No they fail:

FAILED sklearn/tests/test_pipeline.py::test_metadata_routing_for_pipeline[decision_function] - FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.
FAILED sklearn/tests/test_pipeline.py::test_metadata_routing_for_pipeline[inverse_transform] - FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.
FAILED sklearn/tests/test_pipeline.py::test_metadata_routing_for_pipeline[predict] - FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.
FAILED sklearn/tests/test_pipeline.py::test_metadata_routing_for_pipeline[predict_log_proba] - FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.
FAILED sklearn/tests/test_pipeline.py::test_metadata_routing_for_pipeline[predict_proba] - FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.
FAILED sklearn/tests/test_pipeline.py::test_metadata_routing_for_pipeline[score] - FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.
FAILED sklearn/tests/test_pipeline.py::test_metadata_routing_for_pipeline[transform] - FutureWarning: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using other methods such as transform, predict, etc. This will raise an error in 1.8 instead of the current warning.

Copy link
Contributor

@StefanieSenger StefanieSenger Oct 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Running the tests with the -Werror::FutureWarning flag or the SKLEARN_WARNINGS_AS_ERRORS=1 environmental variable shows the errors.
Thanks for the hint, @adrinjalali.

# TODO(1.8): remove this test
def test_pipeline_warns_not_fitted():
class StatelessEstimator(BaseEstimator):
def fit(self, X, y):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here mentioning explicitly what is lacking:

Suggested change
def fit(self, X, y):
def fit(self, X, y):
"""Doesn't create learned attributes."""

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the docstring now does enough explanation I think.

for _, estimator in reversed(self.steps):
if estimator != "passthrough":
last_step = estimator
break
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought using break is considered bad practice, isn't it? Not sure about its actual downsides though. Alternatively a while loop with "last step that is not 'passthrough'" as a stopping criterion, but that would look very complicated compared to the break.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's nothing wrong with using break.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One alternative might be something like

        last_step = next(
            (
                estimator for _, estimator in reversed(self.steps)
                if estimator != "passthrough"
            ),
            None,
        )

but not everyone thinks this is cleaner.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that's hard to read 😁 but nice!

Comment on lines 44 to 47
"""A context manager to make sure a NotFittedError is raised, if a subestimator
raises the error.

Otherwise, we raise a warning if the pipeline is not fitted, with the deprecation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This context manager raises a warning instead of a NotFittedError, which differs from what is written here.

Maybe like this:

Suggested change
"""A context manager to make sure a NotFittedError is raised, if a subestimator
raises the error.
Otherwise, we raise a warning if the pipeline is not fitted, with the deprecation.
"""A context manager to raise a FutureWarning during the deprecation period,
if the last step of a pipeline raises a NotFittedError when it is not fitted.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I see I got this wrong.
Would still be good to explain better what is supposed to happen.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for a helper method which is only here for two versions during the deprecation cycle, it doesn't really matter.

@@ -253,6 +253,7 @@ def fit(
P.shape[0],
P.shape[1],
)
self.fitted_ = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need it here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cause check_is_fitted checks for an attribute with a trailing underscore. This makes check_is_fitted(self) to pass.

Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@adrinjalali
Copy link
Member Author

@Charlie-XIAO or @adam2392 might wanna have a look?

@adrinjalali adrinjalali added the Waiting for Second Reviewer First reviewer is done, need a second one! label Oct 17, 2024
Copy link
Contributor

@Charlie-XIAO Charlie-XIAO left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This overall LGTM! Just a small suggestion:

@@ -37,6 +39,33 @@
__all__ = ["Pipeline", "FeatureUnion", "make_pipeline", "make_union"]


@contextmanager
def _handle_warnings(estimator):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was confused by this name before reading its docstring. Would something like _raise_or_warn_if_not_fitted, or _ensure_fitted_or_warn, or _handle_fit_status be better?

@adrinjalali adrinjalali added No Changelog Needed and removed Waiting for Second Reviewer First reviewer is done, need a second one! labels Oct 17, 2024
@adrinjalali adrinjalali enabled auto-merge (squash) October 17, 2024 12:40
@adrinjalali adrinjalali merged commit 4dfbfb9 into scikit-learn:main Oct 17, 2024
32 of 33 checks passed
@adrinjalali adrinjalali deleted the pipeline/fit branch October 18, 2024 09:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Pipeline throws TypeError on stateless transformers
4 participants