From d68d751170b2cbf79ad899b485f1dd43d6ff79a0 Mon Sep 17 00:00:00 2001 From: terrytangyuan Date: Tue, 15 Sep 2015 16:46:28 -0400 Subject: [PATCH 1/5] ENH: Initial code for NCA --- sklearn/metric_learning/NCA.py | 271 ++++++++++++++++++++++++++++++++ sklearn/metric_learning/demo.py | 34 ++++ sklearn/tests/test_nca.py | 24 +++ 3 files changed, 329 insertions(+) create mode 100644 sklearn/metric_learning/NCA.py create mode 100644 sklearn/metric_learning/demo.py create mode 100644 sklearn/tests/test_nca.py diff --git a/sklearn/metric_learning/NCA.py b/sklearn/metric_learning/NCA.py new file mode 100644 index 0000000000000..0dd49f5f67893 --- /dev/null +++ b/sklearn/metric_learning/NCA.py @@ -0,0 +1,271 @@ +# -*- coding: utf-8 -*- +""" +Created on Thu May 22 16:50:00 2014 + +@author: thiolliere +""" +import numpy as np +import scipy.optimize as opt + +class NCAcost(object): + + @staticmethod + def cost(A, X, y, threshold=None): + """Compute the cost function and the gradient + This is the objective function to be minimized + + Parameters: + ----------- + A : array-like + Projection matrix, shape = [dim, n_features] with dim <= n_features + X : array-like + Training data, shape = [n_features, n_samples] + y : array-like + Target values, shape = [n_samples] + + Returns: + -------- + f : float + The value of the objective function + gradf : array-like + The gradient of the objective function, shape = [dim * n_features] + """ + + (D, N) = np.shape(X) + A = np.reshape(A, (np.size(A) / np.size(X, axis=0), np.size(X, axis=0))) + (d, aux) = np.shape(A) + assert D == aux + + AX = np.dot(A, X) + normAX = np.linalg.norm(AX[:, :, None] - AX[:, None, :], axis=0) + + denomSum = np.sum(np.exp(-normAX[:, :]), axis=0) + Pij = np.exp(- normAX) / denomSum[:, None] + if threshold is not None: + Pij[Pij < threshold] = 0 + Pij[Pij > 1-threshold] = 1 + + mask = (y != y[:, None]) + Pijmask = np.ma.masked_array(Pij, mask) + P = np.array(np.sum(Pijmask, axis=1)) + mask = np.negative(mask) + + f = np.sum(P) + + Xi = X[:, :, None] - X[:, None, :] + Xi = np.swapaxes(Xi, 0, 2) + + Xi = Pij[:, :, None] * Xi + + Xij = Xi[:, :, :, None] * Xi[:, :, None, :] + + gradf = np.sum(P[:, None, None] * np.sum(Xij[:], axis=1), axis=0) + + # To optimize (use mask ?) + for i in range(N): + aux = np.sum(Xij[i, mask[i]], axis=0) + gradf -= aux + + gradf = 2 * np.dot(A, gradf) + gradf = -np.reshape(gradf, np.size(gradf)) + f = np.size(X, 1) - f + + return [f, gradf] + + @staticmethod + def f(A, X, y): + return cost(A, X, y)[0] + + @staticmethod + def grad(A, X, y): + return cost(A, X, y)[1] + + @staticmethod + def cost_g(A, X, y, threshold=None): + """Compute the cost function and the gradient for the K-L divergence + + Parameters: + ----------- + A : array-like + Projection matrix, shape = [dim, n_features] with dim <= n_features + X : array-like + Training data, shape = [n_features, n_samples] + y : array-like + Target values, shape = [n_samples] + + Returns: + -------- + g : float + The value of the objective function + gradg : array-like + The gradient of the objective function, shape = [dim * n_features] + """ + + (D, N) = np.shape(X) + A = np.reshape(A, (np.size(A) / np.size(X, axis=0), np.size(X, axis=0))) + (d, aux) = np.shape(A) + assert D == aux + + AX = np.dot(A, X) + normAX = np.linalg.norm(AX[:, :, None] - AX[:, None, :], axis=0) + + denomSum = np.sum(np.exp(-normAX[:, :]), axis=0) + Pij = np.exp(- normAX) / denomSum[:, None] + if threshold is not None: + Pij[Pij < threshold] = 0 + Pij[Pij > 1-threshold] = 1 + + mask = (y != y[:, None]) + Pijmask = np.ma.masked_array(Pij, mask) + P = np.array(np.sum(Pijmask, axis=1)) + mask = np.negative(mask) + + g = np.sum(np.log(P)) + + Xi = X[:, :, None] - X[:, None, :] + Xi = np.swapaxes(Xi, 0, 2) + + Xi = Pij[:, :, None] * Xi + + Xij = Xi[:, :, :, None] * Xi[:, :, None, :] + + gradg = np.sum(np.sum(Xij[:], axis=1), axis=0) + + # To optimize (use mask ?) + for i in range(N): + aux = np.sum(Xij[i, mask[i]], axis=0) / P[i] + gradg -= aux + + gradg = 2 * np.dot(A, gradg) + gradg = -np.reshape(gradg, np.size(gradg)) + g = -g + + return [g, gradg] + + +class NCA(object): + + def __init__(self, metric=None, dim=None, + threshold=None, objective='Mahalanobis', **kwargs): + """Classification and/or dimensionality reduction with the neighborhood + component analysis. + + The algorithm apply the softmax function on the transformed space and + tries to maximise the leave-one-out classification. + + Parameters: + ----------- + metric : array-like, optional + The initial distance metric, if not precised, the algorithm will + use a poor projection of the Mahalanobis distance. + shape = [dim, n_features] with dim <= n_features being the + dimension of the output space + dim : int, optional + The number of dimensions to keep for dimensionality reduction. If + not precised, the algorithm wont perform dimensionality reduction. + threshold : float, otpional + Threshold for the softmax function, set it higher to discard + further neighbors. + objective : string, optional + The objective function to optimize. The two implemented cost + functions are for Mahalanobis distance and KL-divergence. + **kwargs : keyword arguments, optional + See scipy.optimise.minimize for the list of additional arguments. + Those arguments include: + + method : string + The algorithm to use for optimization. + options : dict + a dictionary of solver options + hess, hessp : callable + Hessian matrix + bounds : sequence + Bounds for variables + constraints : dict or sequence of dict + Constraints definition + tol : float + Tolerance for termination + + Attributes: + ----------- + metric : array-like + The trained disctance metric + """ + self.metric = metric + self.dim = dim + self.threshold = threshold + if objective == 'Mahalanobis': + self.objective = NCAcost.cost + elif objective == 'KL-divergence': + self.objective = NCAcost.cost_g + self.kwargs = kwargs + + def fit(self, X, y): + """Fit the model using X as training data and y as target values. + + Parameters: + ----------- + X : array-like + Training data, shape = [n_features, n_samples] + y : array-like + Target values, shape = [n_samples] + """ + if self.metric is None: + if self.dim is None: + self.metric = np.eye(np.size(X, 1)) + self.dim = np.size(X, 1) + else: + self.metric = np.eye(self.dim, np.size(X, 1) - self.dim) + + res = opt.minimize(fun=self.objective, + x0=self.metric, + args=(X, y, self.threshold), + jac=True, + **self.kwargs + ) + + self.metric = np.reshape(res.x, + (np.size(res.x) / np.size(X, 0), + np.size(X, 0))) + + def fit_transform(self, X, y): + """Fit the model with X and apply the dimensionality reduction on X. + + Parameters: + ----------- + X : array-like + Training data, shape = [n_features, n_samples] + y : array-like + Target values, shape = [n_samples] + + Returns: + -------- + X_new : array-like + shape = [dim, n_samples] + """ + self.fit(self, X, y) + return np.dot(self.metric, X) + + def score(self, X, y): + """Returns the proportion of X correctly classified by the leave-one- + out classification + + Parameters: + ----------- + X : array-like + Training data, shape = [n_features, n_samples] + y : array-like + Target values, shape = [n_samples] + + Returns: + -------- + score : float + The proportion of X correctly classified + """ + return 1 - NCAcost.cost(self.metric, X, y)[0]/np.size(X, 1) + + def getParameters(self): + """Returns a dictionary of the parameters + """ + return dict(metric=self.metric, dim=self.dim, objective=self.objective, + threshold=self.threshold, **self.kwargs) diff --git a/sklearn/metric_learning/demo.py b/sklearn/metric_learning/demo.py new file mode 100644 index 0000000000000..636856300349e --- /dev/null +++ b/sklearn/metric_learning/demo.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- + +import numpy as np +import NCA +import pylab as pl + +# Initialisation +N = 300 +aux = (np.concatenate([0.5*np.ones((N/2, 1)), + np.zeros((N/2, 1)), 1.1*np.ones((N/2, 1))], axis=1)) +X = np.concatenate([np.random.rand(N/2, 3), + np.random.rand(N/2, 3) + aux]) + +y = np.concatenate([np.concatenate([np.ones((N/2, 1)), np.zeros((N/2, 1))]), + np.concatenate([np.zeros((N/2, 1)), np.ones((N/2, 1))])], + axis=1) +X = X.T +y = y[:, 0] +A = np.array([[1, 0, 0], [0, 1, 0]]) + +# Training +nca = NCA.NCA(metric=A, method='BFGS', objective='KL-divergence', options={'maxiter': 10, 'disp': True}) +print nca.score(X, y) +nca.fit(X, y) +print nca.score(X, y) + +# Plot +pl.subplot(2, 1, 1) +AX = np.dot(A, X) +pl.scatter(AX[0, :], AX[1, :], 30, c=y, cmap=pl.cm.Paired) +pl.subplot(2, 1, 2) +BX = np.dot(np.reshape(nca.metric, np.shape(A)), X) +pl.scatter(BX[0, :], BX[1, :], 30, c=y, cmap=pl.cm.Paired) +pl.show() diff --git a/sklearn/tests/test_nca.py b/sklearn/tests/test_nca.py new file mode 100644 index 0000000000000..613c7adc4e08f --- /dev/null +++ b/sklearn/tests/test_nca.py @@ -0,0 +1,24 @@ +import numpy as np +from sklearn.metric_learning import NCA +import pylab as pl +from sklearn.utils.testing import ignore_warnings +from sklearn.utils.testing import assert_array_almost_equal + +N = 300 +aux = (np.concatenate([0.5*np.ones((N/2, 1)), + np.zeros((N/2, 1)), 1.1*np.ones((N/2, 1))], axis=1)) +X = np.concatenate([np.random.rand(N/2, 3), + np.random.rand(N/2, 3) + aux]) + +y = np.concatenate([np.concatenate([np.ones((N/2, 1)), np.zeros((N/2, 1))]), + np.concatenate([np.zeros((N/2, 1)), np.ones((N/2, 1))])], + axis=1) +X = X.T +y = y[:, 0] +A = np.array([[1, 0, 0], [0, 1, 0]]) + +def test_NCA(): + nca = NCA.NCA(metric=A, method='BFGS', objective='KL-divergence', options={'maxiter': 10, 'disp': True}) + print nca.score(X, y) + nca.fit(X, y) + print nca.score(X, y) From af1a53186b35a1e50eb06b7cd64d2b17ac2f14fc Mon Sep 17 00:00:00 2001 From: terrytangyuan Date: Tue, 15 Sep 2015 17:38:56 -0400 Subject: [PATCH 2/5] TST: Added test for NCA algorithm --- sklearn/metric_learning/NCA.py | 2 +- sklearn/metric_learning/__init__.py | 8 +++++++ sklearn/metric_learning/demo.py | 34 ----------------------------- sklearn/tests/test_nca.py | 10 +++++++++ 4 files changed, 19 insertions(+), 35 deletions(-) create mode 100644 sklearn/metric_learning/__init__.py delete mode 100644 sklearn/metric_learning/demo.py diff --git a/sklearn/metric_learning/NCA.py b/sklearn/metric_learning/NCA.py index 0dd49f5f67893..a4170e67d2238 100644 --- a/sklearn/metric_learning/NCA.py +++ b/sklearn/metric_learning/NCA.py @@ -2,7 +2,7 @@ """ Created on Thu May 22 16:50:00 2014 -@author: thiolliere +@author: thiolliere and Yuan Tang (terrytangyuan) """ import numpy as np import scipy.optimize as opt diff --git a/sklearn/metric_learning/__init__.py b/sklearn/metric_learning/__init__.py new file mode 100644 index 0000000000000..f6b2439aa4434 --- /dev/null +++ b/sklearn/metric_learning/__init__.py @@ -0,0 +1,8 @@ +""" +The :mod:`sklearn.metric_learning` module implements metric learning models. + +The algorithms that have been implemented are: +Relevant Components Analysis (RCA) +""" + +__all__ = ['NCA'] diff --git a/sklearn/metric_learning/demo.py b/sklearn/metric_learning/demo.py deleted file mode 100644 index 636856300349e..0000000000000 --- a/sklearn/metric_learning/demo.py +++ /dev/null @@ -1,34 +0,0 @@ -# -*- coding: utf-8 -*- - -import numpy as np -import NCA -import pylab as pl - -# Initialisation -N = 300 -aux = (np.concatenate([0.5*np.ones((N/2, 1)), - np.zeros((N/2, 1)), 1.1*np.ones((N/2, 1))], axis=1)) -X = np.concatenate([np.random.rand(N/2, 3), - np.random.rand(N/2, 3) + aux]) - -y = np.concatenate([np.concatenate([np.ones((N/2, 1)), np.zeros((N/2, 1))]), - np.concatenate([np.zeros((N/2, 1)), np.ones((N/2, 1))])], - axis=1) -X = X.T -y = y[:, 0] -A = np.array([[1, 0, 0], [0, 1, 0]]) - -# Training -nca = NCA.NCA(metric=A, method='BFGS', objective='KL-divergence', options={'maxiter': 10, 'disp': True}) -print nca.score(X, y) -nca.fit(X, y) -print nca.score(X, y) - -# Plot -pl.subplot(2, 1, 1) -AX = np.dot(A, X) -pl.scatter(AX[0, :], AX[1, :], 30, c=y, cmap=pl.cm.Paired) -pl.subplot(2, 1, 2) -BX = np.dot(np.reshape(nca.metric, np.shape(A)), X) -pl.scatter(BX[0, :], BX[1, :], 30, c=y, cmap=pl.cm.Paired) -pl.show() diff --git a/sklearn/tests/test_nca.py b/sklearn/tests/test_nca.py index 613c7adc4e08f..520f2a442cfb0 100644 --- a/sklearn/tests/test_nca.py +++ b/sklearn/tests/test_nca.py @@ -18,7 +18,17 @@ A = np.array([[1, 0, 0], [0, 1, 0]]) def test_NCA(): + # Training nca = NCA.NCA(metric=A, method='BFGS', objective='KL-divergence', options={'maxiter': 10, 'disp': True}) print nca.score(X, y) nca.fit(X, y) print nca.score(X, y) + + # Plot + pl.subplot(2, 1, 1) + AX = np.dot(A, X) + pl.scatter(AX[0, :], AX[1, :], 30, c=y, cmap=pl.cm.Paired) + pl.subplot(2, 1, 2) + BX = np.dot(np.reshape(nca.metric, np.shape(A)), X) + pl.scatter(BX[0, :], BX[1, :], 30, c=y, cmap=pl.cm.Paired) + # pl.show() \ No newline at end of file From 7bd48cc3a70249a08d25786bb9cd0521cfa88270 Mon Sep 17 00:00:00 2001 From: terrytangyuan Date: Tue, 15 Sep 2015 18:13:32 -0400 Subject: [PATCH 3/5] Trigger Travis/Appveyor --- sklearn/metric_learning/NCA.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sklearn/metric_learning/NCA.py b/sklearn/metric_learning/NCA.py index a4170e67d2238..d97bb25ade12a 100644 --- a/sklearn/metric_learning/NCA.py +++ b/sklearn/metric_learning/NCA.py @@ -1,7 +1,5 @@ # -*- coding: utf-8 -*- """ -Created on Thu May 22 16:50:00 2014 - @author: thiolliere and Yuan Tang (terrytangyuan) """ import numpy as np From 0ec9c0d9315337a1294f1bac10afcdeaad9d81bb Mon Sep 17 00:00:00 2001 From: terrytangyuan Date: Tue, 15 Sep 2015 18:28:06 -0400 Subject: [PATCH 4/5] Deleted plot test to pass Travis --- sklearn/tests/test_nca.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/sklearn/tests/test_nca.py b/sklearn/tests/test_nca.py index 520f2a442cfb0..483d674c05095 100644 --- a/sklearn/tests/test_nca.py +++ b/sklearn/tests/test_nca.py @@ -1,6 +1,5 @@ import numpy as np from sklearn.metric_learning import NCA -import pylab as pl from sklearn.utils.testing import ignore_warnings from sklearn.utils.testing import assert_array_almost_equal @@ -22,13 +21,4 @@ def test_NCA(): nca = NCA.NCA(metric=A, method='BFGS', objective='KL-divergence', options={'maxiter': 10, 'disp': True}) print nca.score(X, y) nca.fit(X, y) - print nca.score(X, y) - - # Plot - pl.subplot(2, 1, 1) - AX = np.dot(A, X) - pl.scatter(AX[0, :], AX[1, :], 30, c=y, cmap=pl.cm.Paired) - pl.subplot(2, 1, 2) - BX = np.dot(np.reshape(nca.metric, np.shape(A)), X) - pl.scatter(BX[0, :], BX[1, :], 30, c=y, cmap=pl.cm.Paired) - # pl.show() \ No newline at end of file + print nca.score(X, y) \ No newline at end of file From 863c59347143ddb7850c3c2a26d504d7c8b996aa Mon Sep 17 00:00:00 2001 From: terrytangyuan Date: Wed, 16 Sep 2015 04:42:50 -0400 Subject: [PATCH 5/5] revised and fixing Travis --- sklearn/__init__.py | 2 +- sklearn/metric_learning/NCA.py | 9 +++++---- sklearn/tests/test_nca.py | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/sklearn/__init__.py b/sklearn/__init__.py index ff2206ff5617b..6218e21a7bb19 100644 --- a/sklearn/__init__.py +++ b/sklearn/__init__.py @@ -66,7 +66,7 @@ 'linear_model', 'manifold', 'metrics', 'mixture', 'multiclass', 'naive_bayes', 'neighbors', 'neural_network', 'pipeline', 'preprocessing', 'qda', 'random_projection', 'semi_supervised', - 'svm', 'tree', 'discriminant_analysis', + 'svm', 'tree', 'discriminant_analysis', 'metric_learning' # Non-modules: 'clone'] diff --git a/sklearn/metric_learning/NCA.py b/sklearn/metric_learning/NCA.py index d97bb25ade12a..dd324a3760c5c 100644 --- a/sklearn/metric_learning/NCA.py +++ b/sklearn/metric_learning/NCA.py @@ -4,6 +4,7 @@ """ import numpy as np import scipy.optimize as opt +from sklearn.base import BaseEstimator class NCAcost(object): @@ -141,10 +142,10 @@ def cost_g(A, X, y, threshold=None): return [g, gradg] -class NCA(object): +class NCA(BaseEstimator): def __init__(self, metric=None, dim=None, - threshold=None, objective='Mahalanobis', **kwargs): + threshold=None, objective='mahalanobis', **kwargs): """Classification and/or dimensionality reduction with the neighborhood component analysis. @@ -192,9 +193,9 @@ def __init__(self, metric=None, dim=None, self.metric = metric self.dim = dim self.threshold = threshold - if objective == 'Mahalanobis': + if objective == 'mahalanobis': self.objective = NCAcost.cost - elif objective == 'KL-divergence': + elif objective == 'kl-divergence': self.objective = NCAcost.cost_g self.kwargs = kwargs diff --git a/sklearn/tests/test_nca.py b/sklearn/tests/test_nca.py index 483d674c05095..f7d7a27aadf27 100644 --- a/sklearn/tests/test_nca.py +++ b/sklearn/tests/test_nca.py @@ -18,7 +18,7 @@ def test_NCA(): # Training - nca = NCA.NCA(metric=A, method='BFGS', objective='KL-divergence', options={'maxiter': 10, 'disp': True}) + nca = NCA.NCA(metric=A, method='BFGS', objective='kl-divergence', options={'maxiter': 10, 'disp': True}) print nca.score(X, y) nca.fit(X, y) print nca.score(X, y) \ No newline at end of file