Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats_new/v1.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,9 @@ Changelog
parameter so progress on fitting can be seen.
:pr:`22508` by :user:`Chris Combs <combscCode>`.

- |Fix| :meth:`multiclass.OneVsOneClassifier.predict` returns correct predictions when
the inner classifier only has a :term:`predict_proba`. :pr:`22604` by `Thomas Fan`_.

:mod:`sklearn.neighbors`
........................

Expand Down
19 changes: 12 additions & 7 deletions sklearn/multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,15 @@ def _predict_binary(estimator, X):
return score


def _threshold_for_binary_predict(estimator):
"""Threshold for predictions from binary estimator."""
if hasattr(estimator, "decision_function") and is_classifier(estimator):
return 0.0
else:
# predict_proba threshold
return 0.5


def _check_estimator(estimator):
"""Make sure that an estimator implements the necessary methods."""
if not hasattr(estimator, "decision_function") and not hasattr(
Expand Down Expand Up @@ -426,12 +435,7 @@ def predict(self, X):
argmaxima[maxima == pred] = i
return self.classes_[argmaxima]
else:
if hasattr(self.estimators_[0], "decision_function") and is_classifier(
self.estimators_[0]
):
thresh = 0
else:
thresh = 0.5
thresh = _threshold_for_binary_predict(self.estimators_[0])
indices = array.array("i")
indptr = array.array("i", [0])
for e in self.estimators_:
Expand Down Expand Up @@ -770,7 +774,8 @@ def predict(self, X):
"""
Y = self.decision_function(X)
if self.n_classes_ == 2:
return self.classes_[(Y > 0).astype(int)]
thresh = _threshold_for_binary_predict(self.estimators_[0])
return self.classes_[(Y > thresh).astype(int)]
return self.classes_[Y.argmax(axis=1)]

def decision_function(self, X):
Expand Down
18 changes: 18 additions & 0 deletions sklearn/tests/test_multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@
SGDClassifier,
)
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import GridSearchCV, cross_val_score
from sklearn.pipeline import Pipeline, make_pipeline
from sklearn.impute import SimpleImputer
from sklearn import svm
from sklearn.exceptions import NotFittedError
from sklearn import datasets
from sklearn.datasets import load_breast_cancer

iris = datasets.load_iris()
rng = np.random.RandomState(0)
Expand Down Expand Up @@ -906,3 +908,19 @@ def test_constant_int_target(make_y):
expected = np.zeros((X.shape[0], 2))
expected[:, 0] = 1
assert_allclose(y_pred, expected)


def test_ovo_consistent_binary_classification():
"""Check that ovo is consistent with binary classifier.

Non-regression test for #13617.
"""
X, y = load_breast_cancer(return_X_y=True)

clf = KNeighborsClassifier(n_neighbors=8, weights="distance")
ovo = OneVsOneClassifier(clf)

clf.fit(X, y)
ovo.fit(X, y)

assert_array_equal(clf.predict(X), ovo.predict(X))