Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions sklearn/semi_supervised/_self_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ class SelfTrainingClassifier(MetaEstimatorMixin, BaseEstimator):

.. versionadded:: 0.24

feature_names_in_ : ndarray of shape (`n_features_in_`,)
Names of features seen during :term:`fit`. Defined only when `X`
has feature names that are all strings.

.. versionadded:: 1.0

n_iter_ : int
The number of rounds of self-training, that is the number of times the
base estimator is fitted on relabeled variants of the training set.
Expand Down Expand Up @@ -285,6 +291,7 @@ def predict(self, X):
Array with predicted labels.
"""
check_is_fitted(self)
self._check_feature_names(X, reset=False)
return self.base_estimator_.predict(X)

def predict_proba(self, X):
Expand All @@ -301,6 +308,7 @@ def predict_proba(self, X):
Array with prediction probabilities.
"""
check_is_fitted(self)
self._check_feature_names(X, reset=False)
return self.base_estimator_.predict_proba(X)

@if_delegate_has_method(delegate="base_estimator")
Expand All @@ -318,6 +326,7 @@ def decision_function(self, X):
Result of the decision function of the `base_estimator`.
"""
check_is_fitted(self)
self._check_feature_names(X, reset=False)
return self.base_estimator_.decision_function(X)

@if_delegate_has_method(delegate="base_estimator")
Expand All @@ -335,6 +344,7 @@ def predict_log_proba(self, X):
Array with log prediction probabilities.
"""
check_is_fitted(self)
self._check_feature_names(X, reset=False)
return self.base_estimator_.predict_log_proba(X)

@if_delegate_has_method(delegate="base_estimator")
Expand All @@ -355,4 +365,5 @@ def score(self, X, y):
Result of calling score on the `base_estimator`.
"""
check_is_fitted(self)
self._check_feature_names(X, reset=False)
return self.base_estimator_.score(X, y)
1 change: 0 additions & 1 deletion sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,6 @@ def test_check_n_features_in_after_fitting(estimator):
"kernel_approximation",
"model_selection",
"multioutput",
"semi_supervised",
}

_estimators_to_test = list(
Expand Down