Skip to content

[MRG] FIX towards disregarding multi_class in the case of binary y #11868

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions doc/whats_new/v0.20.rst
Original file line number Diff line number Diff line change
Expand Up @@ -440,10 +440,11 @@ Support for Python 3.3 has been officially dropped.
incorrect results in :class:`linear_model.LogisticRegressionCV`.
:issue:`11724` by :user:`Nicolas Hug <NicolasHug>`.

- |Fix| Fixed a bug in :class:`linear_model.LogisticRegression` where when using
the parameter ``multi_class='multinomial'``, the ``predict_proba`` method was
returning incorrect probabilities in the case of binary outcomes.
:issue:`9939` by :user:`Roger Westover <rwolst>`.
- |Fix| Fixed a bug in :class:`linear_model.LogisticRegression` where incorrect
coefficients and probabilities were reported for binary targets if
``multi_class='multinomial'``.
``multi_class`` is now disregarded for binary targets.
:issue:`XXX` by :user:`Roger Westover <rwolst>` and `Joel Nothman`_.

- |Fix| Fixed a bug in :class:`linear_model.LogisticRegressionCV` where the
``score`` method always computes accuracy, not the metric given by
Expand Down Expand Up @@ -765,7 +766,6 @@ Support for Python 3.3 has been officially dropped.
vector valued pseudocounts (alpha).
:issue:`10346` by :user:`Tobias Madsen <TobiasMadsen>`


:mod:`sklearn.neighbors`
........................

Expand Down
60 changes: 28 additions & 32 deletions sklearn/linear_model/logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,11 +605,12 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
classes = np.unique(y)
random_state = check_random_state(random_state)

if pos_class is None and multi_class != 'multinomial':
if (classes.size > 2):
raise ValueError('To fit OvR, use the pos_class argument')
if classes.size <= 2:
multi_class = 'ovr'
# np.unique(y) gives labels in sorted order.
pos_class = classes[1]
pos_class = classes[-1]
elif multi_class == 'ovr' and pos_class is None:
raise ValueError('To fit OvR, use the pos_class argument')

# If sample weights exist, convert them to array (support for lists)
# and check length
Expand Down Expand Up @@ -868,7 +869,7 @@ def _log_reg_scoring_path(X, y, train, test, pos_class=None, Cs=10,
chosen is 'ovr', then a binary problem is fit for each label. Else
the loss minimised is the multinomial loss fit across
the entire probability distribution. Does not work for 'liblinear'
solver.
solver. For binary problems, 'ovr' must be used here.

random_state : int, RandomState instance or None, optional, default None
The seed of the pseudo random number generator to use when shuffling
Expand Down Expand Up @@ -932,6 +933,7 @@ def _log_reg_scoring_path(X, y, train, test, pos_class=None, Cs=10,
log_reg.classes_ = np.array([-1, 1])
elif multi_class == 'multinomial':
log_reg.classes_ = np.unique(y_train)
assert log_reg.classes_.size, 'Must have multi_class="ovr" for binary'
else:
raise ValueError("multi_class should be either multinomial or ovr, "
"got %d" % multi_class)
Expand Down Expand Up @@ -1107,18 +1109,14 @@ class LogisticRegression(BaseEstimator, LinearClassifierMixin,
coef_ : array, shape (1, n_features) or (n_classes, n_features)
Coefficient of the features in the decision function.

`coef_` is of shape (1, n_features) when the given problem is binary.
In particular, when `multi_class='multinomial'`, `coef_` corresponds
to outcome 1 (True) and `-coef_` corresponds to outcome 0 (False).
`coef_` is of shape (1, n_features) when the given problem
is binary.

intercept_ : array, shape (1,) or (n_classes,)
Intercept (a.k.a. bias) added to the decision function.

If `fit_intercept` is set to False, the intercept is set to zero.
`intercept_` is of shape (1,) when the given problem is binary.
In particular, when `multi_class='multinomial'`, `intercept_`
corresponds to outcome 1 (True) and `-intercept_` corresponds to
outcome 0 (False).
`intercept_` is of shape(1,) when the problem is binary.

n_iter_ : array, shape (n_classes,) or (1, )
Actual number of iterations for all classes. If binary or multinomial,
Expand Down Expand Up @@ -1275,9 +1273,11 @@ def fit(self, X, y, sample_weight=None):
" in the data, but the data contains only one"
" class: %r" % classes_[0])

multi_class = self.multi_class
if len(self.classes_) == 2:
n_classes = 1
classes_ = classes_[1:]
multi_class = 'ovr'

if self.warm_start:
warm_start_coef = getattr(self, 'coef_', None)
Expand All @@ -1292,7 +1292,7 @@ def fit(self, X, y, sample_weight=None):
self.intercept_ = np.zeros(n_classes)

# Hack so that we iterate only once for the multinomial case.
if self.multi_class == 'multinomial':
if multi_class == 'multinomial':
classes_ = [None]
warm_start_coef = [warm_start_coef]
if warm_start_coef is None:
Expand All @@ -1311,7 +1311,7 @@ def fit(self, X, y, sample_weight=None):
path_func(X, y, pos_class=class_, Cs=[self.C],
fit_intercept=self.fit_intercept, tol=self.tol,
verbose=self.verbose, solver=self.solver,
multi_class=self.multi_class, max_iter=self.max_iter,
multi_class=multi_class, max_iter=self.max_iter,
class_weight=self.class_weight, check_input=False,
random_state=self.random_state, coef=warm_start_coef_,
penalty=self.penalty,
Expand All @@ -1322,7 +1322,7 @@ def fit(self, X, y, sample_weight=None):
fold_coefs_, _, n_iter_ = zip(*fold_coefs_)
self.n_iter_ = np.asarray(n_iter_, dtype=np.int32)[:, 0]

if self.multi_class == 'multinomial':
if multi_class == 'multinomial':
self.coef_ = fold_coefs_[0][0]
else:
self.coef_ = np.asarray(fold_coefs_)
Expand Down Expand Up @@ -1360,17 +1360,11 @@ def predict_proba(self, X):
"""
if not hasattr(self, "coef_"):
raise NotFittedError("Call fit before prediction")
if self.multi_class == "ovr":
calculate_ovr = self.coef_.shape[0] == 1 or self.multi_class == "ovr"
if calculate_ovr:
return super(LogisticRegression, self)._predict_proba_lr(X)
else:
decision = self.decision_function(X)
if decision.ndim == 1:
# Workaround for multi_class="multinomial" and binary outcomes
# which requires softmax prediction with only a 1D decision.
decision_2d = np.c_[-decision, decision]
else:
decision_2d = decision
return softmax(decision_2d, copy=False)
return softmax(self.decision_function(X), copy=False)

def predict_log_proba(self, X):
"""Log of probability estimates.
Expand Down Expand Up @@ -1697,16 +1691,18 @@ def fit(self, X, y, sample_weight=None):
" in the data, but the data contains only one"
" class: %r" % classes[0])

multi_class = self.multi_class
if n_classes == 2:
# OvR in case of binary problems is as good as fitting
# the higher label
n_classes = 1
encoded_labels = encoded_labels[1:]
classes = classes[1:]
multi_class = 'ovr'

# We need this hack to iterate only once over labels, in the case of
# multi_class = multinomial, without changing the value of the labels.
if self.multi_class == 'multinomial':
if multi_class == 'multinomial':
iter_encoded_labels = iter_classes = [None]
else:
iter_encoded_labels = encoded_labels
Expand Down Expand Up @@ -1734,7 +1730,7 @@ def fit(self, X, y, sample_weight=None):
dual=self.dual, solver=self.solver, tol=self.tol,
max_iter=self.max_iter, verbose=self.verbose,
class_weight=class_weight, scoring=self.scoring,
multi_class=self.multi_class,
multi_class=multi_class,
intercept_scaling=self.intercept_scaling,
random_state=self.random_state,
max_squared_sum=max_squared_sum,
Expand All @@ -1743,7 +1739,7 @@ def fit(self, X, y, sample_weight=None):
for label in iter_encoded_labels
for train, test in folds)

if self.multi_class == 'multinomial':
if multi_class == 'multinomial':
multi_coefs_paths, Cs, multi_scores, n_iter_ = zip(*fold_coefs_)
multi_coefs_paths = np.asarray(multi_coefs_paths)
multi_scores = np.asarray(multi_scores)
Expand Down Expand Up @@ -1780,14 +1776,14 @@ def fit(self, X, y, sample_weight=None):
self.intercept_ = np.zeros(n_classes)

# hack to iterate only once for multinomial case.
if self.multi_class == 'multinomial':
if multi_class == 'multinomial':
scores = multi_scores
coefs_paths = multi_coefs_paths

for index, (cls, encoded_label) in enumerate(
zip(iter_classes, iter_encoded_labels)):

if self.multi_class == 'ovr':
if multi_class == 'ovr':
# The scores_ / coefs_paths_ dict have unencoded class
# labels as their keys
scores = self.scores_[cls]
Expand All @@ -1798,7 +1794,7 @@ def fit(self, X, y, sample_weight=None):

C_ = self.Cs_[best_index]
self.C_.append(C_)
if self.multi_class == 'multinomial':
if multi_class == 'multinomial':
coef_init = np.mean(coefs_paths[:, best_index, :, :],
axis=0)
else:
Expand All @@ -1812,7 +1808,7 @@ def fit(self, X, y, sample_weight=None):
max_iter=self.max_iter, tol=self.tol,
penalty=self.penalty,
class_weight=class_weight,
multi_class=self.multi_class,
multi_class=multi_class,
verbose=max(0, self.verbose - 1),
random_state=self.random_state,
check_input=False, max_squared_sum=max_squared_sum,
Expand All @@ -1827,7 +1823,7 @@ def fit(self, X, y, sample_weight=None):
for i in range(len(folds))], axis=0)
self.C_.append(np.mean(self.Cs_[best_indices]))

if self.multi_class == 'multinomial':
if multi_class == 'multinomial':
self.C_ = np.tile(self.C_, n_classes)
self.coef_ = w[:, :X.shape[1]]
if self.fit_intercept:
Expand Down
39 changes: 21 additions & 18 deletions sklearn/linear_model/tests/test_logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,23 +247,6 @@ def test_multinomial_binary(solver):
assert_greater(np.mean(pred == target), .9)


def test_multinomial_binary_probabilities():
# Test multinomial LR gives expected probabilities based on the
# decision function, for a binary problem.
X, y = make_classification()
clf = LogisticRegression(multi_class='multinomial', solver='saga')
clf.fit(X, y)

decision = clf.decision_function(X)
proba = clf.predict_proba(X)

expected_proba_class_1 = (np.exp(decision) /
(np.exp(decision) + np.exp(-decision)))
expected_proba = np.c_[1-expected_proba_class_1, expected_proba_class_1]

assert_almost_equal(proba, expected_proba)


def test_sparsify():
# Test sparsify and densify members.
n_samples, n_features = iris.data.shape
Expand Down Expand Up @@ -1294,7 +1277,7 @@ def test_warm_start_converge_LR():

rng = np.random.RandomState(0)
X = np.concatenate((rng.randn(100, 2) + [1, 1], rng.randn(100, 2)))
y = np.array([1] * 100 + [-1] * 100)
y = np.array([1] * 100 + [2] * 50 + [3] * 50)
lr_no_ws = LogisticRegression(multi_class='multinomial',
solver='sag', warm_start=False,
random_state=0)
Expand Down Expand Up @@ -1327,3 +1310,23 @@ def test_logistic_regression_path_coefs_multinomial():
assert_array_almost_equal(coefs[0], coefs[2], decimal=1)
with pytest.raises(AssertionError):
assert_array_almost_equal(coefs[1], coefs[2], decimal=1)


@pytest.mark.parametrize('est', [LogisticRegression(),
LogisticRegressionCV(cv=3)])
@pytest.mark.parametrize('solver', ['newton-cg', 'lbfgs', 'sag', 'saga'])
def test_logistic_binary_invariant_to_multi_class(est, solver):
# Make sure that multi_class does not affect binary classification
X, y = make_classification(n_samples=200, n_classes=2, n_informative=2,
n_redundant=0, n_clusters_per_class=1,
random_state=0, n_features=2)
est.set_params(multi_class='ovr', solver=solver, max_iter=10000,
random_state=0)
est.fit(X, y)
model_ovr = np.hstack([np.squeeze(est.coef_), est.intercept_])
assert est.coef_.shape == (1, X.shape[1])
est.set_params(multi_class='multinomial')
est.fit(X, y)
assert est.coef_.shape == (1, X.shape[1])
model_multi = np.hstack([np.squeeze(est.coef_), est.intercept_])
assert_allclose(model_ovr, model_multi, rtol=0.01)