Skip to content

Commit 78e1530

Browse files
glemaitrethomasjpfanogrisel
authored andcommitted
FIX accept meta-estimator in SelfTrainingClassifier (#19126)
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 6eeb145 commit 78e1530

File tree

3 files changed

+47
-5
lines changed

3 files changed

+47
-5
lines changed

doc/whats_new/v0.24.rst

+17
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,23 @@
22

33
.. currentmodule:: sklearn
44

5+
.. _changes_0_24_1:
6+
7+
Version 0.24.1
8+
==============
9+
10+
Changelog
11+
---------
12+
13+
:mod:`sklearn.semi_supervised`
14+
..............................
15+
16+
- |Fix| :class:`semi_supervised.SelfTrainingClassifier` is now accepting
17+
meta-estimator (e.g. :class:`ensemble.StackingClassifier`). The validation
18+
of this estimator is done on the fitted estimator, once we know the existence
19+
of the method `predict_proba`.
20+
:pr:`19126` by :user:`Guillaume Lemaitre <glemaitre>`.
21+
522
.. _changes_0_24:
623

724
Version 0.24.0

sklearn/semi_supervised/_self_training.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -205,10 +205,10 @@ def fit(self, X, y):
205205
X[safe_mask(X, has_label)],
206206
self.transduction_[has_label])
207207

208-
if self.n_iter_ == 1:
209-
# Only validate in the first iteration so that n_iter=0 is
210-
# equivalent to the base_estimator itself.
211-
_validate_estimator(self.base_estimator)
208+
# Validate the fitted estimator since `predict_proba` can be
209+
# delegated to an underlying "final" fitted estimator as
210+
# generally done in meta-estimator or pipeline.
211+
_validate_estimator(self.base_estimator_)
212212

213213
# Predict on the unlabeled samples
214214
prob = self.base_estimator_.predict_proba(

sklearn/semi_supervised/tests/test_self_training.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@
44
from numpy.testing import assert_array_equal
55
import pytest
66

7+
from sklearn.ensemble import StackingClassifier
78
from sklearn.exceptions import NotFittedError
8-
from sklearn.semi_supervised import SelfTrainingClassifier
99
from sklearn.neighbors import KNeighborsClassifier
1010
from sklearn.svm import SVC
1111
from sklearn.model_selection import train_test_split
1212
from sklearn.datasets import load_iris, make_blobs
1313
from sklearn.metrics import accuracy_score
1414

15+
from sklearn.semi_supervised import SelfTrainingClassifier
16+
1517
# Author: Oliver Rausch <rauscho@ethz.ch>
1618
# License: BSD 3 clause
1719

@@ -318,3 +320,26 @@ def test_k_best_selects_best():
318320

319321
for row in most_confident_svc.tolist():
320322
assert row in added_by_st
323+
324+
325+
def test_base_estimator_meta_estimator():
326+
# Check that a meta-estimator relying on an estimator implementing
327+
# `predict_proba` will work even if it does expose this method before being
328+
# fitted.
329+
# Non-regression test for:
330+
# https://github.com/scikit-learn/scikit-learn/issues/19119
331+
332+
base_estimator = StackingClassifier(
333+
estimators=[
334+
("svc_1", SVC(probability=True)), ("svc_2", SVC(probability=True)),
335+
],
336+
final_estimator=SVC(probability=True), cv=2
337+
)
338+
339+
# make sure that the `base_estimator` does not expose `predict_proba`
340+
# without being fitted
341+
assert not hasattr(base_estimator, "predict_proba")
342+
343+
clf = SelfTrainingClassifier(base_estimator=base_estimator)
344+
clf.fit(X_train, y_train_missing_labels)
345+
clf.predict_proba(X_test)

0 commit comments

Comments
 (0)