diff --git a/metric_learn/lmnn.py b/metric_learn/lmnn.py index dea12f0c..04a69396 100644 --- a/metric_learn/lmnn.py +++ b/metric_learn/lmnn.py @@ -95,12 +95,16 @@ def fit(self, X, y): L = self.L_ objective = np.inf + # we initialize the roll back + L_old = L.copy() + G_old = G.copy() + df_old = df.copy() + a1_old = [a.copy() for a in a1] + a2_old = [a.copy() for a in a2] + objective_old = objective + # main loop for it in xrange(1, self.max_iter): - df_old = df.copy() - a1_old = [a.copy() for a in a1] - a2_old = [a.copy() for a in a2] - objective_old = objective # Compute pairwise distances under current metric Lx = L.dot(self.X_.T).T g0 = _inplace_paired_L2(*Lx[impostors]) @@ -158,14 +162,25 @@ def fit(self, X, y): if delta_obj > 0: # we're getting worse... roll back! learn_rate /= 2.0 + L = L_old + G = G_old df = df_old a1 = a1_old a2 = a2_old objective = objective_old else: - # update L - L -= learn_rate * 2 * L.dot(G) - learn_rate *= 1.01 + # We did good. We store this point as reference in case we do + # worse next time. + objective_old = objective + L_old = L.copy() + G_old = G.copy() + df_old = df.copy() + a1_old = [a.copy() for a in a1] + a2_old = [a.copy() for a in a2] + # we update L and will see in the next iteration if it does indeed + # better + L -= learn_rate * 2 * L.dot(G) + learn_rate *= 1.01 # check for convergence if it > self.min_iter and abs(delta_obj) < self.convergence_tol: @@ -177,7 +192,7 @@ def fit(self, X, y): print("LMNN didn't converge in %d steps." % self.max_iter) # store the last L - self.L_ = L + self.L_ = L_old self.n_iter_ = it return self