From 76f8680f4ae1093c28aa61b45f61a26624a58dd3 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 16 Dec 2021 10:56:21 +0100 Subject: [PATCH 1/5] FIX make GPR works with multi-target and normalize_y=False --- doc/whats_new/v1.0.rst | 7 +++++++ sklearn/gaussian_process/_gpr.py | 11 +++++++---- sklearn/gaussian_process/tests/test_gpr.py | 18 ++++++++++++------ 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index ce9069a6e8671..6d24d3a6a3c36 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -35,6 +35,13 @@ Changelog and :class:`decomposition.MiniBatchSparsePCA` to be convex and match the referenced article. :pr:`19210` by :user:`Jérémie du Boisberranger `. +:mod:`sklearn.gaussian_process` +............................... + +- |Fix| :class:`gaussian_process.GaussianProcessRegressor` was not handling + properly multi-target when `normalize_y=False`. + :pr:`xxx` by :user:`Guillaume Lemaitre `. + :mod:`sklearn.metrics` ...................... diff --git a/sklearn/gaussian_process/_gpr.py b/sklearn/gaussian_process/_gpr.py index e0ae8a0b2ff5b..0cb57b30f968e 100644 --- a/sklearn/gaussian_process/_gpr.py +++ b/sklearn/gaussian_process/_gpr.py @@ -239,8 +239,9 @@ def fit(self, X, y): y = (y - self._y_train_mean) / self._y_train_std else: - self._y_train_mean = np.zeros(1) - self._y_train_std = 1 + shape_y_stats = (y.shape[1],) if y.ndim == 2 else 1 + self._y_train_mean = np.zeros(shape=shape_y_stats) + self._y_train_std = np.ones(shape=shape_y_stats) if np.iterable(self.alpha) and self.alpha.shape[0] != y.shape[0]: if self.alpha.shape[0] == 1: @@ -475,8 +476,10 @@ def sample_y(self, X, n_samples=1, random_state=0): y_samples = rng.multivariate_normal(y_mean, y_cov, n_samples).T else: y_samples = [ - rng.multivariate_normal(y_mean[:, i], y_cov, n_samples).T[:, np.newaxis] - for i in range(y_mean.shape[1]) + rng.multivariate_normal( + y_mean[:, target], y_cov[..., target], n_samples + ).T[:, np.newaxis] + for target in range(y_mean.shape[1]) ] y_samples = np.hstack(y_samples) return y_samples diff --git a/sklearn/gaussian_process/tests/test_gpr.py b/sklearn/gaussian_process/tests/test_gpr.py index 8e57865600987..e5f5196f74504 100644 --- a/sklearn/gaussian_process/tests/test_gpr.py +++ b/sklearn/gaussian_process/tests/test_gpr.py @@ -361,12 +361,17 @@ def test_y_multioutput(): assert_almost_equal(y_pred_1d, y_pred_2d[:, 1] / 2) # Standard deviation and covariance do not depend on output - assert_almost_equal(y_std_1d, y_std_2d) - assert_almost_equal(y_cov_1d, y_cov_2d) + for target in range(y_2d.shape[1]): + assert_almost_equal(y_std_1d, y_std_2d[..., target]) + assert_almost_equal(y_cov_1d, y_cov_2d[..., target]) y_sample_1d = gpr.sample_y(X2, n_samples=10) y_sample_2d = gpr_2d.sample_y(X2, n_samples=10) - assert_almost_equal(y_sample_1d, y_sample_2d[:, 0]) + + assert y_sample_1d.shape == (5, 10) + assert y_sample_2d.shape == (5, 2, 10) + # Only the first target will be equal + assert_almost_equal(y_sample_1d, y_sample_2d[:, 0, :]) # Test hyperparameter optimization for kernel in kernels: @@ -654,8 +659,9 @@ def test_gpr_predict_error(): gpr.predict(X, return_cov=True, return_std=True) -def test_y_std_with_multitarget_normalized(): - """Check the proper normalization of `y_std` and `y_cov` in multi-target scene. +@pytest.mark.parametrize("normalize_y", [True, False]) +def test_multitarget_std_cov_shape(normalize_y): + """Check the shape of std. dev. and covariance in multi-output setting. Non-regression test for: https://github.com/scikit-learn/scikit-learn/issues/17394 @@ -673,7 +679,7 @@ def test_y_std_with_multitarget_normalized(): kernel = WhiteKernel(1.0, (1e-1, 1e3)) * C(10.0, (1e-3, 1e3)) model = GaussianProcessRegressor( - kernel=kernel, n_restarts_optimizer=10, alpha=0.1, normalize_y=True + kernel=kernel, n_restarts_optimizer=10, alpha=0.1, normalize_y=normalize_y ) model.fit(X_train, y_train) y_pred, y_std = model.predict(X_test, return_std=True) From 329dee67c3ebf993ae0a5332fd71d0afad04e86e Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 16 Dec 2021 10:58:56 +0100 Subject: [PATCH 2/5] update gpr multitarget --- doc/whats_new/v1.0.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index 6d24d3a6a3c36..85233023f2f31 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -40,7 +40,7 @@ Changelog - |Fix| :class:`gaussian_process.GaussianProcessRegressor` was not handling properly multi-target when `normalize_y=False`. - :pr:`xxx` by :user:`Guillaume Lemaitre `. + :pr:`21996` by :user:`Guillaume Lemaitre `. :mod:`sklearn.metrics` ...................... From 361d3d00df81262e706e394feaae12010963783c Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 16 Dec 2021 11:07:37 +0100 Subject: [PATCH 3/5] TST add additional case with constant target --- sklearn/gaussian_process/tests/test_gpr.py | 24 +++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/sklearn/gaussian_process/tests/test_gpr.py b/sklearn/gaussian_process/tests/test_gpr.py index e5f5196f74504..86fa89da720b9 100644 --- a/sklearn/gaussian_process/tests/test_gpr.py +++ b/sklearn/gaussian_process/tests/test_gpr.py @@ -18,7 +18,7 @@ from sklearn.gaussian_process.kernels import DotProduct, ExpSineSquared from sklearn.gaussian_process.tests._mini_sequence_kernel import MiniSeqKernel from sklearn.exceptions import ConvergenceWarning - +from sklearn.utils.validation import assert_all_finite from sklearn.utils._testing import ( assert_array_less, assert_almost_equal, @@ -551,8 +551,6 @@ def test_bound_check_fixed_hyperparameter(): GaussianProcessRegressor(kernel=kernel).fit(X, y) -# FIXME: we should test for multitargets as well. However, GPR is broken: -# see: https://github.com/scikit-learn/scikit-learn/pull/19706 @pytest.mark.parametrize("kernel", kernels) def test_constant_target(kernel): """Check that the std. dev. is affected to 1 when normalizing a constant @@ -573,6 +571,26 @@ def test_constant_target(kernel): # set atol because we compare to zero assert_allclose(np.diag(y_cov), 0.0, atol=1e-9) + # Test multi-target data + n_samples, n_targets = X.shape[0], 2 + rng = np.random.RandomState(0) + y = np.concatenate( + [ + rng.normal(size=(n_samples, 1)), # non-constant target + np.full(shape=(n_samples, 1), fill_value=2), # constant target + ], + axis=1, + ) + + gpr.fit(X, y) + Y_pred, Y_cov = gpr.predict(X, return_cov=True) + + assert_all_finite(Y_pred) + assert_all_finite(Y_cov) + + assert Y_pred.shape == (n_samples, n_targets) + assert Y_cov.shape == (n_samples, n_samples, n_targets) + def test_gpr_consistency_std_cov_non_invertible_kernel(): """Check the consistency between the returned std. dev. and the covariance. From 4d6a510332bc8c3d7277a5e135e3b8a6f8aa2c8f Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 16 Dec 2021 11:12:42 +0100 Subject: [PATCH 4/5] add original PR author --- doc/whats_new/v1.0.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index 85233023f2f31..f3d9ebab14711 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -40,7 +40,8 @@ Changelog - |Fix| :class:`gaussian_process.GaussianProcessRegressor` was not handling properly multi-target when `normalize_y=False`. - :pr:`21996` by :user:`Guillaume Lemaitre `. + :pr:`21996` by :user:`Guillaume Lemaitre ` and + :user:`Aidar Shakerimoff `. :mod:`sklearn.metrics` ...................... From 75a28bd82b37c4b4a819ba7005e362b94facf20c Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 17 Dec 2021 09:31:00 +0100 Subject: [PATCH 5/5] TST improve assert --- sklearn/gaussian_process/tests/test_gpr.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sklearn/gaussian_process/tests/test_gpr.py b/sklearn/gaussian_process/tests/test_gpr.py index 86fa89da720b9..bf1917714fcac 100644 --- a/sklearn/gaussian_process/tests/test_gpr.py +++ b/sklearn/gaussian_process/tests/test_gpr.py @@ -18,7 +18,6 @@ from sklearn.gaussian_process.kernels import DotProduct, ExpSineSquared from sklearn.gaussian_process.tests._mini_sequence_kernel import MiniSeqKernel from sklearn.exceptions import ConvergenceWarning -from sklearn.utils.validation import assert_all_finite from sklearn.utils._testing import ( assert_array_less, assert_almost_equal, @@ -585,8 +584,8 @@ def test_constant_target(kernel): gpr.fit(X, y) Y_pred, Y_cov = gpr.predict(X, return_cov=True) - assert_all_finite(Y_pred) - assert_all_finite(Y_cov) + assert_allclose(Y_pred[:, 1], 2) + assert_allclose(np.diag(Y_cov[..., 1]), 0.0, atol=1e-9) assert Y_pred.shape == (n_samples, n_targets) assert Y_cov.shape == (n_samples, n_samples, n_targets)