diff --git a/metric_learn/_util.py b/metric_learn/_util.py index b476e70b..fa196a69 100644 --- a/metric_learn/_util.py +++ b/metric_learn/_util.py @@ -1,5 +1,4 @@ import numpy as np -import scipy import six from numpy.linalg import LinAlgError from sklearn.datasets import make_spd_matrix @@ -8,9 +7,10 @@ from sklearn.utils.validation import check_X_y, check_random_state from .exceptions import PreprocessorError, NonPSDError from sklearn.discriminant_analysis import LinearDiscriminantAnalysis -from scipy.linalg import pinvh +from scipy.linalg import pinvh, eigh import sys import time +import warnings # hack around lack of axis kwarg in older numpy versions try: @@ -678,17 +678,20 @@ def _initialize_metric_mahalanobis(input, init='identity', random_state=None, random_state = check_random_state(random_state) M = init - if isinstance(init, np.ndarray): - s, u = scipy.linalg.eigh(init) - init_is_definite = _check_sdp_from_eigen(s) + if isinstance(M, np.ndarray): + w, V = eigh(M, check_finite=False) + init_is_definite = _check_sdp_from_eigen(w) if strict_pd and not init_is_definite: raise LinAlgError("You should provide a strictly positive definite " "matrix as `{}`. This one is not definite. Try another" " {}, or an algorithm that does not " "require the {} to be strictly positive definite." .format(*((matrix_name,) * 3))) + elif return_inverse and not init_is_definite: + warnings.warn('The initialization matrix is not invertible: ' + 'using the pseudo-inverse instead.') if return_inverse: - M_inv = np.dot(u / s, u.T) + M_inv = _pseudo_inverse_from_eig(w, V) return M, M_inv else: return M @@ -707,15 +710,23 @@ def _initialize_metric_mahalanobis(input, init='identity', random_state=None, X = input # atleast2d is necessary to deal with scalar covariance matrices M_inv = np.atleast_2d(np.cov(X, rowvar=False)) - s, u = scipy.linalg.eigh(M_inv) - cov_is_definite = _check_sdp_from_eigen(s) + w, V = eigh(M_inv, check_finite=False) + cov_is_definite = _check_sdp_from_eigen(w) if strict_pd and not cov_is_definite: raise LinAlgError("Unable to get a true inverse of the covariance " "matrix since it is not definite. Try another " "`{}`, or an algorithm that does not " "require the `{}` to be strictly positive definite." .format(*((matrix_name,) * 2))) - M = np.dot(u / s, u.T) + elif not cov_is_definite: + warnings.warn('The covariance matrix is not invertible: ' + 'using the pseudo-inverse instead.' + 'To make the covariance matrix invertible' + ' you can remove any linearly dependent features and/or ' + 'reduce the dimensionality of your input, ' + 'for instance using `sklearn.decomposition.PCA` as a ' + 'preprocessing step.') + M = _pseudo_inverse_from_eig(w, V) if return_inverse: return M, M_inv else: @@ -742,3 +753,36 @@ def _check_n_components(n_features, n_components): if 0 < n_components <= n_features: return n_components raise ValueError('Invalid n_components, must be in [1, %d]' % n_features) + + +def _pseudo_inverse_from_eig(w, V, tol=None): + """Compute the (Moore-Penrose) pseudo-inverse of the EVD of a symetric + matrix. + + Parameters + ---------- + w : (..., M) ndarray + The eigenvalues in ascending order, each repeated according to + its multiplicity. + + v : {(..., M, M) ndarray, (..., M, M) matrix} + The column ``v[:, i]`` is the normalized eigenvector corresponding + to the eigenvalue ``w[i]``. Will return a matrix object if `a` is + a matrix object. + + tol : positive `float`, optional + Absolute eigenvalues below tol are considered zero. + + Returns + ------- + output : (..., M, N) array_like + The pseudo-inverse given by the EVD. + """ + if tol is None: + tol = np.amax(w) * np.max(w.shape) * np.finfo(w.dtype).eps + # discard small eigenvalues and invert the rest + large = np.abs(w) > tol + w = np.divide(1, w, where=large, out=w) + w[~large] = 0 + + return np.dot(V * w, np.conjugate(V).T) diff --git a/test/test_mahalanobis_mixin.py b/test/test_mahalanobis_mixin.py index 91aa129a..91fb435f 100644 --- a/test/test_mahalanobis_mixin.py +++ b/test/test_mahalanobis_mixin.py @@ -8,12 +8,12 @@ from scipy.stats import ortho_group from sklearn import clone from sklearn.cluster import DBSCAN -from sklearn.datasets import make_spd_matrix -from sklearn.utils import check_random_state +from sklearn.datasets import make_spd_matrix, make_blobs +from sklearn.utils import check_random_state, shuffle from sklearn.utils.multiclass import type_of_target from sklearn.utils.testing import set_random_state -from metric_learn._util import make_context +from metric_learn._util import make_context, _initialize_metric_mahalanobis from metric_learn.base_metric import (_QuadrupletsClassifierMixin, _PairsClassifierMixin) from metric_learn.exceptions import NonPSDError @@ -569,7 +569,7 @@ def test_init_mahalanobis(estimator, build_dataset): in zip(ids_metric_learners, metric_learners) if idml[:4] in ['ITML', 'SDML', 'LSML']]) -def test_singular_covariance_init_or_prior(estimator, build_dataset): +def test_singular_covariance_init_or_prior_strictpd(estimator, build_dataset): """Tests that when using the 'covariance' init or prior, it returns the appropriate error if the covariance matrix is singular, for algorithms that need a strictly PD prior or init (see @@ -603,6 +603,48 @@ def test_singular_covariance_init_or_prior(estimator, build_dataset): assert str(raised_err.value) == msg +@pytest.mark.integration +@pytest.mark.parametrize('estimator, build_dataset', + [(ml, bd) for idml, (ml, bd) + in zip(ids_metric_learners, + metric_learners) + if idml[:3] in ['MMC']], + ids=[idml for idml, (ml, _) + in zip(ids_metric_learners, + metric_learners) + if idml[:3] in ['MMC']]) +def test_singular_covariance_init_of_non_strict_pd(estimator, build_dataset): + """Tests that when using the 'covariance' init or prior, it returns the + appropriate warning if the covariance matrix is singular, for algorithms + that don't need a strictly PD init. Also checks that the returned + inverse matrix has finite values + """ + input_data, labels, _, X = build_dataset() + model = clone(estimator) + set_random_state(model) + # We create a feature that is a linear combination of the first two + # features: + input_data = np.concatenate([input_data, input_data[:, ..., :2].dot([[2], + [3]])], + axis=-1) + model.set_params(init='covariance') + msg = ('The covariance matrix is not invertible: ' + 'using the pseudo-inverse instead.' + 'To make the covariance matrix invertible' + ' you can remove any linearly dependent features and/or ' + 'reduce the dimensionality of your input, ' + 'for instance using `sklearn.decomposition.PCA` as a ' + 'preprocessing step.') + with pytest.warns(UserWarning) as raised_warning: + model.fit(input_data, labels) + assert np.any([str(warning.message) == msg for warning in raised_warning]) + M, _ = _initialize_metric_mahalanobis(X, init='covariance', + random_state=RNG, + return_inverse=True, + strict_pd=False) + assert np.isfinite(M).all() + + @pytest.mark.integration @pytest.mark.parametrize('estimator, build_dataset', [(ml, bd) for idml, (ml, bd) @@ -614,7 +656,7 @@ def test_singular_covariance_init_or_prior(estimator, build_dataset): metric_learners) if idml[:4] in ['ITML', 'SDML', 'LSML']]) @pytest.mark.parametrize('w0', [1e-20, 0., -1e-20]) -def test_singular_array_init_or_prior(estimator, build_dataset, w0): +def test_singular_array_init_or_prior_strictpd(estimator, build_dataset, w0): """Tests that when using a custom array init (or prior), it returns the appropriate error if it is singular, for algorithms that need a strictly PD prior or init (see @@ -654,6 +696,31 @@ def test_singular_array_init_or_prior(estimator, build_dataset, w0): assert str(raised_err.value) == msg +@pytest.mark.parametrize('w0', [1e-20, 0., -1e-20]) +def test_singular_array_init_of_non_strict_pd(w0): + """Tests that when using a custom array init, it returns the + appropriate warning if it is singular. Also checks if the returned + inverse matrix is finite. This isn't checked for model fitting as no + model curently uses this setting. + """ + rng = np.random.RandomState(42) + X, y = shuffle(*make_blobs(random_state=rng), + random_state=rng) + P = ortho_group.rvs(X.shape[1], random_state=rng) + w = np.abs(rng.randn(X.shape[1])) + w[0] = w0 + M = P.dot(np.diag(w)).dot(P.T) + msg = ('The initialization matrix is not invertible: ' + 'using the pseudo-inverse instead.') + with pytest.warns(UserWarning) as raised_warning: + _, M_inv = _initialize_metric_mahalanobis(X, init=M, + random_state=rng, + return_inverse=True, + strict_pd=False) + assert str(raised_warning[0].message) == msg + assert np.isfinite(M_inv).all() + + @pytest.mark.integration @pytest.mark.parametrize('estimator, build_dataset', metric_learners, ids=ids_metric_learners) diff --git a/test/test_utils.py b/test/test_utils.py index 3092e168..0ea871bb 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,4 +1,5 @@ import pytest +from scipy.linalg import eigh, pinvh from collections import namedtuple import numpy as np from numpy.testing import assert_array_equal, assert_equal @@ -11,7 +12,7 @@ check_collapsed_pairs, validate_vector, _check_sdp_from_eigen, _check_n_components, check_y_valid_values_for_pairs, - _auto_select_init) + _auto_select_init, _pseudo_inverse_from_eig) from metric_learn import (ITML, LSML, MMC, RCA, SDML, Covariance, LFDA, LMNN, MLKR, NCA, ITML_Supervised, LSML_Supervised, MMC_Supervised, RCA_Supervised, SDML_Supervised, @@ -1146,3 +1147,27 @@ def test__auto_select_init(has_classes, n_features, n_samples, n_components, """Checks that the auto selection of the init works as expected""" assert (_auto_select_init(has_classes, n_features, n_samples, n_components, n_classes) == result) + + +@pytest.mark.parametrize('w0', [1e-20, 0., -1e-20]) +def test_pseudo_inverse_from_eig_and_pinvh_singular(w0): + """Checks that _pseudo_inverse_from_eig returns the same result as + scipy.linalg.pinvh for a singular matrix""" + rng = np.random.RandomState(SEED) + A = rng.rand(100, 100) + A = A + A.T + w, V = eigh(A) + w[0] = w0 + A = V.dot(np.diag(w)).dot(V.T) + np.testing.assert_allclose(_pseudo_inverse_from_eig(w, V), pinvh(A), + rtol=1e-05) + + +def test_pseudo_inverse_from_eig_and_pinvh_nonsingular(): + """Checks that _pseudo_inverse_from_eig returns the same result as + scipy.linalg.pinvh for a non singular matrix""" + rng = np.random.RandomState(SEED) + A = rng.rand(100, 100) + A = A + A.T + w, V = eigh(A, check_finite=False) + np.testing.assert_allclose(_pseudo_inverse_from_eig(w, V), pinvh(A))