Skip to content

Fixed issue with KernelPCA.inverse_transform mean #16655

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 10, 2020
Merged

Fixed issue with KernelPCA.inverse_transform mean #16655

merged 5 commits into from
Mar 10, 2020

Conversation

lrjball
Copy link
Contributor

@lrjball lrjball commented Mar 7, 2020

Currently KernelPCA.inverse_transform returns a data set with zero mean, even if the original data did not have zero mean.

I believe that this PR fixes that issue, so that the mean of the inverse-transformed data set is the same as the mean of the original data set.

I've also added a test for this update.

Reference Issues/PRs

Fixes #16654

lrjball added 2 commits March 7, 2020 16:50
Added the mean to the inverse_transform to fix the issue, and have added tests as well.
Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks

Please add an entry to the change log at doc/whats_new/v0.23.rst. Like the other entries there, please reference this pull request with :pr: and credit yourself (and other contributors if applicable) with :user:.

Please also note versionchanged in the docstring of inverse_transform

kp = KernelPCA(n_components=2, kernel=kernel, fit_inverse_transform=True)
X_trans = kp.fit_transform(X)
X_inv = kp.inverse_transform(X_trans)
assert np.isclose(X.mean(axis=0), X_inv.mean(axis=0)).all()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be able to assert re the values, not just the means?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I've updated the test now.

- Added entry in whats new v0.23
- Updated tests to check for closeness of full data set, not just closeness of the mean.
- Updated the fix for this bug. Realised that it wasn't an issue with the mean not being added on, but instead that self.alpha was not being taken into account in the inverse transform.
@lrjball
Copy link
Contributor Author

lrjball commented Mar 8, 2020

I realized that the fix to the bug wasn't actually due to the mean not being added, but was due to the alpha not being handled properly in the inverse_transform. So I have updated the fix and have also added an entry in the change log.

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update

@@ -138,6 +138,10 @@ Changelog
:func:`decomposition.non_negative_factorization` now preserves float32 dtype.
:pr:`16280` by :user:`Jeremie du Boisberranger <jeremiedbb>`.

- |Fix| :class:`decomposition.KernelPCA` method ``inverse_transform`` now
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe could be "fixed .... in the case that data was not centred" would be more helpful to users

Copy link
Contributor Author

@lrjball lrjball Mar 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So actually after doing some digging, it was still returning the wrong thing even for centered data. I just only noticed the bug in non-centered data because the mean of the inverse-transformed data was zero when the original data set was not centered.

For example, in 0.22.0 the following still does not work:

import numpy as np
from sklearn.datasets import make_blobs
from sklearn.decomposition import KernelPCA

X, _ = make_blobs(n_samples=100, centers=[[1, 1, 1, 1]], random_state=0)
X = X - X.mean(axis=0)
kp = KernelPCA(n_components=2, fit_inverse_transform=True)
X_trans = kp.fit_transform(X)
X_inv = kp.inverse_transform(X_trans)

assert np.isclose(X, X_inv).all()

So this PR fixes the inverse_transform function for all X.

However, I can still update the message if there is a better way of phrasing this change.

Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems correct. Just a small change in the test to make an assert on the numpy array.

Co-Authored-By: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, @lrjball!

@jnothman
Copy link
Member

jnothman commented Mar 9, 2020

@glemaitre, does this have your approval?

@glemaitre glemaitre merged commit 535ef55 into scikit-learn:master Mar 10, 2020
@glemaitre
Copy link
Member

@lrjball Thanks for the fix

ashutosh1919 pushed a commit to ashutosh1919/scikit-learn that referenced this pull request Mar 13, 2020
gio8tisu pushed a commit to gio8tisu/scikit-learn that referenced this pull request May 15, 2020
@kstoneriv3
Copy link
Contributor

For me, it seems that this PR broke the code. Could you check #18902?

@kstoneriv3
Copy link
Contributor

I realized that the fix to the bug wasn't actually due to the mean not being added, but was due to the alpha not being handled properly in the inverse_transform. ...

Though it is concluded here that the alpha is handled correctly, it is indeed handled correctly in the original implementation. The problem was that the mean was not added to the data when the kernel is linear. When the linear kernel is used, the information of the mean is completely lost by the centering of the kernel, while when non-linear kernel is used, the information of the mean is partially lost.

@kstoneriv3
Copy link
Contributor

kstoneriv3 commented Mar 25, 2021

@jnothman @glemaitre
As discussed above and in #18902, I believe this PR introduced a bug, so I created the above PR (#19732) to fix it. I would appreciate it if you could take the time to review it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

inverse_transform in KernelPCA does not account for the mean
4 participants