Skip to content

[MRG+2] ENH : SAGA support for LogisticRegression (+ L1 support), Ridge #8446

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 27, 2017

Conversation

arthurmensch
Copy link
Contributor

@arthurmensch arthurmensch commented Feb 23, 2017

This PR proposes slight adaptations of the existing sag_fast.pyx module to be able to use SAGA in addition to the SAG algorithm. This is the way to go if we want to propose enet/l1 penalty for fast incremental solvers in for ridge and logistic regression.

A SAGA implementation is already available in lightning, which is adapted from the original paper. This one is slightly different as it is built around understanding SAGA update as a "corrected" version of SAG.
I believe it would also be possible to slightly adapt the module to have SVRG in addition to SAGA, if this interests people.

I tried to keep the changes made to sag_fast.pyx as scarse as possible. I reckon that the sag_fast.pyx module could be made a little more readable using 2d memoryviews instead of using strided pointers everywhere. For further work.

For the moment I adapted the test_logistic.py file to ensure correctness of the algorithm, but the saga algorithm should be tested within test_sag.py.

SAGA paper

TODO

  • Documentation
  • Reference for step size
  • Check optimal step size
  • Ridge API + tests
  • Test module sag.py directly
  • Implement l1 penalty (simple projection)
  • Different PR Use minibatches ? I cannot recall whether this is actually interesting.
  • Add nice benchmarks.
  • Benchmarks against liblinear and lightning
  • Different PR Add a SAGA solver for Lasso (which might imply a bit of refactoring...)
  • Add rcv1 example with multinomial + L1
  • Different PR Add elastic net l1_ratio in LogisticRegression
    ping @TomDLT @agramfort you might be interested by this :)

@GaelVaroquaux
Copy link
Member

GaelVaroquaux commented Feb 23, 2017 via email

@GaelVaroquaux
Copy link
Member

GaelVaroquaux commented Feb 23, 2017 via email

@arthurmensch
Copy link
Contributor Author

Indeed quite a few benchmarks would be necessary. Added !

from ..externals import six
from ..metrics import SCORERS
from ..utils.optimize import newton_cg
from ..utils.validation import check_X_y
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guessed my editor just put these automatically in alphabetic order, I can revert if necessary

@GaelVaroquaux
Copy link
Member

GaelVaroquaux commented Feb 23, 2017 via email

@codecov
Copy link

codecov bot commented Feb 23, 2017

Codecov Report

Merging #8446 into master will increase coverage by <.01%.
The diff coverage is 100%.

@@            Coverage Diff             @@
##           master    #8446      +/-   ##
==========================================
+ Coverage   95.47%   95.48%   +<.01%     
==========================================
  Files         342      342              
  Lines       60907    61007     +100     
==========================================
+ Hits        58154    58255     +101     
+ Misses       2753     2752       -1
Impacted Files Coverage Δ
sklearn/linear_model/tests/test_sag.py 98.6% <100%> (+0.07%)
sklearn/linear_model/sag.py 94.36% <100%> (+0.52%)
sklearn/linear_model/tests/test_logistic.py 100% <100%> (ø)
sklearn/linear_model/logistic.py 97.65% <100%> (+0.03%)
sklearn/linear_model/tests/test_ridge.py 100% <100%> (ø)
sklearn/linear_model/ridge.py 93.88% <100%> (ø)
sklearn/linear_model/coordinate_descent.py 96.94% <0%> (ø)
sklearn/tree/tree.py 98.41% <0%> (ø)
sklearn/metrics/classification.py 97.77% <0%> (ø)
sklearn/decomposition/tests/test_pca.py 100% <0%> (ø)
... and 4 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update fcb1403...2a5c4bf. Read the comment docs.

@arthurmensch
Copy link
Contributor Author

A quick benchmark :

Classification performance:
===========================
Classifier               train-time   test-time   error-rate
------------------------------------------------------------
LinearRegression-SAGA        17.25s       0.07s       0.0810
LinearRegression-SAG         20.64s       0.10s       0.0824

@arthurmensch
Copy link
Contributor Author

arthurmensch commented Feb 24, 2017

figure_1

I added some code to do benchmarks, which should be removed for merging.

With the conservative auto step size that is used, SAGA performs better in the first epochs and SAG gets better for finer convergence (a behavior already observed in the litterature). With more aggressive stepsizes, SAGA performs better than SAG. We could use line-search as it tends to produce better results, it is used in sgd in scikit-learn ?

@GaelVaroquaux
Copy link
Member

GaelVaroquaux commented Feb 24, 2017 via email

@agramfort
Copy link
Member

@arthurmensch can you show a benchmark with $log10(f(x^k) - f(x^*))$ on y axis and time on x axis?
I suspect marginal improvements for L2 log reg so it would need a convincing figure on L1 reg logistic so a comparison with liblinear that we use presently.

@arthurmensch
Copy link
Contributor Author

@agramfort can you recall a dataset on which SAG would beat liblinear with l2 penalty ?

@agramfort
Copy link
Member

agramfort commented Feb 27, 2017 via email

@arthurmensch
Copy link
Contributor Author

I did some benchmarks against lightning for the moment, on rcv1. For some reason we are 10x faster than lightning with L1 penalty, any thought @fabianp ?

l2 logistic:

log_l2

l1 logistic:

log l1

Credits for the shortcut in the composition of prox operators goes to @fabianp and @mblondel, but I think it looks a bit cleaner that way.

I will post benchmarks against liblinear ASAP.

@GaelVaroquaux
Copy link
Member

This is looking great. And the fun part is that it seems that reimplementing code and comparing teaches us a lot, even when the two implementations are done by people that share so much.

@arthurmensch arthurmensch changed the title [WIP] SAGA support for LogisticRegression and Ridge [MRG] SAGA support for LogisticRegression and Ridge Feb 28, 2017
@arthurmensch
Copy link
Contributor Author

arthurmensch commented Feb 28, 2017

For single class l1/l2 logistic regression liblinear is hard to beat. But for multiclass, speed improvements look great:

figure_1-1

There is a discrepency in the final accuracy as both model are not the same (the loss is different, and liblinear penalizes the intercept, but I reckon they perform the same with correct cross validation.

Two-class problem (there is a problem there as the final training score should not be that different).

figure_1-2

One of the advantage of using SAGA instead of liblinear is that we can perform CV with memory-mapped data, which can be crucial when datasets are huge.

@mblondel
Copy link
Member

It could be interesting to see if both implementations return more or less the same weight vectors.

I remember it took @fabianp and @zermelozf several iterations to implement the lazy updates correctly. We have a naive Pure python implementation in our tests to ensure correctness:

https://github.com/scikit-learn-contrib/lightning/blob/master/lightning/impl/tests/test_sag.py#L85

@arthurmensch
Copy link
Contributor Author

arthurmensch commented Mar 1, 2017

single_target_l2
single_target_l1
multi_target_l2
multi_target_l1

These are the right benches liblinear vs lightning vs sklearn saga, for 1/20 of rcv1.

  • SAGA sklearn is better for multitarget regression for l2 and l1
  • SAGA sklearn is better for single target regression for l2, and a little slower than liblinear with l1
  • SAGA sklearn and lightning are roughtly on par with l2 regularisation, single-target
  • SAGA sklearn is ten times faster than lightning with l1 regularisation, single-target. This is surprising

I am currently running the benches on the whole dataset, to see if saga does not get better than liblinear at some point.

The benchmark file that is in this PR does not use callbaks or any hacks, so I think it can stay in the repo for future reference.

@GaelVaroquaux
Copy link
Member

GaelVaroquaux commented Mar 1, 2017 via email

@arthurmensch
Copy link
Contributor Author

arthurmensch commented Mar 2, 2017

Should I try to add a solver for Lasso ? For the moment scikit-learn has quite a variety of Lasso solvers (Lasso, LassoLars, CV versions, IC versions, RandomizedLasso) so I am not sure where to put it. In Lasso, adding a solver option ?

Copy link
Member

@TomDLT TomDLT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work ! The code looks good so far.
I still need to understand the JIT prox update.

Use minibatches ? I cannot recall whether this is actually interesting.

Mark Schmidt says it's interesting in SAG (slide 73), yet I am not sure we need it in scikit-learn.


if penalty == 'l1':
if solver == 'sag':
raise ValueError("Unsupported penalty. Use `saga` instead.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is never reached, thanks to _check_solver_option, isn't it?



def get_auto_step_size(max_squared_sum, alpha_scaled, loss, fit_intercept):
def get_auto_step_size(max_squared_sum, alpha_scaled, loss, fit_intercept,
n_samples=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update docstring

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done




cdef double lagged_update(double* weights, double wscale, int xnnz,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update the docstring description

else:
for class_ind in range(n_classes):
idx = f_idx + class_ind
if fabs(sum_gradient[idx] * cum_sum) < cum_sum_prox:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain this? and maybe add some comments

Copy link
Contributor Author

@arthurmensch arthurmensch Mar 2, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a nice trick from lightning: instead of enrolling the whole delayed softmax(softmax(softmax(w - grad_update) - grad_update ...) in the loop below, we factorize it as we do not cross the non-linearity due to the softmax. There is no academical reference for this, but we should indeed add some comments.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice if there was a small blog post with the derivations for this on the web. @fabianp sent me an unfinished draft for this but I gathered it was to stay a draft.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you do not have time I can write a small blog post wit mathematical derivation for this part, with due reference to @fabianp @zermelozf @mblondel. Then we can reference it in comment.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be awesome. I'm sending you the .tex sources of that draft so you can use it as you please.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even without the full blog post with the derivation it would be great to have an inline comment to state:

if reset:
cumulative_sums[sample_itr - 1] = 0.0
if prox:
cumulative_sums_prox[sample_itr - 1] = 0.0

# reset wscale to 1.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should return void and return 1.0 only in scale_weights

cdef np.ndarray[double, ndim=1] cumulative_sums_prox_array
cdef double* cumulative_sums_prox

cdef bint prox = beta > 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it could be safe to check that saga is True

else:
raise ValueError("Unknown loss function for SAG solver, got %s "
"instead of 'log' or 'squared'" % loss)
if is_saga:
mun = min(2 * n_samples * alpha_scaled, L)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where does it come from?

Copy link
Contributor Author

@arthurmensch arthurmensch Mar 2, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SAGA original paper: proofs of convergence requires step_size < 1 / 3L or step_size < 1 / (2(L + mu n), where mu is the strong convexity modulus of the objective. We could use 1 / 3L but this is more optimal in the low dimensional regime. I will add a reference. By the way SAG use 1 / L whereas the only step size for which proofs are available is 1 / 16 L. I think this is a sound heuristic but we should add a reference as well. 1 / L is also a good heuristic for SAGA in most cases, but it actually make one test fail in the test suite :P We could allow the user to specify the step size, but it would complexify the API.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok
For SAG, I used the recommendation of Mark Schmidt: (slide 65)

@@ -261,26 +272,42 @@ def sag_solver(X, y, sample_weight=None, loss='log', alpha=1.,
if max_squared_sum is None:
max_squared_sum = row_norms(X, squared=True).max()
step_size = get_auto_step_size(max_squared_sum, alpha_scaled, loss,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the step size depend on beta_scaled in the L1 case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it depends on L which is the Lipschitz constant of the gradient of the smooth objective part (i.e logistic + optional l2).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yes of course

@arthurmensch arthurmensch changed the title [MRG] SAGA support for LogisticRegression and Ridge [MRG] ENH : SAGA support for LogisticRegression, Ridge and Lasso (+ l1 support) Mar 2, 2017
@arthurmensch
Copy link
Contributor Author

arthurmensch commented Mar 2, 2017

Should we add an example for SAGA + some narrative documentation ? I suggest multinomial logistic + l1 on rcv1, with comparison against one versus all with liblinear, although this would take a few minutes on the whole dataset (because of ovr).

I suggest we work on merging this PR before working on using SAGA for Lasso.

There is also a question regarding whether we should propose both sag and saga to the end-user, or deprecate sag ? @GaelVaroquaux.

@arthurmensch
Copy link
Contributor Author

arthurmensch commented Mar 2, 2017

ping @TomDLT @agramfort @GaelVaroquaux I think this is ready for review. Reviews welcome from @fabianp and @mblondel for the core code if you have time.

A few UX questions to answer before I move on :

  • Do we agree on a multinomial + L1 example on rcv1, with comparison with ovr + liblinear ?
  • Do we deprecate sag solver or keep both sag and saga ? Is sag solver useful in any way compared to saga @fabianp ?
  • Do we add an l1_ratio to LogisticRegression as we now can ? C + l1_ratio sounds a bit original but it is still meaningful.
  • Naming is now a bit fishy (sag_solver for saga). I guess we can keep it that way for the moment, but it will require refactoring if we add for instance an svrg solver to the code base.

I reckon adding l1_ratio can be made in a new PR to keep this one minimalistic.

@arthurmensch arthurmensch changed the title [MRG] ENH : SAGA support for LogisticRegression, Ridge and Lasso (+ l1 support) [MRG] ENH : SAGA support for LogisticRegression, Ridge and Lasso (+ L1 support) Mar 2, 2017
@arthurmensch
Copy link
Contributor Author

arthurmensch commented Mar 3, 2017

On the whole rcv1 dataset, single target with high regularization, saga is faster than liblinear: C=0.1

single_target_l1

This is not true for low regularization : C=1

single_target_l1

This is expected (with high regularization the soft thresholding shortcut trick works more often).

@TomDLT
Copy link
Member

TomDLT commented Mar 3, 2017

How does SAGA perform on small datasets (e.g. iris)?
It would be great if it would compete with liblinear, so we can change the default solver in LogisticRegression.
The regularization of the intercept in liblinear is often confusing for users.

@GaelVaroquaux
Copy link
Member

We will need a very clear paragraph in the documentation that explains which solver to choose when.

@@ -0,0 +1,114 @@
"""
=====================================================
Multiclass logisitic regression on newgroups20
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also put "sparse" in the title.

features to zero. This is good if the goal is to extract the strongly
discriminative vocabulary of each class. If the goal is to get the best
predictive accuracy, it is better to use the non sparsity-inducing
l2 penalty instead.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should also mention univariate feature selection as an alternative way to extract sparse discriminative vocabularies.

Maybe you could even extend the example by adding a pipeline of a sparse uni variate feature selection model + l2 penalized logistic regression to showcase a classification model with similar sparsity level as the l1 penalized variant.

Performance of multinomial logistic regression with
L1 penalty. We use the SAGA algorithm for this purpose, which is fast. Test
accuracy reaches > 0.8, while weight vectors remains *sparse* and
*interpretable*.
Copy link
Member

@ogrisel ogrisel Mar 20, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a note such that:

Note that this accuracy is far below what can be reached by an non-penalized linear model (I think ~0.93) but this should be checked and even more far below the accuracy non linear models such as a multi layer perceptron (0.98+).

np.ndarray[double, ndim=2, mode='c'] sum_gradient_init,
np.ndarray[double, ndim=2, mode='c'] gradient_memory_init,
np.ndarray[bint, ndim=1, mode='c'] seen_init,
int num_seen,
bint fit_intercept,
np.ndarray[double, ndim=1, mode='c'] intercept_sum_gradient_init,
double intercept_decay,
bint saga,
bint verbose):
"""Stochastic Average Gradient (SAG) solver.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update this docstring to make it explicit that this function implements both SAG and SAGA.

@ogrisel
Copy link
Member

ogrisel commented Mar 22, 2017

I pushed better checks for the l1 penalty logistic regression tests. I still want to review other parts of the code / tests but maybe we should do the Lasso part in a separate PR.

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a review and pushed some fixes (mostly missing updates in the documentation). I think we can merge this PR without waiting for elasticnet penalty and integration in the Lasso* classes.

@ogrisel ogrisel changed the title [MRG] ENH : SAGA support for LogisticRegression, Ridge and Lasso (+ L1 support) [MRG+1] ENH : SAGA support for LogisticRegression, Ridge and Lasso (+ L1 support) Mar 27, 2017
@ogrisel ogrisel changed the title [MRG+1] ENH : SAGA support for LogisticRegression, Ridge and Lasso (+ L1 support) [MRG+1] ENH : SAGA support for LogisticRegression (+ L1 support), Ridge Mar 27, 2017
Copy link
Member

@TomDLT TomDLT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM except for minor nitpicks

I think we can merge this PR without waiting for elasticnet penalty and integration in the Lasso* classes.

I agree

Very Large dataset (`n_samples`) "sag" or "saga"
================================= =====================================

The "saga" solver is almost always a the best choice. The "liblinear"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a

regression yields more accurate results and is faster to train on the larger
scale dataset.

Here we use the l1 sparsity that trims the weights of no to informative
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-> of not informative

@@ -967,6 +976,9 @@ class LogisticRegression(BaseEstimator, LinearClassifierMixin,
Used to specify the norm used in the penalization. The 'newton-cg',
'sag' and 'lbfgs' solvers support only l2 penalties.

.. versionadded:: 0.19
l1 penalty with SAGA solver (allowing 'multinomial + L1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix the '

@@ -6,16 +6,16 @@
# Tom Dupre la Tour <tom.dupre-la-tour@m4x.org>
#
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add author

@@ -860,6 +895,71 @@ def test_logreg_intercept_scaling_zero():
assert_equal(clf.intercept_, 0.)


def test_logreg_l1():
# Because liblinear penalizes the intercept and saga does not, we do
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do not



def test_logreg_l1_sparse_data():
# Because liblinear penalizes the intercept and saga does not, we do
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do not

@ogrisel ogrisel changed the title [MRG+1] ENH : SAGA support for LogisticRegression (+ L1 support), Ridge [MRG+2] ENH : SAGA support for LogisticRegression (+ L1 support), Ridge Mar 27, 2017
@ogrisel
Copy link
Member

ogrisel commented Mar 27, 2017

I rebased an squashed everything down to a single commit. If CI is still green, let's merge.

@ogrisel ogrisel merged commit 5147fd0 into scikit-learn:master Mar 27, 2017
@ogrisel
Copy link
Member

ogrisel commented Mar 27, 2017

Merged! Thanks @arthurmensch!

@arthurmensch
Copy link
Contributor Author

arthurmensch commented Mar 27, 2017 via email

@TomDLT
Copy link
Member

TomDLT commented Mar 27, 2017

🍻

@GaelVaroquaux
Copy link
Member

GaelVaroquaux commented Mar 28, 2017 via email

@jnothman
Copy link
Member

jnothman commented Mar 28, 2017 via email

@agramfort
Copy link
Member

agramfort commented Mar 28, 2017 via email

@fabianp
Copy link
Member

fabianp commented Mar 28, 2017

congrats @arthurmensch and co. I think this is a great example of development that started in scikit-learn-contrib and (with a lot of work and improvements) ended upstream.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants