Skip to content

[MRG+1] Expose SAGA solver for ElasticNet regression #12907 #12966

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 16 commits into from

Conversation

lorentzenchr
Copy link
Member

Reference Issues/PRs

Fixes #12907.

What does this implement/fix? Explain your changes.

A new argument 'solver' is introduced in ElasticNet. In addition to the default 'cd' one can also choose 'saga' to use the SAGA solver, see also Ridge and LogisticRegression.
The same tests as for coordinate descent are done for saga.

Any other comments?

Open question: So far ElasticNet via cd does not support sample_weight. It could be included via SAGA. Should this go in this PR or rather in a new issue?

@lorentzenchr
Copy link
Member Author

There is a test failure in test_enet_toy in 2 of 4 settings of travis that I can't reproduce on my machine. So I'm a bit stuck. Can someone please help to figure this out in order to make the tests pass?

@jnothman
Copy link
Member

The failing settings use latest numpy/scipy versions

Algorithm to use in the optimization problem.

- 'auto' chooses the solver automatically based on the type of data.
If the data is F-contiguous or a sparse 'csc' matrix, it chooses
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems quite an obscure criterion... Is the memory order of the data really sensible criteria to say one solver is more appropriate than the other?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have a large dataset X, reordering from C to F or vice versa with a full copy of the array X might be a heavy (memory) and unwanted operation. The nice point is that 'cd' needs F- and 'saga' needs C-ordered arrays. Furthermore, without benchmarks, we can't tell which solver is better suited for a given X, so I came up with this criterion. I can easily change 'auto' to 'cd' or remove 'auto'. It is just a suggestion.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd be inclined to remove auto, but would be interested to hear from @agramfort

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alternatively, those benchmarks we don't have yet, would be an argument for auto. Would you be very nice and run some benchmarks on this @lorentzenchr ? It may turn out to be the case that the copy is not that significant in most cases compared to the rest of the method afterall.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for removing auto for now but I would be very interested in some extensive benchmarks to guide future decisions on "smart" solver selection.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmarks based on bench_glmnet.py indicate that coordinate descent is the clear winner, regardless of the contiguity of the data. So, auto will be removed and cd the default solver.

@lorentzenchr
Copy link
Member Author

Even with the same versions of python, numpy and scipy, I'm not able to reproduce the test failure on travis for test_enet_toy on my machine. I don't know what to do. Sorry.

@lorentzenchr
Copy link
Member Author

I think that I figured out why the test on travis fails. See issue #13021.

@lorentzenchr lorentzenchr changed the title [WIP] Expose SAGA solver for ElasticNet regression #12907 [MRG] Expose SAGA solver for ElasticNet regression #12907 Jan 23, 2019
@lorentzenchr
Copy link
Member Author

With a new solver in ElasticNet, the current file structure does not make sense to me. If I were to do it from scratch (without backward incompatibility issues), I would separate optimization algos=solvers from the actual models=API=estimators. That is, I would put the class ElasticNet in a file like enet_regression.py (and in the future maybe also Lasso, ElasticNetCV, ..).

What do you think?

@adrinjalali
Copy link
Member

With a new solver in ElasticNet, the current file structure does not make sense to me. If I were to do it from scratch (without backward incompatibility issues), I would separate optimization algos=solvers from the actual models=API=estimators. That is, I would put the class ElasticNet in a file like enet_regression.py (and in the future maybe also Lasso, ElasticNetCV, ..).

I guess we can only make such changes once we're clear on #12927.

@lorentzenchr
Copy link
Member Author

@adrinjalali I hope I resolved your review comments in your way.
I'd like to address the question again, whether or not this PR should be used to include sample_weight in ElasticNet as SAGA supports them.

if isinstance(self.precompute, str):
raise ValueError('precompute should be one of True, False or'
' array-like. Got %r' % self.precompute)
elif self.precompute is not False and self.solver == 'saga':
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's wrong, but self.precomute != False (or not self.precompute if None should be considered the same here) seems more natural to me.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.precompute != False gives "E712 comparison to False should be 'if cond is not false' or ...".
I personally find if not self.precompute less readable than the if self.precompute is not False.

@adrinjalali
Copy link
Member

I'd like to address the question again, whether or not this PR should be used to include sample_weight in ElasticNet as SAGA supports them.

I'd suggest having a separate PR for that.

Please also add an entry to the change log at doc/whats_new/v*.rst. Like the other entries there, please reference this pull request with :issue: and credit yourself (and other contributors if applicable) with :user:.

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not yet reviewed tests

@@ -560,6 +561,7 @@ class ElasticNet(LinearModel, RegressorMixin):
on an estimator with ``normalize=False``.

precompute : True | False | array-like
relevant only if ``solver='cd'``
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please start with uppercase and end with .

In HTML the line break will disappear

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to keep the line break in HTML?

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think that structure is cleaner...

One nitpick. Nice work!

@lorentzenchr lorentzenchr changed the title [MRG] Expose SAGA solver for ElasticNet regression #12907 [MRG+1] Expose SAGA solver for ElasticNet regression #12907 Feb 11, 2019
@lorentzenchr
Copy link
Member Author

@adrinjalali With your +1 this would be ready for merge. At least, I hope so 😏

@ogrisel
Copy link
Member

ogrisel commented Feb 21, 2019

@lorentzenchr thanks for your work. Could you please post a screenshot of your benchmark results? Aren't there any cases where SAGA is faster than CD (e.g. with a large number of samples)?

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this PR looks nice but I would be more confident in the results if we could check the dual gap for the saga solver in the tests. See more details below:

For ``solver='cd'``, the number of iterations run by the solver to
reach the specified tolerance.
For ``solver='saga'``, the number of full passes on all samples until
convergence.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize that the dual_gap_ attribute is not mentioned here. This is an oversight. This should be fixed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It never has been. I can add it in this PR or open a new issue. What do you prefer?

is_saga=True)
coef_[k] = this_coef
self.n_iter_.append(this_iter)
self.dual_gap_ = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should compute the dual_gap_ after the final SAGA iteration (not necessarily use it as a convergence criterion for SAGA itself) but more as a final check to make it possible for the user that the two solvers can find a solution with similar quality.

I don't really know for sure what is the meaning for tol in the context of saga. @TomDLT @arthurmensch @agramfort do know if we could provide any kind of guarantee w.r.t. the dual gap when using the natural stopping criterion of the sag_solver?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I see, saga uses the same convergence criteria but without the additional check for the dual gap. Compare

if ((max_weight != 0 and max_change / max_weight <= tol)

if (w_max == 0.0 or
d_w_max / w_max < d_w_tol or

@ogrisel
Copy link
Member

ogrisel commented Feb 21, 2019

I ran a couple of benchmarks with make_regression(n_samples=int(1e6), n_features=100, n_informative=5) and while the CD and SAGA solver yield similar solutions as expected, SAGA is still significantly slower. CD can converge in just 3 iterations on this dataset while SAGA needs at least 10.

Interestingly, SGDRegressor with early stopping becomes competitive in that regime (3 iteration as well) although I suspect that the dual gap would not be the same :).

SGDRegressor(l1_ratio=0.5, penalty="elasticnet", max_iter=1000, tol=1e-6,
             n_iter_no_change=1, alpha=1).fit(X, y)

It does not pick up the same non-zero features as SAGA and CD but there are also only 5 non-zeros (because of correlation I guess there is redundancy) and the cross-validation accuracy of SGDRegressor is the same or even slightly higher.

@ogrisel
Copy link
Member

ogrisel commented Feb 21, 2019

Note that in that large sample regime, make_pipeline(SelectFdr(f_regression), LinearRegression()) find the same supports, perfect 1.0 r2 score on the validation set and runs in 1s...

And the selected support from SelectFdr(f_regression) matches the one by ElasticNet.

@arthurmensch
Copy link
Contributor

Will review this during the sprint

@ogrisel
Copy link
Member

ogrisel commented Feb 22, 2019

But more importantly is there really a reason to expose the saga solver in ElasticNet if the CD solver is always better?

@lorentzenchr
Copy link
Member Author

I did some more benchmarks:
image
For increasing number of features I get similar results. When I set the size of training data larger than my memory, saga started to be faster.
One reason to include saga could be to add sample_weight, but this could also be added to cd via rescaling of the data and would be another PR anyway.

@agramfort
Copy link
Member

could we reuse some of the benchmarks here:

https://github.com/scikit-learn/scikit-learn/tree/master/benchmarks

to see in which setting it helps.

@lorentzenchr
Copy link
Member Author

I did my benchmarks based on https://github.com/scikit-learn/scikit-learn/blob/master/benchmarks/bench_glmnet.py. If it helps, I could provide my adaptation. Also, I couldn't find a single (in memory) case where saga is faster. So @ogrisel has a good point for questioning to expose saga at all.

@ogrisel
Copy link
Member

ogrisel commented Feb 28, 2019

It's possible that our saga solver is not optimal (e.g. we do not do any minibatching / importance sampling based on gradient magnitude...). But for now I don't feel like merging this: I see no value in exposing choices that are always found to be empirically suboptimal w.r.t. existing methods already implemented in the library.

@fabianp
Copy link
Member

fabianp commented Mar 1, 2019

I'm not too surprised of this since CD can do many optimizations in the case of a squared loss, while the saga solver is basically the same than the one for logistic regression.

AFAIKT the benchmarks are on dense data. A big usercase where SAGA shines is when the input data is very sparse (rcv1, url dataset, etc.)

@ogrisel
Copy link
Member

ogrisel commented Mar 8, 2019

Thanks for your input @fabianp. @lorentzenchr I see you closed the issue, it might still be interesting to run some benchmarks with the SAGA solver on some medium scale sparse text regression problem: e.g. predictive movie review scores from TF-IDF / bag-of-words features.

@lorentzenchr
Copy link
Member Author

Also for sparse (csr) data, I get similar results as the benchmark on dense data above. I don't dare showing the numbers. Conclusion: CD seems to be a factor 10 faster for elastic net regression (squared error).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Expose SAGA solver for ElasiticNet
7 participants