diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 5c338d614b49d..be74649ea354f 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -276,6 +276,14 @@ Changelog :mod:`sklearn.gaussian_process` ............................... +- |Fix| `predict` and `sample_y` methods of + :class:`sklearn.gaussian_process.GaussianProcessRegressor` now return + arrays of the correct shape in single-target and multi-target cases, and for + both `normalize_y=False` and `normalize_y=True`. + :pr:`22199` by :user:`Guillaume Lemaitre `, + :user:`Aidar Shakerimoff ` and + :user:`Tenavi Nakamura-Zimmerer `. + - |Fix| :class:`gaussian_process.GaussianProcessClassifier` raises a more informative error if `CompoundKernel` is passed via `kernel`. :pr:`22223` by :user:`MarcoM `. diff --git a/sklearn/gaussian_process/_gpr.py b/sklearn/gaussian_process/_gpr.py index 91ed3e6d8f318..e93a1121af18c 100644 --- a/sklearn/gaussian_process/_gpr.py +++ b/sklearn/gaussian_process/_gpr.py @@ -239,8 +239,9 @@ def fit(self, X, y): y = (y - self._y_train_mean) / self._y_train_std else: - self._y_train_mean = np.zeros(1) - self._y_train_std = 1 + shape_y_stats = (y.shape[1],) if y.ndim == 2 else 1 + self._y_train_mean = np.zeros(shape=shape_y_stats) + self._y_train_std = np.ones(shape=shape_y_stats) if np.iterable(self.alpha) and self.alpha.shape[0] != y.shape[0]: if self.alpha.shape[0] == 1: @@ -394,6 +395,10 @@ def predict(self, X, return_std=False, return_cov=False): # undo normalisation y_mean = self._y_train_std * y_mean + self._y_train_mean + # if y_mean has shape (n_samples, 1), reshape to (n_samples,) + if y_mean.ndim > 1 and y_mean.shape[1] == 1: + y_mean = np.squeeze(y_mean, axis=1) + # Alg 2.1, page 19, line 5 -> v = L \ K(X_test, X_train)^T V = solve_triangular( self.L_, K_trans.T, lower=GPR_CHOLESKY_LOWER, check_finite=False @@ -407,7 +412,6 @@ def predict(self, X, return_std=False, return_cov=False): y_cov = np.outer(y_cov, self._y_train_std ** 2).reshape( *y_cov.shape, -1 ) - # if y_cov has shape (n_samples, n_samples, 1), reshape to # (n_samples, n_samples) if y_cov.shape[2] == 1: @@ -475,8 +479,10 @@ def sample_y(self, X, n_samples=1, random_state=0): y_samples = rng.multivariate_normal(y_mean, y_cov, n_samples).T else: y_samples = [ - rng.multivariate_normal(y_mean[:, i], y_cov, n_samples).T[:, np.newaxis] - for i in range(y_mean.shape[1]) + rng.multivariate_normal( + y_mean[:, target], y_cov[..., target], n_samples + ).T[:, np.newaxis] + for target in range(y_mean.shape[1]) ] y_samples = np.hstack(y_samples) return y_samples diff --git a/sklearn/gaussian_process/tests/test_gpr.py b/sklearn/gaussian_process/tests/test_gpr.py index 8e57865600987..f14d939de910e 100644 --- a/sklearn/gaussian_process/tests/test_gpr.py +++ b/sklearn/gaussian_process/tests/test_gpr.py @@ -18,7 +18,6 @@ from sklearn.gaussian_process.kernels import DotProduct, ExpSineSquared from sklearn.gaussian_process.tests._mini_sequence_kernel import MiniSeqKernel from sklearn.exceptions import ConvergenceWarning - from sklearn.utils._testing import ( assert_array_less, assert_almost_equal, @@ -361,12 +360,17 @@ def test_y_multioutput(): assert_almost_equal(y_pred_1d, y_pred_2d[:, 1] / 2) # Standard deviation and covariance do not depend on output - assert_almost_equal(y_std_1d, y_std_2d) - assert_almost_equal(y_cov_1d, y_cov_2d) + for target in range(y_2d.shape[1]): + assert_almost_equal(y_std_1d, y_std_2d[..., target]) + assert_almost_equal(y_cov_1d, y_cov_2d[..., target]) y_sample_1d = gpr.sample_y(X2, n_samples=10) y_sample_2d = gpr_2d.sample_y(X2, n_samples=10) - assert_almost_equal(y_sample_1d, y_sample_2d[:, 0]) + + assert y_sample_1d.shape == (5, 10) + assert y_sample_2d.shape == (5, 2, 10) + # Only the first target will be equal + assert_almost_equal(y_sample_1d, y_sample_2d[:, 0, :]) # Test hyperparameter optimization for kernel in kernels: @@ -546,8 +550,6 @@ def test_bound_check_fixed_hyperparameter(): GaussianProcessRegressor(kernel=kernel).fit(X, y) -# FIXME: we should test for multitargets as well. However, GPR is broken: -# see: https://github.com/scikit-learn/scikit-learn/pull/19706 @pytest.mark.parametrize("kernel", kernels) def test_constant_target(kernel): """Check that the std. dev. is affected to 1 when normalizing a constant @@ -568,6 +570,26 @@ def test_constant_target(kernel): # set atol because we compare to zero assert_allclose(np.diag(y_cov), 0.0, atol=1e-9) + # Test multi-target data + n_samples, n_targets = X.shape[0], 2 + rng = np.random.RandomState(0) + y = np.concatenate( + [ + rng.normal(size=(n_samples, 1)), # non-constant target + np.full(shape=(n_samples, 1), fill_value=2), # constant target + ], + axis=1, + ) + + gpr.fit(X, y) + Y_pred, Y_cov = gpr.predict(X, return_cov=True) + + assert_allclose(Y_pred[:, 1], 2) + assert_allclose(np.diag(Y_cov[..., 1]), 0.0, atol=1e-9) + + assert Y_pred.shape == (n_samples, n_targets) + assert Y_cov.shape == (n_samples, n_samples, n_targets) + def test_gpr_consistency_std_cov_non_invertible_kernel(): """Check the consistency between the returned std. dev. and the covariance. @@ -654,31 +676,89 @@ def test_gpr_predict_error(): gpr.predict(X, return_cov=True, return_std=True) -def test_y_std_with_multitarget_normalized(): - """Check the proper normalization of `y_std` and `y_cov` in multi-target scene. +@pytest.mark.parametrize("normalize_y", [True, False]) +@pytest.mark.parametrize("n_targets", [None, 1, 10]) +def test_predict_shapes(normalize_y, n_targets): + """Check the shapes of y_mean, y_std, and y_cov in single-output + (n_targets=None) and multi-output settings, including the edge case when + n_targets=1, where the sklearn convention is to squeeze the predictions. Non-regression test for: https://github.com/scikit-learn/scikit-learn/issues/17394 https://github.com/scikit-learn/scikit-learn/issues/18065 + https://github.com/scikit-learn/scikit-learn/issues/22174 """ rng = np.random.RandomState(1234) - n_samples, n_features, n_targets = 12, 10, 6 + n_features, n_samples_train, n_samples_test = 6, 9, 7 - X_train = rng.randn(n_samples, n_features) - y_train = rng.randn(n_samples, n_targets) - X_test = rng.randn(n_samples, n_features) + y_train_shape = (n_samples_train,) + if n_targets is not None: + y_train_shape = y_train_shape + (n_targets,) - # Generic kernel - kernel = WhiteKernel(1.0, (1e-1, 1e3)) * C(10.0, (1e-3, 1e3)) + # By convention single-output data is squeezed upon prediction + y_test_shape = (n_samples_test,) + if n_targets is not None and n_targets > 1: + y_test_shape = y_test_shape + (n_targets,) - model = GaussianProcessRegressor( - kernel=kernel, n_restarts_optimizer=10, alpha=0.1, normalize_y=True - ) + X_train = rng.randn(n_samples_train, n_features) + X_test = rng.randn(n_samples_test, n_features) + y_train = rng.randn(*y_train_shape) + + model = GaussianProcessRegressor(normalize_y=normalize_y) model.fit(X_train, y_train) + y_pred, y_std = model.predict(X_test, return_std=True) _, y_cov = model.predict(X_test, return_cov=True) - assert y_pred.shape == (n_samples, n_targets) - assert y_std.shape == (n_samples, n_targets) - assert y_cov.shape == (n_samples, n_samples, n_targets) + assert y_pred.shape == y_test_shape + assert y_std.shape == y_test_shape + assert y_cov.shape == (n_samples_test,) + y_test_shape + + +@pytest.mark.parametrize("normalize_y", [True, False]) +@pytest.mark.parametrize("n_targets", [None, 1, 10]) +def test_sample_y_shapes(normalize_y, n_targets): + """Check the shapes of y_samples in single-output (n_targets=0) and + multi-output settings, including the edge case when n_targets=1, where the + sklearn convention is to squeeze the predictions. + + Non-regression test for: + https://github.com/scikit-learn/scikit-learn/issues/22175 + """ + rng = np.random.RandomState(1234) + + n_features, n_samples_train = 6, 9 + # Number of spatial locations to predict at + n_samples_X_test = 7 + # Number of sample predictions per test point + n_samples_y_test = 5 + + y_train_shape = (n_samples_train,) + if n_targets is not None: + y_train_shape = y_train_shape + (n_targets,) + + # By convention single-output data is squeezed upon prediction + if n_targets is not None and n_targets > 1: + y_test_shape = (n_samples_X_test, n_targets, n_samples_y_test) + else: + y_test_shape = (n_samples_X_test, n_samples_y_test) + + X_train = rng.randn(n_samples_train, n_features) + X_test = rng.randn(n_samples_X_test, n_features) + y_train = rng.randn(*y_train_shape) + + model = GaussianProcessRegressor(normalize_y=normalize_y) + + # FIXME: before fitting, the estimator does not have information regarding + # the number of targets and default to 1. This is inconsistent with the shape + # provided after `fit`. This assert should be made once the following issue + # is fixed: + # https://github.com/scikit-learn/scikit-learn/issues/22430 + # y_samples = model.sample_y(X_test, n_samples=n_samples_y_test) + # assert y_samples.shape == y_test_shape + + model.fit(X_train, y_train) + + y_samples = model.sample_y(X_test, n_samples=n_samples_y_test) + assert y_samples.shape == y_test_shape