Skip to content

[WIP] update impostors #223

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 102 additions & 73 deletions metric_learn/lmnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,25 @@

from __future__ import print_function, absolute_import
import numpy as np
import scipy
import warnings
from collections import Counter
from six.moves import xrange
from sklearn.exceptions import ChangedBehaviorWarning
from sklearn.metrics import euclidean_distances
from sklearn.metrics import euclidean_distances, pairwise_distances
from sklearn.base import TransformerMixin

from ._util import _initialize_transformer, _check_n_components
from .base_metric import MahalanobisMixin


def re_order_target_neighbors(L, X, target_neighbors):
Xl = np.dot(X, L.T)
dd = np.sum((Xl[:, None, :] - Xl[target_neighbors])**2, axis=2)
sorted_neighbors = np.take_along_axis(target_neighbors, dd.argsort(axis=1), 1)
return sorted_neighbors


class LMNN(MahalanobisMixin, TransformerMixin):
def __init__(self, init=None, k=3, min_iter=50, max_iter=1000,
learn_rate=1e-7, regularization=0.5, convergence_tol=0.001,
Expand Down Expand Up @@ -155,35 +163,31 @@ def fit(self, X, y):
raise ValueError('not enough class labels for specified k'
' (smallest class has %d)' % required_k)

target_neighbors = self._select_targets(X, label_inds)
impostors = self._find_impostors(target_neighbors[:, -1], X, label_inds)
target_neighbors = self._select_targets(self.transformer_, X, label_inds)
impostors = self._find_impostors(self.transformer_,
target_neighbors[:, -1], X, label_inds)
if len(impostors) == 0:
# L has already been initialized to an identity matrix
return

# sum outer products
dfG = _sum_outer_products(X, target_neighbors.flatten(),
np.repeat(np.arange(X.shape[0]), k))
df = np.zeros_like(dfG)

# storage
a1 = [None]*k
a2 = [None]*k
for nn_idx in xrange(k):
a1[nn_idx] = np.array([])
a2[nn_idx] = np.array([])

# initialize L
L = self.transformer_

# first iteration: we compute variables (including objective and gradient)
# at initialization point
G, objective, total_active, df, a1, a2 = (
self._loss_grad(X, L, dfG, impostors, 1, k, reg, target_neighbors, df,
a1, a2))
G, objective, total_active = self._loss_grad(X, L, y, dfG, 1,
k, reg, target_neighbors)


# TODO: need to print here the log
it = 1 # we already made one iteration

print(it, objective, 0, total_active, 1.05e-5) # TODO: replace by a
# real learning rate here it's just to fix a bug when printing
# main loop
for it in xrange(2, self.max_iter):
# then at each iteration, we try to find a value of L that has better
Expand All @@ -194,10 +198,12 @@ def fit(self, X, y):
# we compute the objective at next point
# we copy variables that can be modified by _loss_grad, because if we
# retry we don t want to modify them several times
(G_next, objective_next, total_active_next, df_next, a1_next,
a2_next) = (
self._loss_grad(X, L_next, dfG, impostors, it, k, reg,
target_neighbors, df.copy(), list(a1), list(a2)))
# target_neighbors_next = self._select_targets(L_next, X, label_inds)
# TODO: I should just re-order the target neighbors

(G_next, objective_next, total_active_next) = (
self._loss_grad(X, L_next, label_inds, dfG, it, k, reg,
target_neighbors))
assert not np.isnan(objective)
delta_obj = objective_next - objective
if delta_obj > 0:
Expand All @@ -212,8 +218,7 @@ def fit(self, X, y):
# old variables to these new ones before next iteration and we
# slightly increase the learning rate
L = L_next
G, df, objective, total_active, a1, a2 = (
G_next, df_next, objective_next, total_active_next, a1_next, a2_next)
G, objective, total_active = G_next, objective_next, total_active_next
learn_rate *= 1.01

if self.verbose:
Expand All @@ -231,74 +236,97 @@ def fit(self, X, y):
# store the last L
self.transformer_ = L
self.n_iter_ = it
self.targets_ = target_neighbors
self.impostors_ = impostors
return self

def _loss_grad(self, X, L, dfG, impostors, it, k, reg, target_neighbors, df,
a1, a2):
def _loss_grad(self, X, L, y, dfG, it, k, reg, target_neighbors):
# Compute pairwise distances under current metric
Lx = L.dot(X.T).T
g0 = _inplace_paired_L2(*Lx[impostors])
Ni = 1 + _inplace_paired_L2(Lx[target_neighbors], Lx[:, None, :])
g1, g2 = Ni[impostors]
# compute the gradient
total_active = 0
for nn_idx in reversed(xrange(k)):
act1 = g0 < g1[:, nn_idx]
act2 = g0 < g2[:, nn_idx]
total_active += act1.sum() + act2.sum()

if it > 1:
plus1 = act1 & ~a1[nn_idx]
minus1 = a1[nn_idx] & ~act1
plus2 = act2 & ~a2[nn_idx]
minus2 = a2[nn_idx] & ~act2
else:
plus1 = act1
plus2 = act2
minus1 = np.zeros(0, dtype=int)
minus2 = np.zeros(0, dtype=int)

targets = target_neighbors[:, nn_idx]
PLUS, pweight = _count_edges(plus1, plus2, impostors, targets)
df += _sum_outer_products(X, PLUS[:, 0], PLUS[:, 1], pweight)
MINUS, mweight = _count_edges(minus1, minus2, impostors, targets)
df -= _sum_outer_products(X, MINUS[:, 0], MINUS[:, 1], mweight)

in_imp, out_imp = impostors
df += _sum_outer_products(X, in_imp[minus1], out_imp[minus1])
df += _sum_outer_products(X, in_imp[minus2], out_imp[minus2])

df -= _sum_outer_products(X, in_imp[plus1], out_imp[plus1])
df -= _sum_outer_products(X, in_imp[plus2], out_imp[plus2])

a1[nn_idx] = act1
a2[nn_idx] = act2
# do the gradient update
assert not np.isnan(df).any()
G = dfG * reg + df * (1 - reg)
G = L.dot(G)
# compute the objective function
objective = total_active * (1 - reg)
objective += G.flatten().dot(L.flatten())
return 2 * G, objective, total_active, df, a1, a2

def _select_targets(self, X, label_inds):
Lx = X.dot(L.T)
n_samples = X.shape[0]
target_dist = np.sum((Lx[:, None] - Lx[target_neighbors])**2, axis=2)
# TODO: maybe this is not the more efficient, to re-order inplace the
# target neighbors ?
target_idx_sorted = np.take_along_axis(target_neighbors,
target_dist.argsort(axis=1), 1)
target_dist = np.sort(target_dist, axis=1)
total_active, push_loss = 0, 0
weights = np.zeros((n_samples, n_samples))
for c in np.unique(y): # could maybe avoid this loop and vectorize
same_label = np.where(y == c)[0] # TODO: I can have this pre-computed
diff_label = np.where(y != c)[0]
imp_dist = pairwise_distances(Lx[same_label], Lx[diff_label],
squared=True)
# TODO: do some computations with a count kind of thing maybe
for nn_idx in reversed(xrange(k)): # could maybe avoid this loop and
# vectorize
# TODO: simplify indexing when possible
margins = target_dist[same_label][:, nn_idx][:, None] + 1 - imp_dist
active = margins > 0
# we mask the further impostors bc they don't need to be compared
# anymore
actives = np.sum(active, axis=1) # result: like a column (but
# result is "list")
current_total_actives = np.sum(actives)
total_active += current_total_actives
pos_margins = margins[active]
push_loss += (1 - reg) * np.sum(pos_margins)

weights[same_label,
(target_idx_sorted[same_label][:, nn_idx]).ravel()] \
-= actives
weights[(target_idx_sorted[same_label][:, nn_idx]).ravel(),
same_label] \
-= \
actives
weights[(target_idx_sorted[same_label][:, nn_idx]).ravel(),
(target_idx_sorted[same_label][:, nn_idx]).ravel()] += actives
weights[diff_label, diff_label] -= np.sum(active, axis=0)
#
# TODO: be
# careful
# may be wrong here
weights[diff_label[:, None], same_label[None]] += active.T
weights[same_label[:, None], diff_label[None]] += active

# TODO: maybe for some of the things we can multiply or add a total
# at the end of the loop on nn_idx ?
# TODO:
# maybe the things on the diagonal could be optimized more (
# like 3 * X instead of 3*np.eye().dot(X) kind of thing ?
push_grad = ((1 - reg) * weights.T.dot(Lx)).T.dot(X) # TODO: optimize
# order of
# ops like
# NCA
# TODO: do better sparse multiplication (avoid the transpose)
pull_grad = L.dot(dfG * reg) # we could do a computation with Lx if d >> n

pull_loss = reg * np.sum(target_dist)
grad = push_grad + pull_grad
grad *= 2
it += 1
objective = pull_loss + push_loss

return grad, objective, total_active

def _select_targets(self, L, X, label_inds):
target_neighbors = np.empty((X.shape[0], self.k), dtype=int)
for label in self.labels_:
inds, = np.nonzero(label_inds == label)
dd = euclidean_distances(X[inds], squared=True)
dd = euclidean_distances(X.dot(L.T)[inds], squared=True)
np.fill_diagonal(dd, np.inf)
nn = np.argsort(dd)[..., :self.k]
target_neighbors[inds] = inds[nn]
return target_neighbors

def _find_impostors(self, furthest_neighbors, X, label_inds):
Lx = self.transform(X)
def _find_impostors(self, L, furthest_neighbors, X, label_inds):
Lx = X.dot(L.T)
margin_radii = 1 + _inplace_paired_L2(Lx[furthest_neighbors], Lx)
impostors = []
for label in self.labels_[:-1]:
in_inds, = np.nonzero(label_inds == label)
out_inds, = np.nonzero(label_inds > label)
out_inds, = np.nonzero(label_inds > label) # TODO: not sure why >,
# sth like only one pass through labels and avoid symmetric ?
dist = euclidean_distances(Lx[out_inds], Lx[in_inds], squared=True)
i1,j1 = np.nonzero(dist < margin_radii[out_inds][:,None])
i2,j2 = np.nonzero(dist < margin_radii[in_inds])
Expand Down Expand Up @@ -335,6 +363,7 @@ def _count_edges(act1, act2, impostors, targets):


def _sum_outer_products(data, a_inds, b_inds, weights=None):
# TODO: since used one time, maybe replace by sth else ?
Xab = data[a_inds] - data[b_inds]
if weights is not None:
return np.dot(Xab.T, Xab * weights[:,None])
Expand Down
Loading