Skip to content

Conversation

arthurmensch
Copy link
Contributor

This PR fixes stability issue and sign indeterminacy in PLS and CCA (see bug #2821).

Three issues were adressed:

  • fit_transform did not work for obvious reason. I fixed this and change estimator_checks so that it do not overlook Transformer unable to perform fit_transform without a previous fit
  • scipy.linalg.pinv is based on lstsq and is subject to more numerical instability than scipy.linalg.pinv2, which, quite amusingly, is the same as numpy.linalg.pinv, and is based on SVD decomposition
  • PLS is subject to sign indeterminacy (like any matrix decomposition method). Similar to svd_flip, we fix this within PLS code.

The signs of x_loadings_, x_score_, x_weights_, x_rotations_ can differ columnwise from R implementation, as there is a sign indeterminacy that we seek to raise (cf SVD with svd_flip).

  • Since test_pls is based on R output, I still need to change it so that it does not fail because of sign differences.

ping @twiecki, @Fenugreek for PLS proficient reviews

@@ -569,6 +569,17 @@ def svd_flip(u, v, u_based_decision=True):
return u, v


def flip_max_to_positive(u):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make that private. Also, it's likely to be only useful for the PLS-style models so it's better to put that helper in the pls module itself I think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed it as it is not longer needed

@arthurmensch arthurmensch changed the title [WIP] Fix fit_transform, stability issue and scale issue in PLS [MRG] Fix fit_transform, stability issue and scale issue in PLS Oct 7, 2015
@@ -21,12 +21,12 @@ def test_pls():
# check equalities of loading (up to the sign of the second column)
assert_array_almost_equal(
pls_bynipals.x_loadings_,
np.multiply(pls_bysvd.x_loadings_, np.array([1, -1, 1])), decimal=5,
pls_bysvd.x_loadings_, decimal=5,
err_msg="nipals and svd implementation lead to different x loadings")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

implementation_s_

@ogrisel ogrisel added the Bug label Oct 12, 2015
@ogrisel ogrisel added this to the 0.17 milestone Oct 12, 2015
@@ -19,6 +20,10 @@
__all__ = ['PLSCanonical', 'PLSRegression', 'PLSSVD']


def check_finite(X_pinv):
pass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a left-over of a prior experiment. It should be removed right?

@ogrisel
Copy link
Member

ogrisel commented Oct 12, 2015

Apart from the minor comments, it looks good to me (+1). Could you please squash the intermediate commits and add a what's new entry to document the fix?

@@ -287,6 +294,11 @@ def fit(self, X, Y):
self.n_iter_.append(n_iter_)
elif self.algorithm == "svd":
x_weights, y_weights = _svd_cross_product(X=Xk, Y=Yk)
# Forces sign stability of x_weights and y_weights
# Sign undeterminacy issue from svd if algorithm == "svd"
# and from platform dependant computation if algorithm == 'nipals'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(French speaking) typo: dependent

@ogrisel ogrisel changed the title [MRG] Fix fit_transform, stability issue and scale issue in PLS [MRG+1] Fix fit_transform, stability issue and scale issue in PLS Oct 12, 2015
Q : y_loadings__

Computed such that:
X = T P^T + Err and Y = U Q^T + Err
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so T is x_scores_ and ^T is transpose? Is the T standard PLS notation? the mixture of python syntax (for the slicing) and latex syntax (for the ^T) is also a bit confusing. Maybe .T?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

T is standard PLS notation AFAIK. I could use ^* or ' as weel. But .T is good

@amueller
Copy link
Member

Is pinv2 much slower than scipy.pinv?

@@ -113,25 +126,32 @@ def check_ortho(M, err_msg):
[[-0.61330704, -0.00443647, 0.78983213],
[-0.74697144, -0.32172099, -0.58183269],
[-0.25668686, 0.94682413, -0.19399983]])
assert_array_almost_equal(pls_2.x_weights_, x_weights)
x_weights_sign_flip = pls_2.x_weights_ / x_weights
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks ok but I feel it would be more natural to compute the sign flip and then compare for R == sign_flip * Python for all of them. but your way works, too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would add a redundant test wouldn't it ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point actually, but we need to compute both sign flip to make sure that np.abs(sign_flip) is np.ones so I am not sure if the code would be better.

In any case, comparing R to Python is IMHO not the way to proceed on the long term. I am no expert but I think that some mathematical assertions could be checked instead.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in addition to anything else, comparing to R is good ;)
I'm fine with leaving as is.

@arthurmensch arthurmensch force-pushed the cca_bug branch 4 times, most recently from 8b83c2b to 0f50e7e Compare October 14, 2015 07:49
@arthurmensch arthurmensch force-pushed the cca_bug branch 2 times, most recently from f48b864 to 307d0c7 Compare October 14, 2015 07:54
@arthurmensch
Copy link
Contributor Author

I addressed all comments and added a what's new entry.

@amueller
Copy link
Member

Did this get much slower than master?

@arthurmensch
Copy link
Contributor Author

I ll do a quick benchmark

@arthurmensch
Copy link
Contributor Author

Performance are not lowered in this PR, using PLS example from scikit-learn with N = 5000

@eickenberg
Copy link
Contributor

LGTM otherwise

@arthurmensch
Copy link
Contributor Author

Comments adressed, you commented in my own branch

@ogrisel
Copy link
Member

ogrisel commented Oct 21, 2015

@arthurmensch This PR needs a rebase on top of the current master. Also I don't see the commit that addresses @eickenberg's comments.

+1 for merge as well once rebased. I can handle the backport to 0.17 once merged to master.

@@ -257,7 +267,7 @@ def fit(self, X, Y):
if self.deflation_mode not in ["canonical", "regression"]:
raise ValueError('The deflation mode is unknown')
# Scale (in place)
X, Y, self.x_mean_, self.y_mean_, self.x_std_, self.y_std_\
X, Y, self.x_mean_, self.y_mean_, self.x_std_, self.y_std_ \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(X, Y,
... ) = _center_scale instead of having the \ ?

@@ -306,10 +306,16 @@ Bug fixes
(`#5360 <https://github.com/scikit-learn/scikit-learn/pull/5360>`_)
By `Tom Dupre la Tour`_.

<<<<<<< HEAD
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, that needs to be fixed

@arthurmensch
Copy link
Contributor Author

Sorry I failed the rebase again, should be ok now

@kastnerkyle
Copy link
Member

Besides line 49 - 50 "stability reasons" nitpick it LGTM. So +3 basically

@arthurmensch
Copy link
Contributor Author

There is WIP on PLS (read only errors PR #5507 ), it would be good to merge this PR before

@amueller
Copy link
Member

merged as 1c7168d, thanks :) [backporting now]

@ogrisel
Copy link
Member

ogrisel commented Oct 27, 2015

@amueller this PR is still open. Have your done the backport?

@amueller amueller closed this Nov 2, 2015
@amueller
Copy link
Member

amueller commented Nov 2, 2015

backport at 2bb4d09

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants