Skip to content

[MRG+1] Reducing t-SNE memory usage #9032

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 55 commits into from
Jul 12, 2017
Merged

Conversation

tomMoral
Copy link
Contributor

@tomMoral tomMoral commented Jun 7, 2017

The barnes-hut algorithm for t-SNE currently have a O(N^2) memory complexity, it could use O(uN). This PR intend to improve the memory usage. (see Issue #7089)

Step to proceed

  • Only compute the nearest neighbors distances
  • Validate _barnes_hut2 with _barnes_hut functions
  • Check memory usage
  • Set default values of optimizer hyperparams to match the reference implementation and check that the results are qualitatively and quantitatively matching (compute trustworthiness)
  • Make TSNE raise a ValueError when n_components > 3 or n_components < 2 with the BH solver enabled.

Related

@jnothman
Copy link
Member

jnothman commented Jun 7, 2017 via email

@jnothman
Copy link
Member

jnothman commented Jun 7, 2017 via email

@vene
Copy link
Member

vene commented Jun 7, 2017

Is there a reason why you opted to make a new Cython file rather than change the existing one and rely on unit tests for ensuring the same behaviour? In particular, did you spot any differences in your refactoring that could turn to regression tests?

Just saying because the current file layout makes reviewing somewhat more difficult.

(answered irl, my bad for rushing into this pr)

# set the neighbors to n - 1
distances_nn, neighbors_nn = knn.kneighbors(
X, n_neighbors=k + 1)
distances_nn = distances_nn[:, 1:]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't quite doing the right thing. If two instances are identical, you can't be certain which will be output first. Use kneighbors(None) instead.

if self.metric == 'precomputed':
# Use the precomputed distances to find
# the k nearest neighbors and their distances
neighbors_nn = np.argsort(distances, axis=1)[:, :k]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just use NearestNeighbors(metric='precomputed') for this.

else:
# Find the nearest neighbors for every point
# TODO: argument for class knn_estimator=None
# TODO: assert that the knn metric is euclidean
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we use another metric?

printf("[t-SNE] [d=%i] Inserting pos %i [%f, %f] duplicate_count=%i "
"into child %p\n", depth, point_index, pos[0], pos[1],
printf("[t-SNE] [d=%li] Inserting pos %li [%f, %f] duplicate_count=%li"
" into child %p\n", depth, point_index, pos[0], pos[1],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using %li is required when parsing long integer to avoid compiler warnings.

@tomMoral tomMoral force-pushed the optim_tsne branch 3 times, most recently from 20f7ebf to dfa8985 Compare June 8, 2017 12:43
float[:,:] pos_reference,
np.int64_t[:,:] neighbors,
np.int64_t[:] neighbors,
np.int64_t[:] indptr,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we really do have n_neighbors non-zero values in each row, I think the previous approach with a 2d array of (samples, neighbors) was better than having an indptr design. Why did you change it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes it easy to efficiently symmetrize the conditional_P matrix using scipy operation for sparse matrices.
Symmetrization is done in the reference implementation so I tried to be as close as possible to it. Although I haven't reviewed it all yet.

range(0, n_samples * K + 1, K)),
shape=(n_samples, n_samples))

P = P + P.T
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sparse symmetrization is done here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indicate it as a comment in the code, rather than a comment in the PR :)

@tomMoral tomMoral force-pushed the optim_tsne branch 5 times, most recently from 8d8c70d to 312d9f0 Compare June 8, 2017 16:35
@tomMoral
Copy link
Contributor Author

tomMoral commented Jun 8, 2017

Some benchmark performance on MNIST against the reference implementation. The script to reproduce is included in the benchmark folder.
I have not yet run it on the full MNIST dataset as it should take more than 1h30 per implementation but the number looks okay on subsamples and the memory usage does not grow quadratically with the number of sample anymore.

There is probably room for improvment, but maybe in the next PR? :)

$ python bench_tsne_mnist.py 
Fitting TSNE on 100 samples...
Fitting T-SNE on 100 samples took 0.226s
Fitting bhtsne on 100 samples took 0.797s
Fitting TSNE on 1000 samples...
Fitting T-SNE on 1000 samples took 4.462s
Fitting bhtsne on 1000 samples took 6.774s
Fitting TSNE on 10000 samples...
Fitting T-SNE on 10000 samples took 133.299s
Fitting bhtsne on 10000 samples took 158.046s

Here is the memory usage reported with memory_profiler.

bench_mnist_tsne

The memory grows most when NN build its ball tree. Then it drops a bit. In any case, it stays around 300MB for 10000 samples.

EDIT: a run of this bench with master gives:

$ python bench_tsne_mnist.py 
Fitting TSNE on 100 samples...
Fitting T-SNE on 100 samples took 0.269s
Fitting TSNE on 1000 samples...
Fitting T-SNE on 1000 samples took 4.811s
Fitting TSNE on 5000 samples...
Fitting T-SNE on 5000 samples took 62.745s
Fitting TSNE on 10000 samples...
Fitting T-SNE on 10000 samples took 271.673s

So this PR is a bit faster.

@@ -79,34 +81,20 @@ cpdef np.ndarray[np.float32_t, ndim=2] _binary_search_perplexity(
# Compute current entropy and corresponding probabilities
# computed just over the nearest neighbors or over all data
# if we're not using neighbors
Copy link
Member

@vene vene Jun 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment seems outdated, what does it mean "if we're not using neighbors" now? sorry, I misread

@ogrisel
Copy link
Member

ogrisel commented Jun 8, 2017

FYI, I am currently running a memory benchmark on the full (70000 samples) MNIST dataset. The memory usage of my python process is constant at 1.2GB.

Update: running MNIST on 70000 samples took ~52 minutes. Here is the memory profile:

tsne_mnist

Most of the time is spent in the ball tree.

Update 2*: : running MNIST on 70000 samples took ~53 minutes with the reference implementation and 1.3GB or RAM (basically same behavior as or Cython impl in this PR).

@jnothman
Copy link
Member

jnothman commented Jun 9, 2017

Yes, I was going to ask how much of the time was in KNN.

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If ball tree is the slowest part:

  • we should consider offering n_jobs if multithreaded balltree queries help reduce runtime
  • we might check how the current implementation compares to other single-processor implementations to check our timings are in the ballpark of state of the art. I just realised you did that above

metric=self.metric)
knn.fit(X)
# LvdM uses 3 * perplexity as the number of neighbors
# And we add one to not count the data point itself
Copy link
Member

@jnothman jnothman Jun 9, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do not add 1; you could explain why we pass None, or more explicitly use neighbors.kneighbors_graph(X, include_self=False).

# And we add one to not count the data point itself
# In the event that we have very small # of points
# set the neighbors to n - 1
distances_nn, neighbors_nn = knn.kneighbors(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could use kneighbors_graph which will already return a sparse matrix representation. This will allow us to rely on kneighbors_graph being as memory efficient as possible, rather than putting it into a sparse matrix in _joint_probabilities_nn.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not put the distances_nn in a sparse matrix in _join_probabilities_nn but we put the conditional_P as a sparse matrix for symmetrization purpose.
Using kneighbors_graph would increase the memory to store range(0, n_samples * K + 1, K). I don't see a need to use it except if it is more efficient internally.
What do you think?

Copy link
Member

@ogrisel ogrisel Jun 9, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need any of the features of the scipy sparse matrix API for the neighbors info itself (we call into Cython code directly after its computation).

Only the conditional probability matrix benefits from being represented as a scipy sparse matrix to make it a one liner python snippet to do the symmetrization.

@tomMoral
Copy link
Contributor Author

tomMoral commented Jun 9, 2017

I ran the benchmark with n_jobs=6

Fitting TSNE on 100 samples...
   Fitting KNN: 0.0008363723754882812
   Predict KNN: 0.1035609245300293
Fitting T-SNE on 100 samples took 0.377s
Fitting TSNE on 1000 samples...
    Fitting KNN: 0.02005767822265625
    Predict KNN: 0.30463409423828125
Fitting T-SNE on 1000 samples took 3.960s
Fitting TSNE on 5000 samples...
    Fitting KNN: 0.14362692832946777
    Predict KNN: 5.516682147979736
Fitting T-SNE on 5000 samples took 17.994s
Fitting TSNE on 10000 samples...
    Fitting KNN: 0.3776566982269287
    Predict KNN: 23.55470323562622
Fitting T-SNE on 10000 samples took 53.554s
Fitting TSNE on 70000 samples...
    Fitting KNN: 13.219616413116455
    Predict KNN: 1376.9010944366455
Fitting T-SNE on 70000 samples took 1579.028s

There is a non-negligible speedup as it took a bit less than 30min to fit the full dataset.
The most expensive computation is still the knn.

@ogrisel
Copy link
Member

ogrisel commented Jun 9, 2017

I am currently running the original bhtsne code on 70000 MNIST and the memory usage of the python process is 1.3GB so I think we can officially declare that our implementation is as memory efficient as it should be 🎉

Update: the original bhtsne code on 70000 MNIST took 53min which is the same as our Cython code: 🎉²

@agramfort
Copy link
Member

agramfort commented Jun 9, 2017 via email

@ogrisel
Copy link
Member

ogrisel commented Jun 9, 2017

Here is a snippet to plot the resulting embedding:

from sklearn.datasets import fetch_mldata
from sklearn.utils import check_array
import matplotlib.pyplot as plt
import numpy as np

X_embedded = np.load('mnist_tsne_70000.npy')


def load_data(dtype=np.float32, order='C'):
    """Load the data, then cache and memmap the train/test split"""
    print("Loading dataset...")
    data = fetch_mldata('MNIST original')
    X = check_array(data['data'], dtype=dtype, order=order)
    y = data["target"]

    # Normalize features
    X /= 255
    return X, y

_, y = load_data()


plt.figure(figsize=(12, 12))
for c in np.unique(y):
    X_c = X_embedded[y == c]
    plt.scatter(X_c[:, 0], X_c[:, 1], alpha=0.1, label=int(c))
plt.legend(loc='best')

With my 52min run with the code of this PR, this yields:

image

This might have converged to a local minima but it's seem to work well enough

Update: here is the output of the reference implementation on the same 70000 samples MNIST dataset:

image

We probably have discrepancies in the hyperparameters that are worth investigating more thoroughly before merging this PR.

@ogrisel
Copy link
Member

ogrisel commented Jun 9, 2017

In ran another session with PCA preprocessing of the MNIST 70000 dataset with 50 principal components:

Preprocessing the data with 50 dim PCA...
PCA took 4.054s
Fitting TSNE on 100 samples...
Fitting T-SNE on 100 samples took 0.164s
Fitting bhtsne on 100 samples took 0.346s
Fitting TSNE on 1000 samples...
Fitting T-SNE on 1000 samples took 2.488s
Fitting bhtsne on 1000 samples took 4.500s
Fitting TSNE on 5000 samples...
Fitting T-SNE on 5000 samples took 9.055s
Fitting bhtsne on 5000 samples took 26.767s
Fitting TSNE on 10000 samples...
Fitting T-SNE on 10000 samples took 26.286s
Fitting bhtsne on 10000 samples took 65.405s
Fitting TSNE on 70000 samples...
Fitting T-SNE on 70000 samples took 665.962s
Fitting bhtsne on 70000 samples took 870.480s

Our PR (reported as "T-SNE" in the benchmark script output):

image

Reference implementation (bhtsne):

image

I have not changed anything in the way we set the default hyperparameters: this is still to investigate. However it shows that we should probably do a X = PCA(n_components=50).fit_transform(X) in the benchmark script as the results are visually similar to applying the TSNE directly to the original 784-dimensional data.

It's interesting to note that in that lower dimensional regime we are significantly faster than the reference implementation.

Also I think we can probably do an MNIST example with 5000 samples by default in the scikit-learn doc (with 50-dim PCA preprocessing).

@ogrisel ogrisel force-pushed the optim_tsne branch 3 times, most recently from 0b7537c to d2d64ff Compare June 10, 2017 12:56
@@ -803,6 +803,8 @@ def _tsne(self, P, degrees_of_freedom, n_samples, random_state, X_embedded,
if self.method == 'barnes_hut':
obj_func = _kl_divergence_bh
opt_args['kwargs']['angle'] = self.angle
# Repeat verbose argument for _kl_divergence_bh
opt_args['kwargs']['verbose'] = self.verbose
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch.

@ogrisel
Copy link
Member

ogrisel commented Jul 12, 2017

@tomMoral I am investigating why the errors are often constant: I think that's because we truncate with too large of an EPSILON in the compute_gradient_positive. I am working on a fix.

@ogrisel
Copy link
Member

ogrisel commented Jul 12, 2017

Alright the bug on the error computation is fixed, all tests pass (the circle ci failure is the stock market stuff), examples look good and the MNIST benchmark is both accurate, reasonably fast and memory efficient.

Merging!

@GaelVaroquaux
Copy link
Member

GaelVaroquaux commented Jul 12, 2017 via email

@ogrisel ogrisel merged commit cb1b6c4 into scikit-learn:master Jul 12, 2017
@ogrisel
Copy link
Member

ogrisel commented Jul 12, 2017

There is still plenty of things to improve in that code base but I think this is a great first step. Thanks @tomMoral for your patience :)

@tomMoral
Copy link
Contributor Author

Thanks @ogrisel for fixing the EPSILON bug!
Hopefully, I will get some time to work on some acceleration of our t-SNE soon!

@tomMoral tomMoral deleted the optim_tsne branch July 12, 2017 21:51
@jnothman
Copy link
Member

jnothman commented Jul 13, 2017 via email

massich pushed a commit to massich/scikit-learn that referenced this pull request Jul 13, 2017
…9032)

Use a sparse matrix representation of the neighbors.
Re-factored the QuadTree implementation to avoid insertion errors.
Various fixes in the gradient descent schedule to get the Barnes Hut and exact solvers to behave more robustly and consistently.
yarikoptic added a commit to yarikoptic/scikit-learn that referenced this pull request Jul 27, 2017
Release 0.19b2

* tag '0.19b2': (808 commits)
  Preparing 0.19b2
  [MRG+1] FIX out of bounds array access in SAGA (scikit-learn#9376)
  FIX make test_importances pass on 32 bit linux
  Release 0.19b1
  DOC remove 'in dev' header in whats_new.rst
  DOC typos in whats_news.rst [ci skip]
  [MRG] DOC cleaning up what's new for 0.19 (scikit-learn#9252)
  FIX t-SNE memory usage and many other optimizer issues (scikit-learn#9032)
  FIX broken link in gallery and bad title rendering
  [MRG] DOC Replace \acute by prime (scikit-learn#9332)
  Fix typos (scikit-learn#9320)
  [MRG + 1 (rv) + 1 (alex) + 1] Add a check to test the docstring params and their order (scikit-learn#9206)
  DOC Residual sum vs. regression sum (scikit-learn#9314)
  [MRG] [HOTFIX] Fix capitalization in test and hence fix failing travis at master (scikit-learn#9317)
  More informative error message for classification metrics given regression output (scikit-learn#9275)
  [MRG] COSMIT Remove unused parameters in private functions (scikit-learn#9310)
  [MRG+1] Ridgecv normalize (scikit-learn#9302)
  [MRG + 2] ENH Allow `cross_val_score`, `GridSearchCV` et al. to evaluate on multiple metrics (scikit-learn#7388)
  Add data_home parameter to fetch_kddcup99 (scikit-learn#9289)
  FIX makedirs(..., exists_ok) not available in Python 2 (scikit-learn#9284)
  ...
yarikoptic added a commit to yarikoptic/scikit-learn that referenced this pull request Jul 27, 2017
* releases: (808 commits)
  Preparing 0.19b2
  [MRG+1] FIX out of bounds array access in SAGA (scikit-learn#9376)
  FIX make test_importances pass on 32 bit linux
  Release 0.19b1
  DOC remove 'in dev' header in whats_new.rst
  DOC typos in whats_news.rst [ci skip]
  [MRG] DOC cleaning up what's new for 0.19 (scikit-learn#9252)
  FIX t-SNE memory usage and many other optimizer issues (scikit-learn#9032)
  FIX broken link in gallery and bad title rendering
  [MRG] DOC Replace \acute by prime (scikit-learn#9332)
  Fix typos (scikit-learn#9320)
  [MRG + 1 (rv) + 1 (alex) + 1] Add a check to test the docstring params and their order (scikit-learn#9206)
  DOC Residual sum vs. regression sum (scikit-learn#9314)
  [MRG] [HOTFIX] Fix capitalization in test and hence fix failing travis at master (scikit-learn#9317)
  More informative error message for classification metrics given regression output (scikit-learn#9275)
  [MRG] COSMIT Remove unused parameters in private functions (scikit-learn#9310)
  [MRG+1] Ridgecv normalize (scikit-learn#9302)
  [MRG + 2] ENH Allow `cross_val_score`, `GridSearchCV` et al. to evaluate on multiple metrics (scikit-learn#7388)
  Add data_home parameter to fetch_kddcup99 (scikit-learn#9289)
  FIX makedirs(..., exists_ok) not available in Python 2 (scikit-learn#9284)
  ...
yarikoptic added a commit to yarikoptic/scikit-learn that referenced this pull request Jul 27, 2017
* dfsg: (808 commits)
  Preparing 0.19b2
  [MRG+1] FIX out of bounds array access in SAGA (scikit-learn#9376)
  FIX make test_importances pass on 32 bit linux
  Release 0.19b1
  DOC remove 'in dev' header in whats_new.rst
  DOC typos in whats_news.rst [ci skip]
  [MRG] DOC cleaning up what's new for 0.19 (scikit-learn#9252)
  FIX t-SNE memory usage and many other optimizer issues (scikit-learn#9032)
  FIX broken link in gallery and bad title rendering
  [MRG] DOC Replace \acute by prime (scikit-learn#9332)
  Fix typos (scikit-learn#9320)
  [MRG + 1 (rv) + 1 (alex) + 1] Add a check to test the docstring params and their order (scikit-learn#9206)
  DOC Residual sum vs. regression sum (scikit-learn#9314)
  [MRG] [HOTFIX] Fix capitalization in test and hence fix failing travis at master (scikit-learn#9317)
  More informative error message for classification metrics given regression output (scikit-learn#9275)
  [MRG] COSMIT Remove unused parameters in private functions (scikit-learn#9310)
  [MRG+1] Ridgecv normalize (scikit-learn#9302)
  [MRG + 2] ENH Allow `cross_val_score`, `GridSearchCV` et al. to evaluate on multiple metrics (scikit-learn#7388)
  Add data_home parameter to fetch_kddcup99 (scikit-learn#9289)
  FIX makedirs(..., exists_ok) not available in Python 2 (scikit-learn#9284)
  ...
dmohns pushed a commit to dmohns/scikit-learn that referenced this pull request Aug 7, 2017
…9032)

Use a sparse matrix representation of the neighbors.
Re-factored the QuadTree implementation to avoid insertion errors.
Various fixes in the gradient descent schedule to get the Barnes Hut and exact solvers to behave more robustly and consistently.
dmohns pushed a commit to dmohns/scikit-learn that referenced this pull request Aug 7, 2017
…9032)

Use a sparse matrix representation of the neighbors.
Re-factored the QuadTree implementation to avoid insertion errors.
Various fixes in the gradient descent schedule to get the Barnes Hut and exact solvers to behave more robustly and consistently.
NelleV pushed a commit to NelleV/scikit-learn that referenced this pull request Aug 11, 2017
…9032)

Use a sparse matrix representation of the neighbors.
Re-factored the QuadTree implementation to avoid insertion errors.
Various fixes in the gradient descent schedule to get the Barnes Hut and exact solvers to behave more robustly and consistently.
paulha pushed a commit to paulha/scikit-learn that referenced this pull request Aug 19, 2017
…9032)

Use a sparse matrix representation of the neighbors.
Re-factored the QuadTree implementation to avoid insertion errors.
Various fixes in the gradient descent schedule to get the Barnes Hut and exact solvers to behave more robustly and consistently.
AishwaryaRK pushed a commit to AishwaryaRK/scikit-learn that referenced this pull request Aug 29, 2017
…9032)

Use a sparse matrix representation of the neighbors.
Re-factored the QuadTree implementation to avoid insertion errors.
Various fixes in the gradient descent schedule to get the Barnes Hut and exact solvers to behave more robustly and consistently.
@rth rth mentioned this pull request Sep 5, 2017
maskani-moh pushed a commit to maskani-moh/scikit-learn that referenced this pull request Nov 15, 2017
…9032)

Use a sparse matrix representation of the neighbors.
Re-factored the QuadTree implementation to avoid insertion errors.
Various fixes in the gradient descent schedule to get the Barnes Hut and exact solvers to behave more robustly and consistently.
jwjohnson314 pushed a commit to jwjohnson314/scikit-learn that referenced this pull request Dec 18, 2017
…9032)

Use a sparse matrix representation of the neighbors.
Re-factored the QuadTree implementation to avoid insertion errors.
Various fixes in the gradient descent schedule to get the Barnes Hut and exact solvers to behave more robustly and consistently.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants