Skip to content

Commit e6d0950

Browse files
lestevejeremiedbb
authored andcommitted
MNT Remove DeprecationWarning for scipy.sparse.linalg.cg tol vs rtol argument (#26814)
1 parent e7eb359 commit e6d0950

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

sklearn/linear_model/_ridge.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
)
3434
from ..utils._param_validation import Interval, StrOptions
3535
from ..utils.extmath import row_norms, safe_sparse_dot
36+
from ..utils.fixes import _sparse_linalg_cg
3637
from ..utils.sparsefuncs import mean_variance_axis
3738
from ..utils.validation import _check_sample_weight, check_is_fitted
3839
from ._base import LinearClassifierMixin, LinearModel, _preprocess_data, _rescale_data
@@ -105,7 +106,7 @@ def _mv(x):
105106
C = sp_linalg.LinearOperator(
106107
(n_samples, n_samples), matvec=mv, dtype=X.dtype
107108
)
108-
coef, info = sp_linalg.cg(C, y_column, tol=tol, atol="legacy")
109+
coef, info = _sparse_linalg_cg(C, y_column, rtol=tol)
109110
coefs[i] = X1.rmatvec(coef)
110111
else:
111112
# linear ridge
@@ -114,9 +115,7 @@ def _mv(x):
114115
C = sp_linalg.LinearOperator(
115116
(n_features, n_features), matvec=mv, dtype=X.dtype
116117
)
117-
coefs[i], info = sp_linalg.cg(
118-
C, y_column, maxiter=max_iter, tol=tol, atol="legacy"
119-
)
118+
coefs[i], info = _sparse_linalg_cg(C, y_column, maxiter=max_iter, rtol=tol)
120119

121120
if info < 0:
122121
raise ValueError("Failed with error code %d" % info)

sklearn/utils/fixes.py

+14
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import numpy as np
1717
import scipy
18+
import scipy.sparse.linalg
1819
import scipy.stats
1920
import threadpoolctl
2021

@@ -109,6 +110,19 @@ def _mode(a, axis=0):
109110
return scipy.stats.mode(a, axis=axis)
110111

111112

113+
# TODO: Remove when Scipy 1.12 is the minimum supported version
114+
if sp_base_version >= parse_version("1.12.0"):
115+
_sparse_linalg_cg = scipy.sparse.linalg.cg
116+
else:
117+
118+
def _sparse_linalg_cg(A, b, **kwargs):
119+
if "rtol" in kwargs:
120+
kwargs["tol"] = kwargs.pop("rtol")
121+
if "atol" not in kwargs:
122+
kwargs["atol"] = "legacy"
123+
return scipy.sparse.linalg.cg(A, b, **kwargs)
124+
125+
112126
###############################################################################
113127
# Backport of Python 3.9's importlib.resources
114128
# TODO: Remove when Python 3.9 is the minimum supported version

0 commit comments

Comments
 (0)