Skip to content

Commit 54d1ec9

Browse files
authored
ENH make check_is_fitted to always pass on stateless estimators (scikit-learn#29880)
1 parent c7839c4 commit 54d1ec9

File tree

6 files changed

+41
-3
lines changed

6 files changed

+41
-3
lines changed

doc/whats_new/v1.6.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,11 @@ Changelog
367367
calling :func:`utils.validation.check_non_negative`.
368368
:pr:`29540` by :user:`Tamara Atanasoska <tamaraatanasoska>`.
369369

370+
- |Enhancement| :func:`utils.validation.check_is_fitted` now passes on stateless
371+
estimators. An estimator can indicate it's stateless by setting the `requires_fit`
372+
tag. See :ref:`estimator_tags` for more information.
373+
:pr:`29880` by `Adrin Jalali`_.
374+
370375
- |API| the `assert_all_finite` parameter of functions :func:`utils.check_array`,
371376
:func:`utils.check_X_y`, :func:`utils.as_float_array` is renamed into
372377
`ensure_all_finite`. `force_all_finite` will be removed in 1.8.

sklearn/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -863,7 +863,10 @@ def get_feature_names_out(self, input_features=None):
863863
feature_names_out : ndarray of str objects
864864
Same as input features.
865865
"""
866-
check_is_fitted(self, "n_features_in_")
866+
# Note that passing attributes="n_features_in_" forces check_is_fitted
867+
# to check if the attribute is present. Otherwise it will pass on
868+
# stateless estimators (requires_fit=False)
869+
check_is_fitted(self, attributes="n_features_in_")
867870
return _check_feature_names_in(self, input_features)
868871

869872

sklearn/kernel_approximation.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -747,7 +747,10 @@ def get_feature_names_out(self, input_features=None):
747747
feature_names_out : ndarray of str objects
748748
Transformed feature names.
749749
"""
750-
check_is_fitted(self, "n_features_in_")
750+
# Note that passing attributes="n_features_in_" forces check_is_fitted
751+
# to check if the attribute is present. Otherwise it will pass on this
752+
# stateless estimator (requires_fit=False)
753+
check_is_fitted(self, attributes="n_features_in_")
751754
input_features = _check_feature_names_in(
752755
self, input_features, generate_names=True
753756
)

sklearn/utils/_tags.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,8 @@ class RegressorTags:
180180
class Tags:
181181
"""Tags for the estimator.
182182
183+
See :ref:`estimator_tags` for more information.
184+
183185
Parameters
184186
----------
185187
target_tags : :class:`TargetTags`

sklearn/utils/tests/test_validation.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,21 @@ def __sklearn_is_fitted__(self):
894894
check_is_fitted(Estimator().fit())
895895

896896

897+
def test_check_is_fitted_stateless():
898+
"""Check that check_is_fitted passes for stateless estimators."""
899+
900+
class StatelessEstimator(BaseEstimator):
901+
def fit(self, **kwargs):
902+
return self # pragma: no cover
903+
904+
def __sklearn_tags__(self):
905+
tags = super().__sklearn_tags__()
906+
tags.requires_fit = False
907+
return tags
908+
909+
check_is_fitted(StatelessEstimator())
910+
911+
897912
def test_check_is_fitted():
898913
# Check is TypeError raised when non estimator instance passed
899914
with pytest.raises(TypeError):

sklearn/utils/validation.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1656,14 +1656,19 @@ def check_is_fitted(estimator, attributes=None, *, msg=None, all_or_any=all):
16561656
16571657
Checks if the estimator is fitted by verifying the presence of
16581658
fitted attributes (ending with a trailing underscore) and otherwise
1659-
raises a NotFittedError with the given message.
1659+
raises a :class:`~sklearn.exceptions.NotFittedError` with the given message.
16601660
16611661
If an estimator does not set any attributes with a trailing underscore, it
16621662
can define a ``__sklearn_is_fitted__`` method returning a boolean to
16631663
specify if the estimator is fitted or not. See
16641664
:ref:`sphx_glr_auto_examples_developing_estimators_sklearn_is_fitted.py`
16651665
for an example on how to use the API.
16661666
1667+
If no `attributes` are passed, this fuction will pass if an estimator is stateless.
1668+
An estimator can indicate it's stateless by setting the `requires_fit` tag. See
1669+
:ref:`estimator_tags` for more information. Note that the `requires_fit` tag
1670+
is ignored if `attributes` are passed.
1671+
16671672
Parameters
16681673
----------
16691674
estimator : estimator instance
@@ -1724,6 +1729,11 @@ def check_is_fitted(estimator, attributes=None, *, msg=None, all_or_any=all):
17241729
if not hasattr(estimator, "fit"):
17251730
raise TypeError("%s is not an estimator instance." % (estimator))
17261731

1732+
tags = get_tags(estimator)
1733+
1734+
if not tags.requires_fit and attributes is None:
1735+
return
1736+
17271737
if not _is_fitted(estimator, attributes, all_or_any):
17281738
raise NotFittedError(msg % {"name": type(estimator).__name__})
17291739

0 commit comments

Comments
 (0)