Skip to content

Commit bddd925

Browse files
thomasjpfanjnothman
authored andcommitted
FIX clip small values in PLS cross-decomposition for increased stability (scikit-learn#13903)
1 parent 8355584 commit bddd925

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

doc/whats_new/v0.20.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@ The bundled version of joblib was upgraded from 0.13.0 to 0.13.2.
3131
``remainder`` transformer.
3232
:pr:`14237` by `Andreas Schuderer <schuderer>`.
3333

34+
:mod:`sklearn.decomposition`
35+
............................
36+
37+
- |Fix| Fixed a bug in :class:`cross_decomposition.CCA` improving numerical
38+
stability when `Y` is close to zero. :pr:`13903` by `Thomas Fan`_.
39+
40+
3441
:mod:`sklearn.model_selection`
3542
..............................
3643

sklearn/cross_decomposition/pls_.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ def fit(self, X, Y):
285285
self.n_iter_ = []
286286

287287
# NIPALS algo: outer loop, over components
288+
Y_eps = np.finfo(Yk.dtype).eps
288289
for k in range(self.n_components):
289290
if np.all(np.dot(Yk.T, Yk) < np.finfo(np.double).eps):
290291
# Yk constant
@@ -293,6 +294,10 @@ def fit(self, X, Y):
293294
# 1) weights estimation (inner loop)
294295
# -----------------------------------
295296
if self.algorithm == "nipals":
297+
# Replace columns that are all close to zero with zeros
298+
Yk_mask = np.all(np.abs(Yk) < 10 * Y_eps, axis=0)
299+
Yk[:, Yk_mask] = 0.0
300+
296301
x_weights, y_weights, n_iter_ = \
297302
_nipals_twoblocks_inner_loop(
298303
X=Xk, Y=Yk, mode=self.mode, max_iter=self.max_iter,

0 commit comments

Comments
 (0)