diff --git a/doc/whats_new/v0.23.rst b/doc/whats_new/v0.23.rst index 720a19cd862f6..19fd78444aad1 100644 --- a/doc/whats_new/v0.23.rst +++ b/doc/whats_new/v0.23.rst @@ -196,7 +196,12 @@ Changelog ............................... - |Enhancement| :func:`gaussian_process.kernels.Matern` returns the RBF kernel when ``nu=np.inf``. - :pr:`15503` by :user:`Sam Dixon `. + :pr:`15503` by :user:`Sam Dixon` . + +- |Fix| Fixed bug in :class:`gaussian_process.GaussianProcessRegressor` that + caused predicted standard deviations to only be between 0 and 1 when + WhiteKernel is not used. :pr:`15782` + by :user:`plgreenLIRU`. :mod:`sklearn.impute` ..................... @@ -218,7 +223,6 @@ Changelog :class:`tree.DecisionTreeRegressor`. :pr:`15864` by `Nicolas Hug`_. - :mod:`sklearn.linear_model` ........................... @@ -393,6 +397,6 @@ Changelog :mod:`sklearn.cluster` ...................... -- |Fix| :class:`cluster.AgglomerativeClustering` add specific error when - distance matrix is not square and `affinity=precomputed`. +- |Fix| :class:`cluster.AgglomerativeClustering` add specific error when + distance matrix is not square and `affinity=precomputed`. :pr:`16257` by :user:`Simona Maggio `. diff --git a/sklearn/gaussian_process/_gpr.py b/sklearn/gaussian_process/_gpr.py index 432e6937e1951..0c1db0d209458 100644 --- a/sklearn/gaussian_process/_gpr.py +++ b/sklearn/gaussian_process/_gpr.py @@ -1,7 +1,7 @@ """Gaussian processes regression. """ # Authors: Jan Hendrik Metzen -# +# Modified by: Pete Green # License: BSD 3 clause import warnings @@ -92,13 +92,14 @@ def optimizer(obj_func, initial_theta, bounds): must be finite. Note that n_restarts_optimizer == 0 implies that one run is performed. - normalize_y : bool, default=False - Whether the target values y are normalized, i.e., the mean of the - observed target values become zero. This parameter should be set to - True if the target values' mean is expected to differ considerable from - zero. When enabled, the normalization effectively modifies the GP's - prior based on the data, which contradicts the likelihood principle; - normalization is thus disabled per default. + normalize_y : boolean, optional (default: False) + Whether the target values y are normalized, the mean and variance of + the target values are set equal to 0 and 1 respectively. This is + recommended for cases where zero-mean, unit-variance priors are used. + Note that, in this implementation, the normalisation is reversed + before the GP predictions are reported. + + .. versionchanged:: 0.23 copy_X_train : bool, default=True If True, a persistent copy of the training data is stored in the @@ -192,10 +193,14 @@ def fit(self, X, y): # Normalize target value if self.normalize_y: self._y_train_mean = np.mean(y, axis=0) - # demean y - y = y - self._y_train_mean + self._y_train_std = np.std(y, axis=0) + + # Remove mean and make unit variance + y = (y - self._y_train_mean) / self._y_train_std + else: self._y_train_mean = np.zeros(1) + self._y_train_std = 1 if np.iterable(self.alpha) \ and self.alpha.shape[0] != y.shape[0]: @@ -330,10 +335,17 @@ def predict(self, X, return_std=False, return_cov=False): else: # Predict based on GP posterior K_trans = self.kernel_(X, self.X_train_) y_mean = K_trans.dot(self.alpha_) # Line 4 (y_mean = f_star) - y_mean = self._y_train_mean + y_mean # undo normal. + + # undo normalisation + y_mean = self._y_train_std * y_mean + self._y_train_mean + if return_cov: v = cho_solve((self.L_, True), K_trans.T) # Line 5 y_cov = self.kernel_(X) - K_trans.dot(v) # Line 6 + + # undo normalisation + y_cov = y_cov * self._y_train_std**2 + return y_mean, y_cov elif return_std: # cache result of K_inv computation @@ -356,6 +368,10 @@ def predict(self, X, return_std=False, return_cov=False): warnings.warn("Predicted variances smaller than 0. " "Setting those variances to 0.") y_var[y_var_negative] = 0.0 + + # undo normalisation + y_var = y_var * self._y_train_std**2 + return y_mean, np.sqrt(y_var) else: return y_mean diff --git a/sklearn/gaussian_process/tests/test_gpr.py b/sklearn/gaussian_process/tests/test_gpr.py index 4ed105db04246..4bdd94e669eb4 100644 --- a/sklearn/gaussian_process/tests/test_gpr.py +++ b/sklearn/gaussian_process/tests/test_gpr.py @@ -1,6 +1,7 @@ """Testing for Gaussian process regression """ # Author: Jan Hendrik Metzen +# Modified by: Pete Green # License: BSD 3 clause import sys @@ -19,7 +20,8 @@ from sklearn.utils._testing \ import (assert_array_less, assert_almost_equal, assert_raise_message, - assert_array_almost_equal, assert_array_equal) + assert_array_almost_equal, assert_array_equal, + assert_allclose) def f(x): @@ -232,33 +234,103 @@ def test_random_starts(): @pytest.mark.parametrize('kernel', kernels) def test_y_normalization(kernel): - # Test normalization of the target values in GP + """ + Test normalization of the target values in GP - # Fitting non-normalizing GP on normalized y and fitting normalizing GP - # on unnormalized y should yield identical results - y_mean = y.mean(0) - y_norm = y - y_mean + Fitting non-normalizing GP on normalized y and fitting normalizing GP + on unnormalized y should yield identical results. Note that, here, + 'normalized y' refers to y that has been made zero mean and unit + variance. + + """ + + y_mean = np.mean(y) + y_std = np.std(y) + y_norm = (y - y_mean) / y_std # Fit non-normalizing GP on normalized y gpr = GaussianProcessRegressor(kernel=kernel) gpr.fit(X, y_norm) + # Fit normalizing GP on unnormalized y gpr_norm = GaussianProcessRegressor(kernel=kernel, normalize_y=True) gpr_norm.fit(X, y) # Compare predicted mean, std-devs and covariances y_pred, y_pred_std = gpr.predict(X2, return_std=True) - y_pred = y_mean + y_pred + y_pred = y_pred * y_std + y_mean + y_pred_std = y_pred_std * y_std y_pred_norm, y_pred_std_norm = gpr_norm.predict(X2, return_std=True) assert_almost_equal(y_pred, y_pred_norm) assert_almost_equal(y_pred_std, y_pred_std_norm) _, y_cov = gpr.predict(X2, return_cov=True) + y_cov = y_cov * y_std**2 _, y_cov_norm = gpr_norm.predict(X2, return_cov=True) + assert_almost_equal(y_cov, y_cov_norm) +def test_large_variance_y(): + """ + Here we test that, when noramlize_y=True, our GP can produce a + sensible fit to training data whose variance is significantly + larger than unity. This test was made in response to issue #15612. + + GP predictions are verified against predictions that were made + using GPy which, here, is treated as the 'gold standard'. Note that we + only investigate the RBF kernel here, as that is what was used in the + GPy implementation. + + The following code can be used to recreate the GPy data: + + -------------------------------------------------------------------------- + import GPy + + kernel_gpy = GPy.kern.RBF(input_dim=1, lengthscale=1.) + gpy = GPy.models.GPRegression(X, np.vstack(y_large), kernel_gpy) + gpy.optimize() + y_pred_gpy, y_var_gpy = gpy.predict(X2) + y_pred_std_gpy = np.sqrt(y_var_gpy) + -------------------------------------------------------------------------- + """ + + # Here we utilise a larger variance version of the training data + y_large = 10 * y + + # Standard GP with normalize_y=True + RBF_params = {'length_scale': 1.0} + kernel = RBF(**RBF_params) + gpr = GaussianProcessRegressor(kernel=kernel, normalize_y=True) + gpr.fit(X, y_large) + y_pred, y_pred_std = gpr.predict(X2, return_std=True) + + # 'Gold standard' mean predictions from GPy + y_pred_gpy = np.array([15.16918303, + -27.98707845, + -39.31636019, + 14.52605515, + 69.18503589]) + + # 'Gold standard' std predictions from GPy + y_pred_std_gpy = np.array([7.78860962, + 3.83179178, + 0.63149951, + 0.52745188, + 0.86170042]) + + # Based on numerical experiments, it's reasonable to expect our + # GP's mean predictions to get within 7% of predictions of those + # made by GPy. + assert_allclose(y_pred, y_pred_gpy, rtol=0.07, atol=0) + + # Based on numerical experiments, it's reasonable to expect our + # GP's std predictions to get within 15% of predictions of those + # made by GPy. + assert_allclose(y_pred_std, y_pred_std_gpy, rtol=0.15, atol=0) + + def test_y_multioutput(): # Test that GPR can deal with multi-dimensional target values y_2d = np.vstack((y, y * 2)).T