From 3ef4f9ce7c3552d40e4b97582385be0036bdae2b Mon Sep 17 00:00:00 2001 From: grudloff Date: Fri, 24 Jan 2020 10:05:17 +0100 Subject: [PATCH 01/17] Fix covariance init when matrix is not invertible --- metric_learn/_util.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/metric_learn/_util.py b/metric_learn/_util.py index b476e70b..20c54f22 100644 --- a/metric_learn/_util.py +++ b/metric_learn/_util.py @@ -707,15 +707,19 @@ def _initialize_metric_mahalanobis(input, init='identity', random_state=None, X = input # atleast2d is necessary to deal with scalar covariance matrices M_inv = np.atleast_2d(np.cov(X, rowvar=False)) - s, u = scipy.linalg.eigh(M_inv) - cov_is_definite = _check_sdp_from_eigen(s) - if strict_pd and not cov_is_definite: - raise LinAlgError("Unable to get a true inverse of the covariance " - "matrix since it is not definite. Try another " - "`{}`, or an algorithm that does not " - "require the `{}` to be strictly positive definite." - .format(*((matrix_name,) * 2))) - M = np.dot(u / s, u.T) + if strict_pd: + s, u = scipy.linalg.eigh(M_inv) + cov_is_definite = _check_sdp_from_eigen(s) + if not cov_is_definite: + raise LinAlgError("Unable to get a true inverse of the covariance " + "matrix since it is not definite. Try another " + "`{}`, or an algorithm that does not " + "require the `{}` to be strictly positive definite." + .format(*((matrix_name,) * 2))) + else: + M = np.dot(u / s, u.T) + else: + M = pinvh(M_inv) if return_inverse: return M, M_inv else: From 2c1e308f7fc345848d4bdd52f54e8733a47886df Mon Sep 17 00:00:00 2001 From: grudloff Date: Fri, 24 Jan 2020 15:42:54 +0100 Subject: [PATCH 02/17] replaced import scipy for only required functions --- metric_learn/_util.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/metric_learn/_util.py b/metric_learn/_util.py index 20c54f22..4be74a08 100644 --- a/metric_learn/_util.py +++ b/metric_learn/_util.py @@ -1,5 +1,4 @@ import numpy as np -import scipy import six from numpy.linalg import LinAlgError from sklearn.datasets import make_spd_matrix @@ -8,7 +7,7 @@ from sklearn.utils.validation import check_X_y, check_random_state from .exceptions import PreprocessorError, NonPSDError from sklearn.discriminant_analysis import LinearDiscriminantAnalysis -from scipy.linalg import pinvh +from scipy.linalg import pinvh, eigh import sys import time @@ -679,7 +678,7 @@ def _initialize_metric_mahalanobis(input, init='identity', random_state=None, random_state = check_random_state(random_state) M = init if isinstance(init, np.ndarray): - s, u = scipy.linalg.eigh(init) + s, u = eigh(init) init_is_definite = _check_sdp_from_eigen(s) if strict_pd and not init_is_definite: raise LinAlgError("You should provide a strictly positive definite " @@ -708,7 +707,7 @@ def _initialize_metric_mahalanobis(input, init='identity', random_state=None, # atleast2d is necessary to deal with scalar covariance matrices M_inv = np.atleast_2d(np.cov(X, rowvar=False)) if strict_pd: - s, u = scipy.linalg.eigh(M_inv) + s, u = eigh(M_inv) cov_is_definite = _check_sdp_from_eigen(s) if not cov_is_definite: raise LinAlgError("Unable to get a true inverse of the covariance " From eff07d474ac85c2bd8f4ebc507154c2054e2ecfc Mon Sep 17 00:00:00 2001 From: grudloff Date: Mon, 27 Jan 2020 10:52:30 +0100 Subject: [PATCH 03/17] Change inv for pseudo-inv on custom matrix init --- metric_learn/_util.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/metric_learn/_util.py b/metric_learn/_util.py index 4be74a08..ea64ef43 100644 --- a/metric_learn/_util.py +++ b/metric_learn/_util.py @@ -677,8 +677,8 @@ def _initialize_metric_mahalanobis(input, init='identity', random_state=None, random_state = check_random_state(random_state) M = init - if isinstance(init, np.ndarray): - s, u = eigh(init) + if isinstance(M, np.ndarray): + s, u = eigh(M) init_is_definite = _check_sdp_from_eigen(s) if strict_pd and not init_is_definite: raise LinAlgError("You should provide a strictly positive definite " @@ -687,7 +687,7 @@ def _initialize_metric_mahalanobis(input, init='identity', random_state=None, "require the {} to be strictly positive definite." .format(*((matrix_name,) * 3))) if return_inverse: - M_inv = np.dot(u / s, u.T) + M_inv = pinvh(M) return M, M_inv else: return M From bd690c8796bbce955867f1a6eb811ca91dd71006 Mon Sep 17 00:00:00 2001 From: grudloff Date: Mon, 27 Jan 2020 16:02:41 +0100 Subject: [PATCH 04/17] Change from EVD to SVD --- metric_learn/_util.py | 75 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 59 insertions(+), 16 deletions(-) diff --git a/metric_learn/_util.py b/metric_learn/_util.py index ea64ef43..dfb7cbe8 100644 --- a/metric_learn/_util.py +++ b/metric_learn/_util.py @@ -7,9 +7,10 @@ from sklearn.utils.validation import check_X_y, check_random_state from .exceptions import PreprocessorError, NonPSDError from sklearn.discriminant_analysis import LinearDiscriminantAnalysis -from scipy.linalg import pinvh, eigh +from scipy.linalg import pinvh import sys import time +import warnings # hack around lack of axis kwarg in older numpy versions try: @@ -678,7 +679,7 @@ def _initialize_metric_mahalanobis(input, init='identity', random_state=None, random_state = check_random_state(random_state) M = init if isinstance(M, np.ndarray): - s, u = eigh(M) + U, s, Vh = np.linalg.svd(M, full_matrices=False, hermitian=True) init_is_definite = _check_sdp_from_eigen(s) if strict_pd and not init_is_definite: raise LinAlgError("You should provide a strictly positive definite " @@ -686,8 +687,11 @@ def _initialize_metric_mahalanobis(input, init='identity', random_state=None, " {}, or an algorithm that does not " "require the {} to be strictly positive definite." .format(*((matrix_name,) * 3))) + elif return_inverse and not init_is_definite: + warnings.warn('The initialization matrix is not invertible, ' + 'but this isn"t an issue as the pseudo-inverse is used.') if return_inverse: - M_inv = pinvh(M) + M_inv = _pseudo_inverse_from_svd(U, s, Vh) return M, M_inv else: return M @@ -706,19 +710,22 @@ def _initialize_metric_mahalanobis(input, init='identity', random_state=None, X = input # atleast2d is necessary to deal with scalar covariance matrices M_inv = np.atleast_2d(np.cov(X, rowvar=False)) - if strict_pd: - s, u = eigh(M_inv) - cov_is_definite = _check_sdp_from_eigen(s) - if not cov_is_definite: - raise LinAlgError("Unable to get a true inverse of the covariance " - "matrix since it is not definite. Try another " - "`{}`, or an algorithm that does not " - "require the `{}` to be strictly positive definite." - .format(*((matrix_name,) * 2))) - else: - M = np.dot(u / s, u.T) - else: - M = pinvh(M_inv) + U, s, Vh = np.linalg.svd(M_inv, full_matrices=False, hermitian=True) + cov_is_definite = _check_sdp_from_eigen(s) + if strict_pd and not cov_is_definite: + raise LinAlgError("Unable to get a true inverse of the covariance " + "matrix since it is not definite. Try another " + "`{}`, or an algorithm that does not " + "require the `{}` to be strictly positive definite." + .format(*((matrix_name,) * 2))) + elif not cov_is_definite: + warnings.warn('The inverse covariance matrix is not invertible, ' + 'but this isn"t an issue as the pseudo-inverse is used. ' + 'You can remove any linearly dependent features and/or ' + 'reduce the dimensionality of your input, ' + 'for instance using `sklearn.decomposition.PCA` as a ' + 'preprocessing step.') + M = _pseudo_inverse_from_svd(U, s, Vh) if return_inverse: return M, M_inv else: @@ -745,3 +752,39 @@ def _check_n_components(n_features, n_components): if 0 < n_components <= n_features: return n_components raise ValueError('Invalid n_components, must be in [1, %d]' % n_features) + + +def _pseudo_inverse_from_svd(u, s, vt, tol=1e-15): + """Compute the (Moore-Penrose) pseudo-inverse of the SVD of a matrix. + + Parameters + ---------- + u : { (..., M, M), (..., M, K) } array + Unitary array(s). + + s : (..., K) array + Vector(s) with the singular values, within each vector sorted in + descending order. + + vh : { (..., N, N), (..., K, N) } array + Unitary array(s). + + tol : positive `float`, optional + Absolute eigenvalues below tol are considered zero. + + Returns + ------- + output : (…, M, N) array_like + The pseudo-inverse give by the SVD. + """ + # discard small singular values + tol = np.asarray(tol) + cutoff = tol[..., np.core.newaxis]*np.core.amax(s, axis=-1, + keepdims=True) + large = s > cutoff + s = np.core.divide(1, s, where=large, out=s) + s[~large] = 0 + # output = vt.T * s^+ * U.T + return np.core.matmul(np.core.swapaxes(vt, -1, -2), + np.core.multiply(s[..., np.core.newaxis], + np.core.swapaxes(u, -1, -2))) From 098255917af8f3cce1e85d129d76df168889dedc Mon Sep 17 00:00:00 2001 From: grudloff Date: Tue, 28 Jan 2020 16:02:37 +0100 Subject: [PATCH 05/17] Roll back to EVD and pseudo inverse of EVD --- metric_learn/_util.py | 55 ++++++++++++++++++++----------------------- 1 file changed, 26 insertions(+), 29 deletions(-) diff --git a/metric_learn/_util.py b/metric_learn/_util.py index dfb7cbe8..a09c3ae6 100644 --- a/metric_learn/_util.py +++ b/metric_learn/_util.py @@ -7,7 +7,7 @@ from sklearn.utils.validation import check_X_y, check_random_state from .exceptions import PreprocessorError, NonPSDError from sklearn.discriminant_analysis import LinearDiscriminantAnalysis -from scipy.linalg import pinvh +from scipy.linalg import pinvh, eigh import sys import time import warnings @@ -679,8 +679,8 @@ def _initialize_metric_mahalanobis(input, init='identity', random_state=None, random_state = check_random_state(random_state) M = init if isinstance(M, np.ndarray): - U, s, Vh = np.linalg.svd(M, full_matrices=False, hermitian=True) - init_is_definite = _check_sdp_from_eigen(s) + w, V = eigh(M, check_finite=False) + init_is_definite = _check_sdp_from_eigen(w) if strict_pd and not init_is_definite: raise LinAlgError("You should provide a strictly positive definite " "matrix as `{}`. This one is not definite. Try another" @@ -691,7 +691,7 @@ def _initialize_metric_mahalanobis(input, init='identity', random_state=None, warnings.warn('The initialization matrix is not invertible, ' 'but this isn"t an issue as the pseudo-inverse is used.') if return_inverse: - M_inv = _pseudo_inverse_from_svd(U, s, Vh) + M_inv = _pseudo_inverse_from_eig(w, V) return M, M_inv else: return M @@ -710,8 +710,8 @@ def _initialize_metric_mahalanobis(input, init='identity', random_state=None, X = input # atleast2d is necessary to deal with scalar covariance matrices M_inv = np.atleast_2d(np.cov(X, rowvar=False)) - U, s, Vh = np.linalg.svd(M_inv, full_matrices=False, hermitian=True) - cov_is_definite = _check_sdp_from_eigen(s) + w, V = eigh(M_inv, check_finite=False) + cov_is_definite = _check_sdp_from_eigen(w) if strict_pd and not cov_is_definite: raise LinAlgError("Unable to get a true inverse of the covariance " "matrix since it is not definite. Try another " @@ -725,7 +725,7 @@ def _initialize_metric_mahalanobis(input, init='identity', random_state=None, 'reduce the dimensionality of your input, ' 'for instance using `sklearn.decomposition.PCA` as a ' 'preprocessing step.') - M = _pseudo_inverse_from_svd(U, s, Vh) + M = _pseudo_inverse_from_eig(w, V) if return_inverse: return M, M_inv else: @@ -754,20 +754,20 @@ def _check_n_components(n_features, n_components): raise ValueError('Invalid n_components, must be in [1, %d]' % n_features) -def _pseudo_inverse_from_svd(u, s, vt, tol=1e-15): - """Compute the (Moore-Penrose) pseudo-inverse of the SVD of a matrix. +def _pseudo_inverse_from_eig(w, V, tol=None): + """Compute the (Moore-Penrose) pseudo-inverse of the EVD of a symetric + matrix. Parameters ---------- - u : { (..., M, M), (..., M, K) } array - Unitary array(s). + w : (..., M) ndarray + The eigenvalues in ascending order, each repeated according to + its multiplicity. - s : (..., K) array - Vector(s) with the singular values, within each vector sorted in - descending order. - - vh : { (..., N, N), (..., K, N) } array - Unitary array(s). + v : {(..., M, M) ndarray, (..., M, M) matrix} + The column ``v[:, i]`` is the normalized eigenvector corresponding + to the eigenvalue ``w[i]``. Will return a matrix object if `a` is + a matrix object. tol : positive `float`, optional Absolute eigenvalues below tol are considered zero. @@ -775,16 +775,13 @@ def _pseudo_inverse_from_svd(u, s, vt, tol=1e-15): Returns ------- output : (…, M, N) array_like - The pseudo-inverse give by the SVD. + The pseudo-inverse given by the EVD. """ - # discard small singular values - tol = np.asarray(tol) - cutoff = tol[..., np.core.newaxis]*np.core.amax(s, axis=-1, - keepdims=True) - large = s > cutoff - s = np.core.divide(1, s, where=large, out=s) - s[~large] = 0 - # output = vt.T * s^+ * U.T - return np.core.matmul(np.core.swapaxes(vt, -1, -2), - np.core.multiply(s[..., np.core.newaxis], - np.core.swapaxes(u, -1, -2))) + if tol is None: + tol = np.amax(w) * np.max(w.shape) * np.finfo(w.dtype).eps + # discard small eigenvalues and invert the rest + large = np.abs(w) > tol + w = np.divide(1, w, where=large, out=w) + w[~large] = 0 + + return np.dot(V * w, np.conjugate(V).T) From baad5a743986c7aa3b6a1b1562c254b7f3304702 Mon Sep 17 00:00:00 2001 From: grudloff Date: Wed, 29 Jan 2020 09:38:46 +0100 Subject: [PATCH 06/17] Fix non-ASCII char --- metric_learn/_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metric_learn/_util.py b/metric_learn/_util.py index a09c3ae6..666d3289 100644 --- a/metric_learn/_util.py +++ b/metric_learn/_util.py @@ -774,7 +774,7 @@ def _pseudo_inverse_from_eig(w, V, tol=None): Returns ------- - output : (…, M, N) array_like + output : (..., M, N) array_like The pseudo-inverse given by the EVD. """ if tol is None: From 80ab9221a3291517bf0220e013d1dacc20a0adf0 Mon Sep 17 00:00:00 2001 From: grudloff Date: Thu, 30 Jan 2020 10:08:32 +0100 Subject: [PATCH 07/17] rephrasing warnings --- metric_learn/_util.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/metric_learn/_util.py b/metric_learn/_util.py index 666d3289..0edd7495 100644 --- a/metric_learn/_util.py +++ b/metric_learn/_util.py @@ -688,8 +688,8 @@ def _initialize_metric_mahalanobis(input, init='identity', random_state=None, "require the {} to be strictly positive definite." .format(*((matrix_name,) * 3))) elif return_inverse and not init_is_definite: - warnings.warn('The initialization matrix is not invertible, ' - 'but this isn"t an issue as the pseudo-inverse is used.') + warnings.warn('The initialization matrix is not invertible: ' + 'using the pseudo-inverse instead.') if return_inverse: M_inv = _pseudo_inverse_from_eig(w, V) return M, M_inv @@ -719,9 +719,10 @@ def _initialize_metric_mahalanobis(input, init='identity', random_state=None, "require the `{}` to be strictly positive definite." .format(*((matrix_name,) * 2))) elif not cov_is_definite: - warnings.warn('The inverse covariance matrix is not invertible, ' - 'but this isn"t an issue as the pseudo-inverse is used. ' - 'You can remove any linearly dependent features and/or ' + warnings.warn('The initialization matrix is not invertible: ' + 'using the pseudo-inverse instead.' + 'To make the inverse covariance matrix invertible' + ' you can remove any linearly dependent features and/or ' 'reduce the dimensionality of your input, ' 'for instance using `sklearn.decomposition.PCA` as a ' 'preprocessing step.') From 66e84d810a5ffc54e6a3fa553d6d55ed84eb90ed Mon Sep 17 00:00:00 2001 From: grudloff Date: Thu, 30 Jan 2020 10:10:39 +0100 Subject: [PATCH 08/17] added tests --- test/test_mahalanobis_mixin.py | 74 +++++++++++++++++++++++++++++++--- 1 file changed, 69 insertions(+), 5 deletions(-) diff --git a/test/test_mahalanobis_mixin.py b/test/test_mahalanobis_mixin.py index 91aa129a..571bb9e6 100644 --- a/test/test_mahalanobis_mixin.py +++ b/test/test_mahalanobis_mixin.py @@ -8,12 +8,12 @@ from scipy.stats import ortho_group from sklearn import clone from sklearn.cluster import DBSCAN -from sklearn.datasets import make_spd_matrix -from sklearn.utils import check_random_state +from sklearn.datasets import make_spd_matrix, make_blobs +from sklearn.utils import check_random_state, shuffle from sklearn.utils.multiclass import type_of_target from sklearn.utils.testing import set_random_state -from metric_learn._util import make_context +from metric_learn._util import make_context, _initialize_metric_mahalanobis from metric_learn.base_metric import (_QuadrupletsClassifierMixin, _PairsClassifierMixin) from metric_learn.exceptions import NonPSDError @@ -569,7 +569,7 @@ def test_init_mahalanobis(estimator, build_dataset): in zip(ids_metric_learners, metric_learners) if idml[:4] in ['ITML', 'SDML', 'LSML']]) -def test_singular_covariance_init_or_prior(estimator, build_dataset): +def test_singular_covariance_init_or_prior_strictpd(estimator, build_dataset): """Tests that when using the 'covariance' init or prior, it returns the appropriate error if the covariance matrix is singular, for algorithms that need a strictly PD prior or init (see @@ -603,6 +603,45 @@ def test_singular_covariance_init_or_prior(estimator, build_dataset): assert str(raised_err.value) == msg +@pytest.mark.integration +@pytest.mark.parametrize('estimator, build_dataset', + [(ml, bd) for idml, (ml, bd) + in zip(ids_metric_learners, + metric_learners) + if idml[:3] in ['MMC']], + ids=[idml for idml, (ml, _) + in zip(ids_metric_learners, + metric_learners) + if idml[:3] in ['MMC']]) +def test_singular_covariance_init_of_non_strict_pd(estimator, build_dataset): + """Tests that when using the 'covariance' init or prior, it returns the + appropriate warning if the covariance matrix is singular, for algorithms + that don't need a strictly PD init. Also checks that the returned + inverse matrix has finite values + """ + input_data, labels, _, X = build_dataset() + model = clone(estimator) + set_random_state(model) + # We create a feature that is a linear combination of the first two + # features: + input_data = np.concatenate([input_data, input_data[:, ..., :2].dot([[2], + [3]])], + axis=-1) + model.set_params(init='covariance') + msg = ('The initialization matrix is not invertible: ' + 'using the pseudo-inverse instead.' + 'To make the inverse covariance matrix invertible' + ' you can remove any linearly dependent features and/or ' + 'reduce the dimensionality of your input, ' + 'for instance using `sklearn.decomposition.PCA` as a ' + 'preprocessing step.') + with pytest.warns(UserWarning) as raised_warning: + model.fit(input_data, labels) + assert np.any([str(warning.message) == msg for warning in raised_warning]) + M = model.get_mahalanobis_matrix() + assert np.isfinite(M).all() + + @pytest.mark.integration @pytest.mark.parametrize('estimator, build_dataset', [(ml, bd) for idml, (ml, bd) @@ -614,7 +653,7 @@ def test_singular_covariance_init_or_prior(estimator, build_dataset): metric_learners) if idml[:4] in ['ITML', 'SDML', 'LSML']]) @pytest.mark.parametrize('w0', [1e-20, 0., -1e-20]) -def test_singular_array_init_or_prior(estimator, build_dataset, w0): +def test_singular_array_init_or_prior_strictpd(estimator, build_dataset, w0): """Tests that when using a custom array init (or prior), it returns the appropriate error if it is singular, for algorithms that need a strictly PD prior or init (see @@ -654,6 +693,31 @@ def test_singular_array_init_or_prior(estimator, build_dataset, w0): assert str(raised_err.value) == msg +@pytest.mark.parametrize('w0', [1e-20, 0., -1e-20]) +def test_singular_array_init_of_non_strict_pd(w0): + """Tests that when using a custom array init, it returns the + appropriate warning if it is singular. Also checks if the returned + inverse matrix is finite. This isn't checked for model fitting as no + model curently uses this setting. + """ + rng = np.random.RandomState(42) + X, y = shuffle(*make_blobs(random_state=rng), + random_state=rng) + P = ortho_group.rvs(X.shape[1], random_state=rng) + w = np.abs(rng.randn(X.shape[1])) + w[0] = w0 + M = P.dot(np.diag(w)).dot(P.T) + msg = ('The initialization matrix is not invertible: ' + 'using the pseudo-inverse instead.') + with pytest.warns(UserWarning) as raised_warning: + _, M_inv = _initialize_metric_mahalanobis(X, init=M, + random_state=rng, + return_inverse=True, + strict_pd=False) + assert str(raised_warning[0].message) == msg + assert np.isfinite(M_inv).all() + + @pytest.mark.integration @pytest.mark.parametrize('estimator, build_dataset', metric_learners, ids=ids_metric_learners) From 57877347c8be0a449fcca87e6916843f5dc7b325 Mon Sep 17 00:00:00 2001 From: grudloff Date: Fri, 31 Jan 2020 15:46:13 +0100 Subject: [PATCH 09/17] more rephrasing --- metric_learn/_util.py | 4 ++-- test/test_mahalanobis_mixin.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/metric_learn/_util.py b/metric_learn/_util.py index 0edd7495..fa196a69 100644 --- a/metric_learn/_util.py +++ b/metric_learn/_util.py @@ -719,9 +719,9 @@ def _initialize_metric_mahalanobis(input, init='identity', random_state=None, "require the `{}` to be strictly positive definite." .format(*((matrix_name,) * 2))) elif not cov_is_definite: - warnings.warn('The initialization matrix is not invertible: ' + warnings.warn('The covariance matrix is not invertible: ' 'using the pseudo-inverse instead.' - 'To make the inverse covariance matrix invertible' + 'To make the covariance matrix invertible' ' you can remove any linearly dependent features and/or ' 'reduce the dimensionality of your input, ' 'for instance using `sklearn.decomposition.PCA` as a ' diff --git a/test/test_mahalanobis_mixin.py b/test/test_mahalanobis_mixin.py index 571bb9e6..de9804b5 100644 --- a/test/test_mahalanobis_mixin.py +++ b/test/test_mahalanobis_mixin.py @@ -628,9 +628,9 @@ def test_singular_covariance_init_of_non_strict_pd(estimator, build_dataset): [3]])], axis=-1) model.set_params(init='covariance') - msg = ('The initialization matrix is not invertible: ' + msg = ('The covariance matrix is not invertible: ' 'using the pseudo-inverse instead.' - 'To make the inverse covariance matrix invertible' + 'To make the covariance matrix invertible' ' you can remove any linearly dependent features and/or ' 'reduce the dimensionality of your input, ' 'for instance using `sklearn.decomposition.PCA` as a ' From 9e7ed19a3601192af611dfb8061ec5b695df50ba Mon Sep 17 00:00:00 2001 From: grudloff Date: Mon, 3 Feb 2020 10:14:09 +0100 Subject: [PATCH 10/17] fix test --- test/test_mahalanobis_mixin.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/test_mahalanobis_mixin.py b/test/test_mahalanobis_mixin.py index de9804b5..91fb435f 100644 --- a/test/test_mahalanobis_mixin.py +++ b/test/test_mahalanobis_mixin.py @@ -638,7 +638,10 @@ def test_singular_covariance_init_of_non_strict_pd(estimator, build_dataset): with pytest.warns(UserWarning) as raised_warning: model.fit(input_data, labels) assert np.any([str(warning.message) == msg for warning in raised_warning]) - M = model.get_mahalanobis_matrix() + M, _ = _initialize_metric_mahalanobis(X, init='covariance', + random_state=RNG, + return_inverse=True, + strict_pd=False) assert np.isfinite(M).all() From f6dd83a6ecef4cb3d7f0eae8fcb88cd0897286f2 Mon Sep 17 00:00:00 2001 From: grudloff Date: Mon, 3 Feb 2020 10:14:18 +0100 Subject: [PATCH 11/17] add test --- test/test_utils.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/test/test_utils.py b/test/test_utils.py index 3092e168..64093723 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,4 +1,5 @@ import pytest +from scipy.linalg import eigh, pinvh from collections import namedtuple import numpy as np from numpy.testing import assert_array_equal, assert_equal @@ -11,7 +12,7 @@ check_collapsed_pairs, validate_vector, _check_sdp_from_eigen, _check_n_components, check_y_valid_values_for_pairs, - _auto_select_init) + _auto_select_init, _pseudo_inverse_from_eig) from metric_learn import (ITML, LSML, MMC, RCA, SDML, Covariance, LFDA, LMNN, MLKR, NCA, ITML_Supervised, LSML_Supervised, MMC_Supervised, RCA_Supervised, SDML_Supervised, @@ -1146,3 +1147,11 @@ def test__auto_select_init(has_classes, n_features, n_samples, n_components, """Checks that the auto selection of the init works as expected""" assert (_auto_select_init(has_classes, n_features, n_samples, n_components, n_classes) == result) + + +def test_pseudo_inverse_from_eig_and_pinvh(): + """Checks that _pseudo_inverse_from_eig return the same result as + scipy.linalg.pinvh""" + A = np.random.rand(100, 100) + w, V = eigh(A, check_finite=False) + assert np.array_equal(_pseudo_inverse_from_eig(w, V), pinvh(A)) From 1e4664fe9ea1c16790d066eed4bf022b5b4a6df4 Mon Sep 17 00:00:00 2001 From: grudloff Date: Mon, 3 Feb 2020 11:19:18 +0100 Subject: [PATCH 12/17] fixes & adds singular pinv test fron eig --- test/test_utils.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 64093723..87b3df02 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1149,9 +1149,22 @@ def test__auto_select_init(has_classes, n_features, n_samples, n_components, n_samples, n_components, n_classes) == result) -def test_pseudo_inverse_from_eig_and_pinvh(): +@pytest.mark.parametrize('w0', [1e-20, 0., -1e-20]) +def test_pseudo_inverse_from_eig_and_pinvh_singular(w0): """Checks that _pseudo_inverse_from_eig return the same result as - scipy.linalg.pinvh""" + scipy.linalg.pinvh for a singular matrix""" A = np.random.rand(100, 100) + A = A + A.T + w, V = eigh(A) + w[0] = w0 + A = V.dot(np.diag(w)).dot(V.T) + np.testing.assert_allclose(_pseudo_inverse_from_eig(w, V), pinvh(A)) + + +def test_pseudo_inverse_from_eig_and_pinvh_nonsingular(): + """Checks that _pseudo_inverse_from_eig return the same result as + scipy.linalg.pinvh for a non singular matrix""" + A = np.random.rand(100, 100) + A = A + A.T w, V = eigh(A, check_finite=False) - assert np.array_equal(_pseudo_inverse_from_eig(w, V), pinvh(A)) + np.testing.assert_allclose(_pseudo_inverse_from_eig(w, V), pinvh(A)) From 5abab24f7e9aaf2791c5e4174ceb490f089f14a3 Mon Sep 17 00:00:00 2001 From: grudloff Date: Mon, 3 Feb 2020 13:46:41 +0100 Subject: [PATCH 13/17] fix tolerance of assert --- test/test_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_utils.py b/test/test_utils.py index 87b3df02..8db264a0 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1167,4 +1167,5 @@ def test_pseudo_inverse_from_eig_and_pinvh_nonsingular(): A = np.random.rand(100, 100) A = A + A.T w, V = eigh(A, check_finite=False) - np.testing.assert_allclose(_pseudo_inverse_from_eig(w, V), pinvh(A)) + np.testing.assert_allclose(_pseudo_inverse_from_eig(w, V), pinvh(A), + rtol=1e-09) From 6b3b02ec1d866051d6d7f53c79d4a8575c40225e Mon Sep 17 00:00:00 2001 From: grudloff Date: Mon, 3 Feb 2020 14:05:39 +0100 Subject: [PATCH 14/17] fix tolerance of assert --- test/test_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 8db264a0..45311970 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1158,7 +1158,8 @@ def test_pseudo_inverse_from_eig_and_pinvh_singular(w0): w, V = eigh(A) w[0] = w0 A = V.dot(np.diag(w)).dot(V.T) - np.testing.assert_allclose(_pseudo_inverse_from_eig(w, V), pinvh(A)) + np.testing.assert_allclose(_pseudo_inverse_from_eig(w, V), pinvh(A), + rtol=1e-06) def test_pseudo_inverse_from_eig_and_pinvh_nonsingular(): @@ -1167,5 +1168,4 @@ def test_pseudo_inverse_from_eig_and_pinvh_nonsingular(): A = np.random.rand(100, 100) A = A + A.T w, V = eigh(A, check_finite=False) - np.testing.assert_allclose(_pseudo_inverse_from_eig(w, V), pinvh(A), - rtol=1e-09) + np.testing.assert_allclose(_pseudo_inverse_from_eig(w, V), pinvh(A)) From a5c476be2bf69cc0537606c5104b0069c3b8bd36 Mon Sep 17 00:00:00 2001 From: grudloff Date: Mon, 3 Feb 2020 14:19:41 +0100 Subject: [PATCH 15/17] fix tolerance of assert --- test/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_utils.py b/test/test_utils.py index 45311970..f81b4794 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1159,7 +1159,7 @@ def test_pseudo_inverse_from_eig_and_pinvh_singular(w0): w[0] = w0 A = V.dot(np.diag(w)).dot(V.T) np.testing.assert_allclose(_pseudo_inverse_from_eig(w, V), pinvh(A), - rtol=1e-06) + rtol=1e-05) def test_pseudo_inverse_from_eig_and_pinvh_nonsingular(): From f2bc3e2987d6adb5abaca9d1b4d94bbbee7cebc2 Mon Sep 17 00:00:00 2001 From: grudloff Date: Tue, 4 Feb 2020 08:53:57 +0100 Subject: [PATCH 16/17] fix random seed --- test/test_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index f81b4794..1beb6324 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1151,8 +1151,9 @@ def test__auto_select_init(has_classes, n_features, n_samples, n_components, @pytest.mark.parametrize('w0', [1e-20, 0., -1e-20]) def test_pseudo_inverse_from_eig_and_pinvh_singular(w0): - """Checks that _pseudo_inverse_from_eig return the same result as + """Checks that _pseudo_inverse_from_eig returns the same result as scipy.linalg.pinvh for a singular matrix""" + np.random.seed(seed=SEED) A = np.random.rand(100, 100) A = A + A.T w, V = eigh(A) @@ -1163,8 +1164,9 @@ def test_pseudo_inverse_from_eig_and_pinvh_singular(w0): def test_pseudo_inverse_from_eig_and_pinvh_nonsingular(): - """Checks that _pseudo_inverse_from_eig return the same result as + """Checks that _pseudo_inverse_from_eig returns the same result as scipy.linalg.pinvh for a non singular matrix""" + np.random.seed(seed=SEED) A = np.random.rand(100, 100) A = A + A.T w, V = eigh(A, check_finite=False) From d1d7ec478a1126a047a19c7c1622ec53fe80fb23 Mon Sep 17 00:00:00 2001 From: grudloff Date: Tue, 4 Feb 2020 10:21:54 +0100 Subject: [PATCH 17/17] isolate random seed setting --- test/test_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 1beb6324..0ea871bb 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1153,8 +1153,8 @@ def test__auto_select_init(has_classes, n_features, n_samples, n_components, def test_pseudo_inverse_from_eig_and_pinvh_singular(w0): """Checks that _pseudo_inverse_from_eig returns the same result as scipy.linalg.pinvh for a singular matrix""" - np.random.seed(seed=SEED) - A = np.random.rand(100, 100) + rng = np.random.RandomState(SEED) + A = rng.rand(100, 100) A = A + A.T w, V = eigh(A) w[0] = w0 @@ -1166,8 +1166,8 @@ def test_pseudo_inverse_from_eig_and_pinvh_singular(w0): def test_pseudo_inverse_from_eig_and_pinvh_nonsingular(): """Checks that _pseudo_inverse_from_eig returns the same result as scipy.linalg.pinvh for a non singular matrix""" - np.random.seed(seed=SEED) - A = np.random.rand(100, 100) + rng = np.random.RandomState(SEED) + A = rng.rand(100, 100) A = A + A.T w, V = eigh(A, check_finite=False) np.testing.assert_allclose(_pseudo_inverse_from_eig(w, V), pinvh(A))