-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
ENH avoid futile recomputation of R_sum in sparse_enet_coordinate_descent #31387
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…cent * R_sum is np.sum(residual) and won't change by a coordinate upate if X_mean is provided.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks @lorentzenchr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code is quite low level, it's not easy to check for math invariants. So I wrote this script to empirically validate the claims of the PR.
# %%
import scipy.sparse as sp
import numpy as np
from sklearn.linear_model import ElasticNet
from time import perf_counter
rng = np.random.default_rng(42)
X = rng.uniform(size=(1_000, 100_000))
X[X < 0.9] = 0.0 # Sparsify the matrix
X_sparse = sp.csc_array(X)
w_true = np.zeros(X.shape[1])
w_true[0:300] = 10.0
assert X.mean(axis=0)[0:5].min() > 0.01
y = X_sparse @ w_true + 1 + rng.normal(size=X.shape[0]) * 0.1
# %%
reg = ElasticNet(
alpha=0.1, l1_ratio=0.5, fit_intercept=True, selection="random", random_state=42
)
coef_dense = reg.fit(X, y).coef_
intercept_dense = reg.intercept_
print("10 first dense coefficients:", coef_dense[:10])
print("dense intercept:", intercept_dense)
# %%
for i in range(5):
tic = perf_counter()
coef_sparse = reg.fit(X_sparse, y).coef_
toc = perf_counter()
print(f"Time for sparse fit: {toc - tic:.3f} seconds")
intercept_sparse = reg.intercept_
print("10 first sparse coefficients:", coef_sparse[:10])
print("sparse intercept:", intercept_sparse)
assert np.allclose(coef_dense, coef_sparse)
assert np.allclose(intercept_dense, intercept_sparse)
The results are:
- both
main
and this branch yields the same coefficients/intercept as expected, and the assertions always pass (they use the dense implementation as a reference that is not touched by this PR). - the sparse fit time is approximately 3x faster with this optimization on this data (it strongly depends on the sparsity level).
Good job @lorentzenchr!
/cc @agramfort @mathurinm.
@ogrisel Thanks for the benchmark and review. I would not have guessed the size of the improvement 😇. The equivalence of different solvers is already checked in a test😉 |
…cent (scikit-learn#31387) Co-authored-by: Omar Salman <omar.salman2007@gmail.com> Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
…cent (scikit-learn#31387) Co-authored-by: Omar Salman <omar.salman2007@gmail.com> Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
…cent (#31387) Co-authored-by: Omar Salman <omar.salman2007@gmail.com> Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
Reference Issues/PRs
None
What does this implement/fix? Explain your changes.
This PR removes the unnecessary updates of
R_sum=np.sum(residuals)
, because it does not change by a coordinate update ifX_mean
is provided, i.e.,np.sum(X[:, j] - X_mean[j])
equals 0.Any other comments?
Should improve runtime performance of
Lasso
andElasticNet
for sparse inputX
a bit.