Skip to content

The fit performance of LinearRegression is sub-optimal #22855

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
mbatoul opened this issue Mar 15, 2022 · 24 comments · Fixed by #26207
Open

The fit performance of LinearRegression is sub-optimal #22855

mbatoul opened this issue Mar 15, 2022 · 24 comments · Fixed by #26207

Comments

@mbatoul
Copy link
Contributor

mbatoul commented Mar 15, 2022

It seems that the performance of Linear Regression is sub-optimal when the number of samples is very large.

sklearn_benchmarks measures a speedup of 48 compared to an optimized implementation from scikit-learn-intelex on a 1000000x100 dataset. For a given set of parameters and a given dataset, we compute the speed-up time scikit-learn / time sklearnex. A speed-up of 48 means that sklearnex is 48 times faster than scikit-learn on the given dataset.

Profiling allows a more detailed analysis of the execution of the algorithm. We observe that most of the execution time is spent in the lstsq solver of scipy.

image

The profiling reports of sklearn_benchmarks can be viewed with Perfetto UI.

See benchmark environment information image

It seems that the solver could be better chosen when the number of samples is very large. Perhaps Ridge's solver with a zero penalty could be chosen in this case. On the same dimensions, it shows better performance.

Speedups can be reproduced with the following code:

conda create -n lr_perf -c conda-forge scikit-learn scikit-learn-intelex numpy jupyter
conda activate lr_perf
from sklearn.linear_model import LinearRegression as LinearRegressionSklearn
from sklearnex.linear_model import LinearRegression as LinearRegressionSklearnex
from sklearn.datasets import make_regression
import time
import numpy as np

X, y = make_regression(n_samples=1_000_000, n_features=100, n_informative=10)

def measure(estimator, X, y, n_executions=10):
    times = []
    while len(times) < n_executions:
        t0 = time.perf_counter()
        estimator.fit(X, y)
        t1 = time.perf_counter()
        times.append(t1 - t0)

    return np.mean(times)

mean_time_sklearn = measure(
    estimator=LinearRegressionSklearn(),
    X=X,
    y=y
)

mean_time_sklearnex = measure(
    estimator=LinearRegressionSklearnex(),
    X=X,
    y=y
)

speedup = mean_time_sklearn / mean_time_sklearnex
speedup
@github-actions github-actions bot added the Needs Triage Issue requires triage label Mar 15, 2022
@thomasjpfan
Copy link
Member

Is there any insight on what sklearnex does for linear regression? The deepest level of the call stack I can get to here: https://github.com/intel/scikit-learn-intelex/blob/59fee38028c53bed0bcc3f259880603a997020d1/daal4py/sklearn/linear_model/_linear.py#L51-L64

@jeremiedbb
Copy link
Member

The deepest level of the call stack I can get to here:

After that it calls daal4py which is just python bindings for daal (now onedal) which is itself a c++ lib. So it's hard to follow the stack. It's possible to give a look at their code base https://github.com/oneapi-src/oneDAL but it's not reader friendly :)

@ogrisel
Copy link
Member

ogrisel commented Mar 15, 2022

In particular: https://github.com/oneapi-src/oneDAL/tree/master/cpp/daal/src/algorithms/linear_regression

EDIT which is called by (based on a py-spy --native trace output by Julien):

daal::algorithms::linear_model::normal_equations::training::internal::UpdateKernel<double, (daal::CpuType)6>::compute

this C++ call, comes from the following Python wrapper:

https://github.com/intel/scikit-learn-intelex/blob/f542d87b3d131fd839a2d4bcb86c849d77aae2b8/daal4py/sklearn/linear_model/_linear.py#L60

Note the use of the "defaultDense" and "qrDense" method names.

@TomDLT
Copy link
Member

TomDLT commented Mar 16, 2022

Perhaps Ridge's solver with a zero penalty could be chosen in this case. On the same dimensions, it shows better performance.

Indeed, one can get a good speedup (x13 on my machine) with Ridge(alpha=0).

Arguably, Ridge should always be preferred instead of LinearRegression anyway.

import time
import numpy as np
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.datasets import make_regression

def measure(estimator, X, y, n_executions=10):
    times = []
    while len(times) < n_executions:
        t0 = time.perf_counter()
        estimator.fit(X, y)
        t1 = time.perf_counter()
        times.append(t1 - t0)
    return np.mean(times)

X, y = make_regression(n_samples=1_000_000, n_features=100, n_informative=10)

mean_time_linear_regression = measure(estimator=LinearRegression(), X=X, y=y)
mean_time_ridge = measure(estimator=Ridge(alpha=0), X=X, y=y)
print("speedup =", mean_time_linear_regression / mean_time_ridge)

@lorentzenchr
Copy link
Member

Ridge(solver="auto") is the same as Ridge(solver="cholesky") which means solving the normal equation by a (relatively cheap) Cholesky (LDL) decomposition. For long data (n_samples >> n_features), this is, of course, faster than scipy.linalg.lstsq as used in LinearRegression.
The great advantage of lstsq is it's stability w.r.t. badly conditioned X (e.g. singular).

We could add a solver argument for LinearRegression.

@ogrisel
Copy link
Member

ogrisel commented Mar 16, 2022

The great advantage of lstsq is it's stability w.r.t. badly conditioned X (e.g. singular).

I think we need tests for this: in particular it would be great that we would call LinearRegression and Ridge on such badly conditioned learning problem and checking that warning with informative error message that suggest to use a more stable solver by its name are raised (or that the switch happens automatically with a warning message to inform the user).

@glemaitre
Copy link
Member

glemaitre commented Mar 17, 2022

I just played for a course that I will do tomorrow and indeed our LinearRegression is terrible.

Given the following dataset:

from sklearn.datasets import make_regression

X, y = make_regression(
    n_samples=100_000,
    n_features=1_000,
    noise=10,
)

I did a quick comparison:

%%time
from sklearn.linear_model import LinearRegression

linear_regression = LinearRegression()
linear_regression.fit(X, y)
CPU times: user 1min 50s, sys: 3.84 s, total: 1min 54s
Wall time: 18.5 s

and then I though about using the Normal equation:

%%time
coef = np.linalg.inv(X_with_dummy.T @ X_with_dummy) @ X_with_dummy.T @ y
CPU times: user 33.1 s, sys: 850 ms, total: 33.9 s
Wall time: 6.36 s

and finally using a L-BFGS-B solver:

def loss(coef, X, y):
    y_estimated = X @ coef
    error = y_estimated - y
    cost = (1 / 2 * len(y)) * np.sum(error ** 2)
    return cost


def grad(coef, X, y):
    y_estimated = X @ coef
    error = y_estimated - y
    gradient = (1 / len(y)) * X.T @ error
    return gradient

coef = np.random.randn(X_with_dummy.shape[1])
res = minimize(loss, coef, args=(X_with_dummy, y), method="L-BFGS-B", jac=grad)
coef = res.x
CPU times: user 6.84 s, sys: 3.73 s, total: 10.6 s
Wall time: 1.89 s

fun: 495382014981.9635
 hess_inv: <1001x1001 LbfgsInvHessProduct with dtype=float64>
      jac: array([ 7.67078546e-07, -6.21855026e-07,  7.71316058e-07, ...,
        1.05001006e-06, -2.00795501e-06,  2.40024978e-08])
  message: 'CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL'
     nfev: 10
      nit: 7
     njev: 10
   status: 0
  success: True
        x: array([ 0.08592853, -0.03905589,  0.01284664, ...,  0.03827157,
        0.08447174, -0.03549184])

I recall a discussion in the book of @ageron regarding the complexity of the lstsq and the gradient descent: https://www.oreilly.com/library/view/hands-on-machine-learning/9781491962282/ch04.html

It might be nice to investigate as well the impact of the LAPACK driver because I am not sure that they are all equivalent in different n_samples/n_features regimes.

NB: I hope that I did not mess up the gradient computation, it was done while catching a train so better be prepared for a stupid error.

NB2: I missed to read the above discussion that already discuss the different solvers inclusion.

@lorentzenchr
Copy link
Member

lorentzenchr commented Mar 17, 2022

@glemaitre Could you replace np.linalg.inv by np.linalg.solve?
Next thought: lbfgs is a good reference but special linalg routines should (we expect them to) be always faster and more robust for OLS-like problems.

I would not go so far as to say LinearRegression is terrible. It is always a tradeoff between speed and robustness and this one is clearly on the more/most robust side.

Actionable way forward ~ finish/pick up #17560 ~:

  • LinearRegression could be made a special case of Ridge.
  • For alpha=0, we might consider changing Ridge svd solver to use lstsq instead of our own svd using implementation. This way, both classes would be 100% identical .
  • Tests for singular X

@glemaitre
Copy link
Member

I would not go so far as to say LinearRegression is terrible.

It is what I found reading your comment afterwards :) Hoping that LinearRegression can forgive me for this comment :)

Regarding using the LAPACK _gesv with np.linagl.solve:

# %%
%%time
coef = np.linalg.solve(X_with_dummy.T @ X_with_dummy, X_with_dummy.T @ y)
CPU times: user 10.2 s, sys: 370 ms, total: 10.5 s
Wall time: 1.47 s

It is then as efficient than LBFGS.

@ogrisel
Copy link
Member

ogrisel commented Mar 10, 2023

The great advantage of lstsq is it's stability w.r.t. badly conditioned X (e.g. singular).

I think we need tests for this: in particular it would be great that we would call LinearRegression and Ridge on such badly conditioned learning problem and checking that warning with informative error message that suggest to use a more stable solver by its name are raised (or that the switch happens automatically with a warning message to inform the user).

I ran a quick experiment with a well specified, over-complete model with duplicated columns + machine precision noise to probe the behavior of different solvers with near degenerate problems:

import numpy as np
from sklearn.linear_model import LinearRegression, Ridge
from time import perf_counter

rng = np.random.RandomState(42)
n_samples = 10_000_000
n_features = 2
y = rng.randn(n_samples)
X = np.repeat(y.reshape(-1, 1), n_features, axis=1)
X += rng.randn(*X.shape) * np.finfo(X.dtype).eps / 1e2


class PInvLinearRegression(LinearRegression):

    def fit(self, X, y):
        self.coef_ = np.linalg.pinv(X.T @ X) @ X.T @ y
        self.intercept_ = 0
        return self


models = [
    LinearRegression(fit_intercept=False),
    Ridge(fit_intercept=False, alpha=0, solver="lsqr"),
    Ridge(fit_intercept=False, alpha=0, solver="svd"),
    Ridge(fit_intercept=False, alpha=0, solver="cholesky"),
    PInvLinearRegression(),
]


print(f"n_samples: {n_samples}, n_features: {n_features}")
print()


for model in models:
    durations = []
    print(model)
    for _ in range(3):
        t0 = perf_counter()
        model.fit(X, y)
        durations.append(perf_counter() - t0)
    print(f"{model.coef_.max() - model.coef_.min() = }")
    print(f"duration: {np.mean(durations):.3f} ± {np.std(durations):.3f} s")
    print()

which outputs for different values of n_samples and n_features:

n_samples: 10000000, n_features: 2

LinearRegression(fit_intercept=False)
model.coef_.max() - model.coef_.min() = 0.9999999165424127
duration: 0.375 ± 0.007 s

Ridge(alpha=0, fit_intercept=False, solver='lsqr')
model.coef_.max() - model.coef_.min() = 0.0
duration: 0.119 ± 0.003 s

Ridge(alpha=0, fit_intercept=False, solver='svd')
model.coef_.max() - model.coef_.min() = 2.979069482632696
duration: 0.163 ± 0.002 s

Ridge(alpha=0, fit_intercept=False, solver='cholesky')
model.coef_.max() - model.coef_.min() = 2.979069482632696
duration: 0.205 ± 0.001 s

PInvLinearRegression()
model.coef_.max() - model.coef_.min() = 5.551115123125783e-17
duration: 0.047 ± 0.000 s
n_samples: 1000000, n_features: 100

LinearRegression(fit_intercept=False)
model.coef_.max() - model.coef_.min() = 6.938893903907228e-18
duration: 4.316 ± 0.134 s

Ridge(alpha=0, fit_intercept=False, solver='lsqr')
model.coef_.max() - model.coef_.min() = 0.0
duration: 0.257 ± 0.016 s

Ridge(alpha=0, fit_intercept=False, solver='svd')
model.coef_.max() - model.coef_.min() = 20.90600371073267
duration: 7.155 ± 0.320 s

Ridge(alpha=0, fit_intercept=False, solver='cholesky')
model.coef_.max() - model.coef_.min() = 20.90600371073267
duration: 7.013 ± 0.083 s

PInvLinearRegression()
model.coef_.max() - model.coef_.min() = 1.734723475976807e-17
duration: 0.501 ± 0.038 s

My conclusions:

  • the current solver of LinearRegression scipy.linalg.lstsq is indeed one of the slowest for those tall and narrow problems,
  • it is often more numerically stable than 'svd' and 'cholesky' but it's not always that stable: depending on the number of features, the random seed, the magnitude of the near-machine level perturbations and the number of samples, it sometimes fails to recover the minimum norm solution. I think this is surprising.
  • 'svd' and 'cholesky' are not numericaly stable, but that was expected (at least for 'cholesky'),
  • the scipy.sparse.linalg.lsqr solver of Ridge is both very fast and numerically stable: in my experiments it would always recover the minimum norm solution.
  • my textbook naive pinv-based solution is even faster and also seems numerically stable.

Of course the pinv-based solution would be catastrophic from a memory usage point of view for large n_features. Not sure what are the space complexities of the other solvers.

So I think we should at least expose the lsqr in LinearRegression (by reusing the _solver_lsqr private function from the Ridge code base with alpha=0). We should further conduct a more systematic review of the performance profiles (both time and space) and numerical stability and ability to recover the minimum norm solution for different n_samples, n_features and condition numbers of X.T @ X or X @ X.T.

Then we might want to also add a pinv based solver dedicated to the case where n_features >> n_samples but I am not sure how to implement fit_intercept=True and sample weights efficiently.

@ogrisel
Copy link
Member

ogrisel commented Mar 10, 2023

Note: the experiments above were conducted one Apple M1 processor, so I cannot compare with the solver of scikit-learn-intelex but in the 1M x 100 problem size, I only observe a 16x speed-up (on a machine with 8 cores) instead of a 48x speed up reported by @mbatoul's benchmark.

@ogrisel
Copy link
Member

ogrisel commented Mar 10, 2023

By profiling this script with viztracer I discovered that the cholesky solver would always silently fallback to the svd for this problem (hence nearly exact same results for both solvers). This is ok but I think we should issue a warning to the user. And furthermore, we might want to fallback to lsqr instead of svd since svd does not typically converge to the minimum norm solution on singular problems.

@ogrisel
Copy link
Member

ogrisel commented Mar 10, 2023

Also found on the profile report: close to half of the time of our LinearRegression.fit is spent rescaling the data even though I did not pass sample weights:

image

We should special case the sample_weight=None case to avoid this overhead.

@ogrisel
Copy link
Member

ogrisel commented Mar 10, 2023

For comparison, here is the zoomed in trace of a fit call with lsqr:

image

There is still some signidicant overhead in the _preprocessing_data but at least, it's less than the time spent in the lsqr call itself.

@ogrisel
Copy link
Member

ogrisel commented Mar 10, 2023

For LinearRegression I think we should only offer solvers that converge to the minimum norm solution in case of (near) multicolinear features. For the l2 penalized variant (Ridge), we should document the list of solvers that are expected to remover the same solution when alpha=0. In the case of Ridge, this problem is tracked under:

@ogrisel
Copy link
Member

ogrisel commented Mar 11, 2023

As seen in the profiling report of the lsqr case, in #22855 (comment), the _preprocess_data function used in Ridge induces a significant overhead. Further more, as confirmed in #22947 (comment) this also prevent this model to converge to the minimum norm solution for singular problems (e.g. when n_features > n_samples) when alpha is 0.

So a possible plan to fix the performance problem of LinearRegression would be to:

  • stop relying on scipy.linalg.lstsq and _preprocess_data,
  • instead use a linear operator to use scipy.sparse.linalg.lsqr (or scipy.sparse.linalg.lsmr both for dense and sparse data with or without sample_weight, with or without fit_intercept and always converge to the minimum norm solution, and hopefully with little extra computational overhead and no memory overhead (no input data copy).

@lorentzenchr
Copy link
Member

I am a great supporter of getting rid of _preprocess_data which often took the fun out of my PRs in the past. I am, however, a bit reluctant to agree to stop using lstsq.

Then for LinearRegressor, I guess we want to keep it for educational reasons and not "merge" it with, e.g. Ridge(alpha=0).

About the solvers, it's always the same story:

  • "cholesky": Solving the normal equations is fastest for n_samples >> n_features. But it is not numerically stable for ill-conditioned problems. Under the hood, it uses linalg.solve which should be preferred to linalg.pinv for efficiency. As to stability, both solve and pinv work with X.T @ X and therefore square the condition number of X which makes them less stable for such problems.
  • lstsq is good for n_samples ~ n_features and should be the best for ill-conditioned problems in general, because it works with X and not with X.T @ X. What it does under the hood depends on the driver. Usually, a SVD is involved in some fashion. It is always a bit tricky to include the intercept with this solution, usually done with hstack and therefore doubling the memory usage.
  • scipy.sparse.linalg.lsqr (or even better lsmr) is just great as an iterative solver and replacement for lstsq and only uses mat-vec operations. But it means that we get additional hyperparameters like stopping criteria and number of iterations, that's the downside.

@ogrisel
Copy link
Member

ogrisel commented Mar 12, 2023

Thanks for the feedback. I guess we need to do some more systematic benchmarking (including with various combinations for dense vs sparse, with or without sample weight, fit intercept or not, on well specified / mispecified, well-conditioned / ill conditioned).

But in my experience (with make_regression with various n_samples / n_features, with or without biased data, ...), I found that lsqr and lsmr with tol=1e-12 are almost always faster than cholesky (sometimes both have similar speed and I never saw a case with cholesky being more than 2x faster). The mean_squared_error are the same the 12 significant digits.

Since lsqr and lsmr are also numerically stable (and recover the minimum norm solution in case of singular X.T @ X) I think either of them would be fool-proof default, at least for LinearRegression. lsmr is sometimes a tiny bit faster than lsqr but not always.

I also tried lsqr and lsmr with tol=1e-5 and they are even faster (typically around 2-3x) but then this is another hyper-parameter to adjust and we probably don't want to expose that for LinearRegression.

It seems that for biased problems, and with fit_intercept=True, the y_mean centering in _preprocess_data can make the convergence a tiny bit faster for lsqr. So we might want to keep it, but I think we would then converge to the minium np.linalg.norm(coef_) solution instead of the minimum np.linalg.norm(np.concatenate(coef_, [intercept_])) solution. I think either are fine as long as we document and properly test those behavior to avoid surprises in the future.

One think that I have not evaluated yet is the memory usage induced by the different solvers.

@lorentzenchr
Copy link
Member

Then let‘s switch to lsqr or lsmr as default.

@ogrisel
Copy link
Member

ogrisel commented Mar 22, 2023

I am a great supporter of getting rid of _preprocess_data which often took the fun out of my PRs in the past. I am, however, a bit reluctant to agree to stop using lstsq.

Actually we should not get rid of _preprocess_data and _set_intercept. They are the valid way to compute the least norm estimator when we have an intercept which should not contribute to the norm so as to naturally extend to penalized model that do not penalize the intercept. I documented this in more details in the following PR #25948 (draft).

@lorentzenchr
Copy link
Member

Also found on the profile report: close to half of the time of our LinearRegression.fit is spent rescaling the data even though I did not pass sample weights:

@ogrisel The viztracer result for LinearRegression.fit seems strange because it calls tocsr inside _rescale_data!!! I'm becoming more and more an opponent of our data rescaling.

@ogrisel
Copy link
Member

ogrisel commented Apr 18, 2023

The viztracer result for LinearRegression.fit seems strange because it calls tocsr inside _rescale_data!!!

This should be fixed by #26207. I will re-run a bunch of viztracer profiles to check that this is actually the case.

@ogrisel
Copy link
Member

ogrisel commented Apr 19, 2023

Let's reopen this issue because #26207 only fixed some of the performance problems caused by preprocessing/rescaling but there is still even more significant improvements to expect from enabling and maybe selecting as default a better solver such as scipy.sparse.linalg.lsqr instead of the SVD / LAPACK-based scipy.linalg.lstsq solver.

@AhmedThahir
Copy link

Just following up if there is any updates?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Discussion
9 participants