Skip to content

Commit 18eef9a

Browse files
authored
FEA Add positive argument to Ridge using L-BFGS-B solver (#20231)
1 parent e4ef854 commit 18eef9a

File tree

3 files changed

+341
-38
lines changed

3 files changed

+341
-38
lines changed

doc/whats_new/v1.0.rst

+6
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,12 @@ Changelog
405405
:user:`Oliver Grisel <ogrisel>` and
406406
:user:`Christian Lorentzen <lorentzenchr>`.
407407

408+
- |Feature| Added new solver `lbfgs` (available with `solver="lbfgs")
409+
and `positive` argument to class:`linear_model.Ridge`.
410+
When `positive` is set to True, forces the coefficients to be positive
411+
(only supported by `lbfgs`).
412+
:pr:`20231` by :user:`Toshihiro Nakae <tnakae>`.
413+
408414
:mod:`sklearn.manifold`
409415
.......................
410416

sklearn/linear_model/_ridge.py

+159-23
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import numpy as np
1616
from scipy import linalg
1717
from scipy import sparse
18+
from scipy import optimize
1819
from scipy.sparse import linalg as sp_linalg
1920

2021
from ._base import LinearClassifierMixin, LinearModel
@@ -235,6 +236,64 @@ def _solve_svd(X, y, alpha):
235236
return np.dot(Vt.T, d_UT_y).T
236237

237238

239+
def _solve_lbfgs(
240+
X, y, alpha, positive=True, max_iter=None, tol=1e-3, X_offset=None, X_scale=None
241+
):
242+
"""Solve ridge regression with LBFGS.
243+
244+
The main purpose is fitting with forcing coefficients to be positive.
245+
For unconstrained ridge regression, there are faster dedicated solver methods.
246+
Note that with positive bounds on the coefficients, LBFGS seems faster
247+
than scipy.optimize.lsq_linear.
248+
"""
249+
n_samples, n_features = X.shape
250+
251+
options = {}
252+
if max_iter is not None:
253+
options["maxiter"] = max_iter
254+
config = {
255+
"method": "L-BFGS-B",
256+
"tol": tol,
257+
"jac": True,
258+
"options": options,
259+
}
260+
if positive:
261+
config["bounds"] = [(0, np.inf)] * n_features
262+
263+
if X_offset is not None and X_scale is not None:
264+
X_offset_scale = X_offset / X_scale
265+
else:
266+
X_offset_scale = None
267+
268+
coefs = np.empty((y.shape[1], n_features), dtype=X.dtype)
269+
270+
for i in range(y.shape[1]):
271+
x0 = np.zeros((n_features,))
272+
y_column = y[:, i]
273+
274+
def func(w):
275+
residual = X.dot(w) - y_column
276+
if X_offset_scale is not None:
277+
residual -= w.dot(X_offset_scale)
278+
f = 0.5 * residual.dot(residual) + 0.5 * alpha[i] * w.dot(w)
279+
grad = X.T @ residual + alpha[i] * w
280+
if X_offset_scale is not None:
281+
grad -= X_offset_scale * np.sum(residual)
282+
283+
return f, grad
284+
285+
result = optimize.minimize(func, x0, **config)
286+
if not result["success"]:
287+
warnings.warn(
288+
"The lbfgs solver did not converge. Try increasing max_iter "
289+
f"or tol. Currently: max_iter={max_iter} and tol={tol}",
290+
ConvergenceWarning,
291+
)
292+
coefs[i] = result["x"]
293+
294+
return coefs
295+
296+
238297
def _get_valid_accept_sparse(is_X_sparse, solver):
239298
if is_X_sparse and solver in ["auto", "sag", "saga"]:
240299
return "csr"
@@ -252,6 +311,7 @@ def ridge_regression(
252311
max_iter=None,
253312
tol=1e-3,
254313
verbose=0,
314+
positive=False,
255315
random_state=None,
256316
return_n_iter=False,
257317
return_intercept=False,
@@ -287,8 +347,8 @@ def ridge_regression(
287347
288348
.. versionadded:: 0.17
289349
290-
solver : {'auto', 'svd', 'cholesky', 'lsqr', 'sparse_cg', 'sag', 'saga'}, \
291-
default='auto'
350+
solver : {'auto', 'svd', 'cholesky', 'lsqr', 'sparse_cg', \
351+
'sag', 'saga', 'lbfgs'}, default='auto'
292352
Solver to use in the computational routines:
293353
294354
- 'auto' chooses the solver automatically based on the type of data.
@@ -317,10 +377,13 @@ def ridge_regression(
317377
approximately the same scale. You can preprocess the data with a
318378
scaler from sklearn.preprocessing.
319379
380+
- 'lbfgs' uses L-BFGS-B algorithm implemented in
381+
`scipy.optimize.minimize`. It can be used only when `positive`
382+
is True.
320383
321-
All last five solvers support both dense and sparse data. However, only
322-
'sag' and 'sparse_cg' supports sparse input when `fit_intercept` is
323-
True.
384+
All last six solvers support both dense and sparse data. However, only
385+
'sag', 'sparse_cg', and 'lbfgs' support sparse input when `fit_intercept`
386+
is True.
324387
325388
.. versionadded:: 0.17
326389
Stochastic Average Gradient descent solver.
@@ -331,7 +394,7 @@ def ridge_regression(
331394
Maximum number of iterations for conjugate gradient solver.
332395
For the 'sparse_cg' and 'lsqr' solvers, the default value is determined
333396
by scipy.sparse.linalg. For 'sag' and saga solver, the default value is
334-
1000.
397+
1000. For 'lbfgs' solver, the default value is 15000.
335398
336399
tol : float, default=1e-3
337400
Precision of the solution.
@@ -340,6 +403,10 @@ def ridge_regression(
340403
Verbosity level. Setting verbose > 0 will display additional
341404
information depending on the solver used.
342405
406+
positive : bool, default=False
407+
When set to ``True``, forces the coefficients to be positive.
408+
Only 'lbfgs' solver is supported in this case.
409+
343410
random_state : int, RandomState instance, default=None
344411
Used when ``solver`` == 'sag' or 'saga' to shuffle the data.
345412
See :term:`Glossary <random_state>` for details.
@@ -389,6 +456,7 @@ def ridge_regression(
389456
max_iter=max_iter,
390457
tol=tol,
391458
verbose=verbose,
459+
positive=positive,
392460
random_state=random_state,
393461
return_n_iter=return_n_iter,
394462
return_intercept=return_intercept,
@@ -407,6 +475,7 @@ def _ridge_regression(
407475
max_iter=None,
408476
tol=1e-3,
409477
verbose=0,
478+
positive=False,
410479
random_state=None,
411480
return_n_iter=False,
412481
return_intercept=False,
@@ -418,18 +487,33 @@ def _ridge_regression(
418487
has_sw = sample_weight is not None
419488

420489
if solver == "auto":
421-
if return_intercept:
422-
# only sag supports fitting intercept directly
490+
if positive:
491+
solver = "lbfgs"
492+
elif return_intercept:
493+
# sag supports fitting intercept directly
423494
solver = "sag"
424495
elif not sparse.issparse(X):
425496
solver = "cholesky"
426497
else:
427498
solver = "sparse_cg"
428499

429-
if solver not in ("sparse_cg", "cholesky", "svd", "lsqr", "sag", "saga"):
500+
if solver not in ("sparse_cg", "cholesky", "svd", "lsqr", "sag", "saga", "lbfgs"):
430501
raise ValueError(
431502
"Known solvers are 'sparse_cg', 'cholesky', 'svd'"
432-
" 'lsqr', 'sag' or 'saga'. Got %s." % solver
503+
" 'lsqr', 'sag', 'saga' or 'lbfgs'. Got %s." % solver
504+
)
505+
506+
if positive and solver != "lbfgs":
507+
raise ValueError(
508+
"When positive=True, only 'lbfgs' solver can be used. "
509+
f"Please change solver {solver} to 'lbfgs' "
510+
"or set positive=False."
511+
)
512+
513+
if solver == "lbfgs" and not positive:
514+
raise ValueError(
515+
"'lbfgs' solver can be used only when positive=True. "
516+
"Please use another solver."
433517
)
434518

435519
if return_intercept and solver != "sag":
@@ -554,6 +638,18 @@ def _ridge_regression(
554638
intercept = intercept[0]
555639
coef = np.asarray(coef)
556640

641+
elif solver == "lbfgs":
642+
coef = _solve_lbfgs(
643+
X,
644+
y,
645+
alpha,
646+
positive=positive,
647+
tol=tol,
648+
max_iter=max_iter,
649+
X_offset=X_offset,
650+
X_scale=X_scale,
651+
)
652+
557653
if solver == "svd":
558654
if sparse.issparse(X):
559655
raise TypeError("SVD solver does not support sparse inputs currently")
@@ -585,6 +681,7 @@ def __init__(
585681
max_iter=None,
586682
tol=1e-3,
587683
solver="auto",
684+
positive=False,
588685
random_state=None,
589686
):
590687
self.alpha = alpha
@@ -594,6 +691,7 @@ def __init__(
594691
self.max_iter = max_iter
595692
self.tol = tol
596693
self.solver = solver
694+
self.positive = positive
597695
self.random_state = random_state
598696

599697
def fit(self, X, y, sample_weight=None):
@@ -612,16 +710,31 @@ def fit(self, X, y, sample_weight=None):
612710
multi_output=True,
613711
y_numeric=True,
614712
)
615-
if sparse.issparse(X) and self.fit_intercept:
616-
if self.solver not in ["auto", "sparse_cg", "sag"]:
713+
if self.solver == "lbfgs" and not self.positive:
714+
raise ValueError(
715+
"'lbfgs' solver can be used only when positive=True. "
716+
"Please use another solver."
717+
)
718+
719+
if self.positive:
720+
if self.solver not in ["auto", "lbfgs"]:
721+
raise ValueError(
722+
f"solver='{self.solver}' does not support positive fitting. Please"
723+
" set the solver to 'auto' or 'lbfgs', or set `positive=False`"
724+
)
725+
else:
726+
solver = self.solver
727+
elif sparse.issparse(X) and self.fit_intercept:
728+
if self.solver not in ["auto", "sparse_cg", "sag", "lbfgs"]:
617729
raise ValueError(
618730
"solver='{}' does not support fitting the intercept "
619731
"on sparse data. Please set the solver to 'auto' or "
620-
"'sparse_cg', 'sag', or set `fit_intercept=False`".format(
621-
self.solver
622-
)
732+
"'sparse_cg', 'sag', 'lbfgs' "
733+
"or set `fit_intercept=False`".format(self.solver)
623734
)
624-
if self.solver == "sag" and self.max_iter is None and self.tol > 1e-4:
735+
if self.solver == "lbfgs":
736+
solver = "lbfgs"
737+
elif self.solver == "sag" and self.max_iter is None and self.tol > 1e-4:
625738
warnings.warn(
626739
'"sag" solver requires many iterations to fit '
627740
"an intercept with sparse inputs. Either set the "
@@ -658,6 +771,7 @@ def fit(self, X, y, sample_weight=None):
658771
max_iter=self.max_iter,
659772
tol=self.tol,
660773
solver="sag",
774+
positive=self.positive,
661775
random_state=self.random_state,
662776
return_n_iter=True,
663777
return_intercept=True,
@@ -682,6 +796,7 @@ def fit(self, X, y, sample_weight=None):
682796
max_iter=self.max_iter,
683797
tol=self.tol,
684798
solver=solver,
799+
positive=self.positive,
685800
random_state=self.random_state,
686801
return_n_iter=True,
687802
return_intercept=False,
@@ -744,12 +859,13 @@ class Ridge(MultiOutputMixin, RegressorMixin, _BaseRidge):
744859
Maximum number of iterations for conjugate gradient solver.
745860
For 'sparse_cg' and 'lsqr' solvers, the default value is determined
746861
by scipy.sparse.linalg. For 'sag' solver, the default value is 1000.
862+
For 'lbfgs' solver, the default value is 15000.
747863
748864
tol : float, default=1e-3
749865
Precision of the solution.
750866
751-
solver : {'auto', 'svd', 'cholesky', 'lsqr', 'sparse_cg', 'sag', 'saga'}, \
752-
default='auto'
867+
solver : {'auto', 'svd', 'cholesky', 'lsqr', 'sparse_cg', \
868+
'sag', 'saga', 'lbfgs'}, default='auto'
753869
Solver to use in the computational routines:
754870
755871
- 'auto' chooses the solver automatically based on the type of data.
@@ -777,15 +893,23 @@ class Ridge(MultiOutputMixin, RegressorMixin, _BaseRidge):
777893
approximately the same scale. You can preprocess the data with a
778894
scaler from sklearn.preprocessing.
779895
780-
All last five solvers support both dense and sparse data. However, only
781-
'sag' and 'sparse_cg' supports sparse input when `fit_intercept` is
782-
True.
896+
- 'lbfgs' uses L-BFGS-B algorithm implemented in
897+
`scipy.optimize.minimize`. It can be used only when `positive`
898+
is True.
899+
900+
All last six solvers support both dense and sparse data. However, only
901+
'sag', 'sparse_cg', and 'lbfgs' support sparse input when `fit_intercept`
902+
is True.
783903
784904
.. versionadded:: 0.17
785905
Stochastic Average Gradient descent solver.
786906
.. versionadded:: 0.19
787907
SAGA solver.
788908
909+
positive : bool, default=False
910+
When set to ``True``, forces the coefficients to be positive.
911+
Only 'lbfgs' solver is supported in this case.
912+
789913
random_state : int, RandomState instance, default=None
790914
Used when ``solver`` == 'sag' or 'saga' to shuffle the data.
791915
See :term:`Glossary <random_state>` for details.
@@ -843,6 +967,7 @@ def __init__(
843967
max_iter=None,
844968
tol=1e-3,
845969
solver="auto",
970+
positive=False,
846971
random_state=None,
847972
):
848973
super().__init__(
@@ -853,6 +978,7 @@ def __init__(
853978
max_iter=max_iter,
854979
tol=tol,
855980
solver=solver,
981+
positive=positive,
856982
random_state=random_state,
857983
)
858984

@@ -932,8 +1058,8 @@ class RidgeClassifier(LinearClassifierMixin, _BaseRidge):
9321058
weights inversely proportional to class frequencies in the input data
9331059
as ``n_samples / (n_classes * np.bincount(y))``.
9341060
935-
solver : {'auto', 'svd', 'cholesky', 'lsqr', 'sparse_cg', 'sag', 'saga'}, \
936-
default='auto'
1061+
solver : {'auto', 'svd', 'cholesky', 'lsqr', 'sparse_cg', \
1062+
'sag', 'saga', 'lbfgs'}, default='auto'
9371063
Solver to use in the computational routines:
9381064
9391065
- 'auto' chooses the solver automatically based on the type of data.
@@ -966,6 +1092,14 @@ class RidgeClassifier(LinearClassifierMixin, _BaseRidge):
9661092
.. versionadded:: 0.19
9671093
SAGA solver.
9681094
1095+
- 'lbfgs' uses L-BFGS-B algorithm implemented in
1096+
`scipy.optimize.minimize`. It can be used only when `positive`
1097+
is True.
1098+
1099+
positive : bool, default=False
1100+
When set to ``True``, forces the coefficients to be positive.
1101+
Only 'lbfgs' solver is supported in this case.
1102+
9691103
random_state : int, RandomState instance, default=None
9701104
Used when ``solver`` == 'sag' or 'saga' to shuffle the data.
9711105
See :term:`Glossary <random_state>` for details.
@@ -1025,6 +1159,7 @@ def __init__(
10251159
tol=1e-3,
10261160
class_weight=None,
10271161
solver="auto",
1162+
positive=False,
10281163
random_state=None,
10291164
):
10301165
super().__init__(
@@ -1035,6 +1170,7 @@ def __init__(
10351170
max_iter=max_iter,
10361171
tol=tol,
10371172
solver=solver,
1173+
positive=positive,
10381174
random_state=random_state,
10391175
)
10401176
self.class_weight = class_weight

0 commit comments

Comments
 (0)