Skip to content

ENH Add trust-ncg solver to LogisticRegression #22236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 41 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
70ecf59
Initial commit
rithvikrao Jul 9, 2020
1ad6650
Remove extraneous print statement
rithvikrao Jul 9, 2020
ba513a0
Merge branch 'master' into logistic
rithvikrao Jul 23, 2020
03381b3
remove trust-krylov solver
rubywerman Jul 24, 2020
9a0e192
remove trust-krylov solver
rubywerman Jul 24, 2020
a9e6208
revert changes on this file, not finished with it yet
rubywerman Jul 24, 2020
0559dd1
add hess suggestions
rubywerman Jul 28, 2020
ea103e1
change hessp parameter
rubywerman Jul 28, 2020
6b27b1d
change hessp signature
rubywerman Aug 7, 2020
b4a016d
add trust-ncg to table
rubywerman Aug 17, 2020
8702ed7
add use case for trust-ncg
rubywerman Aug 17, 2020
9c5adbb
Merge branch 'main' into logistic
Micky774 Jan 17, 2022
e3205c7
Merge branch 'main' into logistic
Micky774 Jan 22, 2022
d750f54
Merge branch 'main' into logistic
Micky774 Jan 23, 2022
d29e0f4
Updated cocs for `_logistic.py`
Micky774 Jan 24, 2022
f624b78
Merge branch 'main' into logistic
Micky774 Jan 24, 2022
4a7b9f1
Added tmp files for establishing benchmark
Micky774 Jan 24, 2022
781d53c
Changed graph style
Micky774 Jan 25, 2022
36ef601
Merge remote-tracking branch 'upstream/main' into logistic
Micky774 Jan 25, 2022
559ca3f
Added saved benchmark results
Micky774 Jan 25, 2022
f77de69
Updated benchmark file w/ validation capability
Micky774 Jan 27, 2022
d4b60b4
Added title to graphs
Micky774 Jan 27, 2022
098fabd
Merge branch 'main' into logistic
Micky774 Jan 28, 2022
9a4de1a
Added new benchmark file
Micky774 Jan 28, 2022
0f2474e
Improved benchmarking display
Micky774 Jan 28, 2022
f9270d9
Saved df w/ dense and sparse comparison
Micky774 Jan 28, 2022
509498c
Merge branch 'main' into logistic
Micky774 Jan 30, 2022
0b5a209
Merge branch 'main' into logistic
Micky774 May 24, 2022
2b8114f
Fixed typo
Micky774 May 24, 2022
b62821c
Tried to reconcile loss update, and improved tests
Micky774 May 24, 2022
98f3970
Updated `trust-ncg` to use scalar loss w/o gradient
Micky774 May 24, 2022
3b058ec
Added broken version for debugging
Micky774 Jun 8, 2022
c36f628
Fixed implementation of `trust-ncg`
Micky774 Jun 8, 2022
87876ad
Add trust-ncg to solver equality tests
Micky774 Jun 8, 2022
0969a67
Removed old dataframe files
Micky774 Jun 8, 2022
c19244a
Merge branch 'main' into logistic
Micky774 Jun 8, 2022
bf66c88
Removed old benchmark files and expanded tests
Micky774 Jun 9, 2022
cb8d256
Added changelog entry
Micky774 Jun 9, 2022
3711628
Updated comments over current hack
Micky774 Jun 9, 2022
f9e1939
Merge branch 'main' into logistic
Micky774 Jun 9, 2022
0da5868
Merge branch 'main' into logistic
Micky774 Jun 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 26 additions & 25 deletions doc/modules/linear_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -974,36 +974,37 @@ small data-sets but for larger datasets its performance suffers. [9]_

The following table summarizes the penalties supported by each solver:

+------------------------------+-----------------+-------------+-----------------+-----------+------------+
| | **Solvers** |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
| **Penalties** | **'liblinear'** | **'lbfgs'** | **'newton-cg'** | **'sag'** | **'saga'** |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
| Multinomial + L2 penalty | no | yes | yes | yes | yes |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
| OVR + L2 penalty | yes | yes | yes | yes | yes |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
| Multinomial + L1 penalty | no | no | no | no | yes |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
| OVR + L1 penalty | yes | no | no | no | yes |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
| Elastic-Net | no | no | no | no | yes |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
| No penalty ('none') | no | yes | yes | yes | yes |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
| **Behaviors** | |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
| Penalize the intercept (bad) | yes | no | no | no | no |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
| Faster for large datasets | no | no | no | yes | yes |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
| Robust to unscaled datasets | yes | yes | yes | no | no |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
+------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+
| | **Solvers** |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+
| **Penalties** | **'liblinear'** | **'lbfgs'** | **'newton-cg'** | **'sag'** | **'saga'** | **'trust-ncg'** |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+
| Multinomial + L2 penalty | no | yes | yes | yes | yes | yes |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+
| OVR + L2 penalty | yes | yes | yes | yes | yes | yes |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+
| Multinomial + L1 penalty | no | no | no | no | yes | no |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+
| OVR + L1 penalty | yes | no | no | no | yes | no |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+
| Elastic-Net | no | no | no | no | yes | no |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+
| No penalty ('none') | no | yes | yes | yes | yes | yes |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+
| **Behaviors** | | |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+
| Penalize the intercept (bad) | yes | no | no | no | no | no |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+
| Faster for large datasets | no | no | no | yes | yes | yes |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+
| Robust to unscaled datasets | yes | yes | yes | no | no | yes |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+

The "lbfgs" solver is used by default for its robustness. For large datasets
the "saga" solver is usually faster.
For large dataset, you may also consider using :class:`SGDClassifier`
with 'log' loss, which might be even faster but requires more tuning.
The "trust-ncg" solver is useful for sparse datasets.

.. topic:: Examples:

Expand Down
9 changes: 9 additions & 0 deletions doc/whats_new/v1.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,15 @@ Changelog
- |Fix| Use dtype-aware tolerances for the validation of gram matrices (passed by users
or precomputed). :pr:`22059` by :user:`Malte S. Kurz <MalteKurz>`.

- |Enhancement| Added `trust-ncg` solver to `linear_model.LogisticRegression`.
This solver is fast for large datasets, similar to `saga`, while also being
robust to unscaled data, similar to `lbfgs`. The solver may require a higher
value of `max_iter` for proper convergence. It is compatible with both `ovr`
and `multinomial` options for `multi_class`, and can optimize with an `L2`
penalty.
:pr:`17877` by :user:`Rithvik Rao <rithvikrao>` and :user:`Ruby Werman <rubywerman>`,
and :pr:`22236` by :user:`Meekail Zain <micky774>`.

- |Fix| Fixed an error in :class:`linear_model.LogisticRegression` with
`solver="newton-cg"`, `fit_intercept=True`, and a single feature. :pr:`23608`
by `Tom Dupre la Tour`_.
Expand Down
77 changes: 56 additions & 21 deletions sklearn/linear_model/_logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@


def _check_solver(solver, penalty, dual):
all_solvers = ["liblinear", "newton-cg", "lbfgs", "sag", "saga"]
all_solvers = ["liblinear", "newton-cg", "lbfgs", "sag", "saga", "trust-ncg"]
if solver not in all_solvers:
raise ValueError(
"Logistic Regression supports only solvers in %s, got %s."
Expand Down Expand Up @@ -158,8 +158,8 @@ def _logistic_regression_path(
Maximum number of iterations for the solver.

tol : float, default=1e-4
Stopping criterion. For the newton-cg and lbfgs solvers, the iteration
will stop when ``max{|g_i | i = 1, ..., n} <= tol``
Stopping criterion. For the newton-cg, trust-ncg, and lbfgs solvers,
the iteration will stop when ``max{|g_i | i = 1, ..., n} <= tol``
where ``g_i`` is the i-th component of the gradient.

verbose : int, default=0
Expand Down Expand Up @@ -316,7 +316,7 @@ def _logistic_regression_path(
w0 = np.zeros(n_features + int(fit_intercept), dtype=X.dtype)
mask = y == pos_class
y_bin = np.ones(y.shape, dtype=X.dtype)
if solver in ["lbfgs", "newton-cg"]:
if solver in ["lbfgs", "newton-cg", "trust-ncg"]:
# HalfBinomialLoss, used for those solvers, represents y in [0, 1] instead
# of in [-1, 1].
mask_classes = np.array([0, 1])
Expand All @@ -333,7 +333,7 @@ def _logistic_regression_path(
sample_weight *= class_weight_[le.fit_transform(y_bin)]

else:
if solver in ["sag", "saga", "lbfgs", "newton-cg"]:
if solver in ["sag", "saga", "lbfgs", "newton-cg", "trust-ncg"]:
# SAG, lbfgs and newton-cg multinomial solvers need LabelEncoder,
# not LabelBinarizer, i.e. y as a 1d-array of integers.
# LabelEncoder also saves memory compared to LabelBinarizer, especially
Expand Down Expand Up @@ -391,7 +391,7 @@ def _logistic_regression_path(
w0[:, : coef.shape[1]] = coef

if multi_class == "multinomial":
if solver in ["lbfgs", "newton-cg"]:
if solver in ["lbfgs", "trust-ncg", "newton-cg"]:
# scipy.optimize.minimize and newton-cg accept only ravelled parameters,
# i.e. 1d-arrays. LinearModelLoss expects classes to be contiguous and
# reconstructs the 2d-array via w0.reshape((n_classes, -1), order="F").
Expand All @@ -408,6 +408,17 @@ def _logistic_regression_path(
func = loss.loss
grad = loss.gradient
hess = loss.gradient_hessian_product # hess = [gradient, hessp]
# TODO: remove local LinearModelLoss after renaming `loss`
elif solver == "trust-ncg":
loss_ = LinearModelLoss(
base_loss=HalfMultinomialLoss(n_classes=classes.size),
fit_intercept=fit_intercept,
)
func = loss_.loss_gradient

def hessp(x, p, *args):
return loss_.gradient_hessian_product(x, *args)[1](p)

warm_start_sag = {"coef": w0.T}
else:
target = y_bin
Expand All @@ -423,24 +434,45 @@ def _logistic_regression_path(
func = loss.loss
grad = loss.gradient
hess = loss.gradient_hessian_product # hess = [gradient, hessp]
# TODO: remove local LinearModelLoss after renaming `loss`
elif solver == "trust-ncg":
loss_ = LinearModelLoss(
base_loss=HalfBinomialLoss(), fit_intercept=fit_intercept
)
Comment on lines +437 to +441
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The local instantiation of LinearModelLoss is to counteract to a bug caused by the reassignment of loss to a scalar value later in the code (at the end of the loop). In a separate PR, we ought to rename the loss referring to the internal loss module, or rename the loss referring to the scalar loss.

func = loss_.loss_gradient

def hessp(x, p, *args):
return loss_.gradient_hessian_product(x, *args)[1](p)

warm_start_sag = {"coef": np.expand_dims(w0, axis=1)}

coefs = list()
n_iter = np.zeros(len(Cs), dtype=np.int32)
for i, C in enumerate(Cs):
if solver == "lbfgs":
l2_reg_strength = 1.0 / C
iprint = [-1, 50, 1, 100, 101][
np.searchsorted(np.array([0, 1, 2, 3]), verbose)
]
opt_res = optimize.minimize(
func,
w0,
method="L-BFGS-B",
jac=True,
args=(X, target, sample_weight, l2_reg_strength, n_threads),
options={"iprint": iprint, "gtol": tol, "maxiter": max_iter},
)
if solver in ["lbfgs", "trust-ncg"]:
if solver == "lbfgs":
l2_reg_strength = 1.0 / C
iprint = [-1, 50, 1, 100, 101][
np.searchsorted(np.array([0, 1, 2, 3]), verbose)
]
opt_res = optimize.minimize(
func,
w0,
method="L-BFGS-B",
jac=True,
args=(X, target, sample_weight, l2_reg_strength, n_threads),
options={"iprint": iprint, "gtol": tol, "maxiter": max_iter},
)
elif solver == "trust-ncg":
opt_res = optimize.minimize(
func,
w0,
method=solver,
jac=True,
hessp=hessp,
args=(X, target, sample_weight, 1.0 / C),
options={"gtol": tol, "maxiter": max_iter},
)
n_iter_i = _check_optimize_result(
solver,
opt_res,
Expand Down Expand Up @@ -517,7 +549,7 @@ def _logistic_regression_path(

if multi_class == "multinomial":
n_classes = max(2, classes.size)
if solver in ["lbfgs", "newton-cg"]:
if solver in ["lbfgs", "newton-cg", "trust-ncg"]:
multi_w0 = np.reshape(w0, (n_classes, -1), order="F")
else:
multi_w0 = w0
Expand Down Expand Up @@ -842,7 +874,7 @@ class LogisticRegression(LinearClassifierMixin, SparseCoefMixin, BaseEstimator):
Used when ``solver`` == 'sag', 'saga' or 'liblinear' to shuffle the
data. See :term:`Glossary <random_state>` for details.

solver : {'newton-cg', 'lbfgs', 'liblinear', 'sag', 'saga'}, \
solver : {'newton-cg', 'lbfgs', 'liblinear', 'sag', 'saga', 'trust-ncg'}, \
default='lbfgs'

Algorithm to use in the optimization problem. Default is 'lbfgs'.
Expand All @@ -863,6 +895,7 @@ class LogisticRegression(LinearClassifierMixin, SparseCoefMixin, BaseEstimator):
- 'liblinear' - ['l1', 'l2']
- 'sag' - ['l2', 'none']
- 'saga' - ['elasticnet', 'l1', 'l2', 'none']
- 'trust-ncg' - ['l2', 'none']

.. note::
'sag' and 'saga' fast convergence is only guaranteed on
Expand All @@ -881,6 +914,8 @@ class LogisticRegression(LinearClassifierMixin, SparseCoefMixin, BaseEstimator):
SAGA solver.
.. versionchanged:: 0.22
The default solver changed from 'liblinear' to 'lbfgs' in 0.22.
.. versionchanged:: 1.1
Newton conjugate gradient trust-region (trust-ncg) solver.

max_iter : int, default=100
Maximum number of iterations taken for the solvers to converge.
Expand Down
Loading