Skip to content

Commit 0946f59

Browse files
veneamueller
authored andcommitted
Compute pseudoinverse using eigendecomposition
1 parent 8d103b5 commit 0946f59

File tree

5 files changed

+88
-13
lines changed

5 files changed

+88
-13
lines changed

sklearn/covariance/empirical_covariance_.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from ..base import BaseEstimator
1919
from ..utils import array2d
20-
from ..utils.extmath import fast_logdet
20+
from ..utils.extmath import fast_logdet, symmetric_pinv
2121

2222

2323
def log_likelihood(emp_cov, precision):
@@ -113,7 +113,7 @@ def _set_covariance(self, covariance):
113113
self.covariance_ = covariance
114114
# set precision
115115
if self.store_precision:
116-
self.precision_ = linalg.pinv(covariance)
116+
self.precision_ = symmetric_pinv(covariance)
117117
else:
118118
self.precision_ = None
119119

@@ -129,7 +129,7 @@ def get_precision(self):
129129
if self.store_precision:
130130
precision = self.precision_
131131
else:
132-
precision = linalg.pinv(self.covariance_)
132+
precision = symmetric_pinv(self.covariance_)
133133
return precision
134134

135135
def fit(self, X):

sklearn/covariance/graph_lasso_.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
EmpiricalCovariance, log_likelihood
1818

1919
from ..utils import ConvergenceWarning
20+
from ..utils.extmath import symmetric_pinv
2021
from ..linear_model import lars_path
2122
from ..linear_model import cd_fast
2223
from ..cross_validation import check_cv, cross_val_score
@@ -143,7 +144,7 @@ def graph_lasso(emp_cov, alpha, cov_init=None, mode='cd', tol=1e-4,
143144
covariance_ *= 0.95
144145
diagonal = emp_cov.flat[::n_features + 1]
145146
covariance_.flat[::n_features + 1] = diagonal
146-
precision_ = linalg.pinv(covariance_)
147+
precision_ = symmetric_pinv(covariance_)
147148

148149
indices = np.arange(n_features)
149150
costs = list()

sklearn/covariance/robust_covariance.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from scipy.stats import chi2
1414

1515
from . import empirical_covariance, EmpiricalCovariance
16-
from ..utils.extmath import fast_logdet
16+
from ..utils.extmath import fast_logdet, symmetric_pinv
1717
from ..utils import check_random_state
1818

1919

@@ -85,7 +85,7 @@ def c_step(X, n_support, remaining_iterations=30, initial_estimates=None,
8585
location = initial_estimates[0]
8686
covariance = initial_estimates[1]
8787
# run a special iteration for that case (to get an initial support)
88-
precision = linalg.pinv(covariance)
88+
precision = symmetric_pinv(covariance)
8989
X_centered = X - location
9090
dist = (np.dot(X_centered, precision) * X_centered).sum(1)
9191
# compute new estimates
@@ -98,15 +98,15 @@ def c_step(X, n_support, remaining_iterations=30, initial_estimates=None,
9898
# Iterative procedure for Minimum Covariance Determinant computation
9999
det = fast_logdet(covariance)
100100
while (det < previous_det) and (remaining_iterations > 0):
101-
# compute a new support from the full data set mahalanobis distances
102-
precision = linalg.pinv(covariance)
103-
X_centered = X - location
104-
dist = (np.dot(X_centered, precision) * X_centered).sum(axis=1)
105101
# save old estimates values
106102
previous_location = location
107103
previous_covariance = covariance
108104
previous_det = det
109105
previous_support = support
106+
# compute a new support from the full data set mahalanobis distances
107+
precision = symmetric_pinv(covariance)
108+
X_centered = X - location
109+
dist = (np.dot(X_centered, precision) * X_centered).sum(axis=1)
110110
# compute new estimates
111111
support = np.zeros(n_samples).astype(bool)
112112
support[np.argsort(dist)[:n_support]] = True
@@ -343,7 +343,8 @@ def fast_mcd(X, support_fraction=None,
343343
support[np.argsort(np.abs(X - location), axis=0)[:n_support]] = True
344344
covariance = np.asarray([[np.var(X[support])]])
345345
location = np.array([location])
346-
dist = (np.dot(X_centered, linalg.pinv(covariance)) \
346+
precision = symmetric_pinv(covariance)
347+
dist = (np.dot(X_centered, precision) \
347348
* (X_centered)).sum(axis=1)
348349

349350
### Starting FastMCD algorithm for p-dimensional case
@@ -542,7 +543,8 @@ def fit(self, X):
542543
raw_location = np.zeros(n_features)
543544
raw_covariance = self._nonrobust_covariance(
544545
X[raw_support], assume_centered=True)
545-
raw_dist = np.sum(np.dot(X, linalg.pinv(raw_covariance)) * X, 1)
546+
precision = symmetric_pinv(raw_covariance)
547+
raw_dist = np.sum(np.dot(X, precision) * X, 1)
546548
self.raw_location_ = raw_location
547549
self.raw_covariance_ = raw_covariance
548550
self.raw_support_ = raw_support

sklearn/utils/extmath.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,3 +297,65 @@ def weighted_mode(a, w, axis=0):
297297
oldcounts = np.maximum(counts, oldcounts)
298298
oldmostfreq = mostfrequent
299299
return mostfrequent, oldcounts
300+
301+
302+
def symmetric_pinv(a, cond=None, rcond=None):
303+
"""Compute the (Moore-Penrose) pseudo-inverse of a matrix.
304+
305+
Calculate a generalized inverse of a symmetric matrix using its
306+
eigenvalue decomposition and including all 'large' eigenvalues.
307+
308+
Inspired by ``scipy.linalg.pinv2``, credited to Pearu Peterson and Travis
309+
Oliphant.
310+
311+
Parameters
312+
----------
313+
a : array, shape (N, N)
314+
Symmetric matrix to be pseudo-inverted
315+
cond, rcond : float or None
316+
Cutoff for 'small' eigenvalues.
317+
Singular values smaller than rcond * largest_eigenvalue are considered
318+
zero.
319+
320+
If None or -1, suitable machine precision is used.
321+
322+
Returns
323+
-------
324+
B : array, shape (N, N)
325+
326+
Raises LinAlgError if eigenvalue does not converge
327+
328+
Examples
329+
--------
330+
>>> from numpy import *
331+
>>> a = random.randn(9, 6)
332+
>>> a = np.dot(a, a.T)
333+
>>> B = symmetric_pinv(a)
334+
>>> allclose(a, dot(a, dot(B, a)))
335+
True
336+
>>> allclose(B, dot(B, dot(a, B)))
337+
True
338+
339+
"""
340+
a = np.asarray_chkfinite(a)
341+
s, u = linalg.eigh(a)
342+
# eigh returns eigvals in reverse order, but this doesn't affect anything.
343+
344+
t = u.dtype.char
345+
if rcond is not None:
346+
cond = rcond
347+
if cond in [None, -1]:
348+
eps = np.finfo(np.float).eps
349+
feps = np.finfo(np.single).eps
350+
_array_precision = {'f': 0, 'd': 1, 'F': 0, 'D': 1}
351+
cond = {0: feps * 1e3, 1: eps * 1e6}[_array_precision[t]]
352+
n = a.shape[0]
353+
cutoff = cond * np.maximum.reduce(s)
354+
psigma = np.zeros((n, n), t)
355+
for i in range(len(s)):
356+
if s[i] > cutoff:
357+
psigma[i, i] = 1.0 / np.conjugate(s[i])
358+
#XXX: use lapack/blas routines for dot
359+
#XXX: above comment is from scipy, but I (@vene)'ll take a look
360+
return np.transpose(np.conjugate(np.dot(np.dot(u, psigma),
361+
u.T.conjugate())))

sklearn/utils/tests/test_utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
1-
from nose.tools import assert_equal, assert_raises, assert_true
21
import warnings
32

43
import numpy as np
54
import scipy.sparse as sp
5+
from scipy.linalg import pinv2
6+
7+
from nose.tools import assert_equal, assert_raises, assert_true
8+
from numpy.testing import assert_almost_equal
69

710
from sklearn.utils import check_random_state
811
from sklearn.utils import deprecated
912
from sklearn.utils import resample
1013
from sklearn.utils import safe_mask
14+
from sklearn.utils.extmath import symmetric_pinv
1115

1216

1317
def test_make_rng():
@@ -87,3 +91,9 @@ def test_safe_mask():
8791

8892
mask = safe_mask(X_csr, mask)
8993
assert_equal(X_csr[mask].shape[0], 3)
94+
95+
96+
def test_symmetric_pinv():
97+
a = np.random.randn(5, 3)
98+
a = np.dot(a, a.T) # symmetric singular matrix
99+
assert_almost_equal(pinv2(a), symmetric_pinv(a))

0 commit comments

Comments
 (0)