Skip to content

Commit ad54b76

Browse files
Charlie-XIAOpunndcoder28
authored andcommitted
FIX ravel prediction of PLSRegression when fitted on 1d y (scikit-learn#26602)
1 parent 22bb0bd commit ad54b76

File tree

3 files changed

+35
-2
lines changed

3 files changed

+35
-2
lines changed

doc/whats_new/v1.4.rst

+7
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,13 @@ Changelog
6666
- |Enhancement| :func:`base.clone` now supports `dict` as input and creates a
6767
copy. :pr:`26786` by `Adrin Jalali`_.
6868

69+
:mod:`sklearn.cross_decomposition`
70+
..................................
71+
72+
- |Fix| :class:`cross_decomposition.PLSRegression` now automatically ravels the output
73+
of `predict` if fitted with one dimensional `y`.
74+
:pr:`26602` by :user:`Yao Xiao <Charlie-XIAO>`.
75+
6976
:mod:`sklearn.decomposition`
7077
............................
7178

sklearn/cross_decomposition/_pls.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,10 @@ def fit(self, X, Y):
238238
Y, input_name="Y", dtype=np.float64, copy=self.copy, ensure_2d=False
239239
)
240240
if Y.ndim == 1:
241+
self._predict_1d = True
241242
Y = Y.reshape(-1, 1)
243+
else:
244+
self._predict_1d = False
242245

243246
n = X.shape[0]
244247
p = X.shape[1]
@@ -469,8 +472,8 @@ def predict(self, X, copy=True):
469472
# Normalize
470473
X -= self._x_mean
471474
X /= self._x_std
472-
Ypred = X @ self.coef_.T
473-
return Ypred + self.intercept_
475+
Ypred = X @ self.coef_.T + self.intercept_
476+
return Ypred.ravel() if self._predict_1d else Ypred
474477

475478
def fit_transform(self, X, y=None):
476479
"""Learn and apply the dimension reduction on the train data.

sklearn/cross_decomposition/tests/test_pls.py

+23
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
_svd_flip_1d,
1313
)
1414
from sklearn.datasets import load_linnerud, make_regression
15+
from sklearn.ensemble import VotingRegressor
1516
from sklearn.exceptions import ConvergenceWarning
17+
from sklearn.linear_model import LinearRegression
1618
from sklearn.utils import check_random_state
1719
from sklearn.utils.extmath import svd_flip
1820

@@ -621,3 +623,24 @@ def test_pls_set_output(Klass):
621623
assert isinstance(y_trans, np.ndarray)
622624
assert isinstance(X_trans, pd.DataFrame)
623625
assert_array_equal(X_trans.columns, est.get_feature_names_out())
626+
627+
628+
def test_pls_regression_fit_1d_y():
629+
"""Check that when fitting with 1d `y`, prediction should also be 1d.
630+
631+
Non-regression test for Issue #26549.
632+
"""
633+
X = np.array([[1, 1], [2, 4], [3, 9], [4, 16], [5, 25], [6, 36]])
634+
y = np.array([2, 6, 12, 20, 30, 42])
635+
expected = y.copy()
636+
637+
plsr = PLSRegression().fit(X, y)
638+
y_pred = plsr.predict(X)
639+
assert y_pred.shape == expected.shape
640+
641+
# Check that it works in VotingRegressor
642+
lr = LinearRegression().fit(X, y)
643+
vr = VotingRegressor([("lr", lr), ("plsr", plsr)])
644+
y_pred = vr.fit(X, y).predict(X)
645+
assert y_pred.shape == expected.shape
646+
assert_allclose(y_pred, expected)

0 commit comments

Comments
 (0)