diff --git a/doc/modules/linear_model.rst b/doc/modules/linear_model.rst index c4767d0cb2d64..728d01a22ea0d 100644 --- a/doc/modules/linear_model.rst +++ b/doc/modules/linear_model.rst @@ -838,36 +838,37 @@ small data-sets but for larger datasets its performance suffers. [9]_ The following table summarizes the penalties supported by each solver: -+------------------------------+-----------------+-------------+-----------------+-----------+------------+ -| | **Solvers** | -+------------------------------+-----------------+-------------+-----------------+-----------+------------+ -| **Penalties** | **'liblinear'** | **'lbfgs'** | **'newton-cg'** | **'sag'** | **'saga'** | -+------------------------------+-----------------+-------------+-----------------+-----------+------------+ -| Multinomial + L2 penalty | no | yes | yes | yes | yes | -+------------------------------+-----------------+-------------+-----------------+-----------+------------+ -| OVR + L2 penalty | yes | yes | yes | yes | yes | -+------------------------------+-----------------+-------------+-----------------+-----------+------------+ -| Multinomial + L1 penalty | no | no | no | no | yes | -+------------------------------+-----------------+-------------+-----------------+-----------+------------+ -| OVR + L1 penalty | yes | no | no | no | yes | -+------------------------------+-----------------+-------------+-----------------+-----------+------------+ -| Elastic-Net | no | no | no | no | yes | -+------------------------------+-----------------+-------------+-----------------+-----------+------------+ -| No penalty ('none') | no | yes | yes | yes | yes | -+------------------------------+-----------------+-------------+-----------------+-----------+------------+ -| **Behaviors** | | -+------------------------------+-----------------+-------------+-----------------+-----------+------------+ -| Penalize the intercept (bad) | yes | no | no | no | no | -+------------------------------+-----------------+-------------+-----------------+-----------+------------+ -| Faster for large datasets | no | no | no | yes | yes | -+------------------------------+-----------------+-------------+-----------------+-----------+------------+ -| Robust to unscaled datasets | yes | yes | yes | no | no | -+------------------------------+-----------------+-------------+-----------------+-----------+------------+ ++------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+ +| | **Solvers** | ++------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+ +| **Penalties** | **'liblinear'** | **'lbfgs'** | **'newton-cg'** | **'sag'** | **'saga'** | **'trust-ncg'** | ++------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+ +| Multinomial + L2 penalty | no | yes | yes | yes | yes | yes | ++------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+ +| OVR + L2 penalty | yes | yes | yes | yes | yes | yes | ++------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+ +| Multinomial + L1 penalty | no | no | no | no | yes | no | ++------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+ +| OVR + L1 penalty | yes | no | no | no | yes | no | ++------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+ +| Elastic-Net | no | no | no | no | yes | no | ++------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+ +| No penalty ('none') | no | yes | yes | yes | yes | yes | ++------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+ +| **Behaviors** | | | ++------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+ +| Penalize the intercept (bad) | yes | no | no | no | no | no | ++------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+ +| Faster for large datasets | no | no | no | yes | yes | yes | ++------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+ +| Robust to unscaled datasets | yes | yes | yes | no | no | yes | ++------------------------------+-----------------+-------------+-----------------+-----------+------------+-----------------+ The "lbfgs" solver is used by default for its robustness. For large datasets the "saga" solver is usually faster. For large dataset, you may also consider using :class:`SGDClassifier` with 'log' loss, which might be even faster but requires more tuning. +The "trust-ncg" solver is useful for sparse datasets. .. topic:: Examples: diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index da86a1755c2f1..2a52595a5bfb3 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -428,7 +428,8 @@ def hessp(v): def _check_solver(solver, penalty, dual): - all_solvers = ['liblinear', 'newton-cg', 'lbfgs', 'sag', 'saga'] + all_solvers = ['liblinear', 'newton-cg', 'lbfgs', 'sag', 'saga', + 'trust-ncg'] if solver not in all_solvers: raise ValueError("Logistic Regression supports only solvers in %s, got" " %s." % (all_solvers, solver)) @@ -728,7 +729,7 @@ def _logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True, if multi_class == 'multinomial': # scipy.optimize.minimize and newton-cg accepts only # ravelled parameters. - if solver in ['lbfgs', 'newton-cg']: + if solver in ['lbfgs', 'trust-ncg', 'newton-cg']: w0 = w0.ravel() target = Y_multi if solver == 'lbfgs': @@ -737,6 +738,9 @@ def func(x, *args): return _multinomial_loss_grad(x, *args)[0:2] def func(x, *args): return _multinomial_loss(x, *args)[0] def grad(x, *args): return _multinomial_loss_grad(x, *args)[1] hess = _multinomial_grad_hess + elif solver == 'trust-ncg': + def func(x, *args): return _multinomial_loss_grad(x, *args)[0:2] + def hessp(x, *args): return _multinomial_grad_hess(*args)[1](x) warm_start_sag = {'coef': w0.T} else: target = y_bin @@ -746,19 +750,29 @@ def grad(x, *args): return _multinomial_loss_grad(x, *args)[1] func = _logistic_loss def grad(x, *args): return _logistic_loss_and_grad(x, *args)[1] hess = _logistic_grad_hess + elif solver == 'trust-ncg': + func = _logistic_loss_and_grad + def hessp(x, *args): return _logistic_grad_hess(*args)[1](x) warm_start_sag = {'coef': np.expand_dims(w0, axis=1)} coefs = list() n_iter = np.zeros(len(Cs), dtype=np.int32) for i, C in enumerate(Cs): - if solver == 'lbfgs': + if solver in ['lbfgs', 'trust-ncg']: iprint = [-1, 50, 1, 100, 101][ np.searchsorted(np.array([0, 1, 2, 3]), verbose)] - opt_res = optimize.minimize( - func, w0, method="L-BFGS-B", jac=True, - args=(X, target, 1. / C, sample_weight), - options={"iprint": iprint, "gtol": tol, "maxiter": max_iter} - ) + if solver == 'lbfgs': + opt_res = optimize.minimize( + func, w0, method="L-BFGS-B", jac=True, + args=(X, target, 1. / C, sample_weight), + options={"iprint": iprint, "gtol": tol, + "maxiter": max_iter} + ) + elif solver == 'trust-ncg': + opt_res = optimize.minimize( + func, w0, method=solver, jac=True, hessp=hessp, + args=(X, target, 1. / C, sample_weight) + ) n_iter_i = _check_optimize_result( solver, opt_res, max_iter, extra_warning_msg=_LOGISTIC_SOLVER_CONVERGENCE_MSG) diff --git a/sklearn/linear_model/tests/test_logistic.py b/sklearn/linear_model/tests/test_logistic.py index e215400b53b80..460a9fac0ef16 100644 --- a/sklearn/linear_model/tests/test_logistic.py +++ b/sklearn/linear_model/tests/test_logistic.py @@ -204,7 +204,8 @@ def test_check_solver_option(LR): X, y = iris.data, iris.target msg = ("Logistic Regression supports only solvers in ['liblinear', " - "'newton-cg', 'lbfgs', 'sag', 'saga'], got wrong_name.") + "'newton-cg', 'lbfgs', 'sag', 'saga', 'trust-ncg'], " + "got wrong_name.") lr = LR(solver="wrong_name", multi_class="ovr") assert_raise_message(ValueError, msg, lr.fit, X, y) @@ -1692,7 +1693,7 @@ def test_logistic_regression_path_coefs_multinomial(): Cs=3, tol=1e-3)], ids=lambda x: x.__class__.__name__) @pytest.mark.parametrize('solver', ['liblinear', 'lbfgs', 'newton-cg', 'sag', - 'saga']) + 'saga', 'trust-ncg']) def test_logistic_regression_multi_class_auto(est, solver): # check multi_class='auto' => multi_class='ovr' iff binary y or liblinear @@ -1718,10 +1719,10 @@ def fit(X, y, **kw): else: est_multi_multi = fit(X, y_multi, multi_class='multinomial', solver=solver) - if sys.platform == 'darwin' and solver == 'lbfgs': - pytest.xfail('Issue #11924: LogisticRegressionCV(solver="lbfgs", ' - 'multi_class="multinomial") is nondeterministic on ' - 'MacOS.') + if sys.platform == 'darwin' and solver in ['lbfgs', 'trust-ncg']: + pytest.xfail('Issue #11924: LogisticRegressionCV(' + f'solver="{solver}", multi_class="multinomial") ' + 'is nondeterministic on MacOS.') assert_allclose(est_auto_multi.coef_, est_multi_multi.coef_) assert_allclose(est_auto_multi.predict_proba(X2), est_multi_multi.predict_proba(X2)) diff --git a/sklearn/utils/optimize.py b/sklearn/utils/optimize.py index d10639ce5bd85..9b44eb594da5a 100644 --- a/sklearn/utils/optimize.py +++ b/sklearn/utils/optimize.py @@ -211,7 +211,8 @@ def _check_optimize_result(solver, result, max_iter=None, Parameters ---------- solver : str - Solver name. Currently only `lbfgs` is supported. + Solver name. Currently only `lbfgs` and `trust-ncg` + are supported. result : OptimizeResult Result of the scipy.optimize.minimize function. @@ -228,7 +229,7 @@ def _check_optimize_result(solver, result, max_iter=None, Number of iterations. """ # handle both scipy and scikit-learn solver names - if solver == "lbfgs": + if solver in ["lbfgs", "trust-ncg"]: if result.status != 0: warning_msg = ( "{} failed to converge (status={}):\n{}.\n\n" @@ -236,7 +237,7 @@ def _check_optimize_result(solver, result, max_iter=None, "or scale the data as shown in:\n" " https://scikit-learn.org/stable/modules/" "preprocessing.html" - ).format(solver, result.status, result.message.decode("latin1")) + ).format(solver, result.status, result.message) if extra_warning_msg is not None: warning_msg += "\n" + extra_warning_msg warnings.warn(warning_msg, ConvergenceWarning, stacklevel=2)