|
4 | 4 | from numpy.testing import assert_array_equal
|
5 | 5 | import pytest
|
6 | 6 |
|
| 7 | +from sklearn.ensemble import StackingClassifier |
7 | 8 | from sklearn.exceptions import NotFittedError
|
8 |
| -from sklearn.semi_supervised import SelfTrainingClassifier |
9 | 9 | from sklearn.neighbors import KNeighborsClassifier
|
10 | 10 | from sklearn.svm import SVC
|
11 | 11 | from sklearn.model_selection import train_test_split
|
12 | 12 | from sklearn.datasets import load_iris, make_blobs
|
13 | 13 | from sklearn.metrics import accuracy_score
|
14 | 14 |
|
| 15 | +from sklearn.semi_supervised import SelfTrainingClassifier |
| 16 | + |
15 | 17 | # Author: Oliver Rausch <rauscho@ethz.ch>
|
16 | 18 | # License: BSD 3 clause
|
17 | 19 |
|
@@ -318,3 +320,26 @@ def test_k_best_selects_best():
|
318 | 320 |
|
319 | 321 | for row in most_confident_svc.tolist():
|
320 | 322 | assert row in added_by_st
|
| 323 | + |
| 324 | + |
| 325 | +def test_base_estimator_meta_estimator(): |
| 326 | + # Check that a meta-estimator relying on an estimator implementing |
| 327 | + # `predict_proba` will work even if it does expose this method before being |
| 328 | + # fitted. |
| 329 | + # Non-regression test for: |
| 330 | + # https://github.com/scikit-learn/scikit-learn/issues/19119 |
| 331 | + |
| 332 | + base_estimator = StackingClassifier( |
| 333 | + estimators=[ |
| 334 | + ("svc_1", SVC(probability=True)), ("svc_2", SVC(probability=True)), |
| 335 | + ], |
| 336 | + final_estimator=SVC(probability=True), cv=2 |
| 337 | + ) |
| 338 | + |
| 339 | + # make sure that the `base_estimator` does not expose `predict_proba` |
| 340 | + # without being fitted |
| 341 | + assert not hasattr(base_estimator, "predict_proba") |
| 342 | + |
| 343 | + clf = SelfTrainingClassifier(base_estimator=base_estimator) |
| 344 | + clf.fit(X_train, y_train_missing_labels) |
| 345 | + clf.predict_proba(X_test) |
0 commit comments