Skip to content

Commit 692cd8b

Browse files
stevendbrownagramfort
authored andcommitted
[MRG+1] Reduce runtime of graph_lasso (#9858)
* reduce runtime of graph_lasso * fixed line length overrun * added comment explaining the change * changed explanation comment
1 parent 493f11c commit 692cd8b

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

sklearn/covariance/graph_lasso_.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -203,10 +203,19 @@ def graph_lasso(emp_cov, alpha, cov_init=None, mode='cd', tol=1e-4,
203203
# be robust to the max_iter=0 edge case, see:
204204
# https://github.com/scikit-learn/scikit-learn/issues/4134
205205
d_gap = np.inf
206+
# set a sub_covariance buffer
207+
sub_covariance = np.ascontiguousarray(covariance_[1:, 1:])
206208
for i in range(max_iter):
207209
for idx in range(n_features):
208-
sub_covariance = np.ascontiguousarray(
209-
covariance_[indices != idx].T[indices != idx])
210+
# To keep the contiguous matrix `sub_covariance` equal to
211+
# covariance_[indices != idx].T[indices != idx]
212+
# we only need to update 1 column and 1 line when idx changes
213+
if idx > 0:
214+
di = idx - 1
215+
sub_covariance[di] = covariance_[di][indices != idx]
216+
sub_covariance[:, di] = covariance_[:, di][indices != idx]
217+
else:
218+
sub_covariance[:] = covariance_[1:, 1:]
210219
row = emp_cov[idx, indices != idx]
211220
with np.errstate(**errors):
212221
if mode == 'cd':

0 commit comments

Comments
 (0)