Skip to content

[MRG] Add memory efficient implementation of NCA #99

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 9, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 75 additions & 20 deletions metric_learn/nca.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,51 @@
"""

from __future__ import absolute_import

import warnings
import numpy as np
from six.moves import xrange
from scipy.optimize import minimize
from sklearn.metrics import pairwise_distances
from sklearn.utils.validation import check_X_y

try: # scipy.misc.logsumexp is deprecated in scipy 1.0.0
from scipy.special import logsumexp
except ImportError:
from scipy.misc import logsumexp

from .base_metric import BaseMetricLearner

EPS = np.finfo(float).eps


class NCA(BaseMetricLearner):
def __init__(self, num_dims=None, max_iter=100, learning_rate=0.01):
def __init__(self, num_dims=None, max_iter=100, learning_rate='deprecated',
tol=None):
"""Neighborhood Components Analysis

Parameters
----------
num_dims : int, optional (default=None)
Embedding dimensionality. If None, will be set to ``n_features``
(``d``) at fit time.

max_iter : int, optional (default=100)
Maximum number of iterations done by the optimization algorithm.

learning_rate : Not used

.. deprecated:: 0.4.0
`learning_rate` was deprecated in version 0.4.0 and will
be removed in 0.5.0. The current optimization algorithm does not need
to fix a learning rate.

tol : float, optional (default=None)
Convergence tolerance for the optimization.
"""
self.num_dims = num_dims
self.max_iter = max_iter
self.learning_rate = learning_rate
self.learning_rate = learning_rate # TODO: remove in v.0.5.0
self.tol = tol

def transformer(self):
return self.A_
Expand All @@ -27,33 +58,57 @@ def fit(self, X, y):
X: data matrix, (n x d)
y: scalar labels, (n)
"""
if self.learning_rate != 'deprecated':
warnings.warn('"learning_rate" parameter is not used.'
' It has been deprecated in version 0.4 and will be'
'removed in 0.5', DeprecationWarning)

X, labels = check_X_y(X, y)
n, d = X.shape
num_dims = self.num_dims
if num_dims is None:
num_dims = d

# Initialize A to a scaling matrix
A = np.zeros((num_dims, d))
np.fill_diagonal(A, 1./(np.maximum(X.max(axis=0)-X.min(axis=0), EPS)))

# Run NCA
dX = X[:,None] - X[None] # shape (n, n, d)
tmp = np.einsum('...i,...j->...ij', dX, dX) # shape (n, n, d, d)
masks = labels[:,None] == labels[None]
for it in xrange(self.max_iter):
for i, label in enumerate(labels):
mask = masks[i]
Ax = A.dot(X.T).T # shape (n, num_dims)

softmax = np.exp(-((Ax[i] - Ax)**2).sum(axis=1)) # shape (n)
softmax[i] = 0
softmax /= softmax.sum()

t = softmax[:, None, None] * tmp[i] # shape (n, d, d)
d = softmax[mask].sum() * t.sum(axis=0) - t[mask].sum(axis=0)
A += self.learning_rate * A.dot(d)
mask = labels[:, np.newaxis] == labels[np.newaxis, :]
optimizer_params = {'method': 'L-BFGS-B',
'fun': self._loss_grad_lbfgs,
'args': (X, mask, -1.0),
'jac': True,
'x0': A.ravel(),
'options': dict(maxiter=self.max_iter),
'tol': self.tol
}

# Call the optimizer
opt_result = minimize(**optimizer_params)

self.X_ = X
self.A_ = A
self.n_iter_ = it
self.A_ = opt_result.x.reshape(-1, X.shape[1])
self.n_iter_ = opt_result.nit
return self

@staticmethod
def _loss_grad_lbfgs(A, X, mask, sign=1.0):
A = A.reshape(-1, X.shape[1])
X_embedded = np.dot(X, A.T) # (n_samples, num_dims)
# Compute softmax distances
p_ij = pairwise_distances(X_embedded, squared=True)
np.fill_diagonal(p_ij, np.inf)
p_ij = np.exp(-p_ij - logsumexp(-p_ij, axis=1)[:, np.newaxis])
# (n_samples, n_samples)

# Compute loss
masked_p_ij = p_ij * mask
p = masked_p_ij.sum(axis=1, keepdims=True) # (n_samples, 1)
loss = p.sum()

# Compute gradient of loss w.r.t. `transform`
weighted_p_ij = masked_p_ij - p_ij * p
gradient = 2 * (X_embedded.T.dot(weighted_p_ij + weighted_p_ij.T) -
X_embedded.T * weighted_p_ij.sum(axis=0)).dot(X)
return sign * loss, sign * gradient.ravel()
128 changes: 113 additions & 15 deletions test/metric_learn_test.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import unittest
import numpy as np
from scipy.optimize import check_grad
from six.moves import xrange
from sklearn.metrics import pairwise_distances
from sklearn.datasets import load_iris
from numpy.testing import assert_array_almost_equal
from sklearn.datasets import load_iris, make_classification
from numpy.testing import assert_array_almost_equal, assert_array_equal
from sklearn.utils.testing import assert_warns_message

from metric_learn import (
LMNN, NCA, LFDA, Covariance, MLKR, MMC,
LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised, MMC_Supervised)
# Import this specially for testing.
from metric_learn import (LMNN, NCA, LFDA, Covariance, MLKR, MMC,
LSML_Supervised, ITML_Supervised, SDML_Supervised,
RCA_Supervised, MMC_Supervised)
from metric_learn.lmnn import python_LMNN


Expand Down Expand Up @@ -88,22 +89,119 @@ def test_iris(self):
n = self.iris_points.shape[0]

# Without dimension reduction
nca = NCA(max_iter=(100000//n), learning_rate=0.01)
nca = NCA(max_iter=(100000//n))
nca.fit(self.iris_points, self.iris_labels)
# Result copied from Iris example at
# https://github.com/vomjom/nca/blob/master/README.mkd
expected = [[-0.09935, -0.2215, 0.3383, 0.443],
[+0.2532, 0.5835, -0.8461, -0.8915],
[-0.729, -0.6386, 1.767, 1.832],
[-0.9405, -0.8461, 2.281, 2.794]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we verified that the new approach produces a correct result? Checking class separation is a very coarse approximation of correctness.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is more a general comment but IMHO this sort of test is not very reliable: (i) the "expected" output comes either from running the code itself or running some other implementation which may or may not be reliable (in this case, the source does not seem especially reliable), and (ii) NCA being a nonconvex objective, depending on the initialization but also on the chosen optimization algorithm, one might converge to a different point, which does not imply that the algorithm is incorrect.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

@wdevazelhes wdevazelhes Jul 4, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, the code in master uses stochastic gradient descent (updating L after a pass on each sample instead of the whole dataset). Therefore we will not be able to test this particular result (even if we tested without scipy's optimizer) since in this PR we compute the true full gradient at each iteration.

What is more, printing the loss in test_iris, the loss in this PR is better than the one in master (147.99999880164899 vs 144.57929636406135), so this adds to the argument that this hard coded array might not be such a good reference to check.

@wdevazelhes maybe some of the tests you designed for the sklearn version can be used instead?

Yes, I already added test_finite_differences (tests the gradient formula) and test_simple_example (toy example, that checks on 4 points that 2 same labeled points become closer after fit_transform) . But now that you say it I will also add test_singleton_class and test_one_class, which are also useful (they test edge cases on the dataset, and in some of these tests we know some analytical properties of the transformation: if only one class, or only singleton classes, then gradient is 0). However I don't think the other tests need to be included since they mostly test input formats, verbose, and other stuff that are necessary for inclusion in scikit-learn, but which I guess could be factored out for every algorithm in metric-learn in a later stage of development (and which were not enforced/tested for NCA before either so if we don't put them in this PR we don't regress).

assert_array_almost_equal(expected, nca.transformer(), decimal=3)
csep = class_separation(nca.transform(), self.iris_labels)
self.assertLess(csep, 0.15)

# With dimension reduction
nca = NCA(max_iter=(100000//n), learning_rate=0.01, num_dims=2)
nca = NCA(max_iter=(100000//n), num_dims=2, tol=1e-9)
nca.fit(self.iris_points, self.iris_labels)
csep = class_separation(nca.transform(), self.iris_labels)
self.assertLess(csep, 0.15)

def test_finite_differences(self):
"""Test gradient of loss function

Assert that the gradient is almost equal to its finite differences
approximation.
"""
# Initialize the transformation `M`, as well as `X` and `y` and `NCA`
X, y = make_classification()
M = np.random.randn(np.random.randint(1, X.shape[1] + 1), X.shape[1])
mask = y[:, np.newaxis] == y[np.newaxis, :]

def fun(M):
return NCA._loss_grad_lbfgs(M, X, mask)[0]

def grad(M):
return NCA._loss_grad_lbfgs(M, X, mask)[1].ravel()

# compute relative error
rel_diff = check_grad(fun, grad, M.ravel()) / np.linalg.norm(grad(M))
np.testing.assert_almost_equal(rel_diff, 0., decimal=6)

def test_simple_example(self):
"""Test on a simple example.

Puts four points in the input space where the opposite labels points are
next to each other. After transform the same labels points should be next
to each other.

"""
X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]])
y = np.array([1, 0, 1, 0])
nca = NCA(num_dims=2,)
nca.fit(X, y)
Xansformed = nca.transform(X)
np.testing.assert_equal(pairwise_distances(Xansformed).argsort()[:, 1],
np.array([2, 3, 0, 1]))

def test_deprecation(self):
# test that the right deprecation message is thrown.
# TODO: remove in v.0.5
X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]])
y = np.array([1, 0, 1, 0])
nca = NCA(num_dims=2, learning_rate=0.01)
msg = ('"learning_rate" parameter is not used.'
' It has been deprecated in version 0.4 and will be'
'removed in 0.5')
assert_warns_message(DeprecationWarning, msg, nca.fit, X, y)

def test_singleton_class(self):
X = self.iris_points
y = self.iris_labels

# one singleton class: test fitting works
singleton_class = 1
ind_singleton, = np.where(y == singleton_class)
y[ind_singleton] = 2
y[ind_singleton[0]] = singleton_class

nca = NCA(max_iter=30)
nca.fit(X, y)

# One non-singleton class: test fitting works
ind_1, = np.where(y == 1)
ind_2, = np.where(y == 2)
y[ind_1] = 0
y[ind_1[0]] = 1
y[ind_2] = 0
y[ind_2[0]] = 2

nca = NCA(max_iter=30)
nca.fit(X, y)

# Only singleton classes: test fitting does nothing (the gradient
# must be null in this case, so the final matrix must stay like
# the initialization)
ind_0, = np.where(y == 0)
ind_1, = np.where(y == 1)
ind_2, = np.where(y == 2)
X = X[[ind_0[0], ind_1[0], ind_2[0]]]
y = y[[ind_0[0], ind_1[0], ind_2[0]]]

EPS = np.finfo(float).eps
A = np.zeros((X.shape[1], X.shape[1]))
np.fill_diagonal(A,
1. / (np.maximum(X.max(axis=0) - X.min(axis=0), EPS)))
nca = NCA(max_iter=30, num_dims=X.shape[1])
nca.fit(X, y)
assert_array_equal(nca.A_, A)

def test_one_class(self):
# if there is only one class the gradient is null, so the final matrix
# must stay like the initialization
X = self.iris_points[self.iris_labels == 0]
y = self.iris_labels[self.iris_labels == 0]
EPS = np.finfo(float).eps
A = np.zeros((X.shape[1], X.shape[1]))
np.fill_diagonal(A,
1. / (np.maximum(X.max(axis=0) - X.min(axis=0), EPS)))
nca = NCA(max_iter=30, num_dims=X.shape[1])
nca.fit(X, y)
assert_array_equal(nca.A_, A)


class TestLFDA(MetricTestCase):
def test_iris(self):
Expand Down
3 changes: 2 additions & 1 deletion test/test_base_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def test_lmnn(self):

def test_nca(self):
self.assertEqual(str(metric_learn.NCA()),
"NCA(learning_rate=0.01, max_iter=100, num_dims=None)")
("NCA(learning_rate='deprecated', max_iter=100, "
"num_dims=None, tol=None)"))

def test_lfda(self):
self.assertEqual(str(metric_learn.LFDA()),
Expand Down