Skip to content

[MRG+1] Reducing t-SNE memory usage #9032

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 55 commits into from
Jul 12, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
bdaf924
ENH improve memory usage of t-sne
tomMoral Jun 8, 2017
3f52705
ENH add a benchmark for t-SNE
tomMoral Jun 8, 2017
cb258ff
ENH improve bench script+add params for TSNE
tomMoral Jun 9, 2017
f77a9d7
ENH add PCA preprocessing to benchmark [ci skip]
tomMoral Jun 9, 2017
22fb977
ENH improve benchmark script
ogrisel Jun 10, 2017
ed4182c
WIP on fixing the optimization schedule and default parameters
ogrisel Jun 10, 2017
cf851f0
Shuffle the data in the benchmark script
ogrisel Jun 10, 2017
a6a28bd
WIP more work on opt scheduling and stopping criterion
ogrisel Jun 13, 2017
e4a0d24
ENH add basic QuadTree in neighbors
tomMoral Jun 18, 2017
02af993
ENH TSNE running with new QuadTree
tomMoral Jun 18, 2017
854d708
FIX picklizable QuadTree + more tests
tomMoral Jun 19, 2017
04ffd54
Compute first NN accuracy
ogrisel Jun 20, 2017
d4f62d2
TST independently seeded rng
tomMoral Jun 20, 2017
ad449ac
Add script to plot TSNE benchmark results
ogrisel Jun 20, 2017
a98d571
CLN remove unsused code
tomMoral Jun 20, 2017
bd96c3a
TST improve parameter testing
tomMoral Jun 21, 2017
acd2666
CLN improve code readability
tomMoral Jun 22, 2017
ffcb6f4
CLN more comments in quad_tree
tomMoral Jun 22, 2017
bebf2d2
CLN remove malloc when possible + trailing white-space
tomMoral Jun 22, 2017
0efe4e5
CLN fix pep8+add comments+rename methods quad_tree
tomMoral Jun 26, 2017
067ec73
CLN pep8+typo+remove plot
tomMoral Jun 29, 2017
71fcd29
CLN remove knn extra args for TSNE
tomMoral Jun 29, 2017
b3276eb
TST reduce test time
tomMoral Jun 29, 2017
1db2871
FIX flake8 unused import
tomMoral Jun 29, 2017
98c64ab
TST make test_preserve_trustworthiness_approximately less strict
ogrisel Jun 29, 2017
d043955
DOC adjust the perplexity range in t-SNE example
ogrisel Jun 30, 2017
487a46b
CLN clarify quad_tree header
tomMoral Jul 3, 2017
f7cdb2e
CLN comment and typo+early free knn
tomMoral Jul 3, 2017
6619f6a
CLN add what's new entry
tomMoral Jul 5, 2017
dec5fb9
TST improve test_t_sne for changed parameters
tomMoral Jul 5, 2017
d7991bb
FIX what's new entry
tomMoral Jul 6, 2017
476aeb6
FIX typo in test and reduce test time
tomMoral Jul 6, 2017
f753da8
Add uniform grid to perplexity example
ogrisel Jul 6, 2017
fdc16d6
TST 2D uniform grid recovery by TSNE
ogrisel Jul 6, 2017
63aecca
TST simpler code
ogrisel Jul 6, 2017
4639bb4
TST fix comment in check_uniform_grid
ogrisel Jul 6, 2017
0509d12
ENH simplify example usage in TSNE docstring
ogrisel Jul 7, 2017
e802689
CLN max_width->squared_max_width to improve code readability
tomMoral Jul 7, 2017
c0cc2f2
FIX duplicated point in quad_tree summary
tomMoral Jul 9, 2017
0737deb
TST add summary test for quad_tree
tomMoral Jul 9, 2017
39f83de
TST fix t-SNE tests for uniform grid
tomMoral Jul 9, 2017
e242881
CLN unified TSNE stopping criterion + comments
tomMoral Jul 10, 2017
b7629cc
FIX pep8+typo/ more log for tests failures
tomMoral Jul 10, 2017
73cb421
FIX quad_tree boundary computations
tomMoral Jul 11, 2017
ae12767
ENH use dbl tsne_bh sum_Q to match exact solver
tomMoral Jul 11, 2017
8a9c9e7
CLN param description+Add private optim controls
tomMoral Jul 11, 2017
ec80a29
FIX TSNE set n_iter_without_progress=300 by default
ogrisel Jul 11, 2017
2b26219
FIX bench_tsne_mnist.py: no n_jobs for now
ogrisel Jul 11, 2017
8e0e798
FIX various optimization schedule issues in TSNE
ogrisel Jul 11, 2017
4a71ecb
TST fix broken test in tsne
ogrisel Jul 11, 2017
9024112
FIX debug code in bench script [ci skip]
ogrisel Jul 11, 2017
a274ab6
TST compare KL error of BH approx vs exact with angle=0
ogrisel Jul 11, 2017
5b36cc7
CLN verbose log more informative
tomMoral Jul 12, 2017
f883519
CLN left out debug print
ogrisel Jul 12, 2017
e9d5c92
FIX numerical precision in Barnes Hut error computation
ogrisel Jul 12, 2017
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions benchmarks/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
/bhtsne
*.npy
*.json
/mnist_tsne_output/
169 changes: 169 additions & 0 deletions benchmarks/bench_tsne_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
"""
=============================
MNIST dataset T-SNE benchmark
=============================

"""
from __future__ import division, print_function

# License: BSD 3 clause

import os
import os.path as op
from time import time
import numpy as np
import json
import argparse

from sklearn.externals.joblib import Memory
from sklearn.datasets import fetch_mldata
from sklearn.manifold import TSNE
from sklearn.neighbors import NearestNeighbors
from sklearn.decomposition import PCA
from sklearn.utils import check_array
from sklearn.utils import shuffle as _shuffle


LOG_DIR = "mnist_tsne_output"
if not os.path.exists(LOG_DIR):
os.mkdir(LOG_DIR)


memory = Memory(os.path.join(LOG_DIR, 'mnist_tsne_benchmark_data'),
mmap_mode='r')


@memory.cache
def load_data(dtype=np.float32, order='C', shuffle=True, seed=0):
"""Load the data, then cache and memmap the train/test split"""
print("Loading dataset...")
data = fetch_mldata('MNIST original')

X = check_array(data['data'], dtype=dtype, order=order)
y = data["target"]

if shuffle:
X, y = _shuffle(X, y, random_state=seed)

# Normalize features
X /= 255
return X, y


def nn_accuracy(X, X_embedded, k=1):
"""Accuracy of the first nearest neighbor"""
knn = NearestNeighbors(n_neighbors=1, n_jobs=-1)
_, neighbors_X = knn.fit(X).kneighbors()
_, neighbors_X_embedded = knn.fit(X_embedded).kneighbors()
return np.mean(neighbors_X == neighbors_X_embedded)


def tsne_fit_transform(model, data):
transformed = model.fit_transform(data)
return transformed, model.n_iter_


def sanitize(filename):
return filename.replace("/", '-').replace(" ", "_")


if __name__ == "__main__":
parser = argparse.ArgumentParser('Benchmark for t-SNE')
parser.add_argument('--order', type=str, default='C',
help='Order of the input data')
parser.add_argument('--perplexity', type=float, default=30)
parser.add_argument('--bhtsne', action='store_true',
help="if set and the reference bhtsne code is "
"correctly installed, run it in the benchmark.")
parser.add_argument('--all', action='store_true',
help="if set, run the benchmark with the whole MNIST."
"dataset. Note that it will take up to 1 hour.")
parser.add_argument('--profile', action='store_true',
help="if set, run the benchmark with a memory "
"profiler.")
parser.add_argument('--verbose', type=int, default=0)
parser.add_argument('--pca-components', type=int, default=50,
help="Number of principal components for "
"preprocessing.")
args = parser.parse_args()

X, y = load_data(order=args.order)

if args.pca_components > 0:
t0 = time()
X = PCA(n_components=args.pca_components).fit_transform(X)
print("PCA preprocessing down to {} dimensions took {:0.3f}s"
.format(args.pca_components, time() - t0))

methods = []

# Put TSNE in methods
tsne = TSNE(n_components=2, init='pca', perplexity=args.perplexity,
verbose=args.verbose, n_iter=1000)
methods.append(("sklearn TSNE",
lambda data: tsne_fit_transform(tsne, data)))

if args.bhtsne:
try:
from bhtsne.bhtsne import run_bh_tsne
except ImportError:
raise ImportError("""\
If you want comparison with the reference implementation, build the
binary from source (https://github.com/lvdmaaten/bhtsne) in the folder
benchmarks/bhtsne and add an empty `__init__.py` file in the folder:

$ git clone git@github.com:lvdmaaten/bhtsne.git
$ cd bhtsne
$ g++ sptree.cpp tsne.cpp tsne_main.cpp -o bh_tsne -O2
$ touch __init__.py
$ cd ..
""")

def bhtsne(X):
"""Wrapper for the reference lvdmaaten/bhtsne implementation."""
# PCA preprocessing is done elsewhere in the benchmark script
n_iter = -1 # TODO find a way to report the number of iterations
return run_bh_tsne(X, use_pca=False, perplexity=args.perplexity,
verbose=args.verbose > 0), n_iter
methods.append(("lvdmaaten/bhtsne", bhtsne))

if args.profile:

try:
from memory_profiler import profile
except ImportError:
raise ImportError("To run the benchmark with `--profile`, you "
"need to install `memory_profiler`. Please "
"run `pip install memory_profiler`.")
methods = [(n, profile(m)) for n, m in methods]

data_size = [100, 500, 1000, 5000, 10000]
if args.all:
data_size.append(70000)

results = []
basename, _ = os.path.splitext(__file__)
log_filename = os.path.join(LOG_DIR, basename + '.json')
for n in data_size:
X_train = X[:n]
y_train = y[:n]
n = X_train.shape[0]
for name, method in methods:
print("Fitting {} on {} samples...".format(name, n))
t0 = time()
np.save(os.path.join(LOG_DIR, 'mnist_{}_{}.npy'
.format('original', n)), X_train)
np.save(os.path.join(LOG_DIR, 'mnist_{}_{}.npy'
.format('original_labels', n)), y_train)
X_embedded, n_iter = method(X_train)
duration = time() - t0
precision_5 = nn_accuracy(X_train, X_embedded)
print("Fitting {} on {} samples took {:.3f}s in {:d} iterations, "
"nn accuracy: {:0.3f}".format(
name, n, duration, n_iter, precision_5))
results.append(dict(method=name, duration=duration, n_samples=n))
with open(log_filename, 'w', encoding='utf-8') as f:
json.dump(results, f)
method_name = sanitize(name)
np.save(op.join(LOG_DIR, 'mnist_{}_{}.npy'.format(method_name, n)),
X_embedded)
30 changes: 30 additions & 0 deletions benchmarks/plot_tsne_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import matplotlib.pyplot as plt
import numpy as np
import os.path as op

import argparse


LOG_DIR = "mnist_tsne_output"


if __name__ == "__main__":
parser = argparse.ArgumentParser('Plot benchmark results for t-SNE')
parser.add_argument(
'--labels', type=str,
default=op.join(LOG_DIR, 'mnist_original_labels_10000.npy'),
help='1D integer numpy array for labels')
parser.add_argument(
'--embedding', type=str,
default=op.join(LOG_DIR, 'mnist_sklearn_TSNE_10000.npy'),
help='2D float numpy array for embedded data')
args = parser.parse_args()

X = np.load(args.embedding)
y = np.load(args.labels)

for i in np.unique(y):
mask = y == i
plt.scatter(X[mask, 0], X[mask, 1], alpha=0.2, label=int(i))
plt.legend(loc='best')
plt.show()
17 changes: 17 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ occurs due to changes in the modelling logic (bug fixes or enhancements), or in
random sampling procedures.

* :class:`sklearn.ensemble.IsolationForest` (bug fix)
* :class:`sklearn.manifold.TSNE` (bug fix)

Details are listed in the changelog below.

Expand Down Expand Up @@ -245,6 +246,14 @@ Enhancements
- Speed improvements to :class:`model_selection.StratifiedShuffleSplit`.
:issue:`5991` by :user:`Arthur Mensch <arthurmensch>` and `Joel Nothman`_.

- Memory improvements for method barnes_hut in :class:`manifold.TSNE`
:issue:`7089` by :user:`Thomas Moreau <tomMoral>` and `Olivier Grisel`_.

- Optimization schedule improvements for so the results are closer to the
one from the reference implementation
`lvdmaaten/bhtsne <https://github.com/lvdmaaten/bhtsne>`_ by
:user:`Thomas Moreau <tomMoral>` and `Olivier Grisel`_.

Bug fixes
.........

Expand Down Expand Up @@ -478,6 +487,14 @@ Bug fixes
and :class:`linear_model.Ridge` when using ``normalize=True``
by `Alexandre Gramfort`_.

- Fixed the implementation of :class:`manifold.TSNE`:
- ``early_exageration`` parameter had no effect and is now used for the
first 250 optimization iterations.
- Fixed the ``InsersionError`` reported in :issue:`8992`.
- Improve the learning schedule to match the one from the reference
implementation `lvdmaaten/bhtsne <https://github.com/lvdmaaten/bhtsne>`_.
by :user:`Thomas Moreau <tomMoral>` and `Olivier Grisel`_.

API changes summary
-------------------

Expand Down
44 changes: 38 additions & 6 deletions examples/manifold/plot_t_sne_perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
As shown below, t-SNE for higher perplexities finds meaningful topology of
two concentric circles, however the size and the distance of the circles varies
slightly from the original. Contrary to the two circles dataset, the shapes
visually diverge from S-curve topology on the S-curve dateset even for
visually diverge from S-curve topology on the S-curve dataset even for
larger perplexity values.

For further details, "How to Use t-SNE Effectively"
Expand All @@ -28,16 +28,17 @@

print(__doc__)

import numpy as np
import matplotlib.pyplot as plt

from matplotlib.ticker import NullFormatter
from sklearn import manifold, datasets
from time import time

n_samples = 500
n_samples = 300
n_components = 2
(fig, subplots) = plt.subplots(2, 5, figsize=(15, 8))
perplexities = [5, 50, 100, 150]
(fig, subplots) = plt.subplots(3, 5, figsize=(15, 8))
perplexities = [5, 30, 50, 100]

X, y = datasets.make_circles(n_samples=n_samples, factor=.5, noise=.05)

Expand Down Expand Up @@ -71,7 +72,7 @@
X, color = datasets.samples_generator.make_s_curve(n_samples, random_state=0)

ax = subplots[1][0]
ax.scatter(X[:, 0], X[:, 2], c=color, cmap=plt.cm.Spectral)
ax.scatter(X[:, 0], X[:, 2], c=color, cmap=plt.cm.viridis)
ax.xaxis.set_major_formatter(NullFormatter())
ax.yaxis.set_major_formatter(NullFormatter())

Expand All @@ -86,9 +87,40 @@
print("S-curve, perplexity=%d in %.2g sec" % (perplexity, t1 - t0))

ax.set_title("Perplexity=%d" % perplexity)
ax.scatter(Y[:, 0], Y[:, 1], c=color, cmap=plt.cm.Spectral)
ax.scatter(Y[:, 0], Y[:, 1], c=color, cmap=plt.cm.viridis)
ax.xaxis.set_major_formatter(NullFormatter())
ax.yaxis.set_major_formatter(NullFormatter())
ax.axis('tight')


# Another example using a 2D uniform grid
x = np.linspace(0, 1, int(np.sqrt(n_samples)))
xx, yy = np.meshgrid(x, x)
X = np.hstack([
xx.ravel().reshape(-1, 1),
yy.ravel().reshape(-1, 1),
])
color = xx.ravel()
ax = subplots[2][0]
ax.scatter(X[:, 0], X[:, 1], c=color, cmap=plt.cm.viridis)
ax.xaxis.set_major_formatter(NullFormatter())
ax.yaxis.set_major_formatter(NullFormatter())

for i, perplexity in enumerate(perplexities):
ax = subplots[2][i + 1]

t0 = time()
tsne = manifold.TSNE(n_components=n_components, init='random',
random_state=0, perplexity=perplexity)
Y = tsne.fit_transform(X)
t1 = time()
print("uniform grid, perplexity=%d in %.2g sec" % (perplexity, t1 - t0))

ax.set_title("Perplexity=%d" % perplexity)
ax.scatter(Y[:, 0], Y[:, 1], c=color, cmap=plt.cm.viridis)
ax.xaxis.set_major_formatter(NullFormatter())
ax.yaxis.set_major_formatter(NullFormatter())
ax.axis('tight')


plt.show()
Loading