From 965cc5c9673fbadaa5c92f29e82573e7123be670 Mon Sep 17 00:00:00 2001 From: grudloff Date: Thu, 13 Feb 2020 14:54:42 +0100 Subject: [PATCH 01/19] add _TripletsClassifierMixin --- metric_learn/base_metric.py | 91 +++++++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) diff --git a/metric_learn/base_metric.py b/metric_learn/base_metric.py index ee73c793..dedce461 100644 --- a/metric_learn/base_metric.py +++ b/metric_learn/base_metric.py @@ -589,6 +589,97 @@ def _validate_calibration_params(strategy='accuracy', min_rate=None, 'Got {} instead.'.format(type(beta))) +class _TripletsClassifierMixin(BaseMetricLearner): + """Base class for triplets learners. + """ + + _tuple_size = 3 # number of points in a tuple, 3 for triplets + + def predict(self, triplets): + """Predicts the ordering between sample distances in input triplets. + + For each triplets, returns 1 if the first two elements are closer than the + first and last and -1 if not. + + Parameters + ---------- + triplets : array-like, shape=(n_triplets, 3, n_features) or (n_triplets, 3) + 3D Array of triplets to predict, with each row corresponding to three + points, or 2D array of indices of triplets if the metric learner + uses a preprocessor. + + Returns + ------- + prediction : `numpy.ndarray` of floats, shape=(n_constraints,) + Predictions of the ordering of pairs, for each triplet. + """ + + # Aren't the following lines redundant as they are called on + # decision_function (as in quadruplets predict) ??? + # check_is_fitted(self, 'preprocessor_') + # triplets = check_input(triplets, type_of_inputs='tuples', + # preprocessor=self.preprocessor_, + # estimator=self, tuple_size=self._tuple_size) + return np.sign(self.decision_function(triplets)) + + def decision_function(self, triplets): + """Predicts differences between sample distances in input triplets. + + For each triplets in the samples, computes the difference between the + learned metric of the second pair minus the learned metric of the first + pair. The higher it is, the more probable it is that the pairs in the + triplets are presented in the right order, i.e. that the label of the + triplet is 1. The lower it is, the more probable it is that the label of + the triplet is -1. + + Parameters + ---------- + triplet : array-like, shape=(n_triplets, 4, n_features) or \ + (n_triplets, 4) + 3D Array of triplets to predict, with each row corresponding to four + points, or 2D array of indices of triplets if the metric learner + uses a preprocessor. + + Returns + ------- + decision_function : `numpy.ndarray` of floats, shape=(n_constraints,) + Metric differences. + """ + check_is_fitted(self, 'preprocessor_') + triplets = check_input(triplets, type_of_inputs='tuples', + preprocessor=self.preprocessor_, + estimator=self, tuple_size=self._tuple_size) + return (self.score_pairs(triplets[:, [0, 2]]) - + self.score_pairs(triplets[:, :2])) + + def score(self, triplets): + """Computes score on input triplets + + Returns the accuracy score of the following classification task: a record + is correctly classified if the predicted similarity between the first two + samples is higher than that of the first and last element. + + Parameters + ---------- + triplets : array-like, shape=(n_triplets, 4, n_features) or \ + (n_triplets, 4) + 3D Array of triplets to score, with each row corresponding to four + points, or 2D array of indices of triplets if the metric learner + uses a preprocessor. + + Returns + ------- + score : float + The triplets score. + """ + # Since the prediction is a vector of values in {-1, +1}, we need to + # rescale them to {0, 1} to compute the accuracy using the mean (because + # then 1 means a correctly classified result (pairs are in the right + # order), and a 0 an incorrectly classified result (pairs are in the + # wrong order). + return self.predict(triplets).mean() / 2 + 0.5 + + class _QuadrupletsClassifierMixin(BaseMetricLearner): """Base class for quadruplets learners. """ From ecd22f6ab862224aa1421cbc16024934936c2545 Mon Sep 17 00:00:00 2001 From: grudloff Date: Mon, 17 Feb 2020 15:39:54 +0100 Subject: [PATCH 02/19] added doc --- doc/weakly_supervised.rst | 113 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 113 insertions(+) diff --git a/doc/weakly_supervised.rst b/doc/weakly_supervised.rst index cf313ba1..546b9dde 100644 --- a/doc/weakly_supervised.rst +++ b/doc/weakly_supervised.rst @@ -592,6 +592,119 @@ points, while constrains the sum of distances between dissimilar points: -with-side-information.pdf>`_. NIPS 2002 .. [2] Adapted from Matlab code http://www.cs.cmu.edu/%7Eepxing/papers/Old_papers/code_Metric_online.tar.gz +.. _learning_on_triplets: + +Learning on Triplets +==================== + +Some metric learning algorithms learn on triplets of samples. In this case, +one should provide the algorithm with `n_samples` triplets of points. The +semantic of each triplet is that the first two points should be closer +together than the first and the last. + +Fitting +------- +Here is an example for fitting on triplets (see :ref:`fit_ws` for more +details on the input data format and how to fit, in the general case of +learning on tuples). + +>>> from metric_learn import SCML +>>> triplets = np.array([[[1.2, 3.2], [2.3, 5.5], [2.1, 0.6]], +>>> [[4.5, 2.3], [2.1, 2.3], [7.3, 3.4]]]) +>>> scml = SCML(random_state=42) +>>> scml.fit(triplets) +SCML(beta=1e-5, B=None, max_iter=100000, verbose=False, + preprocessor=None, random_state=None) + +Or alternatively (using a preprocessor): + +>>> X = np.array([[[1.2, 3.2], +>>> [2.3, 5.5], +>>> [2.1, 0.6], +>>> [4.5, 2.3], +>>> [2.1, 2.3], +>>> [7.3, 3.4]]) +>>> triplets_indices = np.array([[0, 1, 2], [3, 4, 5]]) +>>> scml = SCML(preprocessor=X, random_state=42) +>>> scml.fit(triplets_indices) +SCML(beta=1e-5, B=None, max_iter=100000, verbose=False, + preprocessor=array([[1.2, 3.2], + [2.3, 5.5], + [2.4, 6.7], + [2.1, 0.6], + [4.5, 2.3], + [2.1, 2.3], + [0.6, 1.2], + [7.3, 3.4]]), + random_state=None) + + +Here, we want to learn a metric that, for each of the two +`triplets`, will put the two first points closer together than the first and the last. + +.. _triplets_predicting: + +Prediction +---------- + +When a triplets learner is fitted, it is also able to predict, for an +upcoming triplet, whether the two first points are more similar than the +first and the last (+1), or not (-1). + +>>> triplets_test = np.array( +... [[[5.6, 5.3], [2.2, 2.1], [1.2, 3.4]], +... [[6.0, 4.2], [4.3, 1.2], [0.1, 7.8]]]) +>>> scml.predict(triplets_test) +array([-1., 1.]) + +.. _triplets_scoring: + +Scoring +------- + +Triplet metric learners can also +return a `decision_function` for a set of pairs. This is basically the "score" +which sign will be taken to find the prediction for the pair, which +corresponds to the difference between the distance between the first two points, +and the distance between the first and last points of the triplet (higher +score means the first and last points are more likely to be more dissimilar than +the two first points (i.e. more likely to have a +1 prediction since it's +the right ordering)). + +>>> scml.decision_function(triplets_test) +array([-1.75700306, 4.98982131]) + +In the above example, for the first triplet in `triplets_test`, the +two first points are predicted less similar than the first and last points (they +are further away in the transformed space). + +Unlike for pairs learners, triplets learners don't allow to give a `y` +when fitting, which does not allow to use scikit-learn scoring functions +like: + +>>> from sklearn.model_selection import cross_val_score +>>> cross_val_score(scml, triplets, scoring='f1_score') # this won't work + +(This is actually intentional) + +However, triplets learners do have a default scoring function, which will +basically return the accuracy score on a given test set, i.e. the proportion +of triplets have the right predicted ordering. + +>>> scml.score(triplets_test) +0.5 + +.. note:: + See :ref:`fit_ws` for more details on metric learners functions that are + not specific to learning on pairs, like `transform`, `score_pairs`, + `get_metric` and `get_mahalanobis_matrix`. + + + + +Algorithms +---------- + .. _learning_on_quadruplets: From f0df4adda16c52f8e32be88c56d5d8a2cabbbf3d Mon Sep 17 00:00:00 2001 From: grudloff Date: Tue, 18 Feb 2020 09:53:23 +0100 Subject: [PATCH 03/19] remove redundant code --- metric_learn/base_metric.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/metric_learn/base_metric.py b/metric_learn/base_metric.py index dedce461..b1d6e197 100644 --- a/metric_learn/base_metric.py +++ b/metric_learn/base_metric.py @@ -613,13 +613,6 @@ def predict(self, triplets): prediction : `numpy.ndarray` of floats, shape=(n_constraints,) Predictions of the ordering of pairs, for each triplet. """ - - # Aren't the following lines redundant as they are called on - # decision_function (as in quadruplets predict) ??? - # check_is_fitted(self, 'preprocessor_') - # triplets = check_input(triplets, type_of_inputs='tuples', - # preprocessor=self.preprocessor_, - # estimator=self, tuple_size=self._tuple_size) return np.sign(self.decision_function(triplets)) def decision_function(self, triplets): @@ -705,10 +698,6 @@ def predict(self, quadruplets): prediction : `numpy.ndarray` of floats, shape=(n_constraints,) Predictions of the ordering of pairs, for each quadruplet. """ - check_is_fitted(self, 'preprocessor_') - quadruplets = check_input(quadruplets, type_of_inputs='tuples', - preprocessor=self.preprocessor_, - estimator=self, tuple_size=self._tuple_size) return np.sign(self.decision_function(quadruplets)) def decision_function(self, quadruplets): From b8054f332eada97de3ab7b85d23b6b6486e270e7 Mon Sep 17 00:00:00 2001 From: grudloff Date: Tue, 18 Feb 2020 12:02:45 +0100 Subject: [PATCH 04/19] added tests --- metric_learn/constraints.py | 52 ++++++++++++++++++++++ test/test_triplets_classifiers.py | 71 +++++++++++++++++++++++++++++++ test/test_utils.py | 35 +++++++++++++++ 3 files changed, 158 insertions(+) create mode 100644 test/test_triplets_classifiers.py diff --git a/metric_learn/constraints.py b/metric_learn/constraints.py index 752ca6e0..6075896f 100644 --- a/metric_learn/constraints.py +++ b/metric_learn/constraints.py @@ -6,6 +6,8 @@ import warnings from six.moves import xrange from sklearn.utils import check_random_state +from sklearn.neighbors import NearestNeighbors +from numpy.matlib import repmat __all__ = ['Constraints'] @@ -33,6 +35,56 @@ def positive_negative_pairs(self, num_constraints, same_length=False, return a[:n], b[:n], c[:n], d[:n] return a, b, c, d + def generate_knntriplets(self, X, k_genuine, k_impostor): + + labels = np.unique(self.partial_labels) + L = len(labels) + len_input = np.size(self.partial_labels, 0) + triplets = np.empty((len_input*k_genuine*k_impostor, 3), dtype=np.intp) + + start = 0 + finish = 0 + neigh = NearestNeighbors() + + for i in range(L): + + # generate mask for current label + gen_mask = self.partial_labels == labels[i] + gen_indx = np.where(gen_mask) + + # get k_genuine genuine neighbours + neigh.fit(X=X[gen_indx]) + gen_neigh = np.take(gen_indx, neigh.kneighbors(n_neighbors=k_genuine, + return_distance=False)) + + # generate mask for impostors of current label + imp_indx = np.where(np.invert(gen_mask)) + + # get k_impostor impostor neighbours + neigh.fit(X=X[imp_indx]) + imp_neigh = np.take(imp_indx, neigh.kneighbors( + n_neighbors=k_impostor, + X=X[gen_mask], + return_distance=False)) + + # lenght = len_label*k_genuine*k_impostor + finish += np.sum(gen_mask)*k_genuine*k_impostor + + triplets[start:finish, :] = self._comb(gen_indx, gen_neigh, imp_neigh, + k_genuine, k_impostor) + start = finish + + # TODO: deal with too litle elements for k neighbors to be yielded + + return triplets + + def _comb(self, A, B, C, sizeB, sizeC): + # generate an array will all combinations of choosing + # an element from A, B and C + return np.vstack((repmat(A, sizeB*sizeC, 1).ravel(order='F'), + repmat(np.hstack(B), sizeC, 1).ravel(order='F'), + repmat(C, 1, sizeB).ravel())).T + def _pairs(self, num_constraints, same_label=True, max_iter=10, random_state=np.random): known_label_idx, = np.where(self.partial_labels >= 0) diff --git a/test/test_triplets_classifiers.py b/test/test_triplets_classifiers.py new file mode 100644 index 00000000..ec6184fb --- /dev/null +++ b/test/test_triplets_classifiers.py @@ -0,0 +1,71 @@ +import pytest +from sklearn.exceptions import NotFittedError +from sklearn.model_selection import train_test_split + +from test.test_utils import triplets_learners, ids_triplets_learners +from sklearn.utils.testing import set_random_state +from sklearn import clone +import numpy as np + + +@pytest.mark.parametrize('with_preprocessor', [True, False]) +@pytest.mark.parametrize('estimator, build_dataset', triplets_learners, + ids=ids_triplets_learners) +def test_predict_only_one_or_minus_one(estimator, build_dataset, + with_preprocessor): + """Test that all predicted values are either +1 or -1""" + input_data, preprocessor = build_dataset(with_preprocessor) + estimator = clone(estimator) + estimator.set_params(preprocessor=preprocessor) + set_random_state(estimator) + triplets_train, triplets_test = train_test_split(input_data) + estimator.fit(triplets_train) + predictions = estimator.predict(triplets_test) + for i in range(len(input_data)): + prediction = estimator.predict(input_data[None, i]) + if(prediction == 0): + print(input_data[i]) + print(preprocessor[input_data[i]]) + print(estimator.decision_function(input_data[None, i])) + not_valid = [e for e in predictions if e not in [-1, 1]] + assert len(not_valid) == 0 + + +@pytest.mark.parametrize('with_preprocessor', [True, False]) +@pytest.mark.parametrize('estimator, build_dataset', triplets_learners, + ids=ids_triplets_learners) +def test_raise_not_fitted_error_if_not_fitted(estimator, build_dataset, + with_preprocessor): + """Test that a NotFittedError is raised if someone tries to predict and + the metric learner has not been fitted.""" + input_data, preprocessor = build_dataset(with_preprocessor) + estimator = clone(estimator) + estimator.set_params(preprocessor=preprocessor) + set_random_state(estimator) + with pytest.raises(NotFittedError): + estimator.predict(input_data) + + +@pytest.mark.parametrize('estimator, build_dataset', triplets_learners, + ids=ids_triplets_learners) +def test_accuracy_toy_example(estimator, build_dataset): + """Test that the default scoring for triplets (accuracy) works on some + toy example""" + triplets, X = build_dataset(with_preprocessor=True) + triplets = X[triplets] + estimator = clone(estimator) + set_random_state(estimator) + estimator.fit(triplets) + # We take the two first points and we build 4 regularly spaced points on the + # line they define, so that it's easy to build quadruplets of different + # similarities. + X_test = X[0] + np.arange(4)[:, np.newaxis] * (X[0] - X[1]) / 4 + + triplets_test = np.array( + [[X_test[0], X_test[2], X_test[1]], + [X_test[1], X_test[3], X_test[0]], + [X_test[1], X_test[2], X_test[3]], + [X_test[3], X_test[0], X_test[2]]]) + # we force the transformation to be identity so that we control what it does + estimator.components_ = np.eye(X.shape[1]) + assert estimator.score(triplets_test) == 0.25 diff --git a/test/test_utils.py b/test/test_utils.py index 2510ed89..d61a14db 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -19,9 +19,11 @@ Constraints) from metric_learn.base_metric import (ArrayIndexer, MahalanobisMixin, _PairsClassifierMixin, + _TripletsClassifierMixin, _QuadrupletsClassifierMixin) from metric_learn.exceptions import PreprocessorError, NonPSDError from sklearn.datasets import make_regression, make_blobs, load_iris +from metric_learn.lsml import _BaseLSML SEED = 42 @@ -83,6 +85,34 @@ def build_pairs(with_preprocessor=False): return Dataset(X[c], target, None, X[c[:, 0]]) +def build_triplets(with_preprocessor=False): + input_data, labels = load_iris(return_X_y=True) + X, y = shuffle(input_data, labels, random_state=SEED) + constraints = Constraints(y) + triplets = constraints.generate_knntriplets(X, 3, 10) + if with_preprocessor: + # if preprocessor, we build a 2D array of triplets of indices + return triplets, X + else: + # if not, we build a 3D array of triplets of samples + return X[triplets], None + + +class mock_triplet_LSML(_BaseLSML, _TripletsClassifierMixin): + # Mock Triplet learner from LSML which is a quadruplets learner + # in order to test TripletClassifierMixin basic methods + + _tuple_size = 4 + + def fit(self, triplets, weights=None): + quadruplets = triplets[:, [0, 1, 0, 2]] + return self._fit(quadruplets, weights=weights) + + def decision_function(self, triplets): + self._tuple_size = 3 + return _TripletsClassifierMixin.decision_function(self, triplets) + + def build_quadruplets(with_preprocessor=False): # builds a toy quadruplets problem X, indices = build_data() @@ -103,6 +133,11 @@ def build_quadruplets(with_preprocessor=False): [learner for (learner, _) in quadruplets_learners])) +triplets_learners = [(mock_triplet_LSML(), build_triplets)] +ids_triplets_learners = list(map(lambda x: x.__class__.__name__, + [learner for (learner, _) in + triplets_learners])) + pairs_learners = [(ITML(max_iter=2), build_pairs), # max_iter=2 to be faster (MMC(max_iter=2), build_pairs), # max_iter=2 to be faster (SDML(prior='identity', balance_param=1e-5), build_pairs)] From de7aa11f4c69c725fc4b525a15b94034d585adf9 Mon Sep 17 00:00:00 2001 From: grudloff Date: Tue, 18 Feb 2020 13:37:34 +0100 Subject: [PATCH 05/19] triplets added to doc autosumary --- doc/metric_learn.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/metric_learn.rst b/doc/metric_learn.rst index 930404d0..76c91f48 100644 --- a/doc/metric_learn.rst +++ b/doc/metric_learn.rst @@ -14,6 +14,7 @@ Base Classes metric_learn.Constraints metric_learn.base_metric.BaseMetricLearner metric_learn.base_metric._PairsClassifierMixin + metric_learn.base_metric._TripletsClassifierMixin metric_learn.base_metric._QuadrupletsClassifierMixin Supervised Learning Algorithms From cb64e2fbcb6fed778993387fe3ea338db8c2f2bd Mon Sep 17 00:00:00 2001 From: grudloff Date: Fri, 21 Feb 2020 13:37:50 +0100 Subject: [PATCH 06/19] rephrasing, added docstring and small changes --- doc/weakly_supervised.rst | 76 +++++++++++++------------------ metric_learn/base_metric.py | 34 +++++++------- metric_learn/constraints.py | 29 ++++++++++-- test/test_triplets_classifiers.py | 2 +- test/test_utils.py | 2 +- 5 files changed, 76 insertions(+), 67 deletions(-) diff --git a/doc/weakly_supervised.rst b/doc/weakly_supervised.rst index 546b9dde..63bb00f7 100644 --- a/doc/weakly_supervised.rst +++ b/doc/weakly_supervised.rst @@ -599,8 +599,8 @@ Learning on Triplets Some metric learning algorithms learn on triplets of samples. In this case, one should provide the algorithm with `n_samples` triplets of points. The -semantic of each triplet is that the first two points should be closer -together than the first and the last. +semantic of each triplet is that the first point should be closer to the +second point than to the third one. Fitting ------- @@ -640,7 +640,8 @@ SCML(beta=1e-5, B=None, max_iter=100000, verbose=False, Here, we want to learn a metric that, for each of the two -`triplets`, will put the two first points closer together than the first and the last. +`triplets`, will make the first point closer to the +second point than to the third one. .. _triplets_predicting: @@ -648,8 +649,8 @@ Prediction ---------- When a triplets learner is fitted, it is also able to predict, for an -upcoming triplet, whether the two first points are more similar than the -first and the last (+1), or not (-1). +upcoming triplet, whether the first point is closer to the second point +than to the third one. (+1), or not (-1). >>> triplets_test = np.array( ... [[[5.6, 5.3], [2.2, 2.1], [1.2, 3.4]], @@ -662,34 +663,28 @@ array([-1., 1.]) Scoring ------- -Triplet metric learners can also -return a `decision_function` for a set of pairs. This is basically the "score" -which sign will be taken to find the prediction for the pair, which -corresponds to the difference between the distance between the first two points, -and the distance between the first and last points of the triplet (higher -score means the first and last points are more likely to be more dissimilar than -the two first points (i.e. more likely to have a +1 prediction since it's -the right ordering)). +Triplet metric learners can also return a `decision_function` for a set of triplets, +which correspond to the distance between the first two points minus the distance +between the first and last points of the triplet (the higher the value, the more +similar the first point to the second point compared to the last one). This "score" +can be interpreted as a measure of likeliness of having a +1 prediction for this +triplet. >>> scml.decision_function(triplets_test) array([-1.75700306, 4.98982131]) -In the above example, for the first triplet in `triplets_test`, the -two first points are predicted less similar than the first and last points (they -are further away in the transformed space). - -Unlike for pairs learners, triplets learners don't allow to give a `y` -when fitting, which does not allow to use scikit-learn scoring functions -like: - ->>> from sklearn.model_selection import cross_val_score ->>> cross_val_score(scml, triplets, scoring='f1_score') # this won't work +In the above example, for the first triplet in `triplets_test`, the first +point is predicted less similar to the second point than to the last point +(they are further away in the transformed space). -(This is actually intentional) +Unlike pairs learners, triplets learners do not allow to give a `y` when fitting: we +assume that the ordering of points within triplets is such that the training triplets +are all positive. Therefore, it is not possible to use scikit-learn scoring functions +(such as 'f1_score') for triplets learners. However, triplets learners do have a default scoring function, which will basically return the accuracy score on a given test set, i.e. the proportion -of triplets have the right predicted ordering. +of triplets that have the right predicted ordering. >>> scml.score(triplets_test) 0.5 @@ -777,16 +772,14 @@ array([-1., 1.]) .. _quadruplets_scoring: Scoring -------- +-------W -Quadruplet metric learners can also -return a `decision_function` for a set of pairs. This is basically the "score" -which sign will be taken to find the prediction for the pair, which -corresponds to the difference between the distance between the two last points, -and the distance between the two last points of the quadruplet (higher -score means the two last points are more likely to be more dissimilar than -the two first points (i.e. more likely to have a +1 prediction since it's -the right ordering)). +Quadruplet metric learners can also return a `decision_function` for a set of +quadruplets, which correspond to the distance between the first pair of points minus +the distance between the second pair of points of the triplet (the higher the value, +the more similar the first pair is than the last pair). +This "score" can be interpreted as a measure of likeliness of having a +1 prediction +for this quadruplet. >>> lsml.decision_function(quadruplets_test) array([-1.75700306, 4.98982131]) @@ -795,17 +788,10 @@ In the above example, for the first quadruplet in `quadruplets_test`, the two first points are predicted less similar than the two last points (they are further away in the transformed space). -Unlike for pairs learners, quadruplets learners don't allow to give a `y` -when fitting, which does not allow to use scikit-learn scoring functions -like: - ->>> from sklearn.model_selection import cross_val_score ->>> cross_val_score(lsml, quadruplets, scoring='f1_score') # this won't work - -(This is actually intentional, for more details -about that, see -`this comment `_ -on github.) +Like triplet learners, quadruplets learners do not allow to give a `y` when fitting: we +assume that the ordering of points within triplets is such that the training triplets +are all positive. Therefore, it is not possible to use scikit-learn scoring functions +(such as 'f1_score') for triplets learners. However, quadruplets learners do have a default scoring function, which will basically return the accuracy score on a given test set, i.e. the proportion diff --git a/metric_learn/base_metric.py b/metric_learn/base_metric.py index b1d6e197..6fe3e503 100644 --- a/metric_learn/base_metric.py +++ b/metric_learn/base_metric.py @@ -604,7 +604,7 @@ def predict(self, triplets): Parameters ---------- triplets : array-like, shape=(n_triplets, 3, n_features) or (n_triplets, 3) - 3D Array of triplets to predict, with each row corresponding to three + 3D array of triplets to predict, with each row corresponding to three points, or 2D array of indices of triplets if the metric learner uses a preprocessor. @@ -618,18 +618,18 @@ def predict(self, triplets): def decision_function(self, triplets): """Predicts differences between sample distances in input triplets. - For each triplets in the samples, computes the difference between the - learned metric of the second pair minus the learned metric of the first - pair. The higher it is, the more probable it is that the pairs in the - triplets are presented in the right order, i.e. that the label of the - triplet is 1. The lower it is, the more probable it is that the label of - the triplet is -1. + For each triplet (X_a, X_b, X_c) in the samples, computes the difference + between the learned metric of the second pair (X_a, X_c) minus the learned + metric of the first pair (X_a, X_b). The higher it is, the more probable it + is that the pairs in the triplets are presented in the right order, i.e. + that the label of the triplet is 1. The lower it is, the more probable it + is that the label of the triplet is -1. Parameters ---------- - triplet : array-like, shape=(n_triplets, 4, n_features) or \ - (n_triplets, 4) - 3D Array of triplets to predict, with each row corresponding to four + triplet : array-like, shape=(n_triplets, 3, n_features) or \ + (n_triplets, 3) + 3D array of triplets to predict, with each row corresponding to three points, or 2D array of indices of triplets if the metric learner uses a preprocessor. @@ -646,17 +646,17 @@ def decision_function(self, triplets): self.score_pairs(triplets[:, :2])) def score(self, triplets): - """Computes score on input triplets + """Computes score on input triplets. - Returns the accuracy score of the following classification task: a record - is correctly classified if the predicted similarity between the first two - samples is higher than that of the first and last element. + Returns the accuracy score of the following classification task: a triplet + (X_a, X_b, X_c) is correctly classified if the predicted similarity between + the first pair (X_a, X_b) is higher than that of the second pair (X_a, X_c) Parameters ---------- - triplets : array-like, shape=(n_triplets, 4, n_features) or \ - (n_triplets, 4) - 3D Array of triplets to score, with each row corresponding to four + triplets : array-like, shape=(n_triplets, 3, n_features) or \ + (n_triplets, 3) + 3D array of triplets to score, with each row corresponding to three points, or 2D array of indices of triplets if the metric learner uses a preprocessor. diff --git a/metric_learn/constraints.py b/metric_learn/constraints.py index 6075896f..9952a40c 100644 --- a/metric_learn/constraints.py +++ b/metric_learn/constraints.py @@ -36,9 +36,31 @@ def positive_negative_pairs(self, num_constraints, same_length=False, return a, b, c, d def generate_knntriplets(self, X, k_genuine, k_impostor): + """ + Generates triplets for every point to `k_genuine` neighbors of the same + class and `k_impostor` neighbors of other classes. + + For every point (X_a) the triplets (X_a, X_b, X_c) are constructed from all + the combinations of taking `k_genuine` neighbors (X_b) of the same class + and `k_impostor` neighbors (X_c) of other classes. + + Parameters + ---------- + X : (n x d) matrix + Input data, where each row corresponds to a single instance. + k_genuine : int + Number of neighbors of the same class to be taken into account. + k_impostor : int + Number of neighbors of different classes to be taken into account. + + Returns + ------- + triplets : array-like, shape=(n_constraints, 3) + 2D array of triplets of indicators. + """ labels = np.unique(self.partial_labels) - L = len(labels) + n_labels = len(labels) len_input = np.size(self.partial_labels, 0) triplets = np.empty((len_input*k_genuine*k_impostor, 3), dtype=np.intp) @@ -46,7 +68,7 @@ def generate_knntriplets(self, X, k_genuine, k_impostor): finish = 0 neigh = NearestNeighbors() - for i in range(L): + for i in range(n_labels): # generate mask for current label gen_mask = self.partial_labels == labels[i] @@ -67,7 +89,7 @@ def generate_knntriplets(self, X, k_genuine, k_impostor): X=X[gen_mask], return_distance=False)) - # lenght = len_label*k_genuine*k_impostor + # length = len_label*k_genuine*k_impostor finish += np.sum(gen_mask)*k_genuine*k_impostor triplets[start:finish, :] = self._comb(gen_indx, gen_neigh, imp_neigh, @@ -79,6 +101,7 @@ def generate_knntriplets(self, X, k_genuine, k_impostor): return triplets def _comb(self, A, B, C, sizeB, sizeC): + # generate_knntripelts helper function # generate an array will all combinations of choosing # an element from A, B and C return np.vstack((repmat(A, sizeB*sizeC, 1).ravel(order='F'), diff --git a/test/test_triplets_classifiers.py b/test/test_triplets_classifiers.py index ec6184fb..013ac6f5 100644 --- a/test/test_triplets_classifiers.py +++ b/test/test_triplets_classifiers.py @@ -57,7 +57,7 @@ def test_accuracy_toy_example(estimator, build_dataset): set_random_state(estimator) estimator.fit(triplets) # We take the two first points and we build 4 regularly spaced points on the - # line they define, so that it's easy to build quadruplets of different + # line they define, so that it's easy to build triplets of different # similarities. X_test = X[0] + np.arange(4)[:, np.newaxis] * (X[0] - X[1]) / 4 diff --git a/test/test_utils.py b/test/test_utils.py index d61a14db..c104efbf 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -89,7 +89,7 @@ def build_triplets(with_preprocessor=False): input_data, labels = load_iris(return_X_y=True) X, y = shuffle(input_data, labels, random_state=SEED) constraints = Constraints(y) - triplets = constraints.generate_knntriplets(X, 3, 10) + triplets = constraints.generate_knntriplets(X, 3, 4) if with_preprocessor: # if preprocessor, we build a 2D array of triplets of indices return triplets, X From 102e120153c14385d57804e740d2350895eae5df Mon Sep 17 00:00:00 2001 From: grudloff Date: Fri, 21 Feb 2020 14:16:10 +0100 Subject: [PATCH 07/19] small rephrasing --- metric_learn/base_metric.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/metric_learn/base_metric.py b/metric_learn/base_metric.py index 6fe3e503..53ece28a 100644 --- a/metric_learn/base_metric.py +++ b/metric_learn/base_metric.py @@ -598,8 +598,8 @@ class _TripletsClassifierMixin(BaseMetricLearner): def predict(self, triplets): """Predicts the ordering between sample distances in input triplets. - For each triplets, returns 1 if the first two elements are closer than the - first and last and -1 if not. + For each triplets, returns 1 if the first element is closer to the second than to the + last and -1 if not. Parameters ---------- From 3b421abafe29293845d118c8f3f68877f0b78908 Mon Sep 17 00:00:00 2001 From: grudloff Date: Fri, 21 Feb 2020 14:19:00 +0100 Subject: [PATCH 08/19] small flake8 fix --- metric_learn/base_metric.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/metric_learn/base_metric.py b/metric_learn/base_metric.py index 53ece28a..fe7e1aab 100644 --- a/metric_learn/base_metric.py +++ b/metric_learn/base_metric.py @@ -598,8 +598,8 @@ class _TripletsClassifierMixin(BaseMetricLearner): def predict(self, triplets): """Predicts the ordering between sample distances in input triplets. - For each triplets, returns 1 if the first element is closer to the second than to the - last and -1 if not. + For each triplets, returns 1 if the first element is closer to the second + than to the last and -1 if not. Parameters ---------- From 3e164e44d49dca1c7226c14afbfcef465c7ad46a Mon Sep 17 00:00:00 2001 From: grudloff Date: Mon, 24 Feb 2020 11:26:53 +0100 Subject: [PATCH 09/19] Handle low number of neighbors for knn triplets --- metric_learn/constraints.py | 41 +++++++++++++++++++++++++++++-------- 1 file changed, 33 insertions(+), 8 deletions(-) diff --git a/metric_learn/constraints.py b/metric_learn/constraints.py index 9952a40c..6aa86763 100644 --- a/metric_learn/constraints.py +++ b/metric_learn/constraints.py @@ -59,10 +59,35 @@ class and `k_impostor` neighbors of other classes. 2D array of triplets of indicators. """ - labels = np.unique(self.partial_labels) + labels, labels_count = np.unique(self.partial_labels, return_counts=True) n_labels = len(labels) len_input = np.size(self.partial_labels, 0) - triplets = np.empty((len_input*k_genuine*k_impostor, 3), dtype=np.intp) + + # Handle the case where there are too few elements to yield k_genuine or + # k_impostor neighbors for every class. + + k_genuine_vec = np.ones(n_labels, dtype=np.intp)*k_genuine + k_impostor_vec = np.ones(n_labels, dtype=np.intp)*k_impostor + + for i in range(n_labels): + if (k_genuine + 1 > labels_count[i]): + k_genuine_vec[i] = labels_count[i]-1 + warnings.warn("The class {} has {} elements but a minimum of {}," + " which corresponds to k_genuine+1, is expected. " + "A lower number of k_genuine will be used for this" + "class.\n" + .format(labels[i], labels_count[i], k_genuine+1)) + if (k_impostor > len_input - labels_count[i]): + k_impostor_vec[i] = len_input - labels_count[i] + warnings.warn("The class {} has {} elements of other classes but a " + "minimum of {}, which corresponds to k_impostor, is" + " expected. A lower number of k_impostor will be used" + " for this class.\n" + .format(labels[i], len_input - labels_count[i], + k_impostor)) + + triplets = np.empty((np.dot(k_genuine_vec*k_impostor_vec, labels_count), + 3), dtype=np.intp) start = 0 finish = 0 @@ -76,7 +101,8 @@ class and `k_impostor` neighbors of other classes. # get k_genuine genuine neighbours neigh.fit(X=X[gen_indx]) - gen_neigh = np.take(gen_indx, neigh.kneighbors(n_neighbors=k_genuine, + gen_neigh = np.take(gen_indx, neigh.kneighbors( + n_neighbors=k_genuine_vec[i], return_distance=False)) # generate mask for impostors of current label @@ -85,19 +111,18 @@ class and `k_impostor` neighbors of other classes. # get k_impostor impostor neighbours neigh.fit(X=X[imp_indx]) imp_neigh = np.take(imp_indx, neigh.kneighbors( - n_neighbors=k_impostor, + n_neighbors=k_impostor_vec[i], X=X[gen_mask], return_distance=False)) # length = len_label*k_genuine*k_impostor - finish += np.sum(gen_mask)*k_genuine*k_impostor + finish += labels_count[i]*k_genuine_vec[i]*k_impostor_vec[i] triplets[start:finish, :] = self._comb(gen_indx, gen_neigh, imp_neigh, - k_genuine, k_impostor) + k_genuine_vec[i], + k_impostor_vec[i]) start = finish - # TODO: deal with too litle elements for k neighbors to be yielded - return triplets def _comb(self, A, B, C, sizeB, sizeC): From 8e45abe7aad486020a425c42dd30f46dca2bcd42 Mon Sep 17 00:00:00 2001 From: grudloff Date: Mon, 24 Feb 2020 11:27:39 +0100 Subject: [PATCH 10/19] add tests for knn triplet generation --- test/test_constraints.py | 76 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/test/test_constraints.py b/test/test_constraints.py index 243028f6..cf48e9fc 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -2,6 +2,7 @@ import numpy as np from sklearn.utils import shuffle from metric_learn.constraints import Constraints +from sklearn.datasets import make_blobs SEED = 42 @@ -69,3 +70,78 @@ def test_unknown_labels_not_in_chunks(num_chunks, chunk_size): random_state=SEED) assert np.all(chunks[labels < 0] < 0) + + +def test_generate_knntriplets(): + k = 1 + X = np.array([[0, 0], + [1, 1], + [2, 2], + [3, 3], + [4, 4], + [5, 5], + [6, 6], + [7, 7], + [8, 8]]) + y = np.array([1, 1, 1, 2, 2, 2, 3, 3, 3]) + T_test = np.array([[0, 1, 3], + [1, 0, 3], + [2, 1, 3], + [3, 4, 2], + [4, 3, 2], + [5, 4, 6], + [6, 7, 5], + [7, 6, 5], + [8, 7, 5]]) + T = Constraints(y).generate_knntriplets(X, k, k) + + assert np.array_equal(T, T_test) + + +def test_generate_knntriplets_k_genuine(): + """Checks the correct error raised when k_genuine is too big """ + X, y = shuffle(*make_blobs(random_state=SEED), + random_state=SEED) + + label, labels_count = np.unique(y, return_counts=True) + labels_count_min = np.min(labels_count) + idx_smallest_label = np.where(labels_count == labels_count_min) + + warn_msgs = [] + for idx in idx_smallest_label[0]: + k_genuine = labels_count[idx] + warn_msgs.append("The class {} has {} elements but a minimum of {}," + " which corresponds to k_genuine+1, is expected. " + "A lower number of k_genuine will be used for this" + "class.\n" + .format(label[idx], k_genuine, k_genuine+1)) + + with pytest.warns(UserWarning) as raised_warning: + Constraints(y).generate_knntriplets(X, k_genuine, 1) + for warn in raised_warning: + assert str(warn.message) in warn_msgs + + +def test_generate_knntriplets_k_impostor(): + """Checks the correct error raised when k_impostor is too big """ + X, y = shuffle(*make_blobs(random_state=SEED), + random_state=SEED) + + length = len(y) + label, labels_count = np.unique(y, return_counts=True) + labels_count_max = np.max(labels_count) + idx_smallest_label = np.where(labels_count == labels_count_max) + k_impostor = length - labels_count_max + 1 + + warn_msgs = [] + for idx in idx_smallest_label[0]: + warn_msgs.append("The class {} has {} elements of other classes but a " + "minimum of {}, which corresponds to k_impostor, is" + " expected. A lower number of k_impostor will be used" + " for this class.\n" + .format(label[idx], k_impostor-1, k_impostor)) + + with pytest.warns(UserWarning) as raised_warning: + Constraints(y).generate_knntriplets(X, 1, k_impostor) + for warn in raised_warning: + assert str(warn.message) in warn_msgs From aaeb4582bd1679194db0084c24423ee28e00bfc9 Mon Sep 17 00:00:00 2001 From: grudloff Date: Tue, 25 Feb 2020 17:38:45 +0100 Subject: [PATCH 11/19] fixed typos and rephrasing --- doc/weakly_supervised.rst | 12 +++++------ metric_learn/base_metric.py | 10 ++++----- metric_learn/constraints.py | 38 +++++++++++++++++++--------------- test/test_constraints.py | 41 +++++++++++++------------------------ 4 files changed, 47 insertions(+), 54 deletions(-) diff --git a/doc/weakly_supervised.rst b/doc/weakly_supervised.rst index 63bb00f7..72f68627 100644 --- a/doc/weakly_supervised.rst +++ b/doc/weakly_supervised.rst @@ -594,7 +594,7 @@ points, while constrains the sum of distances between dissimilar points: .. _learning_on_triplets: -Learning on Triplets +Learning on triplets ==================== Some metric learning algorithms learn on triplets of samples. In this case, @@ -650,7 +650,7 @@ Prediction When a triplets learner is fitted, it is also able to predict, for an upcoming triplet, whether the first point is closer to the second point -than to the third one. (+1), or not (-1). +than to the third one (+1), or not (-1). >>> triplets_test = np.array( ... [[[5.6, 5.3], [2.2, 2.1], [1.2, 3.4]], @@ -664,7 +664,7 @@ Scoring ------- Triplet metric learners can also return a `decision_function` for a set of triplets, -which correspond to the distance between the first two points minus the distance +which corresponds to the distance between the first two points minus the distance between the first and last points of the triplet (the higher the value, the more similar the first point to the second point compared to the last one). This "score" can be interpreted as a measure of likeliness of having a +1 prediction for this @@ -707,7 +707,7 @@ Learning on quadruplets ======================= Some metric learning algorithms learn on quadruplets of samples. In this case, -one should provide the algorithm with `n_samples` quadruplets of points. Th +one should provide the algorithm with `n_samples` quadruplets of points. The semantic of each quadruplet is that the first two points should be closer together than the last two points. @@ -772,10 +772,10 @@ array([-1., 1.]) .. _quadruplets_scoring: Scoring --------W +------- Quadruplet metric learners can also return a `decision_function` for a set of -quadruplets, which correspond to the distance between the first pair of points minus +quadruplets, which corresponds to the distance between the first pair of points minus the distance between the second pair of points of the triplet (the higher the value, the more similar the first pair is than the last pair). This "score" can be interpreted as a measure of likeliness of having a +1 prediction diff --git a/metric_learn/base_metric.py b/metric_learn/base_metric.py index fe7e1aab..c1e29d93 100644 --- a/metric_learn/base_metric.py +++ b/metric_learn/base_metric.py @@ -619,11 +619,11 @@ def decision_function(self, triplets): """Predicts differences between sample distances in input triplets. For each triplet (X_a, X_b, X_c) in the samples, computes the difference - between the learned metric of the second pair (X_a, X_c) minus the learned - metric of the first pair (X_a, X_b). The higher it is, the more probable it - is that the pairs in the triplets are presented in the right order, i.e. - that the label of the triplet is 1. The lower it is, the more probable it - is that the label of the triplet is -1. + between the learned distance of the second pair (X_a, X_c) minus the + learned distance of the first pair (X_a, X_b). The higher it is, the more + probable it is that the pairs in the triplets are presented in the right + order, i.e. that the label of the triplet is 1. The lower it is, the more + probable it is that the label of the triplet is -1. Parameters ---------- diff --git a/metric_learn/constraints.py b/metric_learn/constraints.py index 6aa86763..01053d5d 100644 --- a/metric_learn/constraints.py +++ b/metric_learn/constraints.py @@ -37,12 +37,17 @@ def positive_negative_pairs(self, num_constraints, same_length=False, def generate_knntriplets(self, X, k_genuine, k_impostor): """ - Generates triplets for every point to `k_genuine` neighbors of the same - class and `k_impostor` neighbors of other classes. + Generates triplets from labeled data. For every point (X_a) the triplets (X_a, X_b, X_c) are constructed from all - the combinations of taking `k_genuine` neighbors (X_b) of the same class - and `k_impostor` neighbors (X_c) of other classes. + the combinations of taking one of its `k_genuine`-nearest neighbors of the + same class (X_b) and taking one of its `k_impostor`-nearest neighbors of + other classes (X_c). + + In the case a class doesn't have enough points in the same class (other + classes) to yield `k_genuine` (`k_impostor`) neighbors a warning will be + raised and the maximum value of genuine (impostor) neighbors will be used + for that class. Parameters ---------- @@ -72,19 +77,20 @@ class and `k_impostor` neighbors of other classes. for i in range(n_labels): if (k_genuine + 1 > labels_count[i]): k_genuine_vec[i] = labels_count[i]-1 - warnings.warn("The class {} has {} elements but a minimum of {}," - " which corresponds to k_genuine+1, is expected. " - "A lower number of k_genuine will be used for this" - "class.\n" - .format(labels[i], labels_count[i], k_genuine+1)) + warnings.warn("The class {} has {} elements, which is not sufficient " + "to generate {} genuine neighbors as specified by " + "k_genuine. Will generate {} genuine neighbors instead." + "\n" + .format(labels[i], labels_count[i], k_genuine+1, + k_genuine_vec[i])) if (k_impostor > len_input - labels_count[i]): k_impostor_vec[i] = len_input - labels_count[i] - warnings.warn("The class {} has {} elements of other classes but a " - "minimum of {}, which corresponds to k_impostor, is" - " expected. A lower number of k_impostor will be used" - " for this class.\n" - .format(labels[i], len_input - labels_count[i], - k_impostor)) + warnings.warn("The class {} has {} elements of other classes, which is" + " not sufficient to generate {} impostor neighbors as " + "specified by k_impostor. Will generate {} impostor " + "neighbors instead.\n" + .format(labels[i], k_impostor_vec[i], k_impostor, + k_impostor_vec[i])) triplets = np.empty((np.dot(k_genuine_vec*k_impostor_vec, labels_count), 3), dtype=np.intp) @@ -126,7 +132,7 @@ class and `k_impostor` neighbors of other classes. return triplets def _comb(self, A, B, C, sizeB, sizeC): - # generate_knntripelts helper function + # generate_knntriplets helper function # generate an array will all combinations of choosing # an element from A, B and C return np.vstack((repmat(A, sizeB*sizeC, 1).ravel(order='F'), diff --git a/test/test_constraints.py b/test/test_constraints.py index cf48e9fc..4558f008 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -74,25 +74,11 @@ def test_unknown_labels_not_in_chunks(num_chunks, chunk_size): def test_generate_knntriplets(): k = 1 - X = np.array([[0, 0], - [1, 1], - [2, 2], - [3, 3], - [4, 4], - [5, 5], - [6, 6], - [7, 7], + X = np.array([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8]]) y = np.array([1, 1, 1, 2, 2, 2, 3, 3, 3]) - T_test = np.array([[0, 1, 3], - [1, 0, 3], - [2, 1, 3], - [3, 4, 2], - [4, 3, 2], - [5, 4, 6], - [6, 7, 5], - [7, 6, 5], - [8, 7, 5]]) + T_test = np.array([[0, 1, 3], [1, 0, 3], [2, 1, 3], [3, 4, 2], [4, 3, 2], + [5, 4, 6], [6, 7, 5], [7, 6, 5], [8, 7, 5]]) T = Constraints(y).generate_knntriplets(X, k, k) assert np.array_equal(T, T_test) @@ -110,11 +96,11 @@ def test_generate_knntriplets_k_genuine(): warn_msgs = [] for idx in idx_smallest_label[0]: k_genuine = labels_count[idx] - warn_msgs.append("The class {} has {} elements but a minimum of {}," - " which corresponds to k_genuine+1, is expected. " - "A lower number of k_genuine will be used for this" - "class.\n" - .format(label[idx], k_genuine, k_genuine+1)) + warn_msgs.append("The class {} has {} elements, which is not sufficient " + "to generate {} genuine neighbors as specified by " + "k_genuine. Will generate {} genuine neighbors instead." + "\n" + .format(label[idx], k_genuine, k_genuine+1, k_genuine-1)) with pytest.warns(UserWarning) as raised_warning: Constraints(y).generate_knntriplets(X, k_genuine, 1) @@ -135,11 +121,12 @@ def test_generate_knntriplets_k_impostor(): warn_msgs = [] for idx in idx_smallest_label[0]: - warn_msgs.append("The class {} has {} elements of other classes but a " - "minimum of {}, which corresponds to k_impostor, is" - " expected. A lower number of k_impostor will be used" - " for this class.\n" - .format(label[idx], k_impostor-1, k_impostor)) + warn_msgs.append("The class {} has {} elements of other classes, which is" + " not sufficient to generate {} impostor neighbors as " + "specified by k_impostor. Will generate {} impostor " + "neighbors instead.\n" + .format(label[idx], k_impostor-1, k_impostor, + k_impostor-1)) with pytest.warns(UserWarning) as raised_warning: Constraints(y).generate_knntriplets(X, 1, k_impostor) From 43cfd6c61662e0c0917b8f2b9077da54e008a26b Mon Sep 17 00:00:00 2001 From: grudloff Date: Wed, 26 Feb 2020 10:58:48 +0100 Subject: [PATCH 12/19] added more tests for knn triplet construction --- test/test_constraints.py | 99 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 92 insertions(+), 7 deletions(-) diff --git a/test/test_constraints.py b/test/test_constraints.py index 4558f008..b07f5099 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -3,6 +3,7 @@ from sklearn.utils import shuffle from metric_learn.constraints import Constraints from sklearn.datasets import make_blobs +from sklearn.neighbors import NearestNeighbors SEED = 42 @@ -73,15 +74,99 @@ def test_unknown_labels_not_in_chunks(num_chunks, chunk_size): def test_generate_knntriplets(): + """Toy example validation of knn triplets construction""" k = 1 - X = np.array([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], - [8, 8]]) + X = np.array([[0, 0], [2, 2], [4, 4], [8, 8], [16, 16], [32, 32], [64, 64], + [128, 128], [256, 256]]) y = np.array([1, 1, 1, 2, 2, 2, 3, 3, 3]) + T_test = np.array([[0, 1, 3], [1, 0, 3], [2, 1, 3], [3, 4, 2], [4, 3, 2], - [5, 4, 6], [6, 7, 5], [7, 6, 5], [8, 7, 5]]) + [5, 4, 2], [6, 7, 5], [7, 6, 5], [8, 7, 5]]) T = Constraints(y).generate_knntriplets(X, k, k) - assert np.array_equal(T, T_test) + assert len(list(set(map(tuple, T)) - set(map(tuple, T_test)))) == 0 + + +@pytest.mark.parametrize("delta_genuine, delta_impostor", [(1, 1), (1, 2), + (2, 1), (2, 2)]) +def test_generate_knntriplets_k(delta_genuine, delta_impostor): + """Checks edge cases of knn triplet construction""" + X, y = shuffle(*make_blobs(random_state=SEED), + random_state=SEED) + + label, labels_count = np.unique(y, return_counts=True) + labels_count_min = np.min(labels_count) + k_genuine = labels_count_min - delta_genuine + + length = len(y) + labels_count_max = np.max(labels_count) + k_impostor = length - labels_count_max + 1 - delta_impostor + + T = Constraints(y).generate_knntriplets(X, k_genuine, k_impostor) + T_test = naive_generate_knntriplets(X, y, k_genuine, k_impostor) + + assert len(list(set(map(tuple, T)) - set(map(tuple, T_test)))) == 0 + + +def naive_generate_knntriplets(X, y, k_genuine, k_impostor): + """ + Generates triplets from labeled data. Naive implementation + intended for testing. + + Parameters + ---------- + X : (n x d) matrix + Input data, where each row corresponds to a single instance. + k_genuine : int + Number of neighbors of the same class to be taken into account. + k_impostor : int + Number of neighbors of different classes to be taken into account. + + Returns + ------- + triplets : array-like, shape=(n_constraints, 3) + 2D array of triplets of indicators. + """ + + labels, labels_count = np.unique(y, return_counts=True) + n_labels = len(labels) + len_input = np.size(y, 0) + + triplets = np.empty((len_input*k_genuine*k_impostor, 3), + dtype=np.intp) + + j = 0 + neigh = NearestNeighbors() + + for i in range(n_labels): + + # generate mask for current label + gen_mask = y == labels[i] + gen_indx = np.where(gen_mask) + + # get k_genuine genuine neighbours + neigh.fit(X=X[gen_indx]) + gen_neigh = np.take(gen_indx, neigh.kneighbors( + n_neighbors=k_genuine, + return_distance=False)) + + # generate mask for impostors of current label + imp_indx = np.where(np.invert(gen_mask)) + + # get k_impostor impostor neighbours + neigh.fit(X=X[imp_indx]) + imp_neigh = np.take(imp_indx, neigh.kneighbors( + n_neighbors=k_impostor, + X=X[gen_mask], + return_distance=False)) + + for a, k in zip(gen_indx[0], range(len(gen_indx[0]))): + for b in gen_neigh[k, :]: + for c in imp_neigh[k, :]: + triplets[j, :] = np.array([a, b, c]) + j += 1 + + return triplets def test_generate_knntriplets_k_genuine(): @@ -92,10 +177,10 @@ def test_generate_knntriplets_k_genuine(): label, labels_count = np.unique(y, return_counts=True) labels_count_min = np.min(labels_count) idx_smallest_label = np.where(labels_count == labels_count_min) + k_genuine = labels_count_min warn_msgs = [] for idx in idx_smallest_label[0]: - k_genuine = labels_count[idx] warn_msgs.append("The class {} has {} elements, which is not sufficient " "to generate {} genuine neighbors as specified by " "k_genuine. Will generate {} genuine neighbors instead." @@ -116,11 +201,11 @@ def test_generate_knntriplets_k_impostor(): length = len(y) label, labels_count = np.unique(y, return_counts=True) labels_count_max = np.max(labels_count) - idx_smallest_label = np.where(labels_count == labels_count_max) + idx_biggest_label = np.where(labels_count == labels_count_max) k_impostor = length - labels_count_max + 1 warn_msgs = [] - for idx in idx_smallest_label[0]: + for idx in idx_biggest_label[0]: warn_msgs.append("The class {} has {} elements of other classes, which is" " not sufficient to generate {} impostor neighbors as " "specified by k_impostor. Will generate {} impostor " From 0d5134a2285ed8d2e5b5d42d22ad1cc3fa169eb5 Mon Sep 17 00:00:00 2001 From: grudloff Date: Thu, 27 Feb 2020 10:01:42 +0100 Subject: [PATCH 13/19] sorted triplet & fix test_generate_knntriplets_k --- test/test_constraints.py | 116 ++++++++++++--------------------------- 1 file changed, 36 insertions(+), 80 deletions(-) diff --git a/test/test_constraints.py b/test/test_constraints.py index b07f5099..bd12f368 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -3,7 +3,6 @@ from sklearn.utils import shuffle from metric_learn.constraints import Constraints from sklearn.datasets import make_blobs -from sklearn.neighbors import NearestNeighbors SEED = 42 @@ -84,89 +83,46 @@ def test_generate_knntriplets(): [5, 4, 2], [6, 7, 5], [7, 6, 5], [8, 7, 5]]) T = Constraints(y).generate_knntriplets(X, k, k) - assert len(list(set(map(tuple, T)) - set(map(tuple, T_test)))) == 0 - - -@pytest.mark.parametrize("delta_genuine, delta_impostor", [(1, 1), (1, 2), - (2, 1), (2, 2)]) -def test_generate_knntriplets_k(delta_genuine, delta_impostor): + assert np.array_equal(sorted(T.tolist()), sorted(T_test.tolist())) + + +@pytest.mark.parametrize("k_genuine, k_impostor, T_test", + [(2, 3, + [[0, 1, 3], [0, 1, 4], [0, 1, 5], [0, 2, 3], + [0, 2, 4], [0, 2, 5], [1, 0, 3], [1, 0, 4], + [1, 0, 5], [1, 2, 3], [1, 2, 4], [1, 2, 5], + [2, 0, 3], [2, 0, 4], [2, 0, 5], [2, 1, 3], + [2, 1, 4], [2, 1, 5], [3, 4, 0], [3, 4, 1], + [3, 4, 2], [3, 5, 0], [3, 5, 1], [3, 5, 2], + [4, 3, 0], [4, 3, 1], [4, 3, 2], [4, 5, 0], + [4, 5, 1], [4, 5, 2], [5, 3, 0], [5, 3, 1], + [5, 3, 2], [5, 4, 0], [5, 4, 1], [5, 4, 2]]), + (2, 2, + [[0, 1, 3], [0, 1, 4], [0, 2, 3], [0, 2, 4], + [1, 0, 3], [1, 0, 4], [1, 2, 3], [1, 2, 4], + [2, 0, 3], [2, 0, 4], [2, 1, 3], [2, 1, 4], + [3, 4, 1], [3, 4, 2], [3, 5, 1], [3, 5, 2], + [4, 3, 1], [4, 3, 2], [4, 5, 1], [4, 5, 2], + [5, 3, 1], [5, 3, 2], [5, 4, 1], [5, 4, 2]]), + (1, 3, + [[0, 1, 3], [0, 1, 4], [0, 1, 5], [1, 0, 3], + [1, 0, 4], [1, 0, 5], [2, 1, 3], [2, 1, 4], + [2, 1, 5], [3, 4, 0], [3, 4, 1], [3, 4, 2], + [4, 3, 0], [4, 3, 1], [4, 3, 2], [5, 4, 0], + [5, 4, 1], [5, 4, 2]]), + (1, 2, + [[0, 1, 3], [0, 1, 4], [1, 0, 3], [1, 0, 4], + [2, 1, 3], [2, 1, 4], [3, 4, 1], [3, 4, 2], + [4, 3, 1], [4, 3, 2], [5, 4, 1], [5, 4, 2]])]) +def test_generate_knntriplets_k(k_genuine, k_impostor, T_test): """Checks edge cases of knn triplet construction""" - X, y = shuffle(*make_blobs(random_state=SEED), - random_state=SEED) - - label, labels_count = np.unique(y, return_counts=True) - labels_count_min = np.min(labels_count) - k_genuine = labels_count_min - delta_genuine - length = len(y) - labels_count_max = np.max(labels_count) - k_impostor = length - labels_count_max + 1 - delta_impostor + X = np.array([[0, 0], [2, 2], [4, 4], [8, 8], [16, 16], [32, 32]]) + y = np.array([1, 1, 1, 2, 2, 2]) T = Constraints(y).generate_knntriplets(X, k_genuine, k_impostor) - T_test = naive_generate_knntriplets(X, y, k_genuine, k_impostor) - - assert len(list(set(map(tuple, T)) - set(map(tuple, T_test)))) == 0 - - -def naive_generate_knntriplets(X, y, k_genuine, k_impostor): - """ - Generates triplets from labeled data. Naive implementation - intended for testing. - - Parameters - ---------- - X : (n x d) matrix - Input data, where each row corresponds to a single instance. - k_genuine : int - Number of neighbors of the same class to be taken into account. - k_impostor : int - Number of neighbors of different classes to be taken into account. - - Returns - ------- - triplets : array-like, shape=(n_constraints, 3) - 2D array of triplets of indicators. - """ - - labels, labels_count = np.unique(y, return_counts=True) - n_labels = len(labels) - len_input = np.size(y, 0) - - triplets = np.empty((len_input*k_genuine*k_impostor, 3), - dtype=np.intp) - - j = 0 - neigh = NearestNeighbors() - - for i in range(n_labels): - - # generate mask for current label - gen_mask = y == labels[i] - gen_indx = np.where(gen_mask) - - # get k_genuine genuine neighbours - neigh.fit(X=X[gen_indx]) - gen_neigh = np.take(gen_indx, neigh.kneighbors( - n_neighbors=k_genuine, - return_distance=False)) - - # generate mask for impostors of current label - imp_indx = np.where(np.invert(gen_mask)) - - # get k_impostor impostor neighbours - neigh.fit(X=X[imp_indx]) - imp_neigh = np.take(imp_indx, neigh.kneighbors( - n_neighbors=k_impostor, - X=X[gen_mask], - return_distance=False)) - - for a, k in zip(gen_indx[0], range(len(gen_indx[0]))): - for b in gen_neigh[k, :]: - for c in imp_neigh[k, :]: - triplets[j, :] = np.array([a, b, c]) - j += 1 - - return triplets + + assert np.array_equal(sorted(T.tolist()), T_test) def test_generate_knntriplets_k_genuine(): From 758bf14435bb8ee238c1483cbf7556879d1a695f Mon Sep 17 00:00:00 2001 From: grudloff Date: Thu, 27 Feb 2020 15:21:15 +0100 Subject: [PATCH 14/19] added over the edge knn triplets test --- test/test_constraints.py | 52 +++++++++++++++++++--------------------- 1 file changed, 25 insertions(+), 27 deletions(-) diff --git a/test/test_constraints.py b/test/test_constraints.py index bd12f368..91f51368 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -72,32 +72,8 @@ def test_unknown_labels_not_in_chunks(num_chunks, chunk_size): assert np.all(chunks[labels < 0] < 0) -def test_generate_knntriplets(): - """Toy example validation of knn triplets construction""" - k = 1 - X = np.array([[0, 0], [2, 2], [4, 4], [8, 8], [16, 16], [32, 32], [64, 64], - [128, 128], [256, 256]]) - y = np.array([1, 1, 1, 2, 2, 2, 3, 3, 3]) - - T_test = np.array([[0, 1, 3], [1, 0, 3], [2, 1, 3], [3, 4, 2], [4, 3, 2], - [5, 4, 2], [6, 7, 5], [7, 6, 5], [8, 7, 5]]) - T = Constraints(y).generate_knntriplets(X, k, k) - - assert np.array_equal(sorted(T.tolist()), sorted(T_test.tolist())) - - @pytest.mark.parametrize("k_genuine, k_impostor, T_test", - [(2, 3, - [[0, 1, 3], [0, 1, 4], [0, 1, 5], [0, 2, 3], - [0, 2, 4], [0, 2, 5], [1, 0, 3], [1, 0, 4], - [1, 0, 5], [1, 2, 3], [1, 2, 4], [1, 2, 5], - [2, 0, 3], [2, 0, 4], [2, 0, 5], [2, 1, 3], - [2, 1, 4], [2, 1, 5], [3, 4, 0], [3, 4, 1], - [3, 4, 2], [3, 5, 0], [3, 5, 1], [3, 5, 2], - [4, 3, 0], [4, 3, 1], [4, 3, 2], [4, 5, 0], - [4, 5, 1], [4, 5, 2], [5, 3, 0], [5, 3, 1], - [5, 3, 2], [5, 4, 0], [5, 4, 1], [5, 4, 2]]), - (2, 2, + [(2, 2, [[0, 1, 3], [0, 1, 4], [0, 2, 3], [0, 2, 4], [1, 0, 3], [1, 0, 4], [1, 2, 3], [1, 2, 4], [2, 0, 3], [2, 0, 4], [2, 1, 3], [2, 1, 4], @@ -114,8 +90,30 @@ def test_generate_knntriplets(): [[0, 1, 3], [0, 1, 4], [1, 0, 3], [1, 0, 4], [2, 1, 3], [2, 1, 4], [3, 4, 1], [3, 4, 2], [4, 3, 1], [4, 3, 2], [5, 4, 1], [5, 4, 2]])]) -def test_generate_knntriplets_k(k_genuine, k_impostor, T_test): - """Checks edge cases of knn triplet construction""" +def test_generate_knntriplets_under_edge(k_genuine, k_impostor, T_test): + """Checks under the edge cases of knn triplet construction with enough + neighbors""" + + X = np.array([[0, 0], [2, 2], [4, 4], [8, 8], [16, 16], [32, 32]]) + y = np.array([1, 1, 1, 2, 2, 2]) + + T = Constraints(y).generate_knntriplets(X, k_genuine, k_impostor) + + assert np.array_equal(sorted(T.tolist()), T_test) + + +@pytest.mark.parametrize("k_genuine, k_impostor,", + [(2, 3), (3, 3), (2, 4), (3, 4)]) +def test_generate_knntriplets(k_genuine, k_impostor): + """Checks edge and over the edge cases of knn triplet construction with not + enough neighbors""" + + T_test = [[0, 1, 3], [0, 1, 4], [0, 1, 5], [0, 2, 3], [0, 2, 4], [0, 2, 5], + [1, 0, 3], [1, 0, 4], [1, 0, 5], [1, 2, 3], [1, 2, 4], [1, 2, 5], + [2, 0, 3], [2, 0, 4], [2, 0, 5], [2, 1, 3], [2, 1, 4], [2, 1, 5], + [3, 4, 0], [3, 4, 1], [3, 4, 2], [3, 5, 0], [3, 5, 1], [3, 5, 2], + [4, 3, 0], [4, 3, 1], [4, 3, 2], [4, 5, 0], [4, 5, 1], [4, 5, 2], + [5, 3, 0], [5, 3, 1], [5, 3, 2], [5, 4, 0], [5, 4, 1], [5, 4, 2]] X = np.array([[0, 0], [2, 2], [4, 4], [8, 8], [16, 16], [32, 32]]) y = np.array([1, 1, 1, 2, 2, 2]) From f41fea153258049030837b97102154263643c282 Mon Sep 17 00:00:00 2001 From: grudloff Date: Thu, 27 Feb 2020 19:02:42 +0100 Subject: [PATCH 15/19] multiple small code refactoring --- metric_learn/constraints.py | 67 ++++++++++++++++++------------- test/test_constraints.py | 8 ++-- test/test_triplets_classifiers.py | 7 +--- test/test_utils.py | 2 +- 4 files changed, 44 insertions(+), 40 deletions(-) diff --git a/metric_learn/constraints.py b/metric_learn/constraints.py index 01053d5d..1c5df376 100644 --- a/metric_learn/constraints.py +++ b/metric_learn/constraints.py @@ -66,25 +66,25 @@ def generate_knntriplets(self, X, k_genuine, k_impostor): labels, labels_count = np.unique(self.partial_labels, return_counts=True) n_labels = len(labels) - len_input = np.size(self.partial_labels, 0) + len_input = self.partial_labels.shape[0] # Handle the case where there are too few elements to yield k_genuine or # k_impostor neighbors for every class. - k_genuine_vec = np.ones(n_labels, dtype=np.intp)*k_genuine - k_impostor_vec = np.ones(n_labels, dtype=np.intp)*k_impostor + k_genuine_vec = np.full(n_labels, k_genuine) + k_impostor_vec = np.full(n_labels, k_impostor) - for i in range(n_labels): - if (k_genuine + 1 > labels_count[i]): - k_genuine_vec[i] = labels_count[i]-1 + for i, count in enumerate(labels_count): + if k_genuine + 1 > count: + k_genuine_vec[i] = count-1 warnings.warn("The class {} has {} elements, which is not sufficient " "to generate {} genuine neighbors as specified by " "k_genuine. Will generate {} genuine neighbors instead." "\n" - .format(labels[i], labels_count[i], k_genuine+1, + .format(labels[i], count, k_genuine+1, k_genuine_vec[i])) - if (k_impostor > len_input - labels_count[i]): - k_impostor_vec[i] = len_input - labels_count[i] + if k_impostor > len_input - count: + k_impostor_vec[i] = len_input - count warnings.warn("The class {} has {} elements of other classes, which is" " not sufficient to generate {} impostor neighbors as " "specified by k_impostor. Will generate {} impostor " @@ -92,17 +92,26 @@ def generate_knntriplets(self, X, k_genuine, k_impostor): .format(labels[i], k_impostor_vec[i], k_impostor, k_impostor_vec[i])) - triplets = np.empty((np.dot(k_genuine_vec*k_impostor_vec, labels_count), - 3), dtype=np.intp) + # The total number of possible triplets combinations comes from taking one + # of the k_genuine_vec[i] genuine neighbors and one of the + # k_impostor_vec[i] impostor neighbors for the labels_count[i] elements in + # every class + comb_per_label = labels_count * k_genuine_vec * k_impostor_vec + num_triplets = np.sum(comb_per_label) + triplets = np.empty((num_triplets, 3), dtype=np.intp) - start = 0 - finish = 0 neigh = NearestNeighbors() - for i in range(n_labels): + # Get start and finish for later triplet assiging + # append zero at the begining for start + start_finish_indices = np.hstack((0, comb_per_label)) + # get cumulative sum + start_finish_indices.cumsum(out=start_finish_indices) + + for i, label in enumerate(labels): # generate mask for current label - gen_mask = self.partial_labels == labels[i] + gen_mask = self.partial_labels == label gen_indx = np.where(gen_mask) # get k_genuine genuine neighbours @@ -112,7 +121,7 @@ def generate_knntriplets(self, X, k_genuine, k_impostor): return_distance=False)) # generate mask for impostors of current label - imp_indx = np.where(np.invert(gen_mask)) + imp_indx = np.where(~gen_mask) # get k_impostor impostor neighbours neigh.fit(X=X[imp_indx]) @@ -122,23 +131,14 @@ def generate_knntriplets(self, X, k_genuine, k_impostor): return_distance=False)) # length = len_label*k_genuine*k_impostor - finish += labels_count[i]*k_genuine_vec[i]*k_impostor_vec[i] + start, finish = start_finish_indices[i:i+2] - triplets[start:finish, :] = self._comb(gen_indx, gen_neigh, imp_neigh, - k_genuine_vec[i], - k_impostor_vec[i]) - start = finish + triplets[start:finish, :] = comb(gen_indx, gen_neigh, imp_neigh, + k_genuine_vec[i], + k_impostor_vec[i]) return triplets - def _comb(self, A, B, C, sizeB, sizeC): - # generate_knntriplets helper function - # generate an array will all combinations of choosing - # an element from A, B and C - return np.vstack((repmat(A, sizeB*sizeC, 1).ravel(order='F'), - repmat(np.hstack(B), sizeC, 1).ravel(order='F'), - repmat(C, 1, sizeB).ravel())).T - def _pairs(self, num_constraints, same_label=True, max_iter=10, random_state=np.random): known_label_idx, = np.where(self.partial_labels >= 0) @@ -197,6 +197,15 @@ def chunks(self, num_chunks=100, chunk_size=2, random_state=None): return chunks +def comb(A, B, C, sizeB, sizeC): + # generate_knntriplets helper function + # generate an array with all combinations of choosing + # an element from A, B and C + return np.vstack((repmat(A, sizeB*sizeC, 1).ravel(order='F'), + repmat(np.hstack(B), sizeC, 1).ravel(order='F'), + repmat(C, 1, sizeB).ravel())).T + + def wrap_pairs(X, constraints): a = np.array(constraints[0]) b = np.array(constraints[1]) diff --git a/test/test_constraints.py b/test/test_constraints.py index 91f51368..e00aeed6 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -130,11 +130,11 @@ def test_generate_knntriplets_k_genuine(): label, labels_count = np.unique(y, return_counts=True) labels_count_min = np.min(labels_count) - idx_smallest_label = np.where(labels_count == labels_count_min) + idx_smallest_label, = np.where(labels_count == labels_count_min) k_genuine = labels_count_min warn_msgs = [] - for idx in idx_smallest_label[0]: + for idx in idx_smallest_label: warn_msgs.append("The class {} has {} elements, which is not sufficient " "to generate {} genuine neighbors as specified by " "k_genuine. Will generate {} genuine neighbors instead." @@ -155,11 +155,11 @@ def test_generate_knntriplets_k_impostor(): length = len(y) label, labels_count = np.unique(y, return_counts=True) labels_count_max = np.max(labels_count) - idx_biggest_label = np.where(labels_count == labels_count_max) + idx_biggest_label, = np.where(labels_count == labels_count_max) k_impostor = length - labels_count_max + 1 warn_msgs = [] - for idx in idx_biggest_label[0]: + for idx in idx_biggest_label: warn_msgs.append("The class {} has {} elements of other classes, which is" " not sufficient to generate {} impostor neighbors as " "specified by k_impostor. Will generate {} impostor " diff --git a/test/test_triplets_classifiers.py b/test/test_triplets_classifiers.py index 013ac6f5..8cedd8cc 100644 --- a/test/test_triplets_classifiers.py +++ b/test/test_triplets_classifiers.py @@ -21,12 +21,7 @@ def test_predict_only_one_or_minus_one(estimator, build_dataset, triplets_train, triplets_test = train_test_split(input_data) estimator.fit(triplets_train) predictions = estimator.predict(triplets_test) - for i in range(len(input_data)): - prediction = estimator.predict(input_data[None, i]) - if(prediction == 0): - print(input_data[i]) - print(preprocessor[input_data[i]]) - print(estimator.decision_function(input_data[None, i])) + not_valid = [e for e in predictions if e not in [-1, 1]] assert len(not_valid) == 0 diff --git a/test/test_utils.py b/test/test_utils.py index c104efbf..a4cf86f4 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -89,7 +89,7 @@ def build_triplets(with_preprocessor=False): input_data, labels = load_iris(return_X_y=True) X, y = shuffle(input_data, labels, random_state=SEED) constraints = Constraints(y) - triplets = constraints.generate_knntriplets(X, 3, 4) + triplets = constraints.generate_knntriplets(X, k_genuine=3, k_impostor=4) if with_preprocessor: # if preprocessor, we build a 2D array of triplets of indices return triplets, X From 14cf03d4691b9bcd4bba380036afc285759d160f Mon Sep 17 00:00:00 2001 From: grudloff Date: Fri, 28 Feb 2020 13:32:26 +0100 Subject: [PATCH 16/19] more refactoring --- metric_learn/constraints.py | 39 ++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/metric_learn/constraints.py b/metric_learn/constraints.py index 1c5df376..9e4f491e 100644 --- a/metric_learn/constraints.py +++ b/metric_learn/constraints.py @@ -63,16 +63,15 @@ def generate_knntriplets(self, X, k_genuine, k_impostor): triplets : array-like, shape=(n_constraints, 3) 2D array of triplets of indicators. """ - - labels, labels_count = np.unique(self.partial_labels, return_counts=True) - n_labels = len(labels) - len_input = self.partial_labels.shape[0] + known_labels = self.partial_labels[self.partial_labels >= 0] + labels, labels_count = np.unique(known_labels, return_counts=True) + len_input = known_labels.shape[0] # Handle the case where there are too few elements to yield k_genuine or # k_impostor neighbors for every class. - k_genuine_vec = np.full(n_labels, k_genuine) - k_impostor_vec = np.full(n_labels, k_impostor) + k_genuine_vec = np.full_like(labels, k_genuine) + k_impostor_vec = np.full_like(labels, k_impostor) for i, count in enumerate(labels_count): if k_genuine + 1 > count: @@ -92,29 +91,29 @@ def generate_knntriplets(self, X, k_genuine, k_impostor): .format(labels[i], k_impostor_vec[i], k_impostor, k_impostor_vec[i])) - # The total number of possible triplets combinations comes from taking one - # of the k_genuine_vec[i] genuine neighbors and one of the - # k_impostor_vec[i] impostor neighbors for the labels_count[i] elements in - # every class + # The total number of possible triplets combinations per label comes from + # taking one of the k_genuine_vec[i] genuine neighbors and one of the + # k_impostor_vec[i] impostor neighbors for the labels_count[i] elements comb_per_label = labels_count * k_genuine_vec * k_impostor_vec - num_triplets = np.sum(comb_per_label) + + # Get start and finish for later triplet assigning + # append zero at the begining for start and get cumulative sum + start_finish_indices = np.hstack((0, comb_per_label)).cumsum() + + # Total number of triplets is the sum of all possible combinations per + # label + num_triplets = start_finish_indices[-1] triplets = np.empty((num_triplets, 3), dtype=np.intp) neigh = NearestNeighbors() - # Get start and finish for later triplet assiging - # append zero at the begining for start - start_finish_indices = np.hstack((0, comb_per_label)) - # get cumulative sum - start_finish_indices.cumsum(out=start_finish_indices) - for i, label in enumerate(labels): # generate mask for current label - gen_mask = self.partial_labels == label + gen_mask = known_labels == label gen_indx = np.where(gen_mask) - # get k_genuine genuine neighbours + # get k_genuine genuine neighbors neigh.fit(X=X[gen_indx]) gen_neigh = np.take(gen_indx, neigh.kneighbors( n_neighbors=k_genuine_vec[i], @@ -123,7 +122,7 @@ def generate_knntriplets(self, X, k_genuine, k_impostor): # generate mask for impostors of current label imp_indx = np.where(~gen_mask) - # get k_impostor impostor neighbours + # get k_impostor impostor neighbors neigh.fit(X=X[imp_indx]) imp_neigh = np.take(imp_indx, neigh.kneighbors( n_neighbors=k_impostor_vec[i], From 088b59af9e3e595c7b4d7a649c260656b40619ef Mon Sep 17 00:00:00 2001 From: grudloff Date: Fri, 28 Feb 2020 13:41:28 +0100 Subject: [PATCH 17/19] Fix & test unlabeled handling triplet generation --- metric_learn/constraints.py | 6 +++++- test/test_constraints.py | 8 ++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/metric_learn/constraints.py b/metric_learn/constraints.py index 9e4f491e..a78f29ce 100644 --- a/metric_learn/constraints.py +++ b/metric_learn/constraints.py @@ -63,7 +63,11 @@ def generate_knntriplets(self, X, k_genuine, k_impostor): triplets : array-like, shape=(n_constraints, 3) 2D array of triplets of indicators. """ - known_labels = self.partial_labels[self.partial_labels >= 0] + # Ignore unlabeled samples + known_labels_mask = self.partial_labels >= 0 + known_labels = self.partial_labels[known_labels_mask] + X = X[known_labels_mask] + labels, labels_count = np.unique(known_labels, return_counts=True) len_input = known_labels.shape[0] diff --git a/test/test_constraints.py b/test/test_constraints.py index e00aeed6..0735e836 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -94,8 +94,8 @@ def test_generate_knntriplets_under_edge(k_genuine, k_impostor, T_test): """Checks under the edge cases of knn triplet construction with enough neighbors""" - X = np.array([[0, 0], [2, 2], [4, 4], [8, 8], [16, 16], [32, 32]]) - y = np.array([1, 1, 1, 2, 2, 2]) + X = np.array([[0, 0], [2, 2], [4, 4], [8, 8], [16, 16], [32, 32], [64, 64]]) + y = np.array([1, 1, 1, 2, 2, 2, -1]) T = Constraints(y).generate_knntriplets(X, k_genuine, k_impostor) @@ -115,8 +115,8 @@ def test_generate_knntriplets(k_genuine, k_impostor): [4, 3, 0], [4, 3, 1], [4, 3, 2], [4, 5, 0], [4, 5, 1], [4, 5, 2], [5, 3, 0], [5, 3, 1], [5, 3, 2], [5, 4, 0], [5, 4, 1], [5, 4, 2]] - X = np.array([[0, 0], [2, 2], [4, 4], [8, 8], [16, 16], [32, 32]]) - y = np.array([1, 1, 1, 2, 2, 2]) + X = np.array([[0, 0], [2, 2], [4, 4], [8, 8], [16, 16], [32, 32], [64, 64]]) + y = np.array([1, 1, 1, 2, 2, 2, -1]) T = Constraints(y).generate_knntriplets(X, k_genuine, k_impostor) From 59d253afd7afa08fcd8f6f233ae18f07966084ff Mon Sep 17 00:00:00 2001 From: grudloff Date: Fri, 28 Feb 2020 13:47:57 +0100 Subject: [PATCH 18/19] closer unlabeled point --- test/test_constraints.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_constraints.py b/test/test_constraints.py index 0735e836..92876779 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -94,7 +94,7 @@ def test_generate_knntriplets_under_edge(k_genuine, k_impostor, T_test): """Checks under the edge cases of knn triplet construction with enough neighbors""" - X = np.array([[0, 0], [2, 2], [4, 4], [8, 8], [16, 16], [32, 32], [64, 64]]) + X = np.array([[0, 0], [2, 2], [4, 4], [8, 8], [16, 16], [32, 32], [33, 33]]) y = np.array([1, 1, 1, 2, 2, 2, -1]) T = Constraints(y).generate_knntriplets(X, k_genuine, k_impostor) @@ -115,7 +115,7 @@ def test_generate_knntriplets(k_genuine, k_impostor): [4, 3, 0], [4, 3, 1], [4, 3, 2], [4, 5, 0], [4, 5, 1], [4, 5, 2], [5, 3, 0], [5, 3, 1], [5, 3, 2], [5, 4, 0], [5, 4, 1], [5, 4, 2]] - X = np.array([[0, 0], [2, 2], [4, 4], [8, 8], [16, 16], [32, 32], [64, 64]]) + X = np.array([[0, 0], [2, 2], [4, 4], [8, 8], [16, 16], [32, 32], [33, 33]]) y = np.array([1, 1, 1, 2, 2, 2, -1]) T = Constraints(y).generate_knntriplets(X, k_genuine, k_impostor) From fb30673a2e6b8c19be7a27cc4fbf14e4fc98ae89 Mon Sep 17 00:00:00 2001 From: grudloff Date: Fri, 28 Feb 2020 16:50:44 +0100 Subject: [PATCH 19/19] small clarity enhancement & repmat replacement --- metric_learn/constraints.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/metric_learn/constraints.py b/metric_learn/constraints.py index a78f29ce..bd51b29f 100644 --- a/metric_learn/constraints.py +++ b/metric_learn/constraints.py @@ -7,7 +7,6 @@ from six.moves import xrange from sklearn.utils import check_random_state from sklearn.neighbors import NearestNeighbors -from numpy.matlib import repmat __all__ = ['Constraints'] @@ -119,19 +118,21 @@ def generate_knntriplets(self, X, k_genuine, k_impostor): # get k_genuine genuine neighbors neigh.fit(X=X[gen_indx]) - gen_neigh = np.take(gen_indx, neigh.kneighbors( - n_neighbors=k_genuine_vec[i], - return_distance=False)) + # Take elements of gen_indx according to the yielded k-neighbors + gen_relative_indx = neigh.kneighbors(n_neighbors=k_genuine_vec[i], + return_distance=False) + gen_neigh = np.take(gen_indx, gen_relative_indx) # generate mask for impostors of current label imp_indx = np.where(~gen_mask) # get k_impostor impostor neighbors neigh.fit(X=X[imp_indx]) - imp_neigh = np.take(imp_indx, neigh.kneighbors( - n_neighbors=k_impostor_vec[i], - X=X[gen_mask], - return_distance=False)) + # Take elements of imp_indx according to the yielded k-neighbors + imp_relative_indx = neigh.kneighbors(n_neighbors=k_impostor_vec[i], + X=X[gen_mask], + return_distance=False) + imp_neigh = np.take(imp_indx, imp_relative_indx) # length = len_label*k_genuine*k_impostor start, finish = start_finish_indices[i:i+2] @@ -204,9 +205,9 @@ def comb(A, B, C, sizeB, sizeC): # generate_knntriplets helper function # generate an array with all combinations of choosing # an element from A, B and C - return np.vstack((repmat(A, sizeB*sizeC, 1).ravel(order='F'), - repmat(np.hstack(B), sizeC, 1).ravel(order='F'), - repmat(C, 1, sizeB).ravel())).T + return np.vstack((np.tile(A, (sizeB*sizeC, 1)).ravel(order='F'), + np.tile(np.hstack(B), (sizeC, 1)).ravel(order='F'), + np.tile(C, (1, sizeB)).ravel())).T def wrap_pairs(X, constraints):