Skip to content

ENH check_is_fitted calls __is_fitted__ if available #20657

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Aug 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/whats_new/v1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,12 @@ Changelog
unavailable on the basis of state, in a more readable way.
:pr:`19948` by `Joel Nothman`_.

_ |Enhancement| :func:`utils.validation.check_is_fitted` now uses
``__sklearn_is_fitted__`` if available, instead of checking for attributes ending with
an underscore. This also makes :class:`Pipeline` and
:class:`preprocessing.FunctionTransformer` pass
``check_is_fitted(estimator)``. :pr:`20657` by `Adrin Jalali`_.

- |Fix| Fixed a bug in :func:`utils.sparsefuncs.mean_variance_axis` where the
precision of the computed variance was very poor when the real variance is
exactly zero. :pr:`19766` by :user:`Jérémie du Boisberranger <jeremiedbb>`.
Expand Down
14 changes: 14 additions & 0 deletions sklearn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
from .utils.deprecation import deprecated
from .utils._tags import _safe_tags
from .utils.validation import check_memory
from .utils.validation import check_is_fitted
from .utils.fixes import delayed
from .exceptions import NotFittedError

from .utils.metaestimators import _BaseComposition

Expand Down Expand Up @@ -657,6 +659,18 @@ def n_features_in_(self):
# delegate to first step (which will call _check_is_fitted)
return self.steps[0][1].n_features_in_

def __sklearn_is_fitted__(self):
"""Indicate whether pipeline has been fit."""
try:
# check if the last step of the pipeline is fitted
# we only check the last step since if the last step is fit, it
# means the previous steps should also be fit. This is faster than
# checking if every step of the pipeline is fit.
check_is_fitted(self.steps[-1][1])
Copy link
Member

@ogrisel ogrisel Aug 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fact that we only check the last step of the pipeline is to be nice with users that have a currently working pipeline with their own custom stateless transformers that would fail the check_is_fitted check if we were to have this property call this check_is_fitted on all steps?

Or is there another reason?

In both cases it might be worth it to make that explicit in the inline comment.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@glemaitre 's script was checking for the first step, I thought it makes sense to do it for the last step. I'm agnostic on how we do it. I'll add a comment.

return True
except NotFittedError:
return False

def _sk_visual_block_(self):
_, estimators = zip(*self.steps)

Expand Down
4 changes: 4 additions & 0 deletions sklearn/preprocessing/_function_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,5 +176,9 @@ def _transform(self, X, func=None, kw_args=None):

return func(X, **(kw_args if kw_args else {}))

def __sklearn_is_fitted__(self):
"""Return True since FunctionTransfomer is stateless."""
return True

def _more_tags(self):
return {"no_validation": not self.validate, "stateless": True}
16 changes: 15 additions & 1 deletion sklearn/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
MinimalRegressor,
MinimalTransformer,
)

from sklearn.exceptions import NotFittedError
from sklearn.utils.validation import check_is_fitted
from sklearn.base import clone, is_classifier, BaseEstimator, TransformerMixin
from sklearn.pipeline import Pipeline, FeatureUnion, make_pipeline, make_union
from sklearn.svm import SVC
Expand Down Expand Up @@ -1361,3 +1362,16 @@ def test_search_cv_using_minimal_compatible_estimator(Predictor):
else:
assert_allclose(y_pred, y.mean())
assert model.score(X, y) == pytest.approx(r2_score(y, y_pred))


def test_pipeline_check_if_fitted():
class Estimator(BaseEstimator):
def fit(self, X, y):
self.fitted_ = True
return self

pipeline = Pipeline([("clf", Estimator())])
with pytest.raises(NotFittedError):
check_is_fitted(pipeline)
pipeline.fit(iris.data, iris.target)
check_is_fitted(pipeline)
41 changes: 41 additions & 0 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from ..model_selection import ShuffleSplit
from ..model_selection._validation import _safe_split
from ..metrics.pairwise import rbf_kernel, linear_kernel, pairwise_distances
from ..utils.validation import check_is_fitted

from . import shuffle
from ._tags import (
Expand Down Expand Up @@ -305,6 +306,7 @@ def _yield_all_checks(estimator):
yield check_dict_unchanged
yield check_dont_overwrite_parameters
yield check_fit_idempotent
yield check_fit_check_is_fitted
if not tags["no_validation"]:
yield check_n_features_in
yield check_fit1d
Expand Down Expand Up @@ -3493,6 +3495,45 @@ def check_fit_idempotent(name, estimator_orig):
)


def check_fit_check_is_fitted(name, estimator_orig):
# Make sure that estimator doesn't pass check_is_fitted before calling fit
# and that passes check_is_fitted once it's fit.

rng = np.random.RandomState(42)

estimator = clone(estimator_orig)
set_random_state(estimator)
if "warm_start" in estimator.get_params():
estimator.set_params(warm_start=False)

n_samples = 100
X = rng.normal(loc=100, size=(n_samples, 2))
X = _pairwise_estimator_convert_X(X, estimator)
if is_regressor(estimator_orig):
y = rng.normal(size=n_samples)
else:
y = rng.randint(low=0, high=2, size=n_samples)
y = _enforce_estimator_tags_y(estimator, y)

if not _safe_tags(estimator).get("stateless", False):
# stateless estimators (such as FunctionTransformer) are always "fit"!
try:
check_is_fitted(estimator)
raise AssertionError(
f"{estimator.__class__.__name__} passes check_is_fitted before being"
" fit!"
)
except NotFittedError:
pass
estimator.fit(X, y)
try:
check_is_fitted(estimator)
except NotFittedError as e:
raise NotFittedError(
"Estimator fails to pass `check_is_fitted` even though it has been fit."
) from e


def check_n_features_in(name, estimator_orig):
# Make sure that n_features_in_ attribute doesn't exist until fit is
# called, and that its value is correct.
Expand Down
27 changes: 27 additions & 0 deletions sklearn/utils/tests/test_estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from sklearn.utils.validation import check_array
from sklearn.utils import all_estimators
from sklearn.exceptions import SkipTestWarning
from sklearn.utils.metaestimators import available_if

from sklearn.utils.estimator_checks import (
_NotAnArray,
Expand All @@ -51,6 +52,7 @@
check_regressor_data_not_an_array,
check_outlier_corruption,
set_random_state,
check_fit_check_is_fitted,
)


Expand Down Expand Up @@ -986,3 +988,28 @@ def test_minimal_class_implementation_checks():
minimal_estimators = [MinimalTransformer(), MinimalRegressor(), MinimalClassifier()]
for estimator in minimal_estimators:
check_estimator(estimator)


def test_check_fit_check_is_fitted():
class Estimator(BaseEstimator):
def __init__(self, behavior="attribute"):
self.behavior = behavior

def fit(self, X, y, **kwargs):
if self.behavior == "attribute":
self.is_fitted_ = True
elif self.behavior == "method":
self._is_fitted = True
return self

@available_if(lambda self: self.behavior in {"method", "always-true"})
def __sklearn_is_fitted__(self):
if self.behavior == "always-true":
return True
return hasattr(self, "_is_fitted")

with raises(Exception, match="passes check_is_fitted before being fit"):
check_fit_check_is_fitted("estimator", Estimator(behavior="always-true"))

check_fit_check_is_fitted("estimator", Estimator(behavior="method"))
check_fit_check_is_fitted("estimator", Estimator(behavior="attribute"))
16 changes: 15 additions & 1 deletion sklearn/utils/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
FLOAT_DTYPES,
)
from sklearn.utils.validation import _check_fit_params

from sklearn.base import BaseEstimator
import sklearn

from sklearn.exceptions import NotFittedError, PositiveSpectrumWarning
Expand Down Expand Up @@ -750,6 +750,20 @@ def test_check_symmetric():
assert_array_equal(output, arr_sym)


def test_check_is_fitted_with_is_fitted():
class Estimator(BaseEstimator):
def fit(self, **kwargs):
self._is_fitted = True
return self

def __sklearn_is_fitted__(self):
return hasattr(self, "_is_fitted") and self._is_fitted

with pytest.raises(NotFittedError):
check_is_fitted(Estimator())
check_is_fitted(Estimator().fit())


def test_check_is_fitted():
# Check is TypeError raised when non estimator instance passed
with pytest.raises(TypeError):
Expand Down
13 changes: 8 additions & 5 deletions sklearn/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1142,8 +1142,9 @@ def check_is_fitted(estimator, attributes=None, *, msg=None, all_or_any=all):
fitted attributes (ending with a trailing underscore) and otherwise
raises a NotFittedError with the given message.

This utility is meant to be used internally by estimators themselves,
typically in their own predict / transform methods.
If an estimator does not set any attributes with a trailing underscore, it
can define a ``__sklearn_is_fitted__`` method returning a boolean to specify if the
estimator is fitted or not.

Parameters
----------
Expand Down Expand Up @@ -1194,13 +1195,15 @@ def check_is_fitted(estimator, attributes=None, *, msg=None, all_or_any=all):
if attributes is not None:
if not isinstance(attributes, (list, tuple)):
attributes = [attributes]
attrs = all_or_any([hasattr(estimator, attr) for attr in attributes])
fitted = all_or_any([hasattr(estimator, attr) for attr in attributes])
elif hasattr(estimator, "__sklearn_is_fitted__"):
fitted = estimator.__sklearn_is_fitted__()
else:
attrs = [
fitted = [
v for v in vars(estimator) if v.endswith("_") and not v.startswith("__")
]

if not attrs:
if not fitted:
raise NotFittedError(msg % {"name": type(estimator).__name__})


Expand Down