From 76f8680f4ae1093c28aa61b45f61a26624a58dd3 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 16 Dec 2021 10:56:21 +0100 Subject: [PATCH 01/13] FIX make GPR works with multi-target and normalize_y=False --- doc/whats_new/v1.0.rst | 7 +++++++ sklearn/gaussian_process/_gpr.py | 11 +++++++---- sklearn/gaussian_process/tests/test_gpr.py | 18 ++++++++++++------ 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index ce9069a6e8671..6d24d3a6a3c36 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -35,6 +35,13 @@ Changelog and :class:`decomposition.MiniBatchSparsePCA` to be convex and match the referenced article. :pr:`19210` by :user:`Jérémie du Boisberranger `. +:mod:`sklearn.gaussian_process` +............................... + +- |Fix| :class:`gaussian_process.GaussianProcessRegressor` was not handling + properly multi-target when `normalize_y=False`. + :pr:`xxx` by :user:`Guillaume Lemaitre `. + :mod:`sklearn.metrics` ...................... diff --git a/sklearn/gaussian_process/_gpr.py b/sklearn/gaussian_process/_gpr.py index e0ae8a0b2ff5b..0cb57b30f968e 100644 --- a/sklearn/gaussian_process/_gpr.py +++ b/sklearn/gaussian_process/_gpr.py @@ -239,8 +239,9 @@ def fit(self, X, y): y = (y - self._y_train_mean) / self._y_train_std else: - self._y_train_mean = np.zeros(1) - self._y_train_std = 1 + shape_y_stats = (y.shape[1],) if y.ndim == 2 else 1 + self._y_train_mean = np.zeros(shape=shape_y_stats) + self._y_train_std = np.ones(shape=shape_y_stats) if np.iterable(self.alpha) and self.alpha.shape[0] != y.shape[0]: if self.alpha.shape[0] == 1: @@ -475,8 +476,10 @@ def sample_y(self, X, n_samples=1, random_state=0): y_samples = rng.multivariate_normal(y_mean, y_cov, n_samples).T else: y_samples = [ - rng.multivariate_normal(y_mean[:, i], y_cov, n_samples).T[:, np.newaxis] - for i in range(y_mean.shape[1]) + rng.multivariate_normal( + y_mean[:, target], y_cov[..., target], n_samples + ).T[:, np.newaxis] + for target in range(y_mean.shape[1]) ] y_samples = np.hstack(y_samples) return y_samples diff --git a/sklearn/gaussian_process/tests/test_gpr.py b/sklearn/gaussian_process/tests/test_gpr.py index 8e57865600987..e5f5196f74504 100644 --- a/sklearn/gaussian_process/tests/test_gpr.py +++ b/sklearn/gaussian_process/tests/test_gpr.py @@ -361,12 +361,17 @@ def test_y_multioutput(): assert_almost_equal(y_pred_1d, y_pred_2d[:, 1] / 2) # Standard deviation and covariance do not depend on output - assert_almost_equal(y_std_1d, y_std_2d) - assert_almost_equal(y_cov_1d, y_cov_2d) + for target in range(y_2d.shape[1]): + assert_almost_equal(y_std_1d, y_std_2d[..., target]) + assert_almost_equal(y_cov_1d, y_cov_2d[..., target]) y_sample_1d = gpr.sample_y(X2, n_samples=10) y_sample_2d = gpr_2d.sample_y(X2, n_samples=10) - assert_almost_equal(y_sample_1d, y_sample_2d[:, 0]) + + assert y_sample_1d.shape == (5, 10) + assert y_sample_2d.shape == (5, 2, 10) + # Only the first target will be equal + assert_almost_equal(y_sample_1d, y_sample_2d[:, 0, :]) # Test hyperparameter optimization for kernel in kernels: @@ -654,8 +659,9 @@ def test_gpr_predict_error(): gpr.predict(X, return_cov=True, return_std=True) -def test_y_std_with_multitarget_normalized(): - """Check the proper normalization of `y_std` and `y_cov` in multi-target scene. +@pytest.mark.parametrize("normalize_y", [True, False]) +def test_multitarget_std_cov_shape(normalize_y): + """Check the shape of std. dev. and covariance in multi-output setting. Non-regression test for: https://github.com/scikit-learn/scikit-learn/issues/17394 @@ -673,7 +679,7 @@ def test_y_std_with_multitarget_normalized(): kernel = WhiteKernel(1.0, (1e-1, 1e3)) * C(10.0, (1e-3, 1e3)) model = GaussianProcessRegressor( - kernel=kernel, n_restarts_optimizer=10, alpha=0.1, normalize_y=True + kernel=kernel, n_restarts_optimizer=10, alpha=0.1, normalize_y=normalize_y ) model.fit(X_train, y_train) y_pred, y_std = model.predict(X_test, return_std=True) From 329dee67c3ebf993ae0a5332fd71d0afad04e86e Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 16 Dec 2021 10:58:56 +0100 Subject: [PATCH 02/13] update gpr multitarget --- doc/whats_new/v1.0.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index 6d24d3a6a3c36..85233023f2f31 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -40,7 +40,7 @@ Changelog - |Fix| :class:`gaussian_process.GaussianProcessRegressor` was not handling properly multi-target when `normalize_y=False`. - :pr:`xxx` by :user:`Guillaume Lemaitre `. + :pr:`21996` by :user:`Guillaume Lemaitre `. :mod:`sklearn.metrics` ...................... From 361d3d00df81262e706e394feaae12010963783c Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 16 Dec 2021 11:07:37 +0100 Subject: [PATCH 03/13] TST add additional case with constant target --- sklearn/gaussian_process/tests/test_gpr.py | 24 +++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/sklearn/gaussian_process/tests/test_gpr.py b/sklearn/gaussian_process/tests/test_gpr.py index e5f5196f74504..86fa89da720b9 100644 --- a/sklearn/gaussian_process/tests/test_gpr.py +++ b/sklearn/gaussian_process/tests/test_gpr.py @@ -18,7 +18,7 @@ from sklearn.gaussian_process.kernels import DotProduct, ExpSineSquared from sklearn.gaussian_process.tests._mini_sequence_kernel import MiniSeqKernel from sklearn.exceptions import ConvergenceWarning - +from sklearn.utils.validation import assert_all_finite from sklearn.utils._testing import ( assert_array_less, assert_almost_equal, @@ -551,8 +551,6 @@ def test_bound_check_fixed_hyperparameter(): GaussianProcessRegressor(kernel=kernel).fit(X, y) -# FIXME: we should test for multitargets as well. However, GPR is broken: -# see: https://github.com/scikit-learn/scikit-learn/pull/19706 @pytest.mark.parametrize("kernel", kernels) def test_constant_target(kernel): """Check that the std. dev. is affected to 1 when normalizing a constant @@ -573,6 +571,26 @@ def test_constant_target(kernel): # set atol because we compare to zero assert_allclose(np.diag(y_cov), 0.0, atol=1e-9) + # Test multi-target data + n_samples, n_targets = X.shape[0], 2 + rng = np.random.RandomState(0) + y = np.concatenate( + [ + rng.normal(size=(n_samples, 1)), # non-constant target + np.full(shape=(n_samples, 1), fill_value=2), # constant target + ], + axis=1, + ) + + gpr.fit(X, y) + Y_pred, Y_cov = gpr.predict(X, return_cov=True) + + assert_all_finite(Y_pred) + assert_all_finite(Y_cov) + + assert Y_pred.shape == (n_samples, n_targets) + assert Y_cov.shape == (n_samples, n_samples, n_targets) + def test_gpr_consistency_std_cov_non_invertible_kernel(): """Check the consistency between the returned std. dev. and the covariance. From 4d6a510332bc8c3d7277a5e135e3b8a6f8aa2c8f Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 16 Dec 2021 11:12:42 +0100 Subject: [PATCH 04/13] add original PR author --- doc/whats_new/v1.0.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index 85233023f2f31..f3d9ebab14711 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -40,7 +40,8 @@ Changelog - |Fix| :class:`gaussian_process.GaussianProcessRegressor` was not handling properly multi-target when `normalize_y=False`. - :pr:`21996` by :user:`Guillaume Lemaitre `. + :pr:`21996` by :user:`Guillaume Lemaitre ` and + :user:`Aidar Shakerimoff `. :mod:`sklearn.metrics` ...................... From 75a28bd82b37c4b4a819ba7005e362b94facf20c Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 17 Dec 2021 09:31:00 +0100 Subject: [PATCH 05/13] TST improve assert --- sklearn/gaussian_process/tests/test_gpr.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sklearn/gaussian_process/tests/test_gpr.py b/sklearn/gaussian_process/tests/test_gpr.py index 86fa89da720b9..bf1917714fcac 100644 --- a/sklearn/gaussian_process/tests/test_gpr.py +++ b/sklearn/gaussian_process/tests/test_gpr.py @@ -18,7 +18,6 @@ from sklearn.gaussian_process.kernels import DotProduct, ExpSineSquared from sklearn.gaussian_process.tests._mini_sequence_kernel import MiniSeqKernel from sklearn.exceptions import ConvergenceWarning -from sklearn.utils.validation import assert_all_finite from sklearn.utils._testing import ( assert_array_less, assert_almost_equal, @@ -585,8 +584,8 @@ def test_constant_target(kernel): gpr.fit(X, y) Y_pred, Y_cov = gpr.predict(X, return_cov=True) - assert_all_finite(Y_pred) - assert_all_finite(Y_cov) + assert_allclose(Y_pred[:, 1], 2) + assert_allclose(np.diag(Y_cov[..., 1]), 0.0, atol=1e-9) assert Y_pred.shape == (n_samples, n_targets) assert Y_cov.shape == (n_samples, n_samples, n_targets) From ab057a09484d1ce1e506118de7e3fae5a115176e Mon Sep 17 00:00:00 2001 From: "Nakamura-Zimmerer, Tenavi (ARC-AF)" Date: Wed, 12 Jan 2022 13:16:26 -0800 Subject: [PATCH 06/13] shape fixes --- doc/whats_new/v1.1.rst | 9 +++++ sklearn/gaussian_process/_gpr.py | 12 ++++-- sklearn/gaussian_process/tests/test_gpr.py | 45 +++++++++++++++++++++- 3 files changed, 61 insertions(+), 5 deletions(-) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index bb630edda8ac8..3c927df651ee0 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -268,6 +268,15 @@ Changelog F-statistic). :pr:`17819` by :user:`Juan Carlos Alfaro Jiménez `. +:mod:`sklearn.gaussian_process` +............................... + +- |Fix| `predict` and `sample_y` functions of + :class:`sklearn.gaussian_process.GaussianProcessRegressor` now return + arrays of the correct shape in single-target and multi-target cases, and for + both `normalize_y=False` and `normalize_y=True`. + by :user:`Tenavi Nakamura-Zimmerer `. + :mod:`sklearn.impute` ..................... diff --git a/sklearn/gaussian_process/_gpr.py b/sklearn/gaussian_process/_gpr.py index 91ed3e6d8f318..0c0060066a4e2 100644 --- a/sklearn/gaussian_process/_gpr.py +++ b/sklearn/gaussian_process/_gpr.py @@ -240,7 +240,8 @@ def fit(self, X, y): else: self._y_train_mean = np.zeros(1) - self._y_train_std = 1 + # Ensure that there is one entry in _y_train_std for each target + self._y_train_std = np.ones(y.shape[1:]) if np.iterable(self.alpha) and self.alpha.shape[0] != y.shape[0]: if self.alpha.shape[0] == 1: @@ -394,6 +395,10 @@ def predict(self, X, return_std=False, return_cov=False): # undo normalisation y_mean = self._y_train_std * y_mean + self._y_train_mean + # if y_mean has shape (n_samples, 1), reshape to (n_samples,) + if len(y_mean.shape) > 1 and y_mean.shape[1] == 1: + y_mean = np.squeeze(y_mean, axis=1) + # Alg 2.1, page 19, line 5 -> v = L \ K(X_test, X_train)^T V = solve_triangular( self.L_, K_trans.T, lower=GPR_CHOLESKY_LOWER, check_finite=False @@ -407,7 +412,6 @@ def predict(self, X, return_std=False, return_cov=False): y_cov = np.outer(y_cov, self._y_train_std ** 2).reshape( *y_cov.shape, -1 ) - # if y_cov has shape (n_samples, n_samples, 1), reshape to # (n_samples, n_samples) if y_cov.shape[2] == 1: @@ -475,7 +479,9 @@ def sample_y(self, X, n_samples=1, random_state=0): y_samples = rng.multivariate_normal(y_mean, y_cov, n_samples).T else: y_samples = [ - rng.multivariate_normal(y_mean[:, i], y_cov, n_samples).T[:, np.newaxis] + rng.multivariate_normal(y_mean[:, i], y_cov[..., i], n_samples).T[ + :, np.newaxis + ] for i in range(y_mean.shape[1]) ] y_samples = np.hstack(y_samples) diff --git a/sklearn/gaussian_process/tests/test_gpr.py b/sklearn/gaussian_process/tests/test_gpr.py index 8e57865600987..71c262806414c 100644 --- a/sklearn/gaussian_process/tests/test_gpr.py +++ b/sklearn/gaussian_process/tests/test_gpr.py @@ -361,8 +361,9 @@ def test_y_multioutput(): assert_almost_equal(y_pred_1d, y_pred_2d[:, 1] / 2) # Standard deviation and covariance do not depend on output - assert_almost_equal(y_std_1d, y_std_2d) - assert_almost_equal(y_cov_1d, y_cov_2d) + for d in range(2): + assert_almost_equal(y_std_1d, y_std_2d[..., d]) + assert_almost_equal(y_cov_1d, y_cov_2d[..., d]) y_sample_1d = gpr.sample_y(X2, n_samples=10) y_sample_2d = gpr_2d.sample_y(X2, n_samples=10) @@ -682,3 +683,43 @@ def test_y_std_with_multitarget_normalized(): assert y_pred.shape == (n_samples, n_targets) assert y_std.shape == (n_samples, n_targets) assert y_cov.shape == (n_samples, n_samples, n_targets) + + +def test_y_std_cov_with_multitarget(): + """Check the shapes of `y_std` and `y_cov` in multi-target scene. + + Non-regression test for: + https://github.com/scikit-learn/scikit-learn/issues/22174 + """ + + rng = np.random.RandomState(123) + + n_features, n_targets = 3, 2 + n_samples_train, n_samples_test = 9, 7 + + X_train = rng.randn(n_samples_train, n_features) + X_test = rng.randn(n_samples_test, n_features) + + # Test different target shapes + for y_train_shape in ((), (1,), (n_targets,)): + y_train = rng.randn(n_samples_train, *y_train_shape) + + y_test_shape = (n_samples_test,) + y_train.shape[1:] + # If n_targets = 1, convention is to squeeze the predictions + if len(y_test_shape) > 1 and y_test_shape[1] == 1: + y_test_shape = y_test_shape[:1] + + # Test normalized and non-normalized models + models = [ + GaussianProcessRegressor(normalize_y=True), + GaussianProcessRegressor(normalize_y=False), + ] + + for model in models: + model.fit(X_train, y_train) + y_pred, y_std = model.predict(X_test, return_std=True) + _, y_cov = model.predict(X_test, return_cov=True) + + assert y_pred.shape == y_test_shape + assert y_std.shape == y_test_shape + assert y_cov.shape == (n_samples_test,) + y_test_shape From c82ad11bf6b647c00757177c3a43d5e4fc231b47 Mon Sep 17 00:00:00 2001 From: "Nakamura-Zimmerer, Tenavi (ARC-AF)" Date: Wed, 12 Jan 2022 14:00:06 -0800 Subject: [PATCH 07/13] added pr number --- doc/whats_new/v1.1.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 3c927df651ee0..4ee779c779599 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -275,7 +275,7 @@ Changelog :class:`sklearn.gaussian_process.GaussianProcessRegressor` now return arrays of the correct shape in single-target and multi-target cases, and for both `normalize_y=False` and `normalize_y=True`. - by :user:`Tenavi Nakamura-Zimmerer `. + :pr:`22199` by :user:`Tenavi Nakamura-Zimmerer `. :mod:`sklearn.impute` ..................... From e1c0df894aa7ee873fe6da8704594f8572817770 Mon Sep 17 00:00:00 2001 From: "Nakamura-Zimmerer, Tenavi (ARC-AF)" Date: Thu, 27 Jan 2022 11:02:29 -0800 Subject: [PATCH 08/13] consolidated gpr shape tests; moved whats new documentation from 1.0 to 1.1 --- doc/whats_new/v1.0.rst | 11 ---- doc/whats_new/v1.1.rst | 5 +- sklearn/gaussian_process/tests/test_gpr.py | 72 +++++++--------------- 3 files changed, 26 insertions(+), 62 deletions(-) diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index 58588ef2a6822..27c8fd8033075 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -88,17 +88,6 @@ Changelog precomputed dense distance matrix where the neighbors graph has multiple disconnected components. :pr:`21915` by `Tom Dupre la Tour`_. -:mod:`sklearn.gaussian_process` -............................... - -- |Fix| `predict` and `sample_y` functions of - :class:`sklearn.gaussian_process.GaussianProcessRegressor` now return - arrays of the correct shape in single-target and multi-target cases, and for - both `normalize_y=False` and `normalize_y=True`. - :pr:`21996` by :user:`Guillaume Lemaitre `, - :user:`Aidar Shakerimoff ` and - :user:`Tenavi Nakamura-Zimmerer `. - :mod:`sklearn.metrics` ...................... diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 1545eec6fbf0f..0c52101031a02 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -275,11 +275,14 @@ Changelog :mod:`sklearn.gaussian_process` ............................... + - |Fix| `predict` and `sample_y` functions of :class:`sklearn.gaussian_process.GaussianProcessRegressor` now return arrays of the correct shape in single-target and multi-target cases, and for both `normalize_y=False` and `normalize_y=True`. - :pr:`22199` by :user:`Tenavi Nakamura-Zimmerer `. + :pr:`21996` by :user:`Guillaume Lemaitre `, + :user:`Aidar Shakerimoff ` and + :user:`Tenavi Nakamura-Zimmerer `. - |Fix| :class:`gaussian_process.GaussianProcessClassifier` raises a more informative error if `CompoundKernel` is passed via `kernel`. :pr:`22223` by :user:`MarcoM `. diff --git a/sklearn/gaussian_process/tests/test_gpr.py b/sklearn/gaussian_process/tests/test_gpr.py index 64cbaed0f73d9..6e15d4fcc2920 100644 --- a/sklearn/gaussian_process/tests/test_gpr.py +++ b/sklearn/gaussian_process/tests/test_gpr.py @@ -677,68 +677,40 @@ def test_gpr_predict_error(): @pytest.mark.parametrize("normalize_y", [True, False]) -def test_multitarget_std_cov_shape(normalize_y): - """Check the shape of std. dev. and covariance in multi-output setting. +@pytest.mark.parametrize("n_targets", [0, 1, 10]) +def test_multitarget_shape(normalize_y, n_targets): + """Check the shapes of y_mean, y_std, and y_cov in single-output + (n_targets=0) and multi-output settings, including the edge case when + n_targets=1, where the sklearn convention is to squeeze the predictions. Non-regression test for: https://github.com/scikit-learn/scikit-learn/issues/17394 https://github.com/scikit-learn/scikit-learn/issues/18065 + https://github.com/scikit-learn/scikit-learn/issues/22174 """ rng = np.random.RandomState(1234) - n_samples, n_features, n_targets = 12, 10, 6 - - X_train = rng.randn(n_samples, n_features) - y_train = rng.randn(n_samples, n_targets) - X_test = rng.randn(n_samples, n_features) + n_features, n_samples_train, n_samples_test = 6, 9, 7 - # Generic kernel - kernel = WhiteKernel(1.0, (1e-1, 1e3)) * C(10.0, (1e-3, 1e3)) + y_train_shape = (n_samples_train,) + if n_targets >= 1: + y_train_shape = y_train_shape + (n_targets,) - model = GaussianProcessRegressor( - kernel=kernel, n_restarts_optimizer=10, alpha=0.1, normalize_y=normalize_y - ) - model.fit(X_train, y_train) - y_pred, y_std = model.predict(X_test, return_std=True) - _, y_cov = model.predict(X_test, return_cov=True) - - assert y_pred.shape == (n_samples, n_targets) - assert y_std.shape == (n_samples, n_targets) - assert y_cov.shape == (n_samples, n_samples, n_targets) - - -@pytest.mark.parametrize("normalize_y", [True, False]) -def test_y_std_cov_with_multitarget(normalize_y): - """Check the shapes of `y_std` and `y_cov` in multi-target scene. - - Non-regression test for: - https://github.com/scikit-learn/scikit-learn/issues/22174 - """ - - rng = np.random.RandomState(123) - - n_features, n_targets = 3, 2 - n_samples_train, n_samples_test = 9, 7 + # By convention single-output data is squeezed upon prediction + y_test_shape = (n_samples_test,) + if n_targets > 1: + y_test_shape = y_test_shape + (n_targets,) X_train = rng.randn(n_samples_train, n_features) X_test = rng.randn(n_samples_test, n_features) + y_train = rng.randn(*y_train_shape) - # Test different target shapes - for y_train_shape in ((), (1,), (n_targets,)): - y_train = rng.randn(n_samples_train, *y_train_shape) - - y_test_shape = (n_samples_test,) + y_train.shape[1:] - # If n_targets = 1, convention is to squeeze the predictions - if len(y_test_shape) > 1 and y_test_shape[1] == 1: - y_test_shape = y_test_shape[:1] - - # Test normalized and non-normalized models - model = GaussianProcessRegressor(normalize_y=normalize_y) - model.fit(X_train, y_train) + model = GaussianProcessRegressor(normalize_y=normalize_y) + model.fit(X_train, y_train) - y_pred, y_std = model.predict(X_test, return_std=True) - _, y_cov = model.predict(X_test, return_cov=True) + y_pred, y_std = model.predict(X_test, return_std=True) + _, y_cov = model.predict(X_test, return_cov=True) - assert y_pred.shape == y_test_shape - assert y_std.shape == y_test_shape - assert y_cov.shape == (n_samples_test,) + y_test_shape + assert y_pred.shape == y_test_shape + assert y_std.shape == y_test_shape + assert y_cov.shape == (n_samples_test,) + y_test_shape From a8908a89aa85bcab7f0a32550cd21d2e51a75b36 Mon Sep 17 00:00:00 2001 From: "Nakamura-Zimmerer, Tenavi (ARC-AF)" Date: Fri, 28 Jan 2022 09:52:11 -0800 Subject: [PATCH 09/13] added new test for gpr.sample_y shape, changed name of other new test --- doc/whats_new/v1.1.rst | 4 +-- sklearn/gaussian_process/tests/test_gpr.py | 42 +++++++++++++++++++++- 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 0c52101031a02..4845d0ae6d7fd 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -276,11 +276,11 @@ Changelog :mod:`sklearn.gaussian_process` ............................... -- |Fix| `predict` and `sample_y` functions of +- |Fix| `predict` and `sample_y` methods of :class:`sklearn.gaussian_process.GaussianProcessRegressor` now return arrays of the correct shape in single-target and multi-target cases, and for both `normalize_y=False` and `normalize_y=True`. - :pr:`21996` by :user:`Guillaume Lemaitre `, + :pr:`22199` by :user:`Guillaume Lemaitre `, :user:`Aidar Shakerimoff ` and :user:`Tenavi Nakamura-Zimmerer `. - |Fix| :class:`gaussian_process.GaussianProcessClassifier` raises diff --git a/sklearn/gaussian_process/tests/test_gpr.py b/sklearn/gaussian_process/tests/test_gpr.py index 6e15d4fcc2920..c5f647d51a69a 100644 --- a/sklearn/gaussian_process/tests/test_gpr.py +++ b/sklearn/gaussian_process/tests/test_gpr.py @@ -678,7 +678,7 @@ def test_gpr_predict_error(): @pytest.mark.parametrize("normalize_y", [True, False]) @pytest.mark.parametrize("n_targets", [0, 1, 10]) -def test_multitarget_shape(normalize_y, n_targets): +def test_predict_shapes(normalize_y, n_targets): """Check the shapes of y_mean, y_std, and y_cov in single-output (n_targets=0) and multi-output settings, including the edge case when n_targets=1, where the sklearn convention is to squeeze the predictions. @@ -714,3 +714,43 @@ def test_multitarget_shape(normalize_y, n_targets): assert y_pred.shape == y_test_shape assert y_std.shape == y_test_shape assert y_cov.shape == (n_samples_test,) + y_test_shape + + +@pytest.mark.parametrize("normalize_y", [True, False]) +@pytest.mark.parametrize("n_targets", [0, 1, 10]) +def test_sample_y_shapes(normalize_y, n_targets): + """Check the shapes of y_samples in single-output (n_targets=0) and + multi-output settings, including the edge case when n_targets=1, where the + sklearn convention is to squeeze the predictions. + + Non-regression test for: + https://github.com/scikit-learn/scikit-learn/issues/22175 + """ + rng = np.random.RandomState(1234) + + n_features, n_samples_train = 6, 9 + # Number of spatial locations to predict at + n_samples_X_test = 7 + # Number of sample predictions per test point + n_samples_y_test = 5 + + y_train_shape = (n_samples_train,) + if n_targets >= 1: + y_train_shape = y_train_shape + (n_targets,) + + # By convention single-output data is squeezed upon prediction + if n_targets > 1: + y_test_shape = (n_samples_X_test, n_targets, n_samples_y_test) + else: + y_test_shape = (n_samples_X_test, n_samples_y_test) + + X_train = rng.randn(n_samples_train, n_features) + X_test = rng.randn(n_samples_X_test, n_features) + y_train = rng.randn(*y_train_shape) + + model = GaussianProcessRegressor(normalize_y=normalize_y) + model.fit(X_train, y_train) + + y_samples = model.sample_y(X_test, n_samples=n_samples_y_test) + + assert y_samples.shape == y_test_shape From b1a68dc380f09a6c0e672e43c87f1bbc41dfeb07 Mon Sep 17 00:00:00 2001 From: "Nakamura-Zimmerer, Tenavi (ARC-AF)" Date: Mon, 31 Jan 2022 09:55:53 -0800 Subject: [PATCH 10/13] added check of y_samples.shape before fitting; changed n_targets=0 to n_targets=None in test parameterization --- doc/whats_new/v1.1.rst | 1 + sklearn/gaussian_process/tests/test_gpr.py | 20 ++++++++++++-------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 4845d0ae6d7fd..be74649ea354f 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -283,6 +283,7 @@ Changelog :pr:`22199` by :user:`Guillaume Lemaitre `, :user:`Aidar Shakerimoff ` and :user:`Tenavi Nakamura-Zimmerer `. + - |Fix| :class:`gaussian_process.GaussianProcessClassifier` raises a more informative error if `CompoundKernel` is passed via `kernel`. :pr:`22223` by :user:`MarcoM `. diff --git a/sklearn/gaussian_process/tests/test_gpr.py b/sklearn/gaussian_process/tests/test_gpr.py index c5f647d51a69a..895b97ba1ccb3 100644 --- a/sklearn/gaussian_process/tests/test_gpr.py +++ b/sklearn/gaussian_process/tests/test_gpr.py @@ -677,10 +677,10 @@ def test_gpr_predict_error(): @pytest.mark.parametrize("normalize_y", [True, False]) -@pytest.mark.parametrize("n_targets", [0, 1, 10]) +@pytest.mark.parametrize("n_targets", [None, 1, 10]) def test_predict_shapes(normalize_y, n_targets): """Check the shapes of y_mean, y_std, and y_cov in single-output - (n_targets=0) and multi-output settings, including the edge case when + (n_targets=None) and multi-output settings, including the edge case when n_targets=1, where the sklearn convention is to squeeze the predictions. Non-regression test for: @@ -693,12 +693,12 @@ def test_predict_shapes(normalize_y, n_targets): n_features, n_samples_train, n_samples_test = 6, 9, 7 y_train_shape = (n_samples_train,) - if n_targets >= 1: + if n_targets: y_train_shape = y_train_shape + (n_targets,) # By convention single-output data is squeezed upon prediction y_test_shape = (n_samples_test,) - if n_targets > 1: + if n_targets is not None and n_targets > 1: y_test_shape = y_test_shape + (n_targets,) X_train = rng.randn(n_samples_train, n_features) @@ -717,7 +717,7 @@ def test_predict_shapes(normalize_y, n_targets): @pytest.mark.parametrize("normalize_y", [True, False]) -@pytest.mark.parametrize("n_targets", [0, 1, 10]) +@pytest.mark.parametrize("n_targets", [None, 1, 10]) def test_sample_y_shapes(normalize_y, n_targets): """Check the shapes of y_samples in single-output (n_targets=0) and multi-output settings, including the edge case when n_targets=1, where the @@ -735,11 +735,11 @@ def test_sample_y_shapes(normalize_y, n_targets): n_samples_y_test = 5 y_train_shape = (n_samples_train,) - if n_targets >= 1: + if n_targets: y_train_shape = y_train_shape + (n_targets,) # By convention single-output data is squeezed upon prediction - if n_targets > 1: + if n_targets is not None and n_targets > 1: y_test_shape = (n_samples_X_test, n_targets, n_samples_y_test) else: y_test_shape = (n_samples_X_test, n_samples_y_test) @@ -749,8 +749,12 @@ def test_sample_y_shapes(normalize_y, n_targets): y_train = rng.randn(*y_train_shape) model = GaussianProcessRegressor(normalize_y=normalize_y) - model.fit(X_train, y_train) + # Before fitting, model assumes n_targets=1 so output should be 2d array y_samples = model.sample_y(X_test, n_samples=n_samples_y_test) + assert y_samples.shape == (n_samples_X_test, n_samples_y_test) + + model.fit(X_train, y_train) + y_samples = model.sample_y(X_test, n_samples=n_samples_y_test) assert y_samples.shape == y_test_shape From da42222760b9de04f37a0b8dc330215352a83778 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 10 Feb 2022 14:53:07 +0100 Subject: [PATCH 11/13] Update sklearn/gaussian_process/tests/test_gpr.py --- sklearn/gaussian_process/tests/test_gpr.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/sklearn/gaussian_process/tests/test_gpr.py b/sklearn/gaussian_process/tests/test_gpr.py index 895b97ba1ccb3..cbe87c2f99aa1 100644 --- a/sklearn/gaussian_process/tests/test_gpr.py +++ b/sklearn/gaussian_process/tests/test_gpr.py @@ -750,9 +750,13 @@ def test_sample_y_shapes(normalize_y, n_targets): model = GaussianProcessRegressor(normalize_y=normalize_y) - # Before fitting, model assumes n_targets=1 so output should be 2d array - y_samples = model.sample_y(X_test, n_samples=n_samples_y_test) - assert y_samples.shape == (n_samples_X_test, n_samples_y_test) + # FIXME: before fitting, the estimator does not have information regarding + # the number of targets and default to 1. This is inconsistent with the shape + # provided after `fit`. This assert should be made once the following issue + # is fixed: + # https://github.com/scikit-learn/scikit-learn/issues/22430 + # y_samples = model.sample_y(X_test, n_samples=n_samples_y_test) + # assert y_samples.shape == y_test_shape model.fit(X_train, y_train) From 80027a7cd2bf9c79ab2716a946a382b41f1e2d70 Mon Sep 17 00:00:00 2001 From: "Nakamura-Zimmerer, Tenavi (ARC-AF)" Date: Mon, 14 Feb 2022 08:13:03 -0800 Subject: [PATCH 12/13] little code cleanup suggested by thomasjpfan --- sklearn/gaussian_process/_gpr.py | 2 +- sklearn/gaussian_process/tests/test_gpr.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/gaussian_process/_gpr.py b/sklearn/gaussian_process/_gpr.py index 40b9510cc369a..50afba3f0e0da 100644 --- a/sklearn/gaussian_process/_gpr.py +++ b/sklearn/gaussian_process/_gpr.py @@ -396,7 +396,7 @@ def predict(self, X, return_std=False, return_cov=False): y_mean = self._y_train_std * y_mean + self._y_train_mean # if y_mean has shape (n_samples, 1), reshape to (n_samples,) - if len(y_mean.shape) > 1 and y_mean.shape[1] == 1: + if y_mean.n_dim > 1 and y_mean.shape[1] == 1: y_mean = np.squeeze(y_mean, axis=1) # Alg 2.1, page 19, line 5 -> v = L \ K(X_test, X_train)^T diff --git a/sklearn/gaussian_process/tests/test_gpr.py b/sklearn/gaussian_process/tests/test_gpr.py index 895b97ba1ccb3..e88a56daf2a6c 100644 --- a/sklearn/gaussian_process/tests/test_gpr.py +++ b/sklearn/gaussian_process/tests/test_gpr.py @@ -693,7 +693,7 @@ def test_predict_shapes(normalize_y, n_targets): n_features, n_samples_train, n_samples_test = 6, 9, 7 y_train_shape = (n_samples_train,) - if n_targets: + if n_targets is not None: y_train_shape = y_train_shape + (n_targets,) # By convention single-output data is squeezed upon prediction @@ -735,7 +735,7 @@ def test_sample_y_shapes(normalize_y, n_targets): n_samples_y_test = 5 y_train_shape = (n_samples_train,) - if n_targets: + if n_targets is not None: y_train_shape = y_train_shape + (n_targets,) # By convention single-output data is squeezed upon prediction From bdcb2e770e6e9e0a6168d6532c5a26fcdfeade8b Mon Sep 17 00:00:00 2001 From: "Nakamura-Zimmerer, Tenavi (ARC-AF)" Date: Mon, 14 Feb 2022 08:16:21 -0800 Subject: [PATCH 13/13] typo fix --- sklearn/gaussian_process/_gpr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/gaussian_process/_gpr.py b/sklearn/gaussian_process/_gpr.py index 50afba3f0e0da..e93a1121af18c 100644 --- a/sklearn/gaussian_process/_gpr.py +++ b/sklearn/gaussian_process/_gpr.py @@ -396,7 +396,7 @@ def predict(self, X, return_std=False, return_cov=False): y_mean = self._y_train_std * y_mean + self._y_train_mean # if y_mean has shape (n_samples, 1), reshape to (n_samples,) - if y_mean.n_dim > 1 and y_mean.shape[1] == 1: + if y_mean.ndim > 1 and y_mean.shape[1] == 1: y_mean = np.squeeze(y_mean, axis=1) # Alg 2.1, page 19, line 5 -> v = L \ K(X_test, X_train)^T