Skip to content

Commit 4dafa52

Browse files
rwolstjnothman
authored andcommitted
FIX incorrect predict_proba for LogisticRegression in binary case using multinomial parameter. (#9939)
1 parent f247ad5 commit 4dafa52

File tree

3 files changed

+38
-6
lines changed

3 files changed

+38
-6
lines changed

doc/whats_new/v0.20.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,11 @@ Classifiers and regressors
180180
error for prior list which summed to 1.
181181
:issue:`10005` by :user:`Gaurav Dhingra <gxyd>`.
182182

183+
- Fixed a bug in :class:`linear_model.LogisticRegression` where when using the
184+
parameter ``multi_class='multinomial'``, the ``predict_proba`` method was
185+
returning incorrect probabilities in the case of binary outcomes.
186+
:issue:`9939` by :user:`Roger Westover <rwolst>`.
187+
183188
- Fixed a bug in :class:`linear_model.OrthogonalMatchingPursuit` that was
184189
broken when setting ``normalize=False``.
185190
:issue:`10071` by `Alexandre Gramfort`_.

sklearn/linear_model/logistic.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,14 +1101,18 @@ class LogisticRegression(BaseEstimator, LinearClassifierMixin,
11011101
coef_ : array, shape (1, n_features) or (n_classes, n_features)
11021102
Coefficient of the features in the decision function.
11031103
1104-
`coef_` is of shape (1, n_features) when the given problem
1105-
is binary.
1104+
`coef_` is of shape (1, n_features) when the given problem is binary.
1105+
In particular, when `multi_class='multinomial'`, `coef_` corresponds
1106+
to outcome 1 (True) and `-coef_` corresponds to outcome 0 (False).
11061107
11071108
intercept_ : array, shape (1,) or (n_classes,)
11081109
Intercept (a.k.a. bias) added to the decision function.
11091110
11101111
If `fit_intercept` is set to False, the intercept is set to zero.
1111-
`intercept_` is of shape(1,) when the problem is binary.
1112+
`intercept_` is of shape (1,) when the given problem is binary.
1113+
In particular, when `multi_class='multinomial'`, `intercept_`
1114+
corresponds to outcome 1 (True) and `-intercept_` corresponds to
1115+
outcome 0 (False).
11121116
11131117
n_iter_ : array, shape (n_classes,) or (1, )
11141118
Actual number of iterations for all classes. If binary or multinomial,
@@ -1332,11 +1336,17 @@ def predict_proba(self, X):
13321336
"""
13331337
if not hasattr(self, "coef_"):
13341338
raise NotFittedError("Call fit before prediction")
1335-
calculate_ovr = self.coef_.shape[0] == 1 or self.multi_class == "ovr"
1336-
if calculate_ovr:
1339+
if self.multi_class == "ovr":
13371340
return super(LogisticRegression, self)._predict_proba_lr(X)
13381341
else:
1339-
return softmax(self.decision_function(X), copy=False)
1342+
decision = self.decision_function(X)
1343+
if decision.ndim == 1:
1344+
# Workaround for multi_class="multinomial" and binary outcomes
1345+
# which requires softmax prediction with only a 1D decision.
1346+
decision_2d = np.c_[-decision, decision]
1347+
else:
1348+
decision_2d = decision
1349+
return softmax(decision_2d, copy=False)
13401350

13411351
def predict_log_proba(self, X):
13421352
"""Log of probability estimates.

sklearn/linear_model/tests/test_logistic.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,23 @@ def test_multinomial_binary():
198198
assert_greater(np.mean(pred == target), .9)
199199

200200

201+
def test_multinomial_binary_probabilities():
202+
# Test multinomial LR gives expected probabilities based on the
203+
# decision function, for a binary problem.
204+
X, y = make_classification()
205+
clf = LogisticRegression(multi_class='multinomial', solver='saga')
206+
clf.fit(X, y)
207+
208+
decision = clf.decision_function(X)
209+
proba = clf.predict_proba(X)
210+
211+
expected_proba_class_1 = (np.exp(decision) /
212+
(np.exp(decision) + np.exp(-decision)))
213+
expected_proba = np.c_[1-expected_proba_class_1, expected_proba_class_1]
214+
215+
assert_almost_equal(proba, expected_proba)
216+
217+
201218
def test_sparsify():
202219
# Test sparsify and densify members.
203220
n_samples, n_features = iris.data.shape

0 commit comments

Comments
 (0)