Skip to content

FIX Use cho_solve when return_std=True for GaussianProcessRegressor #19939

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 27, 2021
Merged
7 changes: 7 additions & 0 deletions doc/whats_new/v0.24.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ Changelog
:mod:`sklearn.gaussian_process`
...............................

- |Fix| Avoid explicitly forming inverse covariance matrix in
:class:`gaussian_process.GaussianProcessRegressor` when set to output
standard deviation. With certain covariance matrices this inverse is unstable
to compute explicitly. Calling Cholesky solver mitigates this issue in
computation.
:pr:`19939` by :user:`Ian Halvic <iwhalvic>`.

- |Fix| Avoid division by zero when scaling constant target in
:class:`gaussian_process.GaussianProcessRegressor`. It was due to a std. dev.
equal to 0. Now, such case is detected and the std. dev. is affected to 1
Expand Down
24 changes: 9 additions & 15 deletions sklearn/gaussian_process/_gpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from operator import itemgetter

import numpy as np
from scipy.linalg import cholesky, cho_solve, solve_triangular
from scipy.linalg import cholesky, cho_solve
import scipy.optimize

from ..base import BaseEstimator, RegressorMixin, clone
Expand Down Expand Up @@ -270,8 +270,6 @@ def obj_func(theta, eval_gradient=True):
K[np.diag_indices_from(K)] += self.alpha
try:
self.L_ = cholesky(K, lower=True) # Line 2
# self.L_ changed, self._K_inv needs to be recomputed
self._K_inv = None
except np.linalg.LinAlgError as exc:
exc.args = ("The kernel, %s, is not returning a "
"positive definite matrix. Try gradually "
Expand Down Expand Up @@ -345,31 +343,27 @@ def predict(self, X, return_std=False, return_cov=False):
else: # Predict based on GP posterior
K_trans = self.kernel_(X, self.X_train_)
y_mean = K_trans.dot(self.alpha_) # Line 4 (y_mean = f_star)

# undo normalisation
y_mean = self._y_train_std * y_mean + self._y_train_mean

if return_cov:
v = cho_solve((self.L_, True), K_trans.T) # Line 5
y_cov = self.kernel_(X) - K_trans.dot(v) # Line 6
# Solve K @ V = K_trans.T
V = cho_solve((self.L_, True), K_trans.T) # Line 5
y_cov = self.kernel_(X) - K_trans.dot(V) # Line 6

# undo normalisation
y_cov = y_cov * self._y_train_std**2

return y_mean, y_cov
elif return_std:
# cache result of K_inv computation
if self._K_inv is None:
# compute inverse K_inv of K based on its Cholesky
# decomposition L and its inverse L_inv
L_inv = solve_triangular(self.L_.T,
np.eye(self.L_.shape[0]))
self._K_inv = L_inv.dot(L_inv.T)
# Solve K @ V = K_trans.T
V = cho_solve((self.L_, True), K_trans.T) # Line 5

# Compute variance of predictive distribution
# Use einsum to avoid explicitly forming the large matrix
# K_trans @ V just to extract its diagonal afterward.
y_var = self.kernel_.diag(X)
y_var -= np.einsum("ij,ij->i",
np.dot(K_trans, self._K_inv), K_trans)
y_var -= np.einsum("ij,ji->i", K_trans, V)

# Check if any of the variances is negative because of
# numerical issues. If yes: set the variance to 0.
Expand Down
57 changes: 33 additions & 24 deletions sklearn/gaussian_process/tests/test_gpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
from sklearn.gaussian_process.tests._mini_sequence_kernel import MiniSeqKernel
from sklearn.exceptions import ConvergenceWarning

from sklearn.utils._testing \
import (assert_array_less,
assert_almost_equal, assert_array_almost_equal,
assert_array_equal, assert_allclose)
from sklearn.utils._testing import (
assert_array_less,
assert_almost_equal,
assert_array_almost_equal,
assert_allclose
)


def f(x):
Expand Down Expand Up @@ -185,7 +187,8 @@ def test_no_optimizer():


@pytest.mark.parametrize('kernel', kernels)
def test_predict_cov_vs_std(kernel):
@pytest.mark.parametrize("target", [y, np.ones(X.shape[0], dtype=np.float64)])
def test_predict_cov_vs_std(kernel, target):
if sys.maxsize <= 2 ** 32 and sys.version_info[:2] == (3, 6):
pytest.xfail("This test may fail on 32bit Py3.6")

Expand Down Expand Up @@ -455,25 +458,6 @@ def test_no_fit_default_predict():
assert_array_almost_equal(y_cov1, y_cov2)


@pytest.mark.parametrize('kernel', kernels)
def test_K_inv_reset(kernel):
y2 = f(X2).ravel()

# Test that self._K_inv is reset after a new fit
gpr = GaussianProcessRegressor(kernel=kernel).fit(X, y)
assert hasattr(gpr, '_K_inv')
assert gpr._K_inv is None
gpr.predict(X, return_std=True)
assert gpr._K_inv is not None
gpr.fit(X2, y2)
assert gpr._K_inv is None
gpr.predict(X2, return_std=True)
gpr2 = GaussianProcessRegressor(kernel=kernel).fit(X2, y2)
gpr2.predict(X2, return_std=True)
# the value of K_inv should be independent of the first fit
assert_array_equal(gpr._K_inv, gpr2._K_inv)


def test_warning_bounds():
kernel = RBF(length_scale_bounds=[1e-5, 1e-3])
gpr = GaussianProcessRegressor(kernel=kernel)
Expand Down Expand Up @@ -569,3 +553,28 @@ def test_constant_target(kernel):
assert_allclose(y_pred, y_constant)
# set atol because we compare to zero
assert_allclose(np.diag(y_cov), 0., atol=1e-9)


def test_gpr_consistency_std_cov_non_invertible_kernel():
"""Check the consistency between the returned std. dev. and the covariance.
Non-regression test for:
https://github.com/scikit-learn/scikit-learn/issues/19936
Comment on lines +560 to +561
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Non-regression test for:
https://github.com/scikit-learn/scikit-learn/issues/19936
Non-regression test for issues #19936.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest I prefer to have the URL :)

Copy link
Member

@jjerphan jjerphan Apr 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, I just have seen that both styles exist.

@iwhalvic: you can ignore this comment.

Inconsistencies were observed when the kernel cannot be inverted (or
numerically stable).
"""
kernel = (C(8.98576054e+05, (1e-12, 1e12)) *
RBF([5.91326520e+02, 1.32584051e+03], (1e-12, 1e12)) +
WhiteKernel(noise_level=1e-5))
gpr = GaussianProcessRegressor(kernel=kernel, alpha=0, optimizer=None)
X_train = np.array([[0., 0.], [1.54919334, -0.77459667], [-1.54919334, 0.],
[0., -1.54919334], [0.77459667, 0.77459667],
[-0.77459667, 1.54919334]])
y_train = np.array([[-2.14882017e-10], [-4.66975823e+00], [4.01823986e+00],
[-1.30303674e+00], [-1.35760156e+00],
[3.31215668e+00]])
gpr.fit(X_train, y_train)
X_test = np.array([[-1.93649167, -1.93649167], [1.93649167, -1.93649167],
[-1.93649167, 1.93649167], [1.93649167, 1.93649167]])
Comment on lines +565 to +577
Copy link
Member

@jjerphan jjerphan Apr 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are those values the ones for the case you are describing?

Inconsistencies were observed when the kernel cannot be inverted (or numerically stable).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. This kernel is near to be singular: #19939 (comment)
I assume that this fine for a non-regression test.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the problem is about diagonal elements of gpr.L_ being close to zero, maybe we could make that explicit, e.g. with:

assert np.diag(gpr.L_).min() < 0.01

although I am not 100% sure this is the true cause of the problem.

Copy link
Member

@ogrisel ogrisel Apr 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe:

assert np.linalg.eigvalsh(kernel(X_train)).min() < 1e-4

Copy link
Member

@jjerphan jjerphan Apr 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition number of gpr.L_ might also be a good proxy to assert if solving this system comes with numerical instability.

In this case, numerical instability is prone be present as cond(gpr.L_) >> 1:

import numpy as np
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C, WhiteKernel

kernel = (C(8.98576054e+05, (1e-12, 1e12)) *
          RBF([5.91326520e+02, 1.32584051e+03], (1e-12, 1e12)) +
          WhiteKernel(noise_level=1e-5))

gpr = GaussianProcessRegressor(kernel=kernel, alpha=0, optimizer=None)

X_train = np.array([[0., 0.], [1.54919334, -0.77459667], [-1.54919334, 0.],
                    [0., -1.54919334], [0.77459667, 0.77459667],
                    [-0.77459667, 1.54919334]])

y_train = np.array([[-2.14882017e-10], [-4.66975823e+00], [4.01823986e+00],
                    [-1.30303674e+00], [-1.35760156e+00],
                    [3.31215668e+00]])

gpr.fit(X_train, y_train)

print(np.linalg.cond(gpr.L_))
720955.9866810681

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep this is exactly what we can observe in the branch main when we try to inverse L_.T with constant target. I print the condition number of L_ and when it fails for the constant target it is due to the fact that L_ is ill-conditioned. When the condition number >> 1, the inversion is numerically instable but the L_ is not singular.

sklearn/gaussian_process/tests/test_gpr.py::test_predict_cov_vs_std[target0-kernel0] 3.6020373678666835
PASSED
sklearn/gaussian_process/tests/test_gpr.py::test_predict_cov_vs_std[target0-kernel1] 4.000792132354216
PASSED
sklearn/gaussian_process/tests/test_gpr.py::test_predict_cov_vs_std[target0-kernel2] 3.6020373678666835
PASSED
sklearn/gaussian_process/tests/test_gpr.py::test_predict_cov_vs_std[target0-kernel3] 20.254629250748767
PASSED
sklearn/gaussian_process/tests/test_gpr.py::test_predict_cov_vs_std[target0-kernel4] 20.254602147788138
PASSED
sklearn/gaussian_process/tests/test_gpr.py::test_predict_cov_vs_std[target0-kernel5] 1.0000016459179157
PASSED
sklearn/gaussian_process/tests/test_gpr.py::test_predict_cov_vs_std[target1-kernel0] 244949.11539454578
FAILED
sklearn/gaussian_process/tests/test_gpr.py::test_predict_cov_vs_std[target1-kernel1] 4.000792132354216
PASSED
sklearn/gaussian_process/tests/test_gpr.py::test_predict_cov_vs_std[target1-kernel2] 244948.4153273236
FAILED
sklearn/gaussian_process/tests/test_gpr.py::test_predict_cov_vs_std[target1-kernel3] 165965.62854978096
FAILED
sklearn/gaussian_process/tests/test_gpr.py::test_predict_cov_vs_std[target1-kernel4] 165956.90214640685
FAILED
sklearn/gaussian_process/tests/test_gpr.py::test_predict_cov_vs_std[target1-kernel5] 165975.30738621883
FAILED

Copy link
Contributor Author

@iwhalvic iwhalvic Apr 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just adding my comments on what is occurring here:
Covariance matrices must be positive semi-definite, so all eigenvalues >=0
Inversion of matrices require all eigenvalues <>0

So theoretically, we could have a covariance matrix with a 0 eigenvalue -> non-invertible

Minimal example: If I ask you to invert A=[[1 1] [1 1]] you will likely not have an answer. In an analogy to what the previous code using solve_triangular(self.L_.T, np.eye(self.L_.shape[0])) we are asking: "Solve Ax=[1 0]". Once again no solution for x.

However, in the fix code cho_solve((self.L_, True), K_trans.T), but L_ and K_trans are both generated from the same kernel function, so we would not expect to ever be posed with such a problem. Instead we pose a problem such as "Solve Ax=[1 1]", at which point you can probably provide an answer (if not quite a few).

Kernel functions such as the test case form matrices like A=[[1 1] [1 1]] because the length scales of the RBF are quite large, so all points become highly correlated with all other points (almost regardless of distance in the X space). I believe a true linear fit using a GP with an RBF kernel would have a correlation length of inf.

The alpha parameter in GaussianProcessRegressor attempts to avoid this, and with a sufficiently large alpha value I predict it would for this case. But regardless the return_cov=True and return_std=True should be as consistent as possible.

But I would welcome a GP expert to weigh in.

pred1, std = gpr.predict(X_test, return_std=True)
pred2, cov = gpr.predict(X_test, return_cov=True)
assert_allclose(std, np.sqrt(np.diagonal(cov)), rtol=1e-5)