diff --git a/doc/metric_learn.rst b/doc/metric_learn.rst index 76c91f48..8f91d91c 100644 --- a/doc/metric_learn.rst +++ b/doc/metric_learn.rst @@ -33,6 +33,7 @@ Supervised Learning Algorithms metric_learn.MMC_Supervised metric_learn.SDML_Supervised metric_learn.RCA_Supervised + metric_learn.SCML_Supervised Weakly Supervised Learning Algorithms ------------------------------------- @@ -45,6 +46,7 @@ Weakly Supervised Learning Algorithms metric_learn.LSML metric_learn.MMC metric_learn.SDML + metric_learn.SCML Unsupervised Learning Algorithms -------------------------------- diff --git a/doc/weakly_supervised.rst b/doc/weakly_supervised.rst index 174d1a8b..82793b5b 100644 --- a/doc/weakly_supervised.rst +++ b/doc/weakly_supervised.rst @@ -700,6 +700,63 @@ of triplets that have the right predicted ordering. Algorithms ---------- +.. _scml: + +:py:class:`SCML ` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Sparse Compositional Metric Learning +(:py:class:`SCML `) + +`SCML` learns a squared Mahalanobis distance from triplet constraints by +optimizing sparse positive weights assigned to a set of :math:`K` rank-one +PSD bases. This can be formulated as an optimization problem with only +:math:`K` parameters, that can be solved with an efficient stochastic +composite scheme. + +The Mahalanobis matrix :math:`M` is built from a basis set :math:`B = \{b_i\}_{i=\{1,...,K\}}` +weighted by a :math:`K` dimensional vector :math:`w = \{w_i\}_{i=\{1,...,K\}}` as: + +.. math:: + + M = \sum_{i=1}^K w_i b_i b_i^T = B \cdot diag(w) \cdot B^T \quad w_i \geq 0 + +Learning :math:`M` in this form makes it PSD by design, as it is a +nonnegative sum of PSD matrices. The basis set :math:`B` is fixed in advance +and it is possible to construct it from the data. The optimization problem +over :math:`w` is formulated as a classic margin-based hinge loss function +involving the set :math:`C` of triplets. A regularization :math:`\ell_1` +is added to yield a sparse combination. The formulation is the following: + +.. math:: + + \min_{w\geq 0} \sum_{(x_i,x_j,x_k)\in C} [1 + d_w(x_i,x_j)-d_w(x_i,x_k)]_+ + \beta||w||_1 + +where :math:`[\cdot]_+` is the hinge loss. + +.. topic:: Example Code: + +:: + + from metric_learn import SCML + + triplets = [[[1.2, 7.5], [1.3, 1.5], [6.2, 9.7]], + [[1.3, 4.5], [3.2, 4.6], [5.4, 5.4]], + [[3.2, 7.5], [3.3, 1.5], [8.2, 9.7]], + [[3.3, 4.5], [5.2, 4.6], [7.4, 5.4]]] + + scml = SCML() + scml.fit(triplets) + +.. topic:: References: + + .. [1] Y. Shi, A. Bellet and F. Sha. `Sparse Compositional Metric Learning. + `_. \ + (AAAI), 2014. + + .. [2] Adapted from original \ + `Matlab implementation.`_. + .. _learning_on_quadruplets: @@ -829,13 +886,13 @@ extension leads to more stable estimation when the dimension is high and only a small amount of constraints is given. The loss function of each constraint -:math:`d(\mathbf{x}_a, \mathbf{x}_b) < d(\mathbf{x}_c, \mathbf{x}_d)` is +:math:`d(\mathbf{x}_i, \mathbf{x}_j) < d(\mathbf{x}_k, \mathbf{x}_l)` is denoted as: .. math:: - H(d_\mathbf{M}(\mathbf{x}_a, \mathbf{x}_b) - - d_\mathbf{M}(\mathbf{x}_c, \mathbf{x}_d)) + H(d_\mathbf{M}(\mathbf{x}_i, \mathbf{x}_j) + - d_\mathbf{M}(\mathbf{x}_k, \mathbf{x}_l)) where :math:`H(\cdot)` is the squared Hinge loss function defined as: @@ -845,8 +902,8 @@ where :math:`H(\cdot)` is the squared Hinge loss function defined as: \,\,x^2 \qquad x>0\end{aligned}\right.\\ The summed loss function :math:`L(C)` is the simple sum over all constraints -:math:`C = \{(\mathbf{x}_a , \mathbf{x}_b , \mathbf{x}_c , \mathbf{x}_d) -: d(\mathbf{x}_a , \mathbf{x}_b) < d(\mathbf{x}_c , \mathbf{x}_d)\}`. The +:math:`C = \{(\mathbf{x}_i , \mathbf{x}_j , \mathbf{x}_k , \mathbf{x}_l) +: d(\mathbf{x}_i , \mathbf{x}_j) < d(\mathbf{x}_k , \mathbf{x}_l)\}`. The original paper suggested here should be a weighted sum since the confidence or probability of each constraint might differ. However, for the sake of simplicity and assumption of no extra knowledge provided, we just deploy @@ -858,9 +915,9 @@ knowledge: .. math:: - \min_\mathbf{M}(D_{ld}(\mathbf{M, M_0}) + \sum_{(\mathbf{x}_a, - \mathbf{x}_b, \mathbf{x}_c, \mathbf{x}_d)\in C}H(d_\mathbf{M}( - \mathbf{x}_a, \mathbf{x}_b) - d_\mathbf{M}(\mathbf{x}_c, \mathbf{x}_c))\\ + \min_\mathbf{M}(D_{ld}(\mathbf{M, M_0}) + \sum_{(\mathbf{x}_i, + \mathbf{x}_j, \mathbf{x}_k, \mathbf{x}_l)\in C}H(d_\mathbf{M}( + \mathbf{x}_i, \mathbf{x}_j) - d_\mathbf{M}(\mathbf{x}_k, \mathbf{x}_l))\\ where :math:`\mathbf{M}_0` is the prior metric matrix, set as identity by default, :math:`D_{ld}(\mathbf{\cdot, \cdot})` is the LogDet divergence: diff --git a/metric_learn/__init__.py b/metric_learn/__init__.py index b036ccfa..38aa2f7e 100644 --- a/metric_learn/__init__.py +++ b/metric_learn/__init__.py @@ -11,10 +11,12 @@ from .rca import RCA, RCA_Supervised from .mlkr import MLKR from .mmc import MMC, MMC_Supervised +from .scml import SCML, SCML_Supervised from ._version import __version__ __all__ = ['Constraints', 'Covariance', 'ITML', 'ITML_Supervised', 'LMNN', 'LSML', 'LSML_Supervised', 'SDML', 'SDML_Supervised', 'NCA', 'LFDA', 'RCA', 'RCA_Supervised', - 'MLKR', 'MMC', 'MMC_Supervised', '__version__'] + 'MLKR', 'MMC', 'MMC_Supervised', 'SCML', + 'SCML_Supervised', '__version__'] diff --git a/metric_learn/scml.py b/metric_learn/scml.py new file mode 100644 index 00000000..7bbd101a --- /dev/null +++ b/metric_learn/scml.py @@ -0,0 +1,646 @@ +""" +Sparse Compositional Metric Learning (SCML) +""" + +from __future__ import print_function, absolute_import, division +import numpy as np +from .base_metric import _TripletsClassifierMixin, MahalanobisMixin +from ._util import components_from_metric +from sklearn.base import TransformerMixin +from .constraints import Constraints +from sklearn.preprocessing import normalize +from sklearn.neighbors import NearestNeighbors +from sklearn.cluster import KMeans +from sklearn.discriminant_analysis import LinearDiscriminantAnalysis +from sklearn.utils import check_array, check_random_state +import warnings + + +class _BaseSCML(MahalanobisMixin): + + _tuple_size = 3 # constraints are triplets + _authorized_basis = ['triplet_diffs'] + + def __init__(self, beta=1e-5, basis='triplet_diffs', n_basis=None, + gamma=5e-3, max_iter=10000, output_iter=500, batch_size=10, + verbose=False, preprocessor=None, random_state=None): + self.beta = beta + self.basis = basis + self.n_basis = n_basis + self.gamma = gamma + self.max_iter = max_iter + self.output_iter = output_iter + self.batch_size = batch_size + self.verbose = verbose + self.preprocessor = preprocessor + self.random_state = random_state + super(_BaseSCML, self).__init__(preprocessor) + + def _fit(self, triplets, basis=None, n_basis=None): + """ + Optimization procedure to find a sparse vector of weights to + construct the metric from the basis set. This is based on the + dual averaging method. + """ + + if not isinstance(self.max_iter, int): + raise ValueError("max_iter should be an integer, instead it is of type" + " %s" % type(self.max_iter)) + if not isinstance(self.output_iter, int): + raise ValueError("output_iter should be an integer, instead it is of " + "type %s" % type(self.output_iter)) + if not isinstance(self.batch_size, int): + raise ValueError("batch_size should be an integer, instead it is of type" + " %s" % type(self.batch_size)) + + if(self.output_iter > self.max_iter): + raise ValueError("The value of output_iter must be equal or smaller than" + " max_iter.") + + # Currently prepare_inputs makes triplets contain points and not indices + triplets = self._prepare_inputs(triplets, type_of_inputs='tuples') + + # TODO: + # This algorithm is built to work with indices, but in order to be + # compliant with the current handling of inputs it is converted + # back to indices by the following function. This should be improved + # in the future. + triplets, X = self._to_index_points(triplets) + + if basis is None: + basis, n_basis = self._initialize_basis(triplets, X) + + dist_diff = self._compute_dist_diff(triplets, X, basis) + + n_triplets = triplets.shape[0] + + # weight vector + w = np.zeros((1, n_basis)) + # avarage obj gradient wrt weights + avg_grad_w = np.zeros((1, n_basis)) + + # l2 norm in time of all obj gradients wrt weights + ada_grad_w = np.zeros((1, n_basis)) + # slack for not dividing by zero + delta = 0.001 + + best_obj = np.inf + + rng = check_random_state(self.random_state) + rand_int = rng.randint(low=0, high=n_triplets, + size=(self.max_iter, self.batch_size)) + for iter in range(self.max_iter): + + idx = rand_int[iter] + + slack_val = 1 + np.matmul(dist_diff[idx, :], w.T) + slack_mask = np.squeeze(slack_val > 0, axis=1) + + grad_w = np.sum(dist_diff[idx[slack_mask], :], + axis=0, keepdims=True)/self.batch_size + avg_grad_w = (iter * avg_grad_w + grad_w) / (iter+1) + + ada_grad_w = np.sqrt(np.square(ada_grad_w) + np.square(grad_w)) + + scale_f = -(iter+1) / (self.gamma * (delta + ada_grad_w)) + + # proximal operator with negative trimming equivalent + w = scale_f * np.minimum(avg_grad_w + self.beta, 0) + + if (iter + 1) % self.output_iter == 0: + # regularization part of obj function + obj1 = np.sum(w)*self.beta + + # Every triplet distance difference in the space given by L + # plus a slack of one + slack_val = 1 + np.matmul(dist_diff, w.T) + # Mask of places with positive slack + slack_mask = slack_val > 0 + + # loss function of learning task part of obj function + obj2 = np.sum(slack_val[slack_mask])/n_triplets + + obj = obj1 + obj2 + if self.verbose: + count = np.sum(slack_mask) + print("[%s] iter %d\t obj %.6f\t num_imp %d" % + (self.__class__.__name__, (iter+1), obj, count)) + + # update the best + if obj < best_obj: + best_obj = obj + best_w = w + + if self.verbose: + print("max iteration reached.") + + # return L matrix yielded from best weights + self.n_iter_ = iter + self.components_ = self._components_from_basis_weights(basis, best_w) + + return self + + def _compute_dist_diff(self, triplets, X, basis): + """ + Helper function to compute the distance difference of every triplet in the + space yielded by the basis set. + """ + # Transformation of data by the basis set + XB = np.matmul(X, basis.T) + + n_triplets = triplets.shape[0] + # get all positive and negative pairs with lowest index first + # np.array (2*n_triplets,2) + triplets_pairs_sorted = np.sort(np.vstack((triplets[:, [0, 1]], + triplets[:, [0, 2]])), + kind='stable') + # calculate all unique pairs and their indices + uniqPairs, indices = np.unique(triplets_pairs_sorted, return_inverse=True, + axis=0) + # calculate L2 distance acording to bases only for unique pairs + dist = np.square(XB[uniqPairs[:, 0], :] - XB[uniqPairs[:, 1], :]) + + # return the diference of distances between all positive and negative + # pairs + return dist[indices[:n_triplets]] - dist[indices[n_triplets:]] + + def _components_from_basis_weights(self, basis, w): + """ + Get components matrix (L) from computed mahalanobis matrix. + """ + + # get rid of inactive bases + # TODO: Maybe have a tolerance over zero? + active_idx, = w > 0 + w = w[..., active_idx] + basis = basis[active_idx, :] + + n_basis, n_features = basis.shape + + if n_basis < n_features: # if metric is low-rank + warnings.warn("The number of bases with nonzero weight is less than the " + "number of features of the input, in consequence the " + "learned transformation reduces the dimension to %d." + % n_basis) + return np.sqrt(w.T)*basis # equivalent to np.diag(np.sqrt(w)).dot(basis) + + else: # if metric is full rank + return components_from_metric(np.matmul(basis.T, w.T*basis)) + + def _to_index_points(self, triplets): + shape = triplets.shape + X, triplets = np.unique(np.vstack(triplets), return_inverse=True, axis=0) + triplets = triplets.reshape(shape[:2]) + return triplets, X + + def _initialize_basis(self, triplets, X): + """ Checks if the basis array is well constructed or constructs it based + on one of the available options. + """ + n_features = X.shape[1] + + if isinstance(self.basis, np.ndarray): + # TODO: should copy? + basis = check_array(self.basis, copy=True) + if basis.shape[1] != n_features: + raise ValueError('The dimensionality ({}) of the provided bases must' + ' match the dimensionality of the data ' + '({}).'.format(basis.shape[1], n_features)) + elif self.basis not in self._authorized_basis: + raise ValueError( + "`basis` must be one of the options '{}' " + "or an array of shape (n_basis, n_features)." + .format("', '".join(self._authorized_basis))) + if self.basis == 'triplet_diffs': + basis, n_basis = self._generate_bases_dist_diff(triplets, X) + + return basis, n_basis + + def _generate_bases_dist_diff(self, triplets, X): + """ Constructs the basis set from the differences of positive and negative + pairs from the triplets constraints. + + The basis set is constructed iteratively by taking n_features triplets, + then adding and substracting respectively all the outerproducts of the + positive and negative pairs, and finally selecting the eigenvectors + of this matrix with positive eigenvalue. This is done until n_basis are + selected. + """ + n_features = X.shape[1] + n_triplets = triplets.shape[0] + + if self.n_basis is None: + # TODO: Get a good default n_basis directive + n_basis = n_features*80 + warnings.warn('As no value for `n_basis` was selected, the number of ' + 'basis will be set to n_basis= %d' % n_basis) + elif isinstance(self.n_basis, int): + n_basis = self.n_basis + else: + raise ValueError("n_basis should be an integer, instead it is of type %s" + % type(self.n_basis)) + + basis = np.zeros((n_basis, n_features)) + + # get all positive and negative pairs with lowest index first + # np.array (2*n_triplets,2) + triplets_pairs_sorted = np.sort(np.vstack((triplets[:, [0, 1]], + triplets[:, [0, 2]])), + kind='stable') + # calculate all unique pairs and their indices + uniqPairs, indices = np.unique(triplets_pairs_sorted, return_inverse=True, + axis=0) + # calculate differences only for unique pairs + diff = X[uniqPairs[:, 0], :] - X[uniqPairs[:, 1], :] + + diff_pos = diff[indices[:n_triplets], :] + diff_neg = diff[indices[n_triplets:], :] + + rng = check_random_state(self.random_state) + + start = 0 + finish = 0 + + while(finish != n_basis): + + # Select triplets to yield diff + + select_triplet = rng.choice(n_triplets, size=n_features, replace=False) + + # select n_features positive differences + d_pos = diff_pos[select_triplet, :] + + # select n_features negative differences + d_neg = diff_neg[select_triplet, :] + + # Yield matrix + diff_sum = d_pos.T.dot(d_pos) - d_neg.T.dot(d_neg) + + # Calculate eigenvalue and eigenvectors + w, v = np.linalg.eigh(diff_sum.T.dot(diff_sum)) + + # Add eigenvectors with positive eigenvalue to basis set + pos_eig_mask = w > 0 + start = finish + finish += pos_eig_mask.sum() + + try: + basis[start:finish, :] = v[pos_eig_mask] + except ValueError: + # if finish is greater than n_basis + basis[start:, :] = v[pos_eig_mask][:n_basis-start] + break + + # TODO: maybe add a warning in case there are no added bases, this could + # be caused by a bad triplet set. This would cause an infinite loop + + return basis, n_basis + + +class SCML(_BaseSCML, _TripletsClassifierMixin): + """Sparse Compositional Metric Learning (SCML) + + `SCML` learns an squared Mahalanobis distance from triplet constraints by + optimizing sparse positive weights assigned to a set of :math:`K` rank-one + PSD bases. This can be formulated as an optimization problem with only + :math:`K` parameters, that can be solved with an efficient stochastic + composite scheme. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + beta: float (default=1e-5) + L1 regularization parameter. + + basis : string or array-like, optional (default='triplet_diffs') + Set of bases to construct the metric. Possible options are + 'triplet_diffs', and an array-like of shape (n_basis, n_features). + + 'triplet_diffs' + The basis set is constructed from the differences between points of + `n_basis` positive or negative pairs taken from the triplets + constrains. + + array-like + A matrix of shape (n_basis, n_features), that will be used as + the basis set for the metric construction. + + n_basis : int, optional + Number of basis to be yielded. In case it is not set it will be set based + on `basis`. If no value is selected a default will be computed based on + the input. + + gamma: float (default = 5e-3) + Learning rate for the optimization algorithm. + + max_iter : int (default = 100000) + Number of iterations for the algorithm. + + output_iter : int (default = 5000) + Number of iterations to check current weights performance and output this + information in case verbose is True. + + verbose : bool, optional + If True, prints information while learning. + + preprocessor : array-like, shape=(n_samples, n_features) or callable + The preprocessor to call to get triplets from indices. If array-like, + triplets will be formed like this: X[indices]. + + random_state : int or numpy.RandomState or None, optional (default=None) + A pseudo random number generator object or a seed for it if int. + + Attributes + ---------- + components_ : `numpy.ndarray`, shape=(n_features, n_features) + The linear transformation ``L`` deduced from the learned Mahalanobis + metric (See function `_components_from_basis_weights`.) + + Examples + -------- + >>> from metric_learn import SCML + >>> triplets = [[[1.2, 7.5], [1.3, 1.5], [6.2, 9.7]], + >>> [[1.3, 4.5], [3.2, 4.6], [5.4, 5.4]], + >>> [[3.2, 7.5], [3.3, 1.5], [8.2, 9.7]], + >>> [[3.3, 4.5], [5.2, 4.6], [7.4, 5.4]]] + >>> scml = SCML() + >>> scml.fit(triplets) + + References + ---------- + .. [1] Y. Shi, A. Bellet and F. Sha. `Sparse Compositional Metric Learning. + `_. \ + (AAAI), 2014. + + .. [2] Adapted from original \ + `Matlab implementation.`_. + + See Also + -------- + metric_learn.SCML_Supervised : The supervised version of the algorithm. + + :ref:`supervised_version` : The section of the project documentation + that describes the supervised version of weakly supervised estimators. + """ + + def fit(self, triplets): + """Learn the SCML model. + + Parameters + ---------- + triplets : array-like, shape=(n_constraints, 3, n_features) or \ + (n_constraints, 3) + 3D array-like of triplets of points or 2D array of triplets of + indicators. Triplets are assumed to be ordered such that: + d(triplets[i, 0],triplets[i, 1]) < d(triplets[i, 0], triplets[i, 2]). + + Returns + ------- + self : object + Returns the instance. + """ + + return self._fit(triplets) + + +class SCML_Supervised(_BaseSCML, TransformerMixin): + """Supervised version of Sparse Compositional Metric Learning (SCML) + + `SCML_Supervised` creates triplets by taking `k_genuine` neighbours + of the same class and `k_impostor` neighbours from different classes for each + point and then runs the SCML algorithm on these triplets. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + beta: float (default=1e-5) + L1 regularization parameter. + + basis : string or an array-like, optional (default='lda') + Set of bases to construct the metric. Possible options are + 'lda', and an array-like of shape (n_basis, n_features). + + 'lda' + The `n_basis` basis set is constructed from the LDA of significant + local regions in the feature space via clustering, for each region + center k-nearest neighbors are used to obtain the LDA scalings, + which correspond to the locally discriminative basis. + + array-like + A matrix of shape (n_basis, n_features), that will be used as + the basis set for the metric construction. + + n_basis : int, optional + Number of basis to be yielded. In case it is not set it will be set based + on `basis`. If no value is selected a default will be computed based on + the input. + + gamma: float (default = 5e-3) + Learning rate for the optimization algorithm. + + max_iter : int (default = 100000) + Number of iterations for the algorithm. + + output_iter : int (default = 5000) + Number of iterations to check current weights performance and output this + information in case verbose is True. + + verbose : bool, optional + If True, prints information while learning. + + preprocessor : array-like, shape=(n_samples, n_features) or callable + The preprocessor to call to get triplets from indices. If array-like, + triplets will be formed like this: X[indices]. + + random_state : int or numpy.RandomState or None, optional (default=None) + A pseudo random number generator object or a seed for it if int. + + Attributes + ---------- + components_ : `numpy.ndarray`, shape=(n_features, n_features) + The linear transformation ``L`` deduced from the learned Mahalanobis + metric (See function `_components_from_basis_weights`.) + + Examples + -------- + >>> from metric_learn import SCML + >>> triplets = np.array([[[1.2, 3.2], [2.3, 5.5], [2.1, 0.6]], + >>> [[4.5, 2.3], [2.1, 2.3], [7.3, 3.4]]]) + >>> scml = SCML(random_state=42) + >>> scml.fit(triplets) + SCML(beta=1e-5, B=None, max_iter=100000, verbose=False, + preprocessor=None, random_state=None) + + References + ---------- + .. [1] Y. Shi, A. Bellet and F. Sha. `Sparse Compositional Metric Learning. + `_. \ + (AAAI), 2014. + + .. [2] Adapted from original \ + `Matlab implementation.`_. + + See Also + -------- + metric_learn.SCML : The weakly supervised version of this + algorithm. + """ + # Add supervised authorized basis construction options + _authorized_basis = _BaseSCML._authorized_basis + ['lda'] + + def __init__(self, k_genuine=3, k_impostor=10, beta=1e-5, basis='lda', + n_basis=None, gamma=5e-3, max_iter=10000, output_iter=500, + batch_size=10, verbose=False, preprocessor=None, + random_state=None): + self.k_genuine = k_genuine + self.k_impostor = k_impostor + _BaseSCML.__init__(self, beta=beta, basis=basis, n_basis=n_basis, + max_iter=max_iter, output_iter=output_iter, + batch_size=batch_size, verbose=verbose, + preprocessor=preprocessor, random_state=random_state) + + def fit(self, X, y): + """Create constraints from labels and learn the SCML model. + + Parameters + ---------- + X : (n x d) matrix + Input data, where each row corresponds to a single instance. + + y : (n) array-like + Data labels. + + Returns + ------- + self : object + Returns the instance. + """ + X, y = self._prepare_inputs(X, y, ensure_min_samples=2) + + basis, n_basis = self._initialize_basis_supervised(X, y) + + if not isinstance(self.k_genuine, int): + raise ValueError("k_genuine should be an integer, instead it is of type" + " %s" % type(self.k_genuine)) + if not isinstance(self.k_impostor, int): + raise ValueError("k_impostor should be an integer, instead it is of " + "type %s" % type(self.k_impostor)) + + constraints = Constraints(y) + triplets = constraints.generate_knntriplets(X, self.k_genuine, + self.k_impostor) + + triplets = X[triplets] + + return self._fit(triplets, basis, n_basis) + + def _initialize_basis_supervised(self, X, y): + """ Constructs the basis set following one of the supervised options in + case one is selected. + """ + + if self.basis == 'lda': + basis, n_basis = self._generate_bases_LDA(X, y) + else: + basis, n_basis = None, None + + return basis, n_basis + + def _generate_bases_LDA(self, X, y): + """ Generates bases for the 'lda' option. + + The basis set is constructed using Linear Discriminant Analysis of + significant local regions in the feature space via clustering, for + each region center k-nearest neighbors are used to obtain the LDA scalings, + which correspond to the locally discriminative basis. Currently this is + done at two scales `k={10,20}` if `n_feature < 50` or else `k={20,50}`. + """ + + labels, class_count = np.unique(y, return_counts=True) + n_class = len(labels) + + n_features = X.shape[1] + # Number of basis yielded from each LDA + num_eig = min(n_class-1, n_features) + + if self.n_basis is None: + # TODO: Get a good default n_basis directive + n_basis = min(20*n_features, X.shape[0]*2*num_eig - 1) + warnings.warn('As no value for `n_basis` was selected, the number of ' + 'basis will be set to n_basis= %d' % n_basis) + + elif isinstance(self.n_basis, int): + n_basis = self.n_basis + else: + raise ValueError("n_basis should be an integer, instead it is of type %s" + % type(self.n_basis)) + + # Number of clusters needed for 2 scales given the number of basis + # yielded by every LDA + n_clusters = int(np.ceil(n_basis/(2 * num_eig))) + + if n_basis < n_class: + warnings.warn("The number of basis is less than the number of classes, " + "which may lead to poor discriminative performance.") + elif n_basis >= X.shape[0]*2*num_eig: + raise ValueError("Not enough samples to generate %d LDA bases, n_basis" + "should be smaller than %d" % + (n_basis, X.shape[0]*2*num_eig)) + + kmeans = KMeans(n_clusters=n_clusters, random_state=self.random_state, + algorithm='elkan').fit(X) + cX = kmeans.cluster_centers_ + + n_scales = 2 + if n_features > 50: + scales = [20, 50] + else: + scales = [10, 20] + + k_class = np.vstack((np.minimum(class_count, scales[0]), + np.minimum(class_count, scales[1]))) + + idx_set = [np.zeros((n_clusters, sum(k_class[0, :])), dtype=np.int), + np.zeros((n_clusters, sum(k_class[1, :])), dtype=np.int)] + + start_finish_indices = np.hstack((np.zeros((2, 1), np.int), + k_class)).cumsum(axis=1) + + neigh = NearestNeighbors() + + for c in range(n_class): + sel_c = np.where(y == labels[c]) + + # get k_class same class neighbors + neigh.fit(X=X[sel_c]) + # Only take the neighbors once for the biggest scale + neighbors = neigh.kneighbors(X=cX, n_neighbors=k_class[-1, c], + return_distance=False) + + # add index set of neighbors for every cluster center for both scales + for s, k in enumerate(k_class[:, c]): + start, finish = start_finish_indices[s, c:c+2] + idx_set[s][:, start:finish] = np.take(sel_c, neighbors[:, :k]) + + # Compute basis for every cluster in both scales + basis = np.zeros((n_basis, n_features)) + lda = LinearDiscriminantAnalysis() + start_finish_indices = np.hstack((np.vstack((0, n_clusters * num_eig)), + np.full((2, n_clusters), + num_eig))).cumsum(axis=1) + + for s in range(n_scales): + for c in range(n_clusters): + lda.fit(X[idx_set[s][c, :]], y[idx_set[s][c, :]]) + start, finish = start_finish_indices[s, c:c+2] + normalized_scalings = normalize(lda.scalings_.T) + try: + basis[start: finish, :] = normalized_scalings + except ValueError: + # handle tail + basis[start:, :] = normalized_scalings[:n_basis-start] + break + + return basis, n_basis diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index 5a271890..318b200e 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -13,6 +13,7 @@ from sklearn.utils.testing import assert_warns_message from sklearn.exceptions import ConvergenceWarning, ChangedBehaviorWarning from sklearn.utils.validation import check_X_y +from sklearn.preprocessing import StandardScaler try: from inverse_covariance import quic assert(quic) @@ -21,11 +22,11 @@ else: HAS_SKGGM = True from metric_learn import (LMNN, NCA, LFDA, Covariance, MLKR, MMC, - LSML_Supervised, ITML_Supervised, SDML_Supervised, - RCA_Supervised, MMC_Supervised, SDML, RCA, ITML, - LSML) + SCML_Supervised, LSML_Supervised, + ITML_Supervised, SDML_Supervised, RCA_Supervised, + MMC_Supervised, SDML, RCA, ITML, LSML, SCML) # Import this specially for testing. -from metric_learn.constraints import wrap_pairs +from metric_learn.constraints import wrap_pairs, Constraints from metric_learn.lmnn import _sum_outer_products @@ -76,6 +77,235 @@ def test_singular_returns_pseudo_inverse(self): pseudo_inverse) +class TestSCML(object): + @pytest.mark.parametrize('basis', ('lda', 'triplet_diffs')) + def test_iris(self, basis): + X, y = load_iris(return_X_y=True) + scml = SCML_Supervised(basis=basis, n_basis=85, k_genuine=7, k_impostor=5, + random_state=42) + scml.fit(X, y) + csep = class_separation(scml.transform(X), y) + assert csep < 0.24 + + def test_big_n_features(self): + X, y = make_classification(n_samples=100, n_classes=3, n_features=60, + n_informative=60, n_redundant=0, n_repeated=0, + random_state=42) + X = StandardScaler().fit_transform(X) + scml = SCML_Supervised(random_state=42) + scml.fit(X, y) + csep = class_separation(scml.transform(X), y) + assert csep < 0.7 + + @pytest.mark.parametrize(('estimator', 'data'), + [(SCML, (np.ones((3, 3, 3)),)), + (SCML_Supervised, (np.array([[0, 0], [0, 1], + [2, 0], [2, 1]]), + np.array([1, 0, 1, 0])))]) + def test_bad_basis(self, estimator, data): + model = estimator(basis='bad_basis') + msg = ("`basis` must be one of the options '{}' or an array of shape " + "(n_basis, n_features)." + .format("', '".join(model._authorized_basis))) + with pytest.raises(ValueError) as raised_error: + model.fit(*data) + assert msg == raised_error.value.args[0] + + def test_dimension_reduction_msg(self): + scml = SCML(n_basis=2) + triplets = np.array([[[0, 1], [2, 1], [0, 0]], + [[2, 1], [0, 1], [2, 0]], + [[0, 0], [2, 0], [0, 1]], + [[2, 0], [0, 0], [2, 1]]]) + msg = ("The number of bases with nonzero weight is less than the " + "number of features of the input, in consequence the " + "learned transformation reduces the dimension to 1.") + with pytest.warns(UserWarning) as raised_warning: + scml.fit(triplets) + assert msg == raised_warning[0].message.args[0] + + @pytest.mark.parametrize(('estimator', 'data'), + [(SCML, (np.array([[[0, 1], [2, 1], [0, 0]], + [[2, 1], [0, 1], [2, 0]], + [[0, 0], [2, 0], [0, 1]], + [[2, 0], [0, 0], [2, 1]]]),)), + (SCML_Supervised, (np.array([[0, 0], [1, 1], + [3, 3]]), + np.array([1, 2, 3])))]) + def test_n_basis_wrong_type(self, estimator, data): + n_basis = 4.0 + model = estimator(n_basis=n_basis) + msg = ("n_basis should be an integer, instead it is of type %s" + % type(n_basis)) + with pytest.raises(ValueError) as raised_error: + model.fit(*data) + assert msg == raised_error.value.args[0] + + def test_small_n_basis_lda(self): + X = np.array([[0, 0], [1, 1], [2, 2], [3, 3]]) + y = np.array([0, 0, 1, 1]) + + n_class = 2 + scml = SCML_Supervised(n_basis=n_class-1) + msg = ("The number of basis is less than the number of classes, which may" + " lead to poor discriminative performance.") + with pytest.warns(UserWarning) as raised_warning: + scml.fit(X, y) + assert msg == raised_warning[0].message.args[0] + + def test_big_n_basis_lda(self): + X = np.array([[0, 0], [1, 1], [3, 3]]) + y = np.array([1, 2, 3]) + + n_class = 3 + num_eig = min(n_class - 1, X.shape[1]) + n_basis = X.shape[0] * 2 * num_eig + + scml = SCML_Supervised(n_basis=n_basis) + msg = ("Not enough samples to generate %d LDA bases, n_basis" + "should be smaller than %d" % + (n_basis, n_basis)) + with pytest.raises(ValueError) as raised_error: + scml.fit(X, y) + assert msg == raised_error.value.args[0] + + @pytest.mark.parametrize(('estimator', 'data'), + [(SCML, (np.random.rand(3, 3, 2),)), + (SCML_Supervised, (np.array([[0, 0], [0, 1], + [2, 0], [2, 1]]), + np.array([1, 0, 1, 0])))]) + def test_array_basis(self, estimator, data): + """ Test that the proper error is raised when the shape of the input basis + array is not consistent with the input + """ + basis = np.eye(3) + scml = estimator(n_basis=3, basis=basis) + + msg = ('The dimensionality ({}) of the provided bases must match the ' + 'dimensionality of the data ({}).' + .format(basis.shape[1], data[0].shape[-1])) + with pytest.raises(ValueError) as raised_error: + scml.fit(*data) + assert msg == raised_error.value.args[0] + + @pytest.mark.parametrize(('estimator', 'data'), + [(SCML, (np.array([[0, 1, 2], [0, 1, 3], [1, 0, 2], + [1, 0, 3], [2, 3, 1], [2, 3, 0], + [3, 2, 1], [3, 2, 0]]),)), + (SCML_Supervised, (np.array([0, 1, 2, 3]), + np.array([0, 0, 1, 1])))]) + def test_verbose(self, estimator, data, capsys): + # assert there is proper output when verbose = True + model = estimator(preprocessor=np.array([[0, 0], [1, 1], [2, 2], [3, 3]]), + max_iter=1, output_iter=1, batch_size=1, + basis='triplet_diffs', random_state=42, verbose=True) + model.fit(*data) + out, _ = capsys.readouterr() + expected_out = ('[%s] iter 1\t obj 0.569946\t num_imp 2\n' + 'max iteration reached.\n' % estimator.__name__) + assert out == expected_out + + def test_triplet_diffs_toy(self): + expected_n_basis = 10 + model = SCML_Supervised(n_basis=expected_n_basis) + X = np.array([[0, 0], [1, 1], [2, 2], [3, 3]]) + triplets = np.array([[0, 1, 2], [0, 1, 3], [1, 0, 2], [1, 0, 3], + [2, 3, 1], [2, 3, 0], [3, 2, 1], [3, 2, 0]]) + basis, n_basis = model._generate_bases_dist_diff(triplets, X) + # All points are along the same line, so the only possible basis will be + # the vector along that line normalized. + expected_basis = np.ones((expected_n_basis, 2))/np.sqrt(2) + assert n_basis == expected_n_basis + np.testing.assert_allclose(basis, expected_basis) + + def test_lda_toy(self): + expected_n_basis = 7 + model = SCML_Supervised(n_basis=expected_n_basis) + X = np.array([[0, 0], [1, 1], [2, 2], [3, 3]]) + y = np.array([0, 0, 1, 1]) + basis, n_basis = model._generate_bases_LDA(X, y) + # All points are along the same line, so the only possible basis will be + # the vector along that line normalized. In this case it is possible to + # obtain it with positive or negative orientations. + expected_basis = np.ones((expected_n_basis, 2))/np.sqrt(2) + assert n_basis == expected_n_basis + np.testing.assert_allclose(np.abs(basis), expected_basis) + + @pytest.mark.parametrize('n_samples', [100, 500]) + @pytest.mark.parametrize('n_features', [10, 50, 100]) + @pytest.mark.parametrize('n_classes', [5, 10, 15]) + def test_triplet_diffs(self, n_samples, n_features, n_classes): + X, y = make_classification(n_samples=n_samples, n_classes=n_classes, + n_features=n_features, n_informative=n_features, + n_redundant=0, n_repeated=0) + X = StandardScaler().fit_transform(X) + + model = SCML_Supervised() + constraints = Constraints(y) + triplets = constraints.generate_knntriplets(X, model.k_genuine, + model.k_impostor) + basis, n_basis = model._generate_bases_dist_diff(triplets, X) + + expected_n_basis = n_features * 80 + assert n_basis == expected_n_basis + assert basis.shape == (expected_n_basis, n_features) + + @pytest.mark.parametrize('n_samples', [100, 500]) + @pytest.mark.parametrize('n_features', [10, 50, 100]) + @pytest.mark.parametrize('n_classes', [5, 10, 15]) + def test_lda(self, n_samples, n_features, n_classes): + X, y = make_classification(n_samples=n_samples, n_classes=n_classes, + n_features=n_features, n_informative=n_features, + n_redundant=0, n_repeated=0) + X = StandardScaler().fit_transform(X) + + model = SCML_Supervised() + basis, n_basis = model._generate_bases_LDA(X, y) + + num_eig = min(n_classes - 1, n_features) + expected_n_basis = min(20 * n_features, n_samples * 2 * num_eig - 1) + assert n_basis == expected_n_basis + assert basis.shape == (expected_n_basis, n_features) + + @pytest.mark.parametrize('name', ['max_iter', 'output_iter', 'batch_size', + 'n_basis']) + def test_int_inputs(self, name): + value = 1.0 + d = {name: value} + scml = SCML(**d) + triplets = np.array([[[0, 1], [2, 1], [0, 0]]]) + + msg = ("%s should be an integer, instead it is of type" + " %s" % (name, type(value))) + with pytest.raises(ValueError) as raised_error: + scml.fit(triplets) + assert msg == raised_error.value.args[0] + + @pytest.mark.parametrize('name', ['max_iter', 'output_iter', 'batch_size', + 'k_genuine', 'k_impostor', 'n_basis']) + def test_int_inputs_supervised(self, name): + value = 1.0 + d = {name: value} + scml = SCML_Supervised(**d) + X = np.array([[0, 0], [1, 1], [3, 3], [4, 4]]) + y = np.array([1, 1, 0, 0]) + msg = ("%s should be an integer, instead it is of type" + " %s" % (name, type(value))) + with pytest.raises(ValueError) as raised_error: + scml.fit(X, y) + assert msg == raised_error.value.args[0] + + def test_large_output_iter(self): + scml = SCML(max_iter=1, output_iter=2) + triplets = np.array([[[0, 1], [2, 1], [0, 0]]]) + msg = ("The value of output_iter must be equal or smaller than" + " max_iter.") + + with pytest.raises(ValueError) as raised_error: + scml.fit(triplets) + assert msg == raised_error.value.args[0] + + class TestLSML(MetricTestCase): def test_iris(self): lsml = LSML_Supervised(num_constraints=200) diff --git a/test/test_base_metric.py b/test/test_base_metric.py index b2b1d339..b1be4e84 100644 --- a/test/test_base_metric.py +++ b/test/test_base_metric.py @@ -5,7 +5,7 @@ import numpy as np from sklearn import clone from sklearn.utils.testing import set_random_state -from test.test_utils import ids_metric_learners, metric_learners +from test.test_utils import ids_metric_learners, metric_learners, remove_y def remove_spaces(s): @@ -135,12 +135,12 @@ def test_get_metric_is_independent_from_metric_learner(estimator, # we fit the metric learner on it and then we compute the metric on some # points - model.fit(input_data, labels) + model.fit(*remove_y(model, input_data, labels)) metric = model.get_metric() score = metric(X[0], X[1]) # then we refit the estimator on another dataset - model.fit(np.sin(input_data), labels) + model.fit(*remove_y(model, np.sin(input_data), labels)) # we recompute the distance between the two points: it should be the same score_bis = metric(X[0], X[1]) @@ -155,7 +155,7 @@ def test_get_metric_raises_error(estimator, build_dataset): input_data, labels, _, X = build_dataset() model = clone(estimator) set_random_state(model) - model.fit(input_data, labels) + model.fit(*remove_y(model, input_data, labels)) metric = model.get_metric() list_test_get_metric_raises = [(X[0].tolist() + [5.2], X[1]), # vectors with @@ -178,7 +178,7 @@ def test_get_metric_works_does_not_raise(estimator, build_dataset): input_data, labels, _, X = build_dataset() model = clone(estimator) set_random_state(model) - model.fit(input_data, labels) + model.fit(*remove_y(model, input_data, labels)) metric = model.get_metric() list_test_get_metric_doesnt_raise = [(X[0], X[1]), @@ -210,20 +210,20 @@ def test_n_components(estimator, build_dataset): if hasattr(model, 'n_components'): set_random_state(model) model.set_params(n_components=None) - model.fit(input_data, labels) + model.fit(*remove_y(model, input_data, labels)) assert model.components_.shape == (X.shape[1], X.shape[1]) model = clone(estimator) set_random_state(model) model.set_params(n_components=X.shape[1] - 1) - model.fit(input_data, labels) + model.fit(*remove_y(model, input_data, labels)) assert model.components_.shape == (X.shape[1] - 1, X.shape[1]) model = clone(estimator) set_random_state(model) model.set_params(n_components=X.shape[1] + 1) with pytest.raises(ValueError) as expected_err: - model.fit(input_data, labels) + model.fit(*remove_y(model, input_data, labels)) assert (str(expected_err.value) == 'Invalid n_components, must be in [1, {}]'.format(X.shape[1])) @@ -231,7 +231,7 @@ def test_n_components(estimator, build_dataset): set_random_state(model) model.set_params(n_components=0) with pytest.raises(ValueError) as expected_err: - model.fit(input_data, labels) + model.fit(*remove_y(model, input_data, labels)) assert (str(expected_err.value) == 'Invalid n_components, must be in [1, {}]'.format(X.shape[1])) diff --git a/test/test_mahalanobis_mixin.py b/test/test_mahalanobis_mixin.py index 91fb435f..2e3c3ef4 100644 --- a/test/test_mahalanobis_mixin.py +++ b/test/test_mahalanobis_mixin.py @@ -15,11 +15,12 @@ from metric_learn._util import make_context, _initialize_metric_mahalanobis from metric_learn.base_metric import (_QuadrupletsClassifierMixin, + _TripletsClassifierMixin, _PairsClassifierMixin) from metric_learn.exceptions import NonPSDError from test.test_utils import (ids_metric_learners, metric_learners, - remove_y_quadruplets, ids_classifiers) + remove_y, ids_classifiers) RNG = check_random_state(0) @@ -33,7 +34,7 @@ def test_score_pairs_pairwise(estimator, build_dataset): X = X[:n_samples] model = clone(estimator) set_random_state(model) - model.fit(*remove_y_quadruplets(estimator, input_data, labels)) + model.fit(*remove_y(estimator, input_data, labels)) pairwise = model.score_pairs(np.array(list(product(X, X))))\ .reshape(n_samples, n_samples) @@ -57,7 +58,7 @@ def test_score_pairs_toy_example(estimator, build_dataset): X = X[:n_samples] model = clone(estimator) set_random_state(model) - model.fit(*remove_y_quadruplets(estimator, input_data, labels)) + model.fit(*remove_y(estimator, input_data, labels)) pairs = np.stack([X[:10], X[10:20]], axis=1) embedded_pairs = pairs.dot(model.components_.T) distances = np.sqrt(np.sum((embedded_pairs[:, 1] - @@ -73,7 +74,7 @@ def test_score_pairs_finite(estimator, build_dataset): input_data, labels, _, X = build_dataset() model = clone(estimator) set_random_state(model) - model.fit(*remove_y_quadruplets(estimator, input_data, labels)) + model.fit(*remove_y(estimator, input_data, labels)) pairs = np.array(list(product(X, X))) assert np.isfinite(model.score_pairs(pairs)).all() @@ -87,7 +88,7 @@ def test_score_pairs_dim(estimator, build_dataset): input_data, labels, _, X = build_dataset() model = clone(estimator) set_random_state(model) - model.fit(*remove_y_quadruplets(estimator, input_data, labels)) + model.fit(*remove_y(estimator, input_data, labels)) tuples = np.array(list(product(X, X))) assert model.score_pairs(tuples).shape == (tuples.shape[0],) context = make_context(estimator) @@ -118,7 +119,7 @@ def test_embed_toy_example(estimator, build_dataset): X = X[:n_samples] model = clone(estimator) set_random_state(model) - model.fit(*remove_y_quadruplets(estimator, input_data, labels)) + model.fit(*remove_y(estimator, input_data, labels)) embedded_points = X.dot(model.components_.T) assert_array_almost_equal(model.transform(X), embedded_points) @@ -130,7 +131,7 @@ def test_embed_dim(estimator, build_dataset): input_data, labels, _, X = build_dataset() model = clone(estimator) set_random_state(model) - model.fit(*remove_y_quadruplets(estimator, input_data, labels)) + model.fit(*remove_y(estimator, input_data, labels)) assert model.transform(X).shape == X.shape # assert that ValueError is thrown if input shape is 1D @@ -144,7 +145,7 @@ def test_embed_dim(estimator, build_dataset): # we test that the shape is also OK when doing dimensionality reduction if hasattr(model, 'n_components'): model.set_params(n_components=2) - model.fit(*remove_y_quadruplets(estimator, input_data, labels)) + model.fit(*remove_y(estimator, input_data, labels)) assert model.transform(X).shape == (X.shape[0], 2) # assert that ValueError is thrown if input shape is 1D with pytest.raises(ValueError) as raised_error: @@ -159,7 +160,7 @@ def test_embed_finite(estimator, build_dataset): input_data, labels, _, X = build_dataset() model = clone(estimator) set_random_state(model) - model.fit(*remove_y_quadruplets(estimator, input_data, labels)) + model.fit(*remove_y(estimator, input_data, labels)) assert np.isfinite(model.transform(X)).all() @@ -170,7 +171,7 @@ def test_embed_is_linear(estimator, build_dataset): input_data, labels, _, X = build_dataset() model = clone(estimator) set_random_state(model) - model.fit(*remove_y_quadruplets(estimator, input_data, labels)) + model.fit(*remove_y(estimator, input_data, labels)) assert_array_almost_equal(model.transform(X[:10] + X[10:20]), model.transform(X[:10]) + model.transform(X[10:20])) @@ -189,7 +190,7 @@ def test_get_metric_equivalent_to_explicit_mahalanobis(estimator, input_data, labels, _, X = build_dataset() model = clone(estimator) set_random_state(model) - model.fit(*remove_y_quadruplets(estimator, input_data, labels)) + model.fit(*remove_y(estimator, input_data, labels)) metric = model.get_metric() n_features = X.shape[1] a, b = (rng.randn(n_features), rng.randn(n_features)) @@ -208,7 +209,7 @@ def test_get_metric_is_pseudo_metric(estimator, build_dataset): input_data, labels, _, X = build_dataset() model = clone(estimator) set_random_state(model) - model.fit(*remove_y_quadruplets(estimator, input_data, labels)) + model.fit(*remove_y(estimator, input_data, labels)) metric = model.get_metric() n_features = X.shape[1] @@ -234,7 +235,7 @@ def test_metric_raises_deprecation_warning(estimator, build_dataset): input_data, labels, _, X = build_dataset() model = clone(estimator) set_random_state(model) - model.fit(*remove_y_quadruplets(estimator, input_data, labels)) + model.fit(*remove_y(estimator, input_data, labels)) with pytest.warns(DeprecationWarning) as raised_warning: model.metric() @@ -251,7 +252,7 @@ def test_get_metric_compatible_with_scikit_learn(estimator, build_dataset): input_data, labels, _, X = build_dataset() model = clone(estimator) set_random_state(model) - model.fit(*remove_y_quadruplets(estimator, input_data, labels)) + model.fit(*remove_y(estimator, input_data, labels)) clustering = DBSCAN(metric=model.get_metric()) clustering.fit(X) @@ -264,7 +265,7 @@ def test_get_squared_metric(estimator, build_dataset): input_data, labels, _, X = build_dataset() model = clone(estimator) set_random_state(model) - model.fit(*remove_y_quadruplets(estimator, input_data, labels)) + model.fit(*remove_y(estimator, input_data, labels)) metric = model.get_metric() n_features = X.shape[1] @@ -284,26 +285,31 @@ def test_components_is_2D(estimator, build_dataset): model = clone(estimator) set_random_state(model) # test that it works for X.shape[1] features - model.fit(*remove_y_quadruplets(estimator, input_data, labels)) + model.fit(*remove_y(estimator, input_data, labels)) assert model.components_.shape == (X.shape[1], X.shape[1]) # test that it works for 1 feature trunc_data = input_data[..., :1] # we drop duplicates that might have been formed, i.e. of the form # aabc or abcc or aabb for quadruplets, and aa for pairs. + if isinstance(estimator, _QuadrupletsClassifierMixin): - for slice_idx in [slice(0, 2), slice(2, 4)]: - pairs = trunc_data[:, slice_idx, :] - diffs = pairs[:, 1, :] - pairs[:, 0, :] - to_keep = np.where(np.abs(diffs.ravel()) > 1e-9) - trunc_data = trunc_data[to_keep] - labels = labels[to_keep] + pairs_idx = [[0, 1], [2, 3]] + elif isinstance(estimator, _TripletsClassifierMixin): + pairs_idx = [[0, 1], [0, 2]] elif isinstance(estimator, _PairsClassifierMixin): - diffs = trunc_data[:, 1, :] - trunc_data[:, 0, :] - to_keep = np.where(np.abs(diffs.ravel()) > 1e-9) + pairs_idx = [[0, 1]] + else: + pairs_idx = [] + + for pair_idx in pairs_idx: + pairs = trunc_data[:, pair_idx, :] + diffs = pairs[:, 1, :] - pairs[:, 0, :] + to_keep = np.abs(diffs.ravel()) > 1e-9 trunc_data = trunc_data[to_keep] labels = labels[to_keep] - model.fit(*remove_y_quadruplets(estimator, trunc_data, labels)) + + model.fit(*remove_y(estimator, trunc_data, labels)) assert model.components_.shape == (1, 1) # the components must be 2D @@ -735,9 +741,9 @@ def test_deterministic_initialization(estimator, build_dataset): model.set_params(prior='random') model1 = clone(model) set_random_state(model1, 42) - model1 = model1.fit(input_data, labels) + model1 = model1.fit(*remove_y(model, input_data, labels)) model2 = clone(model) set_random_state(model2, 42) - model2 = model2.fit(input_data, labels) + model2 = model2.fit(*remove_y(model, input_data, labels)) np.testing.assert_allclose(model1.get_mahalanobis_matrix(), model2.get_mahalanobis_matrix()) diff --git a/test/test_sklearn_compat.py b/test/test_sklearn_compat.py index b2056c09..7f7d7037 100644 --- a/test/test_sklearn_compat.py +++ b/test/test_sklearn_compat.py @@ -10,7 +10,8 @@ from metric_learn import (Covariance, LFDA, LMNN, MLKR, NCA, ITML_Supervised, LSML_Supervised, - MMC_Supervised, RCA_Supervised, SDML_Supervised) + MMC_Supervised, RCA_Supervised, SDML_Supervised, + SCML_Supervised) from sklearn import clone import numpy as np from sklearn.model_selection import (cross_val_score, cross_val_predict, @@ -20,8 +21,9 @@ from test.test_utils import (metric_learners, ids_metric_learners, mock_preprocessor, tuples_learners, ids_tuples_learners, pairs_learners, - ids_pairs_learners, remove_y_quadruplets, - quadruplets_learners) + ids_pairs_learners, remove_y, + metric_learners_pipeline, + ids_metric_learners_pipeline) class Stable_RCA_Supervised(RCA_Supervised): @@ -79,6 +81,9 @@ def test_sdml(self): def test_rca(self): check_estimator(Stable_RCA_Supervised) + def test_scml(self): + check_estimator(SCML_Supervised) + RNG = check_random_state(0) @@ -125,8 +130,7 @@ def test_array_like_inputs(estimator, build_dataset, with_preprocessor): input_variants, label_variants = generate_array_like(input_data, labels) for input_variant in input_variants: for label_variant in label_variants: - estimator.fit(*remove_y_quadruplets(estimator, input_variant, - label_variant)) + estimator.fit(*remove_y(estimator, input_variant, label_variant)) if hasattr(estimator, "predict"): estimator.predict(input_variant) if hasattr(estimator, "predict_proba"): @@ -137,8 +141,7 @@ def test_array_like_inputs(estimator, build_dataset, with_preprocessor): estimator.decision_function(input_variant) if hasattr(estimator, "score"): for label_variant in label_variants: - estimator.score(*remove_y_quadruplets(estimator, input_variant, - label_variant)) + estimator.score(*remove_y(estimator, input_variant, label_variant)) X_variants, _ = generate_array_like(X) for X_variant in X_variants: @@ -199,13 +202,10 @@ def test_cross_validation_is_finite(estimator, build_dataset): estimator.set_params(preprocessor=preprocessor) set_random_state(estimator) assert np.isfinite(cross_val_score(estimator, - *remove_y_quadruplets(estimator, - input_data, - labels))).all() + *remove_y(estimator, input_data, labels) + )).all() assert np.isfinite(cross_val_predict(estimator, - *remove_y_quadruplets(estimator, - input_data, - labels) + *remove_y(estimator, input_data, labels) )).all() @@ -237,28 +237,26 @@ def test_cross_validation_manual_vs_scikit(estimator, build_dataset, train_mask = np.ones(input_data.shape[0], bool) train_mask[test_slice] = False y_train, y_test = labels[train_mask], labels[test_slice] - estimator.fit(*remove_y_quadruplets(estimator, - input_data[train_mask], - y_train)) + estimator.fit(*remove_y(estimator, input_data[train_mask], y_train)) if hasattr(estimator, "score"): - scores.append(estimator.score(*remove_y_quadruplets( + scores.append(estimator.score(*remove_y( estimator, input_data[test_slice], y_test))) if hasattr(estimator, "predict"): predictions[test_slice] = estimator.predict(input_data[test_slice]) if hasattr(estimator, "score"): assert all(scores == cross_val_score( - estimator, *remove_y_quadruplets(estimator, input_data, labels), + estimator, *remove_y(estimator, input_data, labels), cv=kfold)) if hasattr(estimator, "predict"): assert all(predictions == cross_val_predict( estimator, - *remove_y_quadruplets(estimator, input_data, labels), + *remove_y(estimator, input_data, labels), cv=kfold)) def check_score(estimator, tuples, y): if hasattr(estimator, "score"): - score = estimator.score(*remove_y_quadruplets(estimator, tuples, y)) + score = estimator.score(*remove_y(estimator, tuples, y)) assert np.isfinite(score) @@ -282,7 +280,7 @@ def test_simple_estimator(estimator, build_dataset, with_preprocessor): estimator.set_params(preprocessor=preprocessor) set_random_state(estimator) - estimator.fit(*remove_y_quadruplets(estimator, tuples_train, y_train)) + estimator.fit(*remove_y(estimator, tuples_train, y_train)) check_score(estimator, tuples_test, y_test) check_predict(estimator, tuples_test) @@ -329,62 +327,53 @@ def test_estimators_fit_returns_self(estimator, build_dataset, input_data, labels, preprocessor, _ = build_dataset(with_preprocessor) estimator = clone(estimator) estimator.set_params(preprocessor=preprocessor) - assert estimator.fit(*remove_y_quadruplets(estimator, - input_data, - labels)) is estimator + assert estimator.fit(*remove_y(estimator, input_data, labels)) is estimator @pytest.mark.parametrize('with_preprocessor', [True, False]) -@pytest.mark.parametrize('estimator, build_dataset', metric_learners, - ids=ids_metric_learners) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners_pipeline, + ids=ids_metric_learners_pipeline) def test_pipeline_consistency(estimator, build_dataset, with_preprocessor): # Adapted from scikit learn # check that make_pipeline(est) gives same score as est - # we do this test on all except quadruplets (since they don't have a y - # in fit): - if estimator.__class__.__name__ not in [e.__class__.__name__ - for (e, _) in - quadruplets_learners]: - input_data, y, preprocessor, _ = build_dataset(with_preprocessor) - - def make_random_state(estimator, in_pipeline): - rs = {} - name_estimator = estimator.__class__.__name__ - if name_estimator[-11:] == '_Supervised': - name_param = 'random_state' - if in_pipeline: - name_param = name_estimator.lower() + '__' + name_param - rs[name_param] = check_random_state(0) - return rs - estimator = clone(estimator) - estimator.set_params(preprocessor=preprocessor) - pipeline = make_pipeline(estimator) - estimator.fit(*remove_y_quadruplets(estimator, input_data, y), - **make_random_state(estimator, False)) - pipeline.fit(*remove_y_quadruplets(estimator, input_data, y), - **make_random_state(estimator, True)) - - if hasattr(estimator, 'score'): - result = estimator.score(*remove_y_quadruplets(estimator, - input_data, - y)) - result_pipe = pipeline.score(*remove_y_quadruplets(estimator, - input_data, - y)) - assert_allclose_dense_sparse(result, result_pipe) + input_data, y, preprocessor, _ = build_dataset(with_preprocessor) - if hasattr(estimator, 'predict'): - result = estimator.predict(input_data) - result_pipe = pipeline.predict(input_data) - assert_allclose_dense_sparse(result, result_pipe) + def make_random_state(estimator, in_pipeline): + rs = {} + name_estimator = estimator.__class__.__name__ + if name_estimator[-11:] == '_Supervised': + name_param = 'random_state' + if in_pipeline: + name_param = name_estimator.lower() + '__' + name_param + rs[name_param] = check_random_state(0) + return rs - if issubclass(estimator.__class__, TransformerMixin): - if hasattr(estimator, 'transform'): - result = estimator.transform(input_data) - result_pipe = pipeline.transform(input_data) - assert_allclose_dense_sparse(result, result_pipe) + estimator = clone(estimator) + estimator.set_params(preprocessor=preprocessor, + **make_random_state(estimator, False)) + pipeline = make_pipeline(estimator) + estimator.fit(input_data, y) + estimator.set_params(preprocessor=preprocessor) + pipeline.set_params(**make_random_state(estimator, True)) + pipeline.fit(input_data, y) + + if hasattr(estimator, 'score'): + result = estimator.score(input_data, y) + result_pipe = pipeline.score(input_data, y) + assert_allclose_dense_sparse(result, result_pipe) + + if hasattr(estimator, 'predict'): + result = estimator.predict(input_data) + result_pipe = pipeline.predict(input_data) + assert_allclose_dense_sparse(result, result_pipe) + + if issubclass(estimator.__class__, TransformerMixin): + if hasattr(estimator, 'transform'): + result = estimator.transform(input_data) + result_pipe = pipeline.transform(input_data) + assert_allclose_dense_sparse(result, result_pipe) @pytest.mark.parametrize('with_preprocessor', [True, False]) @@ -398,7 +387,7 @@ def test_dict_unchanged(estimator, build_dataset, with_preprocessor): estimator.set_params(preprocessor=preprocessor) if hasattr(estimator, "n_components"): estimator.n_components = 1 - estimator.fit(*remove_y_quadruplets(estimator, input_data, labels)) + estimator.fit(*remove_y(estimator, input_data, labels)) def check_dict(): assert estimator.__dict__ == dict_before, ( @@ -429,7 +418,7 @@ def test_dont_overwrite_parameters(estimator, build_dataset, estimator.n_components = 1 dict_before_fit = estimator.__dict__.copy() - estimator.fit(*remove_y_quadruplets(estimator, input_data, labels)) + estimator.fit(*remove_y(estimator, input_data, labels)) dict_after_fit = estimator.__dict__ public_keys_after_fit = [key for key in dict_after_fit.keys() diff --git a/test/test_triplets_classifiers.py b/test/test_triplets_classifiers.py index 8cedd8cc..10393919 100644 --- a/test/test_triplets_classifiers.py +++ b/test/test_triplets_classifiers.py @@ -14,7 +14,7 @@ def test_predict_only_one_or_minus_one(estimator, build_dataset, with_preprocessor): """Test that all predicted values are either +1 or -1""" - input_data, preprocessor = build_dataset(with_preprocessor) + input_data, _, preprocessor, _ = build_dataset(with_preprocessor) estimator = clone(estimator) estimator.set_params(preprocessor=preprocessor) set_random_state(estimator) @@ -33,7 +33,7 @@ def test_raise_not_fitted_error_if_not_fitted(estimator, build_dataset, with_preprocessor): """Test that a NotFittedError is raised if someone tries to predict and the metric learner has not been fitted.""" - input_data, preprocessor = build_dataset(with_preprocessor) + input_data, _, preprocessor, _ = build_dataset(with_preprocessor) estimator = clone(estimator) estimator.set_params(preprocessor=preprocessor) set_random_state(estimator) @@ -46,8 +46,7 @@ def test_raise_not_fitted_error_if_not_fitted(estimator, build_dataset, def test_accuracy_toy_example(estimator, build_dataset): """Test that the default scoring for triplets (accuracy) works on some toy example""" - triplets, X = build_dataset(with_preprocessor=True) - triplets = X[triplets] + triplets, _, _, X = build_dataset(with_preprocessor=False) estimator = clone(estimator) set_random_state(estimator) estimator.fit(triplets) diff --git a/test/test_utils.py b/test/test_utils.py index a4cf86f4..fdcb864a 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -16,14 +16,13 @@ from metric_learn import (ITML, LSML, MMC, RCA, SDML, Covariance, LFDA, LMNN, MLKR, NCA, ITML_Supervised, LSML_Supervised, MMC_Supervised, RCA_Supervised, SDML_Supervised, - Constraints) + SCML, SCML_Supervised, Constraints) from metric_learn.base_metric import (ArrayIndexer, MahalanobisMixin, _PairsClassifierMixin, _TripletsClassifierMixin, _QuadrupletsClassifierMixin) from metric_learn.exceptions import PreprocessorError, NonPSDError from sklearn.datasets import make_regression, make_blobs, load_iris -from metric_learn.lsml import _BaseLSML SEED = 42 @@ -92,25 +91,10 @@ def build_triplets(with_preprocessor=False): triplets = constraints.generate_knntriplets(X, k_genuine=3, k_impostor=4) if with_preprocessor: # if preprocessor, we build a 2D array of triplets of indices - return triplets, X + return Dataset(triplets, np.ones(len(triplets)), X, np.arange(len(X))) else: # if not, we build a 3D array of triplets of samples - return X[triplets], None - - -class mock_triplet_LSML(_BaseLSML, _TripletsClassifierMixin): - # Mock Triplet learner from LSML which is a quadruplets learner - # in order to test TripletClassifierMixin basic methods - - _tuple_size = 4 - - def fit(self, triplets, weights=None): - quadruplets = triplets[:, [0, 1, 0, 2]] - return self._fit(quadruplets, weights=weights) - - def decision_function(self, triplets): - self._tuple_size = 3 - return _TripletsClassifierMixin.decision_function(self, triplets) + return Dataset(X[triplets], np.ones(len(triplets)), None, X) def build_quadruplets(with_preprocessor=False): @@ -133,7 +117,7 @@ def build_quadruplets(with_preprocessor=False): [learner for (learner, _) in quadruplets_learners])) -triplets_learners = [(mock_triplet_LSML(), build_triplets)] +triplets_learners = [(SCML(), build_triplets)] ids_triplets_learners = list(map(lambda x: x.__class__.__name__, [learner for (learner, _) in triplets_learners])) @@ -155,7 +139,8 @@ def build_quadruplets(with_preprocessor=False): (MMC_Supervised(max_iter=5), build_classification), (RCA_Supervised(num_chunks=5), build_classification), (SDML_Supervised(prior='identity', balance_param=1e-5), - build_classification)] + build_classification), + (SCML_Supervised(), build_classification)] ids_classifiers = list(map(lambda x: x.__class__.__name__, [learner for (learner, _) in classifiers])) @@ -165,10 +150,12 @@ def build_quadruplets(with_preprocessor=False): [learner for (learner, _) in regressors])) WeaklySupervisedClasses = (_PairsClassifierMixin, + _TripletsClassifierMixin, _QuadrupletsClassifierMixin) -tuples_learners = pairs_learners + quadruplets_learners -ids_tuples_learners = ids_pairs_learners + ids_quadruplets_learners +tuples_learners = pairs_learners + triplets_learners + quadruplets_learners +ids_tuples_learners = ids_pairs_learners + ids_triplets_learners \ + + ids_quadruplets_learners supervised_learners = classifiers + regressors ids_supervised_learners = ids_classifiers + ids_regressors @@ -176,14 +163,17 @@ def build_quadruplets(with_preprocessor=False): metric_learners = tuples_learners + supervised_learners ids_metric_learners = ids_tuples_learners + ids_supervised_learners +metric_learners_pipeline = pairs_learners + supervised_learners +ids_metric_learners_pipeline = ids_pairs_learners + ids_supervised_learners + -def remove_y_quadruplets(estimator, X, y): - """Quadruplets learners have no y in fit, but to write test for all - estimators, it is convenient to have this function, that will return X and y - if the estimator needs a y to fit on, and just X otherwise.""" +def remove_y(estimator, X, y): + """Quadruplets and triplets learners have no y in fit, but to write test for + all estimators, it is convenient to have this function, that will return X + and y if the estimator needs a y to fit on, and just X otherwise.""" + no_y_fit = quadruplets_learners + triplets_learners if estimator.__class__.__name__ in [e.__class__.__name__ - for (e, _) in - quadruplets_learners]: + for (e, _) in no_y_fit]: return (X,) else: return (X, y) @@ -831,13 +821,12 @@ def test_error_message_tuple_size(estimator, _): per tuple, it throws an error message""" estimator = clone(estimator) set_random_state(estimator) - invalid_pairs = np.array([[[1.3, 6.3], [3., 6.8], [6.5, 4.4]], - [[1.9, 5.3], [1., 7.8], [3.2, 1.2]]]) + invalid_pairs = np.ones((2, 5, 2)) y = [1, 1] with pytest.raises(ValueError) as raised_err: - estimator.fit(*remove_y_quadruplets(estimator, invalid_pairs, y)) - expected_msg = ("Tuples of {} element(s) expected{}. Got tuples of 3 " - "element(s) instead (shape=(2, 3, 2)):\ninput={}.\n" + estimator.fit(*remove_y(estimator, invalid_pairs, y)) + expected_msg = ("Tuples of {} element(s) expected{}. Got tuples of 5 " + "element(s) instead (shape=(2, 5, 2)):\ninput={}.\n" .format(estimator._tuple_size, make_context(estimator), invalid_pairs)) assert str(raised_err.value) == expected_msg @@ -911,35 +900,21 @@ def test_same_with_or_without_preprocessor(estimator, build_dataset): dataset_formed.data, random_state=SEED) - def make_random_state(estimator): - rs = {} - if estimator.__class__.__name__[-11:] == '_Supervised': - rs['random_state'] = check_random_state(SEED) - return rs - estimator_with_preprocessor = clone(estimator) set_random_state(estimator_with_preprocessor) estimator_with_preprocessor.set_params(preprocessor=X) - estimator_with_preprocessor.fit(*remove_y_quadruplets(estimator, - indices_train, - y_train), - **make_random_state(estimator)) + estimator_with_preprocessor.fit(*remove_y(estimator, indices_train, y_train)) estimator_without_preprocessor = clone(estimator) set_random_state(estimator_without_preprocessor) estimator_without_preprocessor.set_params(preprocessor=None) - estimator_without_preprocessor.fit(*remove_y_quadruplets(estimator, - formed_train, - y_train), - **make_random_state(estimator)) + estimator_without_preprocessor.fit(*remove_y(estimator, formed_train, + y_train)) estimator_with_prep_formed = clone(estimator) set_random_state(estimator_with_prep_formed) estimator_with_prep_formed.set_params(preprocessor=X) - estimator_with_prep_formed.fit(*remove_y_quadruplets(estimator, - indices_train, - y_train), - **make_random_state(estimator)) + estimator_with_prep_formed.fit(*remove_y(estimator, indices_train, y_train)) # test prediction methods for method in ["predict", "decision_function"]: