Skip to content

Commit 00acd12

Browse files
authored
ENH speedup coordinate descent by avoiding calls to axpy in innermost loop (scikit-learn#31956)
1 parent 5736956 commit 00acd12

File tree

3 files changed

+77
-101
lines changed

3 files changed

+77
-101
lines changed
Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
- :class:`linear_model.ElasticNet`, :class:`linear_model.ElasticNetCV`,
2-
:class:`linear_model.Lasso` and :class:`linear_model.LassoCV` with `precompute=True`
3-
(or `precompute="auto"`` and `n_samples > n_features`) are faster to fit by
4-
avoiding a BLAS level 1 (axpy) call in the inner most loop.
2+
:class:`linear_model.Lasso`, :class:`linear_model.LassoCV`,
3+
:class:`linear_model.MultiTaskElasticNet`,
4+
:class:`linear_model.MultiTaskElasticNetCV`,
5+
:class:`linear_model.MultiTaskLasso` and :class:`linear_model.MultiTaskLassoCV`
6+
are faster to fit by avoiding a BLAS level 1 (axpy) call in the innermost loop.
57
Same for functions :func:`linear_model.enet_path` and
68
:func:`linear_model.lasso_path`.
7-
By :user:`Christian Lorentzen <lorentzenchr>`.
9+
By :user:`Christian Lorentzen <lorentzenchr>` :pr:`31956` and

sklearn/linear_model/_cd_fast.pyx

Lines changed: 70 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -329,22 +329,18 @@ def enet_coordinate_descent(
329329

330330
w_j = w[j] # Store previous value
331331

332-
if w_j != 0.0:
333-
# R += w_j * X[:,j]
334-
_axpy(n_samples, w_j, &X[0, j], 1, &R[0], 1)
335-
336-
# tmp = (X[:,j]*R).sum()
337-
tmp = _dot(n_samples, &X[0, j], 1, &R[0], 1)
332+
# tmp = X[:,j] @ (R + w_j * X[:,j])
333+
tmp = _dot(n_samples, &X[0, j], 1, &R[0], 1) + w_j * norm2_cols_X[j]
338334

339335
if positive and tmp < 0:
340336
w[j] = 0.0
341337
else:
342338
w[j] = (fsign(tmp) * fmax(fabs(tmp) - alpha, 0)
343339
/ (norm2_cols_X[j] + beta))
344340

345-
if w[j] != 0.0:
346-
# R -= w[j] * X[:,j] # Update residual
347-
_axpy(n_samples, -w[j], &X[0, j], 1, &R[0], 1)
341+
if w[j] != w_j:
342+
# R -= (w[j] - w_j) * X[:,j] # Update residual
343+
_axpy(n_samples, w_j - w[j], &X[0, j], 1, &R[0], 1)
348344

349345
# update the maximum absolute coefficient update
350346
d_w_j = fabs(w[j] - w_j)
@@ -450,7 +446,7 @@ def sparse_enet_coordinate_descent(
450446
# We work with:
451447
# yw = sample_weight * y
452448
# R = sample_weight * residual
453-
# norm_cols_X = np.sum(sample_weight * (X - X_mean)**2, axis=0)
449+
# norm2_cols_X = np.sum(sample_weight * (X - X_mean)**2, axis=0)
454450

455451
if floating is float:
456452
dtype = np.float32
@@ -461,8 +457,8 @@ def sparse_enet_coordinate_descent(
461457
cdef unsigned int n_samples = y.shape[0]
462458
cdef unsigned int n_features = w.shape[0]
463459

464-
# compute norms of the columns of X
465-
cdef floating[::1] norm_cols_X = np.zeros(n_features, dtype=dtype)
460+
# compute squared norms of the columns of X
461+
cdef floating[::1] norm2_cols_X = np.zeros(n_features, dtype=dtype)
466462

467463
# initial value of the residuals
468464
# R = y - Zw, weighted version R = sample_weight * (y - Zw)
@@ -523,7 +519,7 @@ def sparse_enet_coordinate_descent(
523519
for jj in range(startptr, endptr):
524520
normalize_sum += (X_data[jj] - X_mean_ii) ** 2
525521
R[X_indices[jj]] -= X_data[jj] * w_ii
526-
norm_cols_X[ii] = normalize_sum + \
522+
norm2_cols_X[ii] = normalize_sum + \
527523
(n_samples - endptr + startptr) * X_mean_ii ** 2
528524
if center:
529525
for jj in range(n_samples):
@@ -542,7 +538,7 @@ def sparse_enet_coordinate_descent(
542538
normalize_sum += sample_weight[jj] * X_mean_ii ** 2
543539
R[jj] += sample_weight[jj] * X_mean_ii * w_ii
544540
R_sum += R[jj]
545-
norm_cols_X[ii] = normalize_sum
541+
norm2_cols_X[ii] = normalize_sum
546542
startptr = endptr
547543

548544
# Note: No need to update R_sum from here on because the update terms cancel
@@ -564,34 +560,19 @@ def sparse_enet_coordinate_descent(
564560
else:
565561
ii = f_iter
566562

567-
if norm_cols_X[ii] == 0.0:
563+
if norm2_cols_X[ii] == 0.0:
568564
continue
569565

570566
startptr = X_indptr[ii]
571567
endptr = X_indptr[ii + 1]
572568
w_ii = w[ii] # Store previous value
573569
X_mean_ii = X_mean[ii]
574570

575-
if w_ii != 0.0:
576-
# R += w_ii * X[:,ii]
577-
if no_sample_weights:
578-
for jj in range(startptr, endptr):
579-
R[X_indices[jj]] += X_data[jj] * w_ii
580-
if center:
581-
for jj in range(n_samples):
582-
R[jj] -= X_mean_ii * w_ii
583-
else:
584-
for jj in range(startptr, endptr):
585-
tmp = sample_weight[X_indices[jj]]
586-
R[X_indices[jj]] += tmp * X_data[jj] * w_ii
587-
if center:
588-
for jj in range(n_samples):
589-
R[jj] -= sample_weight[jj] * X_mean_ii * w_ii
590-
591-
# tmp = (X[:,ii] * R).sum()
571+
# tmp = X[:,ii] @ (R + w_ii * X[:,ii])
592572
tmp = 0.0
593573
for jj in range(startptr, endptr):
594574
tmp += R[X_indices[jj]] * X_data[jj]
575+
tmp += w_ii * norm2_cols_X[ii]
595576

596577
if center:
597578
tmp -= R_sum * X_mean_ii
@@ -600,23 +581,23 @@ def sparse_enet_coordinate_descent(
600581
w[ii] = 0.0
601582
else:
602583
w[ii] = fsign(tmp) * fmax(fabs(tmp) - alpha, 0) \
603-
/ (norm_cols_X[ii] + beta)
584+
/ (norm2_cols_X[ii] + beta)
604585

605-
if w[ii] != 0.0:
606-
# R -= w[ii] * X[:,ii] # Update residual
586+
if w[ii] != w_ii:
587+
# R -= (w[ii] - w_ii) * X[:,ii] # Update residual
607588
if no_sample_weights:
608589
for jj in range(startptr, endptr):
609-
R[X_indices[jj]] -= X_data[jj] * w[ii]
590+
R[X_indices[jj]] -= X_data[jj] * (w[ii] - w_ii)
610591
if center:
611592
for jj in range(n_samples):
612-
R[jj] += X_mean_ii * w[ii]
593+
R[jj] += X_mean_ii * (w[ii] - w_ii)
613594
else:
614595
for jj in range(startptr, endptr):
615-
tmp = sample_weight[X_indices[jj]]
616-
R[X_indices[jj]] -= tmp * X_data[jj] * w[ii]
596+
kk = X_indices[jj]
597+
R[kk] -= sample_weight[kk] * X_data[jj] * (w[ii] - w_ii)
617598
if center:
618599
for jj in range(n_samples):
619-
R[jj] += sample_weight[jj] * X_mean_ii * w[ii]
600+
R[jj] += sample_weight[jj] * X_mean_ii * (w[ii] - w_ii)
620601

621602
# update the maximum absolute coefficient update
622603
d_w_ii = fabs(w[ii] - w_ii)
@@ -744,10 +725,13 @@ def enet_coordinate_descent_gram(
744725
cdef floating w_max
745726
cdef floating d_w_ii
746727
cdef floating q_dot_w
747-
cdef floating w_norm2
748728
cdef floating gap = tol + 1.0
749729
cdef floating d_w_tol = tol
750730
cdef floating dual_norm_XtA
731+
cdef floating R_norm2
732+
cdef floating w_norm2
733+
cdef floating A_norm2
734+
cdef floating const_
751735
cdef unsigned int ii
752736
cdef unsigned int n_iter = 0
753737
cdef unsigned int f_iter
@@ -786,7 +770,7 @@ def enet_coordinate_descent_gram(
786770
w[ii] = fsign(tmp) * fmax(fabs(tmp) - alpha, 0) \
787771
/ (Q[ii, ii] + beta)
788772

789-
if w[ii] != 0.0 or w_ii != 0.0:
773+
if w[ii] != w_ii:
790774
# Qw += (w[ii] - w_ii) * Q[ii] # Update Qw = Q @ w
791775
_axpy(n_features, w[ii] - w_ii, &Q[ii, 0], 1,
792776
&Qw[0], 1)
@@ -899,6 +883,12 @@ def enet_coordinate_descent_multi_task(
899883
cdef unsigned int n_features = X.shape[1]
900884
cdef unsigned int n_tasks = Y.shape[1]
901885

886+
# compute squared norms of the columns of X
887+
# same as norm2_cols_X = np.square(X).sum(axis=0)
888+
cdef floating[::1] norm2_cols_X = np.einsum(
889+
"ij,ij->j", X, X, dtype=dtype, order="C"
890+
)
891+
902892
# to store XtA
903893
cdef floating[:, ::1] XtA = np.zeros((n_features, n_tasks), dtype=dtype)
904894
cdef floating XtA_axis1norm
@@ -907,7 +897,6 @@ def enet_coordinate_descent_multi_task(
907897
# initial value of the residuals
908898
cdef floating[::1, :] R = np.zeros((n_samples, n_tasks), dtype=dtype, order='F')
909899

910-
cdef floating[::1] norm_cols_X = np.zeros(n_features, dtype=dtype)
911900
cdef floating[::1] tmp = np.zeros(n_tasks, dtype=dtype)
912901
cdef floating[::1] w_ii = np.zeros(n_tasks, dtype=dtype)
913902
cdef floating d_w_max
@@ -917,8 +906,8 @@ def enet_coordinate_descent_multi_task(
917906
cdef floating W_ii_abs_max
918907
cdef floating gap = tol + 1.0
919908
cdef floating d_w_tol = tol
920-
cdef floating R_norm
921-
cdef floating w_norm
909+
cdef floating R_norm2
910+
cdef floating w_norm2
922911
cdef floating ry_sum
923912
cdef floating l21_norm
924913
cdef unsigned int ii
@@ -928,30 +917,23 @@ def enet_coordinate_descent_multi_task(
928917
cdef uint32_t rand_r_state_seed = rng.randint(0, RAND_R_MAX)
929918
cdef uint32_t* rand_r_state = &rand_r_state_seed
930919

931-
cdef const floating* X_ptr = &X[0, 0]
932-
cdef const floating* Y_ptr = &Y[0, 0]
933-
934920
if l1_reg == 0:
935921
warnings.warn(
936922
"Coordinate descent with l1_reg=0 may lead to unexpected"
937923
" results and is discouraged."
938924
)
939925

940926
with nogil:
941-
# norm_cols_X = (np.asarray(X) ** 2).sum(axis=0)
942-
for ii in range(n_features):
943-
norm_cols_X[ii] = _nrm2(n_samples, X_ptr + ii * n_samples, 1) ** 2
944-
945927
# R = Y - np.dot(X, W.T)
946-
_copy(n_samples * n_tasks, Y_ptr, 1, &R[0, 0], 1)
928+
_copy(n_samples * n_tasks, &Y[0, 0], 1, &R[0, 0], 1)
947929
for ii in range(n_features):
948930
for jj in range(n_tasks):
949931
if W[jj, ii] != 0:
950-
_axpy(n_samples, -W[jj, ii], X_ptr + ii * n_samples, 1,
932+
_axpy(n_samples, -W[jj, ii], &X[0, ii], 1,
951933
&R[0, jj], 1)
952934

953935
# tol = tol * linalg.norm(Y, ord='fro') ** 2
954-
tol = tol * _nrm2(n_samples * n_tasks, Y_ptr, 1) ** 2
936+
tol = tol * _nrm2(n_samples * n_tasks, &Y[0, 0], 1) ** 2
955937

956938
for n_iter in range(max_iter):
957939
w_max = 0.0
@@ -962,54 +944,47 @@ def enet_coordinate_descent_multi_task(
962944
else:
963945
ii = f_iter
964946

965-
if norm_cols_X[ii] == 0.0:
947+
if norm2_cols_X[ii] == 0.0:
966948
continue
967949

968950
# w_ii = W[:, ii] # Store previous value
969951
_copy(n_tasks, &W[0, ii], 1, &w_ii[0], 1)
970952

971-
# Using Numpy:
972-
# R += np.dot(X[:, ii][:, None], w_ii[None, :]) # rank 1 update
973-
# Using Blas Level2:
974-
# _ger(RowMajor, n_samples, n_tasks, 1.0,
975-
# &X[0, ii], 1,
976-
# &w_ii[0], 1, &R[0, 0], n_tasks)
977-
# Using Blas Level1 and for loop to avoid slower threads
978-
# for such small vectors
979-
for jj in range(n_tasks):
980-
if w_ii[jj] != 0:
981-
_axpy(n_samples, w_ii[jj], X_ptr + ii * n_samples, 1,
982-
&R[0, jj], 1)
983-
984-
# Using numpy:
985-
# tmp = np.dot(X[:, ii][None, :], R).ravel()
986-
# Using BLAS Level 2:
987-
# _gemv(RowMajor, Trans, n_samples, n_tasks, 1.0, &R[0, 0],
988-
# n_tasks, &X[0, ii], 1, 0.0, &tmp[0], 1)
953+
# tmp = X[:, ii] @ (R + w_ii * X[:,ii][:, None])
954+
# first part: X[:, ii] @ R
955+
# Using BLAS Level 2:
956+
# _gemv(RowMajor, Trans, n_samples, n_tasks, 1.0, &R[0, 0],
957+
# n_tasks, &X[0, ii], 1, 0.0, &tmp[0], 1)
958+
# second part: (X[:, ii] @ X[:,ii]) * w_ii = norm2_cols * w_ii
959+
# Using BLAS Level 1:
960+
# _axpy(n_tasks, norm2_cols[ii], &w_ii[0], 1, &tmp[0], 1)
989961
# Using BLAS Level 1 (faster for small vectors like here):
990962
for jj in range(n_tasks):
991-
tmp[jj] = _dot(n_samples, X_ptr + ii * n_samples, 1,
992-
&R[0, jj], 1)
963+
tmp[jj] = _dot(n_samples, &X[0, ii], 1, &R[0, jj], 1)
964+
# As we have the loop already, we use it to replace the second BLAS
965+
# Level 1, i.e., _axpy, too.
966+
tmp[jj] += w_ii[jj] * norm2_cols_X[ii]
993967

994968
# nn = sqrt(np.sum(tmp ** 2))
995969
nn = _nrm2(n_tasks, &tmp[0], 1)
996970

997-
# W[:, ii] = tmp * fmax(1. - l1_reg / nn, 0) / (norm_cols_X[ii] + l2_reg)
971+
# W[:, ii] = tmp * fmax(1. - l1_reg / nn, 0) / (norm2_cols_X[ii] + l2_reg)
998972
_copy(n_tasks, &tmp[0], 1, &W[0, ii], 1)
999-
_scal(n_tasks, fmax(1. - l1_reg / nn, 0) / (norm_cols_X[ii] + l2_reg),
973+
_scal(n_tasks, fmax(1. - l1_reg / nn, 0) / (norm2_cols_X[ii] + l2_reg),
1000974
&W[0, ii], 1)
1001975

976+
# Update residual
1002977
# Using numpy:
1003-
# R -= np.dot(X[:, ii][:, None], W[:, ii][None, :])
1004-
# Using BLAS Level 2:
1005-
# Update residual : rank 1 update
1006-
# _ger(RowMajor, n_samples, n_tasks, -1.0,
1007-
# &X[0, ii], 1, &W[0, ii], 1,
1008-
# &R[0, 0], n_tasks)
978+
# R -= (W[:, ii] - w_ii) * X[:, ii][:, None]
979+
# Using BLAS Level 1 and 2:
980+
# _axpy(n_tasks, -1.0, &W[0, ii], 1, &w_ii[0], 1)
981+
# _ger(RowMajor, n_samples, n_tasks, 1.0,
982+
# &X[0, ii], 1, &w_ii, 1,
983+
# &R[0, 0], n_tasks)
1009984
# Using BLAS Level 1 (faster for small vectors like here):
1010985
for jj in range(n_tasks):
1011-
if W[jj, ii] != 0:
1012-
_axpy(n_samples, -W[jj, ii], X_ptr + ii * n_samples, 1,
986+
if W[jj, ii] != w_ii[jj]:
987+
_axpy(n_samples, w_ii[jj] - W[jj, ii], &X[0, ii], 1,
1013988
&R[0, jj], 1)
1014989

1015990
# update the maximum absolute coefficient update
@@ -1031,7 +1006,7 @@ def enet_coordinate_descent_multi_task(
10311006
for ii in range(n_features):
10321007
for jj in range(n_tasks):
10331008
XtA[ii, jj] = _dot(
1034-
n_samples, X_ptr + ii * n_samples, 1, &R[0, jj], 1
1009+
n_samples, &X[0, ii], 1, &R[0, jj], 1
10351010
) - l2_reg * W[jj, ii]
10361011

10371012
# dual_norm_XtA = np.max(np.sqrt(np.sum(XtA ** 2, axis=1)))
@@ -1042,18 +1017,17 @@ def enet_coordinate_descent_multi_task(
10421017
if XtA_axis1norm > dual_norm_XtA:
10431018
dual_norm_XtA = XtA_axis1norm
10441019

1045-
# TODO: use squared L2 norm directly
1046-
# R_norm = linalg.norm(R, ord='fro')
1047-
# w_norm = linalg.norm(W, ord='fro')
1048-
R_norm = _nrm2(n_samples * n_tasks, &R[0, 0], 1)
1049-
w_norm = _nrm2(n_features * n_tasks, &W[0, 0], 1)
1020+
# R_norm2 = linalg.norm(R, ord='fro') ** 2
1021+
# w_norm2 = linalg.norm(W, ord='fro') ** 2
1022+
R_norm2 = _dot(n_samples * n_tasks, &R[0, 0], 1, &R[0, 0], 1)
1023+
w_norm2 = _dot(n_features * n_tasks, &W[0, 0], 1, &W[0, 0], 1)
10501024
if (dual_norm_XtA > l1_reg):
10511025
const_ = l1_reg / dual_norm_XtA
1052-
A_norm = R_norm * const_
1053-
gap = 0.5 * (R_norm ** 2 + A_norm ** 2)
1026+
A_norm2 = R_norm2 * (const_ ** 2)
1027+
gap = 0.5 * (R_norm2 + A_norm2)
10541028
else:
10551029
const_ = 1.0
1056-
gap = R_norm ** 2
1030+
gap = R_norm2
10571031

10581032
# ry_sum = np.sum(R * y)
10591033
ry_sum = _dot(n_samples * n_tasks, &R[0, 0], 1, &Y[0, 0], 1)
@@ -1066,7 +1040,7 @@ def enet_coordinate_descent_multi_task(
10661040
gap += (
10671041
l1_reg * l21_norm
10681042
- const_ * ry_sum
1069-
+ 0.5 * l2_reg * (1 + const_ ** 2) * (w_norm ** 2)
1043+
+ 0.5 * l2_reg * (1 + const_ ** 2) * w_norm2
10701044
)
10711045

10721046
if gap <= tol:

sklearn/linear_model/tests/test_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def test_model_pipeline_same_dense_and_sparse(LinearModel, params, csr_container
278278
model_dense.fit(X, y)
279279
model_sparse.fit(X_sparse, y)
280280

281-
assert_allclose(model_sparse[1].coef_, model_dense[1].coef_, atol=1e-16)
281+
assert_allclose(model_sparse[1].coef_, model_dense[1].coef_, atol=1e-15)
282282
y_pred_dense = model_dense.predict(X)
283283
y_pred_sparse = model_sparse.predict(X_sparse)
284284
assert_allclose(y_pred_dense, y_pred_sparse)

0 commit comments

Comments
 (0)