Skip to content

FIX Fixes LMNN rollback #100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions metric_learn/lmnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,16 @@ def fit(self, X, y):
L = self.L_
objective = np.inf

# we initialize the roll back
L_old = L.copy()
G_old = G.copy()
df_old = df.copy()
a1_old = [a.copy() for a in a1]
a2_old = [a.copy() for a in a2]
objective_old = objective

# main loop
for it in xrange(1, self.max_iter):
df_old = df.copy()
a1_old = [a.copy() for a in a1]
a2_old = [a.copy() for a in a2]
objective_old = objective
# Compute pairwise distances under current metric
Lx = L.dot(self.X_.T).T
g0 = _inplace_paired_L2(*Lx[impostors])
Expand Down Expand Up @@ -158,14 +162,25 @@ def fit(self, X, y):
if delta_obj > 0:
# we're getting worse... roll back!
learn_rate /= 2.0
L = L_old
G = G_old
df = df_old
a1 = a1_old
a2 = a2_old
objective = objective_old
else:
# update L
L -= learn_rate * 2 * L.dot(G)
learn_rate *= 1.01
# We did good. We store this point as reference in case we do
# worse next time.
objective_old = objective
L_old = L.copy()
G_old = G.copy()
df_old = df.copy()
a1_old = [a.copy() for a in a1]
a2_old = [a.copy() for a in a2]
# we update L and will see in the next iteration if it does indeed
# better
L -= learn_rate * 2 * L.dot(G)
learn_rate *= 1.01

# check for convergence
if it > self.min_iter and abs(delta_obj) < self.convergence_tol:
Expand All @@ -177,7 +192,7 @@ def fit(self, X, y):
print("LMNN didn't converge in %d steps." % self.max_iter)

# store the last L
self.L_ = L
self.L_ = L_old
self.n_iter_ = it
return self

Expand Down