Skip to content

Commit 150ff34

Browse files
[MRG] DOC add __sklearn_is_fitted__ example (scikit-learn#26618)
Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>
1 parent 02d20c1 commit 150ff34

File tree

3 files changed

+96
-0
lines changed

3 files changed

+96
-0
lines changed

doc/developers/develop.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,20 @@ only wrap the first array and not alter the other arrays.
709709
See :ref:`sphx_glr_auto_examples_miscellaneous_plot_set_output.py`
710710
for an example on how to use the API.
711711

712+
.. _developer_api_check_is_fitted:
713+
714+
Developer API for `check_is_fitted`
715+
===================================
716+
717+
By default :func:`~sklearn.utils.validation.check_is_fitted` checks if there
718+
are any attributes in the instance with a trailing underscore, e.g. `coef_`.
719+
An estimator can change the behavior by implementing a `__sklearn_is_fitted__`
720+
method taking no input and returning a boolean. If this method exists,
721+
:func:`~sklearn.utils.validation.check_is_fitted` simply returns its output.
722+
723+
See :ref:`sphx_glr_auto_examples_developing_estimators_sklearn_is_fitted.py`
724+
for an example on how to use the API.
725+
712726
.. _coding-guidelines:
713727

714728
Coding guidelines
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
.. _developing_estimator_examples:
2+
3+
Developing Estimators
4+
---------------------
5+
6+
Examples concerning the development of Custom Estimator.
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""
2+
========================================
3+
`__sklearn_is_fitted__` as Developer API
4+
========================================
5+
6+
The `__sklearn_is_fitted__` method is a convention used in scikit-learn for
7+
checking whether an estimator object has been fitted or not. This method is
8+
typically implemented in custom estimator classes that are built on top of
9+
scikit-learn's base classes like `BaseEstimator` or its subclasses.
10+
11+
Developers should use :func:`~sklearn.sklearn.utils.validation.check_is_fitted`
12+
at the beginning of all methods except `fit`. If they need to customize or
13+
speed-up the check, they can implement the `__sklearn_is_fitted__` method as
14+
shown below.
15+
16+
In this example the custom estimator showcases the usage of the
17+
`__sklearn_is_fitted__` method and the `check_is_fitted` utility function
18+
as developer APIs. The `__sklearn_is_fitted__` method checks fitted status
19+
by verifying the presence of the `_is_fitted` attribute.
20+
"""
21+
22+
# %%
23+
# An example custom estimator implementing a simple classifier
24+
# ------------------------------------------------------------
25+
# This code snippet defines a custom estimator class called `CustomEstimator`
26+
# that extends both the `BaseEstimator` and `ClassifierMixin` classes from
27+
# scikit-learn and showcases the usage of the `__sklearn_is_fitted__` method
28+
# and the `check_is_fitted` utility function.
29+
30+
# Author: Kushan <kushansharma1@gmail.com>
31+
#
32+
# License: BSD 3 clause
33+
34+
from sklearn.base import BaseEstimator, ClassifierMixin
35+
from sklearn.utils.validation import check_is_fitted
36+
37+
38+
class CustomEstimator(BaseEstimator, ClassifierMixin):
39+
def __init__(self, parameter=1):
40+
self.parameter = parameter
41+
42+
def fit(self, X, y):
43+
"""
44+
Fit the estimator to the training data.
45+
"""
46+
self.classes_ = sorted(set(y))
47+
# Custom attribute to track if the estimator is fitted
48+
self._is_fitted = True
49+
return self
50+
51+
def predict(self, X):
52+
"""
53+
Perform Predictions
54+
55+
If the estimator is not fitted, then raise NotFittedError
56+
"""
57+
check_is_fitted(self)
58+
# Perform prediction logic
59+
predictions = [self.classes_[0]] * len(X)
60+
return predictions
61+
62+
def score(self, X, y):
63+
"""
64+
Calculate Score
65+
66+
If the estimator is not fitted, then raise NotFittedError
67+
"""
68+
check_is_fitted(self)
69+
# Perform scoring logic
70+
return 0.5
71+
72+
def __sklearn_is_fitted__(self):
73+
"""
74+
Check fitted status and return a Boolean value.
75+
"""
76+
return hasattr(self, "_is_fitted") and self._is_fitted

0 commit comments

Comments
 (0)