From 86d26d39820185140835db037cfa5f2d7f79c77b Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Mon, 13 May 2019 10:16:17 +0200 Subject: [PATCH 1/8] TST: make tests for LMNN gradient --- test/metric_learn_test.py | 65 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 61 insertions(+), 4 deletions(-) diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index a785d60d..24b9180c 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -2,7 +2,7 @@ import re import pytest import numpy as np -from scipy.optimize import check_grad +from scipy.optimize import check_grad, approx_fprime from six.moves import xrange from sklearn.metrics import pairwise_distances from sklearn.datasets import load_iris, make_classification, make_regression @@ -21,7 +21,7 @@ RCA_Supervised, MMC_Supervised, SDML) # Import this specially for testing. from metric_learn.constraints import wrap_pairs -from metric_learn.lmnn import python_LMNN +from metric_learn.lmnn import python_LMNN, _sum_outer_products def class_separation(X, labels): @@ -120,6 +120,61 @@ def test_iris(self): self.iris_labels) self.assertLess(csep, 0.25) + def test_loss_grad_lbfgs(self): + """Test gradient of loss function + Assert that the gradient is almost equal to its finite differences + approximation. + """ + rng = np.random.RandomState(42) + X, y = make_classification(random_state=rng) + L = rng.randn(rng.randint(1, X.shape[1] + 1), X.shape[1]) + lmnn = LMNN() + + k = lmnn.k + reg = lmnn.regularization + + X, y = lmnn._prepare_inputs(X, y, dtype=float, + ensure_min_samples=2) + num_pts, num_dims = X.shape + unique_labels, label_inds = np.unique(y, return_inverse=True) + lmnn.labels_ = np.arange(len(unique_labels)) + lmnn.transformer_ = np.eye(num_dims) + + target_neighbors = lmnn._select_targets(X, label_inds) + impostors = lmnn._find_impostors(target_neighbors[:, -1], X, label_inds) + + # sum outer products + dfG = _sum_outer_products(X, target_neighbors.flatten(), + np.repeat(np.arange(X.shape[0]), k)) + df = np.zeros_like(dfG) + + # storage + a1 = [None]*k + a2 = [None]*k + for nn_idx in xrange(k): + a1[nn_idx] = np.array([]) + a2[nn_idx] = np.array([]) + + # initialize L + + def fun(L): + return lmnn._loss_grad(X, L.reshape(-1, X.shape[1]), dfG, impostors, 1, + k, reg, + target_neighbors, df, a1, a2)[1] + + def grad(L): + return lmnn._loss_grad(X, L.reshape(-1, X.shape[1]), dfG, impostors, + 1, k, reg, + target_neighbors, df, a1, a2)[0].ravel() + + # compute relative error + epsilon = np.sqrt(np.finfo(float).eps) + rel_diff = (check_grad(fun, grad, L.ravel()) / + np.linalg.norm(approx_fprime(L.ravel(), fun, + epsilon))) + # np.linalg.norm(grad(L)) + np.testing.assert_almost_equal(rel_diff, 0., decimal=5) + def test_convergence_simple_example(capsys): # LMNN should converge on this simple example, which it did not with @@ -421,8 +476,10 @@ def grad(M): return nca._loss_grad_lbfgs(M, X, mask)[1].ravel() # compute relative error - rel_diff = check_grad(fun, grad, M.ravel()) / np.linalg.norm(grad(M)) - np.testing.assert_almost_equal(rel_diff, 0., decimal=6) + epsilon = np.sqrt(np.finfo(float).eps) + rel_diff = (check_grad(fun, grad, M.ravel()) / + np.linalg.norm(approx_fprime(M.ravel(), fun, epsilon))) + np.testing.assert_almost_equal(rel_diff, 0., decimal=10) def test_simple_example(self): """Test on a simple example. From 61eea28e1e046a89ded55639ef92e852ab7cc524 Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Tue, 21 May 2019 15:06:59 +0200 Subject: [PATCH 2/8] FIX: fix gradient computation --- metric_learn/lmnn.py | 6 ++-- test/metric_learn_test.py | 67 ++++++++++++++++++++++++++++++++++----- 2 files changed, 63 insertions(+), 10 deletions(-) diff --git a/metric_learn/lmnn.py b/metric_learn/lmnn.py index f9cd0e91..1455799f 100644 --- a/metric_learn/lmnn.py +++ b/metric_learn/lmnn.py @@ -105,7 +105,7 @@ def fit(self, X, y): # objective than the previous L, following the gradient: while True: # the next point next_L to try out is found by a gradient step - L_next = L - 2 * learn_rate * G + L_next = L - learn_rate * G # we compute the objective at next point # we copy variables that can be modified by _loss_grad, because if we # retry we don t want to modify them several times @@ -191,10 +191,12 @@ def _loss_grad(self, X, L, dfG, impostors, it, k, reg, target_neighbors, df, # do the gradient update assert not np.isnan(df).any() G = dfG * reg + df * (1 - reg) + + grad = 2 * L.dot(G) # compute the objective function objective = total_active * (1 - reg) objective += G.flatten().dot(L.T.dot(L).flatten()) - return G, objective, total_active, df, a1, a2 + return grad, objective, total_active, df, a1, a2 def _select_targets(self, X, label_inds): target_neighbors = np.empty((X.shape[0], self.k), dtype=int) diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index 24b9180c..181c2ad4 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -158,24 +158,75 @@ def test_loss_grad_lbfgs(self): # initialize L def fun(L): - return lmnn._loss_grad(X, L.reshape(-1, X.shape[1]), dfG, impostors, 1, - k, reg, - target_neighbors, df, a1, a2)[1] + # we copy variables that can be modified by _loss_grad, because we + # want to have the same result when applying the function twice + return lmnn._loss_grad(X, L.reshape(-1, X.shape[1]), dfG, impostors, + 1, k, reg, target_neighbors, df.copy(), + list(a1), list(a2))[1] def grad(L): + # we copy variables that can be modified by _loss_grad, because we + # want to have the same result when applying the function twice return lmnn._loss_grad(X, L.reshape(-1, X.shape[1]), dfG, impostors, - 1, k, reg, - target_neighbors, df, a1, a2)[0].ravel() + 1, k, reg, target_neighbors, df.copy(), + list(a1), list(a2))[0].ravel() # compute relative error epsilon = np.sqrt(np.finfo(float).eps) rel_diff = (check_grad(fun, grad, L.ravel()) / - np.linalg.norm(approx_fprime(L.ravel(), fun, - epsilon))) - # np.linalg.norm(grad(L)) + np.linalg.norm(approx_fprime(L.ravel(), fun, epsilon))) np.testing.assert_almost_equal(rel_diff, 0., decimal=5) +@pytest.mark.parametrize('X, y, loss', [(np.array([[0], [1], [2], [3]]), + [1, 1, 0, 0], 3.0), + (np.array([[0], [1], [2], [3]]), + [1, 0, 0, 1], 26.)]) +def test_toy_ex_lmnn(X, y, loss): + """Test that the loss give the right result on a toy example""" + L = np.array([[1]]) + lmnn = LMNN(k=1, regularization=0.5) + + k = lmnn.k + reg = lmnn.regularization + + X, y = lmnn._prepare_inputs(X, y, dtype=float, + ensure_min_samples=2) + num_pts, num_dims = X.shape + unique_labels, label_inds = np.unique(y, return_inverse=True) + lmnn.labels_ = np.arange(len(unique_labels)) + lmnn.transformer_ = np.eye(num_dims) + + target_neighbors = lmnn._select_targets(X, label_inds) + impostors = lmnn._find_impostors(target_neighbors[:, -1], X, label_inds) + + # sum outer products + dfG = _sum_outer_products(X, target_neighbors.flatten(), + np.repeat(np.arange(X.shape[0]), k)) + df = np.zeros_like(dfG) + + # storage + a1 = [None]*k + a2 = [None]*k + for nn_idx in xrange(k): + a1[nn_idx] = np.array([]) + a2[nn_idx] = np.array([]) + + # initialize L + + def fun(L): + return lmnn._loss_grad(X, L.reshape(-1, X.shape[1]), dfG, impostors, 1, + k, reg, + target_neighbors, df, a1, a2)[1] + + def grad(L): + return lmnn._loss_grad(X, L.reshape(-1, X.shape[1]), dfG, impostors, 1, + k, reg, target_neighbors, df, a1, a2)[0].ravel() + + # compute relative error + assert fun(L) == loss + + def test_convergence_simple_example(capsys): # LMNN should converge on this simple example, which it did not with # this issue: https://github.com/metric-learn/metric-learn/issues/88 From d5c9dd07dac9c3304c4210753df9a4ae7775df69 Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Tue, 21 May 2019 15:21:54 +0200 Subject: [PATCH 3/8] Simplify expression --- metric_learn/lmnn.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/metric_learn/lmnn.py b/metric_learn/lmnn.py index 1455799f..0e488749 100644 --- a/metric_learn/lmnn.py +++ b/metric_learn/lmnn.py @@ -191,12 +191,11 @@ def _loss_grad(self, X, L, dfG, impostors, it, k, reg, target_neighbors, df, # do the gradient update assert not np.isnan(df).any() G = dfG * reg + df * (1 - reg) - - grad = 2 * L.dot(G) + G = L.dot(G) # compute the objective function objective = total_active * (1 - reg) - objective += G.flatten().dot(L.T.dot(L).flatten()) - return grad, objective, total_active, df, a1, a2 + objective += G.flatten().dot(L.flatten()) + return 2 * G, objective, total_active, df, a1, a2 def _select_targets(self, X, label_inds): target_neighbors = np.empty((X.shape[0], self.k), dtype=int) From b238b6571454daf96ca2b468874d9f8ef04e7f09 Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Tue, 21 May 2019 16:05:49 +0200 Subject: [PATCH 4/8] Be more tolerant for checking NCA --- test/metric_learn_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index 181c2ad4..399cf050 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -530,7 +530,7 @@ def grad(M): epsilon = np.sqrt(np.finfo(float).eps) rel_diff = (check_grad(fun, grad, M.ravel()) / np.linalg.norm(approx_fprime(M.ravel(), fun, epsilon))) - np.testing.assert_almost_equal(rel_diff, 0., decimal=10) + np.testing.assert_almost_equal(rel_diff, 0., decimal=6) def test_simple_example(self): """Test on a simple example. From f9511a046b9cf0b04e66bd031d6cdf313e3ed28a Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Tue, 21 May 2019 16:58:36 +0200 Subject: [PATCH 5/8] Address https://github.com/metric-learn/metric-learn/pull/201#discussion_r286032898 --- test/metric_learn_test.py | 36 +++++++++++------------------------- 1 file changed, 11 insertions(+), 25 deletions(-) diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index 399cf050..49ea9340 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -156,20 +156,16 @@ def test_loss_grad_lbfgs(self): a2[nn_idx] = np.array([]) # initialize L + def loss_grad(flat_L): + return lmnn._loss_grad(X, flat_L.reshape(-1, X.shape[1]), dfG, impostors, + 1, k, reg, target_neighbors, df.copy(), + list(a1), list(a2)) - def fun(L): - # we copy variables that can be modified by _loss_grad, because we - # want to have the same result when applying the function twice - return lmnn._loss_grad(X, L.reshape(-1, X.shape[1]), dfG, impostors, - 1, k, reg, target_neighbors, df.copy(), - list(a1), list(a2))[1] + def fun(x): + loss_grad(x)[1] - def grad(L): - # we copy variables that can be modified by _loss_grad, because we - # want to have the same result when applying the function twice - return lmnn._loss_grad(X, L.reshape(-1, X.shape[1]), dfG, impostors, - 1, k, reg, target_neighbors, df.copy(), - list(a1), list(a2))[0].ravel() + def grad(x): + loss_grad(x)[0].ravel() # compute relative error epsilon = np.sqrt(np.finfo(float).eps) @@ -212,19 +208,9 @@ def test_toy_ex_lmnn(X, y, loss): a1[nn_idx] = np.array([]) a2[nn_idx] = np.array([]) - # initialize L - - def fun(L): - return lmnn._loss_grad(X, L.reshape(-1, X.shape[1]), dfG, impostors, 1, - k, reg, - target_neighbors, df, a1, a2)[1] - - def grad(L): - return lmnn._loss_grad(X, L.reshape(-1, X.shape[1]), dfG, impostors, 1, - k, reg, target_neighbors, df, a1, a2)[0].ravel() - - # compute relative error - assert fun(L) == loss + # assert that the loss equals the one computed by hand + assert lmnn._loss_grad(X, L.reshape(-1, X.shape[1]), dfG, impostors, 1, k, + reg, target_neighbors, df, a1, a2)[1] == loss def test_convergence_simple_example(capsys): From 562f33bfcc7d5fe6a8fc6f65145f4a0d909224d6 Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Wed, 22 May 2019 10:31:40 +0200 Subject: [PATCH 6/8] Add checks for bounds argument --- metric_learn/itml.py | 18 +++++++++++------- test/metric_learn_test.py | 39 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 49 insertions(+), 8 deletions(-) diff --git a/metric_learn/itml.py b/metric_learn/itml.py index 9b6dccb2..ce821f24 100644 --- a/metric_learn/itml.py +++ b/metric_learn/itml.py @@ -68,9 +68,13 @@ def _fit(self, pairs, y, bounds=None): X = np.vstack({tuple(row) for row in pairs.reshape(-1, pairs.shape[2])}) self.bounds_ = np.percentile(pairwise_distances(X), (5, 95)) else: - assert len(bounds) == 2 + bounds = check_array(bounds, allow_nd=False, ensure_min_samples=0, + ensure_2d=False) + bounds = bounds.ravel() + if bounds.size != 2: + raise ValueError("`bounds` should be an array-like of two elements.") self.bounds_ = bounds - self.bounds_[self.bounds_==0] = 1e-9 + self.bounds_[self.bounds_ == 0] = 1e-9 # init metric if self.A0 is None: A = np.identity(pairs.shape[2]) @@ -133,7 +137,7 @@ class ITML(_BaseITML, _PairsClassifierMixin): Attributes ---------- - bounds_ : array-like, shape=(2,) + bounds_ : `numpy.ndarray`, shape=(2,) Bounds on similarity, aside slack variables, s.t. ``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a`` and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of @@ -170,7 +174,7 @@ def fit(self, pairs, y, bounds=None, calibration_params=None): preprocessor. y: array-like, of shape (n_constraints,) Labels of constraints. Should be -1 for dissimilar pair, 1 for similar. - bounds : `list` of two numbers + bounds : array-like of two numbers Bounds on similarity, aside slack variables, s.t. ``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a`` and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of @@ -191,7 +195,7 @@ def fit(self, pairs, y, bounds=None, calibration_params=None): calibration_params = (calibration_params if calibration_params is not None else dict()) self._validate_calibration_params(**calibration_params) - self._fit(pairs, y) + self._fit(pairs, y, bounds=bounds) self.calibrate_threshold(pairs, y, **calibration_params) return self @@ -201,7 +205,7 @@ class ITML_Supervised(_BaseITML, TransformerMixin): Attributes ---------- - bounds_ : array-like, shape=(2,) + bounds_ : `numpy.ndarray`, shape=(2,) Bounds on similarity, aside slack variables, s.t. ``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a`` and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of @@ -274,7 +278,7 @@ def fit(self, X, y, random_state=np.random, bounds=None): random_state : numpy.random.RandomState, optional If provided, controls random number generation. - bounds : `list` of two numbers + bounds : array-like of two numbers Bounds on similarity, aside slack variables, s.t. ``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a`` and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index 49ea9340..7611109a 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -18,7 +18,7 @@ HAS_SKGGM = True from metric_learn import (LMNN, NCA, LFDA, Covariance, MLKR, MMC, LSML_Supervised, ITML_Supervised, SDML_Supervised, - RCA_Supervised, MMC_Supervised, SDML) + RCA_Supervised, MMC_Supervised, SDML, ITML) # Import this specially for testing. from metric_learn.constraints import wrap_pairs from metric_learn.lmnn import python_LMNN, _sum_outer_products @@ -109,6 +109,43 @@ def test_deprecation_bounds(self): assert_warns_message(DeprecationWarning, msg, itml_supervised.fit, X, y) +@pytest.mark.parametrize('bounds', [None, (20., 100.), [20., 100.], + np.array([20., 100.]), + np.array([[20., 100.]]), + np.array([[20], [100]])]) +def test_bounds_parameters_valid(bounds): + """Asserts that we can provide any array-like of two elements as bounds, + and that the attribute bound_ is a numpy array""" + + pairs = np.array([[[-10., 0.], [10., 0.]], [[0., 50.], [0., -60]]]) + y_pairs = [1, -1] + itml = ITML() + itml.fit(pairs, y_pairs, bounds=bounds) + + X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) + y = np.array([1, 0, 1, 0]) + itml_supervised = ITML_Supervised() + itml_supervised.fit(X, y, bounds=bounds) + + +@pytest.mark.parametrize('bounds', ['weird', ['weird1', 'weird2'], + np.array([1, 2, 3])]) +def test_bounds_parameters_invalid(bounds): + """Assert that if a non array-like is put for bounds, or an array-like + of length different than 2, an error is returned""" + pairs = np.array([[[-10., 0.], [10., 0.]], [[0., 50.], [0., -60]]]) + y_pairs = [1, -1] + itml = ITML() + with pytest.raises(Exception): + itml.fit(pairs, y_pairs, bounds=bounds) + + X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) + y = np.array([1, 0, 1, 0]) + itml_supervised = ITML_Supervised() + with pytest.raises(Exception): + itml_supervised.fit(X, y, bounds=bounds) + + class TestLMNN(MetricTestCase): def test_iris(self): # Test both impls, if available. From 65057e33e706022b04785d33895105e3d3bca6cf Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Wed, 22 May 2019 10:40:17 +0200 Subject: [PATCH 7/8] Revert "Add checks for bounds argument" This reverts commit 562f33bfcc7d5fe6a8fc6f65145f4a0d909224d6. --- metric_learn/itml.py | 18 +++++++----------- test/metric_learn_test.py | 39 +-------------------------------------- 2 files changed, 8 insertions(+), 49 deletions(-) diff --git a/metric_learn/itml.py b/metric_learn/itml.py index ce821f24..9b6dccb2 100644 --- a/metric_learn/itml.py +++ b/metric_learn/itml.py @@ -68,13 +68,9 @@ def _fit(self, pairs, y, bounds=None): X = np.vstack({tuple(row) for row in pairs.reshape(-1, pairs.shape[2])}) self.bounds_ = np.percentile(pairwise_distances(X), (5, 95)) else: - bounds = check_array(bounds, allow_nd=False, ensure_min_samples=0, - ensure_2d=False) - bounds = bounds.ravel() - if bounds.size != 2: - raise ValueError("`bounds` should be an array-like of two elements.") + assert len(bounds) == 2 self.bounds_ = bounds - self.bounds_[self.bounds_ == 0] = 1e-9 + self.bounds_[self.bounds_==0] = 1e-9 # init metric if self.A0 is None: A = np.identity(pairs.shape[2]) @@ -137,7 +133,7 @@ class ITML(_BaseITML, _PairsClassifierMixin): Attributes ---------- - bounds_ : `numpy.ndarray`, shape=(2,) + bounds_ : array-like, shape=(2,) Bounds on similarity, aside slack variables, s.t. ``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a`` and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of @@ -174,7 +170,7 @@ def fit(self, pairs, y, bounds=None, calibration_params=None): preprocessor. y: array-like, of shape (n_constraints,) Labels of constraints. Should be -1 for dissimilar pair, 1 for similar. - bounds : array-like of two numbers + bounds : `list` of two numbers Bounds on similarity, aside slack variables, s.t. ``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a`` and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of @@ -195,7 +191,7 @@ def fit(self, pairs, y, bounds=None, calibration_params=None): calibration_params = (calibration_params if calibration_params is not None else dict()) self._validate_calibration_params(**calibration_params) - self._fit(pairs, y, bounds=bounds) + self._fit(pairs, y) self.calibrate_threshold(pairs, y, **calibration_params) return self @@ -205,7 +201,7 @@ class ITML_Supervised(_BaseITML, TransformerMixin): Attributes ---------- - bounds_ : `numpy.ndarray`, shape=(2,) + bounds_ : array-like, shape=(2,) Bounds on similarity, aside slack variables, s.t. ``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a`` and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of @@ -278,7 +274,7 @@ def fit(self, X, y, random_state=np.random, bounds=None): random_state : numpy.random.RandomState, optional If provided, controls random number generation. - bounds : array-like of two numbers + bounds : `list` of two numbers Bounds on similarity, aside slack variables, s.t. ``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a`` and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index 7611109a..49ea9340 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -18,7 +18,7 @@ HAS_SKGGM = True from metric_learn import (LMNN, NCA, LFDA, Covariance, MLKR, MMC, LSML_Supervised, ITML_Supervised, SDML_Supervised, - RCA_Supervised, MMC_Supervised, SDML, ITML) + RCA_Supervised, MMC_Supervised, SDML) # Import this specially for testing. from metric_learn.constraints import wrap_pairs from metric_learn.lmnn import python_LMNN, _sum_outer_products @@ -109,43 +109,6 @@ def test_deprecation_bounds(self): assert_warns_message(DeprecationWarning, msg, itml_supervised.fit, X, y) -@pytest.mark.parametrize('bounds', [None, (20., 100.), [20., 100.], - np.array([20., 100.]), - np.array([[20., 100.]]), - np.array([[20], [100]])]) -def test_bounds_parameters_valid(bounds): - """Asserts that we can provide any array-like of two elements as bounds, - and that the attribute bound_ is a numpy array""" - - pairs = np.array([[[-10., 0.], [10., 0.]], [[0., 50.], [0., -60]]]) - y_pairs = [1, -1] - itml = ITML() - itml.fit(pairs, y_pairs, bounds=bounds) - - X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) - y = np.array([1, 0, 1, 0]) - itml_supervised = ITML_Supervised() - itml_supervised.fit(X, y, bounds=bounds) - - -@pytest.mark.parametrize('bounds', ['weird', ['weird1', 'weird2'], - np.array([1, 2, 3])]) -def test_bounds_parameters_invalid(bounds): - """Assert that if a non array-like is put for bounds, or an array-like - of length different than 2, an error is returned""" - pairs = np.array([[[-10., 0.], [10., 0.]], [[0., 50.], [0., -60]]]) - y_pairs = [1, -1] - itml = ITML() - with pytest.raises(Exception): - itml.fit(pairs, y_pairs, bounds=bounds) - - X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) - y = np.array([1, 0, 1, 0]) - itml_supervised = ITML_Supervised() - with pytest.raises(Exception): - itml_supervised.fit(X, y, bounds=bounds) - - class TestLMNN(MetricTestCase): def test_iris(self): # Test both impls, if available. From 69c328cf7e05d374e6d7b19c78f377765752f8c1 Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Wed, 22 May 2019 11:32:15 +0200 Subject: [PATCH 8/8] Add missing return --- test/metric_learn_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index 49ea9340..bf079511 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -162,10 +162,10 @@ def loss_grad(flat_L): list(a1), list(a2)) def fun(x): - loss_grad(x)[1] + return loss_grad(x)[1] def grad(x): - loss_grad(x)[0].ravel() + return loss_grad(x)[0].ravel() # compute relative error epsilon = np.sqrt(np.finfo(float).eps)