Skip to content

SCML : Sparse Compositional Metric Learning #278

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 71 commits into from
Jun 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
cab9ce7
scml first commit
grudloff Feb 13, 2020
41e2cef
add scml to __init__.py
grudloff Feb 13, 2020
8ee9a87
fix in components calculation
grudloff Feb 14, 2020
f201f9f
remove triplet generator, added in triplets PR
grudloff Feb 18, 2020
87c3da0
change init&fit interface, faster compute & others
grudloff Feb 19, 2020
21a6fc0
added coments & docstrings, small code changes
grudloff Feb 19, 2020
5453c75
typos and added choice of gamma & output_iter
grudloff Feb 19, 2020
5f8d885
some small improvements
grudloff Feb 20, 2020
1083f57
lda tail handling rollback
grudloff Feb 20, 2020
78b9658
performance improvement by precomputing rand_ints
grudloff Feb 20, 2020
bc203f5
small fix in components computation
grudloff Mar 5, 2020
224e861
Merge branch 'master' of https://github.com/scikit-learn-contrib/metr…
grudloff Mar 5, 2020
ecdb74d
flake8 fix
grudloff Mar 5, 2020
f82f3b3
SCML_global fit fix & other small changes
grudloff Mar 5, 2020
2018d09
Proper use of init vars and unsup bases generation
grudloff Mar 11, 2020
e9e654c
triplet dataset format & remove_y for triplets
grudloff Mar 11, 2020
686b7eb
adaptation with dataset format
grudloff Mar 11, 2020
4ff5f4c
remove labels for triplets and quadruplets
grudloff Mar 11, 2020
dc50dc7
remove labels
grudloff Mar 11, 2020
10d1d04
remove labels & old fit random_state asignation
grudloff Mar 11, 2020
2814662
compliant with older numpy versions
grudloff Mar 11, 2020
e9f4362
small typo and fix order
grudloff Mar 11, 2020
a9d1a02
fix n_basis check
grudloff Mar 11, 2020
b1c01fd
initialize_basis_supervised and some refactoring
grudloff Mar 12, 2020
f4217c8
proper n_basis handling
grudloff Mar 12, 2020
8dd0fbe
scml specific tests
grudloff Mar 12, 2020
8c6567e
remove small mistake
grudloff Mar 12, 2020
b8bc94e
test user input basis
grudloff Mar 12, 2020
cfad0b9
Changed names and messages and some refactoring
grudloff Mar 17, 2020
04c8433
triplets in features form passed to _fit
grudloff Mar 17, 2020
e67ff82
change indeces handlig and edge case fix
grudloff Mar 18, 2020
10efc46
name change and typos
grudloff Mar 18, 2020
ed6d42b
improve test_components_is_2D
grudloff Mar 18, 2020
932ff3f
Replace triplet_diffs option by better aproach
grudloff Mar 20, 2020
534cd3f
some comments, docstring and refactoring
grudloff Mar 20, 2020
576fbcb
fix bad triplet set
grudloff Mar 20, 2020
2bee8cc
flake8 fix
grudloff Mar 20, 2020
895f28b
SCML doc first draft
grudloff Mar 20, 2020
8c4ef22
find neighbors for every class only once
grudloff Mar 23, 2020
26da826
improve some docstring and warnings
grudloff Mar 23, 2020
54525d7
add sklearn compat test
grudloff Mar 23, 2020
4140585
changes to doc
grudloff Mar 23, 2020
e54a741
fix and improve tests
grudloff Mar 23, 2020
b84f8b1
use components_from_metric
grudloff Mar 24, 2020
4af49da
change TestSCML to object and parametrize tests
grudloff Mar 24, 2020
13f7088
fix test_iris
grudloff Mar 24, 2020
78e5084
use model._authorized_basis and other fixes
grudloff Mar 24, 2020
daaf5b0
verbose test
grudloff Mar 24, 2020
9338a7d
revert sum_where
grudloff Mar 24, 2020
33dae25
small n_basis warning instead of error
grudloff Mar 25, 2020
dbcf138
add test iris on triplet_diffs
grudloff Mar 25, 2020
34917d3
test lda & triplet_diffs
grudloff Mar 25, 2020
f6f848d
improved messages
grudloff Mar 25, 2020
38fd80b
remove quadruplets and triplets from pipeline test
grudloff Mar 25, 2020
ad47f7f
test big n_features
grudloff Mar 25, 2020
9541b75
Correct output iters
grudloff Mar 26, 2020
ca6c69d
output_iter on supervised and improved verbose
grudloff Mar 26, 2020
9902dfe
flake8 fix
grudloff Mar 26, 2020
fdf3067
bases generation test comments
grudloff Mar 26, 2020
517cce4
change big_n_basis_lda error msg
grudloff Mar 26, 2020
06d92f2
test generated n_basis and basis shape
grudloff Mar 26, 2020
5fd80e9
add mini batch optimization
grudloff Mar 26, 2020
18289e0
correct iter convention
grudloff Mar 27, 2020
bdf981e
eliminate n_samples = 1000
grudloff Mar 27, 2020
c551344
batch grad refactored
grudloff Mar 27, 2020
c02e6e5
adagrad adaptive learning
grudloff Mar 30, 2020
9fd186c
int input checks and tests
grudloff Mar 30, 2020
8dbaad9
flake8 fix
grudloff Mar 31, 2020
95e5fe8
no double division and smaller triplets arrays
grudloff Apr 1, 2020
2aed606
minor grammar fixes
perimosocordiae Jun 17, 2020
cba1cf6
minor formatting tweaks
perimosocordiae Jun 17, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/metric_learn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Supervised Learning Algorithms
metric_learn.MMC_Supervised
metric_learn.SDML_Supervised
metric_learn.RCA_Supervised
metric_learn.SCML_Supervised

Weakly Supervised Learning Algorithms
-------------------------------------
Expand All @@ -45,6 +46,7 @@ Weakly Supervised Learning Algorithms
metric_learn.LSML
metric_learn.MMC
metric_learn.SDML
metric_learn.SCML

Unsupervised Learning Algorithms
--------------------------------
Expand Down
73 changes: 65 additions & 8 deletions doc/weakly_supervised.rst
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,63 @@ of triplets that have the right predicted ordering.
Algorithms
----------

.. _scml:

:py:class:`SCML <metric_learn.SCML>`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Sparse Compositional Metric Learning
(:py:class:`SCML <metric_learn.SCML>`)

`SCML` learns a squared Mahalanobis distance from triplet constraints by
optimizing sparse positive weights assigned to a set of :math:`K` rank-one
PSD bases. This can be formulated as an optimization problem with only
:math:`K` parameters, that can be solved with an efficient stochastic
composite scheme.

The Mahalanobis matrix :math:`M` is built from a basis set :math:`B = \{b_i\}_{i=\{1,...,K\}}`
weighted by a :math:`K` dimensional vector :math:`w = \{w_i\}_{i=\{1,...,K\}}` as:

.. math::

M = \sum_{i=1}^K w_i b_i b_i^T = B \cdot diag(w) \cdot B^T \quad w_i \geq 0

Learning :math:`M` in this form makes it PSD by design, as it is a
nonnegative sum of PSD matrices. The basis set :math:`B` is fixed in advance
and it is possible to construct it from the data. The optimization problem
over :math:`w` is formulated as a classic margin-based hinge loss function
involving the set :math:`C` of triplets. A regularization :math:`\ell_1`
is added to yield a sparse combination. The formulation is the following:

.. math::

\min_{w\geq 0} \sum_{(x_i,x_j,x_k)\in C} [1 + d_w(x_i,x_j)-d_w(x_i,x_k)]_+ + \beta||w||_1

where :math:`[\cdot]_+` is the hinge loss.

.. topic:: Example Code:

::

from metric_learn import SCML

triplets = [[[1.2, 7.5], [1.3, 1.5], [6.2, 9.7]],
[[1.3, 4.5], [3.2, 4.6], [5.4, 5.4]],
[[3.2, 7.5], [3.3, 1.5], [8.2, 9.7]],
[[3.3, 4.5], [5.2, 4.6], [7.4, 5.4]]]

scml = SCML()
scml.fit(triplets)

.. topic:: References:

.. [1] Y. Shi, A. Bellet and F. Sha. `Sparse Compositional Metric Learning.
<http://researchers.lille.inria.fr/abellet/papers/aaai14.pdf>`_. \
(AAAI), 2014.

.. [2] Adapted from original \
`Matlab implementation.<https://github.com/bellet/SCML>`_.


.. _learning_on_quadruplets:

Expand Down Expand Up @@ -829,13 +886,13 @@ extension leads to more stable estimation when the dimension is high and
only a small amount of constraints is given.

The loss function of each constraint
:math:`d(\mathbf{x}_a, \mathbf{x}_b) < d(\mathbf{x}_c, \mathbf{x}_d)` is
:math:`d(\mathbf{x}_i, \mathbf{x}_j) < d(\mathbf{x}_k, \mathbf{x}_l)` is
denoted as:

.. math::

H(d_\mathbf{M}(\mathbf{x}_a, \mathbf{x}_b)
- d_\mathbf{M}(\mathbf{x}_c, \mathbf{x}_d))
H(d_\mathbf{M}(\mathbf{x}_i, \mathbf{x}_j)
- d_\mathbf{M}(\mathbf{x}_k, \mathbf{x}_l))

where :math:`H(\cdot)` is the squared Hinge loss function defined as:

Expand All @@ -845,8 +902,8 @@ where :math:`H(\cdot)` is the squared Hinge loss function defined as:
\,\,x^2 \qquad x>0\end{aligned}\right.\\

The summed loss function :math:`L(C)` is the simple sum over all constraints
:math:`C = \{(\mathbf{x}_a , \mathbf{x}_b , \mathbf{x}_c , \mathbf{x}_d)
: d(\mathbf{x}_a , \mathbf{x}_b) < d(\mathbf{x}_c , \mathbf{x}_d)\}`. The
:math:`C = \{(\mathbf{x}_i , \mathbf{x}_j , \mathbf{x}_k , \mathbf{x}_l)
: d(\mathbf{x}_i , \mathbf{x}_j) < d(\mathbf{x}_k , \mathbf{x}_l)\}`. The
original paper suggested here should be a weighted sum since the confidence
or probability of each constraint might differ. However, for the sake of
simplicity and assumption of no extra knowledge provided, we just deploy
Expand All @@ -858,9 +915,9 @@ knowledge:

.. math::

\min_\mathbf{M}(D_{ld}(\mathbf{M, M_0}) + \sum_{(\mathbf{x}_a,
\mathbf{x}_b, \mathbf{x}_c, \mathbf{x}_d)\in C}H(d_\mathbf{M}(
\mathbf{x}_a, \mathbf{x}_b) - d_\mathbf{M}(\mathbf{x}_c, \mathbf{x}_c))\\
\min_\mathbf{M}(D_{ld}(\mathbf{M, M_0}) + \sum_{(\mathbf{x}_i,
\mathbf{x}_j, \mathbf{x}_k, \mathbf{x}_l)\in C}H(d_\mathbf{M}(
\mathbf{x}_i, \mathbf{x}_j) - d_\mathbf{M}(\mathbf{x}_k, \mathbf{x}_l))\\

where :math:`\mathbf{M}_0` is the prior metric matrix, set as identity
by default, :math:`D_{ld}(\mathbf{\cdot, \cdot})` is the LogDet divergence:
Expand Down
4 changes: 3 additions & 1 deletion metric_learn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
from .rca import RCA, RCA_Supervised
from .mlkr import MLKR
from .mmc import MMC, MMC_Supervised
from .scml import SCML, SCML_Supervised

from ._version import __version__

__all__ = ['Constraints', 'Covariance', 'ITML', 'ITML_Supervised',
'LMNN', 'LSML', 'LSML_Supervised', 'SDML',
'SDML_Supervised', 'NCA', 'LFDA', 'RCA', 'RCA_Supervised',
'MLKR', 'MMC', 'MMC_Supervised', '__version__']
'MLKR', 'MMC', 'MMC_Supervised', 'SCML',
'SCML_Supervised', '__version__']
Loading