diff --git a/doc/metric_learn.rst b/doc/metric_learn.rst index 930404d0..76c91f48 100644 --- a/doc/metric_learn.rst +++ b/doc/metric_learn.rst @@ -14,6 +14,7 @@ Base Classes metric_learn.Constraints metric_learn.base_metric.BaseMetricLearner metric_learn.base_metric._PairsClassifierMixin + metric_learn.base_metric._TripletsClassifierMixin metric_learn.base_metric._QuadrupletsClassifierMixin Supervised Learning Algorithms diff --git a/doc/weakly_supervised.rst b/doc/weakly_supervised.rst index cf313ba1..72f68627 100644 --- a/doc/weakly_supervised.rst +++ b/doc/weakly_supervised.rst @@ -592,6 +592,114 @@ points, while constrains the sum of distances between dissimilar points: -with-side-information.pdf>`_. NIPS 2002 .. [2] Adapted from Matlab code http://www.cs.cmu.edu/%7Eepxing/papers/Old_papers/code_Metric_online.tar.gz +.. _learning_on_triplets: + +Learning on triplets +==================== + +Some metric learning algorithms learn on triplets of samples. In this case, +one should provide the algorithm with `n_samples` triplets of points. The +semantic of each triplet is that the first point should be closer to the +second point than to the third one. + +Fitting +------- +Here is an example for fitting on triplets (see :ref:`fit_ws` for more +details on the input data format and how to fit, in the general case of +learning on tuples). + +>>> from metric_learn import SCML +>>> triplets = np.array([[[1.2, 3.2], [2.3, 5.5], [2.1, 0.6]], +>>> [[4.5, 2.3], [2.1, 2.3], [7.3, 3.4]]]) +>>> scml = SCML(random_state=42) +>>> scml.fit(triplets) +SCML(beta=1e-5, B=None, max_iter=100000, verbose=False, + preprocessor=None, random_state=None) + +Or alternatively (using a preprocessor): + +>>> X = np.array([[[1.2, 3.2], +>>> [2.3, 5.5], +>>> [2.1, 0.6], +>>> [4.5, 2.3], +>>> [2.1, 2.3], +>>> [7.3, 3.4]]) +>>> triplets_indices = np.array([[0, 1, 2], [3, 4, 5]]) +>>> scml = SCML(preprocessor=X, random_state=42) +>>> scml.fit(triplets_indices) +SCML(beta=1e-5, B=None, max_iter=100000, verbose=False, + preprocessor=array([[1.2, 3.2], + [2.3, 5.5], + [2.4, 6.7], + [2.1, 0.6], + [4.5, 2.3], + [2.1, 2.3], + [0.6, 1.2], + [7.3, 3.4]]), + random_state=None) + + +Here, we want to learn a metric that, for each of the two +`triplets`, will make the first point closer to the +second point than to the third one. + +.. _triplets_predicting: + +Prediction +---------- + +When a triplets learner is fitted, it is also able to predict, for an +upcoming triplet, whether the first point is closer to the second point +than to the third one (+1), or not (-1). + +>>> triplets_test = np.array( +... [[[5.6, 5.3], [2.2, 2.1], [1.2, 3.4]], +... [[6.0, 4.2], [4.3, 1.2], [0.1, 7.8]]]) +>>> scml.predict(triplets_test) +array([-1., 1.]) + +.. _triplets_scoring: + +Scoring +------- + +Triplet metric learners can also return a `decision_function` for a set of triplets, +which corresponds to the distance between the first two points minus the distance +between the first and last points of the triplet (the higher the value, the more +similar the first point to the second point compared to the last one). This "score" +can be interpreted as a measure of likeliness of having a +1 prediction for this +triplet. + +>>> scml.decision_function(triplets_test) +array([-1.75700306, 4.98982131]) + +In the above example, for the first triplet in `triplets_test`, the first +point is predicted less similar to the second point than to the last point +(they are further away in the transformed space). + +Unlike pairs learners, triplets learners do not allow to give a `y` when fitting: we +assume that the ordering of points within triplets is such that the training triplets +are all positive. Therefore, it is not possible to use scikit-learn scoring functions +(such as 'f1_score') for triplets learners. + +However, triplets learners do have a default scoring function, which will +basically return the accuracy score on a given test set, i.e. the proportion +of triplets that have the right predicted ordering. + +>>> scml.score(triplets_test) +0.5 + +.. note:: + See :ref:`fit_ws` for more details on metric learners functions that are + not specific to learning on pairs, like `transform`, `score_pairs`, + `get_metric` and `get_mahalanobis_matrix`. + + + + +Algorithms +---------- + .. _learning_on_quadruplets: @@ -599,7 +707,7 @@ Learning on quadruplets ======================= Some metric learning algorithms learn on quadruplets of samples. In this case, -one should provide the algorithm with `n_samples` quadruplets of points. Th +one should provide the algorithm with `n_samples` quadruplets of points. The semantic of each quadruplet is that the first two points should be closer together than the last two points. @@ -666,14 +774,12 @@ array([-1., 1.]) Scoring ------- -Quadruplet metric learners can also -return a `decision_function` for a set of pairs. This is basically the "score" -which sign will be taken to find the prediction for the pair, which -corresponds to the difference between the distance between the two last points, -and the distance between the two last points of the quadruplet (higher -score means the two last points are more likely to be more dissimilar than -the two first points (i.e. more likely to have a +1 prediction since it's -the right ordering)). +Quadruplet metric learners can also return a `decision_function` for a set of +quadruplets, which corresponds to the distance between the first pair of points minus +the distance between the second pair of points of the triplet (the higher the value, +the more similar the first pair is than the last pair). +This "score" can be interpreted as a measure of likeliness of having a +1 prediction +for this quadruplet. >>> lsml.decision_function(quadruplets_test) array([-1.75700306, 4.98982131]) @@ -682,17 +788,10 @@ In the above example, for the first quadruplet in `quadruplets_test`, the two first points are predicted less similar than the two last points (they are further away in the transformed space). -Unlike for pairs learners, quadruplets learners don't allow to give a `y` -when fitting, which does not allow to use scikit-learn scoring functions -like: - ->>> from sklearn.model_selection import cross_val_score ->>> cross_val_score(lsml, quadruplets, scoring='f1_score') # this won't work - -(This is actually intentional, for more details -about that, see -`this comment `_ -on github.) +Like triplet learners, quadruplets learners do not allow to give a `y` when fitting: we +assume that the ordering of points within triplets is such that the training triplets +are all positive. Therefore, it is not possible to use scikit-learn scoring functions +(such as 'f1_score') for triplets learners. However, quadruplets learners do have a default scoring function, which will basically return the accuracy score on a given test set, i.e. the proportion diff --git a/metric_learn/base_metric.py b/metric_learn/base_metric.py index ee73c793..c1e29d93 100644 --- a/metric_learn/base_metric.py +++ b/metric_learn/base_metric.py @@ -589,6 +589,90 @@ def _validate_calibration_params(strategy='accuracy', min_rate=None, 'Got {} instead.'.format(type(beta))) +class _TripletsClassifierMixin(BaseMetricLearner): + """Base class for triplets learners. + """ + + _tuple_size = 3 # number of points in a tuple, 3 for triplets + + def predict(self, triplets): + """Predicts the ordering between sample distances in input triplets. + + For each triplets, returns 1 if the first element is closer to the second + than to the last and -1 if not. + + Parameters + ---------- + triplets : array-like, shape=(n_triplets, 3, n_features) or (n_triplets, 3) + 3D array of triplets to predict, with each row corresponding to three + points, or 2D array of indices of triplets if the metric learner + uses a preprocessor. + + Returns + ------- + prediction : `numpy.ndarray` of floats, shape=(n_constraints,) + Predictions of the ordering of pairs, for each triplet. + """ + return np.sign(self.decision_function(triplets)) + + def decision_function(self, triplets): + """Predicts differences between sample distances in input triplets. + + For each triplet (X_a, X_b, X_c) in the samples, computes the difference + between the learned distance of the second pair (X_a, X_c) minus the + learned distance of the first pair (X_a, X_b). The higher it is, the more + probable it is that the pairs in the triplets are presented in the right + order, i.e. that the label of the triplet is 1. The lower it is, the more + probable it is that the label of the triplet is -1. + + Parameters + ---------- + triplet : array-like, shape=(n_triplets, 3, n_features) or \ + (n_triplets, 3) + 3D array of triplets to predict, with each row corresponding to three + points, or 2D array of indices of triplets if the metric learner + uses a preprocessor. + + Returns + ------- + decision_function : `numpy.ndarray` of floats, shape=(n_constraints,) + Metric differences. + """ + check_is_fitted(self, 'preprocessor_') + triplets = check_input(triplets, type_of_inputs='tuples', + preprocessor=self.preprocessor_, + estimator=self, tuple_size=self._tuple_size) + return (self.score_pairs(triplets[:, [0, 2]]) - + self.score_pairs(triplets[:, :2])) + + def score(self, triplets): + """Computes score on input triplets. + + Returns the accuracy score of the following classification task: a triplet + (X_a, X_b, X_c) is correctly classified if the predicted similarity between + the first pair (X_a, X_b) is higher than that of the second pair (X_a, X_c) + + Parameters + ---------- + triplets : array-like, shape=(n_triplets, 3, n_features) or \ + (n_triplets, 3) + 3D array of triplets to score, with each row corresponding to three + points, or 2D array of indices of triplets if the metric learner + uses a preprocessor. + + Returns + ------- + score : float + The triplets score. + """ + # Since the prediction is a vector of values in {-1, +1}, we need to + # rescale them to {0, 1} to compute the accuracy using the mean (because + # then 1 means a correctly classified result (pairs are in the right + # order), and a 0 an incorrectly classified result (pairs are in the + # wrong order). + return self.predict(triplets).mean() / 2 + 0.5 + + class _QuadrupletsClassifierMixin(BaseMetricLearner): """Base class for quadruplets learners. """ @@ -614,10 +698,6 @@ def predict(self, quadruplets): prediction : `numpy.ndarray` of floats, shape=(n_constraints,) Predictions of the ordering of pairs, for each quadruplet. """ - check_is_fitted(self, 'preprocessor_') - quadruplets = check_input(quadruplets, type_of_inputs='tuples', - preprocessor=self.preprocessor_, - estimator=self, tuple_size=self._tuple_size) return np.sign(self.decision_function(quadruplets)) def decision_function(self, quadruplets): diff --git a/metric_learn/constraints.py b/metric_learn/constraints.py index 752ca6e0..bd51b29f 100644 --- a/metric_learn/constraints.py +++ b/metric_learn/constraints.py @@ -6,6 +6,7 @@ import warnings from six.moves import xrange from sklearn.utils import check_random_state +from sklearn.neighbors import NearestNeighbors __all__ = ['Constraints'] @@ -33,6 +34,115 @@ def positive_negative_pairs(self, num_constraints, same_length=False, return a[:n], b[:n], c[:n], d[:n] return a, b, c, d + def generate_knntriplets(self, X, k_genuine, k_impostor): + """ + Generates triplets from labeled data. + + For every point (X_a) the triplets (X_a, X_b, X_c) are constructed from all + the combinations of taking one of its `k_genuine`-nearest neighbors of the + same class (X_b) and taking one of its `k_impostor`-nearest neighbors of + other classes (X_c). + + In the case a class doesn't have enough points in the same class (other + classes) to yield `k_genuine` (`k_impostor`) neighbors a warning will be + raised and the maximum value of genuine (impostor) neighbors will be used + for that class. + + Parameters + ---------- + X : (n x d) matrix + Input data, where each row corresponds to a single instance. + k_genuine : int + Number of neighbors of the same class to be taken into account. + k_impostor : int + Number of neighbors of different classes to be taken into account. + + Returns + ------- + triplets : array-like, shape=(n_constraints, 3) + 2D array of triplets of indicators. + """ + # Ignore unlabeled samples + known_labels_mask = self.partial_labels >= 0 + known_labels = self.partial_labels[known_labels_mask] + X = X[known_labels_mask] + + labels, labels_count = np.unique(known_labels, return_counts=True) + len_input = known_labels.shape[0] + + # Handle the case where there are too few elements to yield k_genuine or + # k_impostor neighbors for every class. + + k_genuine_vec = np.full_like(labels, k_genuine) + k_impostor_vec = np.full_like(labels, k_impostor) + + for i, count in enumerate(labels_count): + if k_genuine + 1 > count: + k_genuine_vec[i] = count-1 + warnings.warn("The class {} has {} elements, which is not sufficient " + "to generate {} genuine neighbors as specified by " + "k_genuine. Will generate {} genuine neighbors instead." + "\n" + .format(labels[i], count, k_genuine+1, + k_genuine_vec[i])) + if k_impostor > len_input - count: + k_impostor_vec[i] = len_input - count + warnings.warn("The class {} has {} elements of other classes, which is" + " not sufficient to generate {} impostor neighbors as " + "specified by k_impostor. Will generate {} impostor " + "neighbors instead.\n" + .format(labels[i], k_impostor_vec[i], k_impostor, + k_impostor_vec[i])) + + # The total number of possible triplets combinations per label comes from + # taking one of the k_genuine_vec[i] genuine neighbors and one of the + # k_impostor_vec[i] impostor neighbors for the labels_count[i] elements + comb_per_label = labels_count * k_genuine_vec * k_impostor_vec + + # Get start and finish for later triplet assigning + # append zero at the begining for start and get cumulative sum + start_finish_indices = np.hstack((0, comb_per_label)).cumsum() + + # Total number of triplets is the sum of all possible combinations per + # label + num_triplets = start_finish_indices[-1] + triplets = np.empty((num_triplets, 3), dtype=np.intp) + + neigh = NearestNeighbors() + + for i, label in enumerate(labels): + + # generate mask for current label + gen_mask = known_labels == label + gen_indx = np.where(gen_mask) + + # get k_genuine genuine neighbors + neigh.fit(X=X[gen_indx]) + # Take elements of gen_indx according to the yielded k-neighbors + gen_relative_indx = neigh.kneighbors(n_neighbors=k_genuine_vec[i], + return_distance=False) + gen_neigh = np.take(gen_indx, gen_relative_indx) + + # generate mask for impostors of current label + imp_indx = np.where(~gen_mask) + + # get k_impostor impostor neighbors + neigh.fit(X=X[imp_indx]) + # Take elements of imp_indx according to the yielded k-neighbors + imp_relative_indx = neigh.kneighbors(n_neighbors=k_impostor_vec[i], + X=X[gen_mask], + return_distance=False) + imp_neigh = np.take(imp_indx, imp_relative_indx) + + # length = len_label*k_genuine*k_impostor + start, finish = start_finish_indices[i:i+2] + + triplets[start:finish, :] = comb(gen_indx, gen_neigh, imp_neigh, + k_genuine_vec[i], + k_impostor_vec[i]) + + return triplets + def _pairs(self, num_constraints, same_label=True, max_iter=10, random_state=np.random): known_label_idx, = np.where(self.partial_labels >= 0) @@ -91,6 +201,15 @@ def chunks(self, num_chunks=100, chunk_size=2, random_state=None): return chunks +def comb(A, B, C, sizeB, sizeC): + # generate_knntriplets helper function + # generate an array with all combinations of choosing + # an element from A, B and C + return np.vstack((np.tile(A, (sizeB*sizeC, 1)).ravel(order='F'), + np.tile(np.hstack(B), (sizeC, 1)).ravel(order='F'), + np.tile(C, (1, sizeB)).ravel())).T + + def wrap_pairs(X, constraints): a = np.array(constraints[0]) b = np.array(constraints[1]) diff --git a/test/test_constraints.py b/test/test_constraints.py index 243028f6..92876779 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -2,6 +2,7 @@ import numpy as np from sklearn.utils import shuffle from metric_learn.constraints import Constraints +from sklearn.datasets import make_blobs SEED = 42 @@ -69,3 +70,104 @@ def test_unknown_labels_not_in_chunks(num_chunks, chunk_size): random_state=SEED) assert np.all(chunks[labels < 0] < 0) + + +@pytest.mark.parametrize("k_genuine, k_impostor, T_test", + [(2, 2, + [[0, 1, 3], [0, 1, 4], [0, 2, 3], [0, 2, 4], + [1, 0, 3], [1, 0, 4], [1, 2, 3], [1, 2, 4], + [2, 0, 3], [2, 0, 4], [2, 1, 3], [2, 1, 4], + [3, 4, 1], [3, 4, 2], [3, 5, 1], [3, 5, 2], + [4, 3, 1], [4, 3, 2], [4, 5, 1], [4, 5, 2], + [5, 3, 1], [5, 3, 2], [5, 4, 1], [5, 4, 2]]), + (1, 3, + [[0, 1, 3], [0, 1, 4], [0, 1, 5], [1, 0, 3], + [1, 0, 4], [1, 0, 5], [2, 1, 3], [2, 1, 4], + [2, 1, 5], [3, 4, 0], [3, 4, 1], [3, 4, 2], + [4, 3, 0], [4, 3, 1], [4, 3, 2], [5, 4, 0], + [5, 4, 1], [5, 4, 2]]), + (1, 2, + [[0, 1, 3], [0, 1, 4], [1, 0, 3], [1, 0, 4], + [2, 1, 3], [2, 1, 4], [3, 4, 1], [3, 4, 2], + [4, 3, 1], [4, 3, 2], [5, 4, 1], [5, 4, 2]])]) +def test_generate_knntriplets_under_edge(k_genuine, k_impostor, T_test): + """Checks under the edge cases of knn triplet construction with enough + neighbors""" + + X = np.array([[0, 0], [2, 2], [4, 4], [8, 8], [16, 16], [32, 32], [33, 33]]) + y = np.array([1, 1, 1, 2, 2, 2, -1]) + + T = Constraints(y).generate_knntriplets(X, k_genuine, k_impostor) + + assert np.array_equal(sorted(T.tolist()), T_test) + + +@pytest.mark.parametrize("k_genuine, k_impostor,", + [(2, 3), (3, 3), (2, 4), (3, 4)]) +def test_generate_knntriplets(k_genuine, k_impostor): + """Checks edge and over the edge cases of knn triplet construction with not + enough neighbors""" + + T_test = [[0, 1, 3], [0, 1, 4], [0, 1, 5], [0, 2, 3], [0, 2, 4], [0, 2, 5], + [1, 0, 3], [1, 0, 4], [1, 0, 5], [1, 2, 3], [1, 2, 4], [1, 2, 5], + [2, 0, 3], [2, 0, 4], [2, 0, 5], [2, 1, 3], [2, 1, 4], [2, 1, 5], + [3, 4, 0], [3, 4, 1], [3, 4, 2], [3, 5, 0], [3, 5, 1], [3, 5, 2], + [4, 3, 0], [4, 3, 1], [4, 3, 2], [4, 5, 0], [4, 5, 1], [4, 5, 2], + [5, 3, 0], [5, 3, 1], [5, 3, 2], [5, 4, 0], [5, 4, 1], [5, 4, 2]] + + X = np.array([[0, 0], [2, 2], [4, 4], [8, 8], [16, 16], [32, 32], [33, 33]]) + y = np.array([1, 1, 1, 2, 2, 2, -1]) + + T = Constraints(y).generate_knntriplets(X, k_genuine, k_impostor) + + assert np.array_equal(sorted(T.tolist()), T_test) + + +def test_generate_knntriplets_k_genuine(): + """Checks the correct error raised when k_genuine is too big """ + X, y = shuffle(*make_blobs(random_state=SEED), + random_state=SEED) + + label, labels_count = np.unique(y, return_counts=True) + labels_count_min = np.min(labels_count) + idx_smallest_label, = np.where(labels_count == labels_count_min) + k_genuine = labels_count_min + + warn_msgs = [] + for idx in idx_smallest_label: + warn_msgs.append("The class {} has {} elements, which is not sufficient " + "to generate {} genuine neighbors as specified by " + "k_genuine. Will generate {} genuine neighbors instead." + "\n" + .format(label[idx], k_genuine, k_genuine+1, k_genuine-1)) + + with pytest.warns(UserWarning) as raised_warning: + Constraints(y).generate_knntriplets(X, k_genuine, 1) + for warn in raised_warning: + assert str(warn.message) in warn_msgs + + +def test_generate_knntriplets_k_impostor(): + """Checks the correct error raised when k_impostor is too big """ + X, y = shuffle(*make_blobs(random_state=SEED), + random_state=SEED) + + length = len(y) + label, labels_count = np.unique(y, return_counts=True) + labels_count_max = np.max(labels_count) + idx_biggest_label, = np.where(labels_count == labels_count_max) + k_impostor = length - labels_count_max + 1 + + warn_msgs = [] + for idx in idx_biggest_label: + warn_msgs.append("The class {} has {} elements of other classes, which is" + " not sufficient to generate {} impostor neighbors as " + "specified by k_impostor. Will generate {} impostor " + "neighbors instead.\n" + .format(label[idx], k_impostor-1, k_impostor, + k_impostor-1)) + + with pytest.warns(UserWarning) as raised_warning: + Constraints(y).generate_knntriplets(X, 1, k_impostor) + for warn in raised_warning: + assert str(warn.message) in warn_msgs diff --git a/test/test_triplets_classifiers.py b/test/test_triplets_classifiers.py new file mode 100644 index 00000000..8cedd8cc --- /dev/null +++ b/test/test_triplets_classifiers.py @@ -0,0 +1,66 @@ +import pytest +from sklearn.exceptions import NotFittedError +from sklearn.model_selection import train_test_split + +from test.test_utils import triplets_learners, ids_triplets_learners +from sklearn.utils.testing import set_random_state +from sklearn import clone +import numpy as np + + +@pytest.mark.parametrize('with_preprocessor', [True, False]) +@pytest.mark.parametrize('estimator, build_dataset', triplets_learners, + ids=ids_triplets_learners) +def test_predict_only_one_or_minus_one(estimator, build_dataset, + with_preprocessor): + """Test that all predicted values are either +1 or -1""" + input_data, preprocessor = build_dataset(with_preprocessor) + estimator = clone(estimator) + estimator.set_params(preprocessor=preprocessor) + set_random_state(estimator) + triplets_train, triplets_test = train_test_split(input_data) + estimator.fit(triplets_train) + predictions = estimator.predict(triplets_test) + + not_valid = [e for e in predictions if e not in [-1, 1]] + assert len(not_valid) == 0 + + +@pytest.mark.parametrize('with_preprocessor', [True, False]) +@pytest.mark.parametrize('estimator, build_dataset', triplets_learners, + ids=ids_triplets_learners) +def test_raise_not_fitted_error_if_not_fitted(estimator, build_dataset, + with_preprocessor): + """Test that a NotFittedError is raised if someone tries to predict and + the metric learner has not been fitted.""" + input_data, preprocessor = build_dataset(with_preprocessor) + estimator = clone(estimator) + estimator.set_params(preprocessor=preprocessor) + set_random_state(estimator) + with pytest.raises(NotFittedError): + estimator.predict(input_data) + + +@pytest.mark.parametrize('estimator, build_dataset', triplets_learners, + ids=ids_triplets_learners) +def test_accuracy_toy_example(estimator, build_dataset): + """Test that the default scoring for triplets (accuracy) works on some + toy example""" + triplets, X = build_dataset(with_preprocessor=True) + triplets = X[triplets] + estimator = clone(estimator) + set_random_state(estimator) + estimator.fit(triplets) + # We take the two first points and we build 4 regularly spaced points on the + # line they define, so that it's easy to build triplets of different + # similarities. + X_test = X[0] + np.arange(4)[:, np.newaxis] * (X[0] - X[1]) / 4 + + triplets_test = np.array( + [[X_test[0], X_test[2], X_test[1]], + [X_test[1], X_test[3], X_test[0]], + [X_test[1], X_test[2], X_test[3]], + [X_test[3], X_test[0], X_test[2]]]) + # we force the transformation to be identity so that we control what it does + estimator.components_ = np.eye(X.shape[1]) + assert estimator.score(triplets_test) == 0.25 diff --git a/test/test_utils.py b/test/test_utils.py index 2510ed89..a4cf86f4 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -19,9 +19,11 @@ Constraints) from metric_learn.base_metric import (ArrayIndexer, MahalanobisMixin, _PairsClassifierMixin, + _TripletsClassifierMixin, _QuadrupletsClassifierMixin) from metric_learn.exceptions import PreprocessorError, NonPSDError from sklearn.datasets import make_regression, make_blobs, load_iris +from metric_learn.lsml import _BaseLSML SEED = 42 @@ -83,6 +85,34 @@ def build_pairs(with_preprocessor=False): return Dataset(X[c], target, None, X[c[:, 0]]) +def build_triplets(with_preprocessor=False): + input_data, labels = load_iris(return_X_y=True) + X, y = shuffle(input_data, labels, random_state=SEED) + constraints = Constraints(y) + triplets = constraints.generate_knntriplets(X, k_genuine=3, k_impostor=4) + if with_preprocessor: + # if preprocessor, we build a 2D array of triplets of indices + return triplets, X + else: + # if not, we build a 3D array of triplets of samples + return X[triplets], None + + +class mock_triplet_LSML(_BaseLSML, _TripletsClassifierMixin): + # Mock Triplet learner from LSML which is a quadruplets learner + # in order to test TripletClassifierMixin basic methods + + _tuple_size = 4 + + def fit(self, triplets, weights=None): + quadruplets = triplets[:, [0, 1, 0, 2]] + return self._fit(quadruplets, weights=weights) + + def decision_function(self, triplets): + self._tuple_size = 3 + return _TripletsClassifierMixin.decision_function(self, triplets) + + def build_quadruplets(with_preprocessor=False): # builds a toy quadruplets problem X, indices = build_data() @@ -103,6 +133,11 @@ def build_quadruplets(with_preprocessor=False): [learner for (learner, _) in quadruplets_learners])) +triplets_learners = [(mock_triplet_LSML(), build_triplets)] +ids_triplets_learners = list(map(lambda x: x.__class__.__name__, + [learner for (learner, _) in + triplets_learners])) + pairs_learners = [(ITML(max_iter=2), build_pairs), # max_iter=2 to be faster (MMC(max_iter=2), build_pairs), # max_iter=2 to be faster (SDML(prior='identity', balance_param=1e-5), build_pairs)]