|
| 1 | +""" |
| 2 | +======================================== |
| 3 | +`__sklearn_is_fitted__` as Developer API |
| 4 | +======================================== |
| 5 | +
|
| 6 | +The `__sklearn_is_fitted__` method is a convention used in scikit-learn for |
| 7 | +checking whether an estimator object has been fitted or not. This method is |
| 8 | +typically implemented in custom estimator classes that are built on top of |
| 9 | +scikit-learn's base classes like `BaseEstimator` or its subclasses. |
| 10 | +
|
| 11 | +Developers should use :func:`~sklearn.sklearn.utils.validation.check_is_fitted` |
| 12 | +at the beginning of all methods except `fit`. If they need to customize or |
| 13 | +speed-up the check, they can implement the `__sklearn_is_fitted__` method as |
| 14 | +shown below. |
| 15 | +
|
| 16 | +In this example the custom estimator showcases the usage of the |
| 17 | +`__sklearn_is_fitted__` method and the `check_is_fitted` utility function |
| 18 | +as developer APIs. The `__sklearn_is_fitted__` method checks fitted status |
| 19 | +by verifying the presence of the `_is_fitted` attribute. |
| 20 | +""" |
| 21 | + |
| 22 | +# %% |
| 23 | +# An example custom estimator implementing a simple classifier |
| 24 | +# ------------------------------------------------------------ |
| 25 | +# This code snippet defines a custom estimator class called `CustomEstimator` |
| 26 | +# that extends both the `BaseEstimator` and `ClassifierMixin` classes from |
| 27 | +# scikit-learn and showcases the usage of the `__sklearn_is_fitted__` method |
| 28 | +# and the `check_is_fitted` utility function. |
| 29 | + |
| 30 | +# Author: Kushan <kushansharma1@gmail.com> |
| 31 | +# |
| 32 | +# License: BSD 3 clause |
| 33 | + |
| 34 | +from sklearn.base import BaseEstimator, ClassifierMixin |
| 35 | +from sklearn.utils.validation import check_is_fitted |
| 36 | + |
| 37 | + |
| 38 | +class CustomEstimator(BaseEstimator, ClassifierMixin): |
| 39 | + def __init__(self, parameter=1): |
| 40 | + self.parameter = parameter |
| 41 | + |
| 42 | + def fit(self, X, y): |
| 43 | + """ |
| 44 | + Fit the estimator to the training data. |
| 45 | + """ |
| 46 | + self.classes_ = sorted(set(y)) |
| 47 | + # Custom attribute to track if the estimator is fitted |
| 48 | + self._is_fitted = True |
| 49 | + return self |
| 50 | + |
| 51 | + def predict(self, X): |
| 52 | + """ |
| 53 | + Perform Predictions |
| 54 | +
|
| 55 | + If the estimator is not fitted, then raise NotFittedError |
| 56 | + """ |
| 57 | + check_is_fitted(self) |
| 58 | + # Perform prediction logic |
| 59 | + predictions = [self.classes_[0]] * len(X) |
| 60 | + return predictions |
| 61 | + |
| 62 | + def score(self, X, y): |
| 63 | + """ |
| 64 | + Calculate Score |
| 65 | +
|
| 66 | + If the estimator is not fitted, then raise NotFittedError |
| 67 | + """ |
| 68 | + check_is_fitted(self) |
| 69 | + # Perform scoring logic |
| 70 | + return 0.5 |
| 71 | + |
| 72 | + def __sklearn_is_fitted__(self): |
| 73 | + """ |
| 74 | + Check fitted status and return a Boolean value. |
| 75 | + """ |
| 76 | + return hasattr(self, "_is_fitted") and self._is_fitted |
0 commit comments