From aed378ada885774395f1448b9e64c917c4adf7d7 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Sat, 9 Aug 2025 12:55:04 +0200 Subject: [PATCH 1/7] ENH improve stopping criterion of coordinate descent * tol <= gap instead of tol < gap --- sklearn/linear_model/_cd_fast.pyx | 14 +++++----- sklearn/linear_model/_coordinate_descent.py | 12 ++++---- .../tests/test_coordinate_descent.py | 28 +++++++++++-------- .../tests/test_sparse_coordinate_descent.py | 27 ++++++++++-------- 4 files changed, 45 insertions(+), 36 deletions(-) diff --git a/sklearn/linear_model/_cd_fast.pyx b/sklearn/linear_model/_cd_fast.pyx index 3956a59d91b7f..3f15a2c8d6153 100644 --- a/sklearn/linear_model/_cd_fast.pyx +++ b/sklearn/linear_model/_cd_fast.pyx @@ -219,7 +219,7 @@ def enet_coordinate_descent( if ( w_max == 0.0 - or d_w_max / w_max < d_w_tol + or d_w_max / w_max <= d_w_tol or n_iter == max_iter - 1 ): # the biggest coordinate update of this iteration was smaller @@ -258,7 +258,7 @@ def enet_coordinate_descent( - const_ * _dot(n_samples, &R[0], 1, &y[0], 1) # np.dot(R.T, y) + 0.5 * beta * (1 + const_ ** 2) * (w_norm2)) - if gap < tol: + if gap <= tol: # return if we reached desired tolerance break @@ -499,7 +499,7 @@ def sparse_enet_coordinate_descent( w_max = fmax(w_max, fabs(w[ii])) - if w_max == 0.0 or d_w_max / w_max < d_w_tol or n_iter == max_iter - 1: + if w_max == 0.0 or d_w_max / w_max <= d_w_tol or n_iter == max_iter - 1: # the biggest coordinate update of this iteration was smaller than # the tolerance: check the duality gap as ultimate stopping # criterion @@ -546,7 +546,7 @@ def sparse_enet_coordinate_descent( - const_ * _dot(n_samples, &R[0], 1, &y[0], 1) # np.dot(R.T, y) + 0.5 * beta * (1 + const_ ** 2) * w_norm2) - if gap < tol: + if gap <= tol: # return if we reached desired tolerance break @@ -674,7 +674,7 @@ def enet_coordinate_descent_gram( if fabs(w[ii]) > w_max: w_max = fabs(w[ii]) - if w_max == 0.0 or d_w_max / w_max < d_w_tol or n_iter == max_iter - 1: + if w_max == 0.0 or d_w_max / w_max <= d_w_tol or n_iter == max_iter - 1: # the biggest coordinate update of this iteration was smaller than # the tolerance: check the duality gap as ultimate stopping # criterion @@ -712,7 +712,7 @@ def enet_coordinate_descent_gram( + 0.5 * beta * (1 + const_ ** 2) * w_norm2 ) - if gap < tol: + if gap <= tol: # return if we reached desired tolerance break @@ -891,7 +891,7 @@ def enet_coordinate_descent_multi_task( if W_ii_abs_max > w_max: w_max = W_ii_abs_max - if w_max == 0.0 or d_w_max / w_max < d_w_tol or n_iter == max_iter - 1: + if w_max == 0.0 or d_w_max / w_max <= d_w_tol or n_iter == max_iter - 1: # the biggest coordinate update of this iteration was smaller than # the tolerance: check the duality gap as ultimate stopping # criterion diff --git a/sklearn/linear_model/_coordinate_descent.py b/sklearn/linear_model/_coordinate_descent.py index 0db90c7b21b02..3245e240f7722 100644 --- a/sklearn/linear_model/_coordinate_descent.py +++ b/sklearn/linear_model/_coordinate_descent.py @@ -857,9 +857,9 @@ class ElasticNet(MultiOutputMixin, RegressorMixin, LinearModel): The precise stopping criteria based on `tol` are the following: First, check that that maximum coordinate update, i.e. :math:`\\max_j |w_j^{new} - w_j^{old}|` - is smaller than `tol` times the maximum absolute coefficient, :math:`\\max_j |w_j|`. - If so, then additionally check whether the dual gap is smaller than `tol` times - :math:`||y||_2^2 / n_{\text{samples}}`. + is smaller or equal to `tol` times the maximum absolute coefficient, + :math:`\\max_j |w_j|`. If so, then additionally check whether the dual gap is + smaller or equal to `tol` times :math:`||y||_2^2 / n_{\text{samples}}`. Examples -------- @@ -1285,9 +1285,9 @@ class Lasso(ElasticNet): The precise stopping criteria based on `tol` are the following: First, check that that maximum coordinate update, i.e. :math:`\\max_j |w_j^{new} - w_j^{old}|` - is smaller than `tol` times the maximum absolute coefficient, :math:`\\max_j |w_j|`. - If so, then additionally check whether the dual gap is smaller than `tol` times - :math:`||y||_2^2 / n_{\\text{samples}}`. + is smaller or equal to `tol` times the maximum absolute coefficient, + :math:`\\max_j |w_j|`. If so, then additionally check whether the dual gap is + smaller or equal to `tol` times :math:`||y||_2^2 / n_{\text{samples}}`. The target can be a 2-dimensional array, resulting in the optimization of the following objective:: diff --git a/sklearn/linear_model/tests/test_coordinate_descent.py b/sklearn/linear_model/tests/test_coordinate_descent.py index 70226210c010d..1194348dea11c 100644 --- a/sklearn/linear_model/tests/test_coordinate_descent.py +++ b/sklearn/linear_model/tests/test_coordinate_descent.py @@ -94,10 +94,7 @@ def test_lasso_zero(): # Check that the lasso can handle zero data without crashing X = [[0], [0], [0]] y = [0, 0, 0] - # _cd_fast.pyx tests for gap < tol, but here we get 0.0 < 0.0 - # should probably be changed to gap <= tol ? - with ignore_warnings(category=ConvergenceWarning): - clf = Lasso(alpha=0.1).fit(X, y) + clf = Lasso(alpha=0.1).fit(X, y) pred = clf.predict([[1], [2], [3]]) assert_array_almost_equal(clf.coef_, [0]) assert_array_almost_equal(pred, [0, 0, 0]) @@ -105,6 +102,7 @@ def test_lasso_zero(): @pytest.mark.filterwarnings("ignore::sklearn.exceptions.ConvergenceWarning") +@pytest.mark.filterwarnings("ignore::RuntimeWarning") # overflow and similar def test_enet_nonfinite_params(): # Check ElasticNet throws ValueError when dealing with non-finite parameter # values @@ -521,6 +519,7 @@ def test_warm_start(): assert_array_almost_equal(clf2.coef_, clf.coef_) +@pytest.mark.filterwarnings("ignore:.*with no regularization.*:UserWarning") def test_lasso_alpha_warning(): X = [[-1], [0], [1]] Y = [-1, 0, 1] # just a straight line @@ -1140,13 +1139,18 @@ def test_warm_start_multitask_lasso(): ) def test_enet_coordinate_descent(klass, n_classes, kwargs): """Test that a warning is issued if model does not converge""" - clf = klass(max_iter=2, **kwargs) - n_samples = 5 - n_features = 2 - X = np.ones((n_samples, n_features)) * 1e50 - y = np.ones((n_samples, n_classes)) - if klass == Lasso: - y = y.ravel() + clf = klass( + alpha=1e-10, + fit_intercept=False, + warm_start=True, + max_iter=1, + tol=1e-10, + **kwargs, + ) + # Set initial coefficients to very bad values. + clf.coef_ = np.array([1, 1, 1, 1000]) + X = np.array([[-1, -1, 1, 1], [1, 1, -1, -1]]) + y = np.array([-1, 1]) warning_message = ( "Objective did not converge. You might want to" " increase the number of iterations." @@ -1730,6 +1734,7 @@ def test_linear_model_cv_deprecated_alphas_none(Estimator): # TODO(1.9): remove +@pytest.mark.filterwarnings("ignore:.*with no regularization.*:UserWarning") @pytest.mark.parametrize( "Estimator", [ElasticNetCV, LassoCV, MultiTaskLassoCV, MultiTaskElasticNetCV] ) @@ -1749,6 +1754,7 @@ def test_linear_model_cv_alphas_n_alphas_unset(Estimator): # TODO(1.9): remove @pytest.mark.filterwarnings("ignore:'n_alphas' was deprecated in 1.7") +@pytest.mark.filterwarnings("ignore:.*with no regularization.*:UserWarning") @pytest.mark.parametrize( "Estimator", [ElasticNetCV, LassoCV, MultiTaskLassoCV, MultiTaskElasticNetCV] ) diff --git a/sklearn/linear_model/tests/test_sparse_coordinate_descent.py b/sklearn/linear_model/tests/test_sparse_coordinate_descent.py index 1aab9babeeb40..3e68c41e8fce5 100644 --- a/sklearn/linear_model/tests/test_sparse_coordinate_descent.py +++ b/sklearn/linear_model/tests/test_sparse_coordinate_descent.py @@ -79,7 +79,6 @@ def test_enet_toy_list_input(with_sample_weight, csc_container): @pytest.mark.parametrize("lil_container", LIL_CONTAINERS) def test_enet_toy_explicit_sparse_input(lil_container): # Test ElasticNet for various values of alpha and l1_ratio with sparse X - f = ignore_warnings # training samples X = lil_container((3, 1)) X[0, 0] = -1 @@ -95,7 +94,7 @@ def test_enet_toy_explicit_sparse_input(lil_container): # this should be the same as lasso clf = ElasticNet(alpha=0, l1_ratio=1.0) - f(clf.fit)(X, Y) + ignore_warnings(clf.fit)(X, Y) pred = clf.predict(T) assert_array_almost_equal(clf.coef_, [1]) assert_array_almost_equal(pred, [2, 3, 4]) @@ -254,18 +253,19 @@ def test_path_parameters(csc_container): max_iter = 50 n_alphas = 10 clf = ElasticNetCV( - n_alphas=n_alphas, + alphas=n_alphas, eps=1e-3, max_iter=max_iter, l1_ratio=0.5, fit_intercept=False, ) - ignore_warnings(clf.fit)(X, y) # new params + clf.fit(X, y) assert_almost_equal(0.5, clf.l1_ratio) - assert n_alphas == clf.n_alphas - assert n_alphas == len(clf.alphas_) + assert clf.alphas == n_alphas + assert len(clf.alphas_) == n_alphas sparse_mse_path = clf.mse_path_ - ignore_warnings(clf.fit)(X.toarray(), y) # compare with dense data + # compare with dense data + clf.fit(X.toarray(), y) assert_almost_equal(clf.mse_path_, sparse_mse_path) @@ -356,11 +356,14 @@ def test_same_multiple_output_sparse_dense(coo_container): @pytest.mark.parametrize("csc_container", CSC_CONTAINERS) def test_sparse_enet_coordinate_descent(csc_container): """Test that a warning is issued if model does not converge""" - clf = Lasso(max_iter=2) - n_samples = 5 - n_features = 2 - X = csc_container((n_samples, n_features)) * 1e50 - y = np.ones(n_samples) + clf = Lasso( + alpha=1e-10, fit_intercept=False, warm_start=True, max_iter=2, tol=1e-10 + ) + # Set initial coefficients to very bad values. + clf.coef_ = np.array([1, 1, 1, 1000]) + X = np.array([[-1, -1, 1, 1], [1, 1, -1, -1]]) + X = csc_container(X) + y = np.array([-1, 1]) warning_message = ( "Objective did not converge. You might want " "to increase the number of iterations." From 66140087303e766e5e6faf84714f2fe89570c315 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Sat, 9 Aug 2025 12:58:01 +0200 Subject: [PATCH 2/7] TST move test_model_pipeline_same_dense_and_sparse to test_common.py --- sklearn/linear_model/tests/test_common.py | 78 +++++++++++++++---- .../tests/test_coordinate_descent.py | 53 +------------ sklearn/utils/tests/test_pprint.py | 4 +- 3 files changed, 68 insertions(+), 67 deletions(-) diff --git a/sklearn/linear_model/tests/test_common.py b/sklearn/linear_model/tests/test_common.py index 2483a26644cbb..348710e70af64 100644 --- a/sklearn/linear_model/tests/test_common.py +++ b/sklearn/linear_model/tests/test_common.py @@ -43,9 +43,11 @@ TheilSenRegressor, TweedieRegressor, ) -from sklearn.preprocessing import MinMaxScaler +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import MinMaxScaler, StandardScaler from sklearn.svm import LinearSVC, LinearSVR -from sklearn.utils._testing import set_random_state +from sklearn.utils._testing import assert_allclose, set_random_state +from sklearn.utils.fixes import CSR_CONTAINERS # Note: GammaRegressor() and TweedieRegressor(power != 1) have a non-canonical link. @@ -161,6 +163,7 @@ def test_balance_property(model, with_sample_weight, global_random_seed): @pytest.mark.filterwarnings("ignore:The default of 'normalize'") @pytest.mark.filterwarnings("ignore:lbfgs failed to converge") +@pytest.mark.filterwarnings("ignore:A column-vector y was passed when a 1d array.*") @pytest.mark.parametrize( "Regressor", [ @@ -207,28 +210,77 @@ def test_linear_model_regressor_coef_shape(Regressor, ndim): @pytest.mark.parametrize( - "Classifier", + ["Classifier", "params"], [ - LinearSVC, - LogisticRegression, - LogisticRegressionCV, - PassiveAggressiveClassifier, - Perceptron, - RidgeClassifier, - RidgeClassifierCV, - SGDClassifier, + (LinearSVC, {}), + (LogisticRegression, {}), + (LogisticRegressionCV, {"solver": "newton-cholesky"}), + (PassiveAggressiveClassifier, {}), + (Perceptron, {}), + (RidgeClassifier, {}), + (RidgeClassifierCV, {}), + (SGDClassifier, {}), ], ) @pytest.mark.parametrize("n_classes", [2, 3]) -def test_linear_model_classifier_coef_shape(Classifier, n_classes): +def test_linear_model_classifier_coef_shape(Classifier, params, n_classes): if Classifier in (RidgeClassifier, RidgeClassifierCV): pytest.xfail(f"{Classifier} does not follow `coef_` shape contract!") X, y = make_classification(n_informative=10, n_classes=n_classes, random_state=0) n_features = X.shape[1] - classifier = Classifier() + classifier = Classifier(**params) set_random_state(classifier) classifier.fit(X, y) expected_shape = (1, n_features) if n_classes == 2 else (n_classes, n_features) assert classifier.coef_.shape == expected_shape + + +@pytest.mark.parametrize( + "LinearModel, params", + [ + (Lasso, {"tol": 1e-15, "alpha": 0.01}), + (LassoCV, {"tol": 1e-15}), + (ElasticNetCV, {"tol": 1e-15}), + (RidgeClassifier, {"solver": "sparse_cg", "alpha": 0.1}), + (ElasticNet, {"tol": 1e-15, "l1_ratio": 1, "alpha": 0.01}), + (ElasticNet, {"tol": 1e-15, "l1_ratio": 1e-5, "alpha": 0.01}), + (Ridge, {"solver": "sparse_cg", "tol": 1e-12, "alpha": 0.1}), + (LinearRegression, {}), + (RidgeCV, {}), + (RidgeClassifierCV, {}), + ], +) +@pytest.mark.parametrize("csr_container", CSR_CONTAINERS) +def test_model_pipeline_same_dense_and_sparse(LinearModel, params, csr_container): + """Test that sparse and dense linear models give same results. + + Models use a preprocessing pipeline with a StandardScaler. + """ + model_dense = make_pipeline(StandardScaler(with_mean=False), LinearModel(**params)) + + model_sparse = make_pipeline(StandardScaler(with_mean=False), LinearModel(**params)) + + # prepare the data + rng = np.random.RandomState(0) + n_samples = 100 + n_features = 2 + X = rng.randn(n_samples, n_features) + X[X < 0.1] = 0.0 + + X_sparse = csr_container(X) + y = rng.rand(n_samples) + + if is_classifier(model_dense): + y = np.sign(y) + + model_dense.fit(X, y) + model_sparse.fit(X_sparse, y) + + assert_allclose(model_sparse[1].coef_, model_dense[1].coef_, atol=1e-16) + y_pred_dense = model_dense.predict(X) + y_pred_sparse = model_sparse.predict(X_sparse) + assert_allclose(y_pred_dense, y_pred_sparse) + + assert_allclose(model_dense[1].intercept_, model_sparse[1].intercept_) diff --git a/sklearn/linear_model/tests/test_coordinate_descent.py b/sklearn/linear_model/tests/test_coordinate_descent.py index 1194348dea11c..cd44118778194 100644 --- a/sklearn/linear_model/tests/test_coordinate_descent.py +++ b/sklearn/linear_model/tests/test_coordinate_descent.py @@ -9,7 +9,7 @@ import pytest from scipy import interpolate, sparse -from sklearn.base import clone, config_context, is_classifier +from sklearn.base import clone, config_context from sklearn.datasets import load_diabetes, make_regression from sklearn.exceptions import ConvergenceWarning from sklearn.linear_model import ( @@ -19,7 +19,6 @@ LassoCV, LassoLars, LassoLarsCV, - LinearRegression, MultiTaskElasticNet, MultiTaskElasticNetCV, MultiTaskLasso, @@ -358,56 +357,6 @@ def _scale_alpha_inplace(estimator, n_samples): estimator.set_params(alpha=alpha) -@pytest.mark.filterwarnings("ignore::sklearn.exceptions.ConvergenceWarning") -@pytest.mark.parametrize( - "LinearModel, params", - [ - (Lasso, {"tol": 1e-16, "alpha": 0.1}), - (LassoCV, {"tol": 1e-16}), - (ElasticNetCV, {}), - (RidgeClassifier, {"solver": "sparse_cg", "alpha": 0.1}), - (ElasticNet, {"tol": 1e-16, "l1_ratio": 1, "alpha": 0.01}), - (ElasticNet, {"tol": 1e-16, "l1_ratio": 0, "alpha": 0.01}), - (Ridge, {"solver": "sparse_cg", "tol": 1e-12, "alpha": 0.1}), - (LinearRegression, {}), - (RidgeCV, {}), - (RidgeClassifierCV, {}), - ], -) -@pytest.mark.parametrize("csr_container", CSR_CONTAINERS) -def test_model_pipeline_same_dense_and_sparse(LinearModel, params, csr_container): - # Test that linear model preceded by StandardScaler in the pipeline and - # with normalize set to False gives the same y_pred and the same .coef_ - # given X sparse or dense - - model_dense = make_pipeline(StandardScaler(with_mean=False), LinearModel(**params)) - - model_sparse = make_pipeline(StandardScaler(with_mean=False), LinearModel(**params)) - - # prepare the data - rng = np.random.RandomState(0) - n_samples = 200 - n_features = 2 - X = rng.randn(n_samples, n_features) - X[X < 0.1] = 0.0 - - X_sparse = csr_container(X) - y = rng.rand(n_samples) - - if is_classifier(model_dense): - y = np.sign(y) - - model_dense.fit(X, y) - model_sparse.fit(X_sparse, y) - - assert_allclose(model_sparse[1].coef_, model_dense[1].coef_) - y_pred_dense = model_dense.predict(X) - y_pred_sparse = model_sparse.predict(X_sparse) - assert_allclose(y_pred_dense, y_pred_sparse) - - assert_allclose(model_dense[1].intercept_, model_sparse[1].intercept_) - - def test_lasso_path_return_models_vs_new_return_gives_same_coefficients(): # Test that lasso_path with lars_path style output gives the # same result diff --git a/sklearn/utils/tests/test_pprint.py b/sklearn/utils/tests/test_pprint.py index ee3e267dd5cbe..7fd876eb167bd 100644 --- a/sklearn/utils/tests/test_pprint.py +++ b/sklearn/utils/tests/test_pprint.py @@ -406,7 +406,7 @@ def test_gridsearch_pipeline(print_changed_only_false): "classify__C": C_OPTIONS, }, ] - gspipline = GridSearchCV(pipeline, cv=3, n_jobs=1, param_grid=param_grid) + gspipeline = GridSearchCV(pipeline, cv=3, n_jobs=1, param_grid=param_grid) expected = """ GridSearchCV(cv=3, error_score='raise-deprecating', estimator=Pipeline(memory=None, @@ -447,7 +447,7 @@ def test_gridsearch_pipeline(print_changed_only_false): scoring=None, verbose=0)""" # noqa: E501 expected = expected[1:] # remove first \n - repr_ = pp.pformat(gspipline) + repr_ = pp.pformat(gspipeline) # Remove address of '' for reproducibility repr_ = re.sub("function chi2 at 0x.*>", "function chi2 at some_address>", repr_) assert repr_ == expected From c960dcaf78efa974b706d7069ac6fb1ef855864f Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Sat, 9 Aug 2025 13:12:28 +0200 Subject: [PATCH 3/7] DOC add whatsnew --- .../sklearn.linear_model/31906.enhancement.rst | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 doc/whats_new/upcoming_changes/sklearn.linear_model/31906.enhancement.rst diff --git a/doc/whats_new/upcoming_changes/sklearn.linear_model/31906.enhancement.rst b/doc/whats_new/upcoming_changes/sklearn.linear_model/31906.enhancement.rst new file mode 100644 index 0000000000000..d6a0615af49a1 --- /dev/null +++ b/doc/whats_new/upcoming_changes/sklearn.linear_model/31906.enhancement.rst @@ -0,0 +1,9 @@ +- :class:`linear_model.ElasticNet`, :class:`linear_model.ElasticNetCV`, + :class:`linear_model.Lasso`, :class:`linear_model.LassoCV`, + :class:`MultiTaskElasticNet`, :class:`MultiTaskElasticNetCV`, + :class:`MultiTaskLasso`, :class:`MultiTaskLassoCV`, as well as + :func:`linear_model.enet_path` and :func:`linear_model.lasso_path` + now use `tol <= dual gap` instead of `tol < dual gap`. + The resulting coefficients might differ to previous versions of scikit-learn in + rare cases. + By :user:`Christian Lorentzen `. From 223e2ef6920a03c312f2d519fbc79a12a49a92a4 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Tue, 12 Aug 2025 07:49:54 +0200 Subject: [PATCH 4/7] CLN fix latex formula --- sklearn/linear_model/_coordinate_descent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/linear_model/_coordinate_descent.py b/sklearn/linear_model/_coordinate_descent.py index 3245e240f7722..153d96cdb96f8 100644 --- a/sklearn/linear_model/_coordinate_descent.py +++ b/sklearn/linear_model/_coordinate_descent.py @@ -859,7 +859,7 @@ class ElasticNet(MultiOutputMixin, RegressorMixin, LinearModel): that maximum coordinate update, i.e. :math:`\\max_j |w_j^{new} - w_j^{old}|` is smaller or equal to `tol` times the maximum absolute coefficient, :math:`\\max_j |w_j|`. If so, then additionally check whether the dual gap is - smaller or equal to `tol` times :math:`||y||_2^2 / n_{\text{samples}}`. + smaller or equal to `tol` times :math:`||y||_2^2 / n_{\\text{samples}}`. Examples -------- @@ -1287,7 +1287,7 @@ class Lasso(ElasticNet): that maximum coordinate update, i.e. :math:`\\max_j |w_j^{new} - w_j^{old}|` is smaller or equal to `tol` times the maximum absolute coefficient, :math:`\\max_j |w_j|`. If so, then additionally check whether the dual gap is - smaller or equal to `tol` times :math:`||y||_2^2 / n_{\text{samples}}`. + smaller or equal to `tol` times :math:`||y||_2^2 / n_{\\text{samples}}`. The target can be a 2-dimensional array, resulting in the optimization of the following objective:: From fea0c1597ff9576a6c8f325fb015705600c71f92 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Tue, 12 Aug 2025 07:59:23 +0200 Subject: [PATCH 5/7] DOC docstring for tol --- sklearn/linear_model/_coordinate_descent.py | 58 +++++++++------------ 1 file changed, 25 insertions(+), 33 deletions(-) diff --git a/sklearn/linear_model/_coordinate_descent.py b/sklearn/linear_model/_coordinate_descent.py index 153d96cdb96f8..a1abc4fdc28ff 100644 --- a/sklearn/linear_model/_coordinate_descent.py +++ b/sklearn/linear_model/_coordinate_descent.py @@ -786,10 +786,9 @@ class ElasticNet(MultiOutputMixin, RegressorMixin, LinearModel): If ``True``, X will be copied; else, it may be overwritten. tol : float, default=1e-4 - The tolerance for the optimization: if the updates are - smaller than ``tol``, the optimization code checks the - dual gap for optimality and continues until it is smaller - than ``tol``, see Notes below. + The tolerance for the optimization: if the updates are smaller or equal to + ``tol``, the optimization code checks the dual gap for optimality and continues + until it is smaller or equal to ``tol``, see Notes below. warm_start : bool, default=False When set to ``True``, reuse the solution of the previous call to fit as @@ -1205,13 +1204,12 @@ class Lasso(ElasticNet): The maximum number of iterations. tol : float, default=1e-4 - The tolerance for the optimization: if the updates are - smaller than ``tol``, the optimization code checks the - dual gap for optimality and continues until it is smaller - than ``tol``, see Notes below. + The tolerance for the optimization: if the updates are smaller or equal to + ``tol``, the optimization code checks the dual gap for optimality and continues + until it is smaller or equal to ``tol``, see Notes below. warm_start : bool, default=False - When set to True, reuse the solution of the previous call to fit as + When set to ``True``, reuse the solution of the previous call to fit as initialization, otherwise, just erase the previous solution. See :term:`the Glossary `. @@ -1981,10 +1979,9 @@ class LassoCV(RegressorMixin, LinearModelCV): The maximum number of iterations. tol : float, default=1e-4 - The tolerance for the optimization: if the updates are - smaller than ``tol``, the optimization code checks the - dual gap for optimality and continues until it is smaller - than ``tol``. + The tolerance for the optimization: if the updates are smaller or equal to + ``tol``, the optimization code checks the dual gap for optimality and continues + until it is smaller or equal to ``tol``. copy_X : bool, default=True If ``True``, X will be copied; else, it may be overwritten. @@ -2252,10 +2249,9 @@ class ElasticNetCV(RegressorMixin, LinearModelCV): The maximum number of iterations. tol : float, default=1e-4 - The tolerance for the optimization: if the updates are - smaller than ``tol``, the optimization code checks the - dual gap for optimality and continues until it is smaller - than ``tol``. + The tolerance for the optimization: if the updates are smaller or equal to + ``tol``, the optimization code checks the dual gap for optimality and continues + until it is smaller or equal to ``tol``. cv : int, cross-validation generator or iterable, default=None Determines the cross-validation splitting strategy. @@ -2525,10 +2521,9 @@ class MultiTaskElasticNet(Lasso): The maximum number of iterations. tol : float, default=1e-4 - The tolerance for the optimization: if the updates are - smaller than ``tol``, the optimization code checks the - dual gap for optimality and continues until it is smaller - than ``tol``. + The tolerance for the optimization: if the updates are smaller or equal to + ``tol``, the optimization code checks the dual gap for optimality and continues + until it is smaller or equal to ``tol``. warm_start : bool, default=False When set to ``True``, reuse the solution of the previous call to fit as @@ -2770,10 +2765,9 @@ class MultiTaskLasso(MultiTaskElasticNet): The maximum number of iterations. tol : float, default=1e-4 - The tolerance for the optimization: if the updates are - smaller than ``tol``, the optimization code checks the - dual gap for optimality and continues until it is smaller - than ``tol``. + The tolerance for the optimization: if the updates are smaller or equal to + ``tol``, the optimization code checks the dual gap for optimality and continues + until it is smaller or equal to ``tol``. warm_start : bool, default=False When set to ``True``, reuse the solution of the previous call to fit as @@ -2949,10 +2943,9 @@ class MultiTaskElasticNetCV(RegressorMixin, LinearModelCV): The maximum number of iterations. tol : float, default=1e-4 - The tolerance for the optimization: if the updates are - smaller than ``tol``, the optimization code checks the - dual gap for optimality and continues until it is smaller - than ``tol``. + The tolerance for the optimization: if the updates are smaller or equal to + ``tol``, the optimization code checks the dual gap for optimality and continues + until it is smaller or equal to ``tol``. cv : int, cross-validation generator or iterable, default=None Determines the cross-validation splitting strategy. @@ -3205,10 +3198,9 @@ class MultiTaskLassoCV(RegressorMixin, LinearModelCV): The maximum number of iterations. tol : float, default=1e-4 - The tolerance for the optimization: if the updates are - smaller than ``tol``, the optimization code checks the - dual gap for optimality and continues until it is smaller - than ``tol``. + The tolerance for the optimization: if the updates are smaller or equal to + ``tol``, the optimization code checks the dual gap for optimality and continues + until it is smaller or equal to ``tol``. copy_X : bool, default=True If ``True``, X will be copied; else, it may be overwritten. From 5f8bdf4ae872c005a2b796d120ab71fe14495a5d Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Wed, 13 Aug 2025 06:49:43 +0200 Subject: [PATCH 6/7] Update doc/whats_new/upcoming_changes/sklearn.linear_model/31906.enhancement.rst Co-authored-by: Omar Salman --- .../upcoming_changes/sklearn.linear_model/31906.enhancement.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/upcoming_changes/sklearn.linear_model/31906.enhancement.rst b/doc/whats_new/upcoming_changes/sklearn.linear_model/31906.enhancement.rst index d6a0615af49a1..eac6864cc3505 100644 --- a/doc/whats_new/upcoming_changes/sklearn.linear_model/31906.enhancement.rst +++ b/doc/whats_new/upcoming_changes/sklearn.linear_model/31906.enhancement.rst @@ -3,7 +3,7 @@ :class:`MultiTaskElasticNet`, :class:`MultiTaskElasticNetCV`, :class:`MultiTaskLasso`, :class:`MultiTaskLassoCV`, as well as :func:`linear_model.enet_path` and :func:`linear_model.lasso_path` - now use `tol <= dual gap` instead of `tol < dual gap`. + now use `dual gap <= tol` instead of `dual gap < tol`. The resulting coefficients might differ to previous versions of scikit-learn in rare cases. By :user:`Christian Lorentzen `. From 35712944fdcd3cee7cadbfe595f83571f7a663a9 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Wed, 13 Aug 2025 06:52:29 +0200 Subject: [PATCH 7/7] Update 31906.enhancement.rst --- .../upcoming_changes/sklearn.linear_model/31906.enhancement.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/upcoming_changes/sklearn.linear_model/31906.enhancement.rst b/doc/whats_new/upcoming_changes/sklearn.linear_model/31906.enhancement.rst index eac6864cc3505..8417c3dd2ac29 100644 --- a/doc/whats_new/upcoming_changes/sklearn.linear_model/31906.enhancement.rst +++ b/doc/whats_new/upcoming_changes/sklearn.linear_model/31906.enhancement.rst @@ -3,7 +3,7 @@ :class:`MultiTaskElasticNet`, :class:`MultiTaskElasticNetCV`, :class:`MultiTaskLasso`, :class:`MultiTaskLassoCV`, as well as :func:`linear_model.enet_path` and :func:`linear_model.lasso_path` - now use `dual gap <= tol` instead of `dual gap < tol`. + now use `dual gap <= tol` instead of `dual gap < tol` as stopping criterion. The resulting coefficients might differ to previous versions of scikit-learn in rare cases. By :user:`Christian Lorentzen `.