diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index e8fc80644c002..53723c3f6ea86 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -1170,6 +1170,7 @@ Model validation neighbors.RadiusNeighborsRegressor neighbors.NearestCentroid neighbors.NearestNeighbors + neighbors.NeighborhoodComponentsAnalysis .. autosummary:: :toctree: generated/ @@ -1432,6 +1433,7 @@ Low-level methods utils.assert_all_finite utils.check_X_y utils.check_array + utils.check_scalar utils.check_consistent_length utils.check_random_state utils.class_weight.compute_class_weight diff --git a/doc/modules/decomposition.rst b/doc/modules/decomposition.rst index 33e93207ffbf0..5bfa96bbb759c 100644 --- a/doc/modules/decomposition.rst +++ b/doc/modules/decomposition.rst @@ -953,3 +953,7 @@ when data can be fetched sequentially. * `"Stochastic Variational Inference" `_ M. Hoffman, D. Blei, C. Wang, J. Paisley, 2013 + + +See also :ref:`nca_dim_reduction` for dimensionality reduction with +Neighborhood Components Analysis. diff --git a/doc/modules/neighbors.rst b/doc/modules/neighbors.rst index fec8d98a999eb..094eec438d357 100644 --- a/doc/modules/neighbors.rst +++ b/doc/modules/neighbors.rst @@ -510,3 +510,217 @@ the model from 0.81 to 0.82. * :ref:`sphx_glr_auto_examples_neighbors_plot_nearest_centroid.py`: an example of classification using nearest centroid with different shrink thresholds. + + +.. _nca: + +Neighborhood Components Analysis +================================ + +.. sectionauthor:: William de Vazelhes + +Neighborhood Components Analysis (NCA, :class:`NeighborhoodComponentsAnalysis`) +is a distance metric learning algorithm which aims to improve the accuracy of +nearest neighbors classification compared to the standard Euclidean distance. +The algorithm directly maximizes a stochastic variant of the leave-one-out +k-nearest neighbors (KNN) score on the training set. It can also learn a +low-dimensional linear projection of data that can be used for data +visualization and fast classification. + +.. |nca_illustration_1| image:: ../auto_examples/neighbors/images/sphx_glr_plot_nca_illustration_001.png + :target: ../auto_examples/neighbors/plot_nca_illustration.html + :scale: 50 + +.. |nca_illustration_2| image:: ../auto_examples/neighbors/images/sphx_glr_plot_nca_illustration_002.png + :target: ../auto_examples/neighbors/plot_nca_illustration.html + :scale: 50 + +.. centered:: |nca_illustration_1| |nca_illustration_2| + +In the above illustrating figure, we consider some points from a randomly +generated dataset. We focus on the stochastic KNN classification of point no. +3. The thickness of a link between sample 3 and another point is proportional +to their distance, and can be seen as the relative weight (or probability) that +a stochastic nearest neighbor prediction rule would assign to this point. In +the original space, sample 3 has many stochastic neighbors from various +classes, so the right class is not very likely. However, in the projected space +learned by NCA, the only stochastic neighbors with non-negligible weight are +from the same class as sample 3, guaranteeing that the latter will be well +classified. See the :ref:`mathematical formulation ` +for more details. + + +Classification +-------------- + +Combined with a nearest neighbors classifier (:class:`KNeighborsClassifier`), +NCA is attractive for classification because it can naturally handle +multi-class problems without any increase in the model size, and does not +introduce additional parameters that require fine-tuning by the user. + +NCA classification has been shown to work well in practice for data sets of +varying size and difficulty. In contrast to related methods such as Linear +Discriminant Analysis, NCA does not make any assumptions about the class +distributions. The nearest neighbor classification can naturally produce highly +irregular decision boundaries. + +To use this model for classification, one needs to combine a +:class:`NeighborhoodComponentsAnalysis` instance that learns the optimal +transformation with a :class:`KNeighborsClassifier` instance that performs the +classification in the projected space. Here is an example using the two +classes: + + >>> from sklearn.neighbors import (NeighborhoodComponentsAnalysis, + ... KNeighborsClassifier) + >>> from sklearn.datasets import load_iris + >>> from sklearn.model_selection import train_test_split + >>> from sklearn.pipeline import Pipeline + >>> X, y = load_iris(return_X_y=True) + >>> X_train, X_test, y_train, y_test = train_test_split(X, y, + ... stratify=y, test_size=0.7, random_state=42) + >>> nca = NeighborhoodComponentsAnalysis(random_state=42) + >>> knn = KNeighborsClassifier(n_neighbors=3) + >>> nca_pipe = Pipeline([('nca', nca), ('knn', knn)]) + >>> nca_pipe.fit(X_train, y_train) # doctest: +ELLIPSIS + Pipeline(...) + >>> print(nca_pipe.score(X_test, y_test)) # doctest: +ELLIPSIS + 0.96190476... + +.. |nca_classification_1| image:: ../auto_examples/neighbors/images/sphx_glr_plot_nca_classification_001.png + :target: ../auto_examples/neighbors/plot_nca_classification.html + :scale: 50 + +.. |nca_classification_2| image:: ../auto_examples/neighbors/images/sphx_glr_plot_nca_classification_002.png + :target: ../auto_examples/neighbors/plot_nca_classification.html + :scale: 50 + +.. centered:: |nca_classification_1| |nca_classification_2| + +The plot shows decision boundaries for Nearest Neighbor Classification and +Neighborhood Components Analysis classification on the iris dataset, when +training and scoring on only two features, for visualisation purposes. + +.. _nca_dim_reduction: + +Dimensionality reduction +------------------------ + +NCA can be used to perform supervised dimensionality reduction. The input data +are projected onto a linear subspace consisting of the directions which +minimize the NCA objective. The desired dimensionality can be set using the +parameter ``n_components``. For instance, the following figure shows a +comparison of dimensionality reduction with Principal Component Analysis +(:class:`sklearn.decomposition.PCA`), Linear Discriminant Analysis +(:class:`sklearn.discriminant_analysis.LinearDiscriminantAnalysis`) and +Neighborhood Component Analysis (:class:`NeighborhoodComponentsAnalysis`) on +the Digits dataset, a dataset with size :math:`n_{samples} = 1797` and +:math:`n_{features} = 64`. The data set is split into a training and a test set +of equal size, then standardized. For evaluation the 3-nearest neighbor +classification accuracy is computed on the 2-dimensional projected points found +by each method. Each data sample belongs to one of 10 classes. + +.. |nca_dim_reduction_1| image:: ../auto_examples/neighbors/images/sphx_glr_plot_nca_dim_reduction_001.png + :target: ../auto_examples/neighbors/plot_nca_dim_reduction.html + :width: 32% + +.. |nca_dim_reduction_2| image:: ../auto_examples/neighbors/images/sphx_glr_plot_nca_dim_reduction_002.png + :target: ../auto_examples/neighbors/plot_nca_dim_reduction.html + :width: 32% + +.. |nca_dim_reduction_3| image:: ../auto_examples/neighbors/images/sphx_glr_plot_nca_dim_reduction_003.png + :target: ../auto_examples/neighbors/plot_nca_dim_reduction.html + :width: 32% + +.. centered:: |nca_dim_reduction_1| |nca_dim_reduction_2| |nca_dim_reduction_3| + + +.. topic:: Examples: + + * :ref:`sphx_glr_auto_examples_neighbors_plot_nca_classification.py` + * :ref:`sphx_glr_auto_examples_neighbors_plot_nca_dim_reduction.py` + * :ref:`sphx_glr_auto_examples_manifold_plot_lle_digits.py` + +.. _nca_mathematical_formulation: + +Mathematical formulation +------------------------ + +The goal of NCA is to learn an optimal linear transformation matrix of size +``(n_components, n_features)``, which maximises the sum over all samples +:math:`i` of the probability :math:`p_i` that :math:`i` is correctly +classified, i.e.: + +.. math:: + + \underset{L}{\arg\max} \sum\limits_{i=0}^{N - 1} p_{i} + +with :math:`N` = ``n_samples`` and :math:`p_i` the probability of sample +:math:`i` being correctly classified according to a stochastic nearest +neighbors rule in the learned embedded space: + +.. math:: + + p_{i}=\sum\limits_{j \in C_i}{p_{i j}} + +where :math:`C_i` is the set of points in the same class as sample :math:`i`, +and :math:`p_{i j}` is the softmax over Euclidean distances in the embedded +space: + +.. math:: + + p_{i j} = \frac{\exp(-||L x_i - L x_j||^2)}{\sum\limits_{k \ne + i} {\exp{-(||L x_i - L x_k||^2)}}} , \quad p_{i i} = 0 + + +Mahalanobis distance +^^^^^^^^^^^^^^^^^^^^ + +NCA can be seen as learning a (squared) Mahalanobis distance metric: + +.. math:: + + || L(x_i - x_j)||^2 = (x_i - x_j)^TM(x_i - x_j), + +where :math:`M = L^T L` is a symmetric positive semi-definite matrix of size +``(n_features, n_features)``. + + +Implementation +-------------- + +This implementation follows what is explained in the original paper [1]_. For +the optimisation method, it currently uses scipy's L-BFGS-B with a full +gradient computation at each iteration, to avoid to tune the learning rate and +provide stable learning. + +See the examples below and the docstring of +:meth:`NeighborhoodComponentsAnalysis.fit` for further information. + +Complexity +---------- + +Training +^^^^^^^^ +NCA stores a matrix of pairwise distances, taking ``n_samples ** 2`` memory. +Time complexity depends on the number of iterations done by the optimisation +algorithm. However, one can set the maximum number of iterations with the +argument ``max_iter``. For each iteration, time complexity is +``O(n_components x n_samples x min(n_samples, n_features))``. + + +Transform +^^^^^^^^^ +Here the ``transform`` operation returns :math:`LX^T`, therefore its time +complexity equals ``n_components * n_features * n_samples_test``. There is no +added space complexity in the operation. + + +.. topic:: References: + + .. [1] `"Neighbourhood Components Analysis". Advances in Neural Information" + `_, + J. Goldberger, G. Hinton, S. Roweis, R. Salakhutdinov, Advances in + Neural Information Processing Systems, Vol. 17, May 2005, pp. 513-520. + + .. [2] `Wikipedia entry on Neighborhood Components Analysis + `_ diff --git a/doc/modules/neural_networks_supervised.rst b/doc/modules/neural_networks_supervised.rst index d3e3ac5710cb1..793de7f8212d1 100644 --- a/doc/modules/neural_networks_supervised.rst +++ b/doc/modules/neural_networks_supervised.rst @@ -152,7 +152,7 @@ indices where the value is `1` represents the assigned classes of that sample:: >>> clf.predict([[0., 0.]]) array([[0, 1]]) -See the examples below and the doc string of +See the examples below and the docstring of :meth:`MLPClassifier.fit` for further information. .. topic:: Examples: diff --git a/doc/modules/sgd.rst b/doc/modules/sgd.rst index 08e864a71b76e..b28c6918cd0f6 100644 --- a/doc/modules/sgd.rst +++ b/doc/modules/sgd.rst @@ -154,7 +154,7 @@ one-vs-all classification. :class:`SGDClassifier` supports both weighted classes and weighted instances via the fit parameters ``class_weight`` and ``sample_weight``. See -the examples below and the doc string of :meth:`SGDClassifier.fit` for +the examples below and the docstring of :meth:`SGDClassifier.fit` for further information. .. topic:: Examples: diff --git a/doc/whats_new/v0.21.rst b/doc/whats_new/v0.21.rst index 36582d834c708..f0df026b2f01e 100644 --- a/doc/whats_new/v0.21.rst +++ b/doc/whats_new/v0.21.rst @@ -82,7 +82,7 @@ Support for Python 3.4 and below has been officially dropped. - |Fix| Fixed a bug in :class:`decomposition.NMF` where `init = 'nndsvd'`, `init = 'nndsvda'`, and `init = 'nndsvdar'` are allowed when `n_components < n_features` instead of - `n_components <= min(n_samples, n_features)`. + `n_components <= min(n_samples, n_features)`. :issue:`11650` by :user:`Hossein Pourbozorg ` and :user:`Zijie (ZJ) Poh `. @@ -167,7 +167,7 @@ Support for Python 3.4 and below has been officially dropped. - |Fix| Fixed a bug in :class:`linear_model.LassoLarsIC`, where user input ``copy_X=False`` at instance creation would be overridden by default - parameter value ``copy_X=True`` in ``fit``. + parameter value ``copy_X=True`` in ``fit``. :issue:`12972` by :user:`Lucio Fernandez-Arjona ` :mod:`sklearn.manifold` @@ -235,6 +235,12 @@ Support for Python 3.4 and below has been officially dropped. :mod:`sklearn.neighbors` ........................ +- |MajorFeature| A metric learning algorithm: + :class:`neighbors.NeighborhoodComponentsAnalysis`, which implements the + Neighborhood Components Analysis algorithm described in Goldberger et al. + (2005). :issue:`10058` by :user:`William de Vazelhes + ` and :user:`John Chiotellis `. + - |API| Methods in :class:`neighbors.NearestNeighbors` : :func:`~neighbors.NearestNeighbors.kneighbors`, :func:`~neighbors.NearestNeighbors.radius_neighbors`, diff --git a/examples/manifold/plot_lle_digits.py b/examples/manifold/plot_lle_digits.py index 133d81bab0f62..4a3002a05d0dd 100644 --- a/examples/manifold/plot_lle_digits.py +++ b/examples/manifold/plot_lle_digits.py @@ -15,6 +15,11 @@ this example, which is not the default setting. It ensures global stability of the embedding, i.e., the embedding does not depend on random initialization. + +Linear Discriminant Analysis, from the :mod:`sklearn.discriminant_analysis` +module, and Neighborhood Components Analysis, from the :mod:`sklearn.neighbors` +module, are supervised dimensionality reduction method, i.e. they make use of +the provided labels, contrary to other methods. """ # Authors: Fabian Pedregosa @@ -30,7 +35,7 @@ import matplotlib.pyplot as plt from matplotlib import offsetbox from sklearn import (manifold, datasets, decomposition, ensemble, - discriminant_analysis, random_projection) + discriminant_analysis, random_projection, neighbors) digits = datasets.load_digits(n_class=6) X = digits.data @@ -39,7 +44,7 @@ n_neighbors = 30 -#---------------------------------------------------------------------- +# ---------------------------------------------------------------------- # Scale and visualize the embedding vectors def plot_embedding(X, title=None): x_min, x_max = np.min(X, 0), np.max(X, 0) @@ -70,7 +75,7 @@ def plot_embedding(X, title=None): plt.title(title) -#---------------------------------------------------------------------- +# ---------------------------------------------------------------------- # Plot images of the digits n_img_per_row = 20 img = np.zeros((10 * n_img_per_row, 10 * n_img_per_row)) @@ -86,7 +91,7 @@ def plot_embedding(X, title=None): plt.title('A selection from the 64-dimensional digits dataset') -#---------------------------------------------------------------------- +# ---------------------------------------------------------------------- # Random 2D projection using a random unitary matrix print("Computing random projection") rp = random_projection.SparseRandomProjection(n_components=2, random_state=42) @@ -104,7 +109,7 @@ def plot_embedding(X, title=None): "Principal Components projection of the digits (time %.2fs)" % (time() - t0)) -#---------------------------------------------------------------------- +# ---------------------------------------------------------------------- # Projection on to the first 2 linear discriminant components print("Computing Linear Discriminant Analysis projection") @@ -117,9 +122,9 @@ def plot_embedding(X, title=None): (time() - t0)) -#---------------------------------------------------------------------- +# ---------------------------------------------------------------------- # Isomap projection of the digits dataset -print("Computing Isomap embedding") +print("Computing Isomap projection") t0 = time() X_iso = manifold.Isomap(n_neighbors, n_components=2).fit_transform(X) print("Done.") @@ -128,7 +133,7 @@ def plot_embedding(X, title=None): (time() - t0)) -#---------------------------------------------------------------------- +# ---------------------------------------------------------------------- # Locally linear embedding of the digits dataset print("Computing LLE embedding") clf = manifold.LocallyLinearEmbedding(n_neighbors, n_components=2, @@ -141,7 +146,7 @@ def plot_embedding(X, title=None): (time() - t0)) -#---------------------------------------------------------------------- +# ---------------------------------------------------------------------- # Modified Locally linear embedding of the digits dataset print("Computing modified LLE embedding") clf = manifold.LocallyLinearEmbedding(n_neighbors, n_components=2, @@ -154,7 +159,7 @@ def plot_embedding(X, title=None): (time() - t0)) -#---------------------------------------------------------------------- +# ---------------------------------------------------------------------- # HLLE embedding of the digits dataset print("Computing Hessian LLE embedding") clf = manifold.LocallyLinearEmbedding(n_neighbors, n_components=2, @@ -167,7 +172,7 @@ def plot_embedding(X, title=None): (time() - t0)) -#---------------------------------------------------------------------- +# ---------------------------------------------------------------------- # LTSA embedding of the digits dataset print("Computing LTSA embedding") clf = manifold.LocallyLinearEmbedding(n_neighbors, n_components=2, @@ -179,7 +184,7 @@ def plot_embedding(X, title=None): "Local Tangent Space Alignment of the digits (time %.2fs)" % (time() - t0)) -#---------------------------------------------------------------------- +# ---------------------------------------------------------------------- # MDS embedding of the digits dataset print("Computing MDS embedding") clf = manifold.MDS(n_components=2, n_init=1, max_iter=100) @@ -190,7 +195,7 @@ def plot_embedding(X, title=None): "MDS embedding of the digits (time %.2fs)" % (time() - t0)) -#---------------------------------------------------------------------- +# ---------------------------------------------------------------------- # Random Trees embedding of the digits dataset print("Computing Totally Random Trees embedding") hasher = ensemble.RandomTreesEmbedding(n_estimators=200, random_state=0, @@ -204,7 +209,7 @@ def plot_embedding(X, title=None): "Random forest embedding of the digits (time %.2fs)" % (time() - t0)) -#---------------------------------------------------------------------- +# ---------------------------------------------------------------------- # Spectral embedding of the digits dataset print("Computing Spectral embedding") embedder = manifold.SpectralEmbedding(n_components=2, random_state=0, @@ -216,7 +221,7 @@ def plot_embedding(X, title=None): "Spectral embedding of the digits (time %.2fs)" % (time() - t0)) -#---------------------------------------------------------------------- +# ---------------------------------------------------------------------- # t-SNE embedding of the digits dataset print("Computing t-SNE embedding") tsne = manifold.TSNE(n_components=2, init='pca', random_state=0) @@ -227,4 +232,15 @@ def plot_embedding(X, title=None): "t-SNE embedding of the digits (time %.2fs)" % (time() - t0)) +# ---------------------------------------------------------------------- +# NCA projection of the digits dataset +print("Computing NCA projection") +nca = neighbors.NeighborhoodComponentsAnalysis(n_components=2, random_state=0) +t0 = time() +X_nca = nca.fit_transform(X, y) + +plot_embedding(X_nca, + "NCA embedding of the digits (time %.2fs)" % + (time() - t0)) + plt.show() diff --git a/examples/neighbors/plot_nca_classification.py b/examples/neighbors/plot_nca_classification.py new file mode 100644 index 0000000000000..5536e8eb69e89 --- /dev/null +++ b/examples/neighbors/plot_nca_classification.py @@ -0,0 +1,88 @@ +""" +============================================================================= +Comparing Nearest Neighbors with and without Neighborhood Components Analysis +============================================================================= + +An example comparing nearest neighbors classification with and without +Neighborhood Components Analysis. + +It will plot the class decision boundaries given by a Nearest Neighbors +classifier when using the Euclidean distance on the original features, versus +using the Euclidean distance after the transformation learned by Neighborhood +Components Analysis. The latter aims to find a linear transformation that +maximises the (stochastic) nearest neighbor classification accuracy on the +training set. +""" + +# License: BSD 3 clause + +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.colors import ListedColormap +from sklearn import datasets +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler +from sklearn.neighbors import (KNeighborsClassifier, + NeighborhoodComponentsAnalysis) +from sklearn.pipeline import Pipeline + + +print(__doc__) + +n_neighbors = 1 + +dataset = datasets.load_iris() +X, y = dataset.data, dataset.target + +# we only take two features. We could avoid this ugly +# slicing by using a two-dim dataset +X = X[:, [0, 2]] + +X_train, X_test, y_train, y_test = \ + train_test_split(X, y, stratify=y, test_size=0.7, random_state=42) + +h = .01 # step size in the mesh + +# Create color maps +cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF']) +cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF']) + +names = ['KNN', 'NCA, KNN'] + +classifiers = [Pipeline([('scaler', StandardScaler()), + ('knn', KNeighborsClassifier(n_neighbors=n_neighbors)) + ]), + Pipeline([('scaler', StandardScaler()), + ('nca', NeighborhoodComponentsAnalysis()), + ('knn', KNeighborsClassifier(n_neighbors=n_neighbors)) + ]) + ] + +x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 +y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 +xx, yy = np.meshgrid(np.arange(x_min, x_max, h), + np.arange(y_min, y_max, h)) + +for name, clf in zip(names, classifiers): + + clf.fit(X_train, y_train) + score = clf.score(X_test, y_test) + + # Plot the decision boundary. For that, we will assign a color to each + # point in the mesh [x_min, x_max]x[y_min, y_max]. + Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) + + # Put the result into a color plot + Z = Z.reshape(xx.shape) + plt.figure() + plt.pcolormesh(xx, yy, Z, cmap=cmap_light, alpha=.8) + + # Plot also the training and testing points + plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold, edgecolor='k', s=20) + plt.xlim(xx.min(), xx.max()) + plt.ylim(yy.min(), yy.max()) + plt.title("{} (k = {})".format(name, n_neighbors)) + plt.text(0.9, 0.1, '{:.2f}'.format(score), size=15, + ha='center', va='center', transform=plt.gca().transAxes) + +plt.show() diff --git a/examples/neighbors/plot_nca_dim_reduction.py b/examples/neighbors/plot_nca_dim_reduction.py new file mode 100644 index 0000000000000..ea06b2768e851 --- /dev/null +++ b/examples/neighbors/plot_nca_dim_reduction.py @@ -0,0 +1,101 @@ +""" +============================================================== +Dimensionality Reduction with Neighborhood Components Analysis +============================================================== + +Sample usage of Neighborhood Components Analysis for dimensionality reduction. + +This example compares different (linear) dimensionality reduction methods +applied on the Digits data set. The data set contains images of digits from +0 to 9 with approximately 180 samples of each class. Each image is of +dimension 8x8 = 64, and is reduced to a two-dimensional data point. + +Principal Component Analysis (PCA) applied to this data identifies the +combination of attributes (principal components, or directions in the +feature space) that account for the most variance in the data. Here we +plot the different samples on the 2 first principal components. + +Linear Discriminant Analysis (LDA) tries to identify attributes that +account for the most variance *between classes*. In particular, +LDA, in contrast to PCA, is a supervised method, using known class labels. + +Neighborhood Components Analysis (NCA) tries to find a feature space such +that a stochastic nearest neighbor algorithm will give the best accuracy. +Like LDA, it is a supervised method. + +One can see that NCA enforces a clustering of the data that is visually +meaningful despite the large reduction in dimension. +""" + +# License: BSD 3 clause + +import numpy as np +import matplotlib.pyplot as plt +from sklearn import datasets +from sklearn.model_selection import train_test_split +from sklearn.decomposition import PCA +from sklearn.discriminant_analysis import LinearDiscriminantAnalysis +from sklearn.neighbors import (KNeighborsClassifier, + NeighborhoodComponentsAnalysis) +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler + +print(__doc__) + +n_neighbors = 3 +random_state = 0 + +# Load Digits dataset +digits = datasets.load_digits() +X, y = digits.data, digits.target + +# Split into train/test +X_train, X_test, y_train, y_test = \ + train_test_split(X, y, test_size=0.5, stratify=y, + random_state=random_state) + +dim = len(X[0]) +n_classes = len(np.unique(y)) + +# Reduce dimension to 2 with PCA +pca = make_pipeline(StandardScaler(), + PCA(n_components=2, random_state=random_state)) + +# Reduce dimension to 2 with LinearDiscriminantAnalysis +lda = make_pipeline(StandardScaler(), + LinearDiscriminantAnalysis(n_components=2)) + +# Reduce dimension to 2 with NeighborhoodComponentAnalysis +nca = make_pipeline(StandardScaler(), + NeighborhoodComponentsAnalysis(n_components=2, + random_state=random_state)) + +# Use a nearest neighbor classifier to evaluate the methods +knn = KNeighborsClassifier(n_neighbors=n_neighbors) + +# Make a list of the methods to be compared +dim_reduction_methods = [('PCA', pca), ('LDA', lda), ('NCA', nca)] + +# plt.figure() +for i, (name, model) in enumerate(dim_reduction_methods): + plt.figure() + # plt.subplot(1, 3, i + 1, aspect=1) + + # Fit the method's model + model.fit(X_train, y_train) + + # Fit a nearest neighbor classifier on the embedded training set + knn.fit(model.transform(X_train), y_train) + + # Compute the nearest neighbor accuracy on the embedded test set + acc_knn = knn.score(model.transform(X_test), y_test) + + # Embed the data set in 2 dimensions using the fitted model + X_embedded = model.transform(X) + + # Plot the projected points and show the evaluation score + plt.scatter(X_embedded[:, 0], X_embedded[:, 1], c=y, s=30, cmap='Set1') + plt.title("{}, KNN (k={})\nTest accuracy = {:.2f}".format(name, + n_neighbors, + acc_knn)) +plt.show() diff --git a/examples/neighbors/plot_nca_illustration.py b/examples/neighbors/plot_nca_illustration.py new file mode 100644 index 0000000000000..bc020fc4a1d40 --- /dev/null +++ b/examples/neighbors/plot_nca_illustration.py @@ -0,0 +1,98 @@ +""" +============================================= +Neighborhood Components Analysis Illustration +============================================= + +An example illustrating the goal of learning a distance metric that maximizes +the nearest neighbors classification accuracy. The example is solely for +illustration purposes. Please refer to the :ref:`User Guide ` for +more information. +""" + +# License: BSD 3 clause + +import numpy as np +import matplotlib.pyplot as plt +from sklearn.datasets import make_classification +from sklearn.neighbors import NeighborhoodComponentsAnalysis +from matplotlib import cm +from sklearn.utils.fixes import logsumexp + +print(__doc__) + +n_neighbors = 1 +random_state = 0 + +# Create a tiny data set of 9 samples from 3 classes +X, y = make_classification(n_samples=9, n_features=2, n_informative=2, + n_redundant=0, n_classes=3, n_clusters_per_class=1, + class_sep=1.0, random_state=random_state) + +# Plot the points in the original space +plt.figure() +ax = plt.gca() + +# Draw the graph nodes +for i in range(X.shape[0]): + ax.text(X[i, 0], X[i, 1], str(i), va='center', ha='center') + ax.scatter(X[i, 0], X[i, 1], s=300, c=cm.Set1(y[i]), alpha=0.4) + + +def p_i(X, i): + diff_embedded = X[i] - X + dist_embedded = np.einsum('ij,ij->i', diff_embedded, + diff_embedded) + dist_embedded[i] = np.inf + + # compute exponentiated distances (use the log-sum-exp trick to + # avoid numerical instabilities + exp_dist_embedded = np.exp(-dist_embedded - + logsumexp(-dist_embedded)) + return exp_dist_embedded + + +def relate_point(X, i, ax): + pt_i = X[i] + for j, pt_j in enumerate(X): + thickness = p_i(X, i) + if i != j: + line = ([pt_i[0], pt_j[0]], [pt_i[1], pt_j[1]]) + ax.plot(*line, c=cm.Set1(y[j]), + linewidth=5*thickness[j]) + + +# we consider only point 3 +i = 3 + +# Plot bonds linked to sample i in the original space +relate_point(X, i, ax) +ax.set_title("Original points") +ax.axes.get_xaxis().set_visible(False) +ax.axes.get_yaxis().set_visible(False) +ax.axis('equal') + +# Learn an embedding with NeighborhoodComponentsAnalysis +nca = NeighborhoodComponentsAnalysis(max_iter=30, random_state=random_state) +nca = nca.fit(X, y) + +# Plot the points after transformation with NeighborhoodComponentsAnalysis +plt.figure() +ax2 = plt.gca() + +# Get the embedding and find the new nearest neighbors +X_embedded = nca.transform(X) + +relate_point(X_embedded, i, ax2) + +for i in range(len(X)): + ax2.text(X_embedded[i, 0], X_embedded[i, 1], str(i), + va='center', ha='center') + ax2.scatter(X_embedded[i, 0], X_embedded[i, 1], s=300, c=cm.Set1(y[i]), + alpha=0.4) + +# Make axes equal so that boundaries are displayed correctly as circles +ax2.set_title("NCA embedding") +ax2.axes.get_xaxis().set_visible(False) +ax2.axes.get_yaxis().set_visible(False) +ax2.axis('equal') +plt.show() diff --git a/sklearn/neighbors/__init__.py b/sklearn/neighbors/__init__.py index 51116b3f470e6..550cab3c01bca 100644 --- a/sklearn/neighbors/__init__.py +++ b/sklearn/neighbors/__init__.py @@ -13,6 +13,7 @@ from .nearest_centroid import NearestCentroid from .kde import KernelDensity from .lof import LocalOutlierFactor +from .nca import NeighborhoodComponentsAnalysis from .base import VALID_METRICS, VALID_METRICS_SPARSE __all__ = ['BallTree', @@ -28,5 +29,6 @@ 'radius_neighbors_graph', 'KernelDensity', 'LocalOutlierFactor', + 'NeighborhoodComponentsAnalysis', 'VALID_METRICS', 'VALID_METRICS_SPARSE'] diff --git a/sklearn/neighbors/nca.py b/sklearn/neighbors/nca.py new file mode 100644 index 0000000000000..38f62886807f2 --- /dev/null +++ b/sklearn/neighbors/nca.py @@ -0,0 +1,517 @@ +# coding: utf-8 +""" +Neighborhood Component Analysis +""" + +# Authors: William de Vazelhes +# John Chiotellis +# License: BSD 3 clause + +from __future__ import print_function + +from warnings import warn +import numpy as np +import sys +import time +from scipy.optimize import minimize +from ..utils.extmath import softmax +from ..metrics import pairwise_distances +from ..base import BaseEstimator, TransformerMixin +from ..preprocessing import LabelEncoder +from ..decomposition import PCA +from ..utils.multiclass import check_classification_targets +from ..utils.random import check_random_state +from ..utils.validation import (check_is_fitted, check_array, check_X_y, + check_scalar) +from ..externals.six import integer_types +from ..exceptions import ConvergenceWarning + + +class NeighborhoodComponentsAnalysis(BaseEstimator, TransformerMixin): + """Neighborhood Components Analysis + + Neighborhood Component Analysis (NCA) is a machine learning algorithm for + metric learning. It learns a linear transformation in a supervised fashion + to improve the classification accuracy of a stochastic nearest neighbors + rule in the transformed space. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + n_components : int, optional (default=None) + Preferred dimensionality of the projected space. + If None it will be set to ``n_features``. + + init : string or numpy array, optional (default='auto') + Initialization of the linear transformation. Possible options are + 'auto', 'pca', 'lda', 'identity', 'random', and a numpy array of shape + (n_features_a, n_features_b). + + 'auto' + Depending on ``n_components``, the most reasonable initialization + will be chosen. If ``n_components <= n_classes`` we use 'lda', as + it uses labels information. If not, but + ``n_components < min(n_features, n_samples)``, we use 'pca', as + it projects data in meaningful directions (those of higher + variance). Otherwise, we just use 'identity'. + + 'pca' + ``n_components`` principal components of the inputs passed + to :meth:`fit` will be used to initialize the transformation. + (See `decomposition.PCA`) + + 'lda' + ``min(n_components, n_classes)`` most discriminative + components of the inputs passed to :meth:`fit` will be used to + initialize the transformation. (If ``n_components > n_classes``, + the rest of the components will be zero.) (See + `discriminant_analysis.LinearDiscriminantAnalysis`) + + 'identity' + If ``n_components`` is strictly smaller than the + dimensionality of the inputs passed to :meth:`fit`, the identity + matrix will be truncated to the first ``n_components`` rows. + + 'random' + The initial transformation will be a random array of shape + `(n_components, n_features)`. Each value is sampled from the + standard normal distribution. + + numpy array + n_features_b must match the dimensionality of the inputs passed to + :meth:`fit` and n_features_a must be less than or equal to that. + If ``n_components`` is not None, n_features_a must match it. + + warm_start : bool, optional, (default=False) + If True and :meth:`fit` has been called before, the solution of the + previous call to :meth:`fit` is used as the initial linear + transformation (``n_components`` and ``init`` will be ignored). + + max_iter : int, optional (default=50) + Maximum number of iterations in the optimization. + + tol : float, optional (default=1e-5) + Convergence tolerance for the optimization. + + callback : callable, optional (default=None) + If not None, this function is called after every iteration of the + optimizer, taking as arguments the current solution (flattened + transformation matrix) and the number of iterations. This might be + useful in case one wants to examine or store the transformation + found after each iteration. + + verbose : int, optional (default=0) + If 0, no progress messages will be printed. + If 1, progress messages will be printed to stdout. + If > 1, progress messages will be printed and the ``disp`` + parameter of :func:`scipy.optimize.minimize` will be set to + ``verbose - 2``. + + random_state : int or numpy.RandomState or None, optional (default=None) + A pseudo random number generator object or a seed for it if int. If + ``init='random'``, ``random_state`` is used to initialize the random + transformation. If ``init='pca'``, ``random_state`` is passed as an + argument to PCA when initializing the transformation. + + Attributes + ---------- + components_ : array, shape (n_components, n_features) + The linear transformation learned during fitting. + + n_iter_ : int + Counts the number of iterations performed by the optimizer. + + Examples + -------- + >>> from sklearn.neighbors.nca import NeighborhoodComponentsAnalysis + >>> from sklearn.neighbors import KNeighborsClassifier + >>> from sklearn.datasets import load_iris + >>> from sklearn.model_selection import train_test_split + >>> X, y = load_iris(return_X_y=True) + >>> X_train, X_test, y_train, y_test = train_test_split(X, y, + ... stratify=y, test_size=0.7, random_state=42) + >>> nca = NeighborhoodComponentsAnalysis(random_state=42) + >>> nca.fit(X_train, y_train) # doctest: +ELLIPSIS + NeighborhoodComponentsAnalysis(...) + >>> knn = KNeighborsClassifier(n_neighbors=3) + >>> knn.fit(X_train, y_train) # doctest: +ELLIPSIS + KNeighborsClassifier(...) + >>> print(knn.score(X_test, y_test)) # doctest: +ELLIPSIS + 0.933333... + >>> knn.fit(nca.transform(X_train), y_train) # doctest: +ELLIPSIS + KNeighborsClassifier(...) + >>> print(knn.score(nca.transform(X_test), y_test)) # doctest: +ELLIPSIS + 0.961904... + + References + ---------- + .. [1] J. Goldberger, G. Hinton, S. Roweis, R. Salakhutdinov. + "Neighbourhood Components Analysis". Advances in Neural Information + Processing Systems. 17, 513-520, 2005. + http://www.cs.nyu.edu/~roweis/papers/ncanips.pdf + + .. [2] Wikipedia entry on Neighborhood Components Analysis + https://en.wikipedia.org/wiki/Neighbourhood_components_analysis + + """ + + def __init__(self, n_components=None, init='auto', warm_start=False, + max_iter=50, tol=1e-5, callback=None, verbose=0, + random_state=None): + self.n_components = n_components + self.init = init + self.warm_start = warm_start + self.max_iter = max_iter + self.tol = tol + self.callback = callback + self.verbose = verbose + self.random_state = random_state + + def fit(self, X, y): + """Fit the model according to the given training data. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + The training samples. + + y : array-like, shape (n_samples,) + The corresponding training labels. + + Returns + ------- + self : object + returns a trained NeighborhoodComponentsAnalysis model. + """ + + # Verify inputs X and y and NCA parameters, and transform a copy if + # needed + X, y, init = self._validate_params(X, y) + + # Initialize the random generator + self.random_state_ = check_random_state(self.random_state) + + # Measure the total training time + t_train = time.time() + + # Compute a mask that stays fixed during optimization: + same_class_mask = y[:, np.newaxis] == y[np.newaxis, :] + # (n_samples, n_samples) + + # Initialize the transformation + transformation = self._initialize(X, y, init) + + # Create a dictionary of parameters to be passed to the optimizer + disp = self.verbose - 2 if self.verbose > 1 else -1 + optimizer_params = {'method': 'L-BFGS-B', + 'fun': self._loss_grad_lbfgs, + 'args': (X, same_class_mask, -1.0), + 'jac': True, + 'x0': transformation, + 'tol': self.tol, + 'options': dict(maxiter=self.max_iter, disp=disp), + 'callback': self._callback + } + + # Call the optimizer + self.n_iter_ = 0 + opt_result = minimize(**optimizer_params) + + # Reshape the solution found by the optimizer + self.components_ = opt_result.x.reshape(-1, X.shape[1]) + + # Stop timer + t_train = time.time() - t_train + if self.verbose: + cls_name = self.__class__.__name__ + + # Warn the user if the algorithm did not converge + if not opt_result.success: + warn('[{}] NCA did not converge: {}'.format( + cls_name, opt_result.message), + ConvergenceWarning) + + print('[{}] Training took {:8.2f}s.'.format(cls_name, t_train)) + + return self + + def transform(self, X): + """Applies the learned transformation to the given data. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + Data samples. + + Returns + ------- + X_embedded: array, shape (n_samples, n_components) + The data samples transformed. + + Raises + ------ + NotFittedError + If :meth:`fit` has not been called before. + """ + + check_is_fitted(self, ['components_']) + X = check_array(X) + + return np.dot(X, self.components_.T) + + def _validate_params(self, X, y): + """Validate parameters as soon as :meth:`fit` is called. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + The training samples. + + y : array-like, shape (n_samples,) + The corresponding training labels. + + Returns + ------- + X : array, shape (n_samples, n_features) + The validated training samples. + + y : array, shape (n_samples,) + The validated training labels, encoded to be integers in + the range(0, n_classes). + + init : string or numpy array of shape (n_features_a, n_features_b) + The validated initialization of the linear transformation. + + Raises + ------- + TypeError + If a parameter is not an instance of the desired type. + + ValueError + If a parameter's value violates its legal value range or if the + combination of two or more given parameters is incompatible. + """ + + # Validate the inputs X and y, and converts y to numerical classes. + X, y = check_X_y(X, y, ensure_min_samples=2) + check_classification_targets(y) + y = LabelEncoder().fit_transform(y) + + # Check the preferred dimensionality of the projected space + if self.n_components is not None: + check_scalar(self.n_components, 'n_components', + integer_types, 1) + + if self.n_components > X.shape[1]: + raise ValueError('The preferred dimensionality of the ' + 'projected space `n_components` ({}) cannot ' + 'be greater than the given data ' + 'dimensionality ({})!' + .format(self.n_components, X.shape[1])) + + # If warm_start is enabled, check that the inputs are consistent + check_scalar(self.warm_start, 'warm_start', bool) + if self.warm_start and hasattr(self, 'components_'): + if self.components_.shape[1] != X.shape[1]: + raise ValueError('The new inputs dimensionality ({}) does not ' + 'match the input dimensionality of the ' + 'previously learned transformation ({}).' + .format(X.shape[1], + self.components_.shape[1])) + + check_scalar(self.max_iter, 'max_iter', integer_types, 1) + check_scalar(self.tol, 'tol', float, 0.) + check_scalar(self.verbose, 'verbose', integer_types, 0) + + if self.callback is not None: + if not callable(self.callback): + raise ValueError('`callback` is not callable.') + + # Check how the linear transformation should be initialized + init = self.init + + if isinstance(init, np.ndarray): + init = check_array(init) + + # Assert that init.shape[1] = X.shape[1] + if init.shape[1] != X.shape[1]: + raise ValueError( + 'The input dimensionality ({}) of the given ' + 'linear transformation `init` must match the ' + 'dimensionality of the given inputs `X` ({}).' + .format(init.shape[1], X.shape[1])) + + # Assert that init.shape[0] <= init.shape[1] + if init.shape[0] > init.shape[1]: + raise ValueError( + 'The output dimensionality ({}) of the given ' + 'linear transformation `init` cannot be ' + 'greater than its input dimensionality ({}).' + .format(init.shape[0], init.shape[1])) + + if self.n_components is not None: + # Assert that self.n_components = init.shape[0] + if self.n_components != init.shape[0]: + raise ValueError('The preferred dimensionality of the ' + 'projected space `n_components` ({}) does' + ' not match the output dimensionality of ' + 'the given linear transformation ' + '`init` ({})!' + .format(self.n_components, + init.shape[0])) + elif init in ['auto', 'pca', 'lda', 'identity', 'random']: + pass + else: + raise ValueError( + "`init` must be 'auto', 'pca', 'lda', 'identity', 'random' " + "or a numpy array of shape (n_components, n_features).") + + return X, y, init + + def _initialize(self, X, y, init): + """Initialize the transformation. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + The training samples. + + y : array-like, shape (n_samples,) + The training labels. + + init : string or numpy array of shape (n_features_a, n_features_b) + The validated initialization of the linear transformation. + + Returns + ------- + transformation : array, shape (n_components, n_features) + The initialized linear transformation. + + """ + + transformation = init + if self.warm_start and hasattr(self, 'components_'): + transformation = self.components_ + elif isinstance(init, np.ndarray): + pass + else: + n_samples, n_features = X.shape + n_components = self.n_components or n_features + if init == 'auto': + n_classes = len(np.unique(y)) + if n_components <= min(n_features, n_classes - 1): + init = 'lda' + elif n_components < min(n_features, n_samples): + init = 'pca' + else: + init = 'identity' + if init == 'identity': + transformation = np.eye(n_components, X.shape[1]) + elif init == 'random': + transformation = self.random_state_.randn(n_components, + X.shape[1]) + elif init in {'pca', 'lda'}: + init_time = time.time() + if init == 'pca': + pca = PCA(n_components=n_components, + random_state=self.random_state_) + if self.verbose: + print('Finding principal components... ', end='') + sys.stdout.flush() + pca.fit(X) + transformation = pca.components_ + elif init == 'lda': + from ..discriminant_analysis import ( + LinearDiscriminantAnalysis) + lda = LinearDiscriminantAnalysis(n_components=n_components) + if self.verbose: + print('Finding most discriminative components... ', + end='') + sys.stdout.flush() + lda.fit(X, y) + transformation = lda.scalings_.T[:n_components] + if self.verbose: + print('done in {:5.2f}s'.format(time.time() - init_time)) + return transformation + + def _callback(self, transformation): + """Called after each iteration of the optimizer. + + Parameters + ---------- + transformation : array, shape=(n_components * n_features,) + The solution computed by the optimizer in this iteration. + """ + if self.callback is not None: + self.callback(transformation, self.n_iter_) + + self.n_iter_ += 1 + + def _loss_grad_lbfgs(self, transformation, X, same_class_mask, sign=1.0): + """Compute the loss and the loss gradient w.r.t. ``transformation``. + + Parameters + ---------- + transformation : array, shape (n_components * n_features,) + The raveled linear transformation on which to compute loss and + evaluate gradient. + + X : array, shape (n_samples, n_features) + The training samples. + + same_class_mask : array, shape (n_samples, n_samples) + A mask where ``mask[i, j] == 1`` if ``X[i]`` and ``X[j]`` belong + to the same class, and ``0`` otherwise. + + Returns + ------- + loss : float + The loss computed for the given transformation. + + gradient : array, shape (n_components * n_features,) + The new (flattened) gradient of the loss. + """ + + if self.n_iter_ == 0: + self.n_iter_ += 1 + if self.verbose: + header_fields = ['Iteration', 'Objective Value', 'Time(s)'] + header_fmt = '{:>10} {:>20} {:>10}' + header = header_fmt.format(*header_fields) + cls_name = self.__class__.__name__ + print('[{}]'.format(cls_name)) + print('[{}] {}\n[{}] {}'.format(cls_name, header, + cls_name, '-' * len(header))) + + t_funcall = time.time() + + transformation = transformation.reshape(-1, X.shape[1]) + X_embedded = np.dot(X, transformation.T) # (n_samples, n_components) + + # Compute softmax distances + p_ij = pairwise_distances(X_embedded, squared=True) + np.fill_diagonal(p_ij, np.inf) + p_ij = softmax(-p_ij) # (n_samples, n_samples) + + # Compute loss + masked_p_ij = p_ij * same_class_mask + p = np.sum(masked_p_ij, axis=1, keepdims=True) # (n_samples, 1) + loss = np.sum(p) + + # Compute gradient of loss w.r.t. `transform` + weighted_p_ij = masked_p_ij - p_ij * p + weighted_p_ij_sym = weighted_p_ij + weighted_p_ij.T + np.fill_diagonal(weighted_p_ij_sym, -weighted_p_ij.sum(axis=0)) + gradient = 2 * X_embedded.T.dot(weighted_p_ij_sym).dot(X) + # time complexity of the gradient: O(n_components x n_samples x ( + # n_samples + n_features)) + + if self.verbose: + t_funcall = time.time() - t_funcall + values_fmt = '[{}] {:>10} {:>20.6e} {:>10.2f}' + print(values_fmt.format(self.__class__.__name__, self.n_iter_, + loss, t_funcall)) + sys.stdout.flush() + + return sign * loss, sign * gradient.ravel() diff --git a/sklearn/neighbors/tests/test_nca.py b/sklearn/neighbors/tests/test_nca.py new file mode 100644 index 0000000000000..2397af5bc0ed1 --- /dev/null +++ b/sklearn/neighbors/tests/test_nca.py @@ -0,0 +1,520 @@ +# coding: utf-8 +""" +Testing for Neighborhood Component Analysis module (sklearn.neighbors.nca) +""" + +# Authors: William de Vazelhes +# John Chiotellis +# License: BSD 3 clause + +import pytest +import re +import numpy as np +from numpy.testing import assert_array_equal, assert_array_almost_equal +from scipy.optimize import check_grad +from sklearn import clone +from sklearn.exceptions import ConvergenceWarning +from sklearn.utils import check_random_state +from sklearn.utils.testing import (assert_raises, assert_equal, + assert_raise_message, assert_warns_message) +from sklearn.datasets import load_iris, make_classification, make_blobs +from sklearn.neighbors.nca import NeighborhoodComponentsAnalysis +from sklearn.metrics import pairwise_distances + + +rng = check_random_state(0) +# load and shuffle iris dataset +iris = load_iris() +perm = rng.permutation(iris.target.size) +iris_data = iris.data[perm] +iris_target = iris.target[perm] +EPS = np.finfo(float).eps + + +def test_simple_example(): + """Test on a simple example. + + Puts four points in the input space where the opposite labels points are + next to each other. After transform the samples from the same class + should be next to each other. + + """ + X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) + y = np.array([1, 0, 1, 0]) + nca = NeighborhoodComponentsAnalysis(n_components=2, init='identity', + random_state=42) + nca.fit(X, y) + X_t = nca.transform(X) + assert_array_equal(pairwise_distances(X_t).argsort()[:, 1], + np.array([2, 3, 0, 1])) + + +def test_toy_example_collapse_points(): + """Test on a toy example of three points that should collapse + + We build a simple example: two points from the same class and a point from + a different class in the middle of them. On this simple example, the new + (transformed) points should all collapse into one single point. Indeed, the + objective is 2/(1 + exp(d/2)), with d the euclidean distance between the + two samples from the same class. This is maximized for d=0 (because d>=0), + with an objective equal to 1 (loss=-1.). + + """ + rng = np.random.RandomState(42) + input_dim = 5 + two_points = rng.randn(2, input_dim) + X = np.vstack([two_points, two_points.mean(axis=0)[np.newaxis, :]]) + y = [0, 0, 1] + + class LossStorer: + + def __init__(self, X, y): + self.loss = np.inf # initialize the loss to very high + # Initialize a fake NCA and variables needed to compute the loss: + self.fake_nca = NeighborhoodComponentsAnalysis() + self.fake_nca.n_iter_ = np.inf + self.X, y, _ = self.fake_nca._validate_params(X, y) + self.same_class_mask = y[:, np.newaxis] == y[np.newaxis, :] + + def callback(self, transformation, n_iter): + """Stores the last value of the loss function""" + self.loss, _ = self.fake_nca._loss_grad_lbfgs(transformation, + self.X, + self.same_class_mask, + -1.0) + + loss_storer = LossStorer(X, y) + nca = NeighborhoodComponentsAnalysis(random_state=42, + callback=loss_storer.callback) + X_t = nca.fit_transform(X, y) + print(X_t) + # test that points are collapsed into one point + assert_array_almost_equal(X_t - X_t[0], 0.) + assert abs(loss_storer.loss + 1) < 1e-10 + + +def test_finite_differences(): + """Test gradient of loss function + + Assert that the gradient is almost equal to its finite differences + approximation. + """ + # Initialize the transformation `M`, as well as `X` and `y` and `NCA` + rng = np.random.RandomState(42) + X, y = make_classification() + M = rng.randn(rng.randint(1, X.shape[1] + 1), + X.shape[1]) + nca = NeighborhoodComponentsAnalysis() + nca.n_iter_ = 0 + mask = y[:, np.newaxis] == y[np.newaxis, :] + + def fun(M): + return nca._loss_grad_lbfgs(M, X, mask)[0] + + def grad(M): + return nca._loss_grad_lbfgs(M, X, mask)[1] + + # compute relative error + rel_diff = check_grad(fun, grad, M.ravel()) / np.linalg.norm(grad(M)) + np.testing.assert_almost_equal(rel_diff, 0., decimal=5) + + +def test_params_validation(): + # Test that invalid parameters raise value error + X = np.arange(12).reshape(4, 3) + y = [1, 1, 2, 2] + NCA = NeighborhoodComponentsAnalysis + rng = np.random.RandomState(42) + + # TypeError + assert_raises(TypeError, NCA(max_iter='21').fit, X, y) + assert_raises(TypeError, NCA(verbose='true').fit, X, y) + assert_raises(TypeError, NCA(tol=1).fit, X, y) + assert_raises(TypeError, NCA(n_components='invalid').fit, X, y) + assert_raises(TypeError, NCA(warm_start=1).fit, X, y) + + # ValueError + assert_raise_message(ValueError, + "`init` must be 'auto', 'pca', 'lda', 'identity', " + "'random' or a numpy array of shape " + "(n_components, n_features).", + NCA(init=1).fit, X, y) + assert_raise_message(ValueError, + '`max_iter`= -1, must be >= 1.', + NCA(max_iter=-1).fit, X, y) + + init = rng.rand(5, 3) + assert_raise_message(ValueError, + 'The output dimensionality ({}) of the given linear ' + 'transformation `init` cannot be greater than its ' + 'input dimensionality ({}).' + .format(init.shape[0], init.shape[1]), + NCA(init=init).fit, X, y) + + n_components = 10 + assert_raise_message(ValueError, + 'The preferred dimensionality of the ' + 'projected space `n_components` ({}) cannot ' + 'be greater than the given data ' + 'dimensionality ({})!' + .format(n_components, X.shape[1]), + NCA(n_components=n_components).fit, X, y) + + +def test_transformation_dimensions(): + X = np.arange(12).reshape(4, 3) + y = [1, 1, 2, 2] + + # Fail if transformation input dimension does not match inputs dimensions + transformation = np.array([[1, 2], [3, 4]]) + assert_raises(ValueError, + NeighborhoodComponentsAnalysis(init=transformation).fit, + X, y) + + # Fail if transformation output dimension is larger than + # transformation input dimension + transformation = np.array([[1, 2], [3, 4], [5, 6]]) + # len(transformation) > len(transformation[0]) + assert_raises(ValueError, + NeighborhoodComponentsAnalysis(init=transformation).fit, + X, y) + + # Pass otherwise + transformation = np.arange(9).reshape(3, 3) + NeighborhoodComponentsAnalysis(init=transformation).fit(X, y) + + +def test_n_components(): + rng = np.random.RandomState(42) + X = np.arange(12).reshape(4, 3) + y = [1, 1, 2, 2] + + init = rng.rand(X.shape[1] - 1, 3) + + # n_components = X.shape[1] != transformation.shape[0] + n_components = X.shape[1] + nca = NeighborhoodComponentsAnalysis(init=init, n_components=n_components) + assert_raise_message(ValueError, + 'The preferred dimensionality of the ' + 'projected space `n_components` ({}) does not match ' + 'the output dimensionality of the given ' + 'linear transformation `init` ({})!' + .format(n_components, init.shape[0]), + nca.fit, X, y) + + # n_components > X.shape[1] + n_components = X.shape[1] + 2 + nca = NeighborhoodComponentsAnalysis(init=init, n_components=n_components) + assert_raise_message(ValueError, + 'The preferred dimensionality of the ' + 'projected space `n_components` ({}) cannot ' + 'be greater than the given data ' + 'dimensionality ({})!' + .format(n_components, X.shape[1]), + nca.fit, X, y) + + # n_components < X.shape[1] + nca = NeighborhoodComponentsAnalysis(n_components=2, init='identity') + nca.fit(X, y) + + +def test_init_transformation(): + rng = np.random.RandomState(42) + X, y = make_blobs(n_samples=30, centers=6, n_features=5, random_state=0) + + # Start learning from scratch + nca = NeighborhoodComponentsAnalysis(init='identity') + nca.fit(X, y) + + # Initialize with random + nca_random = NeighborhoodComponentsAnalysis(init='random') + nca_random.fit(X, y) + + # Initialize with auto + nca_auto = NeighborhoodComponentsAnalysis(init='auto') + nca_auto.fit(X, y) + + # Initialize with PCA + nca_pca = NeighborhoodComponentsAnalysis(init='pca') + nca_pca.fit(X, y) + + # Initialize with LDA + nca_lda = NeighborhoodComponentsAnalysis(init='lda') + nca_lda.fit(X, y) + + init = rng.rand(X.shape[1], X.shape[1]) + nca = NeighborhoodComponentsAnalysis(init=init) + nca.fit(X, y) + + # init.shape[1] must match X.shape[1] + init = rng.rand(X.shape[1], X.shape[1] + 1) + nca = NeighborhoodComponentsAnalysis(init=init) + assert_raise_message(ValueError, + 'The input dimensionality ({}) of the given ' + 'linear transformation `init` must match the ' + 'dimensionality of the given inputs `X` ({}).' + .format(init.shape[1], X.shape[1]), + nca.fit, X, y) + + # init.shape[0] must be <= init.shape[1] + init = rng.rand(X.shape[1] + 1, X.shape[1]) + nca = NeighborhoodComponentsAnalysis(init=init) + assert_raise_message(ValueError, + 'The output dimensionality ({}) of the given ' + 'linear transformation `init` cannot be ' + 'greater than its input dimensionality ({}).' + .format(init.shape[0], init.shape[1]), + nca.fit, X, y) + + # init.shape[0] must match n_components + init = rng.rand(X.shape[1], X.shape[1]) + n_components = X.shape[1] - 2 + nca = NeighborhoodComponentsAnalysis(init=init, n_components=n_components) + assert_raise_message(ValueError, + 'The preferred dimensionality of the ' + 'projected space `n_components` ({}) does not match ' + 'the output dimensionality of the given ' + 'linear transformation `init` ({})!' + .format(n_components, init.shape[0]), + nca.fit, X, y) + + +@pytest.mark.parametrize('n_samples', [3, 5, 7, 11]) +@pytest.mark.parametrize('n_features', [3, 5, 7, 11]) +@pytest.mark.parametrize('n_classes', [5, 7, 11]) +@pytest.mark.parametrize('n_components', [3, 5, 7, 11]) +def test_auto_init(n_samples, n_features, n_classes, n_components): + # Test that auto choose the init as expected with every configuration + # of order of n_samples, n_features, n_classes and n_components. + rng = np.random.RandomState(42) + nca_base = NeighborhoodComponentsAnalysis(init='auto', + n_components=n_components, + max_iter=1, + random_state=rng) + if n_classes >= n_samples: + pass + # n_classes > n_samples is impossible, and n_classes == n_samples + # throws an error from lda but is an absurd case + else: + X = rng.randn(n_samples, n_features) + y = np.tile(range(n_classes), n_samples // n_classes + 1)[:n_samples] + if n_components > n_features: + # this would return a ValueError, which is already tested in + # test_params_validation + pass + else: + nca = clone(nca_base) + nca.fit(X, y) + if n_components <= min(n_classes - 1, n_features): + nca_other = clone(nca_base).set_params(init='lda') + elif n_components < min(n_features, n_samples): + nca_other = clone(nca_base).set_params(init='pca') + else: + nca_other = clone(nca_base).set_params(init='identity') + nca_other.fit(X, y) + assert_array_almost_equal(nca.components_, nca_other.components_) + + +def test_warm_start_validation(): + X, y = make_classification(n_samples=30, n_features=5, n_classes=4, + n_redundant=0, n_informative=5, random_state=0) + + nca = NeighborhoodComponentsAnalysis(warm_start=True, max_iter=5) + nca.fit(X, y) + + X_less_features, y = make_classification(n_samples=30, n_features=4, + n_classes=4, n_redundant=0, + n_informative=4, random_state=0) + assert_raise_message(ValueError, + 'The new inputs dimensionality ({}) does not ' + 'match the input dimensionality of the ' + 'previously learned transformation ({}).' + .format(X_less_features.shape[1], + nca.components_.shape[1]), + nca.fit, X_less_features, y) + + +def test_warm_start_effectiveness(): + # A 1-iteration second fit on same data should give almost same result + # with warm starting, and quite different result without warm starting. + + nca_warm = NeighborhoodComponentsAnalysis(warm_start=True, random_state=0) + nca_warm.fit(iris_data, iris_target) + transformation_warm = nca_warm.components_ + nca_warm.max_iter = 1 + nca_warm.fit(iris_data, iris_target) + transformation_warm_plus_one = nca_warm.components_ + + nca_cold = NeighborhoodComponentsAnalysis(warm_start=False, random_state=0) + nca_cold.fit(iris_data, iris_target) + transformation_cold = nca_cold.components_ + nca_cold.max_iter = 1 + nca_cold.fit(iris_data, iris_target) + transformation_cold_plus_one = nca_cold.components_ + + diff_warm = np.sum(np.abs(transformation_warm_plus_one - + transformation_warm)) + diff_cold = np.sum(np.abs(transformation_cold_plus_one - + transformation_cold)) + assert diff_warm < 3.0, ("Transformer changed significantly after one " + "iteration even though it was warm-started.") + + assert diff_cold > diff_warm, ("Cold-started transformer changed less " + "significantly than warm-started " + "transformer after one iteration.") + + +@pytest.mark.parametrize('init_name', ['pca', 'lda', 'identity', 'random', + 'precomputed']) +def test_verbose(init_name, capsys): + # assert there is proper output when verbose = 1, for every initialization + # except auto because auto will call one of the others + rng = np.random.RandomState(42) + X, y = make_blobs(n_samples=30, centers=6, n_features=5, random_state=0) + regexp_init = r'... done in \ *\d+\.\d{2}s' + msgs = {'pca': "Finding principal components" + regexp_init, + 'lda': "Finding most discriminative components" + regexp_init} + if init_name == 'precomputed': + init = rng.randn(X.shape[1], X.shape[1]) + else: + init = init_name + nca = NeighborhoodComponentsAnalysis(verbose=1, init=init) + nca.fit(X, y) + out, _ = capsys.readouterr() + + # check output + lines = re.split('\n+', out) + # if pca or lda init, an additional line is printed, so we test + # it and remove it to test the rest equally among initializations + if init_name in ['pca', 'lda']: + assert re.match(msgs[init_name], lines[0]) + lines = lines[1:] + assert lines[0] == '[NeighborhoodComponentsAnalysis]' + header = '{:>10} {:>20} {:>10}'.format('Iteration', 'Objective Value', + 'Time(s)') + assert lines[1] == '[NeighborhoodComponentsAnalysis] {}'.format(header) + assert lines[2] == ('[NeighborhoodComponentsAnalysis] {}' + .format('-' * len(header))) + for line in lines[3:-2]: + # The following regex will match for instance: + # '[NeighborhoodComponentsAnalysis] 0 6.988936e+01 0.01' + assert re.match(r'\[NeighborhoodComponentsAnalysis\] *\d+ *\d\.\d{6}e' + r'[+|-]\d+\ *\d+\.\d{2}', line) + assert re.match(r'\[NeighborhoodComponentsAnalysis\] Training took\ *' + r'\d+\.\d{2}s\.', lines[-2]) + assert lines[-1] == '' + + +def test_no_verbose(capsys): + # assert by default there is no output (verbose=0) + nca = NeighborhoodComponentsAnalysis() + nca.fit(iris_data, iris_target) + out, _ = capsys.readouterr() + # check output + assert(out == '') + + +def test_singleton_class(): + X = iris_data + y = iris_target + + # one singleton class + singleton_class = 1 + ind_singleton, = np.where(y == singleton_class) + y[ind_singleton] = 2 + y[ind_singleton[0]] = singleton_class + + nca = NeighborhoodComponentsAnalysis(max_iter=30) + nca.fit(X, y) + + # One non-singleton class + ind_1, = np.where(y == 1) + ind_2, = np.where(y == 2) + y[ind_1] = 0 + y[ind_1[0]] = 1 + y[ind_2] = 0 + y[ind_2[0]] = 2 + + nca = NeighborhoodComponentsAnalysis(max_iter=30) + nca.fit(X, y) + + # Only singleton classes + ind_0, = np.where(y == 0) + ind_1, = np.where(y == 1) + ind_2, = np.where(y == 2) + X = X[[ind_0[0], ind_1[0], ind_2[0]]] + y = y[[ind_0[0], ind_1[0], ind_2[0]]] + + nca = NeighborhoodComponentsAnalysis(init='identity', max_iter=30) + nca.fit(X, y) + assert_array_equal(X, nca.transform(X)) + + +def test_one_class(): + X = iris_data[iris_target == 0] + y = iris_target[iris_target == 0] + + nca = NeighborhoodComponentsAnalysis(max_iter=30, + n_components=X.shape[1], + init='identity') + nca.fit(X, y) + assert_array_equal(X, nca.transform(X)) + + +def test_callback(capsys): + X = iris_data + y = iris_target + + nca = NeighborhoodComponentsAnalysis(callback='my_cb') + assert_raises(ValueError, nca.fit, X, y) + + max_iter = 10 + + def my_cb(transformation, n_iter): + assert transformation.shape == (iris_data.shape[1]**2,) + rem_iter = max_iter - n_iter + print('{} iterations remaining...'.format(rem_iter)) + + # assert that my_cb is called + nca = NeighborhoodComponentsAnalysis(max_iter=max_iter, + callback=my_cb, verbose=1) + nca.fit(iris_data, iris_target) + out, _ = capsys.readouterr() + + # check output + assert('{} iterations remaining...'.format(max_iter - 1) in out) + + +def test_expected_transformation_shape(): + """Test that the transformation has the expected shape.""" + X = iris_data + y = iris_target + + class TransformationStorer: + + def __init__(self, X, y): + # Initialize a fake NCA and variables needed to call the loss + # function: + self.fake_nca = NeighborhoodComponentsAnalysis() + self.fake_nca.n_iter_ = np.inf + self.X, y, _ = self.fake_nca._validate_params(X, y) + self.same_class_mask = y[:, np.newaxis] == y[np.newaxis, :] + + def callback(self, transformation, n_iter): + """Stores the last value of the transformation taken as input by + the optimizer""" + self.transformation = transformation + + transformation_storer = TransformationStorer(X, y) + cb = transformation_storer.callback + nca = NeighborhoodComponentsAnalysis(max_iter=5, callback=cb) + nca.fit(X, y) + assert_equal(transformation_storer.transformation.size, X.shape[1]**2) + + +def test_convergence_warning(): + nca = NeighborhoodComponentsAnalysis(max_iter=2, verbose=1) + cls_name = nca.__class__.__name__ + assert_warns_message(ConvergenceWarning, + '[{}] NCA did not converge'.format(cls_name), + nca.fit, iris_data, iris_target) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index ebbdbcaa2b702..6150e017e3e28 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -19,7 +19,7 @@ assert_all_finite, check_random_state, column_or_1d, check_array, check_consistent_length, check_X_y, indexable, - check_symmetric) + check_symmetric, check_scalar) from .. import get_config @@ -60,7 +60,7 @@ class Parallel(_joblib.Parallel): "check_random_state", "compute_class_weight", "compute_sample_weight", "column_or_1d", "safe_indexing", - "check_consistent_length", "check_X_y", 'indexable', + "check_consistent_length", "check_X_y", "check_scalar", 'indexable', "check_symmetric", "indices_to_mask", "deprecated", "cpu_count", "Parallel", "Memory", "delayed", "parallel_backend", "register_parallel_backend", "hash", "effective_n_jobs", diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index e9d766ed44094..e2bc9dd8a58b2 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -39,8 +39,8 @@ assert_all_finite, check_memory, check_non_negative, - _num_samples -) + _num_samples, + check_scalar) import sklearn from sklearn.exceptions import NotFittedError @@ -797,3 +797,34 @@ def __len__(self): X = TestNonNumericShape() assert _num_samples(X) == len(X) + + +@pytest.mark.parametrize('x, target_type, min_val, max_val', + [(3, int, 2, 5), + (2.5, float, 2, 5)]) +def test_check_scalar_valid(x, target_type, min_val, max_val): + """Test that check_scalar returns no error/warning if valid inputs are + provided""" + with pytest.warns(None) as record: + check_scalar(x, "test_name", target_type, min_val, max_val) + assert len(record) == 0 + + +@pytest.mark.parametrize('x, target_name, target_type, min_val, max_val, ' + 'err_msg', + [(1, "test_name1", float, 2, 4, + TypeError("`test_name1` must be an instance of " + ", not .")), + (1, "test_name2", int, 2, 4, + ValueError('`test_name2`= 1, must be >= 2.')), + (5, "test_name3", int, 2, 4, + ValueError('`test_name3`= 5, must be <= 4.'))]) +def test_check_scalar_invalid(x, target_name, target_type, min_val, max_val, + err_msg): + """Test that check_scalar returns the right error if a wrong input is + given""" + with pytest.raises(Exception) as raised_error: + check_scalar(x, target_name, target_type=target_type, + min_val=min_val, max_val=max_val) + assert str(raised_error.value) == str(err_msg) + assert type(raised_error.value) == type(err_msg) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 9810f7f865fc3..96922a8e4af28 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -936,3 +936,45 @@ def check_non_negative(X, whom): if X_min < 0: raise ValueError("Negative values in data passed to %s" % whom) + + +def check_scalar(x, name, target_type, min_val=None, max_val=None): + """Validate scalar parameters type and value. + + Parameters + ---------- + x : object + The scalar parameter to validate. + + name : str + The name of the parameter to be printed in error messages. + + target_type : type or tuple + Acceptable data types for the parameter. + + min_val : float or int, optional (default=None) + The minimum valid value the parameter can take. If None (default) it + is implied that the parameter does not have a lower bound. + + max_val : float or int, optional (default=None) + The maximum valid value the parameter can take. If None (default) it + is implied that the parameter does not have an upper bound. + + Raises + ------- + TypeError + If the parameter's type does not match the desired type. + + ValueError + If the parameter's value violates the given bounds. + """ + + if not isinstance(x, target_type): + raise TypeError('`{}` must be an instance of {}, not {}.' + .format(name, target_type, type(x))) + + if min_val is not None and x < min_val: + raise ValueError('`{}`= {}, must be >= {}.'.format(name, x, min_val)) + + if max_val is not None and x > max_val: + raise ValueError('`{}`= {}, must be <= {}.'.format(name, x, max_val))