-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG] Fix selection of solver in ridge_regression when solver=='auto' #13363
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Fix selection of solver in ridge_regression when solver=='auto' #13363
Conversation
@btel : I've merged the other PR, but now this one has conflicts (I think that it is not a good idea to do one PR on top of another). |
0a2cfd7
to
5a53d57
Compare
@GaelVaroquaux : indeed, it wasn't a good idea to implement this PR on top of #13336. I forgot that the commits will be squashed at merge time... Lesson learned. I fixed the PR by replaying the changes on top of current master. |
Can you update what’s new ? |
sklearn/linear_model/ridge.py
Outdated
if return_intercept and solver != 'sag': | ||
warnings.warn("In Ridge, only 'sag' solver can currently fit the " | ||
"intercept. Solver has been " | ||
"automatically changed into 'sag'.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this still true: I though that now ridge supports fitting the intercept via other solvers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only the Ridge.fit
method supports fitting intercept with other solvers (i.e. cholesky/sparse_cg). This was already the case before and it should be changed when the ridge_regression
function is refactored as discussed in #13336.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we not just raise an error in this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think an error is better too.
Doesn't 'saga' also support fitting intercept ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since #13336 is merged this needs to be update
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree an error should be raised instead of a warning
sklearn/linear_model/ridge.py
Outdated
def _select_auto_mode(): | ||
if return_intercept: | ||
# only sag supports fitting intercept directly | ||
return "sag" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic duplicates that of line 390.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line executes when argument solver is set to 'auto', line 390 executes when user sets the argument to any other value (except 'sag'). We need to make difference between these two cases, because in one we generate a warning whereas in the other we don't.
@agramfort I added a new section in whats new doc. |
f17610c
to
6ae88b8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @btel
sklearn/linear_model/ridge.py
Outdated
solver = "cholesky" | ||
|
||
if solver not in ('sparse_cg', 'cholesky', 'svd', 'lsqr', 'sag', 'saga'): | ||
raise ValueError("Known solver are 'sparse_cg', 'cholesky', 'svd'" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
solver -> solvers
Or "solver must be one of ..."
sklearn/linear_model/ridge.py
Outdated
if return_intercept and solver != 'sag': | ||
warnings.warn("In Ridge, only 'sag' solver can currently fit the " | ||
"intercept. Solver has been " | ||
"automatically changed into 'sag'.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we not just raise an error in this case?
Thanks @jnothman. I fixed the error message. I agree that it might be better to raise instead of changing automagically the solver, but it was the old behaviour and I didn't want to change it in this PR. However, if there is a consensus to raise, I will update this PR. |
@GaelVaroquaux @agramfort could you approve/merge? |
@jeremiedbb would you mind reviewing/merging this PR? It's good to merge. I talked to @GaelVaroquaux and he does not have time to have a look at it now. |
|
||
for solver in ['sparse_cg', 'cholesky', 'svd', 'lsqr', 'saga']: | ||
with pytest.warns(UserWarning) as record: | ||
target = ridge_regression(X, y, 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can match the warning message directly
with pytest.warns(UserWarning, match='return_intercept=True is only'):
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
for r in record: | ||
r.message.args[0].startswith("return_intercept=True is only") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you forgot an assert
here. But you'd not need these lines if you do as above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed these lines
sklearn/linear_model/ridge.py
Outdated
if return_intercept: | ||
# only sag supports fitting intercept directly | ||
solver = "sag" | ||
elif has_sw: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tought 'saga' also supports fitting intercept
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried fitting intercept with saga, and I get a warning in test test_ridge_fit_intercept_sparse
with solver='saga'
:
ConvergenceWarning('The max_iter was reached which means the coef_ did not converge')
so I guess it's not supported.
sklearn/linear_model/ridge.py
Outdated
# this should be changed since all solvers support sample_weights | ||
solver = "cholesky" | ||
elif sparse.issparse(X): | ||
solver = "sparse_cg" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find the sequence of conditions hard to follow. I find the initial pattern easier:
if return_intercept:
solver = "sag"
elif not sparse.issparse(X) or has_sw:
solver = "cholesky"
else:
solver = "sparse_cg"
If all solvers support sample weight, I'd be in favor of removing the has_sw
condition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, I changed the sequence and removed the has_sw
option.
if return_intercept: | ||
coef, intercept = target | ||
assert_array_almost_equal(coef, true_coefs, decimal=1) | ||
assert_array_almost_equal(intercept, 0, decimal=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use assert_allclose
instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
sklearn/linear_model/ridge.py
Outdated
if return_intercept and solver != 'sag': | ||
warnings.warn("In Ridge, only 'sag' solver can currently fit the " | ||
"intercept. Solver has been " | ||
"automatically changed into 'sag'.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think an error is better too.
Doesn't 'saga' also support fitting intercept ?
I made a few comments. But I'm sorry I can't merge it because I'm just a contributor :) |
@jeremiedbb I made the changes as you suggested. I still kept the warning, because I don't want to break people's code, when the used return_intercept=True and set a solver to a different value than I tried fitting intercept with saga and it did not work (see my comment above).
Sorry, I am a bit confused with the review process. I meant approving not merging the PR. |
coef, intercept = target | ||
assert_allclose(coef, true_coefs, atol=0.1) | ||
assert_allclose(intercept, 0, atol=0.1) | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is an absolute tol of 0.1 necessary ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs to be atol when comparing to 0 but 0.1 seems big for checking equality
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's true, but the differences in the estimations are around 0.02, so I can change to atol=0.03.
The true coefs are: 1, 2, 0.1, intercept 0
.solver auto, return_intercept=True, dense, no sw: (array([0.98738028, 1.97544885, 0.0983598 ]), 0.019808458760372082)
.solver auto, return_intercept=False, dense, with sw: [0.99946651 1.98731769 0.10972413]
.solver auto, return_intercept=True, dense, with sw: (array([0.98793847, 1.97696595, 0.1002527 ]), 0.01906343205046055)
.solver auto, return_intercept=False, sparse, no sw: [0.99836423 1.98816093 0.10940093]
.solver auto, return_intercept=True, sparse, no sw: (array([0.98888637, 1.97631928, 0.10069256]), 0.016659883979934714)
.solver auto, return_intercept=False, sparse, with sw: [0.99966452 1.98818802 0.10897127]
.solver auto, return_intercept=True, sparse, with sw: (array([0.9893669 , 1.97592915, 0.09956311]), 0.017160723846287827)
.solver sparse_cg, return_intercept=False, dense, no sw: [0.9991974 1.98769984 0.10987237]
.solver sparse_cg, return_intercept=True, dense, no sw: (array([0.98920646, 1.97558791, 0.09788704]), 0.021340292109204684)
.solver sparse_cg, return_intercept=False, dense, with sw: [0.99024222 1.99174497 0.11463053]
.solver sparse_cg, return_intercept=True, dense, with sw: (array([0.98761024, 1.97533341, 0.09901903]), 0.01854073215870237)
.solver sparse_cg, return_intercept=False, sparse, no sw: [0.9994363 1.98744613 0.1091634 ]
.solver sparse_cg, return_intercept=True, sparse, no sw: (array([0.9880755 , 1.97727505, 0.09942038]), 0.01744125549431998)
.solver sparse_cg, return_intercept=False, sparse, with sw: [0.99054007 1.99162964 0.1142885 ]
.solver sparse_cg, return_intercept=True, sparse, with sw: (array([0.98772017, 1.9773033 , 0.10027304]), 0.017187277743091305)
.solver cholesky, return_intercept=False, dense, no sw: [0.99913301 1.98694126 0.10987065]
.solver cholesky, return_intercept=True, dense, no sw: (array([0.98777785, 1.97725736, 0.09918002]), 0.019112808226411305)
.solver cholesky, return_intercept=False, dense, with sw: [0.99911638 1.98733099 0.10993568]
.solver cholesky, return_intercept=True, dense, with sw: (array([0.98614477, 1.97491618, 0.09864355]), 0.01985489984301323)
.solver cholesky, return_intercept=False, sparse, no sw: [0.99941728 1.98712034 0.1096048 ]
.solver cholesky, return_intercept=True, sparse, no sw: (array([0.98829793, 1.97603815, 0.09797809]), 0.019415378825590024)
.solver cholesky, return_intercept=False, sparse, with sw: [0.99934764 1.98748946 0.10933262]
.solver cholesky, return_intercept=True, sparse, with sw: (array([0.98731514, 1.97603497, 0.09975062]), 0.018854865586171575)
.solver lsqr, return_intercept=False, dense, no sw: [0.99951604 1.98764362 0.10910706]
.solver lsqr, return_intercept=True, dense, no sw: (array([0.9870815 , 1.97716204, 0.0995867 ]), 0.015630561450455646)
.solver lsqr, return_intercept=False, dense, with sw: [0.99976968 1.98730972 0.10931853]
.solver lsqr, return_intercept=True, dense, with sw: (array([0.9884556 , 1.97708318, 0.0994495 ]), 0.020868291821179098)
.solver lsqr, return_intercept=False, sparse, no sw: [0.99895738 1.98781999 0.1097974 ]
.solver lsqr, return_intercept=True, sparse, no sw: (array([0.98793535, 1.97810114, 0.10162843]), 0.01745574357981361)
.solver lsqr, return_intercept=False, sparse, with sw: [0.99888735 1.98760483 0.11008256]
.solver lsqr, return_intercept=True, sparse, with sw: (array([0.98908756, 1.97644754, 0.09898548]), 0.0179269121290331)
.solver sag, return_intercept=False, dense, no sw: [0.99851148 1.98589838 0.10733857]
.solver sag, return_intercept=True, dense, no sw: (array([0.98735441, 1.97573369, 0.09955717]), 0.019699491413496736)
.solver sag, return_intercept=False, dense, with sw: [1.00098593 1.98703252 0.10881227]
.solver sag, return_intercept=True, dense, with sw: (array([0.98860356, 1.97635123, 0.09842113]), 0.01970872473553323)
.solver sag, return_intercept=False, sparse, no sw: [0.99926502 1.9880752 0.11024386]
.solver sag, return_intercept=True, sparse, no sw: (array([0.98690006, 1.97720689, 0.09983899]), 0.018324191901867053)
.solver sag, return_intercept=False, sparse, with sw: [1.00034609 1.98534772 0.10857017]
.solver sag, return_intercept=True, sparse, with sw: (array([0.98836367, 1.97905487, 0.09644623]), 0.018058154068734605)
.solver saga, return_intercept=False, dense, no sw: [0.999297 1.98728781 0.10987951]
.solver saga, return_intercept=True, dense, no sw: (array([0.98715784, 1.97528129, 0.09813245]), 0.02077689395931017)
.solver saga, return_intercept=False, dense, with sw: [0.99869864 1.98810477 0.10949984]
.solver saga, return_intercept=True, dense, with sw: (array([0.98897455, 1.97488114, 0.09956767]), 0.021863847921240975)
.solver saga, return_intercept=False, sparse, no sw: [0.99885491 1.98692853 0.11079414]
.solver saga, return_intercept=True, sparse, no sw: (array([0.98841551, 1.97610063, 0.10014244]), 0.01736938519414459)
.solver saga, return_intercept=False, sparse, with sw: [0.99622644 1.99183347 0.10405478]
.solver saga, return_intercept=True, sparse, with sw: (array([0.98847833, 1.97703492, 0.09749779]), 0.018143914964244622)
so I can change to atol=0.03
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess the default tol of 1e-3 might be the reason of this poor comparison. Could you try with a zero tol ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@btel I don't understand why you cannot use a lower tolerance as you have no noise added to data.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what do you mean? the data are randomly generated, so I don't get exactly the coefficients I put in. I can freeze the seed and test against the coefficients that I get after a test run, but still I might get some small differences between the solvers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What @agramfort meant is that you could pass a smaller tolerance to ridge_regression
(as pushed in 604f7d1). Since your data isn't noisy, the solvers should converge just fine
@jeremiedbb is it good to approve? Let more know if you have extra comments. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just made one last comment. Otherwise LGTM.
coef, intercept = target | ||
assert_allclose(coef, true_coefs, atol=0.1) | ||
assert_allclose(intercept, 0, atol=0.1) | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess the default tol of 1e-3 might be the reason of this poor comparison. Could you try with a zero tol ?
f74d2f9
to
97a5326
Compare
@jeremiedbb @jnothman I changed rtol to 0 and fixed the merge conflict. Should be good to merge now. |
Formally we need another core dev to review |
@agramfort @GaelVaroquaux would you mind approving? |
|
||
# test excludes 'svd' solver because it raises exception for sparse inputs | ||
|
||
X = np.random.rand(1000, 3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please use a fixed random_state
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
doc/whats_new/v0.21.rst
Outdated
@@ -340,6 +340,16 @@ Support for Python 3.4 and below has been officially dropped. | |||
deterministic when trained in a multi-class setting on several threads. | |||
:issue:`13422` by :user:`Clément Doumouro <ClemDoum>`. | |||
|
|||
- |Fix| Fixed bug in :func:`linear_model.ridge.ridge_regression` that |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is also a bugfix for RidgeClassifier
at least
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, I updated the entry
sklearn/linear_model/ridge.py
Outdated
if return_intercept and solver != 'sag': | ||
warnings.warn("In Ridge, only 'sag' solver can currently fit the " | ||
"intercept. Solver has been " | ||
"automatically changed into 'sag'.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since #13336 is merged this needs to be update
@NicolasHug for the moment only _BaseRidge estimator implements fit_intercept for sparse_cg. This still does not work for ridge_regression function, which only supports the intercept in solvers that fit it directly (sag). This will be changed in the future, but requires some refactoring, see comment: #13336 (comment) |
ok I pushed stricter test @btel I managed by reducing the tol and the alpha to have a much lower atol that pass tests for me. If CIs are green it's good to go from my end. |
X_testing = arr_type(X) | ||
|
||
alpha, atol, tol = 1e-3, 1e-4, 1e-6 | ||
target = ridge_regression(X_testing, y, alpha=alpha, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not a big fan of target
, which is usually what we use for y
. Maybe out
(not great either)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed to out
X, y = make_regression(n_samples=1000, n_features=2, n_informative=2, | ||
bias=10., random_state=42) | ||
|
||
for solver in ['sparse_cg', 'cholesky', 'svd', 'lsqr', 'saga']: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this could be parametrized but OK
sklearn/linear_model/ridge.py
Outdated
if return_intercept and solver != 'sag': | ||
warnings.warn("In Ridge, only 'sag' solver can currently fit the " | ||
"intercept. Solver has been " | ||
"automatically changed into 'sag'.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree an error should be raised instead of a warning
if return_intercept: | ||
coef, intercept = target | ||
assert_allclose(coef, true_coefs, rtol=0, atol=atol) | ||
assert_allclose(intercept, intercept, rtol=0, atol=atol) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please change the name of the true intercept to true_intercept
because this doesn't check anything
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True! Thanks!
coef, intercept = target | ||
assert_allclose(coef, true_coefs, atol=0.1) | ||
assert_allclose(intercept, 0, atol=0.1) | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What @agramfort meant is that you could pass a smaller tolerance to ridge_regression
(as pushed in 604f7d1). Since your data isn't noisy, the solvers should converge just fine
(covered by test_ridge_regression_check_arguments_validity)
@NicolasHug I addressed the points that you raised. Thanks for reviewing! I also changed the warning to an exception, since it seems that everyone was in favour of it. |
Thanks @btel! |
Yay! Thanks @btel.
|
… classes (scikit-learn#13363)" This reverts commit e1f66b8.
… classes (scikit-learn#13363)" This reverts commit e1f66b8.
Continues the work started in #13336.
Must be reviewed as an addition to #13336 and merged after(merged)Fixes #13362.
What does this implement/fix? Explain your changes.
The solver that was selected when
solver
argument of ridge_regression was set to'auto'
was ambiguous and sometimes even incorrect (see #13362). This PR tries to make it more explicit.Any other comments?
Ping: @agramfort