Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
- :class:`linear_model.ElasticNet`, :class:`linear_model.ElasticNetCV`,
:class:`linear_model.Lasso`, :class:`linear_model.LassoCV`,
:class:`MultiTaskElasticNet`, :class:`MultiTaskElasticNetCV`,
:class:`MultiTaskLasso`, :class:`MultiTaskLassoCV`, as well as
:func:`linear_model.enet_path` and :func:`linear_model.lasso_path`
now use `dual gap <= tol` instead of `dual gap < tol` as stopping criterion.
The resulting coefficients might differ to previous versions of scikit-learn in
rare cases.
By :user:`Christian Lorentzen <lorentzenchr>`.
14 changes: 7 additions & 7 deletions sklearn/linear_model/_cd_fast.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def enet_coordinate_descent(

if (
w_max == 0.0
or d_w_max / w_max < d_w_tol
or d_w_max / w_max <= d_w_tol
or n_iter == max_iter - 1
):
# the biggest coordinate update of this iteration was smaller
Expand Down Expand Up @@ -298,7 +298,7 @@ def enet_coordinate_descent(
- const_ * _dot(n_samples, &R[0], 1, &y[0], 1) # np.dot(R.T, y)
+ 0.5 * beta * (1 + const_ ** 2) * (w_norm2))

if gap < tol:
if gap <= tol:
# return if we reached desired tolerance
break

Expand Down Expand Up @@ -539,7 +539,7 @@ def sparse_enet_coordinate_descent(

w_max = fmax(w_max, fabs(w[ii]))

if w_max == 0.0 or d_w_max / w_max < d_w_tol or n_iter == max_iter - 1:
if w_max == 0.0 or d_w_max / w_max <= d_w_tol or n_iter == max_iter - 1:
# the biggest coordinate update of this iteration was smaller than
# the tolerance: check the duality gap as ultimate stopping
# criterion
Expand Down Expand Up @@ -586,7 +586,7 @@ def sparse_enet_coordinate_descent(
- const_ * _dot(n_samples, &R[0], 1, &y[0], 1) # np.dot(R.T, y)
+ 0.5 * beta * (1 + const_ ** 2) * w_norm2)

if gap < tol:
if gap <= tol:
# return if we reached desired tolerance
break

Expand Down Expand Up @@ -714,7 +714,7 @@ def enet_coordinate_descent_gram(
if fabs(w[ii]) > w_max:
w_max = fabs(w[ii])

if w_max == 0.0 or d_w_max / w_max < d_w_tol or n_iter == max_iter - 1:
if w_max == 0.0 or d_w_max / w_max <= d_w_tol or n_iter == max_iter - 1:
# the biggest coordinate update of this iteration was smaller than
# the tolerance: check the duality gap as ultimate stopping
# criterion
Expand Down Expand Up @@ -752,7 +752,7 @@ def enet_coordinate_descent_gram(
+ 0.5 * beta * (1 + const_ ** 2) * w_norm2
)

if gap < tol:
if gap <= tol:
# return if we reached desired tolerance
break

Expand Down Expand Up @@ -931,7 +931,7 @@ def enet_coordinate_descent_multi_task(
if W_ii_abs_max > w_max:
w_max = W_ii_abs_max

if w_max == 0.0 or d_w_max / w_max < d_w_tol or n_iter == max_iter - 1:
if w_max == 0.0 or d_w_max / w_max <= d_w_tol or n_iter == max_iter - 1:
# the biggest coordinate update of this iteration was smaller than
# the tolerance: check the duality gap as ultimate stopping
# criterion
Expand Down
70 changes: 31 additions & 39 deletions sklearn/linear_model/_coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,10 +786,9 @@ class ElasticNet(MultiOutputMixin, RegressorMixin, LinearModel):
If ``True``, X will be copied; else, it may be overwritten.

tol : float, default=1e-4
The tolerance for the optimization: if the updates are
smaller than ``tol``, the optimization code checks the
dual gap for optimality and continues until it is smaller
than ``tol``, see Notes below.
The tolerance for the optimization: if the updates are smaller or equal to
``tol``, the optimization code checks the dual gap for optimality and continues
until it is smaller or equal to ``tol``, see Notes below.

warm_start : bool, default=False
When set to ``True``, reuse the solution of the previous call to fit as
Expand Down Expand Up @@ -857,9 +856,9 @@ class ElasticNet(MultiOutputMixin, RegressorMixin, LinearModel):

The precise stopping criteria based on `tol` are the following: First, check that
that maximum coordinate update, i.e. :math:`\\max_j |w_j^{new} - w_j^{old}|`
is smaller than `tol` times the maximum absolute coefficient, :math:`\\max_j |w_j|`.
If so, then additionally check whether the dual gap is smaller than `tol` times
:math:`||y||_2^2 / n_{\text{samples}}`.
is smaller or equal to `tol` times the maximum absolute coefficient,
:math:`\\max_j |w_j|`. If so, then additionally check whether the dual gap is
smaller or equal to `tol` times :math:`||y||_2^2 / n_{\\text{samples}}`.

Examples
--------
Expand Down Expand Up @@ -1205,13 +1204,12 @@ class Lasso(ElasticNet):
The maximum number of iterations.

tol : float, default=1e-4
The tolerance for the optimization: if the updates are
smaller than ``tol``, the optimization code checks the
dual gap for optimality and continues until it is smaller
than ``tol``, see Notes below.
The tolerance for the optimization: if the updates are smaller or equal to
``tol``, the optimization code checks the dual gap for optimality and continues
until it is smaller or equal to ``tol``, see Notes below.

warm_start : bool, default=False
When set to True, reuse the solution of the previous call to fit as
When set to ``True``, reuse the solution of the previous call to fit as
initialization, otherwise, just erase the previous solution.
See :term:`the Glossary <warm_start>`.

Expand Down Expand Up @@ -1285,9 +1283,9 @@ class Lasso(ElasticNet):

The precise stopping criteria based on `tol` are the following: First, check that
that maximum coordinate update, i.e. :math:`\\max_j |w_j^{new} - w_j^{old}|`
is smaller than `tol` times the maximum absolute coefficient, :math:`\\max_j |w_j|`.
If so, then additionally check whether the dual gap is smaller than `tol` times
:math:`||y||_2^2 / n_{\\text{samples}}`.
is smaller or equal to `tol` times the maximum absolute coefficient,
:math:`\\max_j |w_j|`. If so, then additionally check whether the dual gap is
smaller or equal to `tol` times :math:`||y||_2^2 / n_{\\text{samples}}`.

The target can be a 2-dimensional array, resulting in the optimization of the
following objective::
Expand Down Expand Up @@ -1981,10 +1979,9 @@ class LassoCV(RegressorMixin, LinearModelCV):
The maximum number of iterations.

tol : float, default=1e-4
The tolerance for the optimization: if the updates are
smaller than ``tol``, the optimization code checks the
dual gap for optimality and continues until it is smaller
than ``tol``.
The tolerance for the optimization: if the updates are smaller or equal to
``tol``, the optimization code checks the dual gap for optimality and continues
until it is smaller or equal to ``tol``.

copy_X : bool, default=True
If ``True``, X will be copied; else, it may be overwritten.
Expand Down Expand Up @@ -2252,10 +2249,9 @@ class ElasticNetCV(RegressorMixin, LinearModelCV):
The maximum number of iterations.

tol : float, default=1e-4
The tolerance for the optimization: if the updates are
smaller than ``tol``, the optimization code checks the
dual gap for optimality and continues until it is smaller
than ``tol``.
The tolerance for the optimization: if the updates are smaller or equal to
``tol``, the optimization code checks the dual gap for optimality and continues
until it is smaller or equal to ``tol``.

cv : int, cross-validation generator or iterable, default=None
Determines the cross-validation splitting strategy.
Expand Down Expand Up @@ -2525,10 +2521,9 @@ class MultiTaskElasticNet(Lasso):
The maximum number of iterations.

tol : float, default=1e-4
The tolerance for the optimization: if the updates are
smaller than ``tol``, the optimization code checks the
dual gap for optimality and continues until it is smaller
than ``tol``.
The tolerance for the optimization: if the updates are smaller or equal to
``tol``, the optimization code checks the dual gap for optimality and continues
until it is smaller or equal to ``tol``.

warm_start : bool, default=False
When set to ``True``, reuse the solution of the previous call to fit as
Expand Down Expand Up @@ -2770,10 +2765,9 @@ class MultiTaskLasso(MultiTaskElasticNet):
The maximum number of iterations.

tol : float, default=1e-4
The tolerance for the optimization: if the updates are
smaller than ``tol``, the optimization code checks the
dual gap for optimality and continues until it is smaller
than ``tol``.
The tolerance for the optimization: if the updates are smaller or equal to
``tol``, the optimization code checks the dual gap for optimality and continues
until it is smaller or equal to ``tol``.

warm_start : bool, default=False
When set to ``True``, reuse the solution of the previous call to fit as
Expand Down Expand Up @@ -2949,10 +2943,9 @@ class MultiTaskElasticNetCV(RegressorMixin, LinearModelCV):
The maximum number of iterations.

tol : float, default=1e-4
The tolerance for the optimization: if the updates are
smaller than ``tol``, the optimization code checks the
dual gap for optimality and continues until it is smaller
than ``tol``.
The tolerance for the optimization: if the updates are smaller or equal to
``tol``, the optimization code checks the dual gap for optimality and continues
until it is smaller or equal to ``tol``.

cv : int, cross-validation generator or iterable, default=None
Determines the cross-validation splitting strategy.
Expand Down Expand Up @@ -3205,10 +3198,9 @@ class MultiTaskLassoCV(RegressorMixin, LinearModelCV):
The maximum number of iterations.

tol : float, default=1e-4
The tolerance for the optimization: if the updates are
smaller than ``tol``, the optimization code checks the
dual gap for optimality and continues until it is smaller
than ``tol``.
The tolerance for the optimization: if the updates are smaller or equal to
``tol``, the optimization code checks the dual gap for optimality and continues
until it is smaller or equal to ``tol``.

copy_X : bool, default=True
If ``True``, X will be copied; else, it may be overwritten.
Expand Down
78 changes: 65 additions & 13 deletions sklearn/linear_model/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@
TheilSenRegressor,
TweedieRegressor,
)
from sklearn.preprocessing import MinMaxScaler
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.svm import LinearSVC, LinearSVR
from sklearn.utils._testing import set_random_state
from sklearn.utils._testing import assert_allclose, set_random_state
from sklearn.utils.fixes import CSR_CONTAINERS


# Note: GammaRegressor() and TweedieRegressor(power != 1) have a non-canonical link.
Expand Down Expand Up @@ -161,6 +163,7 @@ def test_balance_property(model, with_sample_weight, global_random_seed):

@pytest.mark.filterwarnings("ignore:The default of 'normalize'")
@pytest.mark.filterwarnings("ignore:lbfgs failed to converge")
@pytest.mark.filterwarnings("ignore:A column-vector y was passed when a 1d array.*")
@pytest.mark.parametrize(
"Regressor",
[
Expand Down Expand Up @@ -207,28 +210,77 @@ def test_linear_model_regressor_coef_shape(Regressor, ndim):


@pytest.mark.parametrize(
"Classifier",
["Classifier", "params"],
[
LinearSVC,
LogisticRegression,
LogisticRegressionCV,
PassiveAggressiveClassifier,
Perceptron,
RidgeClassifier,
RidgeClassifierCV,
SGDClassifier,
(LinearSVC, {}),
(LogisticRegression, {}),
(LogisticRegressionCV, {"solver": "newton-cholesky"}),
(PassiveAggressiveClassifier, {}),
(Perceptron, {}),
(RidgeClassifier, {}),
(RidgeClassifierCV, {}),
(SGDClassifier, {}),
],
)
@pytest.mark.parametrize("n_classes", [2, 3])
def test_linear_model_classifier_coef_shape(Classifier, n_classes):
def test_linear_model_classifier_coef_shape(Classifier, params, n_classes):
if Classifier in (RidgeClassifier, RidgeClassifierCV):
pytest.xfail(f"{Classifier} does not follow `coef_` shape contract!")

X, y = make_classification(n_informative=10, n_classes=n_classes, random_state=0)
n_features = X.shape[1]

classifier = Classifier()
classifier = Classifier(**params)
set_random_state(classifier)
classifier.fit(X, y)
expected_shape = (1, n_features) if n_classes == 2 else (n_classes, n_features)
assert classifier.coef_.shape == expected_shape


@pytest.mark.parametrize(
"LinearModel, params",
[
(Lasso, {"tol": 1e-15, "alpha": 0.01}),
(LassoCV, {"tol": 1e-15}),
(ElasticNetCV, {"tol": 1e-15}),
(RidgeClassifier, {"solver": "sparse_cg", "alpha": 0.1}),
(ElasticNet, {"tol": 1e-15, "l1_ratio": 1, "alpha": 0.01}),
(ElasticNet, {"tol": 1e-15, "l1_ratio": 1e-5, "alpha": 0.01}),
(Ridge, {"solver": "sparse_cg", "tol": 1e-12, "alpha": 0.1}),
(LinearRegression, {}),
(RidgeCV, {}),
(RidgeClassifierCV, {}),
],
)
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
def test_model_pipeline_same_dense_and_sparse(LinearModel, params, csr_container):
"""Test that sparse and dense linear models give same results.

Models use a preprocessing pipeline with a StandardScaler.
"""
model_dense = make_pipeline(StandardScaler(with_mean=False), LinearModel(**params))

model_sparse = make_pipeline(StandardScaler(with_mean=False), LinearModel(**params))

# prepare the data
rng = np.random.RandomState(0)
n_samples = 100
n_features = 2
X = rng.randn(n_samples, n_features)
X[X < 0.1] = 0.0

X_sparse = csr_container(X)
y = rng.rand(n_samples)

if is_classifier(model_dense):
y = np.sign(y)

model_dense.fit(X, y)
model_sparse.fit(X_sparse, y)

assert_allclose(model_sparse[1].coef_, model_dense[1].coef_, atol=1e-16)
y_pred_dense = model_dense.predict(X)
y_pred_sparse = model_sparse.predict(X_sparse)
assert_allclose(y_pred_dense, y_pred_sparse)

assert_allclose(model_dense[1].intercept_, model_sparse[1].intercept_)
Loading
Loading