Skip to content

Multi-target GPR sample_y fails when normalize_y=True #22175

@Tenavi

Description

@Tenavi

Describe the bug

Related to #22174; supposed to have been fixed in #20761?

See #22199

When using a GPR model for multi-target data, if we set normalize_y=False then sample_y fails.

Steps/Code to Reproduce

import numpy as np
import sklearn
from sklearn.gaussian_process import GaussianProcessRegressor as GPR

print(sklearn.__version__)

X_train = np.random.rand(7,3)
X_test = np.random.rand(4,3)

# ---- Working code ---- #

# n_targets = 1
Y_train = np.random.randn(7,1)

model = GPR(normalize_y=True)

model.fit(X_train, Y_train)

print(model.sample_y(X_test, n_samples=1).shape)
print(model.sample_y(X_test, n_samples=15).shape)

# n_targets = 2
Y_train = np.random.randn(7,2)

model = GPR(normalize_y=False)

model.fit(X_train, Y_train)

print(model.sample_y(X_test, n_samples=1).shape)
print(model.sample_y(X_test, n_samples=15).shape)

# ---- Broken code ---- #

# n_targets = 2
Y_train = np.random.randn(7,2)

model = GPR(normalize_y=True)

model.fit(X_train, Y_train)

print(model.sample_y(X_test, n_samples=1).shape)
print(model.sample_y(X_test, n_samples=15).shape)

Expected Results

GPR.sample_y should result in an array of shape (n_samples_X, n_targets, n_samples_y), so the above code should produce

(4, 1, 1)
(4, 1, 15)
(4, 2, 1)
(4, 2, 15)
(4, 2, 1)
(4, 2, 15)

Actual Results

Instead we get an error relating to the shape of the predicted covariance:

File "sklearn/gaussian_process/_gpr.py", line 477, in sample_y
    y_samples = [
  File "sklearn/gaussian_process/_gpr.py", line 478, in <listcomp>
    rng.multivariate_normal(y_mean[:, i], y_cov, n_samples).T[:, np.newaxis]
  File "mtrand.pyx", line 4069, in numpy.random.mtrand.RandomState.multivariate_normal
ValueError: cov must be 2 dimensional and square

Versions

System:
python: 3.9.5 | packaged by conda-forge | (default, Jun 19 2021, 00:27:35) [Clang 11.1.0 ]
executable: /Users/tnakam10/opt/anaconda3/envs/aerofusion/bin/python
machine: macOS-11.6.1-x86_64-i386-64bit

Python dependencies:
pip: 21.3.1
setuptools: 60.5.0
sklearn: 1.0.2
numpy: 1.19.5
scipy: 1.7.3
Cython: None
pandas: 1.3.5
matplotlib: 3.5.1
joblib: 1.1.0
threadpoolctl: 3.0.0

Built with OpenMP: True

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions