Skip to content

Commit 8525ba5

Browse files
authored
ENH speedup enet_coordinate_descent_gram (#31880)
1 parent 52fb066 commit 8525ba5

File tree

2 files changed

+26
-31
lines changed

2 files changed

+26
-31
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
- :class:`linear_model.ElasticNet`, :class:`linear_model.ElasticNetCV`,
2+
:class:`linear_model.Lasso` and :class:`linear_model.LassoCV` with `precompute=True`
3+
(or `precompute="auto"`` and `n_samples > n_features`) are faster to fit by
4+
avoiding a BLAS level 1 (axpy) call in the inner most loop.
5+
Same for functions :func:`linear_model.enet_path` and
6+
:func:`linear_model.lasso_path`.
7+
By :user:`Christian Lorentzen <lorentzenchr>`.

sklearn/linear_model/_cd_fast.pyx

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def sparse_enet_coordinate_descent(
337337
cdef unsigned int n_features = w.shape[0]
338338

339339
# compute norms of the columns of X
340-
cdef floating[:] norm_cols_X = np.zeros(n_features, dtype=dtype)
340+
cdef floating[::1] norm_cols_X = np.zeros(n_features, dtype=dtype)
341341

342342
# initial value of the residuals
343343
# R = y - Zw, weighted version R = sample_weight * (y - Zw)
@@ -609,9 +609,10 @@ def enet_coordinate_descent_gram(
609609
cdef unsigned int n_features = Q.shape[0]
610610

611611
# initial value "Q w" which will be kept of up to date in the iterations
612-
cdef floating[:] H = np.dot(Q, w)
612+
cdef floating[::1] Qw = np.dot(Q, w)
613+
cdef floating[::1] XtA = np.zeros(n_features, dtype=dtype)
614+
cdef floating y_norm2 = np.dot(y, y)
613615

614-
cdef floating[:] XtA = np.zeros(n_features, dtype=dtype)
615616
cdef floating tmp
616617
cdef floating w_ii
617618
cdef floating d_w_max
@@ -628,14 +629,6 @@ def enet_coordinate_descent_gram(
628629
cdef uint32_t rand_r_state_seed = rng.randint(0, RAND_R_MAX)
629630
cdef uint32_t* rand_r_state = &rand_r_state_seed
630631

631-
cdef floating y_norm2 = np.dot(y, y)
632-
cdef floating* w_ptr = &w[0]
633-
cdef const floating* Q_ptr = &Q[0, 0]
634-
cdef const floating* q_ptr = &q[0]
635-
cdef floating* H_ptr = &H[0]
636-
cdef floating* XtA_ptr = &XtA[0]
637-
tol = tol * y_norm2
638-
639632
if alpha == 0:
640633
warnings.warn(
641634
"Coordinate descent without L1 regularization may "
@@ -644,6 +637,7 @@ def enet_coordinate_descent_gram(
644637
)
645638

646639
with nogil:
640+
tol *= y_norm2
647641
for n_iter in range(max_iter):
648642
w_max = 0.0
649643
d_w_max = 0.0
@@ -658,23 +652,19 @@ def enet_coordinate_descent_gram(
658652

659653
w_ii = w[ii] # Store previous value
660654

661-
if w_ii != 0.0:
662-
# H -= w_ii * Q[ii]
663-
_axpy(n_features, -w_ii, Q_ptr + ii * n_features, 1,
664-
H_ptr, 1)
665-
666-
tmp = q[ii] - H[ii]
655+
# if Q = X.T @ X then tmp = X[:,ii] @ (y - X @ w + X[:, ii] * w_ii)
656+
tmp = q[ii] - Qw[ii] + w_ii * Q[ii, ii]
667657

668658
if positive and tmp < 0:
669659
w[ii] = 0.0
670660
else:
671661
w[ii] = fsign(tmp) * fmax(fabs(tmp) - alpha, 0) \
672662
/ (Q[ii, ii] + beta)
673663

674-
if w[ii] != 0.0:
675-
# H += w[ii] * Q[ii] # Update H = X.T X w
676-
_axpy(n_features, w[ii], Q_ptr + ii * n_features, 1,
677-
H_ptr, 1)
664+
if w[ii] != 0.0 or w_ii != 0.0:
665+
# Qw += (w[ii] - w_ii) * Q[ii] # Update Qw = Q @ w
666+
_axpy(n_features, w[ii] - w_ii, &Q[ii, 0], 1,
667+
&Qw[0], 1)
678668

679669
# update the maximum absolute coefficient update
680670
d_w_ii = fabs(w[ii] - w_ii)
@@ -689,23 +679,21 @@ def enet_coordinate_descent_gram(
689679
# the tolerance: check the duality gap as ultimate stopping
690680
# criterion
691681

692-
# q_dot_w = np.dot(w, q)
693-
q_dot_w = _dot(n_features, w_ptr, 1, q_ptr, 1)
682+
# q_dot_w = w @ q
683+
q_dot_w = _dot(n_features, &w[0], 1, &q[0], 1)
694684

695685
for ii in range(n_features):
696-
XtA[ii] = q[ii] - H[ii] - beta * w[ii]
686+
XtA[ii] = q[ii] - Qw[ii] - beta * w[ii]
697687
if positive:
698-
dual_norm_XtA = max(n_features, XtA_ptr)
688+
dual_norm_XtA = max(n_features, &XtA[0])
699689
else:
700-
dual_norm_XtA = abs_max(n_features, XtA_ptr)
690+
dual_norm_XtA = abs_max(n_features, &XtA[0])
701691

702-
# temp = np.sum(w * H)
703-
tmp = 0.0
704-
for ii in range(n_features):
705-
tmp += w[ii] * H[ii]
692+
# temp = w @ Q @ w
693+
tmp = _dot(n_features, &w[0], 1, &Qw[0], 1)
706694
R_norm2 = y_norm2 + tmp - 2.0 * q_dot_w
707695

708-
# w_norm2 = np.dot(w, w)
696+
# w_norm2 = w @ w
709697
w_norm2 = _dot(n_features, &w[0], 1, &w[0], 1)
710698

711699
if (dual_norm_XtA > alpha):

0 commit comments

Comments
 (0)