Skip to content

Commit 12f4eb9

Browse files
jakevdpamueller
authored andcommitted
@jakevdp's version of pinvh
speed up symmetric_pinv
1 parent 602bac4 commit 12f4eb9

File tree

5 files changed

+35
-36
lines changed

5 files changed

+35
-36
lines changed

sklearn/covariance/empirical_covariance_.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from ..base import BaseEstimator
1919
from ..utils import array2d
20-
from ..utils.extmath import fast_logdet, symmetric_pinv
20+
from ..utils.extmath import fast_logdet, pinvh
2121

2222

2323
def log_likelihood(emp_cov, precision):
@@ -113,7 +113,7 @@ def _set_covariance(self, covariance):
113113
self.covariance_ = covariance
114114
# set precision
115115
if self.store_precision:
116-
self.precision_ = symmetric_pinv(covariance)
116+
self.precision_ = pinvh(covariance)
117117
else:
118118
self.precision_ = None
119119

@@ -129,7 +129,7 @@ def get_precision(self):
129129
if self.store_precision:
130130
precision = self.precision_
131131
else:
132-
precision = symmetric_pinv(self.covariance_)
132+
precision = pinvh(self.covariance_)
133133
return precision
134134

135135
def fit(self, X):

sklearn/covariance/graph_lasso_.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
EmpiricalCovariance, log_likelihood
1818

1919
from ..utils import ConvergenceWarning
20-
from ..utils.extmath import symmetric_pinv
20+
from ..utils.extmath import pinvh
2121
from ..linear_model import lars_path
2222
from ..linear_model import cd_fast
2323
from ..cross_validation import check_cv, cross_val_score
@@ -144,7 +144,7 @@ def graph_lasso(emp_cov, alpha, cov_init=None, mode='cd', tol=1e-4,
144144
covariance_ *= 0.95
145145
diagonal = emp_cov.flat[::n_features + 1]
146146
covariance_.flat[::n_features + 1] = diagonal
147-
precision_ = symmetric_pinv(covariance_)
147+
precision_ = pinvh(covariance_)
148148

149149
indices = np.arange(n_features)
150150
costs = list()

sklearn/covariance/robust_covariance.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from scipy.stats import chi2
1414

1515
from . import empirical_covariance, EmpiricalCovariance
16-
from ..utils.extmath import fast_logdet, symmetric_pinv
16+
from ..utils.extmath import fast_logdet, pinvh
1717
from ..utils import check_random_state
1818

1919

@@ -85,7 +85,7 @@ def c_step(X, n_support, remaining_iterations=30, initial_estimates=None,
8585
location = initial_estimates[0]
8686
covariance = initial_estimates[1]
8787
# run a special iteration for that case (to get an initial support)
88-
precision = symmetric_pinv(covariance)
88+
precision = pinvh(covariance)
8989
X_centered = X - location
9090
dist = (np.dot(X_centered, precision) * X_centered).sum(1)
9191
# compute new estimates
@@ -104,7 +104,7 @@ def c_step(X, n_support, remaining_iterations=30, initial_estimates=None,
104104
previous_det = det
105105
previous_support = support
106106
# compute a new support from the full data set mahalanobis distances
107-
precision = symmetric_pinv(covariance)
107+
precision = pinvh(covariance)
108108
X_centered = X - location
109109
dist = (np.dot(X_centered, precision) * X_centered).sum(axis=1)
110110
# compute new estimates
@@ -344,7 +344,7 @@ def fast_mcd(X, support_fraction=None,
344344
covariance = np.asarray([[np.var(X[support])]])
345345
location = np.array([location])
346346
# get precision matrix in an optimized way
347-
precision = symmetric_pinv(covariance)
347+
precision = pinvh(covariance)
348348
dist = (np.dot(X_centered, precision) \
349349
* (X_centered)).sum(axis=1)
350350

@@ -545,7 +545,7 @@ def fit(self, X):
545545
raw_covariance = self._nonrobust_covariance(
546546
X[raw_support], assume_centered=True)
547547
# get precision matrix in an optimized way
548-
precision = symmetric_pinv(raw_covariance)
548+
precision = pinvh(raw_covariance)
549549
raw_dist = np.sum(np.dot(X, precision) * X, 1)
550550
self.raw_location_ = raw_location
551551
self.raw_covariance_ = raw_covariance

sklearn/utils/extmath.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -299,61 +299,60 @@ def weighted_mode(a, w, axis=0):
299299
return mostfrequent, oldcounts
300300

301301

302-
def symmetric_pinv(a, cond=None, rcond=None):
303-
"""Compute the (Moore-Penrose) pseudo-inverse of a matrix.
302+
def pinvh(a, cond=None, rcond=None, lower=True):
303+
"""Compute the (Moore-Penrose) pseudo-inverse of a hermetian matrix.
304304
305305
Calculate a generalized inverse of a symmetric matrix using its
306306
eigenvalue decomposition and including all 'large' eigenvalues.
307307
308-
Inspired by ``scipy.linalg.pinv2``, credited to Pearu Peterson and Travis
309-
Oliphant.
310-
311308
Parameters
312309
----------
313310
a : array, shape (N, N)
314-
Symmetric matrix to be pseudo-inverted
311+
Real symmetric or complex hermetian matrix to be pseudo-inverted
315312
cond, rcond : float or None
316313
Cutoff for 'small' eigenvalues.
317314
Singular values smaller than rcond * largest_eigenvalue are considered
318315
zero.
319316
320317
If None or -1, suitable machine precision is used.
318+
lower : boolean
319+
Whether the pertinent array data is taken from the lower or upper
320+
triangle of a. (Default: lower)
321321
322322
Returns
323323
-------
324324
B : array, shape (N, N)
325325
326-
Raises LinAlgError if eigenvalue does not converge
326+
Raises
327+
------
328+
LinAlgError
329+
If eigenvalue does not converge
327330
328331
Examples
329332
--------
330333
>>> from numpy import *
331334
>>> a = random.randn(9, 6)
332335
>>> a = np.dot(a, a.T)
333-
>>> B = symmetric_pinv(a)
336+
>>> B = pinvh(a)
334337
>>> allclose(a, dot(a, dot(B, a)))
335338
True
336339
>>> allclose(B, dot(B, dot(a, B)))
337340
True
338341
339342
"""
340343
a = np.asarray_chkfinite(a)
341-
s, u = linalg.eigh(a)
342-
# eigh returns eigvals in reverse order, but this doesn't affect anything.
344+
s, u = linalg.eigh(a, lower=lower)
343345

344-
t = u.dtype.char
345346
if rcond is not None:
346347
cond = rcond
347348
if cond in [None, -1]:
348-
eps = np.finfo(np.float).eps
349-
feps = np.finfo(np.single).eps
350-
_array_precision = {'f': 0, 'd': 1, 'F': 0, 'D': 1}
351-
cond = {0: feps * 1e3, 1: eps * 1e6}[_array_precision[t]]
352-
n = a.shape[0]
353-
cutoff = cond * np.maximum.reduce(s)
354-
psigma = np.zeros(n, t)
355-
above_cutoff = np.where(s > cutoff)
356-
psigma[above_cutoff] = 1.0 / np.conjugate(s[above_cutoff])
357-
#XXX: use lapack/blas routines for dot
358-
#XXX: above comment is from scipy, but I (@vene)'ll take a look
359-
return np.transpose(np.conjugate(np.dot(u * psigma, u.T.conjugate())))
349+
t = u.dtype.char.lower()
350+
factor = {'f': 1E3, 'd': 1E6}
351+
cond = factor[t] * np.finfo(t).eps
352+
353+
# unlike svd case, eigh can lead to negative eigenvalues
354+
above_cutoff = (abs(s) > cond * np.max(abs(s)))
355+
psigma_diag = np.zeros_like(s)
356+
psigma_diag[above_cutoff] = 1.0 / s[above_cutoff]
357+
358+
return np.dot(u * psigma_diag, np.conjugate(u).T)

sklearn/utils/tests/test_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from sklearn.utils import deprecated
1212
from sklearn.utils import resample
1313
from sklearn.utils import safe_mask
14-
from sklearn.utils.extmath import symmetric_pinv
14+
from sklearn.utils.extmath import pinvh
1515

1616

1717
def test_make_rng():
@@ -93,7 +93,7 @@ def test_safe_mask():
9393
assert_equal(X_csr[mask].shape[0], 3)
9494

9595

96-
def test_symmetric_pinv():
96+
def test_pinvh():
9797
a = np.random.randn(5, 3)
9898
a = np.dot(a, a.T) # symmetric singular matrix
99-
assert_almost_equal(pinv2(a), symmetric_pinv(a))
99+
assert_almost_equal(pinv2(a), pinvh(a))

0 commit comments

Comments
 (0)