Skip to content

TST suppress convergence warnings in coordinate descent #20144

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 6, 2024
4 changes: 2 additions & 2 deletions sklearn/linear_model/_cd_fast.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ def enet_coordinate_descent_multi_task(
bint random=0
):
"""Cython version of the coordinate descent algorithm
for Elastic-Net mult-task regression
for Elastic-Net multi-task regression

We minimize

Expand Down Expand Up @@ -947,7 +947,7 @@ def enet_coordinate_descent_multi_task(
+ 0.5 * l2_reg * (1 + const ** 2) * (w_norm ** 2)
)

if gap < tol:
if gap <= tol:
# return if we reached desired tolerance
break
else:
Expand Down
24 changes: 16 additions & 8 deletions sklearn/linear_model/tests/test_coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,17 @@ def test_lasso_zero():
# Check that the lasso can handle zero data without crashing
X = [[0], [0], [0]]
y = [0, 0, 0]
clf = Lasso(alpha=0.1).fit(X, y)
# _cd_fast.pyx tests for gap < tol, but here we get 0.0 < 0.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I do not understand why it would be the case in this test. By default the tolerance is 1e-4:

>>> from sklearn.linear_model import Lasso
>>> Lasso().tol
0.0001

This test does not override the default. The same is true for MultiTaskLasso.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok I understand, this is because we rescale the tolerance by a data dependent quantity which is 0 in this case.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is indeed sub-optimal - y == 0 are valid options. What about the following change to keep tol as is if it would become 0 (similar code for all the different coordinate descent function instances):

--- a/sklearn/linear_model/_cd_fast.pyx
+++ b/sklearn/linear_model/_cd_fast.pyx
@@ -672,6 +672,7 @@ def enet_coordinate_descent_multi_task(
     cdef floating W_ii_abs_max
     cdef floating gap = tol + 1.0
     cdef floating d_w_tol = tol
+    cdef floating temp_tol = tol
     cdef floating R_norm
     cdef floating w_norm
     cdef floating ry_sum
@@ -704,7 +705,9 @@ def enet_coordinate_descent_multi_task(
                           &R[0, jj], 1)
 
         # tol = tol * linalg.norm(Y, ord='fro') ** 2
-        tol = tol * _nrm2(n_samples * n_tasks, Y_ptr, 1) ** 2
+        temp_tol = tol * _nrm2(n_samples * n_tasks, Y_ptr, 1) ** 2
+        if temp_tol > 0:
+            tol = temp_tol
 
         for n_iter in range(max_iter):
             w_max = 0.0
@@ -819,7 +822,7 @@ def enet_coordinate_descent_multi_task(
                 gap += l1_reg * l21_norm - const * ry_sum + \
                      0.5 * l2_reg * (1 + const ** 2) * (w_norm ** 2)
 
-                if gap <= tol:
+                if gap < tol:
                     # return if we reached desired tolerance
                     break
         else:

Then we could also return to the previous (gap < tol).

# should probably be changed to gap <= tol ?
with ignore_warnings(category=ConvergenceWarning):
clf = Lasso(alpha=0.1).fit(X, y)
pred = clf.predict([[1], [2], [3]])
assert_array_almost_equal(clf.coef_, [0])
assert_array_almost_equal(pred, [0, 0, 0])
assert_almost_equal(clf.dual_gap_, 0)


@pytest.mark.filterwarnings("ignore::sklearn.exceptions.ConvergenceWarning")
def test_enet_nonfinite_params():
# Check ElasticNet throws ValueError when dealing with non-finite parameter
# values
Expand Down Expand Up @@ -358,6 +362,7 @@ def _scale_alpha_inplace(estimator, n_samples):
estimator.set_params(alpha=alpha)


@pytest.mark.filterwarnings("ignore::sklearn.exceptions.ConvergenceWarning")
@pytest.mark.parametrize(
"LinearModel, params",
[
Expand Down Expand Up @@ -593,14 +598,16 @@ def test_uniform_targets():
for model in models_single_task:
for y_values in (0, 5):
y1.fill(y_values)
assert_array_equal(model.fit(X_train, y1).predict(X_test), y1)
with ignore_warnings(category=ConvergenceWarning):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we ignore the warning all the time, we could decorate the test function as well.

assert_array_equal(model.fit(X_train, y1).predict(X_test), y1)
assert_array_equal(model.alphas_, [np.finfo(float).resolution] * 3)

for model in models_multi_task:
for y_values in (0, 5):
y2[:, 0].fill(y_values)
y2[:, 1].fill(2 * y_values)
assert_array_equal(model.fit(X_train, y2).predict(X_test), y2)
with ignore_warnings(category=ConvergenceWarning):
assert_array_equal(model.fit(X_train, y2).predict(X_test), y2)
assert_array_equal(model.alphas_, [np.finfo(float).resolution] * 3)


Expand Down Expand Up @@ -686,7 +693,7 @@ def test_multitask_enet_and_lasso_cv():

X, y, _, _ = build_dataset(n_targets=3)
clf = MultiTaskElasticNetCV(
n_alphas=10, eps=1e-3, max_iter=100, l1_ratio=[0.3, 0.5], tol=1e-3, cv=3
n_alphas=10, eps=1e-3, max_iter=200, l1_ratio=[0.3, 0.5], tol=1e-3, cv=3
)
clf.fit(X, y)
assert 0.5 == clf.l1_ratio_
Expand All @@ -696,7 +703,7 @@ def test_multitask_enet_and_lasso_cv():
assert (2, 10) == clf.alphas_.shape

X, y, _, _ = build_dataset(n_targets=3)
clf = MultiTaskLassoCV(n_alphas=10, eps=1e-3, max_iter=100, tol=1e-3, cv=3)
clf = MultiTaskLassoCV(n_alphas=10, eps=1e-3, max_iter=500, tol=1e-3, cv=3)
clf.fit(X, y)
assert (3, X.shape[1]) == clf.coef_.shape
assert (3,) == clf.intercept_.shape
Expand Down Expand Up @@ -945,7 +952,8 @@ def test_check_input_false():
# dtype is still cast in _preprocess_data to X's dtype. So the test should
# pass anyway
X = check_array(X, order="F", dtype="float32")
clf.fit(X, y, check_input=False)
with ignore_warnings(category=ConvergenceWarning):
clf.fit(X, y, check_input=False)
# With no input checking, providing X in C order should result in false
# computation
X = check_array(X, order="C", dtype="float64")
Expand Down Expand Up @@ -1061,6 +1069,7 @@ def test_enet_float_precision():
)


@pytest.mark.filterwarnings("ignore::sklearn.exceptions.ConvergenceWarning")
def test_enet_l1_ratio():
# Test that an error message is raised if an estimator that
# uses _alpha_grid is called with l1_ratio=0
Expand Down Expand Up @@ -1129,8 +1138,6 @@ def test_warm_start_multitask_lasso():
[
(Lasso, 1, dict(precompute=True)),
(Lasso, 1, dict(precompute=False)),
(MultiTaskLasso, 2, dict()),
(MultiTaskLasso, 2, dict()),
],
)
def test_enet_coordinate_descent(klass, n_classes, kwargs):
Expand Down Expand Up @@ -1474,6 +1481,7 @@ def test_enet_sample_weight_does_not_overwrite_sample_weight(check_input):
assert_array_equal(sample_weight, sample_weight_1_25)


@pytest.mark.filterwarnings("ignore::sklearn.exceptions.ConvergenceWarning")
@pytest.mark.parametrize("ridge_alpha", [1e-1, 1.0, 1e6])
def test_enet_ridge_consistency(ridge_alpha):
# Check that ElasticNet(l1_ratio=0) converges to the same solution as Ridge
Expand Down
Loading