@@ -1101,14 +1101,18 @@ class LogisticRegression(BaseEstimator, LinearClassifierMixin,
1101
1101
coef_ : array, shape (1, n_features) or (n_classes, n_features)
1102
1102
Coefficient of the features in the decision function.
1103
1103
1104
- `coef_` is of shape (1, n_features) when the given problem
1105
- is binary.
1104
+ `coef_` is of shape (1, n_features) when the given problem is binary.
1105
+ In particular, when `multi_class='multinomial'`, `coef_` corresponds
1106
+ to outcome 1 (True) and `-coef_` corresponds to outcome 0 (False).
1106
1107
1107
1108
intercept_ : array, shape (1,) or (n_classes,)
1108
1109
Intercept (a.k.a. bias) added to the decision function.
1109
1110
1110
1111
If `fit_intercept` is set to False, the intercept is set to zero.
1111
- `intercept_` is of shape(1,) when the problem is binary.
1112
+ `intercept_` is of shape (1,) when the given problem is binary.
1113
+ In particular, when `multi_class='multinomial'`, `intercept_`
1114
+ corresponds to outcome 1 (True) and `-intercept_` corresponds to
1115
+ outcome 0 (False).
1112
1116
1113
1117
n_iter_ : array, shape (n_classes,) or (1, )
1114
1118
Actual number of iterations for all classes. If binary or multinomial,
@@ -1332,11 +1336,17 @@ def predict_proba(self, X):
1332
1336
"""
1333
1337
if not hasattr (self , "coef_" ):
1334
1338
raise NotFittedError ("Call fit before prediction" )
1335
- calculate_ovr = self .coef_ .shape [0 ] == 1 or self .multi_class == "ovr"
1336
- if calculate_ovr :
1339
+ if self .multi_class == "ovr" :
1337
1340
return super (LogisticRegression , self )._predict_proba_lr (X )
1338
1341
else :
1339
- return softmax (self .decision_function (X ), copy = False )
1342
+ decision = self .decision_function (X )
1343
+ if decision .ndim == 1 :
1344
+ # Workaround for multi_class="multinomial" and binary outcomes
1345
+ # which requires softmax prediction with only a 1D decision.
1346
+ decision_2d = np .c_ [- decision , decision ]
1347
+ else :
1348
+ decision_2d = decision
1349
+ return softmax (decision_2d , copy = False )
1340
1350
1341
1351
def predict_log_proba (self , X ):
1342
1352
"""Log of probability estimates.
0 commit comments