Skip to content

Commit dd2fa68

Browse files
timstaleyjnothman
authored andcommitted
Fix ARDRegression accuracy issue with scipy 1.3.0 (scikit-learn#14067)
1 parent a6bdb2b commit dd2fa68

File tree

5 files changed

+151
-4
lines changed

5 files changed

+151
-4
lines changed

doc/whats_new/v0.21.rst

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,20 @@ Changelog
4949
:mod:`sklearn.impute`
5050
.....................
5151

52-
- |Fix| Fixed a bug in :class:`SimpleImputer` and :class:`IterativeImputer`
53-
so that no errors are thrown when there are missing values in training data.
54-
:pr:`13974` by `Frank Hoang <fhoang7>`.
52+
- |Fix| Fixed a bug in :class:`impute.SimpleImputer` and
53+
:class:`impute.IterativeImputer` so that no errors are thrown when there are
54+
missing values in training data. :pr:`13974` by `Frank Hoang <fhoang7>`.
5555

5656
:mod:`sklearn.linear_model`
5757
...........................
5858
- |Fix| Fixed a bug in :class:`linear_model.LogisticRegressionCV` where
5959
``refit=False`` would fail depending on the ``'multiclass'`` and
6060
``'penalty'`` parameters (regression introduced in 0.21). :pr:`14087` by
6161
`Nicolas Hug`_.
62+
- |Fix| Compatibility fix for :class:`linear_model.ARDRegression` and
63+
Scipy>=1.3.0. Adapts to upstream changes to the default `pinvh` cutoff
64+
threshold which otherwise results in poor accuracy in some cases.
65+
:pr:`14067` by :user:`Tim Staley <timstaley>`.
6266

6367
:mod:`sklearn.tree`
6468
...................

sklearn/externals/_scipy_linalg.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# This should remained pinned to version 1.2 and not updated like other
2+
# externals.
3+
"""Copyright (c) 2001-2002 Enthought, Inc. 2003-2019, SciPy Developers.
4+
All rights reserved.
5+
6+
Redistribution and use in source and binary forms, with or without
7+
modification, are permitted provided that the following conditions
8+
are met:
9+
10+
1. Redistributions of source code must retain the above copyright
11+
notice, this list of conditions and the following disclaimer.
12+
13+
2. Redistributions in binary form must reproduce the above
14+
copyright notice, this list of conditions and the following
15+
disclaimer in the documentation and/or other materials provided
16+
with the distribution.
17+
18+
3. Neither the name of the copyright holder nor the names of its
19+
contributors may be used to endorse or promote products derived
20+
from this software without specific prior written permission.
21+
22+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23+
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24+
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
25+
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
26+
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
27+
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
28+
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
29+
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
30+
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33+
"""
34+
35+
import numpy as np
36+
import scipy.linalg.decomp as decomp
37+
38+
39+
def pinvh(a, cond=None, rcond=None, lower=True, return_rank=False,
40+
check_finite=True):
41+
"""
42+
Compute the (Moore-Penrose) pseudo-inverse of a Hermitian matrix.
43+
44+
Copied in from scipy==1.2.2, in order to preserve the default choice of the
45+
`cond` and `above_cutoff` values which determine which values of the matrix
46+
inversion lie below threshold and are so set to zero. Changes in scipy 1.3
47+
resulted in a smaller default threshold and thus slower convergence of
48+
dependent algorithms in some cases (see Sklearn github issue #14055).
49+
50+
Calculate a generalized inverse of a Hermitian or real symmetric matrix
51+
using its eigenvalue decomposition and including all eigenvalues with
52+
'large' absolute value.
53+
54+
Parameters
55+
----------
56+
a : (N, N) array_like
57+
Real symmetric or complex hermetian matrix to be pseudo-inverted
58+
cond, rcond : float or None
59+
Cutoff for 'small' eigenvalues.
60+
Singular values smaller than rcond * largest_eigenvalue are considered
61+
zero.
62+
63+
If None or -1, suitable machine precision is used.
64+
lower : bool, optional
65+
Whether the pertinent array data is taken from the lower or upper
66+
triangle of a. (Default: lower)
67+
return_rank : bool, optional
68+
if True, return the effective rank of the matrix
69+
check_finite : bool, optional
70+
Whether to check that the input matrix contains only finite numbers.
71+
Disabling may give a performance gain, but may result in problems
72+
(crashes, non-termination) if the inputs do contain infinities or NaNs.
73+
74+
Returns
75+
-------
76+
B : (N, N) ndarray
77+
The pseudo-inverse of matrix `a`.
78+
rank : int
79+
The effective rank of the matrix. Returned if return_rank == True
80+
81+
Raises
82+
------
83+
LinAlgError
84+
If eigenvalue does not converge
85+
86+
Examples
87+
--------
88+
>>> from scipy.linalg import pinvh
89+
>>> a = np.random.randn(9, 6)
90+
>>> a = np.dot(a, a.T)
91+
>>> B = pinvh(a)
92+
>>> np.allclose(a, np.dot(a, np.dot(B, a)))
93+
True
94+
>>> np.allclose(B, np.dot(B, np.dot(a, B)))
95+
True
96+
97+
"""
98+
a = decomp._asarray_validated(a, check_finite=check_finite)
99+
s, u = decomp.eigh(a, lower=lower, check_finite=False)
100+
101+
if rcond is not None:
102+
cond = rcond
103+
if cond in [None, -1]:
104+
t = u.dtype.char.lower()
105+
factor = {'f': 1E3, 'd': 1E6}
106+
cond = factor[t] * np.finfo(t).eps
107+
108+
# For Hermitian matrices, singular values equal abs(eigenvalues)
109+
above_cutoff = (abs(s) > cond * np.max(abs(s)))
110+
psigma_diag = 1.0 / s[above_cutoff]
111+
u = u[:, above_cutoff]
112+
113+
B = np.dot(u * psigma_diag, np.conjugate(u).T)
114+
115+
if return_rank:
116+
return B, len(psigma_diag)
117+
else:
118+
return B

sklearn/linear_model/bayes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
from math import log
99
import numpy as np
1010
from scipy import linalg
11-
from scipy.linalg import pinvh
1211

1312
from .base import LinearModel, _rescale_data
1413
from ..base import RegressorMixin
1514
from ..utils.extmath import fast_logdet
1615
from ..utils import check_X_y
16+
from ..utils.fixes import pinvh
1717

1818

1919
###############################################################################

sklearn/linear_model/tests/test_bayes.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,24 @@ def test_toy_ard_object():
188188
assert_array_almost_equal(clf.predict(test), [1, 3, 4], 2)
189189

190190

191+
def test_ard_accuracy_on_easy_problem():
192+
# Check that ARD converges with reasonable accuracy on an easy problem
193+
# (Github issue #14055)
194+
# This particular seed seems to converge poorly in the failure-case
195+
# (scipy==1.3.0, sklearn==0.21.2)
196+
seed = 45
197+
X = np.random.RandomState(seed=seed).normal(size=(250, 3))
198+
y = X[:, 1]
199+
200+
regressor = ARDRegression()
201+
regressor.fit(X, y)
202+
203+
abs_coef_error = np.abs(1 - regressor.coef_[1])
204+
# Expect an accuracy of better than 1E-4 in most cases -
205+
# Failure-case produces 0.16!
206+
assert abs_coef_error < 0.01
207+
208+
191209
def test_return_std():
192210
# Test return_std option for both Bayesian regressors
193211
def f(X):

sklearn/utils/fixes.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,13 @@ def _parse_version(version_string):
3939
from scipy.misc import comb, logsumexp # noqa
4040

4141

42+
if sp_version >= (1, 3):
43+
# Preserves earlier default choice of pinvh cutoff `cond` value.
44+
# Can be removed once issue #14055 is fully addressed.
45+
from ..externals._scipy_linalg import pinvh
46+
else:
47+
from scipy.linalg import pinvh # noqa
48+
4249
if sp_version >= (0, 19):
4350
def _argmax(arr_or_spmatrix, axis=None):
4451
return arr_or_spmatrix.argmax(axis=axis)

0 commit comments

Comments
 (0)