Skip to content

Conversation

lorentzenchr
Copy link
Member

@lorentzenchr lorentzenchr commented Mar 12, 2022

Reference Issues/PRs

Fixes #3702 by implementing the final bits.

What does this implement/fix? Explain your changes.

This PR implements sample_weight support with sparse X for ElasticNet, ElasticNetCV, Lasso and LassoCV.

Details

The objective with sample weight sw is given by

sum(sw * (y - X w - w0)^2) + penalty

Solving for the intercept w0 gives

sum(sw * (y1 - X1 w)^2) + penalty,  y1 = y - y_mean, X1 = X - X_mean

where the mean is a weighted average, weighted by sw.
Dense solvers go on an rescale y1 and X1 by sqrt(sw) but sparse solvers cannot set X1 = X - X_mean as this destroys sparsity of X. Therefore, X_mean is passed to the coordinate descent solver.
This PR goes on and also passes sample weights to the cd solver. The alternative would be to provide sw * X_mean which is a dense matrix of the same dimensions as X, not a good idea.

@lorentzenchr
Copy link
Member Author

@agramfort @TomDLT @rth You might be interested.
Maybe, you can help with the final bits. For instance, there might be a problem with the stopping criterion of the cd solver. I had to decrease (making more strict) tol for sparse cases in comparison to the dense cases.

Copy link
Member

@agramfort agramfort left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe @mathurinm has time to look

@lorentzenchr lorentzenchr force-pushed the cd_sparse_sample_weight branch from 509844d to b4557c9 Compare March 18, 2022 17:53
Copy link
Member

@agramfort agramfort left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ! thx @lorentzenchr

just to check did you observe any slowdown in the no sample weight case due to extract branching now present in the new code?

@lorentzenchr
Copy link
Member Author

just to check did you observe any slowdown in the no sample weight case due to extract branching now present in the new code?

from sklearn.linear_model import ElasticNet
from sklearn.linear_model.tests.test_sparse_coordinate_descent import make_sparse_data

X, y = make_sparse_data(n_samples=1000, n_features=10_000)

%timeit ElasticNet().fit(X, y)

Results with %timeit -r50:

branch time [ms]
main 114 ± 1.95
this PR 110 ms ± 2.59

Note that I notices quite some variation of those timings.

@mathurinm
Copy link
Contributor

@lorentzenchr beware that with such data and alpha=1 you get very sparse coef_ and very few iterations:

In [2]: %paste
from sklearn.linear_model import ElasticNet
from sklearn.linear_model.tests.test_sparse_coordinate_descent import make_sparse_data

X, y = make_sparse_data(n_samples=1000, n_features=10_000)
## -- End pasted text --

In [3]: clf = ElasticNet().fit(X, y)

In [4]: (clf.coef_ != 0).sum()
Out[4]: 2

In [5]: clf.n_iter_
Out[5]: 3

so computations like the lipschitz constants (norm(X, axis=0) ** 2) have a heavy weight in the total time.

In [19]:  %timeit ElasticNet(max_iter=0).fit(X, y)
84.3 ms ± 849 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

@lorentzenchr
Copy link
Member Author

@mathurinm good point.
Next try:

from sklearn.linear_model import ElasticNet
from sklearn.linear_model.tests.test_sparse_coordinate_descent import make_sparse_data

X, y = make_sparse_data(n_samples=1000, n_features=10_000, n_informative=1000)

%timeit ElasticNet(alpha=0.01).fit(X, y)

Gives n_iter_=601.
Results:

branch time [s]
main 12.4 ± 0.17
this PR 11.7 ± 0.10

@agramfort
Copy link
Member

perfect ! thx @lorentzenchr

just need another +1 to MRG this one. 🙏

Copy link
Member

@jeremiedbb jeremiedbb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good. I can't really review the sparse_enet_coordinate_descent details (besides trying to check that the code does what you wrote as comments), but I trust the tests and @agramfort :). And it does not modify the no sample weight caseso no big risk.

@jeremiedbb jeremiedbb added this to the 1.1 milestone Mar 24, 2022
@jeremiedbb
Copy link
Member

jeremiedbb commented Mar 25, 2022

issue #3702 was created in 2014. We are on shedule 😄
Thanks @lorentzenchr !

@jeremiedbb jeremiedbb merged commit bf0ece8 into scikit-learn:main Mar 25, 2022
@lorentzenchr lorentzenchr deleted the cd_sparse_sample_weight branch March 25, 2022 10:37
glemaitre pushed a commit to glemaitre/scikit-learn that referenced this pull request Apr 6, 2022
@mathurinm
Copy link
Contributor

@lorentzenchr after thoughts I don't understand why in the sparse case R is scaled by sample_weights and X is not scaled, instead of both being scaled by sqrt(sample_weights) (here).

Is the current design adopted in order to have the same code to compute the gradient, with or without sample weights ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

sample_weight for lasso, elastic etc
4 participants