-
-
Notifications
You must be signed in to change notification settings - Fork 26k
[MRG] LinearRegression Optimizations #17560
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3ea2b72
bf0ed78
d3f7b74
bb40fd2
a9579c2
c9ac69e
3fe0996
75cde9b
a1740b3
0283e52
3931d91
f8baf27
b437ab1
5c61941
be3224e
8e6545b
aafd093
6fc98c9
76d3994
955c0d9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -61,6 +61,12 @@ example, when data are collected without an experimental design. | |||||
|
||||||
* :ref:`sphx_glr_auto_examples_linear_model_plot_ols.py` | ||||||
|
||||||
Ordinary Least Squares uses a Singular Value Decomposition (SVD) based | ||||||
approach. The ``LinearRegression`` class has an additional, optional | ||||||
``solver`` parameter, which if set to ``"cholesky"`` uses the Cholesky | ||||||
factorization instead. See `these notes | ||||||
<https://www.cs.ubc.ca/~schmidtm/Courses/540-F14/leastSquares.pdf>` for a | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It might be better to place references under a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
correct web link syntax at a minimum |
||||||
discussion of the tradeoffs. | ||||||
|
||||||
Ordinary Least Squares Complexity | ||||||
--------------------------------- | ||||||
|
@@ -71,6 +77,16 @@ this method has a cost of | |||||
:math:`O(n_{\text{samples}} n_{\text{features}}^2)`, assuming that | ||||||
:math:`n_{\text{samples}} \geq n_{\text{features}}`. | ||||||
|
||||||
Cholesky Complexity | ||||||
--------------------------------- | ||||||
|
||||||
The Cholesky solution is computed using the Cholesky factorization of | ||||||
X. If X is a matrix of shape `(n_samples, n_features)` this method has | ||||||
Comment on lines
+83
to
+84
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The cholesky solver solves the normal equation by a Cholesky decomposition of X'X or XX', whichever has smaller dimension. That's why the condition number of the Least Squares problem is doubled and the numerical solution can become more unstable compared to approaches that use a decomposition of X alone like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this would be interesting to expand the doc to include this remark, possibly as a new paragraph. |
||||||
a cost of | ||||||
:math:`O(n_{\text{samples}} n_{\text{features}}^2)` to form | ||||||
:math:`X^{\intercal}X` and :math:`O(n_{\text{features}}^3)` to run the | ||||||
solver. | ||||||
|
||||||
.. _ridge_regression: | ||||||
|
||||||
Ridge regression and classification | ||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -26,6 +26,7 @@ | |||||
|
||||||
from ..base import (BaseEstimator, ClassifierMixin, RegressorMixin, | ||||||
MultiOutputMixin) | ||||||
from ._ridge_solvers import _cholesky_helper | ||||||
from ..utils import check_array | ||||||
from ..utils.validation import FLOAT_DTYPES | ||||||
from ..utils.validation import _deprecate_positional_args | ||||||
|
@@ -419,6 +420,13 @@ class LinearRegression(MultiOutputMixin, RegressorMixin, LinearModel): | |||||
``-1`` means using all processors. See :term:`Glossary <n_jobs>` | ||||||
for more details. | ||||||
|
||||||
solver : str, default="lstsq" | ||||||
The solver to use. ``"lstsq"`` uses a SVD-based least-squares | ||||||
approach, by calling ``scipy.linalg.lstsq``. ``"cholesky"`` uses the | ||||||
rithvikrao marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
Cholesky decomposition. If X is singular, then ``"cholesky"`` will | ||||||
instead use an SVD-based solver. ``"cholesky"`` does not support `X` | ||||||
matrices which are both singular and sparse. | ||||||
Comment on lines
+426
to
+428
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see this is the behavior for |
||||||
|
||||||
Attributes | ||||||
---------- | ||||||
coef_ : array of shape (n_features, ) or (n_targets, n_features) | ||||||
|
@@ -472,11 +480,14 @@ class LinearRegression(MultiOutputMixin, RegressorMixin, LinearModel): | |||||
""" | ||||||
@_deprecate_positional_args | ||||||
def __init__(self, *, fit_intercept=True, normalize=False, copy_X=True, | ||||||
n_jobs=None): | ||||||
n_jobs=None, solver="lstsq"): | ||||||
self.fit_intercept = fit_intercept | ||||||
self.normalize = normalize | ||||||
self.copy_X = copy_X | ||||||
self.n_jobs = n_jobs | ||||||
if solver not in ["lstsq", "cholesky"]: | ||||||
raise ValueError("Solver must be either `lstsq` or `cholesky`") | ||||||
Comment on lines
+488
to
+489
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By convention we validate hyperparameters in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Furthermore it would be great to also expand the error message to report the observed invalid value of the raise ValueError(f'Solver must be either "lstsq" or "cholesky", got: {repr(solver)}') |
||||||
self.solver = solver | ||||||
|
||||||
def fit(self, X, y, sample_weight=None): | ||||||
""" | ||||||
|
@@ -518,7 +529,29 @@ def fit(self, X, y, sample_weight=None): | |||||
# Sample weight can be implemented via a simple rescaling. | ||||||
X, y = _rescale_data(X, y, sample_weight) | ||||||
|
||||||
if sp.issparse(X): | ||||||
if self.solver == "cholesky": | ||||||
rithvikrao marked this conversation as resolved.
Show resolved
Hide resolved
rithvikrao marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
n_samples, n_features = X.shape | ||||||
rithvikrao marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
ravel = False | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would be more explicit:
Suggested change
|
||||||
if y.ndim == 1: | ||||||
y = y.reshape(-1, 1) | ||||||
ravel = True | ||||||
n_samples_, n_targets = y.shape | ||||||
alpha = np.array([0], dtype=X.dtype) | ||||||
|
||||||
if n_targets > 1: | ||||||
alpha = np.repeat(alpha, n_targets) | ||||||
|
||||||
try: | ||||||
self.coef_ = _cholesky_helper(X, y, alpha, n_features, | ||||||
n_samples) | ||||||
except TypeError: | ||||||
raise TypeError('X matrix is singular and sparse, and not' | ||||||
jnothman marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
'supported by the Cholesky solver. ') | ||||||
|
||||||
if ravel: | ||||||
# When y was passed as a 1d-array, we flatten the coefficients. | ||||||
self.coef_ = self.coef_.ravel() | ||||||
elif sp.issparse(X): | ||||||
X_offset_scale = X_offset / X_scale | ||||||
|
||||||
def matvec(b): | ||||||
|
@@ -544,7 +577,7 @@ def rmatvec(b): | |||||
self._residues = np.vstack([out[3] for out in outs]) | ||||||
else: | ||||||
self.coef_, self._residues, self.rank_, self.singular_ = \ | ||||||
linalg.lstsq(X, y) | ||||||
linalg.lstsq(X, y, check_finite=False) | ||||||
self.coef_ = self.coef_.T | ||||||
|
||||||
if y.ndim == 1: | ||||||
|
Uh oh!
There was an error while loading. Please reload this page.