-
-
Notifications
You must be signed in to change notification settings - Fork 26.2k
[MRG+1] Fix fit_transform, stability issue and scale issue in PLS #5358
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -569,6 +569,17 @@ def svd_flip(u, v, u_based_decision=True): | |||
return u, v | |||
|
|||
|
|||
def flip_max_to_positive(u): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please make that private. Also, it's likely to be only useful for the PLS-style models so it's better to put that helper in the pls module itself I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed it as it is not longer needed
@@ -21,12 +21,12 @@ def test_pls(): | |||
# check equalities of loading (up to the sign of the second column) | |||
assert_array_almost_equal( | |||
pls_bynipals.x_loadings_, | |||
np.multiply(pls_bysvd.x_loadings_, np.array([1, -1, 1])), decimal=5, | |||
pls_bysvd.x_loadings_, decimal=5, | |||
err_msg="nipals and svd implementation lead to different x loadings") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
implementation_s_
@@ -19,6 +20,10 @@ | |||
__all__ = ['PLSCanonical', 'PLSRegression', 'PLSSVD'] | |||
|
|||
|
|||
def check_finite(X_pinv): | |||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like a left-over of a prior experiment. It should be removed right?
Apart from the minor comments, it looks good to me (+1). Could you please squash the intermediate commits and add a what's new entry to document the fix? |
@@ -287,6 +294,11 @@ def fit(self, X, Y): | |||
self.n_iter_.append(n_iter_) | |||
elif self.algorithm == "svd": | |||
x_weights, y_weights = _svd_cross_product(X=Xk, Y=Yk) | |||
# Forces sign stability of x_weights and y_weights | |||
# Sign undeterminacy issue from svd if algorithm == "svd" | |||
# and from platform dependant computation if algorithm == 'nipals' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(French speaking) typo: dependent
Q : y_loadings__ | ||
|
||
Computed such that: | ||
X = T P^T + Err and Y = U Q^T + Err |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so T
is x_scores_
and ^T
is transpose? Is the T
standard PLS notation? the mixture of python syntax (for the slicing) and latex syntax (for the ^T
) is also a bit confusing. Maybe .T
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
T is standard PLS notation AFAIK. I could use ^* or ' as weel. But .T is good
Is |
@@ -113,25 +126,32 @@ def check_ortho(M, err_msg): | |||
[[-0.61330704, -0.00443647, 0.78983213], | |||
[-0.74697144, -0.32172099, -0.58183269], | |||
[-0.25668686, 0.94682413, -0.19399983]]) | |||
assert_array_almost_equal(pls_2.x_weights_, x_weights) | |||
x_weights_sign_flip = pls_2.x_weights_ / x_weights |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this looks ok but I feel it would be more natural to compute the sign flip and then compare for R == sign_flip * Python for all of them. but your way works, too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would add a redundant test wouldn't it ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see your point actually, but we need to compute both sign flip to make sure that np.abs(sign_flip)
is np.ones
so I am not sure if the code would be better.
In any case, comparing R to Python is IMHO not the way to proceed on the long term. I am no expert but I think that some mathematical assertions could be checked instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think in addition to anything else, comparing to R is good ;)
I'm fine with leaving as is.
8b83c2b
to
0f50e7e
Compare
f48b864
to
307d0c7
Compare
I addressed all comments and added a what's new entry. |
Did this get much slower than master? |
I ll do a quick benchmark |
Performance are not lowered in this PR, using PLS example from scikit-learn with |
LGTM otherwise |
Comments adressed, you commented in my own branch |
@arthurmensch This PR needs a rebase on top of the current master. Also I don't see the commit that addresses @eickenberg's comments. +1 for merge as well once rebased. I can handle the backport to 0.17 once merged to master. |
@@ -257,7 +267,7 @@ def fit(self, X, Y): | |||
if self.deflation_mode not in ["canonical", "regression"]: | |||
raise ValueError('The deflation mode is unknown') | |||
# Scale (in place) | |||
X, Y, self.x_mean_, self.y_mean_, self.x_std_, self.y_std_\ | |||
X, Y, self.x_mean_, self.y_mean_, self.x_std_, self.y_std_ \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(X, Y,
... ) = _center_scale instead of having the \ ?
@@ -306,10 +306,16 @@ Bug fixes | |||
(`#5360 <https://github.com/scikit-learn/scikit-learn/pull/5360>`_) | |||
By `Tom Dupre la Tour`_. | |||
|
|||
<<<<<<< HEAD |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops, that needs to be fixed
Sorry I failed the rebase again, should be ok now |
Besides line 49 - 50 "stability reasons" nitpick it LGTM. So +3 basically |
There is WIP on PLS (read only errors PR #5507 ), it would be good to merge this PR before |
merged as 1c7168d, thanks :) [backporting now] |
@amueller this PR is still open. Have your done the backport? |
backport at 2bb4d09 |
This PR fixes stability issue and sign indeterminacy in PLS and CCA (see bug #2821).
Three issues were adressed:
fit_transform
did not work for obvious reason. I fixed this and change estimator_checks so that it do not overlook Transformer unable to performfit_transform
without a previousfit
lstsq
and is subject to more numerical instability thanscipy.linalg.pinv2
, which, quite amusingly, is the same asnumpy.linalg.pinv
, and is based on SVD decompositionsvd_flip
, we fix this within PLS code.The signs of
x_loadings_
,x_score_
,x_weights_
,x_rotations_
can differ columnwise from R implementation, as there is a sign indeterminacy that we seek to raise (cf SVD withsvd_flip
).test_pls
is based on R output, I still need to change it so that it does not fail because of sign differences.ping @twiecki, @Fenugreek for PLS proficient reviews