Skip to content

[MRG] Add memory efficient implementation of NCA #99

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 9, 2018

Conversation

wdevazelhes
Copy link
Member

Fixes #45
This PR adds a more memory efficient implementation of the gradient of NCA, avoiding creation of a (n, n, d, d) matrix. This is taken from a PR I opened in scikit-learn: scikit-learn/scikit-learn#10058
Eventually, some others improvements of this PR (like verbose, checkers, choice of initializations) would be integrated in metric-learn, but I think for now just changing the gradient computation is fine to solve #45.

Here is a little snippet to show the performance gain:

from time import time
from sklearn.datasets import load_iris, load_digits
from metric_learn import NCA
for dataset in [load_iris(), load_digits()]:
	X, y = dataset.data, dataset.target
	nca = NCA()
	start = time()
	nca.fit(X, y)
	print(time() - start)

Prints for this PR:
0.1164708137512207 (for iris)
21.722617626190186 (for digits)

Prints for master:
0.9377930164337158 (for iris)
Memory Error (for digits)

Note: with an identity initialization, it prints:
0.10440230369567871 (for iris)
8.215924263000488 (for digits)
Therefore in this case the init is quite important to help the algorithm converge (identity seems a good init in this case)

William de Vazelhes added 3 commits June 29, 2018 18:55
… of NCA

- Make gradient computation more memory efficient
- Remove the hard-coded test but adds others
- Add deprecation for learning rate (not needed anymore)
- TST: test deprecation
- TST: force the algorithm to converge to pass test_iris using tol
- use checked labels instead of raw y
- update string representation with new arguments
Copy link
Member

@bellet bellet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM as a quick fix. You should remove the init parameter which is not used, we can add more init schemes (factorized across several algorithms) in a later PR.

Is there any documentation to update?


from .base_metric import BaseMetricLearner

EPS = np.finfo(float).eps


class NCA(BaseMetricLearner):
def __init__(self, num_dims=None, max_iter=100, learning_rate=0.01):
def __init__(self, num_dims=None, max_iter=100, learning_rate='deprecated',
init='auto', random_state=0, tol=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove init param

self.num_dims = num_dims
self.max_iter = max_iter
self.learning_rate = learning_rate
self.learning_rate = learning_rate # TODO: remove in v.0.5.0
self.init = init
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again

@wdevazelhes
Copy link
Member Author

LGTM as a quick fix. You should remove the init parameter which is not used, we can add more init schemes (factorized across several algorithms) in a later PR.

Thanks, I added it and removed it since it s a quick fix indeed, but forgot to remove these lines...

Is there any documentation to update?

There was no documentation before, so I did not add one, but now that you say it, I think toland random_state are less clear than learning_rate (and I should maybe warn again the fact that learning_rate is deprecated), so I think I will add some docstring to __init__

@wdevazelhes
Copy link
Member Author

I will also remove random_state since I had only put it for PCA initialization which is not done so it is not used

expected = [[-0.09935, -0.2215, 0.3383, 0.443],
[+0.2532, 0.5835, -0.8461, -0.8915],
[-0.729, -0.6386, 1.767, 1.832],
[-0.9405, -0.8461, 2.281, 2.794]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we verified that the new approach produces a correct result? Checking class separation is a very coarse approximation of correctness.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is more a general comment but IMHO this sort of test is not very reliable: (i) the "expected" output comes either from running the code itself or running some other implementation which may or may not be reliable (in this case, the source does not seem especially reliable), and (ii) NCA being a nonconvex objective, depending on the initialization but also on the chosen optimization algorithm, one might converge to a different point, which does not imply that the algorithm is incorrect.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

@wdevazelhes wdevazelhes Jul 4, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, the code in master uses stochastic gradient descent (updating L after a pass on each sample instead of the whole dataset). Therefore we will not be able to test this particular result (even if we tested without scipy's optimizer) since in this PR we compute the true full gradient at each iteration.

What is more, printing the loss in test_iris, the loss in this PR is better than the one in master (147.99999880164899 vs 144.57929636406135), so this adds to the argument that this hard coded array might not be such a good reference to check.

@wdevazelhes maybe some of the tests you designed for the sklearn version can be used instead?

Yes, I already added test_finite_differences (tests the gradient formula) and test_simple_example (toy example, that checks on 4 points that 2 same labeled points become closer after fit_transform) . But now that you say it I will also add test_singleton_class and test_one_class, which are also useful (they test edge cases on the dataset, and in some of these tests we know some analytical properties of the transformation: if only one class, or only singleton classes, then gradient is 0). However I don't think the other tests need to be included since they mostly test input formats, verbose, and other stuff that are necessary for inclusion in scikit-learn, but which I guess could be factored out for every algorithm in metric-learn in a later stage of development (and which were not enforced/tested for NCA before either so if we don't put them in this PR we don't regress).

# Compute loss
masked_p_ij = p_ij * mask
p = np.sum(masked_p_ij, axis=1, keepdims=True) # (n_samples, 1)
loss = np.sum(p)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stylistically, I prefer the method call p.sum() over the numpy function. But it's not a big deal if you like the other way better.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's true that p.sum is more concise, I guess I like it more too
I'll make the change

@bellet bellet merged commit faa240f into scikit-learn-contrib:master Jul 9, 2018
@wdevazelhes wdevazelhes deleted the fix/nca_perf branch August 22, 2018 06:50
@GeorgePearse
Copy link

GeorgePearse commented Nov 2, 2021

Hi @wdevazelhes, looks like you've done some great work here. What sort of relationship between dataset size and memory consumption does this method lead to? Also is there a verbose option anywhere or could one be added?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

NCA - performance issues
4 participants