From 946ff1de0ff26cf00c1ff1a74539974f5fc47b63 Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Fri, 21 Jun 2019 17:17:02 +0200 Subject: [PATCH 1/7] WIP update impostors --- metric_learn/lmnn.py | 69 ++++++++++++++++++++++++++++++++------------ 1 file changed, 50 insertions(+), 19 deletions(-) diff --git a/metric_learn/lmnn.py b/metric_learn/lmnn.py index 20eeea3b..d972e02b 100644 --- a/metric_learn/lmnn.py +++ b/metric_learn/lmnn.py @@ -25,6 +25,13 @@ from .base_metric import MahalanobisMixin +def re_order_target_neighbors(L_next, X, target_neighbors): + Xl = np.dot(X, L_next.T) + dd = np.sum((Xl[:, None, :] - Xl[target_neighbors])**2, axis=2) + sorted_neighbors = np.take_along_axis(target_neighbors, dd.argsort(axis=1), 1) + return sorted_neighbors + + class LMNN(MahalanobisMixin, TransformerMixin): def __init__(self, init=None, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7, regularization=0.5, convergence_tol=0.001, @@ -155,8 +162,9 @@ def fit(self, X, y): raise ValueError('not enough class labels for specified k' ' (smallest class has %d)' % required_k) - target_neighbors = self._select_targets(X, label_inds) - impostors = self._find_impostors(target_neighbors[:, -1], X, label_inds) + target_neighbors = self._select_targets(self.transformer_, X, label_inds) + impostors = self._find_impostors(self.transformer_, + target_neighbors[:, -1], X, label_inds) if len(impostors) == 0: # L has already been initialized to an identity matrix return @@ -194,10 +202,17 @@ def fit(self, X, y): # we compute the objective at next point # we copy variables that can be modified by _loss_grad, because if we # retry we don t want to modify them several times + # target_neighbors_next = self._select_targets(L_next, X, label_inds) + # TODO: I should just re-order the target neighbors + target_neighbors_next = re_order_target_neighbors(L_next, X, + target_neighbors) + impostors_next = self._find_impostors(L_next, + target_neighbors_next[:, -1], X, + label_inds) (G_next, objective_next, total_active_next, df_next, a1_next, a2_next) = ( - self._loss_grad(X, L_next, dfG, impostors, it, k, reg, - target_neighbors, df.copy(), list(a1), list(a2))) + self._loss_grad(X, L_next, dfG, impostors_next, it, k, reg, + target_neighbors_next, df.copy(), list(a1), list(a2))) assert not np.isnan(objective) delta_obj = objective_next - objective if delta_obj > 0: @@ -212,6 +227,8 @@ def fit(self, X, y): # old variables to these new ones before next iteration and we # slightly increase the learning rate L = L_next + target_neighbors = target_neighbors_next + impostors = impostors_next G, df, objective, total_active, a1, a2 = ( G_next, df_next, objective_next, total_active_next, a1_next, a2_next) learn_rate *= 1.01 @@ -231,20 +248,25 @@ def fit(self, X, y): # store the last L self.transformer_ = L self.n_iter_ = it + self.targets_ = target_neighbors + self.impostors_ = impostors return self def _loss_grad(self, X, L, dfG, impostors, it, k, reg, target_neighbors, df, a1, a2): # Compute pairwise distances under current metric Lx = L.dot(X.T).T + n_samples = X.shape[0] g0 = _inplace_paired_L2(*Lx[impostors]) Ni = 1 + _inplace_paired_L2(Lx[target_neighbors], Lx[:, None, :]) g1, g2 = Ni[impostors] # compute the gradient total_active = 0 for nn_idx in reversed(xrange(k)): - act1 = g0 < g1[:, nn_idx] - act2 = g0 < g2[:, nn_idx] + act1 = np.zeros(X.shape[0]**2, dtype=bool) + act2 = np.zeros(X.shape[0]**2, dtype=bool) + act1[impostors[0] * n_samples + impostors[1]] = g0 < g1[:, nn_idx] + act2[impostors[0] * n_samples + impostors[1]] = g0 < g2[:, nn_idx] total_active += act1.sum() + act2.sum() if it > 1: @@ -259,17 +281,25 @@ def _loss_grad(self, X, L, dfG, impostors, it, k, reg, target_neighbors, df, minus2 = np.zeros(0, dtype=int) targets = target_neighbors[:, nn_idx] - PLUS, pweight = _count_edges(plus1, plus2, impostors, targets) + all_points = np.repeat(np.arange(X.shape[0])[None], 2, axis=0) + + PLUS, pweight = _count_edges(np.where(plus1)[0] % n_samples, + np.where(plus2)[0] % n_samples, + all_points, targets) df += _sum_outer_products(X, PLUS[:, 0], PLUS[:, 1], pweight) - MINUS, mweight = _count_edges(minus1, minus2, impostors, targets) + MINUS, mweight = _count_edges(np.where(minus1)[0] % n_samples, + np.where(minus2)[0] % n_samples, + all_points, targets) df -= _sum_outer_products(X, MINUS[:, 0], MINUS[:, 1], mweight) + df += _sum_outer_products(X, np.where(minus1)[0] % n_samples, + np.where(minus1)[0] // n_samples) + df += _sum_outer_products(X, np.where(minus2)[0] % n_samples, + np.where(minus2)[0] // n_samples) - in_imp, out_imp = impostors - df += _sum_outer_products(X, in_imp[minus1], out_imp[minus1]) - df += _sum_outer_products(X, in_imp[minus2], out_imp[minus2]) - - df -= _sum_outer_products(X, in_imp[plus1], out_imp[plus1]) - df -= _sum_outer_products(X, in_imp[plus2], out_imp[plus2]) + df -= _sum_outer_products(X, np.where(plus1)[0] % n_samples, + np.where(plus1)[0] // n_samples) + df -= _sum_outer_products(X, np.where(plus2)[0] % n_samples, + np.where(plus2)[0] // n_samples) a1[nn_idx] = act1 a2[nn_idx] = act2 @@ -282,23 +312,24 @@ def _loss_grad(self, X, L, dfG, impostors, it, k, reg, target_neighbors, df, objective += G.flatten().dot(L.flatten()) return 2 * G, objective, total_active, df, a1, a2 - def _select_targets(self, X, label_inds): + def _select_targets(self, L, X, label_inds): target_neighbors = np.empty((X.shape[0], self.k), dtype=int) for label in self.labels_: inds, = np.nonzero(label_inds == label) - dd = euclidean_distances(X[inds], squared=True) + dd = euclidean_distances(X.dot(L.T)[inds], squared=True) np.fill_diagonal(dd, np.inf) nn = np.argsort(dd)[..., :self.k] target_neighbors[inds] = inds[nn] return target_neighbors - def _find_impostors(self, furthest_neighbors, X, label_inds): - Lx = self.transform(X) + def _find_impostors(self, L, furthest_neighbors, X, label_inds): + Lx = X.dot(L.T) margin_radii = 1 + _inplace_paired_L2(Lx[furthest_neighbors], Lx) impostors = [] for label in self.labels_[:-1]: in_inds, = np.nonzero(label_inds == label) - out_inds, = np.nonzero(label_inds > label) + out_inds, = np.nonzero(label_inds > label) # TODO: not sure why >, + # sth like only one pass through labels and avoid symmetric ? dist = euclidean_distances(Lx[out_inds], Lx[in_inds], squared=True) i1,j1 = np.nonzero(dist < margin_radii[out_inds][:,None]) i2,j2 = np.nonzero(dist < margin_radii[in_inds]) From 9a968570c774bdbe3fb0e1f4f9b67f1ee17076d5 Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Sun, 30 Jun 2019 23:24:28 +0200 Subject: [PATCH 2/7] Add test for the cost function --- test/metric_learn_test.py | 101 +++++++++++++++++++++++++++++++++++++- 1 file changed, 100 insertions(+), 1 deletion(-) diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index c49c9ef5..e1594175 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -6,7 +6,7 @@ from six.moves import xrange from sklearn.metrics import pairwise_distances from sklearn.datasets import (load_iris, make_classification, make_regression, - make_spd_matrix) + make_spd_matrix, make_blobs) from numpy.testing import (assert_array_almost_equal, assert_array_equal, assert_allclose) from sklearn.utils.testing import assert_warns_message @@ -292,6 +292,105 @@ def test_changed_behaviour_warning(self): assert any(msg == str(wrn.message) for wrn in raised_warning) +def test_loss_func(capsys): + """Test the loss function (and its gradient) on a simple example, + by comparing the results with the actual implementation of metric-learn, + with a very simple (but nonperformant) implementation""" + # TODO: we need to find an example where there are still some impostors at + # the beginning and they decrease + # TODO: ideally we would like to do a test where the number of active + # constraints decrease + def hinge(a): + if a > 0: + return a, 1 + else: + return 0, 0 + + def loss_fn(L, X, y, target_neighbors, regularization): + L = L.reshape(-1, X.shape[1]) + Lx = np.dot(X, L.T) + loss = 0 + total_active = 0 + for i in range(X.shape[0]): + for j in target_neighbors[i]: + loss += (1 - regularization) * np.sum((Lx[i] - Lx[j])**2) + for l in range(X.shape[0]): + y_il = int(y[i] == y[l]) + hin, active = hinge(1 + np.sum((Lx[i] - Lx[j])**2) - + np.sum((Lx[i] - Lx[l])**2)) + total_active += active * (1 - y_il) # an active constraint is + # active, and is a constraint + loss += regularization * ((1 - y_il) * hin) + return loss, total_active + + def loss_fn_reduced(*args): + return loss_fn(*args)[0] + + class LMNN_nonperformant(LMNN): + + def fit(self, X, y): + self.y = y + return super(LMNN_nonperformant, self).fit(X, y) + + def _loss_grad(self, X, L, dfG, impostors, it, k, reg, target_neighbors, + df,a1, a2): + loss, total_active = loss_fn(L.ravel(), X, self.y, target_neighbors, + self.regularization) + grad = approx_fprime(L.ravel(), loss_fn_reduced, 1e-3, X, self.y, + target_neighbors, self.regularization) + grad = grad.reshape(-1, X.shape[1]) + return grad, loss, total_active, [], [], [] + + # test that the objective function never has twice the same value + # see https://github.com/metric-learn/metric-learn/issues/88 + + X, y = make_classification(n_samples=10, n_classes=2, + n_features=2, + n_redundant=0, shuffle=True, + scale=[1, 20], random_state=42) + lmnn_perf = LMNN(verbose=True, random_state=42, init='identity', max_iter=10) + lmnn_nonperf = LMNN_nonperformant(verbose=True, random_state=42, + init='identity', max_iter=10) + objectives, obj_diffs, grads, total_active = dict(), dict(), dict(), dict() + for algo, name in zip([lmnn_perf, lmnn_nonperf], ['perf', 'nonperf']): + algo.fit(X, y) + out, _ = capsys.readouterr() + lines = re.split("\n+", out) + # we get only objectives from each line: + # the regexp matches a float that follows an integer (the iteration + # number), and which is followed by a (signed) float (delta obj). It + # matches for instance: + # 3 **1113.7665747189938** -3.182774197440267 46431.0200999999999998e-06 + # regex for a signed number allowing scientific expression + num = '(-?\d+.?\d*(e[+|-]\d+)?)' + strings = [re.search("\d+ (?:{}) (?:{}) (?:(\d+)) (?:{})".format(num, num, + num), + s) for + s in + lines] + objectives[name] = [float(match.group(1)) for match in strings if match is + not + None] + obj_diffs[name] = [float(match.group(3)) for match in strings if match is + not + None] + total_active[name] = [float(match.group(5)) for match in strings if + match is not + None] + grads[name] = [float(match.group(6)) for match in strings if match is not + None] + assert len(strings) >= 10 # we ensure that we actually did more than 10 + # iterations + assert total_active[name][0] >= 2 # we ensure that we have some active + # constraints (that's the case we want to test) + # we remove the last element because it can be equal to the penultimate + # if the last gradient update is null + np.testing.assert_allclose(objectives['perf'], objectives['nonperf']) + np.testing.assert_allclose(obj_diffs['perf'], obj_diffs['nonperf']) + np.testing.assert_allclose(total_active['perf'], total_active['nonperf']) + np.testing.assert_allclose(grads['perf'], grads['nonperf']) + + @pytest.mark.parametrize('X, y, loss', [(np.array([[0], [1], [2], [3]]), [1, 1, 0, 0], 3.0), (np.array([[0], [1], [2], [3]]), From 247041e2a2ee2315b1b0eae2e8eec0fb3d8a6c91 Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Tue, 2 Jul 2019 02:24:54 +0200 Subject: [PATCH 3/7] Use real gradient for testing --- test/metric_learn_test.py | 49 ++++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index e1594175..2f628a35 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -307,24 +307,28 @@ def hinge(a): return 0, 0 def loss_fn(L, X, y, target_neighbors, regularization): - L = L.reshape(-1, X.shape[1]) - Lx = np.dot(X, L.T) - loss = 0 - total_active = 0 - for i in range(X.shape[0]): - for j in target_neighbors[i]: - loss += (1 - regularization) * np.sum((Lx[i] - Lx[j])**2) - for l in range(X.shape[0]): - y_il = int(y[i] == y[l]) - hin, active = hinge(1 + np.sum((Lx[i] - Lx[j])**2) - - np.sum((Lx[i] - Lx[l])**2)) - total_active += active * (1 - y_il) # an active constraint is - # active, and is a constraint - loss += regularization * ((1 - y_il) * hin) - return loss, total_active - - def loss_fn_reduced(*args): - return loss_fn(*args)[0] + L = L.reshape(-1, X.shape[1]) + Lx = np.dot(X, L.T) + loss = 0 + total_active = 0 + grad = np.zeros_like(L) + for i in range(X.shape[0]): + for j in target_neighbors[i]: + loss += (1 - regularization) * np.sum((Lx[i] - Lx[j])**2) + grad += (1 - regularization) * (Lx[i] - Lx[j]).T.dot(X[i] - X[j]) + for l in range(X.shape[0]): + if y[i] != y[l]: + hin, active = hinge(1 + np.sum((Lx[i] - Lx[j])**2) - + np.sum((Lx[i] - Lx[l])**2)) + total_active += active # an active constraint is + # active, and is a constraint + if active: + loss += regularization * hin + grad += (regularization * + ((Lx[i] - Lx[j]).T.dot(X[i] - X[j]) + - (Lx[i] - Lx[l]).T.dot(X[i] - X[l]))) + grad *= 2 + return grad, loss, total_active class LMNN_nonperformant(LMNN): @@ -333,12 +337,9 @@ def fit(self, X, y): return super(LMNN_nonperformant, self).fit(X, y) def _loss_grad(self, X, L, dfG, impostors, it, k, reg, target_neighbors, - df,a1, a2): - loss, total_active = loss_fn(L.ravel(), X, self.y, target_neighbors, - self.regularization) - grad = approx_fprime(L.ravel(), loss_fn_reduced, 1e-3, X, self.y, - target_neighbors, self.regularization) - grad = grad.reshape(-1, X.shape[1]) + df, a1, a2): + grad, loss, total_active = loss_fn(L.ravel(), X, self.y, + target_neighbors, self.regularization) return grad, loss, total_active, [], [], [] # test that the objective function never has twice the same value From ffb2fcf1408727397213118e5d26ddfdd9e21cf0 Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Tue, 2 Jul 2019 02:28:07 +0200 Subject: [PATCH 4/7] wip new implem --- metric_learn/lmnn.py | 160 ++++++++++++++++++-------------------- test/metric_learn_test.py | 5 +- 2 files changed, 78 insertions(+), 87 deletions(-) diff --git a/metric_learn/lmnn.py b/metric_learn/lmnn.py index d972e02b..50fa84d6 100644 --- a/metric_learn/lmnn.py +++ b/metric_learn/lmnn.py @@ -14,19 +14,20 @@ from __future__ import print_function, absolute_import import numpy as np +import scipy import warnings from collections import Counter from six.moves import xrange from sklearn.exceptions import ChangedBehaviorWarning -from sklearn.metrics import euclidean_distances +from sklearn.metrics import euclidean_distances, pairwise_distances from sklearn.base import TransformerMixin from ._util import _initialize_transformer, _check_n_components from .base_metric import MahalanobisMixin -def re_order_target_neighbors(L_next, X, target_neighbors): - Xl = np.dot(X, L_next.T) +def re_order_target_neighbors(L, X, target_neighbors): + Xl = np.dot(X, L.T) dd = np.sum((Xl[:, None, :] - Xl[target_neighbors])**2, axis=2) sorted_neighbors = np.take_along_axis(target_neighbors, dd.argsort(axis=1), 1) return sorted_neighbors @@ -172,24 +173,16 @@ def fit(self, X, y): # sum outer products dfG = _sum_outer_products(X, target_neighbors.flatten(), np.repeat(np.arange(X.shape[0]), k)) - df = np.zeros_like(dfG) - - # storage - a1 = [None]*k - a2 = [None]*k - for nn_idx in xrange(k): - a1[nn_idx] = np.array([]) - a2[nn_idx] = np.array([]) # initialize L L = self.transformer_ # first iteration: we compute variables (including objective and gradient) # at initialization point - G, objective, total_active, df, a1, a2 = ( - self._loss_grad(X, L, dfG, impostors, 1, k, reg, target_neighbors, df, - a1, a2)) + G, objective, total_active = self._loss_grad(X, L, y, dfG, 1, + k, reg, target_neighbors) + # TODO: need to print here the log it = 1 # we already made one iteration # main loop @@ -204,15 +197,10 @@ def fit(self, X, y): # retry we don t want to modify them several times # target_neighbors_next = self._select_targets(L_next, X, label_inds) # TODO: I should just re-order the target neighbors - target_neighbors_next = re_order_target_neighbors(L_next, X, - target_neighbors) - impostors_next = self._find_impostors(L_next, - target_neighbors_next[:, -1], X, - label_inds) - (G_next, objective_next, total_active_next, df_next, a1_next, - a2_next) = ( - self._loss_grad(X, L_next, dfG, impostors_next, it, k, reg, - target_neighbors_next, df.copy(), list(a1), list(a2))) + + (G_next, objective_next, total_active_next) = ( + self._loss_grad(X, L_next, label_inds, dfG, it, k, reg, + target_neighbors)) assert not np.isnan(objective) delta_obj = objective_next - objective if delta_obj > 0: @@ -227,10 +215,7 @@ def fit(self, X, y): # old variables to these new ones before next iteration and we # slightly increase the learning rate L = L_next - target_neighbors = target_neighbors_next - impostors = impostors_next - G, df, objective, total_active, a1, a2 = ( - G_next, df_next, objective_next, total_active_next, a1_next, a2_next) + G, objective, total_active = G_next, objective_next, total_active_next learn_rate *= 1.01 if self.verbose: @@ -252,65 +237,71 @@ def fit(self, X, y): self.impostors_ = impostors return self - def _loss_grad(self, X, L, dfG, impostors, it, k, reg, target_neighbors, df, - a1, a2): + def _loss_grad(self, X, L, y, dfG, it, k, reg, target_neighbors): # Compute pairwise distances under current metric - Lx = L.dot(X.T).T + Lx = X.dot(L.T) n_samples = X.shape[0] - g0 = _inplace_paired_L2(*Lx[impostors]) - Ni = 1 + _inplace_paired_L2(Lx[target_neighbors], Lx[:, None, :]) - g1, g2 = Ni[impostors] - # compute the gradient - total_active = 0 - for nn_idx in reversed(xrange(k)): - act1 = np.zeros(X.shape[0]**2, dtype=bool) - act2 = np.zeros(X.shape[0]**2, dtype=bool) - act1[impostors[0] * n_samples + impostors[1]] = g0 < g1[:, nn_idx] - act2[impostors[0] * n_samples + impostors[1]] = g0 < g2[:, nn_idx] - total_active += act1.sum() + act2.sum() - - if it > 1: - plus1 = act1 & ~a1[nn_idx] - minus1 = a1[nn_idx] & ~act1 - plus2 = act2 & ~a2[nn_idx] - minus2 = a2[nn_idx] & ~act2 - else: - plus1 = act1 - plus2 = act2 - minus1 = np.zeros(0, dtype=int) - minus2 = np.zeros(0, dtype=int) - - targets = target_neighbors[:, nn_idx] - all_points = np.repeat(np.arange(X.shape[0])[None], 2, axis=0) - - PLUS, pweight = _count_edges(np.where(plus1)[0] % n_samples, - np.where(plus2)[0] % n_samples, - all_points, targets) - df += _sum_outer_products(X, PLUS[:, 0], PLUS[:, 1], pweight) - MINUS, mweight = _count_edges(np.where(minus1)[0] % n_samples, - np.where(minus2)[0] % n_samples, - all_points, targets) - df -= _sum_outer_products(X, MINUS[:, 0], MINUS[:, 1], mweight) - df += _sum_outer_products(X, np.where(minus1)[0] % n_samples, - np.where(minus1)[0] // n_samples) - df += _sum_outer_products(X, np.where(minus2)[0] % n_samples, - np.where(minus2)[0] // n_samples) - - df -= _sum_outer_products(X, np.where(plus1)[0] % n_samples, - np.where(plus1)[0] // n_samples) - df -= _sum_outer_products(X, np.where(plus2)[0] % n_samples, - np.where(plus2)[0] // n_samples) - - a1[nn_idx] = act1 - a2[nn_idx] = act2 - # do the gradient update - assert not np.isnan(df).any() - G = dfG * reg + df * (1 - reg) - G = L.dot(G) - # compute the objective function - objective = total_active * (1 - reg) - objective += G.flatten().dot(L.flatten()) - return 2 * G, objective, total_active, df, a1, a2 + target_dist = np.sum((Lx[:, None] - Lx[target_neighbors])**2, axis=2) + # TODO: maybe this is not the more efficient, to re-order inplace the + # target neighbors ? + target_idx_sorted = np.take_along_axis(target_neighbors, + target_dist.argsort(axis=1), 1) + target_dist = np.sort(target_dist, axis=1) + total_active, push_loss = 0, 0 + weights = scipy.sparse.csr_matrix((n_samples, n_samples)) + for c in np.unique(y): # could maybe avoid this loop and vectorize + same_label = y == c # TODO: I can have this pre-computed + imp_dist = pairwise_distances(Lx[same_label], Lx[~same_label], + squared=True) + # TODO: do some computations with a count kind of thing maybe + for nn_idx in reversed(xrange(k)): # could maybe avoid this loop and + # vectorize + # TODO: simplify indexing when possible + margins = target_dist[same_label, nn_idx][:, None] + 1 - imp_dist + active = margins > 0 + # we mask the further impostors bc they don't need to be compared + # anymore + actives = np.sum(active, axis=1) # result: like a column (but + # result is "list") + current_total_actives = np.sum(actives) + total_active += current_total_actives + pos_margins = np.ma.masked_array(margins, ~active) + imp_dist = np.ma.masked_array(imp_dist, ~active) + push_loss += (1 - reg) * np.ma.sum(pos_margins) + + weights[same_label, target_idx_sorted[same_label][:, nn_idx]] -= \ + actives + weights[target_idx_sorted[same_label][:, nn_idx], same_label] -= \ + actives + weights[target_idx_sorted[same_label][:, nn_idx], + target_idx_sorted[same_label][:, nn_idx]] += actives + weights[~same_label][:, ~same_label] -= np.ma.sum(active, axis=0) + # + # TODO: be + # careful + # may be wrong here + weights[~same_label][:, same_label] += active.T + weights[same_label][:, ~same_label] += active + + # TODO: maybe for some of the things we can multiply or add a total + # at the end of the loop on nn_idx ? + # TODO: + # maybe the things on the diagonal could be optimized more ( + # like 3 * X instead of 3*np.eye().dot(X) kind of thing ? + push_grad = ((1 - reg) * weights.T.dot(Lx)).T.dot(X) # TODO: optimize + # order of + # ops like + # NCA + # TODO: do better sparse multiplication (avoid the transpose) + pull_grad = L.dot(dfG * reg) # we could do a computation with Lx if d >> n + + pull_loss = reg * np.sum(target_dist) + grad = push_grad + pull_grad + grad *= 2 + it += 1 + objective = pull_loss + push_loss + + return grad, objective, total_active def _select_targets(self, L, X, label_inds): target_neighbors = np.empty((X.shape[0], self.k), dtype=int) @@ -366,6 +357,7 @@ def _count_edges(act1, act2, impostors, targets): def _sum_outer_products(data, a_inds, b_inds, weights=None): + # TODO: since used one time, maybe replace by sth else ? Xab = data[a_inds] - data[b_inds] if weights is not None: return np.dot(Xab.T, Xab * weights[:,None]) diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index 2f628a35..e4ca9e45 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -336,11 +336,10 @@ def fit(self, X, y): self.y = y return super(LMNN_nonperformant, self).fit(X, y) - def _loss_grad(self, X, L, dfG, impostors, it, k, reg, target_neighbors, - df, a1, a2): + def _loss_grad(self, X, L, y, dfG, it, k, reg, target_neighbors): grad, loss, total_active = loss_fn(L.ravel(), X, self.y, target_neighbors, self.regularization) - return grad, loss, total_active, [], [], [] + return grad, loss, total_active # test that the objective function never has twice the same value # see https://github.com/metric-learn/metric-learn/issues/88 From 862e698ac1a5e1129135f9c21cd550c736369310 Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Tue, 2 Jul 2019 02:35:46 +0200 Subject: [PATCH 5/7] add warning for self --- test/metric_learn_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index e4ca9e45..6e87ca4a 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -307,6 +307,8 @@ def hinge(a): return 0, 0 def loss_fn(L, X, y, target_neighbors, regularization): + # warning to self: this is probably wrong, see test on the thing that was + # before L = L.reshape(-1, X.shape[1]) Lx = np.dot(X, L.T) loss = 0 From 3aee1d8cdd28d53d9f016238ac5a6a6b8d492c35 Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Tue, 2 Jul 2019 10:46:32 +0200 Subject: [PATCH 6/7] Improvements --- metric_learn/lmnn.py | 19 ++++++++++------- test/metric_learn_test.py | 43 ++++++++++++++++++++++++++++++++++----- 2 files changed, 50 insertions(+), 12 deletions(-) diff --git a/metric_learn/lmnn.py b/metric_learn/lmnn.py index 50fa84d6..a803a9fc 100644 --- a/metric_learn/lmnn.py +++ b/metric_learn/lmnn.py @@ -182,9 +182,12 @@ def fit(self, X, y): G, objective, total_active = self._loss_grad(X, L, y, dfG, 1, k, reg, target_neighbors) + # TODO: need to print here the log it = 1 # we already made one iteration + print(it, objective, 0, total_active, 1.05e-5) # TODO: replace by a + # real learning rate here it's just to fix a bug when printing # main loop for it in xrange(2, self.max_iter): # then at each iteration, we try to find a value of L that has better @@ -261,21 +264,23 @@ def _loss_grad(self, X, L, y, dfG, it, k, reg, target_neighbors): active = margins > 0 # we mask the further impostors bc they don't need to be compared # anymore - actives = np.sum(active, axis=1) # result: like a column (but + actives = np.ma.sum(active, axis=1) # result: like a column (but # result is "list") - current_total_actives = np.sum(actives) + current_total_actives = np.ma.sum(actives) total_active += current_total_actives pos_margins = np.ma.masked_array(margins, ~active) imp_dist = np.ma.masked_array(imp_dist, ~active) push_loss += (1 - reg) * np.ma.sum(pos_margins) - weights[same_label, target_idx_sorted[same_label][:, nn_idx]] -= \ + weights[same_label, target_idx_sorted[same_label][:, nn_idx].ravel()] \ + -= \ actives - weights[target_idx_sorted[same_label][:, nn_idx], same_label] -= \ + weights[target_idx_sorted[same_label][:, nn_idx].ravel(), same_label] \ + -= \ actives - weights[target_idx_sorted[same_label][:, nn_idx], - target_idx_sorted[same_label][:, nn_idx]] += actives - weights[~same_label][:, ~same_label] -= np.ma.sum(active, axis=0) + weights[target_idx_sorted[same_label][:, nn_idx].ravel(), + target_idx_sorted[same_label][:, nn_idx].ravel()] += actives + weights[~same_label, ~same_label] -= np.ma.sum(active, axis=0) # # TODO: be # careful diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index 6e87ca4a..3c85dbe6 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -2,9 +2,10 @@ import re import pytest import numpy as np +import scipy from scipy.optimize import check_grad, approx_fprime from six.moves import xrange -from sklearn.metrics import pairwise_distances +from sklearn.metrics import pairwise_distances, euclidean_distances from sklearn.datasets import (load_iris, make_classification, make_regression, make_spd_matrix, make_blobs) from numpy.testing import (assert_array_almost_equal, assert_array_equal, @@ -300,6 +301,10 @@ def test_loss_func(capsys): # the beginning and they decrease # TODO: ideally we would like to do a test where the number of active # constraints decrease + # TODO: apparently it worked with master, when master does not recomputes + # the impostors, so see if I could not improve this test to ensure it + # tests well impostors recomputation (or at least if the impostors at + # init are not all the impostors) def hinge(a): if a > 0: return a, 1 @@ -317,7 +322,7 @@ def loss_fn(L, X, y, target_neighbors, regularization): for i in range(X.shape[0]): for j in target_neighbors[i]: loss += (1 - regularization) * np.sum((Lx[i] - Lx[j])**2) - grad += (1 - regularization) * (Lx[i] - Lx[j]).T.dot(X[i] - X[j]) + grad += (1 - regularization) * np.outer(Lx[i] - Lx[j], X[i] - X[j]) for l in range(X.shape[0]): if y[i] != y[l]: hin, active = hinge(1 + np.sum((Lx[i] - Lx[j])**2) - @@ -327,11 +332,39 @@ def loss_fn(L, X, y, target_neighbors, regularization): if active: loss += regularization * hin grad += (regularization * - ((Lx[i] - Lx[j]).T.dot(X[i] - X[j]) - - (Lx[i] - Lx[l]).T.dot(X[i] - X[l]))) - grad *= 2 + (np.outer(Lx[i] - Lx[j], X[i] - X[j]) + - np.outer(Lx[i] - Lx[l], X[i] - X[l]))) + grad = 2 * grad return grad, loss, total_active + # we check that the gradient we have computed in the test is indeed the + # true gradient on a toy example: + X, y = make_classification(random_state=42, class_sep=0.1, n_features=20) + + def _select_targets(X, y, k): + target_neighbors = np.empty((X.shape[0], k), dtype=int) + for label in np.unique(y): + inds, = np.nonzero(y == label) + dd = euclidean_distances(X[inds], squared=True) + np.fill_diagonal(dd, np.inf) + nn = np.argsort(dd)[..., :k] + target_neighbors[inds] = inds[nn] + return target_neighbors + + target_neighbors = _select_targets(X, y, 2) + regularization = 0.5 + x0 = np.random.randn(5, 20) + + def loss(x0): + return loss_fn(x0.reshape(-1, X.shape[1]), X, y, target_neighbors, + regularization)[1] + + def grad(x0): + return loss_fn(x0.reshape(-1, X.shape[1]), X, y, target_neighbors, + regularization)[0].ravel() + + scipy.optimize.check_grad(loss, grad, x0.ravel()) + class LMNN_nonperformant(LMNN): def fit(self, X, y): From 269d51dc2d82512aed33f08f06809d5da37e3d29 Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Tue, 2 Jul 2019 15:11:47 +0200 Subject: [PATCH 7/7] wip --- metric_learn/lmnn.py | 41 +++++++++++++------------- test/metric_learn_test.py | 62 ++++++++++++++++++++++----------------- 2 files changed, 56 insertions(+), 47 deletions(-) diff --git a/metric_learn/lmnn.py b/metric_learn/lmnn.py index a803a9fc..bb046a95 100644 --- a/metric_learn/lmnn.py +++ b/metric_learn/lmnn.py @@ -251,43 +251,44 @@ def _loss_grad(self, X, L, y, dfG, it, k, reg, target_neighbors): target_dist.argsort(axis=1), 1) target_dist = np.sort(target_dist, axis=1) total_active, push_loss = 0, 0 - weights = scipy.sparse.csr_matrix((n_samples, n_samples)) + weights = np.zeros((n_samples, n_samples)) for c in np.unique(y): # could maybe avoid this loop and vectorize - same_label = y == c # TODO: I can have this pre-computed - imp_dist = pairwise_distances(Lx[same_label], Lx[~same_label], + same_label = np.where(y == c)[0] # TODO: I can have this pre-computed + diff_label = np.where(y != c)[0] + imp_dist = pairwise_distances(Lx[same_label], Lx[diff_label], squared=True) # TODO: do some computations with a count kind of thing maybe for nn_idx in reversed(xrange(k)): # could maybe avoid this loop and # vectorize # TODO: simplify indexing when possible - margins = target_dist[same_label, nn_idx][:, None] + 1 - imp_dist + margins = target_dist[same_label][:, nn_idx][:, None] + 1 - imp_dist active = margins > 0 # we mask the further impostors bc they don't need to be compared # anymore - actives = np.ma.sum(active, axis=1) # result: like a column (but + actives = np.sum(active, axis=1) # result: like a column (but # result is "list") - current_total_actives = np.ma.sum(actives) + current_total_actives = np.sum(actives) total_active += current_total_actives - pos_margins = np.ma.masked_array(margins, ~active) - imp_dist = np.ma.masked_array(imp_dist, ~active) - push_loss += (1 - reg) * np.ma.sum(pos_margins) - - weights[same_label, target_idx_sorted[same_label][:, nn_idx].ravel()] \ - -= \ - actives - weights[target_idx_sorted[same_label][:, nn_idx].ravel(), same_label] \ + pos_margins = margins[active] + push_loss += (1 - reg) * np.sum(pos_margins) + + weights[same_label, + (target_idx_sorted[same_label][:, nn_idx]).ravel()] \ + -= actives + weights[(target_idx_sorted[same_label][:, nn_idx]).ravel(), + same_label] \ -= \ actives - weights[target_idx_sorted[same_label][:, nn_idx].ravel(), - target_idx_sorted[same_label][:, nn_idx].ravel()] += actives - weights[~same_label, ~same_label] -= np.ma.sum(active, axis=0) + weights[(target_idx_sorted[same_label][:, nn_idx]).ravel(), + (target_idx_sorted[same_label][:, nn_idx]).ravel()] += actives + weights[diff_label, diff_label] -= np.sum(active, axis=0) # # TODO: be # careful # may be wrong here - weights[~same_label][:, same_label] += active.T - weights[same_label][:, ~same_label] += active - + weights[diff_label[:, None], same_label[None]] += active.T + weights[same_label[:, None], diff_label[None]] += active + # TODO: maybe for some of the things we can multiply or add a total # at the end of the loop on nn_idx ? # TODO: diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index 3c85dbe6..2efb395e 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -319,6 +319,7 @@ def loss_fn(L, X, y, target_neighbors, regularization): loss = 0 total_active = 0 grad = np.zeros_like(L) + outer = np.zeros((X.shape[0], X.shape[0])) for i in range(X.shape[0]): for j in target_neighbors[i]: loss += (1 - regularization) * np.sum((Lx[i] - Lx[j])**2) @@ -334,36 +335,43 @@ def loss_fn(L, X, y, target_neighbors, regularization): grad += (regularization * (np.outer(Lx[i] - Lx[j], X[i] - X[j]) - np.outer(Lx[i] - Lx[l], X[i] - X[l]))) + outer[i, j] -= 1 + outer[j, i] -= 1 + outer[j, j] += 1 + outer[l, l] -= 1 + outer[l, i] += 1 + outer[i, l] += 1 grad = 2 * grad return grad, loss, total_active - # we check that the gradient we have computed in the test is indeed the - # true gradient on a toy example: - X, y = make_classification(random_state=42, class_sep=0.1, n_features=20) - - def _select_targets(X, y, k): - target_neighbors = np.empty((X.shape[0], k), dtype=int) - for label in np.unique(y): - inds, = np.nonzero(y == label) - dd = euclidean_distances(X[inds], squared=True) - np.fill_diagonal(dd, np.inf) - nn = np.argsort(dd)[..., :k] - target_neighbors[inds] = inds[nn] - return target_neighbors - - target_neighbors = _select_targets(X, y, 2) - regularization = 0.5 - x0 = np.random.randn(5, 20) - - def loss(x0): - return loss_fn(x0.reshape(-1, X.shape[1]), X, y, target_neighbors, - regularization)[1] - - def grad(x0): - return loss_fn(x0.reshape(-1, X.shape[1]), X, y, target_neighbors, - regularization)[0].ravel() - - scipy.optimize.check_grad(loss, grad, x0.ravel()) + # TODO: keep this but make it lighter + # # we check that the gradient we have computed in the test is indeed the + # # true gradient on a toy example: + # X, y = make_classification(random_state=42, class_sep=0.1, n_features=20) + # + # def _select_targets(X, y, k): + # target_neighbors = np.empty((X.shape[0], k), dtype=int) + # for label in np.unique(y): + # inds, = np.nonzero(y == label) + # dd = euclidean_distances(X[inds], squared=True) + # np.fill_diagonal(dd, np.inf) + # nn = np.argsort(dd)[..., :k] + # target_neighbors[inds] = inds[nn] + # return target_neighbors + # + # target_neighbors = _select_targets(X, y, 2) + # regularization = 0.5 + # x0 = np.random.randn(5, 20) # TODO: take smaller x0, X, y + # + # def loss(x0): + # return loss_fn(x0.reshape(-1, X.shape[1]), X, y, target_neighbors, + # regularization)[1] + # + # def grad(x0): + # return loss_fn(x0.reshape(-1, X.shape[1]), X, y, target_neighbors, + # regularization)[0].ravel() + # + # scipy.optimize.check_grad(loss, grad, x0.ravel()) class LMNN_nonperformant(LMNN):