From 40f551dcafe76e1a106bd617de0ac96f355eab99 Mon Sep 17 00:00:00 2001 From: ArturoAmorQ Date: Tue, 24 Jan 2023 14:55:40 +0100 Subject: [PATCH 01/13] DOC Add pynndescent to Approximate nearest neighbors in TSNE example --- .../approximate_nearest_neighbors.py | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/examples/neighbors/approximate_nearest_neighbors.py b/examples/neighbors/approximate_nearest_neighbors.py index 479e324cd6aa4..22a3d653cf8bf 100644 --- a/examples/neighbors/approximate_nearest_neighbors.py +++ b/examples/neighbors/approximate_nearest_neighbors.py @@ -4,9 +4,9 @@ ===================================== This example presents how to chain KNeighborsTransformer and TSNE in a pipeline. -It also shows how to wrap the packages `annoy` and `nmslib` to replace -KNeighborsTransformer and perform approximate nearest neighbors. These packages -can be installed with `pip install annoy nmslib`. +It also shows how to wrap the packages `annoy`, `nmslib` and `pynndescent` to +replace KNeighborsTransformer and perform approximate nearest neighbors. These +packages can be installed with `pip install annoy nmslib pynndescent`. Note: In KNeighborsTransformer we use the definition which includes each training point as its own neighbor in the count of `n_neighbors`, and for @@ -20,9 +20,11 @@ AnnoyTransformer: 0.305 sec NMSlibTransformer: 0.144 sec KNeighborsTransformer: 0.090 sec + PyNNDescentTransformer: 23.402 sec TSNE with AnnoyTransformer: 2.818 sec TSNE with NMSlibTransformer: 2.592 sec TSNE with KNeighborsTransformer: 2.338 sec + TSNE with PyNNDescentTransformer: 6.288 sec TSNE with internal NearestNeighbors: 2.364 sec Benchmarking on MNIST_10000: @@ -30,9 +32,11 @@ AnnoyTransformer: 2.874 sec NMSlibTransformer: 1.098 sec KNeighborsTransformer: 1.264 sec + PyNNDescentTransformer: 7.170 sec TSNE with AnnoyTransformer: 16.118 sec TSNE with NMSlibTransformer: 15.281 sec TSNE with KNeighborsTransformer: 15.400 sec + TSNE with PyNNDescentTransformer: 28.782 sec TSNE with internal NearestNeighbors: 15.573 sec @@ -65,6 +69,7 @@ import matplotlib.pyplot as plt from matplotlib.ticker import NullFormatter from scipy.sparse import csr_matrix +from pynndescent import PyNNDescentTransformer from sklearn.base import BaseEstimator, TransformerMixin from sklearn.neighbors import KNeighborsTransformer @@ -237,6 +242,10 @@ def run_benchmark(): n_neighbors=n_neighbors, mode="distance", metric=metric ), ), + ( + "PyNNDescentTransformer", + PyNNDescentTransformer(n_neighbors=n_neighbors, metric=metric), + ), ( "TSNE with AnnoyTransformer", make_pipeline( @@ -260,6 +269,13 @@ def run_benchmark(): TSNE(metric="precomputed", **tsne_params), ), ), + ( + "TSNE with PyNNDescentTransformer", + make_pipeline( + PyNNDescentTransformer(n_neighbors=n_neighbors, metric=metric), + TSNE(metric="precomputed", **tsne_params), + ), + ), ("TSNE with internal NearestNeighbors", TSNE(metric=metric, **tsne_params)), ] From 5e83ad4113e9faa3759c5c8939580f8d712440ac Mon Sep 17 00:00:00 2001 From: ArturoAmorQ Date: Tue, 24 Jan 2023 15:10:02 +0100 Subject: [PATCH 02/13] Add ImportError message if pynndescent is missing --- examples/neighbors/approximate_nearest_neighbors.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/neighbors/approximate_nearest_neighbors.py b/examples/neighbors/approximate_nearest_neighbors.py index 22a3d653cf8bf..8e5f0aa7a3ed9 100644 --- a/examples/neighbors/approximate_nearest_neighbors.py +++ b/examples/neighbors/approximate_nearest_neighbors.py @@ -65,11 +65,16 @@ print("The package 'nmslib' is required to run this example.") sys.exit() +try: + from pynndescent import PyNNDescentTransformer +except ImportError: + print("The package 'pynndescent' is required to run this example.") + sys.exit() + import numpy as np import matplotlib.pyplot as plt from matplotlib.ticker import NullFormatter from scipy.sparse import csr_matrix -from pynndescent import PyNNDescentTransformer from sklearn.base import BaseEstimator, TransformerMixin from sklearn.neighbors import KNeighborsTransformer From 4488c33d847495aa40770ce50bf1d9cfe3644847 Mon Sep 17 00:00:00 2001 From: ArturoAmorQ Date: Tue, 24 Jan 2023 17:57:29 +0100 Subject: [PATCH 03/13] Remove AnnoyTransformer from benchmark --- .../approximate_nearest_neighbors.py | 87 +------------------ 1 file changed, 4 insertions(+), 83 deletions(-) diff --git a/examples/neighbors/approximate_nearest_neighbors.py b/examples/neighbors/approximate_nearest_neighbors.py index 8e5f0aa7a3ed9..89b104e85b016 100644 --- a/examples/neighbors/approximate_nearest_neighbors.py +++ b/examples/neighbors/approximate_nearest_neighbors.py @@ -4,9 +4,9 @@ ===================================== This example presents how to chain KNeighborsTransformer and TSNE in a pipeline. -It also shows how to wrap the packages `annoy`, `nmslib` and `pynndescent` to -replace KNeighborsTransformer and perform approximate nearest neighbors. These -packages can be installed with `pip install annoy nmslib pynndescent`. +It also shows how to wrap the packages `nmslib` and `pynndescent` to replace +KNeighborsTransformer and perform approximate nearest neighbors. These packages +can be installed with `pip install nmslib pynndescent`. Note: In KNeighborsTransformer we use the definition which includes each training point as its own neighbor in the count of `n_neighbors`, and for @@ -17,11 +17,9 @@ Benchmarking on MNIST_2000: --------------------------- - AnnoyTransformer: 0.305 sec NMSlibTransformer: 0.144 sec KNeighborsTransformer: 0.090 sec PyNNDescentTransformer: 23.402 sec - TSNE with AnnoyTransformer: 2.818 sec TSNE with NMSlibTransformer: 2.592 sec TSNE with KNeighborsTransformer: 2.338 sec TSNE with PyNNDescentTransformer: 6.288 sec @@ -29,11 +27,9 @@ Benchmarking on MNIST_10000: ---------------------------- - AnnoyTransformer: 2.874 sec NMSlibTransformer: 1.098 sec KNeighborsTransformer: 1.264 sec PyNNDescentTransformer: 7.170 sec - TSNE with AnnoyTransformer: 16.118 sec TSNE with NMSlibTransformer: 15.281 sec TSNE with KNeighborsTransformer: 15.400 sec TSNE with PyNNDescentTransformer: 28.782 sec @@ -53,12 +49,6 @@ import time import sys -try: - import annoy -except ImportError: - print("The package 'annoy' is required to run this example.") - sys.exit() - try: import nmslib except ImportError: @@ -131,63 +121,6 @@ def transform(self, X): return kneighbors_graph -class AnnoyTransformer(TransformerMixin, BaseEstimator): - """Wrapper for using annoy.AnnoyIndex as sklearn's KNeighborsTransformer""" - - def __init__(self, n_neighbors=5, metric="euclidean", n_trees=10, search_k=-1): - self.n_neighbors = n_neighbors - self.n_trees = n_trees - self.search_k = search_k - self.metric = metric - - def fit(self, X): - self.n_samples_fit_ = X.shape[0] - self.annoy_ = annoy.AnnoyIndex(X.shape[1], metric=self.metric) - for i, x in enumerate(X): - self.annoy_.add_item(i, x.tolist()) - self.annoy_.build(self.n_trees) - return self - - def transform(self, X): - return self._transform(X) - - def fit_transform(self, X, y=None): - return self.fit(X)._transform(X=None) - - def _transform(self, X): - """As `transform`, but handles X is None for faster `fit_transform`.""" - - n_samples_transform = self.n_samples_fit_ if X is None else X.shape[0] - - # For compatibility reasons, as each sample is considered as its own - # neighbor, one extra neighbor will be computed. - n_neighbors = self.n_neighbors + 1 - - indices = np.empty((n_samples_transform, n_neighbors), dtype=int) - distances = np.empty((n_samples_transform, n_neighbors)) - - if X is None: - for i in range(self.annoy_.get_n_items()): - ind, dist = self.annoy_.get_nns_by_item( - i, n_neighbors, self.search_k, include_distances=True - ) - - indices[i], distances[i] = ind, dist - else: - for i, x in enumerate(X): - indices[i], distances[i] = self.annoy_.get_nns_by_vector( - x.tolist(), n_neighbors, self.search_k, include_distances=True - ) - - indptr = np.arange(0, n_samples_transform * n_neighbors + 1, n_neighbors) - kneighbors_graph = csr_matrix( - (distances.ravel(), indices.ravel(), indptr), - shape=(n_samples_transform, self.n_samples_fit_), - ) - - return kneighbors_graph - - def test_transformers(): """Test that AnnoyTransformer and KNeighborsTransformer give same results""" X = np.random.RandomState(42).randn(10, 2) @@ -195,14 +128,10 @@ def test_transformers(): knn = KNeighborsTransformer() Xt0 = knn.fit_transform(X) - ann = AnnoyTransformer() - Xt1 = ann.fit_transform(X) - nms = NMSlibTransformer() - Xt2 = nms.fit_transform(X) + Xt1 = nms.fit_transform(X) assert_array_almost_equal(Xt0.toarray(), Xt1.toarray(), decimal=5) - assert_array_almost_equal(Xt0.toarray(), Xt2.toarray(), decimal=5) def load_mnist(n_samples): @@ -236,7 +165,6 @@ def run_benchmark(): ) transformers = [ - ("AnnoyTransformer", AnnoyTransformer(n_neighbors=n_neighbors, metric=metric)), ( "NMSlibTransformer", NMSlibTransformer(n_neighbors=n_neighbors, metric=metric), @@ -251,13 +179,6 @@ def run_benchmark(): "PyNNDescentTransformer", PyNNDescentTransformer(n_neighbors=n_neighbors, metric=metric), ), - ( - "TSNE with AnnoyTransformer", - make_pipeline( - AnnoyTransformer(n_neighbors=n_neighbors, metric=metric), - TSNE(metric="precomputed", **tsne_params), - ), - ), ( "TSNE with NMSlibTransformer", make_pipeline( From 62c0932696dd0ab6ef7aaf667cc0c32453172ed3 Mon Sep 17 00:00:00 2001 From: ArturoAmorQ Date: Wed, 25 Jan 2023 15:07:12 +0100 Subject: [PATCH 04/13] Remove test to assert almost exact results --- .../neighbors/approximate_nearest_neighbors.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/examples/neighbors/approximate_nearest_neighbors.py b/examples/neighbors/approximate_nearest_neighbors.py index 89b104e85b016..bc219c472cf12 100644 --- a/examples/neighbors/approximate_nearest_neighbors.py +++ b/examples/neighbors/approximate_nearest_neighbors.py @@ -68,7 +68,6 @@ from sklearn.base import BaseEstimator, TransformerMixin from sklearn.neighbors import KNeighborsTransformer -from sklearn.utils._testing import assert_array_almost_equal from sklearn.datasets import fetch_openml from sklearn.pipeline import make_pipeline from sklearn.manifold import TSNE @@ -121,19 +120,6 @@ def transform(self, X): return kneighbors_graph -def test_transformers(): - """Test that AnnoyTransformer and KNeighborsTransformer give same results""" - X = np.random.RandomState(42).randn(10, 2) - - knn = KNeighborsTransformer() - Xt0 = knn.fit_transform(X) - - nms = NMSlibTransformer() - Xt1 = nms.fit_transform(X) - - assert_array_almost_equal(Xt0.toarray(), Xt1.toarray(), decimal=5) - - def load_mnist(n_samples): """Load MNIST, shuffle the data, and return only n_samples.""" mnist = fetch_openml("mnist_784", as_frame=False, parser="pandas") @@ -249,5 +235,4 @@ def run_benchmark(): if __name__ == "__main__": - test_transformers() run_benchmark() From 7397aed67c21f5f7de0b4dad971378cbe248ec9d Mon Sep 17 00:00:00 2001 From: ArturoAmorQ Date: Wed, 25 Jan 2023 15:10:58 +0100 Subject: [PATCH 05/13] Implement joblib logic and use as many threads as CPUs by default --- .../neighbors/approximate_nearest_neighbors.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/examples/neighbors/approximate_nearest_neighbors.py b/examples/neighbors/approximate_nearest_neighbors.py index bc219c472cf12..49a54194c6f5b 100644 --- a/examples/neighbors/approximate_nearest_neighbors.py +++ b/examples/neighbors/approximate_nearest_neighbors.py @@ -46,6 +46,7 @@ # Author: Tom Dupre la Tour # # License: BSD 3 clause +import joblib import time import sys @@ -77,7 +78,7 @@ class NMSlibTransformer(TransformerMixin, BaseEstimator): """Wrapper for using nmslib as sklearn's KNeighborsTransformer""" - def __init__(self, n_neighbors=5, metric="euclidean", method="sw-graph", n_jobs=1): + def __init__(self, n_neighbors=5, metric="euclidean", method="sw-graph", n_jobs=-1): self.n_neighbors = n_neighbors self.method = method self.metric = metric @@ -96,7 +97,7 @@ def fit(self, X): }[self.metric] self.nmslib_ = nmslib.init(method=self.method, space=space) - self.nmslib_.addDataPointBatch(X) + self.nmslib_.addDataPointBatch(X.copy()) self.nmslib_.createIndex() return self @@ -107,7 +108,16 @@ def transform(self, X): # neighbor, one extra neighbor will be computed. n_neighbors = self.n_neighbors + 1 - results = self.nmslib_.knnQueryBatch(X, k=n_neighbors, num_threads=self.n_jobs) + if self.n_jobs < 0: + # Same handling as done in joblib for negative values of n_jobs: + # in particular, `n_jobs == -1` means "as many threads as CPUs". + num_threads = joblib.cpu_count() + self.n_jobs + 1 + else: + num_threads = self.n_jobs + + results = self.nmslib_.knnQueryBatch( + X.copy(), k=n_neighbors, num_threads=num_threads + ) indices, distances = zip(*results) indices, distances = np.vstack(indices), np.vstack(distances) From ca420a6e9ba3903a5e2936306d321eeefdaf8771 Mon Sep 17 00:00:00 2001 From: ArturoAmorQ Date: Wed, 25 Jan 2023 17:46:21 +0100 Subject: [PATCH 06/13] Implement notebook style and tutorialization --- .../approximate_nearest_neighbors.py | 338 +++++++++++------- 1 file changed, 199 insertions(+), 139 deletions(-) diff --git a/examples/neighbors/approximate_nearest_neighbors.py b/examples/neighbors/approximate_nearest_neighbors.py index 49a54194c6f5b..d6a5747f4c4e1 100644 --- a/examples/neighbors/approximate_nearest_neighbors.py +++ b/examples/neighbors/approximate_nearest_neighbors.py @@ -11,43 +11,15 @@ Note: In KNeighborsTransformer we use the definition which includes each training point as its own neighbor in the count of `n_neighbors`, and for compatibility reasons, one extra neighbor is computed when `mode == 'distance'`. -Please note that we do the same in the proposed wrappers. - -Sample output:: - - Benchmarking on MNIST_2000: - --------------------------- - NMSlibTransformer: 0.144 sec - KNeighborsTransformer: 0.090 sec - PyNNDescentTransformer: 23.402 sec - TSNE with NMSlibTransformer: 2.592 sec - TSNE with KNeighborsTransformer: 2.338 sec - TSNE with PyNNDescentTransformer: 6.288 sec - TSNE with internal NearestNeighbors: 2.364 sec - - Benchmarking on MNIST_10000: - ---------------------------- - NMSlibTransformer: 1.098 sec - KNeighborsTransformer: 1.264 sec - PyNNDescentTransformer: 7.170 sec - TSNE with NMSlibTransformer: 15.281 sec - TSNE with KNeighborsTransformer: 15.400 sec - TSNE with PyNNDescentTransformer: 28.782 sec - TSNE with internal NearestNeighbors: 15.573 sec - - -Note that the prediction speed KNeighborsTransformer was optimized in -scikit-learn 1.1 and therefore approximate methods are not necessarily faster -because computing the index takes time and can nullify the gains obtained at -prediction time. - +Please note that we do the same in the proposed `nmslib` wrapper. """ # Author: Tom Dupre la Tour -# # License: BSD 3 clause -import joblib -import time + +# %% +# First we try to import the packages and warn the user in case they are +# missing. import sys try: @@ -62,16 +34,14 @@ print("The package 'pynndescent' is required to run this example.") sys.exit() +# %% +# We define a wrapper class for implementing the scikit-learn API to the +# `nmslib`, as well as a loading function. +import joblib import numpy as np -import matplotlib.pyplot as plt -from matplotlib.ticker import NullFormatter from scipy.sparse import csr_matrix - from sklearn.base import BaseEstimator, TransformerMixin -from sklearn.neighbors import KNeighborsTransformer from sklearn.datasets import fetch_openml -from sklearn.pipeline import make_pipeline -from sklearn.manifold import TSNE from sklearn.utils import shuffle @@ -137,112 +107,202 @@ def load_mnist(n_samples): return X[:n_samples] / 255, y[:n_samples] -def run_benchmark(): - datasets = [ - ("MNIST_2000", load_mnist(n_samples=2000)), - ("MNIST_10000", load_mnist(n_samples=10000)), - ] - - n_iter = 500 - perplexity = 30 - metric = "euclidean" - # TSNE requires a certain number of neighbors which depends on the - # perplexity parameter. - # Add one since we include each sample as its own neighbor. - n_neighbors = int(3.0 * perplexity + 1) + 1 - - tsne_params = dict( - init="random", # pca not supported for sparse matrices - perplexity=perplexity, - method="barnes_hut", - random_state=42, - n_iter=n_iter, - learning_rate="auto", - ) - - transformers = [ - ( - "NMSlibTransformer", - NMSlibTransformer(n_neighbors=n_neighbors, metric=metric), +# %% +# We benchmark the different exact/approximate nearest neighbors transformers. +import time + +from sklearn.manifold import TSNE +from sklearn.neighbors import KNeighborsTransformer +from sklearn.pipeline import make_pipeline + +datasets = [ + ("MNIST_10000", load_mnist(n_samples=2000)), + ("MNIST_20000", load_mnist(n_samples=20000)), +] + +n_iter = 500 +perplexity = 30 +metric = "euclidean" +# TSNE requires a certain number of neighbors which depends on the +# perplexity parameter. +# Add one since we include each sample as its own neighbor. +n_neighbors = int(3.0 * perplexity + 1) + 1 + +tsne_params = dict( + init="random", # pca not supported for sparse matrices + perplexity=perplexity, + method="barnes_hut", + random_state=42, + n_iter=n_iter, + learning_rate="auto", +) + +transformers = [ + ( + "KNeighborsTransformer", + KNeighborsTransformer(n_neighbors=n_neighbors, mode="distance", metric=metric), + ), + ( + "NMSlibTransformer", + NMSlibTransformer(n_neighbors=n_neighbors, metric=metric), + ), + ( + "PyNNDescentTransformer", + PyNNDescentTransformer( + n_neighbors=n_neighbors, metric=metric, parallel_batch_queries=True ), - ( - "KNeighborsTransformer", + ), +] + +for dataset_name, (X, y) in datasets: + + msg = "Benchmarking on %s:" % dataset_name + print("\n%s\n%s" % (msg, "-" * len(msg))) + + for transformer_name, transformer in transformers: + longest = np.max([len(name) for name, model in transformers]) + whitespaces = " " * (longest - len(transformer_name)) + for _ in range(2): + start = time.time() + transformer.fit(X) + fit_duration = time.time() - start + print( + "%s: %s%.3f sec (fit)" % (transformer_name, whitespaces, fit_duration) + ) + for _ in range(2): + start = time.time() + Xt = transformer.transform(X) + transform_duration = time.time() - start + print( + "%s: %s%.3f sec (transform)" + % (transformer_name, whitespaces, transform_duration) + ) + +# %% +# Sample output:: +# +# Benchmarking on MNIST_10000: +# ---------------------------- +# KNeighborsTransformer: 0.005 sec (fit) +# KNeighborsTransformer: 0.004 sec (fit) +# KNeighborsTransformer: 1.285 sec (transform) +# KNeighborsTransformer: 1.162 sec (transform) +# NMSlibTransformer: 0.226 sec (fit) +# NMSlibTransformer: 0.235 sec (fit) +# NMSlibTransformer: 0.323 sec (transform) +# NMSlibTransformer: 0.295 sec (transform) +# PyNNDescentTransformer: 18.129 sec (fit) +# PyNNDescentTransformer: 4.584 sec (fit) +# PyNNDescentTransformer: 15.092 sec (transform) +# PyNNDescentTransformer: 0.862 sec (transform) +# +# Benchmarking on MNIST_20000: +# ---------------------------- +# KNeighborsTransformer: 0.010 sec (fit) +# KNeighborsTransformer: 0.010 sec (fit) +# KNeighborsTransformer: 6.992 sec (transform) +# KNeighborsTransformer: 6.951 sec (transform) +# NMSlibTransformer: 0.777 sec (fit) +# NMSlibTransformer: 0.788 sec (fit) +# NMSlibTransformer: 0.796 sec (transform) +# NMSlibTransformer: 0.740 sec (transform) +# PyNNDescentTransformer: 13.609 sec (fit) +# PyNNDescentTransformer: 13.359 sec (fit) +# PyNNDescentTransformer: 7.001 sec (transform) +# PyNNDescentTransformer: 1.748 sec (transform) +# +# Notice that the `PyNNDescentTransformer` takes more time during the first +# `fit` and the first `transform` due to storing in the cache memory, but a +# second run will dramatically improve prediction time. Both +# :class:`~sklearn.neighbors.KNeighborsTransformer` and `NMSlibTransformer` show +# more stable `fit` and `transform` times. + +# %% +import matplotlib.pyplot as plt +from matplotlib.ticker import NullFormatter + +transformers = [ + ( + "TSNE with KNeighborsTransformer", + make_pipeline( KNeighborsTransformer( n_neighbors=n_neighbors, mode="distance", metric=metric ), + TSNE(metric="precomputed", **tsne_params), ), - ( - "PyNNDescentTransformer", - PyNNDescentTransformer(n_neighbors=n_neighbors, metric=metric), - ), - ( - "TSNE with NMSlibTransformer", - make_pipeline( - NMSlibTransformer(n_neighbors=n_neighbors, metric=metric), - TSNE(metric="precomputed", **tsne_params), - ), - ), - ( - "TSNE with KNeighborsTransformer", - make_pipeline( - KNeighborsTransformer( - n_neighbors=n_neighbors, mode="distance", metric=metric - ), - TSNE(metric="precomputed", **tsne_params), - ), - ), - ( - "TSNE with PyNNDescentTransformer", - make_pipeline( - PyNNDescentTransformer(n_neighbors=n_neighbors, metric=metric), - TSNE(metric="precomputed", **tsne_params), - ), + ), + ( + "TSNE with NMSlibTransformer", + make_pipeline( + NMSlibTransformer(n_neighbors=n_neighbors, metric=metric), + TSNE(metric="precomputed", **tsne_params), ), - ("TSNE with internal NearestNeighbors", TSNE(metric=metric, **tsne_params)), - ] - - # init the plot - nrows = len(datasets) - ncols = np.sum([1 for name, model in transformers if "TSNE" in name]) - fig, axes = plt.subplots( - nrows=nrows, ncols=ncols, squeeze=False, figsize=(5 * ncols, 4 * nrows) - ) - axes = axes.ravel() - i_ax = 0 + ), + ("TSNE with internal NearestNeighbors", TSNE(metric=metric, **tsne_params)), +] + +# init the plot +nrows = len(datasets) +ncols = np.sum([1 for name, model in transformers if "TSNE" in name]) +fig, axes = plt.subplots( + nrows=nrows, ncols=ncols, squeeze=False, figsize=(5 * ncols, 4 * nrows) +) +axes = axes.ravel() +i_ax = 0 + +for dataset_name, (X, y) in datasets: + + msg = "Benchmarking on %s:" % dataset_name + print("\n%s\n%s" % (msg, "-" * len(msg))) + + for transformer_name, transformer in transformers: + longest = np.max([len(name) for name, model in transformers]) + whitespaces = " " * (longest - len(transformer_name)) + start = time.time() + Xt = transformer.fit_transform(X) + transform_duration = time.time() - start + print( + "%s: %s%.3f sec (fit_transform)" + % (transformer_name, whitespaces, transform_duration) + ) - for dataset_name, (X, y) in datasets: + # plot TSNE embedding which should be very similar across methods + axes[i_ax].set_title(transformer_name + "\non " + dataset_name) + axes[i_ax].scatter( + Xt[:, 0], + Xt[:, 1], + c=y.astype(np.int32), + alpha=0.2, + cmap=plt.cm.viridis, + ) + axes[i_ax].xaxis.set_major_formatter(NullFormatter()) + axes[i_ax].yaxis.set_major_formatter(NullFormatter()) + axes[i_ax].axis("tight") + i_ax += 1 - msg = "Benchmarking on %s:" % dataset_name - print("\n%s\n%s" % (msg, "-" * len(msg))) +fig.tight_layout() +plt.show() - for transformer_name, transformer in transformers: - start = time.time() - Xt = transformer.fit_transform(X) - duration = time.time() - start - - # print the duration report - longest = np.max([len(name) for name, model in transformers]) - whitespaces = " " * (longest - len(transformer_name)) - print("%s: %s%.3f sec" % (transformer_name, whitespaces, duration)) - - # plot TSNE embedding which should be very similar across methods - if "TSNE" in transformer_name: - axes[i_ax].set_title(transformer_name + "\non " + dataset_name) - axes[i_ax].scatter( - Xt[:, 0], - Xt[:, 1], - c=y.astype(np.int32), - alpha=0.2, - cmap=plt.cm.viridis, - ) - axes[i_ax].xaxis.set_major_formatter(NullFormatter()) - axes[i_ax].yaxis.set_major_formatter(NullFormatter()) - axes[i_ax].axis("tight") - i_ax += 1 - - fig.tight_layout() - plt.show() - - -if __name__ == "__main__": - run_benchmark() +# %% +# Sample output:: +# +# Benchmarking on MNIST_10000: +# ---------------------------- +# TSNE with KNeighborsTransformer: 20.111 sec (fit_transform) +# TSNE with NMSlibTransformer: 21.757 sec (fit_transform) +# TSNE with internal NearestNeighbors: 24.828 sec (fit_transform) +# +# Benchmarking on MNIST_20000: +# ---------------------------- +# TSNE with KNeighborsTransformer: 50.994 sec (fit_transform) +# TSNE with NMSlibTransformer: 43.536 sec (fit_transform) +# TSNE with internal NearestNeighbors: 51.955 sec (fit_transform) +# +# Notice that the prediction speed +# :class:`~sklearn.neighbors.KNeighborsTransformer` was optimized in +# scikit-learn 1.1 and therefore the total `fit_transform` time of approximate +# methods is not necessarily lower than the exact +# :class:`~sklearn.neighbors.KNeighborsTransformer` solution. The reason is that +# computing the index takes time and can nullify the benefits obtained by the +# approximation. Indeed, the gains with respect to the exact solution increase +# with increasing number of samples. From 2a3bfc27607740b0be66102dcf4df8e3bc8f62a3 Mon Sep 17 00:00:00 2001 From: Arturo Amor <86408019+ArturoAmorQ@users.noreply.github.com> Date: Thu, 26 Jan 2023 16:53:46 +0100 Subject: [PATCH 07/13] Update examples/neighbors/approximate_nearest_neighbors.py Co-authored-by: Olivier Grisel --- examples/neighbors/approximate_nearest_neighbors.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/neighbors/approximate_nearest_neighbors.py b/examples/neighbors/approximate_nearest_neighbors.py index d6a5747f4c4e1..b742e72a3b5c7 100644 --- a/examples/neighbors/approximate_nearest_neighbors.py +++ b/examples/neighbors/approximate_nearest_neighbors.py @@ -211,9 +211,10 @@ def load_mnist(n_samples): # PyNNDescentTransformer: 7.001 sec (transform) # PyNNDescentTransformer: 1.748 sec (transform) # -# Notice that the `PyNNDescentTransformer` takes more time during the first -# `fit` and the first `transform` due to storing in the cache memory, but a -# second run will dramatically improve prediction time. Both +# `fit` and the first `transform` due to the overhead of the numba just +# in time compiler. But after the first call, the compiled Python code +# is kept in a cache by numba and as a result subsequent calls are do +# not suffer from this initial overhead. Both # :class:`~sklearn.neighbors.KNeighborsTransformer` and `NMSlibTransformer` show # more stable `fit` and `transform` times. From 19c6a5e2a8a6cefa0f8c28c8cd7a6c5263ac414b Mon Sep 17 00:00:00 2001 From: ArturoAmorQ Date: Thu, 26 Jan 2023 16:57:26 +0100 Subject: [PATCH 08/13] Fix format --- .../neighbors/approximate_nearest_neighbors.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/examples/neighbors/approximate_nearest_neighbors.py b/examples/neighbors/approximate_nearest_neighbors.py index b742e72a3b5c7..30f3f9a5ad527 100644 --- a/examples/neighbors/approximate_nearest_neighbors.py +++ b/examples/neighbors/approximate_nearest_neighbors.py @@ -116,8 +116,8 @@ def load_mnist(n_samples): from sklearn.pipeline import make_pipeline datasets = [ - ("MNIST_10000", load_mnist(n_samples=2000)), - ("MNIST_20000", load_mnist(n_samples=20000)), + ("MNIST_10000", load_mnist(n_samples=10_000)), + ("MNIST_20000", load_mnist(n_samples=20_000)), ] n_iter = 500 @@ -211,12 +211,11 @@ def load_mnist(n_samples): # PyNNDescentTransformer: 7.001 sec (transform) # PyNNDescentTransformer: 1.748 sec (transform) # -# `fit` and the first `transform` due to the overhead of the numba just -# in time compiler. But after the first call, the compiled Python code -# is kept in a cache by numba and as a result subsequent calls are do -# not suffer from this initial overhead. Both -# :class:`~sklearn.neighbors.KNeighborsTransformer` and `NMSlibTransformer` show -# more stable `fit` and `transform` times. +# `fit` and the first `transform` due to the overhead of the numba just in time +# compiler. But after the first call, the compiled Python code is kept in a +# cache by numba and as a result subsequent calls are do not suffer from this +# initial overhead. Both :class:`~sklearn.neighbors.KNeighborsTransformer` and +# `NMSlibTransformer` show more stable `fit` and `transform` times. # %% import matplotlib.pyplot as plt From f82a75262fb0bf1443cadb134722b47967d14f2d Mon Sep 17 00:00:00 2001 From: Arturo Amor <86408019+ArturoAmorQ@users.noreply.github.com> Date: Fri, 27 Jan 2023 12:05:06 +0100 Subject: [PATCH 09/13] Update examples/neighbors/approximate_nearest_neighbors.py Co-authored-by: Olivier Grisel --- .../approximate_nearest_neighbors.py | 32 ++++++++++++++----- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/examples/neighbors/approximate_nearest_neighbors.py b/examples/neighbors/approximate_nearest_neighbors.py index 30f3f9a5ad527..4c65fe36698f2 100644 --- a/examples/neighbors/approximate_nearest_neighbors.py +++ b/examples/neighbors/approximate_nearest_neighbors.py @@ -298,11 +298,27 @@ def load_mnist(n_samples): # TSNE with NMSlibTransformer: 43.536 sec (fit_transform) # TSNE with internal NearestNeighbors: 51.955 sec (fit_transform) # -# Notice that the prediction speed -# :class:`~sklearn.neighbors.KNeighborsTransformer` was optimized in -# scikit-learn 1.1 and therefore the total `fit_transform` time of approximate -# methods is not necessarily lower than the exact -# :class:`~sklearn.neighbors.KNeighborsTransformer` solution. The reason is that -# computing the index takes time and can nullify the benefits obtained by the -# approximation. Indeed, the gains with respect to the exact solution increase -# with increasing number of samples. +# We can observe that the default `TSNE` estimator with its internal +# :class:`~sklearn.neighbors.NearestNeighbors` implementation is roughly +# equivalent the `TSNE` pipeline with +# :class:`~sklearn.neighbors.KNeighborsTransformer` in terms of performance. +# This is expected because both pipelines rely internally on the same +# :class:`~sklearn.neighbors.NearestNeighbors`NearestNeighbors` implementation +# that performs exacts neighbors search. The approximate `NMSlibTransformer` is +# already slightly faster than exact search on the smallest dataset but this +# speed difference is expected to become more significant on datasets with a +# larger number of samples. +# +# Note however that not all approximate search methods are guaranteed to +# improve upon the speed of the default exact search method: indeed the exact +# search implementation has been significantly improved in scikit-learn 1.1. +# Furthermore, the brute-force exact search method does not require building an +# index at `fit` time. So, to get an overall performance improvement in the +# context of the `TSNE` pipeline, the gains of the approximate search at +# `transform` need to be larger than the extra time speed to build the +# approximate search index at `fit` time. +# +# Finally, the TSNE algorithm itself is also computationally intensive, +# irrespective of the nearest neighbors search. So speeding-up the nearest +# neighbors search step by a factor of 5 would not result in a speed up by a +# factor of 5 for the overall pipeline. From 16236d75e3eec24f5bccd76f1e02072dbb2ea966 Mon Sep 17 00:00:00 2001 From: ArturoAmorQ Date: Fri, 27 Jan 2023 15:46:18 +0100 Subject: [PATCH 10/13] Change double run of KNeighbors and NMSlib transformers to single run --- .../approximate_nearest_neighbors.py | 58 +++++++++---------- 1 file changed, 26 insertions(+), 32 deletions(-) diff --git a/examples/neighbors/approximate_nearest_neighbors.py b/examples/neighbors/approximate_nearest_neighbors.py index 4c65fe36698f2..8810320c4f938 100644 --- a/examples/neighbors/approximate_nearest_neighbors.py +++ b/examples/neighbors/approximate_nearest_neighbors.py @@ -162,14 +162,18 @@ def load_mnist(n_samples): for transformer_name, transformer in transformers: longest = np.max([len(name) for name, model in transformers]) whitespaces = " " * (longest - len(transformer_name)) - for _ in range(2): - start = time.time() - transformer.fit(X) - fit_duration = time.time() - start - print( - "%s: %s%.3f sec (fit)" % (transformer_name, whitespaces, fit_duration) - ) - for _ in range(2): + start = time.time() + transformer.fit(X) + fit_duration = time.time() - start + print("%s: %s%.3f sec (fit)" % (transformer_name, whitespaces, fit_duration)) + start = time.time() + Xt = transformer.transform(X) + transform_duration = time.time() - start + print( + "%s: %s%.3f sec (transform)" + % (transformer_name, whitespaces, transform_duration) + ) + if transformer_name == "PyNNDescentTransformer": start = time.time() Xt = transformer.transform(X) transform_duration = time.time() - start @@ -183,33 +187,23 @@ def load_mnist(n_samples): # # Benchmarking on MNIST_10000: # ---------------------------- -# KNeighborsTransformer: 0.005 sec (fit) -# KNeighborsTransformer: 0.004 sec (fit) -# KNeighborsTransformer: 1.285 sec (transform) -# KNeighborsTransformer: 1.162 sec (transform) -# NMSlibTransformer: 0.226 sec (fit) -# NMSlibTransformer: 0.235 sec (fit) -# NMSlibTransformer: 0.323 sec (transform) -# NMSlibTransformer: 0.295 sec (transform) -# PyNNDescentTransformer: 18.129 sec (fit) -# PyNNDescentTransformer: 4.584 sec (fit) -# PyNNDescentTransformer: 15.092 sec (transform) -# PyNNDescentTransformer: 0.862 sec (transform) +# KNeighborsTransformer: 0.007 sec (fit) +# KNeighborsTransformer: 1.139 sec (transform) +# NMSlibTransformer: 0.208 sec (fit) +# NMSlibTransformer: 0.315 sec (transform) +# PyNNDescentTransformer: 4.823 sec (fit) +# PyNNDescentTransformer: 4.884 sec (transform) +# PyNNDescentTransformer: 0.744 sec (transform) # # Benchmarking on MNIST_20000: # ---------------------------- -# KNeighborsTransformer: 0.010 sec (fit) -# KNeighborsTransformer: 0.010 sec (fit) -# KNeighborsTransformer: 6.992 sec (transform) -# KNeighborsTransformer: 6.951 sec (transform) -# NMSlibTransformer: 0.777 sec (fit) -# NMSlibTransformer: 0.788 sec (fit) -# NMSlibTransformer: 0.796 sec (transform) -# NMSlibTransformer: 0.740 sec (transform) -# PyNNDescentTransformer: 13.609 sec (fit) -# PyNNDescentTransformer: 13.359 sec (fit) -# PyNNDescentTransformer: 7.001 sec (transform) -# PyNNDescentTransformer: 1.748 sec (transform) +# KNeighborsTransformer: 0.011 sec (fit) +# KNeighborsTransformer: 5.769 sec (transform) +# NMSlibTransformer: 0.733 sec (fit) +# NMSlibTransformer: 1.077 sec (transform) +# PyNNDescentTransformer: 14.448 sec (fit) +# PyNNDescentTransformer: 7.103 sec (transform) +# PyNNDescentTransformer: 1.759 sec (transform) # # `fit` and the first `transform` due to the overhead of the numba just in time # compiler. But after the first call, the compiled Python code is kept in a From afe8249e7748ab20a7085fbd023bb428d4642e23 Mon Sep 17 00:00:00 2001 From: ArturoAmorQ Date: Fri, 27 Jan 2023 15:47:29 +0100 Subject: [PATCH 11/13] Fix wording --- .../approximate_nearest_neighbors.py | 38 ++++++++++--------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/examples/neighbors/approximate_nearest_neighbors.py b/examples/neighbors/approximate_nearest_neighbors.py index 8810320c4f938..691cb53c3478e 100644 --- a/examples/neighbors/approximate_nearest_neighbors.py +++ b/examples/neighbors/approximate_nearest_neighbors.py @@ -205,11 +205,13 @@ def load_mnist(n_samples): # PyNNDescentTransformer: 7.103 sec (transform) # PyNNDescentTransformer: 1.759 sec (transform) # +# Notice that the `PyNNDescentTransformer` takes more time during the first # `fit` and the first `transform` due to the overhead of the numba just in time # compiler. But after the first call, the compiled Python code is kept in a -# cache by numba and as a result subsequent calls are do not suffer from this -# initial overhead. Both :class:`~sklearn.neighbors.KNeighborsTransformer` and -# `NMSlibTransformer` show more stable `fit` and `transform` times. +# cache by numba and subsequent calls do not suffer from this initial overhead. +# Both :class:`~sklearn.neighbors.KNeighborsTransformer` and `NMSlibTransformer` +# are only run once here as they would show more stable `fit` and `transform` +# times (they don't have the cold start problem of PyNNDescentTransformer). # %% import matplotlib.pyplot as plt @@ -292,24 +294,24 @@ def load_mnist(n_samples): # TSNE with NMSlibTransformer: 43.536 sec (fit_transform) # TSNE with internal NearestNeighbors: 51.955 sec (fit_transform) # -# We can observe that the default `TSNE` estimator with its internal -# :class:`~sklearn.neighbors.NearestNeighbors` implementation is roughly -# equivalent the `TSNE` pipeline with +# We can observe that the default :class:`~sklearn.manifold.TSNE` estimator with +# its internal :class:`~sklearn.neighbors.NearestNeighbors` implementation is +# roughly equivalent to the pipeline with :class:`~sklearn.manifold.TSNE` and # :class:`~sklearn.neighbors.KNeighborsTransformer` in terms of performance. # This is expected because both pipelines rely internally on the same -# :class:`~sklearn.neighbors.NearestNeighbors`NearestNeighbors` implementation -# that performs exacts neighbors search. The approximate `NMSlibTransformer` is -# already slightly faster than exact search on the smallest dataset but this -# speed difference is expected to become more significant on datasets with a -# larger number of samples. +# :class:`~sklearn.neighbors.NearestNeighbors` implementation that performs +# exacts neighbors search. The approximate `NMSlibTransformer` is already +# slightly faster than the exact search on the smallest dataset but this speed +# difference is expected to become more significant on datasets with a larger +# number of samples. # -# Note however that not all approximate search methods are guaranteed to -# improve upon the speed of the default exact search method: indeed the exact -# search implementation has been significantly improved in scikit-learn 1.1. -# Furthermore, the brute-force exact search method does not require building an -# index at `fit` time. So, to get an overall performance improvement in the -# context of the `TSNE` pipeline, the gains of the approximate search at -# `transform` need to be larger than the extra time speed to build the +# Notice however that not all approximate search methods are guaranteed to +# improve the speed of the default exact search method: indeed the exact search +# implementation significantly improved since scikit-learn 1.1. Furthermore, the +# brute-force exact search method does not require building an index at `fit` +# time. So, to get an overall performance improvement in the context of the +# :class:`~sklearn.manifold.TSNE` pipeline, the gains of the approximate search +# at `transform` need to be larger than the extra time spent to build the # approximate search index at `fit` time. # # Finally, the TSNE algorithm itself is also computationally intensive, From cb0049f9c96f4d773c05f4cef1f3e447faaaa2ee Mon Sep 17 00:00:00 2001 From: ArturoAmorQ Date: Fri, 27 Jan 2023 15:48:32 +0100 Subject: [PATCH 12/13] Fix logical progression --- examples/neighbors/approximate_nearest_neighbors.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/neighbors/approximate_nearest_neighbors.py b/examples/neighbors/approximate_nearest_neighbors.py index 691cb53c3478e..6532ca3982607 100644 --- a/examples/neighbors/approximate_nearest_neighbors.py +++ b/examples/neighbors/approximate_nearest_neighbors.py @@ -218,6 +218,7 @@ def load_mnist(n_samples): from matplotlib.ticker import NullFormatter transformers = [ + ("TSNE with internal NearestNeighbors", TSNE(metric=metric, **tsne_params)), ( "TSNE with KNeighborsTransformer", make_pipeline( @@ -234,7 +235,6 @@ def load_mnist(n_samples): TSNE(metric="precomputed", **tsne_params), ), ), - ("TSNE with internal NearestNeighbors", TSNE(metric=metric, **tsne_params)), ] # init the plot @@ -284,15 +284,15 @@ def load_mnist(n_samples): # # Benchmarking on MNIST_10000: # ---------------------------- +# TSNE with internal NearestNeighbors: 24.828 sec (fit_transform) # TSNE with KNeighborsTransformer: 20.111 sec (fit_transform) # TSNE with NMSlibTransformer: 21.757 sec (fit_transform) -# TSNE with internal NearestNeighbors: 24.828 sec (fit_transform) # # Benchmarking on MNIST_20000: # ---------------------------- +# TSNE with internal NearestNeighbors: 51.955 sec (fit_transform) # TSNE with KNeighborsTransformer: 50.994 sec (fit_transform) # TSNE with NMSlibTransformer: 43.536 sec (fit_transform) -# TSNE with internal NearestNeighbors: 51.955 sec (fit_transform) # # We can observe that the default :class:`~sklearn.manifold.TSNE` estimator with # its internal :class:`~sklearn.neighbors.NearestNeighbors` implementation is From f66a2b531a98a954e056c0a1b81b8163f2170997 Mon Sep 17 00:00:00 2001 From: ArturoAmorQ Date: Wed, 8 Feb 2023 17:25:53 +0100 Subject: [PATCH 13/13] DOC Change print format of Approximate nearest neighbors in TSNE example --- .../approximate_nearest_neighbors.py | 65 +++++++++---------- 1 file changed, 30 insertions(+), 35 deletions(-) diff --git a/examples/neighbors/approximate_nearest_neighbors.py b/examples/neighbors/approximate_nearest_neighbors.py index 6532ca3982607..8b73fa28b7a6e 100644 --- a/examples/neighbors/approximate_nearest_neighbors.py +++ b/examples/neighbors/approximate_nearest_neighbors.py @@ -156,30 +156,26 @@ def load_mnist(n_samples): for dataset_name, (X, y) in datasets: - msg = "Benchmarking on %s:" % dataset_name - print("\n%s\n%s" % (msg, "-" * len(msg))) + msg = f"Benchmarking on {dataset_name}:" + print(f"\n{msg}\n" + str("-" * len(msg))) for transformer_name, transformer in transformers: longest = np.max([len(name) for name, model in transformers]) - whitespaces = " " * (longest - len(transformer_name)) start = time.time() transformer.fit(X) fit_duration = time.time() - start - print("%s: %s%.3f sec (fit)" % (transformer_name, whitespaces, fit_duration)) + print(f"{transformer_name:<{longest}} {fit_duration:.3f} sec (fit)") start = time.time() Xt = transformer.transform(X) transform_duration = time.time() - start - print( - "%s: %s%.3f sec (transform)" - % (transformer_name, whitespaces, transform_duration) - ) + print(f"{transformer_name:<{longest}} {transform_duration:.3f} sec (transform)") if transformer_name == "PyNNDescentTransformer": start = time.time() Xt = transformer.transform(X) transform_duration = time.time() - start print( - "%s: %s%.3f sec (transform)" - % (transformer_name, whitespaces, transform_duration) + f"{transformer_name:<{longest}} {transform_duration:.3f} sec" + " (transform)" ) # %% @@ -187,23 +183,23 @@ def load_mnist(n_samples): # # Benchmarking on MNIST_10000: # ---------------------------- -# KNeighborsTransformer: 0.007 sec (fit) -# KNeighborsTransformer: 1.139 sec (transform) -# NMSlibTransformer: 0.208 sec (fit) -# NMSlibTransformer: 0.315 sec (transform) -# PyNNDescentTransformer: 4.823 sec (fit) -# PyNNDescentTransformer: 4.884 sec (transform) -# PyNNDescentTransformer: 0.744 sec (transform) +# KNeighborsTransformer 0.007 sec (fit) +# KNeighborsTransformer 1.139 sec (transform) +# NMSlibTransformer 0.208 sec (fit) +# NMSlibTransformer 0.315 sec (transform) +# PyNNDescentTransformer 4.823 sec (fit) +# PyNNDescentTransformer 4.884 sec (transform) +# PyNNDescentTransformer 0.744 sec (transform) # # Benchmarking on MNIST_20000: # ---------------------------- -# KNeighborsTransformer: 0.011 sec (fit) -# KNeighborsTransformer: 5.769 sec (transform) -# NMSlibTransformer: 0.733 sec (fit) -# NMSlibTransformer: 1.077 sec (transform) -# PyNNDescentTransformer: 14.448 sec (fit) -# PyNNDescentTransformer: 7.103 sec (transform) -# PyNNDescentTransformer: 1.759 sec (transform) +# KNeighborsTransformer 0.011 sec (fit) +# KNeighborsTransformer 5.769 sec (transform) +# NMSlibTransformer 0.733 sec (fit) +# NMSlibTransformer 1.077 sec (transform) +# PyNNDescentTransformer 14.448 sec (fit) +# PyNNDescentTransformer 7.103 sec (transform) +# PyNNDescentTransformer 1.759 sec (transform) # # Notice that the `PyNNDescentTransformer` takes more time during the first # `fit` and the first `transform` due to the overhead of the numba just in time @@ -248,18 +244,17 @@ def load_mnist(n_samples): for dataset_name, (X, y) in datasets: - msg = "Benchmarking on %s:" % dataset_name - print("\n%s\n%s" % (msg, "-" * len(msg))) + msg = f"Benchmarking on {dataset_name}:" + print(f"\n{msg}\n" + str("-" * len(msg))) for transformer_name, transformer in transformers: longest = np.max([len(name) for name, model in transformers]) - whitespaces = " " * (longest - len(transformer_name)) start = time.time() Xt = transformer.fit_transform(X) transform_duration = time.time() - start print( - "%s: %s%.3f sec (fit_transform)" - % (transformer_name, whitespaces, transform_duration) + f"{transformer_name:<{longest}} {transform_duration:.3f} sec" + " (fit_transform)" ) # plot TSNE embedding which should be very similar across methods @@ -284,15 +279,15 @@ def load_mnist(n_samples): # # Benchmarking on MNIST_10000: # ---------------------------- -# TSNE with internal NearestNeighbors: 24.828 sec (fit_transform) -# TSNE with KNeighborsTransformer: 20.111 sec (fit_transform) -# TSNE with NMSlibTransformer: 21.757 sec (fit_transform) +# TSNE with internal NearestNeighbors 24.828 sec (fit_transform) +# TSNE with KNeighborsTransformer 20.111 sec (fit_transform) +# TSNE with NMSlibTransformer 21.757 sec (fit_transform) # # Benchmarking on MNIST_20000: # ---------------------------- -# TSNE with internal NearestNeighbors: 51.955 sec (fit_transform) -# TSNE with KNeighborsTransformer: 50.994 sec (fit_transform) -# TSNE with NMSlibTransformer: 43.536 sec (fit_transform) +# TSNE with internal NearestNeighbors 51.955 sec (fit_transform) +# TSNE with KNeighborsTransformer 50.994 sec (fit_transform) +# TSNE with NMSlibTransformer 43.536 sec (fit_transform) # # We can observe that the default :class:`~sklearn.manifold.TSNE` estimator with # its internal :class:`~sklearn.neighbors.NearestNeighbors` implementation is