-
-
Notifications
You must be signed in to change notification settings - Fork 26.2k
Description
Describe the bug
Related to #22174; supposed to have been fixed in #20761?
See #22199
When using a GPR model for multi-target data, if we set normalize_y=False then sample_y fails.
Steps/Code to Reproduce
import numpy as np
import sklearn
from sklearn.gaussian_process import GaussianProcessRegressor as GPR
print(sklearn.__version__)
X_train = np.random.rand(7,3)
X_test = np.random.rand(4,3)
# ---- Working code ---- #
# n_targets = 1
Y_train = np.random.randn(7,1)
model = GPR(normalize_y=True)
model.fit(X_train, Y_train)
print(model.sample_y(X_test, n_samples=1).shape)
print(model.sample_y(X_test, n_samples=15).shape)
# n_targets = 2
Y_train = np.random.randn(7,2)
model = GPR(normalize_y=False)
model.fit(X_train, Y_train)
print(model.sample_y(X_test, n_samples=1).shape)
print(model.sample_y(X_test, n_samples=15).shape)
# ---- Broken code ---- #
# n_targets = 2
Y_train = np.random.randn(7,2)
model = GPR(normalize_y=True)
model.fit(X_train, Y_train)
print(model.sample_y(X_test, n_samples=1).shape)
print(model.sample_y(X_test, n_samples=15).shape)
Expected Results
GPR.sample_y should result in an array of shape (n_samples_X, n_targets, n_samples_y), so the above code should produce
(4, 1, 1)
(4, 1, 15)
(4, 2, 1)
(4, 2, 15)
(4, 2, 1)
(4, 2, 15)
Actual Results
Instead we get an error relating to the shape of the predicted covariance:
File "sklearn/gaussian_process/_gpr.py", line 477, in sample_y
y_samples = [
File "sklearn/gaussian_process/_gpr.py", line 478, in <listcomp>
rng.multivariate_normal(y_mean[:, i], y_cov, n_samples).T[:, np.newaxis]
File "mtrand.pyx", line 4069, in numpy.random.mtrand.RandomState.multivariate_normal
ValueError: cov must be 2 dimensional and square
Versions
System:
python: 3.9.5 | packaged by conda-forge | (default, Jun 19 2021, 00:27:35) [Clang 11.1.0 ]
executable: /Users/tnakam10/opt/anaconda3/envs/aerofusion/bin/python
machine: macOS-11.6.1-x86_64-i386-64bit
Python dependencies:
pip: 21.3.1
setuptools: 60.5.0
sklearn: 1.0.2
numpy: 1.19.5
scipy: 1.7.3
Cython: None
pandas: 1.3.5
matplotlib: 3.5.1
joblib: 1.1.0
threadpoolctl: 3.0.0
Built with OpenMP: True