Skip to content

FIX a bug in KernelPCA.inverse_transform #19732

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

kstoneriv3
Copy link
Contributor

@kstoneriv3 kstoneriv3 commented Mar 20, 2021

Reference Issues/PRs

Fix #18902.

As discussed in #18902, PCA reconstructs the mean of the data while KernelPCA does not. This results in inconsistent inverse transformation by PCA and (linear-)KernelPCA. This inconsistency led to the misunderstanding and the introduction of a bug in #16655.

What does this implement/fix? Explain your changes.

As discussed above, a bug was introduced in KernelPCA.inverse_transform in #16655. This PR removes this bug.

Additionally, I suggest a small modification to the KernelPCA.inverse_transform to improve the compatibility with PCA.inverse_transform when the linear kernel is used. I propose to handle this issue by reconstructing the mean when the linear kernel is used so that KernelPCA.inverse_transform can reconstruct the mean in the same way as PCA.inverse_transform does.

@kstoneriv3 kstoneriv3 changed the title fix a bug in KernelPCA.inverse_transform FIX a bug in KernelPCA.inverse_transform Mar 20, 2021
@glemaitre glemaitre self-requested a review April 9, 2021 16:19
@kstoneriv3 kstoneriv3 force-pushed the fix/kernel_pca_inverse_transform branch from 83a7e0a to c3822ac Compare April 11, 2021 10:03
@glemaitre
Copy link
Member

glemaitre commented Apr 14, 2021

We had a thorough look at this issue/PR with @ogrisel. So you are completely right that it's unjustified to add the regularization parameter at the reconstruction.

Indeed, we think that it is not really useful to account for the mean loss with the linear kernel. It makes the linear kernel special while one should use PCA in this case. We certainly need to improve the documentation.

So for this PR, we can revert the change that was adding the alpha in the diagonal of K. We need a better unit test though. One thing could be to check the reconstruction error by measuring the Frobenius norm between the original matrix and the reconstruction and check that it is close enough, and this with kernel.

In addition, we think that we can do the following improvements:

  • advocate to use PCA instead of kernel='linear' (or maybe there is some reasons that we did not think about).
  • improve the docstring of kernelPCA and notably inverse_transform to mention the approximation when reconstructing;
  • improve the user guide: (i) discuss inverse_transform and (ii) the reconstruction, discuss the impact of alpha;
  • improve the example: (i) use a test and not a single set to illustrate the reconstruction approximation, (ii) reproduce the denoising example from Section 4, (iii) illustrate intuitively on the denoising the effect of alpha.

In addition, I am thinking about some other improvements:

  • change the default of kernel to rbf. 'linear' as default is weird since one should use PCA instead;
  • it is confusing to have an alpha parameter and alphas_. It might be better to rename alphas_ to eigenvectors_ or components_ and lambdas_ to eigenvalues_ that are more explicit naming.
  • rename fit_inverse_transform to something more intuitive, e.g. enable_inverse_transform

@kstoneriv3 Would you mind to go forward by reverting the previous PR and improving the test? Do you wish to contribute to the future improvements?

@kstoneriv3
Copy link
Contributor Author

Indeed, we think that it is not really useful to account for the mean loss with the linear kernel. It makes the linear kernel special while one should use PCA in this case. We certainly need to improve the documentation.

So for this PR, we can revert the change that was adding the alpha in the diagonal of K.

I agree that we should rather revert the bug and improve the documentation.

We need a better unit test though. One thing could be to check the reconstruction error by measuring the Frobenius norm between the original matrix and the reconstruction and check that it is close enough, and this with kernel.

OK, I will replace the old test with this one.

  • advocate to use PCA instead of kernel='linear' (or maybe there is some reasons that we did not think about).

Do you mean raising a warning when kernel='linear' is used by the following?

I generally agree with all of the suggestions you made. If no further discussions are needed before these updates, I can do it. Would you like to have some of these updates in the form of separate PRs?

@ogrisel
Copy link
Member

ogrisel commented Apr 15, 2021

change the default of kernel to rbf. 'linear' as default is weird since one should use PCA instead;

if we do this, we should do it in a dedicated PR with backward compat. I think it's low priority.

@ogrisel
Copy link
Member

ogrisel commented Apr 15, 2021

Do you mean raising a warning when kernel='linear' is used by the following?

Improving the docstring would be enough.

@ogrisel
Copy link
Member

ogrisel commented Apr 15, 2021

Indeed separate PRs would be helpful. Let's start with undoing the inverse_transform alpha fix and write a new test as part of the current PR.

For instance for the test we can keep a test that checks that we can approximately recover the original data using a frobenius norm of ||X_test - X_test_reconstructed|| with some tolerance (and maybe with a low enough value for the kernel ridge penalty alpha and a large enough n_components and n_samples on the training set compared to the number of features.

For the test data, you could try to use sklearn.datasets.make_swiss_roll and split it into a train and test set using train test split.

@glemaitre glemaitre added the Bug label Apr 15, 2021
@glemaitre glemaitre added this to the 0.24.2 milestone Apr 15, 2021
@kstoneriv3
Copy link
Contributor Author

For instance for the test we can keep a test that checks that we can approximately recover the original data using a frobenius norm of ||X_test - X_test_reconstructed|| with some tolerance (and maybe with a low enough value for the kernel ridge penalty alpha and a large enough n_components and n_samples on the training set compared to the number of features.

For the test data, you could try to use sklearn.datasets.make_swiss_roll and split it into a train and test set using train test split.

I tried the Swiss roll but the reconstruction quality was not very good. This is because the Swiss roll kernel PCA's 'rbf' kernel cannot well capture the similarity of data points along the third axis. So I just used make_blobs in the test case.

@kstoneriv3
Copy link
Contributor Author

I will make changes to the document later.

@kstoneriv3
Copy link
Contributor Author

I have a question: It seems that centering of the kernel is not applied in the sklearn.kernel_approximation. What would be the disadvantage if we were to remove the centering of kernels in the first place?

@glemaitre
Copy link
Member

I have a question: It seems that centering of the kernel is not applied in the sklearn.kernel_approximation. What would be the disadvantage if we were to remove the centering of kernels in the first place?

Isn't it what KernelCenterer is doing?

@kstoneriv3
Copy link
Contributor Author

I have a question: It seems that centering of the kernel is not applied in the sklearn.kernel_approximation. What would be the disadvantage if we were to remove the centering of kernels in the first place?

Isn't it what KernelCenterer is doing?

It seems to me that KernelCenterer is applied in KernelPCA but not in sklearn.kernel_approximation. So I wondered what would be the pros and cons if we just remove kernel centering from KernelPCA.

@kstoneriv3
Copy link
Contributor Author

kstoneriv3 commented Apr 17, 2021

The progress status of this PR:

  • fix the bug
  • fix the test
  • advocate to use PCA instead of kernel='linear' (or maybe there is some reasons that we did not think about).
  • improve the docstring of kernelPCA and notably inverse_transform to mention the approximation when reconstructing;
  • improve the user guide: (i) discuss inverse_transform and (ii) the reconstruction, discuss the impact of alpha;

The followings can be included in this PR if we agree on what to rename them.

  • it is confusing to have an alpha parameter and alphas_. It might be better to rename alphas_ to eigenvectors_ or components_ and lambdas_ to eigenvalues_ that are more explicit naming.
    -> I agree with changing them to eigenvectors_ and eigenvalues_.
  • rename fit_inverse_transform to something more intuitive, e.g. enable_inverse_transform
    -> I am not particular about the name of this argument. fit_inverse_transform is a bit obscure but make it clear that we are using a learned pre-image.

The following should be in separate PRs I guess.

  • improve the example: (i) use a test and not a single set to illustrate the reconstruction approximation, (ii) reproduce the denoising example from Section 4, (iii) illustrate intuitively on the denoising the effect of alpha.
  • change the default of kernel to rbf. 'linear' as default is weird since one should use PCA instead;

@glemaitre
Copy link
Member

@kstoneriv3 I think that we can limit this PR to the first two points. It will be easier to review and merge.
Then we could make a PR for the change in the documentation and finally individual PRs for each of the API change.

So I will review shortly this PR (@ogrisel can make the second review maybe). In parallel, do not hesitate to open subsequent PRs already.

@kstoneriv3
Copy link
Contributor Author

kstoneriv3 commented Apr 17, 2021

@kstoneriv3 I think that we can limit this PR to the first two points. It will be easier to review and merge.
Then we could make a PR for the change in the documentation and finally individual PRs for each of the API change.

So I will review shortly this PR (@ogrisel can make the second review maybe). In parallel, do not hesitate to open subsequent PRs already.

OK, great. Then I will leave this PR as it is and make separate PRs for other issues.

@kstoneriv3
Copy link
Contributor Author

kstoneriv3 commented Apr 17, 2021

I tried to reproduce the "denoising example from Section 4" but the denoising quality is not as good as I expected (left: training images, middle test images to be denoised, right: denoised images). At this quality, I would not add a new example of denoising or mention the effect of alpha on denoising in the document.

This might be due to the fact that sklearn's kernel PCA shares kernel parameters while in the paper they use RBF kernel with different parameters for kernel PCA itself and preimage. The code is available at my gist.

Figure_1

@glemaitre
Copy link
Member

I tried to reproduce the exact example with the same hyperparameter as in the paper and the same dataset (USPS). I put the code in the "details" section. I get close results to the original paper:

fig_0
fig_1
fig_2
fig_3

# %%
from sklearn.datasets import fetch_openml

usps = fetch_openml(data_id=41082)

# %%
data = usps.data
target = usps.target

# %%
import numpy as np

img = np.reshape(data.iloc[0].to_numpy(), (16, 16))

# %%
import matplotlib.pyplot as plt

plt.imshow(img)

# %%
from sklearn.model_selection import train_test_split

data_rest, data_train, target_rest, target_train = train_test_split(
    data, target, stratify=target, random_state=42, test_size=100,
)
data_rest, data_test, target_rest, target_test = train_test_split(
    data_rest, target_rest, stratify=target_rest, random_state=42,
    test_size=100,
)
data_train, data_test = data_train.to_numpy(), data_test.to_numpy()

# %%
fig, axs = plt.subplots(nrows=10, ncols=10, figsize=(15, 15))
for img, ax in zip(data_test, axs.ravel()):
    ax.imshow(img.reshape((16, 16)), cmap="Greys")
    ax.axis("off")
_ = fig.suptitle("Uncorrupted test dataset")

# %%
rng = np.random.RandomState(0)
noise = rng.normal(scale=0.5, size=(data_train.shape))
data_test_corrupted = data_test + noise

# %%
fig, axs = plt.subplots(nrows=10, ncols=10, figsize=(15, 15))
for img, ax in zip(data_test_corrupted, axs.ravel()):
    ax.imshow(img.reshape((16, 16)), cmap="Greys")
    ax.axis("off")
_ = fig.suptitle(
    f"Corrupted test data: "
    f"MSE={np.mean((data_test - data_test_corrupted) ** 2):.2f}",
    size=26,
)

# %%
from sklearn.decomposition import KernelPCA

kpca = KernelPCA(
    n_components=80, kernel="rbf", gamma=0.5, fit_inverse_transform=True,
    alpha=1.0,
)

# %%
kpca.fit(data_train)

# %%
import pandas as pd

data_reconstruct = kpca.inverse_transform(kpca.transform(data_test))

# %%
fig, axs = plt.subplots(nrows=10, ncols=10, figsize=(15, 15))
for img, ax in zip(data_reconstruct, axs.ravel()):
    ax.imshow(img.reshape((16, 16)), cmap="Greys")
    ax.axis("off")
_ = fig.suptitle(
    f"Denoising using Kernel PCA with RBF kernel: "
    f"MSE={np.mean((data_test - data_reconstruct) ** 2):.2f}",
    size=26,
)

# %%
from sklearn.decomposition import PCA

pca = PCA(n_components=32)
pca.fit(data_train)
data_reconstruct =  pca.inverse_transform(pca.transform(data_test_corrupted))

# %%
fig, axs = plt.subplots(nrows=10, ncols=10, figsize=(15, 15))
for img, ax in zip(data_reconstruct, axs.ravel()):
    ax.imshow(img.reshape((16, 16)), cmap="Greys")
    ax.axis("off")
_ = fig.suptitle(
    f"Denosing using PCA: "
    f"MSE={np.mean((data_test - data_reconstruct) ** 2):.2f}",
    size=26
)

@glemaitre
Copy link
Member

We might want to play with the parameter alpha since I am not sure about the standardization which is not super precise and I am not sure which normalization was applied to the dataset available in OpenML.

@glemaitre
Copy link
Member

For instance, the results look better with a stronger regularization alpha=10:

fig_4

@kstoneriv3
Copy link
Contributor Author

@glemaitre Thank you! The denoising quality is surprisingly better! Now it makes sense to add this to examples.

Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I am wondering if we could make the reconstruction better but I don' t know if it is needed. I would rely on a review of @ogrisel

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A quick note and this should be good:

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
@glemaitre glemaitre merged commit 4946bfc into scikit-learn:main Apr 20, 2021
@glemaitre
Copy link
Member

@kstoneriv3 thanks for your work. I started to modify the example and we saw with @ogrisel that this bug was really affecting the results of the reconstruction. I will push my changes tomorrow for this example. If you want you can review it.

This was referenced Apr 21, 2021
glemaitre pushed a commit to glemaitre/scikit-learn that referenced this pull request Apr 22, 2021
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
smarie pushed a commit to smarie/scikit-learn that referenced this pull request Apr 27, 2021
smarie pushed a commit to smarie/scikit-learn that referenced this pull request Apr 27, 2021
glemaitre pushed a commit that referenced this pull request Apr 28, 2021
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
@ageron ageron mentioned this pull request Aug 27, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

A Bug at the inverse_transform of the KernelPCA
3 participants