-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
check_is_fitted gives false positive when extracted from ensemble classifier #18648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Hey @gkiar, The issue is unrelated with the ensemble, but with the The following code snippet fails: from sklearn.datasets import load_iris
from sklearn.pipeline import make_pipeline
from sklearn.tree import DecisionTreeClassifier
from sklearn.utils.validation import check_is_fitted
X, y = load_iris(return_X_y=True)
model = make_pipeline(DecisionTreeClassifier(random_state=0))
clf = model.fit(X, y)
check_is_fitted(clf) I have found the following comment in the scikit-learn/sklearn/pipeline.py Lines 263 to 264 in 5d5329c
I think that we should change Before confirming this is a bug, I would like to know the opinion of a core-developer (pinging @glemaitre that previously worked in the |
It is a bug that became a feature. It is not that easy to create a proper pipeline: #8350 We have now some usages that are working because we modify Regarding check_is_fitted(vclf.estimators_[0], "n_features_in_") |
The reason is that |
Hmm... We don't really recommend using check_is_fitted as a barrier
condition, precisely because we don't have a universal solution by which it
can be applied with a single argument.
|
This is true and this is one of the reasons that we reintroduce the |
I come across similar issue with e.g. PCA and a scaler. It
This raises Meanwhile this works fine.
So, it seems it would be nice if something is set as being "fit" in the pipeline. Surely the length of the pipeline would be a simple (genuinely "fitted") quantity. The check with |
With On 1.0, the snippet in #18648 (comment) and the original issue now works. |
Describe the bug
I trained several classifiers on my dataset, and then created an ensemble classifier (voting classifier) from them. While each of the estimators, stored at
.estimators_
, have been fit and used independently and within the ensemble, and even after extracting them from the ensemble, they fail acheck_if_fitted
test, so I cannot use them on their own in a context that checks for fit, or in anotherensemble
classifier.Steps/Code to Reproduce
Expected Results
No error is thrown.
Actual Results
Versions
The text was updated successfully, but these errors were encountered: