diff --git a/.travis.yml b/.travis.yml index cda5b00f..f5527089 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,18 +4,19 @@ cache: pip python: - "2.7" - "3.4" + - "3.6" before_install: + - sudo apt-get install liblapack-dev - pip install --upgrade pip pytest - - pip install wheel - - pip install codecov - - if [[ $TRAVIS_PYTHON_VERSION == "3.4" ]]; - then pip install pytest-cov; + - pip install wheel cython numpy scipy scikit-learn codecov pytest-cov + - if [[ ($TRAVIS_PYTHON_VERSION == "3.6") || + ($TRAVIS_PYTHON_VERSION == "2.7")]]; then + pip install git+https://github.com/skggm/skggm.git@a0ed406586c4364ea3297a658f415e13b5cbdaf8; fi - - pip install numpy scipy scikit-learn script: - - if [[ $TRAVIS_PYTHON_VERSION == "3.4" ]]; - then pytest test --cov; - else pytest test; - fi + # we do coverage for all versions so that codecov will merge them: this + # way we will see that both paths (with or without skggm) are tested + - pytest test --cov; after_success: - bash <(curl -s https://codecov.io/bash) + diff --git a/README.rst b/README.rst index c2a0a205..e1bfca51 100644 --- a/README.rst +++ b/README.rst @@ -21,7 +21,12 @@ Metric Learning algorithms in Python. - Python 2.7+, 3.4+ - numpy, scipy, scikit-learn -- (for running the examples only: matplotlib) + +**Optional dependencies** + +- For SDML, using skggm will allow the algorithm to solve problematic cases + (install from commit `a0ed406 `_). +- For running the examples only: matplotlib **Installation/Setup** diff --git a/doc/getting_started.rst b/doc/getting_started.rst index 040adedc..2d2df25e 100644 --- a/doc/getting_started.rst +++ b/doc/getting_started.rst @@ -16,7 +16,12 @@ Alternately, download the source repository and run: - Python 2.7+, 3.4+ - numpy, scipy, scikit-learn -- (for running the examples only: matplotlib) + +**Optional dependencies** + +- For SDML, using skggm will allow the algorithm to solve problematic cases + (install from commit `a0ed406 `_). +- For running the examples only: matplotlib **Notes** diff --git a/metric_learn/constraints.py b/metric_learn/constraints.py index c4ddcae8..e591830b 100644 --- a/metric_learn/constraints.py +++ b/metric_learn/constraints.py @@ -96,6 +96,6 @@ def wrap_pairs(X, constraints): c = np.array(constraints[2]) d = np.array(constraints[3]) constraints = np.vstack((np.column_stack((a, b)), np.column_stack((c, d)))) - y = np.vstack([np.ones((len(a), 1)), - np.ones((len(c), 1))]) + y = np.concatenate([np.ones_like(a), -np.ones_like(c)]) pairs = X[constraints] return pairs, y diff --git a/metric_learn/sdml.py b/metric_learn/sdml.py index 78fc4ebc..590fbfb2 100644 --- a/metric_learn/sdml.py +++ b/metric_learn/sdml.py @@ -12,12 +12,19 @@ import warnings import numpy as np from sklearn.base import TransformerMixin -from sklearn.covariance import graph_lasso -from sklearn.utils.extmath import pinvh +from scipy.linalg import pinvh +from sklearn.covariance import graphical_lasso +from sklearn.exceptions import ConvergenceWarning from .base_metric import MahalanobisMixin, _PairsClassifierMixin from .constraints import Constraints, wrap_pairs from ._util import transformer_from_metric +try: + from inverse_covariance import quic +except ImportError: + HAS_SKGGM = False +else: + HAS_SKGGM = True class _BaseSDML(MahalanobisMixin): @@ -52,24 +59,74 @@ def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, super(_BaseSDML, self).__init__(preprocessor) def _fit(self, pairs, y): + if not HAS_SKGGM: + if self.verbose: + print("SDML will use scikit-learn's graphical lasso solver.") + else: + if self.verbose: + print("SDML will use skggm's graphical lasso solver.") pairs, y = self._prepare_inputs(pairs, y, type_of_inputs='tuples') - # set up prior M + # set up (the inverse of) the prior M if self.use_cov: X = np.vstack({tuple(row) for row in pairs.reshape(-1, pairs.shape[2])}) - M = pinvh(np.atleast_2d(np.cov(X, rowvar = False))) + prior_inv = np.atleast_2d(np.cov(X, rowvar=False)) else: - M = np.identity(pairs.shape[2]) + prior_inv = np.identity(pairs.shape[2]) diff = pairs[:, 0] - pairs[:, 1] loss_matrix = (diff.T * y).dot(diff) - P = M + self.balance_param * loss_matrix - emp_cov = pinvh(P) - # hack: ensure positive semidefinite - emp_cov = emp_cov.T.dot(emp_cov) - _, M = graph_lasso(emp_cov, self.sparsity_param, verbose=self.verbose) - - self.transformer_ = transformer_from_metric(M) + emp_cov = prior_inv + self.balance_param * loss_matrix + + # our initialization will be the matrix with emp_cov's eigenvalues, + # with a constant added so that they are all positive (plus an epsilon + # to ensure definiteness). This is empirical. + w, V = np.linalg.eigh(emp_cov) + min_eigval = np.min(w) + if min_eigval < 0.: + warnings.warn("Warning, the input matrix of graphical lasso is not " + "positive semi-definite (PSD). The algorithm may diverge, " + "and lead to degenerate solutions. " + "To prevent that, try to decrease the balance parameter " + "`balance_param` and/or to set use_covariance=False.", + ConvergenceWarning) + w -= min_eigval # we translate the eigenvalues to make them all positive + w += 1e-10 # we add a small offset to avoid definiteness problems + sigma0 = (V * w).dot(V.T) + try: + if HAS_SKGGM: + theta0 = pinvh(sigma0) + M, _, _, _, _, _ = quic(emp_cov, lam=self.sparsity_param, + msg=self.verbose, + Theta0=theta0, Sigma0=sigma0) + else: + _, M = graphical_lasso(emp_cov, alpha=self.sparsity_param, + verbose=self.verbose, + cov_init=sigma0) + raised_error = None + w_mahalanobis, _ = np.linalg.eigh(M) + not_spd = any(w_mahalanobis < 0.) + not_finite = not np.isfinite(M).all() + except Exception as e: + raised_error = e + not_spd = False # not_spd not applicable here so we set to False + not_finite = False # not_finite not applicable here so we set to False + if raised_error is not None or not_spd or not_finite: + msg = ("There was a problem in SDML when using {}'s graphical " + "lasso solver.").format("skggm" if HAS_SKGGM else "scikit-learn") + if not HAS_SKGGM: + skggm_advice = (" skggm's graphical lasso can sometimes converge " + "on non SPD cases where scikit-learn's graphical " + "lasso fails to converge. Try to install skggm and " + "rerun the algorithm (see the README.md for the " + "right version of skggm).") + msg += skggm_advice + if raised_error is not None: + msg += " The following error message was thrown: {}.".format( + raised_error) + raise RuntimeError(msg) + + self.transformer_ = transformer_from_metric(np.atleast_2d(M)) return self diff --git a/setup.py b/setup.py index 168fbcb6..dfb20fc0 100755 --- a/setup.py +++ b/setup.py @@ -38,6 +38,7 @@ extras_require=dict( docs=['sphinx', 'shinx_rtd_theme', 'numpydoc'], demo=['matplotlib'], + sdml=['skggm>=0.2.9'] ), test_suite='test', keywords=[ diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index e1eace90..ae9a8657 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -10,10 +10,15 @@ from sklearn.utils.testing import assert_warns_message from sklearn.exceptions import ConvergenceWarning from sklearn.utils.validation import check_X_y - -from metric_learn import ( - LMNN, NCA, LFDA, Covariance, MLKR, MMC, - LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised, MMC_Supervised) +try: + from inverse_covariance import quic +except ImportError: + HAS_SKGGM = False +else: + HAS_SKGGM = True +from metric_learn import (LMNN, NCA, LFDA, Covariance, MLKR, MMC, + LSML_Supervised, ITML_Supervised, SDML_Supervised, + RCA_Supervised, MMC_Supervised, SDML) # Import this specially for testing. from metric_learn.constraints import wrap_pairs from metric_learn.lmnn import python_LMNN @@ -148,28 +153,237 @@ def test_no_twice_same_objective(capsys): class TestSDML(MetricTestCase): + + @pytest.mark.skipif(HAS_SKGGM, + reason="The warning can be thrown only if skggm is " + "not installed.") + def test_sdml_supervised_raises_warning_msg_not_installed_skggm(self): + """Tests that the right warning message is raised if someone tries to + use SDML_Supervised but has not installed skggm, and that the algorithm + fails to converge""" + # TODO: remove if we don't need skggm anymore + # load_iris: dataset where we know scikit-learn's graphical lasso fails + # with a Floating Point error + X, y = load_iris(return_X_y=True) + sdml_supervised = SDML_Supervised(balance_param=0.5, use_cov=True, + sparsity_param=0.01) + msg = ("There was a problem in SDML when using scikit-learn's graphical " + "lasso solver. skggm's graphical lasso can sometimes converge on " + "non SPD cases where scikit-learn's graphical lasso fails to " + "converge. Try to install skggm and rerun the algorithm (see " + "the README.md for the right version of skggm). The following " + "error message was thrown:") + with pytest.raises(RuntimeError) as raised_error: + sdml_supervised.fit(X, y) + assert str(raised_error.value).startswith(msg) + + @pytest.mark.skipif(HAS_SKGGM, + reason="The warning can be thrown only if skggm is " + "not installed.") + def test_sdml_raises_warning_msg_not_installed_skggm(self): + """Tests that the right warning message is raised if someone tries to + use SDML but has not installed skggm, and that the algorithm fails to + converge""" + # TODO: remove if we don't need skggm anymore + # case on which we know that scikit-learn's graphical lasso fails + # because it will return a non SPD matrix + pairs = np.array([[[-10., 0.], [10., 0.]], [[0., 50.], [0., -60]]]) + y_pairs = [1, -1] + sdml = SDML(use_cov=False, balance_param=100, verbose=True) + + msg = ("There was a problem in SDML when using scikit-learn's graphical " + "lasso solver. skggm's graphical lasso can sometimes converge on " + "non SPD cases where scikit-learn's graphical lasso fails to " + "converge. Try to install skggm and rerun the algorithm (see " + "the README.md for the right version of skggm).") + with pytest.raises(RuntimeError) as raised_error: + sdml.fit(pairs, y_pairs) + assert msg == str(raised_error.value) + + @pytest.mark.skipif(not HAS_SKGGM, + reason="The warning can be thrown only if skggm is " + "installed.") + def test_sdml_raises_warning_msg_installed_skggm(self): + """Tests that the right warning message is raised if someone tries to + use SDML but has not installed skggm, and that the algorithm fails to + converge""" + # TODO: remove if we don't need skggm anymore + # case on which we know that skggm's graphical lasso fails + # because it will return non finite values + pairs = np.array([[[-10., 0.], [10., 0.]], [[0., 50.], [0., -60]]]) + y_pairs = [1, -1] + sdml = SDML(use_cov=False, balance_param=100, verbose=True) + + msg = ("There was a problem in SDML when using skggm's graphical " + "lasso solver.") + with pytest.raises(RuntimeError) as raised_error: + sdml.fit(pairs, y_pairs) + assert msg == str(raised_error.value) + + @pytest.mark.skipif(not HAS_SKGGM, + reason="The warning can be thrown only if skggm is " + "installed.") + def test_sdml_supervised_raises_warning_msg_installed_skggm(self): + """Tests that the right warning message is raised if someone tries to + use SDML_Supervised but has not installed skggm, and that the algorithm + fails to converge""" + # TODO: remove if we don't need skggm anymore + # case on which we know that skggm's graphical lasso fails + # because it will return non finite values + rng = np.random.RandomState(42) + # This example will create a diagonal em_cov with a negative coeff ( + # pathological case) + X = np.array([[-10., 0.], [10., 0.], [5., 0.], [3., 0.]]) + y = [0, 0, 1, 1] + sdml_supervised = SDML_Supervised(balance_param=0.5, use_cov=False, + sparsity_param=0.01) + msg = ("There was a problem in SDML when using skggm's graphical " + "lasso solver.") + with pytest.raises(RuntimeError) as raised_error: + sdml_supervised.fit(X, y, random_state=rng) + assert msg == str(raised_error.value) + + @pytest.mark.skipif(not HAS_SKGGM, + reason="It's only in the case where skggm is installed" + "that no warning should be thrown.") + def test_raises_no_warning_installed_skggm(self): + # otherwise we should be able to instantiate and fit SDML and it + # should raise no warning + pairs = np.array([[[-10., 0.], [10., 0.]], [[0., -55.], [0., -60]]]) + y_pairs = [1, -1] + X, y = make_classification(random_state=42) + with pytest.warns(None) as record: + sdml = SDML() + sdml.fit(pairs, y_pairs) + assert len(record) == 0 + with pytest.warns(None) as record: + sdml = SDML_Supervised(use_cov=False, balance_param=1e-5) + sdml.fit(X, y) + assert len(record) == 0 + def test_iris(self): # Note: this is a flaky test, which fails for certain seeds. # TODO: un-flake it! rs = np.random.RandomState(5555) - sdml = SDML_Supervised(num_constraints=1500) + sdml = SDML_Supervised(num_constraints=1500, use_cov=False, + balance_param=5e-5) sdml.fit(self.iris_points, self.iris_labels, random_state=rs) - csep = class_separation(sdml.transform(self.iris_points), self.iris_labels) - self.assertLess(csep, 0.25) + csep = class_separation(sdml.transform(self.iris_points), + self.iris_labels) + self.assertLess(csep, 0.22) def test_deprecation_num_labeled(self): # test that a deprecation message is thrown if num_labeled is set at # initialization # TODO: remove in v.0.6 - X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) - y = np.array([1, 0, 1, 0]) - sdml_supervised = SDML_Supervised(num_labeled=np.inf) + X, y = make_classification(random_state=42) + sdml_supervised = SDML_Supervised(num_labeled=np.inf, use_cov=False, + balance_param=5e-5) msg = ('"num_labeled" parameter is not used.' ' It has been deprecated in version 0.5.0 and will be' 'removed in 0.6.0') assert_warns_message(DeprecationWarning, msg, sdml_supervised.fit, X, y) + def test_sdml_raises_warning_non_psd(self): + """Tests that SDML raises a warning on a toy example where we know the + pseudo-covariance matrix is not PSD""" + pairs = np.array([[[-10., 0.], [10., 0.]], [[0., 50.], [0., -60]]]) + y = [1, -1] + sdml = SDML(use_cov=True, sparsity_param=0.01, balance_param=0.5) + msg = ("Warning, the input matrix of graphical lasso is not " + "positive semi-definite (PSD). The algorithm may diverge, " + "and lead to degenerate solutions. " + "To prevent that, try to decrease the balance parameter " + "`balance_param` and/or to set use_covariance=False.") + with pytest.warns(ConvergenceWarning) as raised_warning: + try: + sdml.fit(pairs, y) + except Exception: + pass + # we assert that this warning is in one of the warning raised by the + # estimator + assert msg in list(map(lambda w: str(w.message), raised_warning)) + + def test_sdml_converges_if_psd(self): + """Tests that sdml converges on a simple problem where we know the + pseudo-covariance matrix is PSD""" + pairs = np.array([[[-10., 0.], [10., 0.]], [[0., -55.], [0., -60]]]) + y = [1, -1] + sdml = SDML(use_cov=True, sparsity_param=0.01, balance_param=0.5) + sdml.fit(pairs, y) + assert np.isfinite(sdml.get_mahalanobis_matrix()).all() + + @pytest.mark.skipif(not HAS_SKGGM, + reason="sklearn's graphical_lasso can sometimes not " + "work on some non SPD problems. We test that " + "is works only if skggm is installed.") + def test_sdml_works_on_non_spd_pb_with_skggm(self): + """Test that SDML works on a certain non SPD problem on which we know + it should work, but scikit-learn's graphical_lasso does not work""" + X, y = load_iris(return_X_y=True) + sdml = SDML_Supervised(balance_param=0.5, sparsity_param=0.01, + use_cov=True) + sdml.fit(X, y) + + +@pytest.mark.skipif(not HAS_SKGGM, + reason='The message should be printed only if skggm is ' + 'installed.') +def test_verbose_has_installed_skggm_sdml(capsys): + # Test that if users have installed skggm, a message is printed telling them + # skggm's solver is used (when they use SDML) + # TODO: remove if we don't need skggm anymore + pairs = np.array([[[-10., 0.], [10., 0.]], [[0., -55.], [0., -60]]]) + y_pairs = [1, -1] + sdml = SDML(verbose=True) + sdml.fit(pairs, y_pairs) + out, _ = capsys.readouterr() + assert "SDML will use skggm's graphical lasso solver." in out + + +@pytest.mark.skipif(not HAS_SKGGM, + reason='The message should be printed only if skggm is ' + 'installed.') +def test_verbose_has_installed_skggm_sdml_supervised(capsys): + # Test that if users have installed skggm, a message is printed telling them + # skggm's solver is used (when they use SDML_Supervised) + # TODO: remove if we don't need skggm anymore + X, y = make_classification(random_state=42) + sdml = SDML_Supervised(verbose=True) + sdml.fit(X, y) + out, _ = capsys.readouterr() + assert "SDML will use skggm's graphical lasso solver." in out + + +@pytest.mark.skipif(HAS_SKGGM, + reason='The message should be printed only if skggm is ' + 'not installed.') +def test_verbose_has_not_installed_skggm_sdml(capsys): + # Test that if users have installed skggm, a message is printed telling them + # skggm's solver is used (when they use SDML) + # TODO: remove if we don't need skggm anymore + pairs = np.array([[[-10., 0.], [10., 0.]], [[0., -55.], [0., -60]]]) + y_pairs = [1, -1] + sdml = SDML(verbose=True) + sdml.fit(pairs, y_pairs) + out, _ = capsys.readouterr() + assert "SDML will use scikit-learn's graphical lasso solver." in out + + +@pytest.mark.skipif(HAS_SKGGM, + reason='The message should be printed only if skggm is ' + 'not installed.') +def test_verbose_has_not_installed_skggm_sdml_supervised(capsys): + # Test that if users have installed skggm, a message is printed telling them + # skggm's solver is used (when they use SDML_Supervised) + # TODO: remove if we don't need skggm anymore + X, y = make_classification(random_state=42) + sdml = SDML_Supervised(verbose=True, balance_param=1e-5, use_cov=False) + sdml.fit(X, y) + out, _ = capsys.readouterr() + assert "SDML will use scikit-learn's graphical lasso solver." in out + class TestNCA(MetricTestCase): def test_iris(self): diff --git a/test/test_fit_transform.py b/test/test_fit_transform.py index 118f6b90..b85e9273 100644 --- a/test/test_fit_transform.py +++ b/test/test_fit_transform.py @@ -1,3 +1,4 @@ +import pytest import unittest import numpy as np from sklearn.datasets import load_iris @@ -5,7 +6,8 @@ from metric_learn import ( LMNN, NCA, LFDA, Covariance, MLKR, - LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised, MMC_Supervised) + LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised, + MMC_Supervised) class TestFitTransform(unittest.TestCase): @@ -62,12 +64,14 @@ def test_lmnn(self): def test_sdml_supervised(self): seed = np.random.RandomState(1234) - sdml = SDML_Supervised(num_constraints=1500) + sdml = SDML_Supervised(num_constraints=1500, balance_param=1e-5, + use_cov=False) sdml.fit(self.X, self.y, random_state=seed) res_1 = sdml.transform(self.X) seed = np.random.RandomState(1234) - sdml = SDML_Supervised(num_constraints=1500) + sdml = SDML_Supervised(num_constraints=1500, balance_param=1e-5, + use_cov=False) res_2 = sdml.fit_transform(self.X, self.y, random_state=seed) assert_array_almost_equal(res_1, res_2) diff --git a/test/test_mahalanobis_mixin.py b/test/test_mahalanobis_mixin.py index 1e555e73..a0bf3b9d 100644 --- a/test/test_mahalanobis_mixin.py +++ b/test/test_mahalanobis_mixin.py @@ -10,6 +10,8 @@ from sklearn.utils.testing import set_random_state from metric_learn._util import make_context +from metric_learn.base_metric import (_QuadrupletsClassifierMixin, + _PairsClassifierMixin) from test.test_utils import ids_metric_learners, metric_learners @@ -96,7 +98,7 @@ def check_is_distance_matrix(pairwise): assert np.array_equal(pairwise, pairwise.T) # symmetry assert (pairwise.diagonal() == 0).all() # identity # triangular inequality - tol = 1e-15 + tol = 1e-12 assert (pairwise <= pairwise[:, :, np.newaxis] + pairwise[:, np.newaxis, :] + tol).all() @@ -281,5 +283,19 @@ def test_transformer_is_2D(estimator, build_dataset): # test that it works for 1 feature trunc_data = input_data[..., :1] + # we drop duplicates that might have been formed, i.e. of the form + # aabc or abcc or aabb for quadruplets, and aa for pairs. + if isinstance(estimator, _QuadrupletsClassifierMixin): + for slice_idx in [slice(0, 2), slice(2, 4)]: + pairs = trunc_data[:, slice_idx, :] + diffs = pairs[:, 1, :] - pairs[:, 0, :] + to_keep = np.where(np.abs(diffs.ravel()) > 1e-9) + trunc_data = trunc_data[to_keep] + labels = labels[to_keep] + elif isinstance(estimator, _PairsClassifierMixin): + diffs = trunc_data[:, 1, :] - trunc_data[:, 0, :] + to_keep = np.where(np.abs(diffs.ravel()) > 1e-9) + trunc_data = trunc_data[to_keep] + labels = labels[to_keep] model.fit(trunc_data, labels) assert model.transformer_.shape == (1, 1) # the transformer must be 2D diff --git a/test/test_sklearn_compat.py b/test/test_sklearn_compat.py index d9dce685..f1248c9a 100644 --- a/test/test_sklearn_compat.py +++ b/test/test_sklearn_compat.py @@ -72,9 +72,18 @@ def test_itml(self): def test_mmc(self): check_estimator(dMMC) - # This fails due to a FloatingPointError - # def test_sdml(self): - # check_estimator(dSDML) + def test_sdml(self): + def stable_init(self, sparsity_param=0.01, num_labeled='deprecated', + num_constraints=None, verbose=False, preprocessor=None): + # this init makes SDML stable for scikit-learn examples. + SDML_Supervised.__init__(self, sparsity_param=sparsity_param, + num_labeled=num_labeled, + num_constraints=num_constraints, + verbose=verbose, + preprocessor=preprocessor, + balance_param=1e-5, use_cov=False) + dSDML.__init__ = stable_init + check_estimator(dSDML) # This fails because the default num_chunks isn't data-dependent. # def test_rca(self): diff --git a/test/test_transformer_metric_conversion.py b/test/test_transformer_metric_conversion.py index 59986011..6cfe8281 100644 --- a/test/test_transformer_metric_conversion.py +++ b/test/test_transformer_metric_conversion.py @@ -44,7 +44,8 @@ def test_lmnn(self): def test_sdml_supervised(self): seed = np.random.RandomState(1234) - sdml = SDML_Supervised(num_constraints=1500) + sdml = SDML_Supervised(num_constraints=1500, use_cov=False, + balance_param=1e-5) sdml.fit(self.X, self.y, random_state=seed) L = sdml.transformer_ assert_array_almost_equal(L.T.dot(L), sdml.get_mahalanobis_matrix()) diff --git a/test/test_utils.py b/test/test_utils.py index 9099e12d..f1df4098 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -102,26 +102,25 @@ def build_quadruplets(with_preprocessor=False): pairs_learners = [(ITML(), build_pairs), (MMC(max_iter=2), build_pairs), # max_iter=2 for faster - (SDML(), build_pairs), - ] + (SDML(use_cov=False, balance_param=1e-5), build_pairs)] ids_pairs_learners = list(map(lambda x: x.__class__.__name__, - [learner for (learner, _) in - pairs_learners])) - -classifiers = [(Covariance(), build_classification), - (LFDA(), build_classification), - (LMNN(), build_classification), - (NCA(), build_classification), - (RCA(), build_classification), - (ITML_Supervised(max_iter=5), build_classification), - (LSML_Supervised(), build_classification), - (MMC_Supervised(max_iter=5), build_classification), - (RCA_Supervised(num_chunks=10), build_classification), - (SDML_Supervised(), build_classification) - ] + [learner for (learner, _) in + pairs_learners])) + +classifiers = [(Covariance(), build_classification), + (LFDA(), build_classification), + (LMNN(), build_classification), + (NCA(), build_classification), + (RCA(), build_classification), + (ITML_Supervised(max_iter=5), build_classification), + (LSML_Supervised(), build_classification), + (MMC_Supervised(max_iter=5), build_classification), + (RCA_Supervised(num_chunks=10), build_classification), + (SDML_Supervised(use_cov=False, balance_param=1e-5), + build_classification)] ids_classifiers = list(map(lambda x: x.__class__.__name__, - [learner for (learner, _) in - classifiers])) + [learner for (learner, _) in + classifiers])) regressors = [(MLKR(), build_regression)] ids_regressors = list(map(lambda x: x.__class__.__name__, @@ -830,9 +829,9 @@ class MockMetricLearner(MahalanobisMixin): "or a callable.".format(type(preprocessor))) -@pytest.mark.parametrize('estimator', [ITML(), LSML(), MMC(), SDML()], - ids=['ITML', 'LSML', 'MMC', 'SDML']) -def test_error_message_tuple_size(estimator): +@pytest.mark.parametrize('estimator, _', tuples_learners, + ids=ids_tuples_learners) +def test_error_message_tuple_size(estimator, _): """Tests that if a tuples learner is not given the good number of points per tuple, it throws an error message""" estimator = clone(estimator)