Skip to content

check_is_fitted gives false positive when extracted from ensemble classifier #18648

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
gkiar opened this issue Oct 19, 2020 · 7 comments
Closed

Comments

@gkiar
Copy link

gkiar commented Oct 19, 2020

Describe the bug

I trained several classifiers on my dataset, and then created an ensemble classifier (voting classifier) from them. While each of the estimators, stored at .estimators_, have been fit and used independently and within the ensemble, and even after extracting them from the ensemble, they fail a check_if_fitted test, so I cannot use them on their own in a context that checks for fit, or in another ensemble classifier.

Steps/Code to Reproduce

# Handle imports
from sklearn.pipeline import Pipeline
from sklearn.svm import SVC
from sklearn.decomposition import PCA
from sklearn.model_selection import KFold
from sklearn.ensemble import VotingClassifier
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.validation import check_is_fitted
from copy import deepcopy
import numpy as np

# Generate a dummy dataset
y = np.random.choice([0, 1], size=50)
X = np.zeros((len(y), 100))
for idx, _y in enumerate(y):
    X[idx, :] = 10*(np.random.random((100)) - 0.5) + int(_y)*0.75 + 20 * (np.random.random((100)) - 0.2)

yval = np.random.choice([0, 1], size=5)
Xval = np.zeros((len(yval), 100))

# Create and train classifiers across some folds
clf = Pipeline([('pca', PCA()), ('svm', SVC())])
cv = KFold(n_splits=5)

clfs = []
for idx, (train_idx, test_idx) in enumerate(cv.split(X, y)):
    X_train, X_test = X[train_idx], X[test_idx]
    y_train, y_test = y[train_idx], y[test_idx]
    
    tmpclf = deepcopy(clf)
    tmpclf.fit(X_train, y_train)
    clfs += [('fold{0}'.format(idx), tmpclf)]
    
    print(tmpclf.score(X_test, y_test))
print(clfs)

# Create and initialize VotingClassifier
vclf = VotingClassifier(clfs)

vclf.estimators_ = [c[1] for c in clfs]  # pass pre-fit estimators
vclf.le_ = LabelEncoder().fit(yval)
vclf.classes_ = vclf.le_.classes_

print(vclf.score(Xval, yval))

# Finally, and this is where the error occurs, extract original classifiers
orig_clf = vclf.estimators_[0]

print(orig_clf.score(Xval, yval))
check_is_fitted(orig_clf)

Expected Results

No error is thrown.

Actual Results

---------------------------------------------------------------------------
NotFittedError                            Traceback (most recent call last)
<ipython-input-77-d38acec5e641> in <module>
      2 
      3 print(orig_clf.score(Xval, yval))
----> 4 check_is_fitted(orig_clf)

~/code/env/agg/lib/python3.7/site-packages/sklearn/utils/validation.py in inner_f(*args, **kwargs)
     70                           FutureWarning)
     71         kwargs.update({k: arg for k, arg in zip(sig.parameters, args)})
---> 72         return f(**kwargs)
     73     return inner_f
     74 

~/code/env/agg/lib/python3.7/site-packages/sklearn/utils/validation.py in check_is_fitted(estimator, attributes, msg, all_or_any)
   1017 
   1018     if not attrs:
-> 1019         raise NotFittedError(msg % {'name': type(estimator).__name__})
   1020 
   1021 

NotFittedError: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using this estimator.

Versions

System:
    python: 3.7.3 (default, Dec 13 2019, 19:58:14)  [Clang 11.0.0 (clang-1100.0.33.17)]
executable: /Users/greg/code/env/agg/bin/python3
   machine: Darwin-19.6.0-x86_64-i386-64bit

Python dependencies:
          pip: 20.2.3
   setuptools: 50.3.0
      sklearn: 0.23.2
        numpy: 1.19.2
        scipy: 1.5.2
       Cython: None
       pandas: 1.1.3
   matplotlib: 3.3.2
       joblib: 0.17.0
threadpoolctl: 2.1.0

Built with OpenMP: True
@gkiar gkiar changed the title check_if_fitted gives false positive when extracted from ensemble classifier check_is_fitted gives false positive when extracted from ensemble classifier Oct 19, 2020
@alfaro96
Copy link
Member

Hey @gkiar,

The issue is unrelated with the ensemble, but with the Pipeline (meta-)estimator which has not fitted attributes (ending with a trailing underscore).

The following code snippet fails:

from sklearn.datasets import load_iris
from sklearn.pipeline import make_pipeline
from sklearn.tree import DecisionTreeClassifier
from sklearn.utils.validation import check_is_fitted

X, y = load_iris(return_X_y=True)
model = make_pipeline(DecisionTreeClassifier(random_state=0))
clf = model.fit(X, y)

check_is_fitted(clf)

I have found the following comment in the Pipeline (meta-)estimator code:

# shallow copy of steps - this should really be steps_
self.steps = list(self.steps)

I think that we should change self.steps by self.steps_ to address this comment. Thus, the Pipeline (meta-)estimator would provide fitted attributes and we should solve this issue.

Before confirming this is a bug, I would like to know the opinion of a core-developer (pinging @glemaitre that previously worked in the Pipeline (meta-)estimator).

@glemaitre
Copy link
Member

Before confirming this is a bug, I would like to know the opinion of a core-developer (pinging @glemaitre that previously worked in the Pipeline (meta-)estimator).

It is a bug that became a feature. It is not that easy to create a proper pipeline: #8350

We have now some usages that are working because we modify steps directly and make backward compatibility quite difficult.

Regarding check_is_fitted, I would have think that introducing n_features_in_ would have somehow solved the issue with Pipeline but it seems the implementation using vars is not enough. We still need the specific attribute to check_is_fitted:

check_is_fitted(vclf.estimators_[0], "n_features_in_")

@glemaitre
Copy link
Member

The reason is that n_features_in_ is a property in this case.

@jnothman
Copy link
Member

jnothman commented Oct 20, 2020 via email

@glemaitre
Copy link
Member

Hmm... We don't really recommend using check_is_fitted as a barrier
condition, precisely because we don't have a universal solution by which it
can be applied with a single argument.

This is true and this is one of the reasons that we reintroduce the attribute parameter.
@jnothman Do you think that n_features_in_ could play this universal solution.
If so hasattr should be enough and it would make check_is_fitted efficient as well.

@jonathan-taylor
Copy link

jonathan-taylor commented Oct 6, 2021

I come across similar issue with e.g. PCA and a scaler. It

import numpy as np
from sklearn.decomposition import PCA
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.utils.validation import check_is_fitted

scaler = StandardScaler(with_mean=True,
                                          with_std=True)
encoder = make_pipeline(scaler, PCA(n_components=2))
X = np.random.standard_normal((50, 5))
encoder.fit(X)
check_is_fitted(encoder)

This raises NotFittedError.

Meanwhile this works fine.

encoder.transform(X)

So, it seems it would be nice if something is set as being "fit" in the pipeline. Surely the length of the pipeline would be a simple (genuinely "fitted") quantity.

The check with n_features_in_ satisfies my use case if that is the acceptable solution.

@thomasjpfan
Copy link
Member

With scikit-learn 1.0, we introduced a new __sklearn_is_fitted__ API that is currently being used by Pipeline to denote if it is fitted. The advantage of this is that it allows for "stateless" estimators such as FunctionTransformer to say that is it always fitted.

On 1.0, the snippet in #18648 (comment) and the original issue now works.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

7 participants