Skip to content

FIX compute y_std properly with multi-target in GPR #20761

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Oct 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
0eec61b
fix: GaussianProcessRegressor fails to compute y_std when n_targets > 1
patrickctrf Aug 17, 2021
e6b7c2e
fix: Reshape to (n_samples,) instead of (n_samples, 1) if single target
patrickctrf Aug 25, 2021
d59a825
feature: Add an entry in doc/whats_new, explaining changes and credits.
patrickctrf Sep 9, 2021
c321c32
feature: Add a non-regression test for the issue being fixed.
patrickctrf Sep 9, 2021
7acf698
fix: Linting error in tests file.
patrickctrf Sep 9, 2021
acc4458
fix: Persistent linting error in tests file.
patrickctrf Sep 9, 2021
7d0d97a
checkpoint: Testing linting demands
patrickctrf Sep 9, 2021
eecb8bc
feature: Added whats_new entry and non-regressive test.
patrickctrf Sep 9, 2021
df652ae
fix: Fix whats_new formatting.
patrickctrf Sep 9, 2021
937b00a
fix: test_gpr.py formatting fixed using black.
patrickctrf Sep 9, 2021
8a17971
fix: There was a typo in test function.
patrickctrf Sep 9, 2021
b8f7f3d
checkpoint: Testing pipeline execution.
patrickctrf Sep 9, 2021
18871a7
refactor: Reformat sklearn/gaussian_process/tests/test_gpr.py
patrickctrf Oct 9, 2021
1dc94bd
fix: doc/whats_new/v1.0.rst pattern
patrickctrf Oct 9, 2021
57c0b26
refactor: Update sklearn/gaussian_process/_gpr.py
patrickctrf Oct 9, 2021
6b3d12e
refactor: Undo Black changes.
patrickctrf Oct 12, 2021
c100c0d
Merge remote-tracking branch 'origin/main' into pr/patrickctrf/20761
glemaitre Oct 14, 2021
c40487d
fix whats new
glemaitre Oct 14, 2021
25c944b
black
glemaitre Oct 14, 2021
7189727
Merge remote-tracking branch 'origin/main' into pr/patrickctrf/20761
glemaitre Oct 19, 2021
da31226
fix changelog
glemaitre Oct 19, 2021
1027317
Merge branch 'master' into pr/patrickctrf/20761
jeremiedbb Oct 20, 2021
d0a1a74
y_cov as well
jeremiedbb Oct 20, 2021
d52a00e
cln
jeremiedbb Oct 20, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/whats_new/v1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ Fixed models
between sparse and dense input. :pr:`21195`
by :user:`Jérémie du Boisberranger <jeremiedbb>`.

:mod:`sklearn.gaussian_process`
...............................

- |Fix| Compute `y_std` properly with multi-target in
:class:`sklearn.gaussian_process.GaussianProcessRegressor` allowing
proper normalization in multi-target scene.
:pr:`20761` by :user:`Patrick de C. T. R. Ferreira <patrickctrf>`.

:mod:`sklearn.feature_extraction`
.................................

Expand Down
22 changes: 18 additions & 4 deletions sklearn/gaussian_process/_gpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,11 +349,12 @@ def predict(self, X, return_std=False, return_cov=False):
y_mean : ndarray of shape (n_samples,) or (n_samples, n_targets)
Mean of predictive distribution a query points.

y_std : ndarray of shape (n_samples,), optional
y_std : ndarray of shape (n_samples,) or (n_samples, n_targets), optional
Standard deviation of predictive distribution at query points.
Only returned when `return_std` is True.

y_cov : ndarray of shape (n_samples, n_samples), optional
y_cov : ndarray of shape (n_samples, n_samples) or \
(n_samples, n_samples, n_targets), optional
Covariance of joint predictive distribution a query points.
Only returned when `return_cov` is True.
"""
Expand Down Expand Up @@ -403,7 +404,14 @@ def predict(self, X, return_std=False, return_cov=False):
y_cov = self.kernel_(X) - V.T @ V

# undo normalisation
y_cov = y_cov * self._y_train_std ** 2
y_cov = np.outer(y_cov, self._y_train_std ** 2).reshape(
*y_cov.shape, -1
)

# if y_cov has shape (n_samples, n_samples, 1), reshape to
# (n_samples, n_samples)
if y_cov.shape[2] == 1:
y_cov = np.squeeze(y_cov, axis=2)

return y_mean, y_cov
elif return_std:
Expand All @@ -424,7 +432,13 @@ def predict(self, X, return_std=False, return_cov=False):
y_var[y_var_negative] = 0.0

# undo normalisation
y_var = y_var * self._y_train_std ** 2
y_var = np.outer(y_var, self._y_train_std ** 2).reshape(
*y_var.shape, -1
)

# if y_var has shape (n_samples, 1), reshape to (n_samples,)
if y_var.shape[1] == 1:
y_var = np.squeeze(y_var, axis=1)

return y_mean, np.sqrt(y_var)
else:
Expand Down
30 changes: 30 additions & 0 deletions sklearn/gaussian_process/tests/test_gpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,3 +652,33 @@ def test_gpr_predict_error():
err_msg = "At most one of return_std or return_cov can be requested."
with pytest.raises(RuntimeError, match=err_msg):
gpr.predict(X, return_cov=True, return_std=True)


def test_y_std_with_multitarget_normalized():
"""Check the proper normalization of `y_std` and `y_cov` in multi-target scene.

Non-regression test for:
https://github.com/scikit-learn/scikit-learn/issues/17394
https://github.com/scikit-learn/scikit-learn/issues/18065
"""
rng = np.random.RandomState(1234)

n_samples, n_features, n_targets = 12, 10, 6

X_train = rng.randn(n_samples, n_features)
y_train = rng.randn(n_samples, n_targets)
X_test = rng.randn(n_samples, n_features)

# Generic kernel
kernel = WhiteKernel(1.0, (1e-1, 1e3)) * C(10.0, (1e-3, 1e3))

model = GaussianProcessRegressor(
kernel=kernel, n_restarts_optimizer=10, alpha=0.1, normalize_y=True
)
model.fit(X_train, y_train)
y_pred, y_std = model.predict(X_test, return_std=True)
_, y_cov = model.predict(X_test, return_cov=True)

assert y_pred.shape == (n_samples, n_targets)
assert y_std.shape == (n_samples, n_targets)
assert y_cov.shape == (n_samples, n_samples, n_targets)