diff --git a/doc/modules/linear_model.rst b/doc/modules/linear_model.rst index 0177165b2bb07..14902d67b263e 100644 --- a/doc/modules/linear_model.rst +++ b/doc/modules/linear_model.rst @@ -954,7 +954,7 @@ Solvers ------- The solvers implemented in the class :class:`LogisticRegression` -are "liblinear", "newton-cg", "lbfgs", "sag" and "saga": +are "lbfgs", "liblinear", "newton-cg", "newton-cholesky", "sag" and "saga": The solver "liblinear" uses a coordinate descent (CD) algorithm, and relies on the excellent C++ `LIBLINEAR library @@ -968,7 +968,7 @@ classifiers. For :math:`\ell_1` regularization :func:`sklearn.svm.l1_min_c` allo calculate the lower bound for C in order to get a non "null" (all feature weights to zero) model. -The "lbfgs", "sag" and "newton-cg" solvers only support :math:`\ell_2` +The "lbfgs", "newton-cg" and "sag" solvers only support :math:`\ell_2` regularization or no regularization, and are found to converge faster for some high-dimensional data. Setting `multi_class` to "multinomial" with these solvers learns a true multinomial logistic regression model [5]_, which means that its @@ -989,33 +989,41 @@ Broyden–Fletcher–Goldfarb–Shanno algorithm [8]_, which belongs to quasi-Newton methods. The "lbfgs" solver is recommended for use for small data-sets but for larger datasets its performance suffers. [9]_ +The "newton-cholesky" solver is an exact Newton solver that calculates the hessian +matrix and solves the resulting linear system. It is a very good choice for +`n_samples` >> `n_features`, but has a few shortcomings: Only :math:`\ell_2` +regularization is supported. Furthermore, because the hessian matrix is explicitly +computed, the memory usage has a quadratic dependency on `n_features` as well as on +`n_classes`. As a consequence, only the one-vs-rest scheme is implemented for the +multiclass case. + The following table summarizes the penalties supported by each solver: -+------------------------------+-----------------+-------------+-----------------+-----------+------------+ -| | **Solvers** | -+------------------------------+-----------------+-------------+-----------------+-----------+------------+ -| **Penalties** | **'liblinear'** | **'lbfgs'** | **'newton-cg'** | **'sag'** | **'saga'** | -+------------------------------+-----------------+-------------+-----------------+-----------+------------+ -| Multinomial + L2 penalty | no | yes | yes | yes | yes | -+------------------------------+-----------------+-------------+-----------------+-----------+------------+ -| OVR + L2 penalty | yes | yes | yes | yes | yes | -+------------------------------+-----------------+-------------+-----------------+-----------+------------+ -| Multinomial + L1 penalty | no | no | no | no | yes | -+------------------------------+-----------------+-------------+-----------------+-----------+------------+ -| OVR + L1 penalty | yes | no | no | no | yes | -+------------------------------+-----------------+-------------+-----------------+-----------+------------+ -| Elastic-Net | no | no | no | no | yes | -+------------------------------+-----------------+-------------+-----------------+-----------+------------+ -| No penalty ('none') | no | yes | yes | yes | yes | -+------------------------------+-----------------+-------------+-----------------+-----------+------------+ -| **Behaviors** | | -+------------------------------+-----------------+-------------+-----------------+-----------+------------+ -| Penalize the intercept (bad) | yes | no | no | no | no | -+------------------------------+-----------------+-------------+-----------------+-----------+------------+ -| Faster for large datasets | no | no | no | yes | yes | -+------------------------------+-----------------+-------------+-----------------+-----------+------------+ -| Robust to unscaled datasets | yes | yes | yes | no | no | -+------------------------------+-----------------+-------------+-----------------+-----------+------------+ ++------------------------------+-----------------+-------------+-----------------+-----------------------+-----------+------------+ +| | **Solvers** | ++------------------------------+-------------+-----------------+-----------------+-----------------------+-----------+------------+ +| **Penalties** | **'lbfgs'** | **'liblinear'** | **'newton-cg'** | **'newton-cholesky'** | **'sag'** | **'saga'** | ++------------------------------+-------------+-----------------+-----------------+-----------------------+-----------+------------+ +| Multinomial + L2 penalty | yes | no | yes | no | yes | yes | ++------------------------------+-------------+-----------------+-----------------+-----------------------+-----------+------------+ +| OVR + L2 penalty | yes | yes | yes | yes | yes | yes | ++------------------------------+-------------+-----------------+-----------------+-----------------------+-----------+------------+ +| Multinomial + L1 penalty | no | no | no | no | no | yes | ++------------------------------+-------------+-----------------+-----------------+-----------------------+-----------+------------+ +| OVR + L1 penalty | no | yes | no | no | no | yes | ++------------------------------+-------------+-----------------+-----------------+-----------------------+-----------+------------+ +| Elastic-Net | no | no | no | no | no | yes | ++------------------------------+-------------+-----------------+-----------------+-----------------------+-----------+------------+ +| No penalty ('none') | yes | no | yes | yes | yes | yes | ++------------------------------+-------------+-----------------+-----------------+-----------------------+-----------+------------+ +| **Behaviors** | | ++------------------------------+-------------+-----------------+-----------------+-----------------------+-----------+------------+ +| Penalize the intercept (bad) | no | yes | no | no | no | no | ++------------------------------+-------------+-----------------+-----------------+-----------------------+-----------+------------+ +| Faster for large datasets | no | no | no | no | yes | yes | ++------------------------------+-------------+-----------------+-----------------+-----------------------+-----------+------------+ +| Robust to unscaled datasets | yes | yes | yes | yes | no | no | ++------------------------------+-------------+-----------------+-----------------+-----------------------+-----------+------------+ The "lbfgs" solver is used by default for its robustness. For large datasets the "saga" solver is usually faster. diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index a60325c6b3e4f..94902911aa253 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -353,7 +353,8 @@ Changelog :mod:`sklearn.linear_model` ........................... -- |Enhancement| :class:`linear_model.GammaRegressor`, +- |Enhancement| :class:`linear_model.LogisticRegression`, + :class:`linear_model.LogisticRegressionCV`, :class:`linear_model.GammaRegressor`, :class:`linear_model.PoissonRegressor` and :class:`linear_model.TweedieRegressor` got a new solver `solver="newton-cholesky"`. This is a 2nd order (Newton) optimisation routine that uses a Cholesky decomposition of the hessian matrix. @@ -361,7 +362,7 @@ Changelog converge both faster and to a higher precision solution than the `"lbfgs"` solver on problems with one-hot encoded categorical variables with some rare categorical levels. - :pr:`24637` by :user:`Christian Lorentzen `. + :pr:`24637` and :pr:`24767` by :user:`Christian Lorentzen `. - |Enhancement| :class:`linear_model.GammaRegressor`, :class:`linear_model.PoissonRegressor` and :class:`linear_model.TweedieRegressor` diff --git a/sklearn/linear_model/_glm/glm.py b/sklearn/linear_model/_glm/glm.py index 6dd02a387e0f3..1ffa16b219c22 100644 --- a/sklearn/linear_model/_glm/glm.py +++ b/sklearn/linear_model/_glm/glm.py @@ -75,7 +75,10 @@ class _GeneralizedLinearRegressor(RegressorMixin, BaseEstimator): 'newton-cholesky' Uses Newton-Raphson steps (in arbitrary precision arithmetic equivalent to iterated reweighted least squares) with an inner Cholesky based solver. - This solver is suited for n_samples >> n_features. + This solver is a good choice for `n_samples` >> `n_features`, especially + with one-hot encoded categorical features with rare categories. Be aware + that the memory usage of this solver has a quadratic dependency on + `n_features` because it explicitly computes the Hessian matrix. .. versionadded:: 1.2 @@ -304,7 +307,7 @@ def fit(self, X, y, sample_weight=None): coef = sol.solve(X, y, sample_weight) self.n_iter_ = sol.iteration else: - raise TypeError(f"Invalid solver={self.solver}.") + raise ValueError(f"Invalid solver={self.solver}.") if self.fit_intercept: self.intercept_ = coef[-1] @@ -512,7 +515,10 @@ class PoissonRegressor(_GeneralizedLinearRegressor): 'newton-cholesky' Uses Newton-Raphson steps (in arbitrary precision arithmetic equivalent to iterated reweighted least squares) with an inner Cholesky based solver. - This solver is suited for n_samples >> n_features. + This solver is a good choice for `n_samples` >> `n_features`, especially + with one-hot encoded categorical features with rare categories. Be aware + that the memory usage of this solver has a quadratic dependency on + `n_features` because it explicitly computes the Hessian matrix. .. versionadded:: 1.2 @@ -640,7 +646,10 @@ class GammaRegressor(_GeneralizedLinearRegressor): 'newton-cholesky' Uses Newton-Raphson steps (in arbitrary precision arithmetic equivalent to iterated reweighted least squares) with an inner Cholesky based solver. - This solver is suited for n_samples >> n_features. + This solver is a good choice for `n_samples` >> `n_features`, especially + with one-hot encoded categorical features with rare categories. Be aware + that the memory usage of this solver has a quadratic dependency on + `n_features` because it explicitly computes the Hessian matrix. .. versionadded:: 1.2 @@ -799,7 +808,10 @@ class TweedieRegressor(_GeneralizedLinearRegressor): 'newton-cholesky' Uses Newton-Raphson steps (in arbitrary precision arithmetic equivalent to iterated reweighted least squares) with an inner Cholesky based solver. - This solver is suited for n_samples >> n_features. + This solver is a good choice for `n_samples` >> `n_features`, especially + with one-hot encoded categorical features with rare categories. Be aware + that the memory usage of this solver has a quadratic dependency on + `n_features` because it explicitly computes the Hessian matrix. .. versionadded:: 1.2 diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index 375847f5de1f8..fecde099a3f60 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -23,6 +23,7 @@ from ._base import LinearClassifierMixin, SparseCoefMixin, BaseEstimator from ._linear_loss import LinearModelLoss from ._sag import sag_solver +from ._glm.glm import NewtonCholeskySolver from .._loss.loss import HalfBinomialLoss, HalfMultinomialLoss from ..preprocessing import LabelEncoder, LabelBinarizer from ..svm._base import _fit_liblinear @@ -73,14 +74,19 @@ def _check_solver(solver, penalty, dual): def _check_multi_class(multi_class, solver, n_classes): + """Computes the multi class type, either "multinomial" or "ovr". + + For `n_classes` > 2 and a solver that supports it, returns "multinomial". + For all other cases, in particular binary classification, return "ovr". + """ if multi_class == "auto": - if solver == "liblinear": + if solver in ("liblinear", "newton-cholesky"): multi_class = "ovr" elif n_classes > 2: multi_class = "multinomial" else: multi_class = "ovr" - if multi_class == "multinomial" and solver == "liblinear": + if multi_class == "multinomial" and solver in ("liblinear", "newton-cholesky"): raise ValueError("Solver %s does not support a multinomial backend." % solver) return multi_class @@ -153,7 +159,7 @@ def _logistic_regression_path( For the liblinear and lbfgs solvers set verbose to any positive number for verbosity. - solver : {'lbfgs', 'newton-cg', 'liblinear', 'sag', 'saga'}, \ + solver : {'lbfgs', 'liblinear', 'newton-cg', 'newton-cholesky', 'sag', 'saga'}, \ default='lbfgs' Numerical solver to use. @@ -272,7 +278,7 @@ def _logistic_regression_path( ) y = check_array(y, ensure_2d=False, dtype=None) check_consistent_length(X, y) - _, n_features = X.shape + n_samples, n_features = X.shape classes = np.unique(y) random_state = check_random_state(random_state) @@ -289,6 +295,23 @@ def _logistic_regression_path( # Otherwise set them to 1 for all examples sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype, copy=True) + if solver == "newton-cholesky": + # IMPORTANT NOTE: Rescaling of sample_weight: + # Same as in _GeneralizedLinearRegressor.fit(). + # We want to minimize + # obj = 1/(2*sum(sample_weight)) * sum(sample_weight * deviance) + # + 1/2 * alpha * L2, + # with + # deviance = 2 * log_loss. + # The objective is invariant to multiplying sample_weight by a constant. We + # choose this constant such that sum(sample_weight) = 1. Thus, we end up with + # obj = sum(sample_weight * loss) + 1/2 * alpha * L2. + # Note that LinearModelLoss.loss() computes sum(sample_weight * loss). + # + # This rescaling has to be done before multiplying by class_weights. + sw_sum = sample_weight.sum() # needed to rescale penalty, nasty matter! + sample_weight = sample_weight / sw_sum + # If class_weights is a dict (provided by the user), the weights # are assigned to the original labels. If it is "balanced", then # the class_weights are assigned after masking the labels with a OvR. @@ -297,13 +320,13 @@ def _logistic_regression_path( class_weight_ = compute_class_weight(class_weight, classes=classes, y=y) sample_weight *= class_weight_[le.fit_transform(y)] - # For doing a ovr, we need to mask the labels first. for the + # For doing a ovr, we need to mask the labels first. For the # multinomial case this is not necessary. if multi_class == "ovr": w0 = np.zeros(n_features + int(fit_intercept), dtype=X.dtype) mask = y == pos_class y_bin = np.ones(y.shape, dtype=X.dtype) - if solver in ["lbfgs", "newton-cg"]: + if solver in ["lbfgs", "newton-cg", "newton-cholesky"]: # HalfBinomialLoss, used for those solvers, represents y in [0, 1] instead # of in [-1, 1]. mask_classes = np.array([0, 1]) @@ -410,6 +433,10 @@ def _logistic_regression_path( func = loss.loss grad = loss.gradient hess = loss.gradient_hessian_product # hess = [gradient, hessp] + elif solver == "newton-cholesky": + loss = LinearModelLoss( + base_loss=HalfBinomialLoss(), fit_intercept=fit_intercept + ) warm_start_sag = {"coef": np.expand_dims(w0, axis=1)} coefs = list() @@ -441,6 +468,21 @@ def _logistic_regression_path( w0, n_iter_i = _newton_cg( hess, func, grad, w0, args=args, maxiter=max_iter, tol=tol ) + elif solver == "newton-cholesky": + # The division by sw_sum is a consequence of the rescaling of + # sample_weight, see comment above. + l2_reg_strength = 1.0 / C / sw_sum + sol = NewtonCholeskySolver( + coef=w0, + linear_loss=loss, + l2_reg_strength=l2_reg_strength, + tol=tol, + max_iter=max_iter, + n_threads=n_threads, + verbose=verbose, + ) + w0 = sol.solve(X=X, y=target, sample_weight=sample_weight) + n_iter_i = sol.iteration elif solver == "liblinear": coef_, intercept_, n_iter_i, = _fit_liblinear( X, @@ -601,7 +643,7 @@ def _log_reg_scoring_path( For the liblinear and lbfgs solvers set verbose to any positive number for verbosity. - solver : {'lbfgs', 'newton-cg', 'liblinear', 'sag', 'saga'}, \ + solver : {'lbfgs', 'liblinear', 'newton-cg', 'newton-cholesky', 'sag', 'saga'}, \ default='lbfgs' Decides which solver to use. @@ -833,7 +875,7 @@ class LogisticRegression(LinearClassifierMixin, SparseCoefMixin, BaseEstimator): Used when ``solver`` == 'sag', 'saga' or 'liblinear' to shuffle the data. See :term:`Glossary ` for details. - solver : {'newton-cg', 'lbfgs', 'liblinear', 'sag', 'saga'}, \ + solver : {'lbfgs', 'liblinear', 'newton-cg', 'newton-cholesky', 'sag', 'saga'}, \ default='lbfgs' Algorithm to use in the optimization problem. Default is 'lbfgs'. @@ -843,22 +885,29 @@ class LogisticRegression(LinearClassifierMixin, SparseCoefMixin, BaseEstimator): and 'saga' are faster for large ones; - For multiclass problems, only 'newton-cg', 'sag', 'saga' and 'lbfgs' handle multinomial loss; - - 'liblinear' is limited to one-versus-rest schemes. + - 'liblinear' and is limited to one-versus-rest schemes. + - 'newton-cholesky' is a good choice for `n_samples` >> `n_features`, + especially with one-hot encoded categorical features with rare + categories. Note that it is limited to binary classification and the + one-versus-rest reduction for multiclass classification. Be aware that + the memory usage of this solver has a quadratic dependency on + `n_features` because it explicitly computes the Hessian matrix. .. warning:: - The choice of the algorithm depends on the penalty chosen: + The choice of the algorithm depends on the penalty chosen. Supported penalties by solver: - - 'newton-cg' - ['l2', None] - - 'lbfgs' - ['l2', None] - - 'liblinear' - ['l1', 'l2'] - - 'sag' - ['l2', None] - - 'saga' - ['elasticnet', 'l1', 'l2', None] + - 'lbfgs' - ['l2', None] + - 'liblinear' - ['l1', 'l2'] + - 'newton-cg' - ['l2', None] + - 'newton-cholesky' - ['l2', None] + - 'sag' - ['l2', None] + - 'saga' - ['elasticnet', 'l1', 'l2', None] .. note:: - 'sag' and 'saga' fast convergence is only guaranteed on - features with approximately the same scale. You can - preprocess the data with a scaler from :mod:`sklearn.preprocessing`. + 'sag' and 'saga' fast convergence is only guaranteed on features + with approximately the same scale. You can preprocess the data with + a scaler from :mod:`sklearn.preprocessing`. .. seealso:: Refer to the User Guide for more information regarding @@ -872,6 +921,8 @@ class LogisticRegression(LinearClassifierMixin, SparseCoefMixin, BaseEstimator): SAGA solver. .. versionchanged:: 0.22 The default solver changed from 'liblinear' to 'lbfgs' in 0.22. + .. versionadded:: 1.2 + newton-cholesky solver. max_iter : int, default=100 Maximum number of iterations taken for the solvers to converge. @@ -1027,7 +1078,11 @@ class LogisticRegression(LinearClassifierMixin, SparseCoefMixin, BaseEstimator): "intercept_scaling": [Interval(Real, 0, None, closed="neither")], "class_weight": [dict, StrOptions({"balanced"}), None], "random_state": ["random_state"], - "solver": [StrOptions({"newton-cg", "lbfgs", "liblinear", "sag", "saga"})], + "solver": [ + StrOptions( + {"lbfgs", "liblinear", "newton-cg", "newton-cholesky", "sag", "saga"} + ) + ], "max_iter": [Interval(Integral, 0, None, closed="left")], "multi_class": [StrOptions({"auto", "ovr", "multinomial"})], "verbose": ["verbose"], @@ -1223,7 +1278,7 @@ def fit(self, X, y, sample_weight=None): # and multinomial multiclass classification and use joblib only for the # one-vs-rest multiclass case. if ( - solver in ["lbfgs", "newton-cg"] + solver in ["lbfgs", "newton-cg", "newton-cholesky"] and len(classes_) == 1 and effective_n_jobs(self.n_jobs) == 1 ): @@ -1308,7 +1363,10 @@ def predict_proba(self, X): ovr = self.multi_class in ["ovr", "warn"] or ( self.multi_class == "auto" - and (self.classes_.size <= 2 or self.solver == "liblinear") + and ( + self.classes_.size <= 2 + or self.solver in ("liblinear", "newton-cholesky") + ) ) if ovr: return super()._predict_proba_lr(X) @@ -1409,7 +1467,7 @@ class LogisticRegressionCV(LogisticRegression, LinearClassifierMixin, BaseEstima that can be used, look at :mod:`sklearn.metrics`. The default scoring option used is 'accuracy'. - solver : {'newton-cg', 'lbfgs', 'liblinear', 'sag', 'saga'}, \ + solver : {'lbfgs', 'liblinear', 'newton-cg', 'newton-cholesky', 'sag', 'saga'}, \ default='lbfgs' Algorithm to use in the optimization problem. Default is 'lbfgs'. @@ -1422,15 +1480,23 @@ class LogisticRegressionCV(LogisticRegression, LinearClassifierMixin, BaseEstima - 'liblinear' might be slower in :class:`LogisticRegressionCV` because it does not handle warm-starting. 'liblinear' is limited to one-versus-rest schemes. + - 'newton-cholesky' is a good choice for `n_samples` >> `n_features`, + especially with one-hot encoded categorical features with rare + categories. Note that it is limited to binary classification and the + one-versus-rest reduction for multiclass classification. Be aware that + the memory usage of this solver has a quadratic dependency on + `n_features` because it explicitly computes the Hessian matrix. .. warning:: - The choice of the algorithm depends on the penalty chosen: + The choice of the algorithm depends on the penalty chosen. + Supported penalties by solver: - - 'newton-cg' - ['l2'] - - 'lbfgs' - ['l2'] - - 'liblinear' - ['l1', 'l2'] - - 'sag' - ['l2'] - - 'saga' - ['elasticnet', 'l1', 'l2'] + - 'lbfgs' - ['l2'] + - 'liblinear' - ['l1', 'l2'] + - 'newton-cg' - ['l2'] + - 'newton-cholesky' - ['l2'] + - 'sag' - ['l2'] + - 'saga' - ['elasticnet', 'l1', 'l2'] .. note:: 'sag' and 'saga' fast convergence is only guaranteed on features @@ -1441,6 +1507,8 @@ class LogisticRegressionCV(LogisticRegression, LinearClassifierMixin, BaseEstima Stochastic Average Gradient descent solver. .. versionadded:: 0.19 SAGA solver. + .. versionadded:: 1.2 + newton-cholesky solver. tol : float, default=1e-4 Tolerance for stopping criteria. diff --git a/sklearn/linear_model/tests/test_logistic.py b/sklearn/linear_model/tests/test_logistic.py index b9fee199380f3..47c6860fe653f 100644 --- a/sklearn/linear_model/tests/test_logistic.py +++ b/sklearn/linear_model/tests/test_logistic.py @@ -1,5 +1,6 @@ import itertools import os +import warnings import numpy as np from numpy.testing import assert_allclose, assert_almost_equal from numpy.testing import assert_array_almost_equal, assert_array_equal @@ -31,6 +32,8 @@ LogisticRegressionCV, ) + +SOLVERS = ("lbfgs", "liblinear", "newton-cg", "newton-cholesky", "sag", "saga") X = [[-1, 0], [0, 1], [1, 1]] X_sp = sparse.csr_matrix(X) Y1 = [0, 1, 1] @@ -121,16 +124,9 @@ def test_predict_3_classes(): check_predictions(LogisticRegression(C=10), X_sp, Y2) -def test_predict_iris(): - # Test logistic regression with the iris dataset - n_samples, n_features = iris.data.shape - - target = iris.target_names[iris.target] - - # Test that both multinomial and OvR solvers handle - # multiclass data correctly and give good accuracy - # score (>0.95) for the training data. - for clf in [ +@pytest.mark.parametrize( + "clf", + [ LogisticRegression(C=len(iris.data), solver="liblinear", multi_class="ovr"), LogisticRegression(C=len(iris.data), solver="lbfgs", multi_class="multinomial"), LogisticRegression( @@ -146,37 +142,57 @@ def test_predict_iris(): multi_class="ovr", random_state=42, ), - ]: + LogisticRegression( + C=len(iris.data), solver="newton-cholesky", multi_class="ovr" + ), + ], +) +def test_predict_iris(clf): + """Test logistic regression with the iris dataset. + + Test that both multinomial and OvR solvers handle multiclass data correctly and + give good accuracy score (>0.95) for the training data. + """ + n_samples, n_features = iris.data.shape + target = iris.target_names[iris.target] + + if clf.solver == "lbfgs": + # lbfgs has convergence issues on the iris data with its default max_iter=100 + with warnings.catch_warnings(): + warnings.simplefilter("ignore", ConvergenceWarning) + clf.fit(iris.data, target) + else: clf.fit(iris.data, target) - assert_array_equal(np.unique(target), clf.classes_) + assert_array_equal(np.unique(target), clf.classes_) - pred = clf.predict(iris.data) - assert np.mean(pred == target) > 0.95 + pred = clf.predict(iris.data) + assert np.mean(pred == target) > 0.95 - probabilities = clf.predict_proba(iris.data) - assert_array_almost_equal(probabilities.sum(axis=1), np.ones(n_samples)) + probabilities = clf.predict_proba(iris.data) + assert_allclose(probabilities.sum(axis=1), np.ones(n_samples)) - pred = iris.target_names[probabilities.argmax(axis=1)] - assert np.mean(pred == target) > 0.95 + pred = iris.target_names[probabilities.argmax(axis=1)] + assert np.mean(pred == target) > 0.95 @pytest.mark.parametrize("LR", [LogisticRegression, LogisticRegressionCV]) def test_check_solver_option(LR): X, y = iris.data, iris.target - # only 'liblinear' solver - msg = "Solver liblinear does not support a multinomial backend." - lr = LR(solver="liblinear", multi_class="multinomial") - with pytest.raises(ValueError, match=msg): - lr.fit(X, y) + # only 'liblinear' and 'newton-cholesky' solver + for solver in ["liblinear", "newton-cholesky"]: + msg = f"Solver {solver} does not support a multinomial backend." + lr = LR(solver=solver, multi_class="multinomial") + with pytest.raises(ValueError, match=msg): + lr.fit(X, y) # all solvers except 'liblinear' and 'saga' - for solver in ["newton-cg", "lbfgs", "sag"]: + for solver in ["lbfgs", "newton-cg", "newton-cholesky", "sag"]: msg = "Solver %s supports only 'l2' or 'none' penalties," % solver lr = LR(solver=solver, penalty="l1", multi_class="ovr") with pytest.raises(ValueError, match=msg): lr.fit(X, y) - for solver in ["newton-cg", "lbfgs", "sag", "saga"]: + for solver in ["lbfgs", "newton-cg", "newton-cholesky", "sag", "saga"]: msg = "Solver %s supports only dual=False, got dual=True" % solver lr = LR(solver=solver, dual=True, multi_class="ovr") with pytest.raises(ValueError, match=msg): @@ -343,7 +359,7 @@ def test_consistency_path(): ) # test for fit_intercept=True - for solver in ("lbfgs", "newton-cg", "liblinear", "sag", "saga"): + for solver in ("lbfgs", "newton-cg", "newton-cholesky", "liblinear", "sag", "saga"): Cs = [1e3] coefs, Cs, _ = f(_logistic_regression_path)( X, @@ -357,7 +373,7 @@ def test_consistency_path(): ) lr = LogisticRegression( C=Cs[0], - tol=1e-4, + tol=1e-6, intercept_scaling=10000.0, random_state=0, multi_class="ovr", @@ -623,11 +639,10 @@ def test_logistic_regression_solvers(): X, y = make_classification(n_features=10, n_informative=5, random_state=0) params = dict(fit_intercept=False, random_state=42, multi_class="ovr") - solvers = ("newton-cg", "lbfgs", "liblinear", "sag", "saga") regressors = { solver: LogisticRegression(solver=solver, **params).fit(X, y) - for solver in solvers + for solver in SOLVERS } for solver_1, solver_2 in itertools.combinations(regressors, r=2): @@ -643,7 +658,6 @@ def test_logistic_regression_solvers_multiclass(): ) tol = 1e-7 params = dict(fit_intercept=False, tol=tol, random_state=42, multi_class="ovr") - solvers = ("newton-cg", "lbfgs", "liblinear", "sag", "saga") # Override max iteration count for specific solvers to allow for # proper convergence. @@ -653,7 +667,7 @@ def test_logistic_regression_solvers_multiclass(): solver: LogisticRegression( solver=solver, max_iter=solver_max_iter.get(solver, 100), **params ).fit(X, y) - for solver in solvers + for solver in SOLVERS } for solver_1, solver_2 in itertools.combinations(regressors, r=2): @@ -662,70 +676,38 @@ def test_logistic_regression_solvers_multiclass(): ) -def test_logistic_regressioncv_class_weights(): - for weight in [{0: 0.1, 1: 0.2}, {0: 0.1, 1: 0.2, 2: 0.5}]: - n_classes = len(weight) - for class_weight in (weight, "balanced"): - X, y = make_classification( - n_samples=30, - n_features=3, - n_repeated=0, - n_informative=3, - n_redundant=0, - n_classes=n_classes, - random_state=0, - ) +@pytest.mark.parametrize("weight", [{0: 0.1, 1: 0.2}, {0: 0.1, 1: 0.2, 2: 0.5}]) +@pytest.mark.parametrize("class_weight", ["weight", "balanced"]) +def test_logistic_regressioncv_class_weights(weight, class_weight): + """Test class_weight for LogisticRegressionCV.""" + n_classes = len(weight) + if class_weight == "weight": + class_weight = weight - clf_lbf = LogisticRegressionCV( - solver="lbfgs", - Cs=1, - fit_intercept=False, - multi_class="ovr", - class_weight=class_weight, - ) - clf_ncg = LogisticRegressionCV( - solver="newton-cg", - Cs=1, - fit_intercept=False, - multi_class="ovr", - class_weight=class_weight, - ) - clf_lib = LogisticRegressionCV( - solver="liblinear", - Cs=1, - fit_intercept=False, - multi_class="ovr", - class_weight=class_weight, - ) - clf_sag = LogisticRegressionCV( - solver="sag", - Cs=1, - fit_intercept=False, - multi_class="ovr", - class_weight=class_weight, - tol=1e-5, - max_iter=10000, - random_state=0, - ) - clf_saga = LogisticRegressionCV( - solver="saga", - Cs=1, - fit_intercept=False, - multi_class="ovr", - class_weight=class_weight, - tol=1e-5, - max_iter=10000, - random_state=0, - ) - clf_lbf.fit(X, y) - clf_ncg.fit(X, y) - clf_lib.fit(X, y) - clf_sag.fit(X, y) - clf_saga.fit(X, y) - assert_array_almost_equal(clf_lib.coef_, clf_lbf.coef_, decimal=4) - assert_array_almost_equal(clf_ncg.coef_, clf_lbf.coef_, decimal=4) - assert_array_almost_equal(clf_sag.coef_, clf_lbf.coef_, decimal=4) - assert_array_almost_equal(clf_saga.coef_, clf_lbf.coef_, decimal=4) + X, y = make_classification( + n_samples=30, + n_features=3, + n_repeated=0, + n_informative=3, + n_redundant=0, + n_classes=n_classes, + random_state=0, + ) + params = dict( + Cs=1, + fit_intercept=False, + multi_class="ovr", + class_weight=class_weight, + ) + clf_lbfgs = LogisticRegressionCV(solver="lbfgs", **params) + clf_lbfgs.fit(X, y) + + for solver in set(SOLVERS) - set(["lbfgs"]): + clf = LogisticRegressionCV(solver=solver, **params) + if solver in ("sag", "saga"): + clf.set_params(tol=1e-5, max_iter=10000, random_state=0) + clf.fit(X, y) + assert_allclose(clf.coef_, clf_lbfgs.coef_, rtol=1e-3) def test_logistic_regression_sample_weights(): @@ -747,23 +729,18 @@ def test_logistic_regression_sample_weights(): clf_sw_ones = LR(solver=solver, **kw) clf_sw_none.fit(X, y) clf_sw_ones.fit(X, y, sample_weight=np.ones(y.shape[0])) - assert_array_almost_equal(clf_sw_none.coef_, clf_sw_ones.coef_, decimal=4) + assert_allclose(clf_sw_none.coef_, clf_sw_ones.coef_, rtol=1e-4) # Test that sample weights work the same with the lbfgs, - # newton-cg, and 'sag' solvers + # newton-cg, newton-cholesky and 'sag' solvers clf_sw_lbfgs = LR(**kw) clf_sw_lbfgs.fit(X, y, sample_weight=sample_weight) - clf_sw_n = LR(solver="newton-cg", **kw) - clf_sw_n.fit(X, y, sample_weight=sample_weight) - clf_sw_sag = LR(solver="sag", tol=1e-10, **kw) - # ignore convergence warning due to small dataset - with ignore_warnings(): - clf_sw_sag.fit(X, y, sample_weight=sample_weight) - clf_sw_liblinear = LR(solver="liblinear", **kw) - clf_sw_liblinear.fit(X, y, sample_weight=sample_weight) - assert_array_almost_equal(clf_sw_lbfgs.coef_, clf_sw_n.coef_, decimal=4) - assert_array_almost_equal(clf_sw_lbfgs.coef_, clf_sw_sag.coef_, decimal=4) - assert_array_almost_equal(clf_sw_lbfgs.coef_, clf_sw_liblinear.coef_, decimal=4) + for solver in set(SOLVERS) - set(("lbfgs", "saga")): + clf_sw = LR(solver=solver, tol=1e-10 if solver == "sag" else 1e-5, **kw) + # ignore convergence warning due to small dataset with sag + with ignore_warnings(): + clf_sw.fit(X, y, sample_weight=sample_weight) + assert_allclose(clf_sw_lbfgs.coef_, clf_sw.coef_, rtol=1e-4) # Test that passing class_weight as [1,2] is the same as # passing class weight = [1,1] but adjusting sample weights @@ -773,7 +750,7 @@ def test_logistic_regression_sample_weights(): clf_cw_12.fit(X, y) clf_sw_12 = LR(solver=solver, **kw) clf_sw_12.fit(X, y, sample_weight=sample_weight) - assert_array_almost_equal(clf_cw_12.coef_, clf_sw_12.coef_, decimal=4) + assert_allclose(clf_cw_12.coef_, clf_sw_12.coef_, rtol=1e-4) # Test the above for l1 penalty and l2 penalty with dual=True. # since the patched liblinear code is different. @@ -849,10 +826,9 @@ def test_logistic_regression_class_weights(): # Binary case: remove 90% of class 0 and 100% of class 2 X = iris.data[45:100, :] y = iris.target[45:100] - solvers = ("lbfgs", "newton-cg", "liblinear") class_weight_dict = _compute_class_weight_dictionary(y) - for solver in solvers: + for solver in set(SOLVERS) - set(("sag", "saga")): clf1 = LogisticRegression( solver=solver, multi_class="ovr", class_weight="balanced" ) @@ -1121,6 +1097,7 @@ def test_logreg_predict_proba_multinomial(): ("sag", "The max_iter was reached which means the coef_ did not converge"), ("saga", "The max_iter was reached which means the coef_ did not converge"), ("lbfgs", "lbfgs failed to converge"), + ("newton-cholesky", "Newton solver did not converge after [0-9]* iterations"), ], ) def test_max_iter(max_iter, multi_class, solver, message): @@ -1128,8 +1105,10 @@ def test_max_iter(max_iter, multi_class, solver, message): X, y_bin = iris.data, iris.target.copy() y_bin[y_bin == 2] = 0 - if solver == "liblinear" and multi_class == "multinomial": - pytest.skip("'multinomial' is unavailable when solver='liblinear'") + if solver in ("liblinear", "newton-cholesky") and multi_class == "multinomial": + pytest.skip("'multinomial' is not supported by liblinear and newton-cholesky") + if solver == "newton-cholesky" and max_iter > 1: + pytest.skip("solver newton-cholesky might converge very fast") lr = LogisticRegression( max_iter=max_iter, @@ -1144,7 +1123,7 @@ def test_max_iter(max_iter, multi_class, solver, message): assert lr.n_iter_[0] == max_iter -@pytest.mark.parametrize("solver", ["newton-cg", "liblinear", "sag", "saga", "lbfgs"]) +@pytest.mark.parametrize("solver", SOLVERS) def test_n_iter(solver): # Test that self.n_iter_ has the correct format. X, y = iris.data, iris.target @@ -1177,7 +1156,7 @@ def test_n_iter(solver): assert clf_cv.n_iter_.shape == (n_classes, n_cv_fold, n_Cs) # multinomial case - if solver == "liblinear": + if solver in ("liblinear", "newton-cholesky"): # This solver only supports one-vs-rest multiclass classification. return @@ -1190,7 +1169,7 @@ def test_n_iter(solver): assert clf_cv.n_iter_.shape == (1, n_cv_fold, n_Cs) -@pytest.mark.parametrize("solver", ("newton-cg", "sag", "saga", "lbfgs")) +@pytest.mark.parametrize("solver", sorted(set(SOLVERS) - set(["liblinear"]))) @pytest.mark.parametrize("warm_start", (True, False)) @pytest.mark.parametrize("fit_intercept", (True, False)) @pytest.mark.parametrize("multi_class", ["ovr", "multinomial"]) @@ -1200,6 +1179,10 @@ def test_warm_start(solver, warm_start, fit_intercept, multi_class): # Warm starting does not work with liblinear solver. X, y = iris.data, iris.target + if solver == "newton-cholesky" and multi_class == "multinomial": + # solver does only support OvR + return + clf = LogisticRegression( tol=1e-4, multi_class=multi_class, @@ -1274,14 +1257,16 @@ def test_saga_vs_liblinear(): @pytest.mark.parametrize("multi_class", ["ovr", "multinomial"]) -@pytest.mark.parametrize("solver", ["newton-cg", "liblinear", "saga"]) +@pytest.mark.parametrize( + "solver", ["liblinear", "newton-cg", "newton-cholesky", "saga"] +) @pytest.mark.parametrize("fit_intercept", [False, True]) def test_dtype_match(solver, multi_class, fit_intercept): # Test that np.float32 input data is not cast to np.float64 when possible # and that the output is approximately the same no matter the input format. - if solver == "liblinear" and multi_class == "multinomial": - pytest.skip("liblinear does not support multinomial logistic") + if solver in ("liblinear", "newton-cholesky") and multi_class == "multinomial": + pytest.skip(f"Solver={solver} does not support multinomial logistic.") out32_type = np.float64 if solver == "liblinear" else np.float32 @@ -1728,9 +1713,10 @@ def test_logistic_regression_path_coefs_multinomial(): ], ids=lambda x: x.__class__.__name__, ) -@pytest.mark.parametrize("solver", ["liblinear", "lbfgs", "newton-cg", "sag", "saga"]) +@pytest.mark.parametrize("solver", SOLVERS) def test_logistic_regression_multi_class_auto(est, solver): - # check multi_class='auto' => multi_class='ovr' iff binary y or liblinear + # check multi_class='auto' => multi_class='ovr' + # iff binary y or liblinear or newton-cholesky def fit(X, y, **kw): return clone(est).set_params(**kw).fit(X, y) @@ -1746,7 +1732,7 @@ def fit(X, y, **kw): assert_allclose(est_auto_bin.predict_proba(X2), est_ovr_bin.predict_proba(X2)) est_auto_multi = fit(X, y_multi, multi_class="auto", solver=solver) - if solver == "liblinear": + if solver in ("liblinear", "newton-cholesky"): est_ovr_multi = fit(X, y_multi, multi_class="ovr", solver=solver) assert_allclose(est_auto_multi.coef_, est_ovr_multi.coef_) assert_allclose( @@ -1770,7 +1756,7 @@ def fit(X, y, **kw): ) -@pytest.mark.parametrize("solver", ("lbfgs", "newton-cg", "sag", "saga")) +@pytest.mark.parametrize("solver", sorted(set(SOLVERS) - set(["liblinear"]))) def test_penalty_none(solver): # - Make sure warning is raised if penalty=None and C is set to a # non-default value. @@ -1944,7 +1930,7 @@ def test_sample_weight_not_modified(multi_class, class_weight): assert_allclose(expected, W) -@pytest.mark.parametrize("solver", ["liblinear", "lbfgs", "newton-cg", "sag", "saga"]) +@pytest.mark.parametrize("solver", SOLVERS) def test_large_sparse_matrix(solver): # Solvers either accept large sparse matrices, or raise helpful error. # Non-regression test for pull-request #21093.