Skip to content

[MRG] ENH Add trust-ncg option to LogisticRegression #17877

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
51 changes: 26 additions & 25 deletions doc/modules/linear_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -838,36 +838,37 @@ small data-sets but for larger datasets its performance suffers. [9]_

The following table summarizes the penalties supported by each solver:

+------------------------------+-----------------+-------------+-----------------+-----------+------------+
| | **Solvers** |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
| **Penalties** | **'liblinear'** | **'lbfgs'** | **'newton-cg'** | **'sag'** | **'saga'** |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
| Multinomial + L2 penalty | no | yes | yes | yes | yes |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
| OVR + L2 penalty | yes | yes | yes | yes | yes |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
| Multinomial + L1 penalty | no | no | no | no | yes |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
| OVR + L1 penalty | yes | no | no | no | yes |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
| Elastic-Net | no | no | no | no | yes |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
| No penalty ('none') | no | yes | yes | yes | yes |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
| **Behaviors** | |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
| Penalize the intercept (bad) | yes | no | no | no | no |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
| Faster for large datasets | no | no | no | yes | yes |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
| Robust to unscaled datasets | yes | yes | yes | no | no |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
+------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+
| | **Solvers** |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+
| **Penalties** | **'liblinear'** | **'lbfgs'** | **'newton-cg'** | **'sag'** | **'saga'** | **'trust-ncg'** |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+
| Multinomial + L2 penalty | no | yes | yes | yes | yes | yes |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+
| OVR + L2 penalty | yes | yes | yes | yes | yes | yes |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+
| Multinomial + L1 penalty | no | no | no | no | yes | no |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+
| OVR + L1 penalty | yes | no | no | no | yes | no |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+
| Elastic-Net | no | no | no | no | yes | no |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+
| No penalty ('none') | no | yes | yes | yes | yes | yes |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+
| **Behaviors** | | |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+
| Penalize the intercept (bad) | yes | no | no | no | no | no |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+
| Faster for large datasets | no | no | no | yes | yes | yes |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+
| Robust to unscaled datasets | yes | yes | yes | no | no | yes |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+

The "lbfgs" solver is used by default for its robustness. For large datasets
the "saga" solver is usually faster.
For large dataset, you may also consider using :class:`SGDClassifier`
with 'log' loss, which might be even faster but requires more tuning.
The "trust-ncg" solver is useful for sparse datasets.

.. topic:: Examples:

Expand Down
30 changes: 22 additions & 8 deletions sklearn/linear_model/_logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,8 @@ def hessp(v):


def _check_solver(solver, penalty, dual):
all_solvers = ['liblinear', 'newton-cg', 'lbfgs', 'sag', 'saga']
all_solvers = ['liblinear', 'newton-cg', 'lbfgs', 'sag', 'saga',
'trust-ncg']
Comment on lines +431 to +432
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
all_solvers = ['liblinear', 'newton-cg', 'lbfgs', 'sag', 'saga',
'trust-ncg']
all_solvers = [
'liblinear', 'newton-cg', 'lbfgs', 'sag', 'saga', 'trust-ncg'
]

if solver not in all_solvers:
raise ValueError("Logistic Regression supports only solvers in %s, got"
" %s." % (all_solvers, solver))
Expand Down Expand Up @@ -728,7 +729,7 @@ def _logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
if multi_class == 'multinomial':
# scipy.optimize.minimize and newton-cg accepts only
# ravelled parameters.
if solver in ['lbfgs', 'newton-cg']:
if solver in ['lbfgs', 'trust-ncg', 'newton-cg']:
w0 = w0.ravel()
target = Y_multi
if solver == 'lbfgs':
Expand All @@ -737,6 +738,9 @@ def func(x, *args): return _multinomial_loss_grad(x, *args)[0:2]
def func(x, *args): return _multinomial_loss(x, *args)[0]
def grad(x, *args): return _multinomial_loss_grad(x, *args)[1]
hess = _multinomial_grad_hess
elif solver == 'trust-ncg':
def func(x, *args): return _multinomial_loss_grad(x, *args)[0:2]
def hessp(x, *args): return _multinomial_grad_hess(*args)[1](x)
warm_start_sag = {'coef': w0.T}
else:
target = y_bin
Expand All @@ -746,19 +750,29 @@ def grad(x, *args): return _multinomial_loss_grad(x, *args)[1]
func = _logistic_loss
def grad(x, *args): return _logistic_loss_and_grad(x, *args)[1]
hess = _logistic_grad_hess
elif solver == 'trust-ncg':
func = _logistic_loss_and_grad
def hessp(x, *args): return _logistic_grad_hess(*args)[1](x)
warm_start_sag = {'coef': np.expand_dims(w0, axis=1)}

coefs = list()
n_iter = np.zeros(len(Cs), dtype=np.int32)
for i, C in enumerate(Cs):
if solver == 'lbfgs':
if solver in ['lbfgs', 'trust-ncg']:
iprint = [-1, 50, 1, 100, 101][
np.searchsorted(np.array([0, 1, 2, 3]), verbose)]
opt_res = optimize.minimize(
func, w0, method="L-BFGS-B", jac=True,
args=(X, target, 1. / C, sample_weight),
options={"iprint": iprint, "gtol": tol, "maxiter": max_iter}
)
if solver == 'lbfgs':
opt_res = optimize.minimize(
func, w0, method="L-BFGS-B", jac=True,
args=(X, target, 1. / C, sample_weight),
options={"iprint": iprint, "gtol": tol,
"maxiter": max_iter}
)
elif solver == 'trust-ncg':
opt_res = optimize.minimize(
func, w0, method=solver, jac=True, hessp=hessp,
args=(X, target, 1. / C, sample_weight)
)
n_iter_i = _check_optimize_result(
solver, opt_res, max_iter,
extra_warning_msg=_LOGISTIC_SOLVER_CONVERGENCE_MSG)
Expand Down
13 changes: 7 additions & 6 deletions sklearn/linear_model/tests/test_logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ def test_check_solver_option(LR):
X, y = iris.data, iris.target

msg = ("Logistic Regression supports only solvers in ['liblinear', "
"'newton-cg', 'lbfgs', 'sag', 'saga'], got wrong_name.")
"'newton-cg', 'lbfgs', 'sag', 'saga', 'trust-ncg'], "
"got wrong_name.")
lr = LR(solver="wrong_name", multi_class="ovr")
assert_raise_message(ValueError, msg, lr.fit, X, y)

Expand Down Expand Up @@ -1692,7 +1693,7 @@ def test_logistic_regression_path_coefs_multinomial():
Cs=3, tol=1e-3)],
ids=lambda x: x.__class__.__name__)
@pytest.mark.parametrize('solver', ['liblinear', 'lbfgs', 'newton-cg', 'sag',
'saga'])
'saga', 'trust-ncg'])
def test_logistic_regression_multi_class_auto(est, solver):
# check multi_class='auto' => multi_class='ovr' iff binary y or liblinear

Expand All @@ -1718,10 +1719,10 @@ def fit(X, y, **kw):
else:
est_multi_multi = fit(X, y_multi, multi_class='multinomial',
solver=solver)
if sys.platform == 'darwin' and solver == 'lbfgs':
pytest.xfail('Issue #11924: LogisticRegressionCV(solver="lbfgs", '
'multi_class="multinomial") is nondeterministic on '
'MacOS.')
if sys.platform == 'darwin' and solver in ['lbfgs', 'trust-ncg']:
pytest.xfail('Issue #11924: LogisticRegressionCV('
f'solver="{solver}", multi_class="multinomial") '
'is nondeterministic on MacOS.')
assert_allclose(est_auto_multi.coef_, est_multi_multi.coef_)
assert_allclose(est_auto_multi.predict_proba(X2),
est_multi_multi.predict_proba(X2))
Expand Down
7 changes: 4 additions & 3 deletions sklearn/utils/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,8 @@ def _check_optimize_result(solver, result, max_iter=None,
Parameters
----------
solver : str
Solver name. Currently only `lbfgs` is supported.
Solver name. Currently only `lbfgs` and `trust-ncg`
are supported.

result : OptimizeResult
Result of the scipy.optimize.minimize function.
Expand All @@ -228,15 +229,15 @@ def _check_optimize_result(solver, result, max_iter=None,
Number of iterations.
"""
# handle both scipy and scikit-learn solver names
if solver == "lbfgs":
if solver in ["lbfgs", "trust-ncg"]:
if result.status != 0:
warning_msg = (
"{} failed to converge (status={}):\n{}.\n\n"
"Increase the number of iterations (max_iter) "
"or scale the data as shown in:\n"
" https://scikit-learn.org/stable/modules/"
"preprocessing.html"
).format(solver, result.status, result.message.decode("latin1"))
).format(solver, result.status, result.message)
if extra_warning_msg is not None:
warning_msg += "\n" + extra_warning_msg
warnings.warn(warning_msg, ConvergenceWarning, stacklevel=2)
Expand Down