diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index f64a6bda6ea95..56eb642729b36 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -697,6 +697,9 @@ Changelog parameter so progress on fitting can be seen. :pr:`22508` by :user:`Chris Combs `. +- |Fix| :meth:`multiclass.OneVsOneClassifier.predict` returns correct predictions when + the inner classifier only has a :term:`predict_proba`. :pr:`22604` by `Thomas Fan`_. + :mod:`sklearn.neighbors` ........................ diff --git a/sklearn/multiclass.py b/sklearn/multiclass.py index e100bb4ef99dc..b46b4bfb8b5ef 100644 --- a/sklearn/multiclass.py +++ b/sklearn/multiclass.py @@ -102,6 +102,15 @@ def _predict_binary(estimator, X): return score +def _threshold_for_binary_predict(estimator): + """Threshold for predictions from binary estimator.""" + if hasattr(estimator, "decision_function") and is_classifier(estimator): + return 0.0 + else: + # predict_proba threshold + return 0.5 + + def _check_estimator(estimator): """Make sure that an estimator implements the necessary methods.""" if not hasattr(estimator, "decision_function") and not hasattr( @@ -426,12 +435,7 @@ def predict(self, X): argmaxima[maxima == pred] = i return self.classes_[argmaxima] else: - if hasattr(self.estimators_[0], "decision_function") and is_classifier( - self.estimators_[0] - ): - thresh = 0 - else: - thresh = 0.5 + thresh = _threshold_for_binary_predict(self.estimators_[0]) indices = array.array("i") indptr = array.array("i", [0]) for e in self.estimators_: @@ -770,7 +774,8 @@ def predict(self, X): """ Y = self.decision_function(X) if self.n_classes_ == 2: - return self.classes_[(Y > 0).astype(int)] + thresh = _threshold_for_binary_predict(self.estimators_[0]) + return self.classes_[(Y > thresh).astype(int)] return self.classes_[Y.argmax(axis=1)] def decision_function(self, X): diff --git a/sklearn/tests/test_multiclass.py b/sklearn/tests/test_multiclass.py index 9571b43b3d746..a3621414ae793 100644 --- a/sklearn/tests/test_multiclass.py +++ b/sklearn/tests/test_multiclass.py @@ -32,12 +32,14 @@ SGDClassifier, ) from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor +from sklearn.neighbors import KNeighborsClassifier from sklearn.model_selection import GridSearchCV, cross_val_score from sklearn.pipeline import Pipeline, make_pipeline from sklearn.impute import SimpleImputer from sklearn import svm from sklearn.exceptions import NotFittedError from sklearn import datasets +from sklearn.datasets import load_breast_cancer iris = datasets.load_iris() rng = np.random.RandomState(0) @@ -906,3 +908,19 @@ def test_constant_int_target(make_y): expected = np.zeros((X.shape[0], 2)) expected[:, 0] = 1 assert_allclose(y_pred, expected) + + +def test_ovo_consistent_binary_classification(): + """Check that ovo is consistent with binary classifier. + + Non-regression test for #13617. + """ + X, y = load_breast_cancer(return_X_y=True) + + clf = KNeighborsClassifier(n_neighbors=8, weights="distance") + ovo = OneVsOneClassifier(clf) + + clf.fit(X, y) + ovo.fit(X, y) + + assert_array_equal(clf.predict(X), ovo.predict(X))