Skip to content

Commit 33a733e

Browse files
authored
ENH/FIX stopping criterion for coordinate descent gap <= tol (scikit-learn#31906)
1 parent 3c74809 commit 33a733e

File tree

7 files changed

+147
-136
lines changed

7 files changed

+147
-136
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
- :class:`linear_model.ElasticNet`, :class:`linear_model.ElasticNetCV`,
2+
:class:`linear_model.Lasso`, :class:`linear_model.LassoCV`,
3+
:class:`MultiTaskElasticNet`, :class:`MultiTaskElasticNetCV`,
4+
:class:`MultiTaskLasso`, :class:`MultiTaskLassoCV`, as well as
5+
:func:`linear_model.enet_path` and :func:`linear_model.lasso_path`
6+
now use `dual gap <= tol` instead of `dual gap < tol` as stopping criterion.
7+
The resulting coefficients might differ to previous versions of scikit-learn in
8+
rare cases.
9+
By :user:`Christian Lorentzen <lorentzenchr>`.

sklearn/linear_model/_cd_fast.pyx

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def enet_coordinate_descent(
259259

260260
if (
261261
w_max == 0.0
262-
or d_w_max / w_max < d_w_tol
262+
or d_w_max / w_max <= d_w_tol
263263
or n_iter == max_iter - 1
264264
):
265265
# the biggest coordinate update of this iteration was smaller
@@ -298,7 +298,7 @@ def enet_coordinate_descent(
298298
- const_ * _dot(n_samples, &R[0], 1, &y[0], 1) # np.dot(R.T, y)
299299
+ 0.5 * beta * (1 + const_ ** 2) * (w_norm2))
300300

301-
if gap < tol:
301+
if gap <= tol:
302302
# return if we reached desired tolerance
303303
break
304304

@@ -539,7 +539,7 @@ def sparse_enet_coordinate_descent(
539539

540540
w_max = fmax(w_max, fabs(w[ii]))
541541

542-
if w_max == 0.0 or d_w_max / w_max < d_w_tol or n_iter == max_iter - 1:
542+
if w_max == 0.0 or d_w_max / w_max <= d_w_tol or n_iter == max_iter - 1:
543543
# the biggest coordinate update of this iteration was smaller than
544544
# the tolerance: check the duality gap as ultimate stopping
545545
# criterion
@@ -586,7 +586,7 @@ def sparse_enet_coordinate_descent(
586586
- const_ * _dot(n_samples, &R[0], 1, &y[0], 1) # np.dot(R.T, y)
587587
+ 0.5 * beta * (1 + const_ ** 2) * w_norm2)
588588

589-
if gap < tol:
589+
if gap <= tol:
590590
# return if we reached desired tolerance
591591
break
592592

@@ -714,7 +714,7 @@ def enet_coordinate_descent_gram(
714714
if fabs(w[ii]) > w_max:
715715
w_max = fabs(w[ii])
716716

717-
if w_max == 0.0 or d_w_max / w_max < d_w_tol or n_iter == max_iter - 1:
717+
if w_max == 0.0 or d_w_max / w_max <= d_w_tol or n_iter == max_iter - 1:
718718
# the biggest coordinate update of this iteration was smaller than
719719
# the tolerance: check the duality gap as ultimate stopping
720720
# criterion
@@ -752,7 +752,7 @@ def enet_coordinate_descent_gram(
752752
+ 0.5 * beta * (1 + const_ ** 2) * w_norm2
753753
)
754754

755-
if gap < tol:
755+
if gap <= tol:
756756
# return if we reached desired tolerance
757757
break
758758

@@ -931,7 +931,7 @@ def enet_coordinate_descent_multi_task(
931931
if W_ii_abs_max > w_max:
932932
w_max = W_ii_abs_max
933933

934-
if w_max == 0.0 or d_w_max / w_max < d_w_tol or n_iter == max_iter - 1:
934+
if w_max == 0.0 or d_w_max / w_max <= d_w_tol or n_iter == max_iter - 1:
935935
# the biggest coordinate update of this iteration was smaller than
936936
# the tolerance: check the duality gap as ultimate stopping
937937
# criterion

sklearn/linear_model/_coordinate_descent.py

Lines changed: 31 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -786,10 +786,9 @@ class ElasticNet(MultiOutputMixin, RegressorMixin, LinearModel):
786786
If ``True``, X will be copied; else, it may be overwritten.
787787
788788
tol : float, default=1e-4
789-
The tolerance for the optimization: if the updates are
790-
smaller than ``tol``, the optimization code checks the
791-
dual gap for optimality and continues until it is smaller
792-
than ``tol``, see Notes below.
789+
The tolerance for the optimization: if the updates are smaller or equal to
790+
``tol``, the optimization code checks the dual gap for optimality and continues
791+
until it is smaller or equal to ``tol``, see Notes below.
793792
794793
warm_start : bool, default=False
795794
When set to ``True``, reuse the solution of the previous call to fit as
@@ -857,9 +856,9 @@ class ElasticNet(MultiOutputMixin, RegressorMixin, LinearModel):
857856
858857
The precise stopping criteria based on `tol` are the following: First, check that
859858
that maximum coordinate update, i.e. :math:`\\max_j |w_j^{new} - w_j^{old}|`
860-
is smaller than `tol` times the maximum absolute coefficient, :math:`\\max_j |w_j|`.
861-
If so, then additionally check whether the dual gap is smaller than `tol` times
862-
:math:`||y||_2^2 / n_{\text{samples}}`.
859+
is smaller or equal to `tol` times the maximum absolute coefficient,
860+
:math:`\\max_j |w_j|`. If so, then additionally check whether the dual gap is
861+
smaller or equal to `tol` times :math:`||y||_2^2 / n_{\\text{samples}}`.
863862
864863
Examples
865864
--------
@@ -1205,13 +1204,12 @@ class Lasso(ElasticNet):
12051204
The maximum number of iterations.
12061205
12071206
tol : float, default=1e-4
1208-
The tolerance for the optimization: if the updates are
1209-
smaller than ``tol``, the optimization code checks the
1210-
dual gap for optimality and continues until it is smaller
1211-
than ``tol``, see Notes below.
1207+
The tolerance for the optimization: if the updates are smaller or equal to
1208+
``tol``, the optimization code checks the dual gap for optimality and continues
1209+
until it is smaller or equal to ``tol``, see Notes below.
12121210
12131211
warm_start : bool, default=False
1214-
When set to True, reuse the solution of the previous call to fit as
1212+
When set to ``True``, reuse the solution of the previous call to fit as
12151213
initialization, otherwise, just erase the previous solution.
12161214
See :term:`the Glossary <warm_start>`.
12171215
@@ -1285,9 +1283,9 @@ class Lasso(ElasticNet):
12851283
12861284
The precise stopping criteria based on `tol` are the following: First, check that
12871285
that maximum coordinate update, i.e. :math:`\\max_j |w_j^{new} - w_j^{old}|`
1288-
is smaller than `tol` times the maximum absolute coefficient, :math:`\\max_j |w_j|`.
1289-
If so, then additionally check whether the dual gap is smaller than `tol` times
1290-
:math:`||y||_2^2 / n_{\\text{samples}}`.
1286+
is smaller or equal to `tol` times the maximum absolute coefficient,
1287+
:math:`\\max_j |w_j|`. If so, then additionally check whether the dual gap is
1288+
smaller or equal to `tol` times :math:`||y||_2^2 / n_{\\text{samples}}`.
12911289
12921290
The target can be a 2-dimensional array, resulting in the optimization of the
12931291
following objective::
@@ -1981,10 +1979,9 @@ class LassoCV(RegressorMixin, LinearModelCV):
19811979
The maximum number of iterations.
19821980
19831981
tol : float, default=1e-4
1984-
The tolerance for the optimization: if the updates are
1985-
smaller than ``tol``, the optimization code checks the
1986-
dual gap for optimality and continues until it is smaller
1987-
than ``tol``.
1982+
The tolerance for the optimization: if the updates are smaller or equal to
1983+
``tol``, the optimization code checks the dual gap for optimality and continues
1984+
until it is smaller or equal to ``tol``.
19881985
19891986
copy_X : bool, default=True
19901987
If ``True``, X will be copied; else, it may be overwritten.
@@ -2252,10 +2249,9 @@ class ElasticNetCV(RegressorMixin, LinearModelCV):
22522249
The maximum number of iterations.
22532250
22542251
tol : float, default=1e-4
2255-
The tolerance for the optimization: if the updates are
2256-
smaller than ``tol``, the optimization code checks the
2257-
dual gap for optimality and continues until it is smaller
2258-
than ``tol``.
2252+
The tolerance for the optimization: if the updates are smaller or equal to
2253+
``tol``, the optimization code checks the dual gap for optimality and continues
2254+
until it is smaller or equal to ``tol``.
22592255
22602256
cv : int, cross-validation generator or iterable, default=None
22612257
Determines the cross-validation splitting strategy.
@@ -2525,10 +2521,9 @@ class MultiTaskElasticNet(Lasso):
25252521
The maximum number of iterations.
25262522
25272523
tol : float, default=1e-4
2528-
The tolerance for the optimization: if the updates are
2529-
smaller than ``tol``, the optimization code checks the
2530-
dual gap for optimality and continues until it is smaller
2531-
than ``tol``.
2524+
The tolerance for the optimization: if the updates are smaller or equal to
2525+
``tol``, the optimization code checks the dual gap for optimality and continues
2526+
until it is smaller or equal to ``tol``.
25322527
25332528
warm_start : bool, default=False
25342529
When set to ``True``, reuse the solution of the previous call to fit as
@@ -2770,10 +2765,9 @@ class MultiTaskLasso(MultiTaskElasticNet):
27702765
The maximum number of iterations.
27712766
27722767
tol : float, default=1e-4
2773-
The tolerance for the optimization: if the updates are
2774-
smaller than ``tol``, the optimization code checks the
2775-
dual gap for optimality and continues until it is smaller
2776-
than ``tol``.
2768+
The tolerance for the optimization: if the updates are smaller or equal to
2769+
``tol``, the optimization code checks the dual gap for optimality and continues
2770+
until it is smaller or equal to ``tol``.
27772771
27782772
warm_start : bool, default=False
27792773
When set to ``True``, reuse the solution of the previous call to fit as
@@ -2949,10 +2943,9 @@ class MultiTaskElasticNetCV(RegressorMixin, LinearModelCV):
29492943
The maximum number of iterations.
29502944
29512945
tol : float, default=1e-4
2952-
The tolerance for the optimization: if the updates are
2953-
smaller than ``tol``, the optimization code checks the
2954-
dual gap for optimality and continues until it is smaller
2955-
than ``tol``.
2946+
The tolerance for the optimization: if the updates are smaller or equal to
2947+
``tol``, the optimization code checks the dual gap for optimality and continues
2948+
until it is smaller or equal to ``tol``.
29562949
29572950
cv : int, cross-validation generator or iterable, default=None
29582951
Determines the cross-validation splitting strategy.
@@ -3205,10 +3198,9 @@ class MultiTaskLassoCV(RegressorMixin, LinearModelCV):
32053198
The maximum number of iterations.
32063199
32073200
tol : float, default=1e-4
3208-
The tolerance for the optimization: if the updates are
3209-
smaller than ``tol``, the optimization code checks the
3210-
dual gap for optimality and continues until it is smaller
3211-
than ``tol``.
3201+
The tolerance for the optimization: if the updates are smaller or equal to
3202+
``tol``, the optimization code checks the dual gap for optimality and continues
3203+
until it is smaller or equal to ``tol``.
32123204
32133205
copy_X : bool, default=True
32143206
If ``True``, X will be copied; else, it may be overwritten.

sklearn/linear_model/tests/test_common.py

Lines changed: 65 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,11 @@
4343
TheilSenRegressor,
4444
TweedieRegressor,
4545
)
46-
from sklearn.preprocessing import MinMaxScaler
46+
from sklearn.pipeline import make_pipeline
47+
from sklearn.preprocessing import MinMaxScaler, StandardScaler
4748
from sklearn.svm import LinearSVC, LinearSVR
48-
from sklearn.utils._testing import set_random_state
49+
from sklearn.utils._testing import assert_allclose, set_random_state
50+
from sklearn.utils.fixes import CSR_CONTAINERS
4951

5052

5153
# Note: GammaRegressor() and TweedieRegressor(power != 1) have a non-canonical link.
@@ -161,6 +163,7 @@ def test_balance_property(model, with_sample_weight, global_random_seed):
161163

162164
@pytest.mark.filterwarnings("ignore:The default of 'normalize'")
163165
@pytest.mark.filterwarnings("ignore:lbfgs failed to converge")
166+
@pytest.mark.filterwarnings("ignore:A column-vector y was passed when a 1d array.*")
164167
@pytest.mark.parametrize(
165168
"Regressor",
166169
[
@@ -207,28 +210,77 @@ def test_linear_model_regressor_coef_shape(Regressor, ndim):
207210

208211

209212
@pytest.mark.parametrize(
210-
"Classifier",
213+
["Classifier", "params"],
211214
[
212-
LinearSVC,
213-
LogisticRegression,
214-
LogisticRegressionCV,
215-
PassiveAggressiveClassifier,
216-
Perceptron,
217-
RidgeClassifier,
218-
RidgeClassifierCV,
219-
SGDClassifier,
215+
(LinearSVC, {}),
216+
(LogisticRegression, {}),
217+
(LogisticRegressionCV, {"solver": "newton-cholesky"}),
218+
(PassiveAggressiveClassifier, {}),
219+
(Perceptron, {}),
220+
(RidgeClassifier, {}),
221+
(RidgeClassifierCV, {}),
222+
(SGDClassifier, {}),
220223
],
221224
)
222225
@pytest.mark.parametrize("n_classes", [2, 3])
223-
def test_linear_model_classifier_coef_shape(Classifier, n_classes):
226+
def test_linear_model_classifier_coef_shape(Classifier, params, n_classes):
224227
if Classifier in (RidgeClassifier, RidgeClassifierCV):
225228
pytest.xfail(f"{Classifier} does not follow `coef_` shape contract!")
226229

227230
X, y = make_classification(n_informative=10, n_classes=n_classes, random_state=0)
228231
n_features = X.shape[1]
229232

230-
classifier = Classifier()
233+
classifier = Classifier(**params)
231234
set_random_state(classifier)
232235
classifier.fit(X, y)
233236
expected_shape = (1, n_features) if n_classes == 2 else (n_classes, n_features)
234237
assert classifier.coef_.shape == expected_shape
238+
239+
240+
@pytest.mark.parametrize(
241+
"LinearModel, params",
242+
[
243+
(Lasso, {"tol": 1e-15, "alpha": 0.01}),
244+
(LassoCV, {"tol": 1e-15}),
245+
(ElasticNetCV, {"tol": 1e-15}),
246+
(RidgeClassifier, {"solver": "sparse_cg", "alpha": 0.1}),
247+
(ElasticNet, {"tol": 1e-15, "l1_ratio": 1, "alpha": 0.01}),
248+
(ElasticNet, {"tol": 1e-15, "l1_ratio": 1e-5, "alpha": 0.01}),
249+
(Ridge, {"solver": "sparse_cg", "tol": 1e-12, "alpha": 0.1}),
250+
(LinearRegression, {}),
251+
(RidgeCV, {}),
252+
(RidgeClassifierCV, {}),
253+
],
254+
)
255+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
256+
def test_model_pipeline_same_dense_and_sparse(LinearModel, params, csr_container):
257+
"""Test that sparse and dense linear models give same results.
258+
259+
Models use a preprocessing pipeline with a StandardScaler.
260+
"""
261+
model_dense = make_pipeline(StandardScaler(with_mean=False), LinearModel(**params))
262+
263+
model_sparse = make_pipeline(StandardScaler(with_mean=False), LinearModel(**params))
264+
265+
# prepare the data
266+
rng = np.random.RandomState(0)
267+
n_samples = 100
268+
n_features = 2
269+
X = rng.randn(n_samples, n_features)
270+
X[X < 0.1] = 0.0
271+
272+
X_sparse = csr_container(X)
273+
y = rng.rand(n_samples)
274+
275+
if is_classifier(model_dense):
276+
y = np.sign(y)
277+
278+
model_dense.fit(X, y)
279+
model_sparse.fit(X_sparse, y)
280+
281+
assert_allclose(model_sparse[1].coef_, model_dense[1].coef_, atol=1e-16)
282+
y_pred_dense = model_dense.predict(X)
283+
y_pred_sparse = model_sparse.predict(X_sparse)
284+
assert_allclose(y_pred_dense, y_pred_sparse)
285+
286+
assert_allclose(model_dense[1].intercept_, model_sparse[1].intercept_)

0 commit comments

Comments
 (0)