Skip to content

DEP deprecate multi_class in LogisticRegression #28703

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
103d9ef
DEP deprecate multi_class in LogisticRegression
lorentzenchr Mar 26, 2024
c0901a3
DOC improve deprecation messages
lorentzenchr Apr 7, 2024
2e3605f
DOC improve solver support matrix
lorentzenchr Apr 7, 2024
fced0b2
FIX do not change self.multi_class in fit
lorentzenchr Apr 8, 2024
df4091e
DEP multi_class in LogisticRegressionCV
lorentzenchr Apr 8, 2024
c192198
DOC update linear model user guide
lorentzenchr Apr 8, 2024
a80142b
FIX typos
lorentzenchr Apr 8, 2024
1b67f3e
DOC / EXA adapt for deprecated ovr
lorentzenchr Apr 8, 2024
b4a6629
EXA fix coef in example
lorentzenchr Apr 9, 2024
398fd01
FIX develop.rst
lorentzenchr Apr 9, 2024
100ae1d
TST fix linear model tests
lorentzenchr Apr 9, 2024
1072da0
DOC / EXA fix examples and docstring with LogisticRegression(multi_cl…
lorentzenchr Apr 9, 2024
fc8f319
EXA fix plot_logistic_multinomial
lorentzenchr Apr 9, 2024
94a7ec5
TST catch all FutureWarnings in test_logistic.py
lorentzenchr Apr 10, 2024
9f03542
Merge branch 'main' into deprecate_ovr_logreg
lorentzenchr Apr 10, 2024
81161e5
TST make all tests pass with -Werror::FutureWarning
lorentzenchr Apr 11, 2024
8dc3c0d
Merge branch 'main' into deprecate_ovr_logreg
lorentzenchr Apr 11, 2024
a67bb05
DOC add whatsnew
lorentzenchr Apr 11, 2024
29f6b52
CLN whatsnew entry text alignment
lorentzenchr Apr 12, 2024
63a9399
CLN address review comments
lorentzenchr Apr 24, 2024
1350563
Merge branch 'main' into deprecate_ovr_logreg
jeremiedbb Apr 30, 2024
535fa9f
CLN address review comments
lorentzenchr Apr 30, 2024
179741e
ENH add additional warning for binary multinomial
lorentzenchr May 2, 2024
29d2e79
CLN filterwarnings in test_multinomial_binary
lorentzenchr May 2, 2024
10e73d8
Merge branch 'main' into deprecate_ovr_logreg
lorentzenchr May 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion asv_benchmarks/benchmarks/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def make_estimator(self, params):
estimator = LogisticRegression(
solver=solver,
penalty=penalty,
multi_class="multinomial",
tol=0.01,
n_jobs=n_jobs,
random_state=0,
Expand Down
11 changes: 9 additions & 2 deletions benchmarks/bench_saga.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss
from sklearn.model_selection import train_test_split
from sklearn.multiclass import OneVsRestClassifier
from sklearn.preprocessing import LabelBinarizer, LabelEncoder
from sklearn.utils.extmath import safe_sparse_dot, softmax
from sklearn.utils.parallel import Parallel, delayed
Expand Down Expand Up @@ -95,14 +96,15 @@ def fit_single(
else:
lr = LogisticRegression(
solver=solver,
multi_class=multi_class,
C=C,
penalty=penalty,
fit_intercept=False,
tol=0,
max_iter=this_max_iter,
random_state=42,
)
if multi_class == "ovr":
lr = OneVsRestClassifier(lr)

# Makes cpu cache even for all fit calls
X_train.max()
Expand All @@ -118,8 +120,12 @@ def fit_single(
except NotImplementedError:
# Lightning predict_proba is not implemented for n_classes > 2
y_pred = _predict_proba(lr, X)
if isinstance(lr, OneVsRestClassifier):
coef = np.concatenate([est.coef_ for est in lr.estimators_])
else:
coef = lr.coef_
score = log_loss(y, y_pred, normalize=False) / n_samples
score += 0.5 * alpha * np.sum(lr.coef_**2) + beta * np.sum(np.abs(lr.coef_))
score += 0.5 * alpha * np.sum(coef**2) + beta * np.sum(np.abs(coef))
scores.append(score)
train_score, test_score = tuple(scores)

Expand All @@ -133,6 +139,7 @@ def fit_single(


def _predict_proba(lr, X):
"""Predict proba for lightning for n_classes >=3."""
pred = safe_sparse_dot(X, lr.coef_.T)
if hasattr(lr, "intercept_"):
pred += lr.intercept_
Expand Down
2 changes: 1 addition & 1 deletion doc/developers/develop.rst
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ The parameter `deep` will control whether or not the parameters of the
subestimator__intercept_scaling -> 1
subestimator__l1_ratio -> None
subestimator__max_iter -> 100
subestimator__multi_class -> auto
subestimator__multi_class -> deprecated
subestimator__n_jobs -> None
subestimator__penalty -> l2
subestimator__random_state -> None
Expand Down
14 changes: 7 additions & 7 deletions doc/modules/linear_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1047,24 +1047,24 @@ Solvers
The solvers implemented in the class :class:`LogisticRegression`
are "lbfgs", "liblinear", "newton-cg", "newton-cholesky", "sag" and "saga":

The following table summarizes the penalties supported by each solver:
The following table summarizes the penalties and multinomial multiclass supported by each solver:

+------------------------------+-----------------+-------------+-----------------+-----------------------+-----------+------------+
| | **Solvers** |
+------------------------------+-------------+-----------------+-----------------+-----------------------+-----------+------------+
| **Penalties** | **'lbfgs'** | **'liblinear'** | **'newton-cg'** | **'newton-cholesky'** | **'sag'** | **'saga'** |
+------------------------------+-------------+-----------------+-----------------+-----------------------+-----------+------------+
| Multinomial + L2 penalty | yes | no | yes | no | yes | yes |
| L2 penalty | yes | no | yes | no | yes | yes |
+------------------------------+-------------+-----------------+-----------------+-----------------------+-----------+------------+
| OVR + L2 penalty | yes | yes | yes | yes | yes | yes |
| L1 penalty | no | yes | no | no | no | yes |
+------------------------------+-------------+-----------------+-----------------+-----------------------+-----------+------------+
| Multinomial + L1 penalty | no | no | no | no | no | yes |
| Elastic-Net (L1 + L2) | no | no | no | no | no | yes |
+------------------------------+-------------+-----------------+-----------------+-----------------------+-----------+------------+
| OVR + L1 penalty | no | yes | no | no | no | yes |
| No penalty ('none') | yes | no | yes | yes | yes | yes |
+------------------------------+-------------+-----------------+-----------------+-----------------------+-----------+------------+
| Elastic-Net | no | no | no | no | no | yes |
| **Multiclass support** | |
+------------------------------+-------------+-----------------+-----------------+-----------------------+-----------+------------+
| No penalty ('none') | yes | no | yes | yes | yes | yes |
| multinomial multiclass | yes | no | yes | no | yes | yes |
+------------------------------+-------------+-----------------+-----------------+-----------------------+-----------+------------+
| **Behaviors** | |
+------------------------------+-------------+-----------------+-----------------+-----------------------+-----------+------------+
Expand Down
8 changes: 4 additions & 4 deletions doc/modules/multiclass.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ can provide additional strategies beyond what is built-in:
- :class:`semi_supervised.LabelSpreading`
- :class:`discriminant_analysis.LinearDiscriminantAnalysis`
- :class:`svm.LinearSVC` (setting multi_class="crammer_singer")
- :class:`linear_model.LogisticRegression` (setting multi_class="multinomial")
- :class:`linear_model.LogisticRegressionCV` (setting multi_class="multinomial")
- :class:`linear_model.LogisticRegression` (with most solvers)
- :class:`linear_model.LogisticRegressionCV` (with most solvers)
- :class:`neural_network.MLPClassifier`
- :class:`neighbors.NearestCentroid`
- :class:`discriminant_analysis.QuadraticDiscriminantAnalysis`
Expand All @@ -86,8 +86,8 @@ can provide additional strategies beyond what is built-in:
- :class:`ensemble.GradientBoostingClassifier`
- :class:`gaussian_process.GaussianProcessClassifier` (setting multi_class = "one_vs_rest")
- :class:`svm.LinearSVC` (setting multi_class="ovr")
- :class:`linear_model.LogisticRegression` (setting multi_class="ovr")
- :class:`linear_model.LogisticRegressionCV` (setting multi_class="ovr")
- :class:`linear_model.LogisticRegression` (most solvers)
- :class:`linear_model.LogisticRegressionCV` (most solvers)
- :class:`linear_model.SGDClassifier`
- :class:`linear_model.Perceptron`
- :class:`linear_model.PassiveAggressiveClassifier`
Expand Down
8 changes: 8 additions & 0 deletions doc/whats_new/v1.5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,14 @@ Changelog
:class:`linear_model.SGDOneClassSVM`. Pass `average=False` instead.
:pr:`28582` by :user:`Jérémie du Boisberranger <jeremiedbb>`.

- |API| Parameter `multi_class` was deprecated in
:class:`linear_model.LogisticRegression` and
:class:`linear_model.LogisticRegressionCV`. `multi_class` will be removed in 1.7,
and internally, for 3 and more classes, it will always use multinomial.
If you still want to use the one-vs-rest scheme, you can use
`OneVsRestClassifier(LogisticRegression(..))`.
:pr:`28703` by :user:`Christian Lorentzen <lorentzenchr>`.

- |API| `store_cv_values` and `cv_values_` are deprecated in favor of
`store_cv_results` and `cv_results_` in `RidgeCV` and `RidgeClassifierCV`.
:pr:`28915` by :user:`Lucy Liu <lucyleeow>`.
Expand Down
15 changes: 7 additions & 8 deletions examples/classification/plot_classification_probability.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

Plot the classification probability for different classifiers. We use a 3 class
dataset, and we classify it with a Support Vector classifier, L1 and L2
penalized logistic regression with either a One-Vs-Rest or multinomial setting,
and Gaussian process classification.
penalized logistic regression (multinomial multiclass), a One-Vs-Rest version with
logistic regression, and Gaussian process classification.

Linear SVC is not a probabilistic classifier by default but it has a built-in
calibration option enabled in this example (`probability=True`).
Expand All @@ -30,6 +30,7 @@
from sklearn.inspection import DecisionBoundaryDisplay
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.multiclass import OneVsRestClassifier
from sklearn.svm import SVC

iris = datasets.load_iris()
Expand All @@ -43,14 +44,12 @@

# Create different classifiers.
classifiers = {
"L1 logistic": LogisticRegression(
C=C, penalty="l1", solver="saga", multi_class="multinomial", max_iter=10000
),
"L1 logistic": LogisticRegression(C=C, penalty="l1", solver="saga", max_iter=10000),
"L2 logistic (Multinomial)": LogisticRegression(
C=C, penalty="l2", solver="saga", multi_class="multinomial", max_iter=10000
C=C, penalty="l2", solver="saga", max_iter=10000
),
"L2 logistic (OvR)": LogisticRegression(
C=C, penalty="l2", solver="saga", multi_class="ovr", max_iter=10000
"L2 logistic (OvR)": OneVsRestClassifier(
LogisticRegression(C=C, penalty="l2", solver="saga", max_iter=10000)
),
"Linear SVC": SVC(kernel="linear", C=C, probability=True, random_state=0),
"GPC": GaussianProcessClassifier(kernel),
Expand Down
16 changes: 11 additions & 5 deletions examples/linear_model/plot_logistic_multinomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from sklearn.datasets import make_blobs
from sklearn.inspection import DecisionBoundaryDisplay
from sklearn.linear_model import LogisticRegression
from sklearn.multiclass import OneVsRestClassifier

# make 3-class dataset for classification
centers = [[-5, 0], [0, 1.5], [5, -1]]
Expand All @@ -26,9 +27,10 @@
X = np.dot(X, transformation)

for multi_class in ("multinomial", "ovr"):
clf = LogisticRegression(
solver="sag", max_iter=100, random_state=42, multi_class=multi_class
).fit(X, y)
clf = LogisticRegression(solver="sag", max_iter=100, random_state=42)
if multi_class == "ovr":
clf = OneVsRestClassifier(clf)
clf.fit(X, y)

# print the training scores
print("training score : %.3f (%s)" % (clf.score(X, y), multi_class))
Expand All @@ -51,8 +53,12 @@
# Plot the three one-against-all classifiers
xmin, xmax = plt.xlim()
ymin, ymax = plt.ylim()
coef = clf.coef_
intercept = clf.intercept_
if multi_class == "ovr":
coef = np.concatenate([est.coef_ for est in clf.estimators_])
intercept = np.concatenate([est.intercept_ for est in clf.estimators_])
else:
coef = clf.coef_
intercept = clf.intercept_
Comment on lines +56 to +61
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All these coef reconstruction make me wonder if we should enable OneVsRestClassifier to do it automatically. It could have a coef_ property as fitted attribute defined with an available_if on estimators_.coef_.

Not to do in this PR though, just a thought for later


def plot_hyperplane(c, color):
def line(x0):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from sklearn.exceptions import ConvergenceWarning
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.multiclass import OneVsRestClassifier

warnings.filterwarnings("ignore", category=ConvergenceWarning, module="sklearn")
t0 = timeit.default_timer()
Expand Down Expand Up @@ -76,20 +77,25 @@
"[model=%s, solver=%s] Number of epochs: %s"
% (model_params["name"], solver, this_max_iter)
)
lr = LogisticRegression(
clf = LogisticRegression(
solver=solver,
multi_class=model,
penalty="l1",
max_iter=this_max_iter,
random_state=42,
)
if model == "ovr":
clf = OneVsRestClassifier(clf)
t1 = timeit.default_timer()
lr.fit(X_train, y_train)
clf.fit(X_train, y_train)
train_time = timeit.default_timer() - t1

y_pred = lr.predict(X_test)
y_pred = clf.predict(X_test)
accuracy = np.sum(y_pred == y_test) / y_test.shape[0]
density = np.mean(lr.coef_ != 0, axis=1) * 100
if model == "ovr":
coef = np.concatenate([est.coef_ for est in clf.estimators_])
else:
coef = clf.coef_
density = np.mean(coef != 0, axis=1) * 100
accuracies.append(accuracy)
densities.append(density)
times.append(train_time)
Expand Down
2 changes: 1 addition & 1 deletion sklearn/ensemble/_voting.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ class VotingClassifier(ClassifierMixin, _BaseVoting):
>>> from sklearn.linear_model import LogisticRegression
>>> from sklearn.naive_bayes import GaussianNB
>>> from sklearn.ensemble import RandomForestClassifier, VotingClassifier
>>> clf1 = LogisticRegression(multi_class='multinomial', random_state=1)
>>> clf1 = LogisticRegression(random_state=1)
>>> clf2 = RandomForestClassifier(n_estimators=50, random_state=1)
>>> clf3 = GaussianNB()
>>> X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
Expand Down
Loading
Loading