Skip to content

Commit 6ea371a

Browse files
committed
Merge branch 'pr/3710'
2 parents aa66dea + b584ac4 commit 6ea371a

File tree

3 files changed

+34
-8
lines changed

3 files changed

+34
-8
lines changed

doc/whats_new.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,13 @@ API changes summary
120120
- `n_jobs` parameter of the fit method shifted to the constructor of the
121121
LinearRegression class.
122122

123+
- The ``predict_proba`` method of :class:`multiclass.OneVsRestClassifier`
124+
now returns two probabilities per sample in the multiclass case; this
125+
is consistent with other estimators and with the method's documentation,
126+
but previous versions accidentally returned only the positive
127+
probability. Fixed by Will Lamond and `Lars Buitinck`_.
128+
129+
123130
.. _changes_0_15_2:
124131

125132
0.15.2

sklearn/multiclass.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
"OutputCodeClassifier",
5555
]
5656

57+
5758
def _fit_binary(estimator, X, y, classes=None):
5859
"""Fit a single binary estimator."""
5960
unique_y = np.unique(y)
@@ -299,17 +300,17 @@ def predict(self, X):
299300
else:
300301
thresh = .5
301302

303+
n_samples = _num_samples(X)
302304
if self.label_binarizer_.y_type_ == "multiclass":
303-
maxima = np.empty(X.shape[0], dtype=float)
305+
maxima = np.empty(n_samples, dtype=float)
304306
maxima.fill(-np.inf)
305-
argmaxima = np.zeros(X.shape[0], dtype=int)
307+
argmaxima = np.zeros(n_samples, dtype=int)
306308
for i, e in enumerate(self.estimators_):
307309
pred = _predict_binary(e, X)
308310
np.maximum(maxima, pred, out=maxima)
309311
argmaxima[maxima == pred] = i
310312
return self.label_binarizer_.classes_[np.array(argmaxima.T)]
311313
else:
312-
n_samples = _num_samples(X)
313314
indices = array.array('i')
314315
indptr = array.array('i', [0])
315316
for e in self.estimators_:
@@ -347,6 +348,11 @@ def predict_proba(self, X):
347348
# In the multi-label case, these are not disjoint.
348349
Y = np.array([e.predict_proba(X)[:, 1] for e in self.estimators_]).T
349350

351+
if len(self.estimators_) == 1:
352+
# Only one estimator, but we still want to return probabilities
353+
# for two classes.
354+
Y = np.concatenate(((1 - Y), Y), axis=1)
355+
350356
if not self.multilabel_:
351357
# Then, probabilities should be normalized to 1.
352358
Y /= np.sum(Y, axis=1)[:, np.newaxis]

sklearn/tests/test_multiclass.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
from sklearn.preprocessing import LabelBinarizer
2929

30-
from sklearn.svm import LinearSVC
30+
from sklearn.svm import LinearSVC, SVC
3131
from sklearn.naive_bayes import MultinomialNB
3232
from sklearn.linear_model import (LinearRegression, Lasso, ElasticNet, Ridge,
3333
Perceptron, LogisticRegression)
@@ -186,20 +186,33 @@ def test_ovr_binary():
186186

187187
classes = set("eggs spam".split())
188188

189-
for base_clf in (MultinomialNB(), LinearSVC(random_state=0),
190-
LinearRegression(), Ridge(),
191-
ElasticNet()):
192-
189+
def conduct_test(base_clf, test_predict_proba=False):
193190
clf = OneVsRestClassifier(base_clf).fit(X, y)
194191
assert_equal(set(clf.classes_), classes)
195192
y_pred = clf.predict(np.array([[0, 0, 4]]))[0]
196193
assert_equal(set(y_pred), set("eggs"))
197194

195+
if test_predict_proba:
196+
X_test = np.array([[0, 0, 4]])
197+
probabilities = clf.predict_proba(X_test)
198+
assert_equal(2, len(probabilities[0]))
199+
assert_equal(clf.classes_[np.argmax(probabilities, axis=1)],
200+
clf.predict(X_test))
201+
198202
# test input as label indicator matrix
199203
clf = OneVsRestClassifier(base_clf).fit(X, Y)
200204
y_pred = clf.predict([[3, 0, 0]])[0]
201205
assert_equal(y_pred, 1)
202206

207+
for base_clf in (LinearSVC(random_state=0), LinearRegression(),
208+
Ridge(), ElasticNet()):
209+
conduct_test(base_clf)
210+
211+
for base_clf in (MultinomialNB(), SVC(probability=True),
212+
LogisticRegression()):
213+
conduct_test(base_clf, test_predict_proba=True)
214+
215+
203216
@ignore_warnings
204217
def test_ovr_multilabel():
205218
# Toy dataset where features correspond directly to labels.

0 commit comments

Comments
 (0)