-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
The fit performance of LinearRegression is sub-optimal #22855
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Is there any insight on what sklearnex does for linear regression? The deepest level of the call stack I can get to here: https://github.com/intel/scikit-learn-intelex/blob/59fee38028c53bed0bcc3f259880603a997020d1/daal4py/sklearn/linear_model/_linear.py#L51-L64 |
After that it calls daal4py which is just python bindings for daal (now onedal) which is itself a c++ lib. So it's hard to follow the stack. It's possible to give a look at their code base https://github.com/oneapi-src/oneDAL but it's not reader friendly :) |
In particular: https://github.com/oneapi-src/oneDAL/tree/master/cpp/daal/src/algorithms/linear_regression EDIT which is called by (based on a
this C++ call, comes from the following Python wrapper: Note the use of the |
Indeed, one can get a good speedup (x13 on my machine) with Arguably, import time
import numpy as np
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.datasets import make_regression
def measure(estimator, X, y, n_executions=10):
times = []
while len(times) < n_executions:
t0 = time.perf_counter()
estimator.fit(X, y)
t1 = time.perf_counter()
times.append(t1 - t0)
return np.mean(times)
X, y = make_regression(n_samples=1_000_000, n_features=100, n_informative=10)
mean_time_linear_regression = measure(estimator=LinearRegression(), X=X, y=y)
mean_time_ridge = measure(estimator=Ridge(alpha=0), X=X, y=y)
print("speedup =", mean_time_linear_regression / mean_time_ridge) |
We could add a |
I think we need tests for this: in particular it would be great that we would call LinearRegression and Ridge on such badly conditioned learning problem and checking that warning with informative error message that suggest to use a more stable solver by its name are raised (or that the switch happens automatically with a warning message to inform the user). |
I just played for a course that I will do tomorrow and indeed our Given the following dataset: from sklearn.datasets import make_regression
X, y = make_regression(
n_samples=100_000,
n_features=1_000,
noise=10,
) I did a quick comparison: %%time
from sklearn.linear_model import LinearRegression
linear_regression = LinearRegression()
linear_regression.fit(X, y) CPU times: user 1min 50s, sys: 3.84 s, total: 1min 54s
Wall time: 18.5 s and then I though about using the Normal equation: %%time
coef = np.linalg.inv(X_with_dummy.T @ X_with_dummy) @ X_with_dummy.T @ y CPU times: user 33.1 s, sys: 850 ms, total: 33.9 s
Wall time: 6.36 s and finally using a L-BFGS-B solver: def loss(coef, X, y):
y_estimated = X @ coef
error = y_estimated - y
cost = (1 / 2 * len(y)) * np.sum(error ** 2)
return cost
def grad(coef, X, y):
y_estimated = X @ coef
error = y_estimated - y
gradient = (1 / len(y)) * X.T @ error
return gradient
coef = np.random.randn(X_with_dummy.shape[1])
res = minimize(loss, coef, args=(X_with_dummy, y), method="L-BFGS-B", jac=grad)
coef = res.x CPU times: user 6.84 s, sys: 3.73 s, total: 10.6 s
Wall time: 1.89 s
fun: 495382014981.9635
hess_inv: <1001x1001 LbfgsInvHessProduct with dtype=float64>
jac: array([ 7.67078546e-07, -6.21855026e-07, 7.71316058e-07, ...,
1.05001006e-06, -2.00795501e-06, 2.40024978e-08])
message: 'CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL'
nfev: 10
nit: 7
njev: 10
status: 0
success: True
x: array([ 0.08592853, -0.03905589, 0.01284664, ..., 0.03827157,
0.08447174, -0.03549184]) I recall a discussion in the book of @ageron regarding the complexity of the It might be nice to investigate as well the impact of the LAPACK driver because I am not sure that they are all equivalent in different NB: I hope that I did not mess up the gradient computation, it was done while catching a train so better be prepared for a stupid error. NB2: I missed to read the above discussion that already discuss the different solvers inclusion. |
@glemaitre Could you replace I would not go so far as to say Actionable way forward ~ finish/pick up #17560 ~:
|
It is what I found reading your comment afterwards :) Hoping that Regarding using the LAPACK # %%
%%time
coef = np.linalg.solve(X_with_dummy.T @ X_with_dummy, X_with_dummy.T @ y) CPU times: user 10.2 s, sys: 370 ms, total: 10.5 s
Wall time: 1.47 s It is then as efficient than LBFGS. |
I ran a quick experiment with a well specified, over-complete model with duplicated columns + machine precision noise to probe the behavior of different solvers with near degenerate problems: import numpy as np
from sklearn.linear_model import LinearRegression, Ridge
from time import perf_counter
rng = np.random.RandomState(42)
n_samples = 10_000_000
n_features = 2
y = rng.randn(n_samples)
X = np.repeat(y.reshape(-1, 1), n_features, axis=1)
X += rng.randn(*X.shape) * np.finfo(X.dtype).eps / 1e2
class PInvLinearRegression(LinearRegression):
def fit(self, X, y):
self.coef_ = np.linalg.pinv(X.T @ X) @ X.T @ y
self.intercept_ = 0
return self
models = [
LinearRegression(fit_intercept=False),
Ridge(fit_intercept=False, alpha=0, solver="lsqr"),
Ridge(fit_intercept=False, alpha=0, solver="svd"),
Ridge(fit_intercept=False, alpha=0, solver="cholesky"),
PInvLinearRegression(),
]
print(f"n_samples: {n_samples}, n_features: {n_features}")
print()
for model in models:
durations = []
print(model)
for _ in range(3):
t0 = perf_counter()
model.fit(X, y)
durations.append(perf_counter() - t0)
print(f"{model.coef_.max() - model.coef_.min() = }")
print(f"duration: {np.mean(durations):.3f} ± {np.std(durations):.3f} s")
print() which outputs for different values of
My conclusions:
Of course the pinv-based solution would be catastrophic from a memory usage point of view for large So I think we should at least expose the Then we might want to also add a pinv based solver dedicated to the case where |
Note: the experiments above were conducted one Apple M1 processor, so I cannot compare with the solver of |
By profiling this script with |
For |
As seen in the profiling report of the So a possible plan to fix the performance problem of LinearRegression would be to:
|
I am a great supporter of getting rid of Then for About the solvers, it's always the same story:
|
Thanks for the feedback. I guess we need to do some more systematic benchmarking (including with various combinations for dense vs sparse, with or without sample weight, fit intercept or not, on well specified / mispecified, well-conditioned / ill conditioned). But in my experience (with Since I also tried It seems that for biased problems, and with One think that I have not evaluated yet is the memory usage induced by the different solvers. |
Then let‘s switch to lsqr or lsmr as default. |
Actually we should not get rid of |
@ogrisel The viztracer result for |
This should be fixed by #26207. I will re-run a bunch of viztracer profiles to check that this is actually the case. |
Let's reopen this issue because #26207 only fixed some of the performance problems caused by preprocessing/rescaling but there is still even more significant improvements to expect from enabling and maybe selecting as default a better solver such as |
Just following up if there is any updates? |
It seems that the performance of Linear Regression is sub-optimal when the number of samples is very large.
sklearn_benchmarks measures a speedup of 48 compared to an optimized implementation from scikit-learn-intelex on a
1000000x100
dataset. For a given set of parameters and a given dataset, we compute the speed-uptime scikit-learn
/time sklearnex
. A speed-up of 48 means that sklearnex is 48 times faster than scikit-learn on the given dataset.Profiling allows a more detailed analysis of the execution of the algorithm. We observe that most of the execution time is spent in the
lstsq
solver of scipy.The profiling reports of sklearn_benchmarks can be viewed with Perfetto UI.
See benchmark environment information
It seems that the solver could be better chosen when the number of samples is very large. Perhaps Ridge's solver with a zero penalty could be chosen in this case. On the same dimensions, it shows better performance.
Speedups can be reproduced with the following code:
The text was updated successfully, but these errors were encountered: