Skip to content

[MRG] Create new Mahalanobis mixin #96

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
7b3d739
WIP create MahalanobisMixin
May 25, 2018
f21cc85
ENH Update algorithms with Mahalanobis Mixin:
May 25, 2018
6f8a115
Merge branch 'new_api_design' into feat/mahalanobis_class
Jun 11, 2018
f9e3c82
FIX: add missing import
Jun 11, 2018
1a32c11
FIX: update sklearn's function check_no_fit_attributes_set_in_init to…
Jun 11, 2018
d0f5019
FIX: take function ``_get_args`` from scikit-learn's PR https://githu…
Jun 11, 2018
eba2a60
ENH: add transformer_ attribute and improve docstring
Jun 14, 2018
b5d966f
WIP: move transform() in BaseMetricLearner to transformer_from_metric…
Jun 18, 2018
ee0d1bd
WIP: refactor metric to original formulation: a function, with result…
Jun 18, 2018
6b5a3b5
WIP: make all Mahalanobis Metric Learner algorithms have transformer_…
Jun 19, 2018
6eb65ac
ENH Add score_pairs function
Jun 25, 2018
35ece36
TST add test on toy example for score_pairs
Jun 26, 2018
dca6838
ENH Add embed function
Jun 27, 2018
3254ce3
FIX fix error in slicing of quadruplets
Jun 27, 2018
e209b21
FIX minor corrections
Jun 27, 2018
abea7de
FIX minor corrections
Jun 27, 2018
65e794a
FIX fix PEP8 errors
Jun 27, 2018
12b5429
FIX remove possible one-sample scoring from docstring for now
Jun 27, 2018
eff278e
REF rename n_features_out to num_dims to be more coherent with curren…
Jun 27, 2018
810d191
MAINT: Adress https://github.com/metric-learn/metric-learn/pull/96#pu…
Jul 24, 2018
585b5d2
ENH: Add check_tuples
Jul 24, 2018
af0a3ac
FIX: fix parenthesis
Jul 24, 2018
f2b0163
ENH: put docstring of transformer_ in each metric learner
Aug 22, 2018
3c37fd7
FIX: style knitpicks to uniformize transformer_ docstring with childs
Aug 22, 2018
912c1db
FIX: make transformer_from_metric public
Aug 22, 2018
0e0ebf1
Address https://github.com/metric-learn/metric-learn/pull/96#pullrequ…
Aug 23, 2018
31350e8
FIX: fix pairwise distances check
Aug 23, 2018
d1f811b
FIX: ensure random state is set in all tests
Aug 23, 2018
4dd8990
FIX: fix test with real value to test in check_tuples
Aug 23, 2018
779a93a
FIX: update MetricTransformer to be abstract method and have the appr…
Sep 3, 2018
75d4ad2
Merge branch 'feat/mahalanobis_class' of https://github.com/wdevazelh…
Sep 3, 2018
657cdcd
MAINT: make BaseMetricLearner and MetricTransformer abstract
Sep 3, 2018
131ccbb
MAINT: remove __init__ method from BaseMetricLearner
Sep 3, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/sandwich.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def sandwich_demo():

for ax_num, ml in enumerate(mls, start=3):
ml.fit(x, y)
tx = ml.transform()
tx = ml.transform(x)
ml_knn = nearest_neighbors(tx, k=2)
ax = plt.subplot(3, 2, ax_num)
plot_sandwich_data(tx, y, axis=ax)
Expand Down
40 changes: 39 additions & 1 deletion metric_learn/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,42 @@ def vector_norm(X):
return np.apply_along_axis(np.linalg.norm, 1, X)
else:
def vector_norm(X):
return np.linalg.norm(X, axis=1)
return np.linalg.norm(X, axis=1)


def check_tuples(tuples):
"""Check that the input is a valid 3D array representing a dataset of tuples.

Equivalent of `check_array` in scikit-learn.

Parameters
----------
tuples : object
The tuples to check.

Returns
-------
tuples_valid : object
The validated input.
"""
# If input is scalar raise error
if np.isscalar(tuples):
raise ValueError(
"Expected 3D array, got scalar instead. Cannot apply this function on "
"scalars.")
# If input is 1D raise error
if len(tuples.shape) == 1:
raise ValueError(
"Expected 3D array, got 1D array instead:\ntuples={}.\n"
"Reshape your data using tuples.reshape(1, -1, 1) if it contains a "
"single tuple and the points in the tuple have a single "
"feature.".format(tuples))
# If input is 2D raise error
if len(tuples.shape) == 2:
raise ValueError(
"Expected 3D array, got 2D array instead:\ntuples={}.\n"
"Reshape your data either using tuples.reshape(-1, {}, 1) if "
"your data has a single feature or tuples.reshape(1, {}, -1) "
"if it contains a single tuple.".format(tuples, tuples.shape[1],
tuples.shape[0]))
return tuples
187 changes: 144 additions & 43 deletions metric_learn/base_metric.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,148 @@
from numpy.linalg import cholesky
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.base import BaseEstimator
from sklearn.utils.validation import check_array
from sklearn.metrics import roc_auc_score
import numpy as np
from abc import ABCMeta, abstractmethod
import six
from ._util import check_tuples


class BaseMetricLearner(BaseEstimator):
def __init__(self):
raise NotImplementedError('BaseMetricLearner should not be instantiated')
class BaseMetricLearner(six.with_metaclass(ABCMeta, BaseEstimator)):

def metric(self):
"""Computes the Mahalanobis matrix from the transformation matrix.
@abstractmethod
def score_pairs(self, pairs):
"""Returns the score between pairs
(can be a similarity, or a distance/metric depending on the algorithm)

.. math:: M = L^{\\top} L
Parameters
----------
pairs : `numpy.ndarray`, shape=(n_samples, 2, n_features)
3D array of pairs.

Returns
-------
M : (d x d) matrix
scores: `numpy.ndarray` of shape=(n_pairs,)
The score of every pair.
"""
L = self.transformer()
return L.T.dot(L)

def transformer(self):
"""Computes the transformation matrix from the Mahalanobis matrix.

L = cholesky(M).T
class MetricTransformer(six.with_metaclass(ABCMeta)):

@abstractmethod
def transform(self, X):
"""Applies the metric transformation.

Parameters
----------
X : (n x d) matrix
Data to transform.

Returns
-------
L : upper triangular (d x d) matrix
transformed : (n x d) matrix
Input data transformed to the metric space by :math:`XL^{\\top}`
"""
return cholesky(self.metric()).T


class MetricTransformer(TransformerMixin):
class MahalanobisMixin(six.with_metaclass(ABCMeta, BaseMetricLearner,
MetricTransformer)):
"""Mahalanobis metric learning algorithms.

Algorithm that learns a Mahalanobis (pseudo) distance :math:`d_M(x, x')`,
defined between two column vectors :math:`x` and :math:`x'` by: :math:`d_M(x,
x') = \sqrt{(x-x')^T M (x-x')}`, where :math:`M` is a learned symmetric
positive semi-definite (PSD) matrix. The metric between points can then be
expressed as the euclidean distance between points embedded in a new space
through a linear transformation. Indeed, the above matrix can be decomposed
into the product of two transpose matrices (through SVD or Cholesky
decomposition): :math:`d_M(x, x')^2 = (x-x')^T M (x-x') = (x-x')^T L^T L
(x-x') = (L x - L x')^T (L x- L x')`

Attributes
----------
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
The learned linear transformation ``L``.
"""

def score_pairs(self, pairs):
"""Returns the learned Mahalanobis distance between pairs.

This distance is defined as: :math:`d_M(x, x') = \sqrt{(x-x')^T M (x-x')}`
where ``M`` is the learned Mahalanobis matrix, for every pair of points
``x`` and ``x'``. This corresponds to the euclidean distance between
embeddings of the points in a new space, obtained through a linear
transformation. Indeed, we have also: :math:`d_M(x, x') = \sqrt{(x_e -
x_e')^T (x_e- x_e')}`, with :math:`x_e = L x` (See
:class:`MahalanobisMixin`).

def transform(self, X=None):
"""Applies the metric transformation.
Parameters
----------
pairs : `numpy.ndarray`, shape=(n_samples, 2, n_features)
3D array of pairs, or 2D array of one pair.

Returns
-------
scores: `numpy.ndarray` of shape=(n_pairs,)
The learned Mahalanobis distance for every pair.
"""
pairs = check_tuples(pairs)
pairwise_diffs = self.transform(pairs[:, 1, :] - pairs[:, 0, :])
# (for MahalanobisMixin, the embedding is linear so we can just embed the
# difference)
return np.sqrt(np.sum(pairwise_diffs**2, axis=-1))

def transform(self, X):
"""Embeds data points in the learned linear embedding space.

Transforms samples in ``X`` into ``X_embedded``, samples inside a new
embedding space such that: ``X_embedded = X.dot(L.T)``, where ``L`` is
the learned linear transformation (See :class:`MahalanobisMixin`).

Parameters
----------
X : (n x d) matrix, optional
Data to transform. If not supplied, the training data will be used.
X : `numpy.ndarray`, shape=(n_samples, n_features)
The data points to embed.

Returns
-------
transformed : (n x d) matrix
Input data transformed to the metric space by :math:`XL^{\\top}`
X_embedded : `numpy.ndarray`, shape=(n_samples, num_dims)
The embedded data points.
"""
X_checked = check_array(X, accept_sparse=True)
return X_checked.dot(self.transformer_.T)

def metric(self):
return self.transformer_.T.dot(self.transformer_)

def transformer_from_metric(self, metric):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should probably be a private method?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed

"""Computes the transformation matrix from the Mahalanobis matrix.

Since by definition the metric `M` is positive semi-definite (PSD), it
admits a Cholesky decomposition: L = cholesky(M).T. However, currently the
computation of the Cholesky decomposition used does not support
non-definite matrices. If the metric is not definite, this method will
return L = V.T w^( -1/2), with M = V*w*V.T being the eigenvector
decomposition of M with the eigenvalues in the diagonal matrix w and the
columns of V being the eigenvectors. If M is diagonal, this method will
just return its elementwise square root (since the diagonalization of
the matrix is itself).

Returns
-------
L : (d x d) matrix
"""
if X is None:
X = self.X_

if np.allclose(metric, np.diag(np.diag(metric))):
return np.sqrt(metric)
elif not np.isclose(np.linalg.det(metric), 0):
return cholesky(metric).T
else:
X = check_array(X, accept_sparse=True)
L = self.transformer()
return X.dot(L.T)
w, V = np.linalg.eigh(metric)
return V.T * np.sqrt(np.maximum(0, w[:, None]))


class _PairsClassifierMixin:
class _PairsClassifierMixin(BaseMetricLearner):

def predict(self, pairs):
"""Predicts the learned metric between input pairs.
Expand All @@ -74,11 +160,11 @@ def predict(self, pairs):
y_predicted : `numpy.ndarray` of floats, shape=(n_constraints,)
The predicted learned metric value between samples in every pair.
"""
pairwise_diffs = pairs[:, 0, :] - pairs[:, 1, :]
return np.sqrt(np.sum(pairwise_diffs.dot(self.metric()) * pairwise_diffs,
axis=1))
pairs = check_tuples(pairs)
return self.score_pairs(pairs)

def decision_function(self, pairs):
pairs = check_tuples(pairs)
return self.predict(pairs)

def score(self, pairs, y):
Expand All @@ -104,12 +190,32 @@ def score(self, pairs, y):
score : float
The ``roc_auc`` score.
"""
pairs = check_tuples(pairs)
return roc_auc_score(y, self.decision_function(pairs))


class _QuadrupletsClassifierMixin:
class _QuadrupletsClassifierMixin(BaseMetricLearner):

def predict(self, quadruplets):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be more logical if predict would compute accuracy on quadruplets (proportion of quadruplets correctly ordered) and score would compute the difference between distances

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that having a predict and a decision_function doing the same thing is not ideal, and in the case of quadruplets this can be fixed easily since there is no threshold to fix like in pairs
But I did not really get what you mean ?
Since score is the scikit-learn like function for scoring it should return one scalar, whereas predict should return a sample-wise output
I guess the most coherent with scikit-learn would be that predict would output a binary sign (or 0 or 1) depending on the ordering of pairs in the quadruplet, and decision_function would return the differences between distances, since it is a float-type score (like predict_proba but without being necessarily between 0 and 1)
score would return the accuracy of the predict function (proportion of quadruplets correctly ordered)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is what I meant, sorry for the confusion

"""Predicts the ordering between sample distances in input quadruplets.

For each quadruplet, returns 1 if the quadruplet is in the right order (
first pair is more similar than second pair), and -1 if not.

Parameters
----------
quadruplets : array-like, shape=(n_constraints, 4, n_features)
Input quadruplets.

Returns
-------
prediction : `numpy.ndarray` of floats, shape=(n_constraints,)
Predictions of the ordering of pairs, for each quadruplet.
"""
quadruplets = check_tuples(quadruplets)
return np.sign(self.decision_function(quadruplets))

def decision_function(self, quadruplets):
"""Predicts differences between sample distances in input quadruplets.

For each quadruplet of samples, computes the difference between the learned
Expand All @@ -122,18 +228,12 @@ def predict(self, quadruplets):

Returns
-------
prediction : `numpy.ndarray` of floats, shape=(n_constraints,)
decision_function : `numpy.ndarray` of floats, shape=(n_constraints,)
Metric differences.
"""
similar_diffs = quadruplets[:, 0, :] - quadruplets[:, 1, :]
dissimilar_diffs = quadruplets[:, 2, :] - quadruplets[:, 3, :]
return (np.sqrt(np.sum(similar_diffs.dot(self.metric()) *
similar_diffs, axis=1)) -
np.sqrt(np.sum(dissimilar_diffs.dot(self.metric()) *
dissimilar_diffs, axis=1)))

def decision_function(self, quadruplets):
return self.predict(quadruplets)
quadruplets = check_tuples(quadruplets)
return (self.score_pairs(quadruplets[:, :2, :]) -
self.score_pairs(quadruplets[:, 2:, :]))

def score(self, quadruplets, y=None):
"""Computes score on input quadruplets
Expand All @@ -154,4 +254,5 @@ def score(self, quadruplets, y=None):
score : float
The quadruplets score.
"""
return - np.mean(np.sign(self.decision_function(quadruplets)))
quadruplets = check_tuples(quadruplets)
return -np.mean(self.predict(quadruplets))
19 changes: 14 additions & 5 deletions metric_learn/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,24 @@
from __future__ import absolute_import
import numpy as np
from sklearn.utils.validation import check_array
from sklearn.base import TransformerMixin

from .base_metric import BaseMetricLearner, MetricTransformer
from .base_metric import MahalanobisMixin


class Covariance(BaseMetricLearner, MetricTransformer):
class Covariance(MahalanobisMixin, TransformerMixin):
"""Covariance metric (baseline method)

Attributes
----------
transformer_ : `numpy.ndarray`, shape=(num_dims, n_features)
The linear transformation ``L`` deduced from the learned Mahalanobis
metric (See :meth:`transformer_from_metric`.)
"""

def __init__(self):
pass

def metric(self):
return self.M_

def fit(self, X, y=None):
"""
X : data matrix, (n x d)
Expand All @@ -33,4 +40,6 @@ def fit(self, X, y=None):
self.M_ = 1./self.M_
else:
self.M_ = np.linalg.inv(self.M_)

self.transformer_ = self.transformer_from_metric(check_array(self.M_))
return self
Loading