From d64311eca525a7ab70e08e464df1a4c4b1bc29dd Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Sat, 16 Aug 2025 09:14:04 +0200 Subject: [PATCH 01/10] ENH avoid axpy in enet_coordinate_descent --- sklearn/linear_model/_cd_fast.pyx | 50 ++++++++++++++----------------- 1 file changed, 23 insertions(+), 27 deletions(-) diff --git a/sklearn/linear_model/_cd_fast.pyx b/sklearn/linear_model/_cd_fast.pyx index 422da51c21d88..5842a8f1db2ea 100644 --- a/sklearn/linear_model/_cd_fast.pyx +++ b/sklearn/linear_model/_cd_fast.pyx @@ -178,9 +178,9 @@ def enet_coordinate_descent( cdef unsigned int n_samples = X.shape[0] cdef unsigned int n_features = X.shape[1] - # compute norms of the columns of X - # same as norm_cols_X = np.square(X).sum(axis=0) - cdef floating[::1] norm_cols_X = np.einsum( + # compute squared norms of the columns of X + # same as norm2_cols_X = np.square(X).sum(axis=0) + cdef floating[::1] norm2_cols_X = np.einsum( "ij,ij->j", X, X, dtype=dtype, order="C" ) @@ -229,27 +229,23 @@ def enet_coordinate_descent( else: ii = f_iter - if norm_cols_X[ii] == 0.0: + if norm2_cols_X[ii] == 0.0: continue w_ii = w[ii] # Store previous value - if w_ii != 0.0: - # R += w_ii * X[:,ii] - _axpy(n_samples, w_ii, &X[0, ii], 1, &R[0], 1) - - # tmp = (X[:,ii]*R).sum() - tmp = _dot(n_samples, &X[0, ii], 1, &R[0], 1) + # tmp = X[:,ii] @ (R + w_ii * X[:,ii]) + tmp = _dot(n_samples, &X[0, ii], 1, &R[0], 1) + w_ii * norm2_cols_X[ii] if positive and tmp < 0: w[ii] = 0.0 else: w[ii] = (fsign(tmp) * fmax(fabs(tmp) - alpha, 0) - / (norm_cols_X[ii] + beta)) + / (norm2_cols_X[ii] + beta)) - if w[ii] != 0.0: - # R -= w[ii] * X[:,ii] # Update residual - _axpy(n_samples, -w[ii], &X[0, ii], 1, &R[0], 1) + if w[ii] != w_ii: + # R -= (w[ii] - w_ii) * X[:,ii] # Update residual + _axpy(n_samples, w_ii - w[ii], &X[0, ii], 1, &R[0], 1) # update the maximum absolute coefficient update d_w_ii = fabs(w[ii] - w_ii) @@ -365,7 +361,7 @@ def sparse_enet_coordinate_descent( # We work with: # yw = sample_weight * y # R = sample_weight * residual - # norm_cols_X = np.sum(sample_weight * (X - X_mean)**2, axis=0) + # norm2_cols_X = np.sum(sample_weight * (X - X_mean)**2, axis=0) if floating is float: dtype = np.float32 @@ -377,7 +373,7 @@ def sparse_enet_coordinate_descent( cdef unsigned int n_features = w.shape[0] # compute norms of the columns of X - cdef floating[::1] norm_cols_X = np.zeros(n_features, dtype=dtype) + cdef floating[::1] norm2_cols_X = np.zeros(n_features, dtype=dtype) # initial value of the residuals # R = y - Zw, weighted version R = sample_weight * (y - Zw) @@ -438,7 +434,7 @@ def sparse_enet_coordinate_descent( for jj in range(startptr, endptr): normalize_sum += (X_data[jj] - X_mean_ii) ** 2 R[X_indices[jj]] -= X_data[jj] * w_ii - norm_cols_X[ii] = normalize_sum + \ + norm2_cols_X[ii] = normalize_sum + \ (n_samples - endptr + startptr) * X_mean_ii ** 2 if center: for jj in range(n_samples): @@ -457,7 +453,7 @@ def sparse_enet_coordinate_descent( normalize_sum += sample_weight[jj] * X_mean_ii ** 2 R[jj] += sample_weight[jj] * X_mean_ii * w_ii R_sum += R[jj] - norm_cols_X[ii] = normalize_sum + norm2_cols_X[ii] = normalize_sum startptr = endptr # Note: No need to update R_sum from here on because the update terms cancel @@ -479,7 +475,7 @@ def sparse_enet_coordinate_descent( else: ii = f_iter - if norm_cols_X[ii] == 0.0: + if norm2_cols_X[ii] == 0.0: continue startptr = X_indptr[ii] @@ -515,7 +511,7 @@ def sparse_enet_coordinate_descent( w[ii] = 0.0 else: w[ii] = fsign(tmp) * fmax(fabs(tmp) - alpha, 0) \ - / (norm_cols_X[ii] + beta) + / (norm2_cols_X[ii] + beta) if w[ii] != 0.0: # R -= w[ii] * X[:,ii] # Update residual @@ -701,7 +697,7 @@ def enet_coordinate_descent_gram( w[ii] = fsign(tmp) * fmax(fabs(tmp) - alpha, 0) \ / (Q[ii, ii] + beta) - if w[ii] != 0.0 or w_ii != 0.0: + if w[ii] != w_ii: # Qw += (w[ii] - w_ii) * Q[ii] # Update Qw = Q @ w _axpy(n_features, w[ii] - w_ii, &Q[ii, 0], 1, &Qw[0], 1) @@ -816,7 +812,7 @@ def enet_coordinate_descent_multi_task( # initial value of the residuals cdef floating[::1, :] R = np.zeros((n_samples, n_tasks), dtype=dtype, order='F') - cdef floating[::1] norm_cols_X = np.zeros(n_features, dtype=dtype) + cdef floating[::1] norm2_cols_X = np.zeros(n_features, dtype=dtype) cdef floating[::1] tmp = np.zeros(n_tasks, dtype=dtype) cdef floating[::1] w_ii = np.zeros(n_tasks, dtype=dtype) cdef floating d_w_max @@ -847,9 +843,9 @@ def enet_coordinate_descent_multi_task( ) with nogil: - # norm_cols_X = (np.asarray(X) ** 2).sum(axis=0) + # norm2_cols_X = (np.asarray(X) ** 2).sum(axis=0) for ii in range(n_features): - norm_cols_X[ii] = _nrm2(n_samples, X_ptr + ii * n_samples, 1) ** 2 + norm2_cols_X[ii] = _nrm2(n_samples, X_ptr + ii * n_samples, 1) ** 2 # R = Y - np.dot(X, W.T) _copy(n_samples * n_tasks, Y_ptr, 1, &R[0, 0], 1) @@ -871,7 +867,7 @@ def enet_coordinate_descent_multi_task( else: ii = f_iter - if norm_cols_X[ii] == 0.0: + if norm2_cols_X[ii] == 0.0: continue # w_ii = W[:, ii] # Store previous value @@ -903,9 +899,9 @@ def enet_coordinate_descent_multi_task( # nn = sqrt(np.sum(tmp ** 2)) nn = _nrm2(n_tasks, &tmp[0], 1) - # W[:, ii] = tmp * fmax(1. - l1_reg / nn, 0) / (norm_cols_X[ii] + l2_reg) + # W[:, ii] = tmp * fmax(1. - l1_reg / nn, 0) / (norm2_cols_X[ii] + l2_reg) _copy(n_tasks, &tmp[0], 1, &W[0, ii], 1) - _scal(n_tasks, fmax(1. - l1_reg / nn, 0) / (norm_cols_X[ii] + l2_reg), + _scal(n_tasks, fmax(1. - l1_reg / nn, 0) / (norm2_cols_X[ii] + l2_reg), &W[0, ii], 1) # Using numpy: From 2230141ad2174721af66c85e5a1b5f740366d3f0 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Sat, 16 Aug 2025 10:44:09 +0200 Subject: [PATCH 02/10] ENH avoid axpy in sparse_enet_coordinate_descent --- sklearn/linear_model/_cd_fast.pyx | 35 +++++++++---------------------- 1 file changed, 10 insertions(+), 25 deletions(-) diff --git a/sklearn/linear_model/_cd_fast.pyx b/sklearn/linear_model/_cd_fast.pyx index 5842a8f1db2ea..ca3260bfc0a5d 100644 --- a/sklearn/linear_model/_cd_fast.pyx +++ b/sklearn/linear_model/_cd_fast.pyx @@ -372,7 +372,7 @@ def sparse_enet_coordinate_descent( cdef unsigned int n_samples = y.shape[0] cdef unsigned int n_features = w.shape[0] - # compute norms of the columns of X + # compute squared norms of the columns of X cdef floating[::1] norm2_cols_X = np.zeros(n_features, dtype=dtype) # initial value of the residuals @@ -483,26 +483,11 @@ def sparse_enet_coordinate_descent( w_ii = w[ii] # Store previous value X_mean_ii = X_mean[ii] - if w_ii != 0.0: - # R += w_ii * X[:,ii] - if no_sample_weights: - for jj in range(startptr, endptr): - R[X_indices[jj]] += X_data[jj] * w_ii - if center: - for jj in range(n_samples): - R[jj] -= X_mean_ii * w_ii - else: - for jj in range(startptr, endptr): - tmp = sample_weight[X_indices[jj]] - R[X_indices[jj]] += tmp * X_data[jj] * w_ii - if center: - for jj in range(n_samples): - R[jj] -= sample_weight[jj] * X_mean_ii * w_ii - - # tmp = (X[:,ii] * R).sum() + # tmp = X[:,ii] @ (R + w_ii * X[:,ii]) tmp = 0.0 for jj in range(startptr, endptr): tmp += R[X_indices[jj]] * X_data[jj] + tmp += w_ii * norm2_cols_X[ii] if center: tmp -= R_sum * X_mean_ii @@ -513,21 +498,21 @@ def sparse_enet_coordinate_descent( w[ii] = fsign(tmp) * fmax(fabs(tmp) - alpha, 0) \ / (norm2_cols_X[ii] + beta) - if w[ii] != 0.0: - # R -= w[ii] * X[:,ii] # Update residual + if w[ii] != w_ii: + # R -= (w[ii] - w_ii) * X[:,ii] # Update residual if no_sample_weights: for jj in range(startptr, endptr): - R[X_indices[jj]] -= X_data[jj] * w[ii] + R[X_indices[jj]] -= X_data[jj] * (w[ii] - w_ii) if center: for jj in range(n_samples): - R[jj] += X_mean_ii * w[ii] + R[jj] += X_mean_ii * (w[ii] - w_ii) else: for jj in range(startptr, endptr): - tmp = sample_weight[X_indices[jj]] - R[X_indices[jj]] -= tmp * X_data[jj] * w[ii] + kk = X_indices[jj] + R[kk] -= sample_weight[kk] * X_data[jj] * (w[ii] - w_ii) if center: for jj in range(n_samples): - R[jj] += sample_weight[jj] * X_mean_ii * w[ii] + R[jj] += sample_weight[jj] * X_mean_ii * (w[ii] - w_ii) # update the maximum absolute coefficient update d_w_ii = fabs(w[ii] - w_ii) From 3517d30c9996a2856a29dd7baf591395fde1eee1 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Sat, 16 Aug 2025 11:00:07 +0200 Subject: [PATCH 03/10] ENH avoid axpy in enet_coordinate_descent_multi_task --- sklearn/linear_model/_cd_fast.pyx | 97 ++++++++++++++----------------- 1 file changed, 45 insertions(+), 52 deletions(-) diff --git a/sklearn/linear_model/_cd_fast.pyx b/sklearn/linear_model/_cd_fast.pyx index ca3260bfc0a5d..ceca8a6ecb1ee 100644 --- a/sklearn/linear_model/_cd_fast.pyx +++ b/sklearn/linear_model/_cd_fast.pyx @@ -640,10 +640,13 @@ def enet_coordinate_descent_gram( cdef floating w_max cdef floating d_w_ii cdef floating q_dot_w - cdef floating w_norm2 cdef floating gap = tol + 1.0 cdef floating d_w_tol = tol cdef floating dual_norm_XtA + cdef floating R_norm2 + cdef floating w_norm2 + cdef floating A_norm2 + cdef floating const_ cdef unsigned int ii cdef unsigned int n_iter = 0 cdef unsigned int f_iter @@ -789,6 +792,12 @@ def enet_coordinate_descent_multi_task( cdef unsigned int n_features = X.shape[1] cdef unsigned int n_tasks = Y.shape[1] + # compute squared norms of the columns of X + # same as norm2_cols_X = np.square(X).sum(axis=0) + cdef floating[::1] norm2_cols_X = np.einsum( + "ij,ij->j", X, X, dtype=dtype, order="C" + ) + # to store XtA cdef floating[:, ::1] XtA = np.zeros((n_features, n_tasks), dtype=dtype) cdef floating XtA_axis1norm @@ -797,7 +806,6 @@ def enet_coordinate_descent_multi_task( # initial value of the residuals cdef floating[::1, :] R = np.zeros((n_samples, n_tasks), dtype=dtype, order='F') - cdef floating[::1] norm2_cols_X = np.zeros(n_features, dtype=dtype) cdef floating[::1] tmp = np.zeros(n_tasks, dtype=dtype) cdef floating[::1] w_ii = np.zeros(n_tasks, dtype=dtype) cdef floating d_w_max @@ -807,8 +815,8 @@ def enet_coordinate_descent_multi_task( cdef floating W_ii_abs_max cdef floating gap = tol + 1.0 cdef floating d_w_tol = tol - cdef floating R_norm - cdef floating w_norm + cdef floating R_norm2 + cdef floating w_norm2 cdef floating ry_sum cdef floating l21_norm cdef unsigned int ii @@ -818,9 +826,6 @@ def enet_coordinate_descent_multi_task( cdef uint32_t rand_r_state_seed = rng.randint(0, RAND_R_MAX) cdef uint32_t* rand_r_state = &rand_r_state_seed - cdef const floating* X_ptr = &X[0, 0] - cdef const floating* Y_ptr = &Y[0, 0] - if l1_reg == 0: warnings.warn( "Coordinate descent with l1_reg=0 may lead to unexpected" @@ -828,20 +833,16 @@ def enet_coordinate_descent_multi_task( ) with nogil: - # norm2_cols_X = (np.asarray(X) ** 2).sum(axis=0) - for ii in range(n_features): - norm2_cols_X[ii] = _nrm2(n_samples, X_ptr + ii * n_samples, 1) ** 2 - # R = Y - np.dot(X, W.T) - _copy(n_samples * n_tasks, Y_ptr, 1, &R[0, 0], 1) + _copy(n_samples * n_tasks, &Y[0, 0], 1, &R[0, 0], 1) for ii in range(n_features): for jj in range(n_tasks): if W[jj, ii] != 0: - _axpy(n_samples, -W[jj, ii], X_ptr + ii * n_samples, 1, + _axpy(n_samples, -W[jj, ii], &X[0, ii], 1, &R[0, jj], 1) # tol = tol * linalg.norm(Y, ord='fro') ** 2 - tol = tol * _nrm2(n_samples * n_tasks, Y_ptr, 1) ** 2 + tol = tol * _nrm2(n_samples * n_tasks, &Y[0, 0], 1) ** 2 for n_iter in range(max_iter): w_max = 0.0 @@ -858,28 +859,20 @@ def enet_coordinate_descent_multi_task( # w_ii = W[:, ii] # Store previous value _copy(n_tasks, &W[0, ii], 1, &w_ii[0], 1) - # Using Numpy: - # R += np.dot(X[:, ii][:, None], w_ii[None, :]) # rank 1 update - # Using Blas Level2: - # _ger(RowMajor, n_samples, n_tasks, 1.0, - # &X[0, ii], 1, - # &w_ii[0], 1, &R[0, 0], n_tasks) - # Using Blas Level1 and for loop to avoid slower threads - # for such small vectors - for jj in range(n_tasks): - if w_ii[jj] != 0: - _axpy(n_samples, w_ii[jj], X_ptr + ii * n_samples, 1, - &R[0, jj], 1) - - # Using numpy: - # tmp = np.dot(X[:, ii][None, :], R).ravel() - # Using BLAS Level 2: - # _gemv(RowMajor, Trans, n_samples, n_tasks, 1.0, &R[0, 0], - # n_tasks, &X[0, ii], 1, 0.0, &tmp[0], 1) + # tmp = X[:, ii] @ (R + w_ii * X[:,ii][:, None]) + # first part: X[:, ii] @ R + # Using BLAS Level 2: + # _gemv(RowMajor, Trans, n_samples, n_tasks, 1.0, &R[0, 0], + # n_tasks, &X[0, ii], 1, 0.0, &tmp[0], 1) + # second part: (X[:, ii] @ X[:,ii]) * w_ii = norm2_cols * w_ii + # Using BLAS Level 1: + # _axpy(n_tasks, norm2_cols[ii], &w_ii[0], 1, &tmp[0], 1) # Using BLAS Level 1 (faster for small vectors like here): for jj in range(n_tasks): - tmp[jj] = _dot(n_samples, X_ptr + ii * n_samples, 1, - &R[0, jj], 1) + tmp[jj] = _dot(n_samples, &X[0, ii], 1, &R[0, jj], 1) + # As we have the loop already, we use it to replace the second BLAS + # Level 1, i.e., _axpy, too. + tmp[jj] += w_ii[jj] * norm2_cols_X[ii] # nn = sqrt(np.sum(tmp ** 2)) nn = _nrm2(n_tasks, &tmp[0], 1) @@ -889,17 +882,18 @@ def enet_coordinate_descent_multi_task( _scal(n_tasks, fmax(1. - l1_reg / nn, 0) / (norm2_cols_X[ii] + l2_reg), &W[0, ii], 1) + # Update residual # Using numpy: - # R -= np.dot(X[:, ii][:, None], W[:, ii][None, :]) - # Using BLAS Level 2: - # Update residual : rank 1 update - # _ger(RowMajor, n_samples, n_tasks, -1.0, - # &X[0, ii], 1, &W[0, ii], 1, - # &R[0, 0], n_tasks) + # R -= (W[:, ii] - w_ii) * X[:, ii][:, None] + # Using BLAS Level 1 and 2: + # _axpy(n_tasks, -1.0, &W[0, ii], 1, &w_ii[0], 1) + # _ger(RowMajor, n_samples, n_tasks, 1.0, + # &X[0, ii], 1, &w_ii, 1, + # &R[0, 0], n_tasks) # Using BLAS Level 1 (faster for small vectors like here): for jj in range(n_tasks): - if W[jj, ii] != 0: - _axpy(n_samples, -W[jj, ii], X_ptr + ii * n_samples, 1, + if W[jj, ii] != 0 or w_ii[jj] != 0.0: + _axpy(n_samples, w_ii[jj] - W[jj, ii], &X[0, ii], 1, &R[0, jj], 1) # update the maximum absolute coefficient update @@ -921,7 +915,7 @@ def enet_coordinate_descent_multi_task( for ii in range(n_features): for jj in range(n_tasks): XtA[ii, jj] = _dot( - n_samples, X_ptr + ii * n_samples, 1, &R[0, jj], 1 + n_samples, &X[0, ii], 1, &R[0, jj], 1 ) - l2_reg * W[jj, ii] # dual_norm_XtA = np.max(np.sqrt(np.sum(XtA ** 2, axis=1))) @@ -932,18 +926,17 @@ def enet_coordinate_descent_multi_task( if XtA_axis1norm > dual_norm_XtA: dual_norm_XtA = XtA_axis1norm - # TODO: use squared L2 norm directly - # R_norm = linalg.norm(R, ord='fro') - # w_norm = linalg.norm(W, ord='fro') - R_norm = _nrm2(n_samples * n_tasks, &R[0, 0], 1) - w_norm = _nrm2(n_features * n_tasks, &W[0, 0], 1) + # R_norm2 = linalg.norm(R, ord='fro') ** 2 + # w_norm2 = linalg.norm(W, ord='fro') ** 2 + R_norm2 = _dot(n_samples * n_tasks, &R[0, 0], 1, &R[0, 0], 1) + w_norm2 = _dot(n_features * n_tasks, &W[0, 0], 1, &W[0, 0], 1) if (dual_norm_XtA > l1_reg): const_ = l1_reg / dual_norm_XtA - A_norm = R_norm * const_ - gap = 0.5 * (R_norm ** 2 + A_norm ** 2) + A_norm2 = R_norm2 * (const_ ** 2) + gap = 0.5 * (R_norm2 + A_norm2) else: const_ = 1.0 - gap = R_norm ** 2 + gap = R_norm2 # ry_sum = np.sum(R * y) ry_sum = _dot(n_samples * n_tasks, &R[0, 0], 1, &Y[0, 0], 1) @@ -956,7 +949,7 @@ def enet_coordinate_descent_multi_task( gap += ( l1_reg * l21_norm - const_ * ry_sum - + 0.5 * l2_reg * (1 + const_ ** 2) * (w_norm ** 2) + + 0.5 * l2_reg * (1 + const_ ** 2) * w_norm2 ) if gap <= tol: From 6b4b72cb11a4373cef53709b22809e98e90fbe4a Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Sat, 16 Aug 2025 11:23:01 +0200 Subject: [PATCH 04/10] DOC extent whatsnew entry of 31880 --- .../sklearn.linear_model/31880.efficiency.rst | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/doc/whats_new/upcoming_changes/sklearn.linear_model/31880.efficiency.rst b/doc/whats_new/upcoming_changes/sklearn.linear_model/31880.efficiency.rst index 9befdee1e144c..195eb42d907eb 100644 --- a/doc/whats_new/upcoming_changes/sklearn.linear_model/31880.efficiency.rst +++ b/doc/whats_new/upcoming_changes/sklearn.linear_model/31880.efficiency.rst @@ -1,7 +1,9 @@ - :class:`linear_model.ElasticNet`, :class:`linear_model.ElasticNetCV`, - :class:`linear_model.Lasso` and :class:`linear_model.LassoCV` with `precompute=True` - (or `precompute="auto"`` and `n_samples > n_features`) are faster to fit by - avoiding a BLAS level 1 (axpy) call in the inner most loop. + :class:`linear_model.Lasso`, :class:`linear_model.LassoCV`, + :class:`linear_model.MultiTaskElasticNet`, + :class:`linear_model.MultiTaskElasticNetCV`, + :class:`linear_model.MultiTaskLasso` and :class:`linear_model.MultiTaskLassoCV` + are faster to fit by avoiding a BLAS level 1 (axpy) call in the innermost loop. Same for functions :func:`linear_model.enet_path` and :func:`linear_model.lasso_path`. - By :user:`Christian Lorentzen `. + By :user:`Christian Lorentzen ` :pr:`31956` and From 4e02b2b35ef8db838296fddd228d5b9affb5eb87 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Mon, 18 Aug 2025 11:16:29 +0200 Subject: [PATCH 05/10] TST reduce assertion atol to 1e-15 --- sklearn/linear_model/tests/test_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/linear_model/tests/test_common.py b/sklearn/linear_model/tests/test_common.py index 348710e70af64..12248a0980db4 100644 --- a/sklearn/linear_model/tests/test_common.py +++ b/sklearn/linear_model/tests/test_common.py @@ -278,7 +278,7 @@ def test_model_pipeline_same_dense_and_sparse(LinearModel, params, csr_container model_dense.fit(X, y) model_sparse.fit(X_sparse, y) - assert_allclose(model_sparse[1].coef_, model_dense[1].coef_, atol=1e-16) + assert_allclose(model_sparse[1].coef_, model_dense[1].coef_, atol=1e-15) y_pred_dense = model_dense.predict(X) y_pred_sparse = model_sparse.predict(X_sparse) assert_allclose(y_pred_dense, y_pred_sparse) From bbde57c73baa2f9142789fd669d7db488e9b6fea Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Fri, 22 Aug 2025 18:01:17 +0200 Subject: [PATCH 06/10] CLN address review --- sklearn/linear_model/_cd_fast.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/linear_model/_cd_fast.pyx b/sklearn/linear_model/_cd_fast.pyx index ceca8a6ecb1ee..d2552899af66b 100644 --- a/sklearn/linear_model/_cd_fast.pyx +++ b/sklearn/linear_model/_cd_fast.pyx @@ -892,7 +892,7 @@ def enet_coordinate_descent_multi_task( # &R[0, 0], n_tasks) # Using BLAS Level 1 (faster for small vectors like here): for jj in range(n_tasks): - if W[jj, ii] != 0 or w_ii[jj] != 0.0: + if W[jj, ii] != w_ii[jj]: _axpy(n_samples, w_ii[jj] - W[jj, ii], &X[0, ii], 1, &R[0, jj], 1) From 5a8848d193ee9c30387cd7f7371301421e591863 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Mon, 25 Aug 2025 08:55:53 +0200 Subject: [PATCH 07/10] ENH gap_enet_multi_task and early stopping before main loop --- sklearn/linear_model/_cd_fast.pyx | 121 +++++++++++------- .../tests/test_coordinate_descent.py | 4 +- 2 files changed, 76 insertions(+), 49 deletions(-) diff --git a/sklearn/linear_model/_cd_fast.pyx b/sklearn/linear_model/_cd_fast.pyx index ba8ae2e575576..0bc6be38207f5 100644 --- a/sklearn/linear_model/_cd_fast.pyx +++ b/sklearn/linear_model/_cd_fast.pyx @@ -359,7 +359,6 @@ def enet_coordinate_descent( gap, dual_norm_XtA = gap_enet( n_samples, n_features, w, alpha, beta, X, y, R, XtA, positive ) - if gap <= tol: # return if we reached desired tolerance break @@ -837,6 +836,70 @@ def enet_coordinate_descent_gram( return np.asarray(w), gap, tol, n_iter + 1 +cdef (floating, floating) gap_enet_multi_task( + int n_samples, + int n_features, + int n_tasks, + const floating[::1, :] W, # n_tasks, n_features + floating l1_reg, # L1 penalty + floating l2_reg, # L2 penalty + const floating[::1, :] X, # n_samples, n_features + const floating[::1, :] Y, # n_samples, n_tasks + const floating[::1, :] R, # current residuals = y - X @ W.T + floating[:, ::1] XtA, # XtA = X.T @ R - l2_reg * W.T is calculated inplace +) noexcept nogil: + """Compute dual gap for use in enet_coordinate_descent.""" + cdef floating gap = 0.0 + cdef floating dual_norm_XtA + cdef floating XtA_axis1norm + cdef floating R_norm2 + cdef floating w_norm2 = 0.0 + cdef floating l21_norm + cdef floating A_norm2 + cdef floating const_ + cdef unsigned int t, j + + # XtA = X.T @ R - l2_reg * W.T + for j in range(n_features): + for t in range(n_tasks): + XtA[j, t] = _dot(n_samples, &X[0, j], 1, &R[0, t], 1) - l2_reg * W[t, j] + + # dual_norm_XtA = np.max(np.sqrt(np.sum(XtA ** 2, axis=1))) + dual_norm_XtA = 0.0 + for j in range(n_features): + # np.sqrt(np.sum(XtA ** 2, axis=1)) + XtA_axis1norm = _nrm2(n_tasks, &XtA[j, 0], 1) + if XtA_axis1norm > dual_norm_XtA: + dual_norm_XtA = XtA_axis1norm + + # R_norm2 = linalg.norm(R, ord="fro") ** 2 + R_norm2 = _dot(n_samples * n_tasks, &R[0, 0], 1, &R[0, 0], 1) + + # w_norm2 = linalg.norm(W, ord="fro") ** 2 + if l2_reg > 0: + w_norm2 = _dot(n_features * n_tasks, &W[0, 0], 1, &W[0, 0], 1) + + if (dual_norm_XtA > l1_reg): + const_ = l1_reg / dual_norm_XtA + A_norm2 = R_norm2 * (const_ ** 2) + gap = 0.5 * (R_norm2 + A_norm2) + else: + const_ = 1.0 + gap = R_norm2 + + # l21_norm = np.sqrt(np.sum(W ** 2, axis=0)).sum() + l21_norm = 0.0 + for ii in range(n_features): + l21_norm += _nrm2(n_tasks, &W[0, ii], 1) + + gap += ( + l1_reg * l21_norm + - const_ * _dot(n_samples * n_tasks, &R[0, 0], 1, &Y[0, 0], 1) # np.sum(R * Y) + + 0.5 * l2_reg * (1 + const_ ** 2) * w_norm2 + ) + return gap, dual_norm_XtA + + def enet_coordinate_descent_multi_task( const floating[::1, :] W, floating l1_reg, @@ -891,7 +954,6 @@ def enet_coordinate_descent_multi_task( # to store XtA cdef floating[:, ::1] XtA = np.zeros((n_features, n_tasks), dtype=dtype) - cdef floating XtA_axis1norm cdef floating dual_norm_XtA # initial value of the residuals @@ -906,10 +968,6 @@ def enet_coordinate_descent_multi_task( cdef floating W_ii_abs_max cdef floating gap = tol + 1.0 cdef floating d_w_tol = tol - cdef floating R_norm2 - cdef floating w_norm2 - cdef floating ry_sum - cdef floating l21_norm cdef unsigned int ii cdef unsigned int jj cdef unsigned int n_iter = 0 @@ -935,6 +993,14 @@ def enet_coordinate_descent_multi_task( # tol = tol * linalg.norm(Y, ord='fro') ** 2 tol = tol * _nrm2(n_samples * n_tasks, &Y[0, 0], 1) ** 2 + # Check convergence before entering the main loop. + gap, dual_norm_XtA = gap_enet_multi_task( + n_samples, n_features, n_tasks, W, l1_reg, l2_reg, X, Y, R, XtA + ) + if gap <= tol: + with gil: + return np.asarray(W), gap, tol, 0 + for n_iter in range(max_iter): w_max = 0.0 d_w_max = 0.0 @@ -1001,48 +1067,9 @@ def enet_coordinate_descent_multi_task( # the biggest coordinate update of this iteration was smaller than # the tolerance: check the duality gap as ultimate stopping # criterion - - # XtA = np.dot(X.T, R) - l2_reg * W.T - for ii in range(n_features): - for jj in range(n_tasks): - XtA[ii, jj] = _dot( - n_samples, &X[0, ii], 1, &R[0, jj], 1 - ) - l2_reg * W[jj, ii] - - # dual_norm_XtA = np.max(np.sqrt(np.sum(XtA ** 2, axis=1))) - dual_norm_XtA = 0.0 - for ii in range(n_features): - # np.sqrt(np.sum(XtA ** 2, axis=1)) - XtA_axis1norm = _nrm2(n_tasks, &XtA[ii, 0], 1) - if XtA_axis1norm > dual_norm_XtA: - dual_norm_XtA = XtA_axis1norm - - # R_norm2 = linalg.norm(R, ord='fro') ** 2 - # w_norm2 = linalg.norm(W, ord='fro') ** 2 - R_norm2 = _dot(n_samples * n_tasks, &R[0, 0], 1, &R[0, 0], 1) - w_norm2 = _dot(n_features * n_tasks, &W[0, 0], 1, &W[0, 0], 1) - if (dual_norm_XtA > l1_reg): - const_ = l1_reg / dual_norm_XtA - A_norm2 = R_norm2 * (const_ ** 2) - gap = 0.5 * (R_norm2 + A_norm2) - else: - const_ = 1.0 - gap = R_norm2 - - # ry_sum = np.sum(R * y) - ry_sum = _dot(n_samples * n_tasks, &R[0, 0], 1, &Y[0, 0], 1) - - # l21_norm = np.sqrt(np.sum(W ** 2, axis=0)).sum() - l21_norm = 0.0 - for ii in range(n_features): - l21_norm += _nrm2(n_tasks, &W[0, ii], 1) - - gap += ( - l1_reg * l21_norm - - const_ * ry_sum - + 0.5 * l2_reg * (1 + const_ ** 2) * w_norm2 + gap, dual_norm_XtA = gap_enet_multi_task( + n_samples, n_features, n_tasks, W, l1_reg, l2_reg, X, Y, R, XtA ) - if gap <= tol: # return if we reached desired tolerance break diff --git a/sklearn/linear_model/tests/test_coordinate_descent.py b/sklearn/linear_model/tests/test_coordinate_descent.py index aa073b9a5080b..e098d297db329 100644 --- a/sklearn/linear_model/tests/test_coordinate_descent.py +++ b/sklearn/linear_model/tests/test_coordinate_descent.py @@ -694,7 +694,7 @@ def test_multitask_enet_and_lasso_cv(): X, y, _, _ = build_dataset(n_features=50, n_targets=3) clf = MultiTaskElasticNetCV(cv=3).fit(X, y) assert_almost_equal(clf.alpha_, 0.00556, 3) - clf = MultiTaskLassoCV(cv=3).fit(X, y) + clf = MultiTaskLassoCV(cv=3, tol=1e-6).fit(X, y) assert_almost_equal(clf.alpha_, 0.00278, 3) X, y, _, _ = build_dataset(n_targets=3) @@ -1231,7 +1231,7 @@ def test_multi_task_lasso_cv_dtype(): X = rng.binomial(1, 0.5, size=(n_samples, n_features)) X = X.astype(int) # make it explicit that X is int y = X[:, [0, 0]].copy() - est = MultiTaskLassoCV(alphas=5, fit_intercept=True).fit(X, y) + est = MultiTaskLassoCV(alphas=5, fit_intercept=True, tol=1e-6).fit(X, y) assert_array_almost_equal(est.coef_, [[1, 0, 0]] * 2, decimal=3) From 697c6bd62ea0bc922310c96fa101ddb0bea1e4ab Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Mon, 25 Aug 2025 09:42:25 +0200 Subject: [PATCH 08/10] MNT ii->j (features) and jj->t (tasks) --- sklearn/linear_model/_cd_fast.pyx | 82 +++++++++++++++---------------- 1 file changed, 41 insertions(+), 41 deletions(-) diff --git a/sklearn/linear_model/_cd_fast.pyx b/sklearn/linear_model/_cd_fast.pyx index 0bc6be38207f5..706509fd0e755 100644 --- a/sklearn/linear_model/_cd_fast.pyx +++ b/sklearn/linear_model/_cd_fast.pyx @@ -960,16 +960,16 @@ def enet_coordinate_descent_multi_task( cdef floating[::1, :] R = np.zeros((n_samples, n_tasks), dtype=dtype, order='F') cdef floating[::1] tmp = np.zeros(n_tasks, dtype=dtype) - cdef floating[::1] w_ii = np.zeros(n_tasks, dtype=dtype) + cdef floating[::1] w_j = np.zeros(n_tasks, dtype=dtype) cdef floating d_w_max cdef floating w_max - cdef floating d_w_ii + cdef floating d_w_j cdef floating nn - cdef floating W_ii_abs_max + cdef floating W_j_abs_max cdef floating gap = tol + 1.0 cdef floating d_w_tol = tol - cdef unsigned int ii - cdef unsigned int jj + cdef unsigned int j + cdef unsigned int t cdef unsigned int n_iter = 0 cdef unsigned int f_iter cdef uint32_t rand_r_state_seed = rng.randint(0, RAND_R_MAX) @@ -982,13 +982,13 @@ def enet_coordinate_descent_multi_task( ) with nogil: - # R = Y - np.dot(X, W.T) + # R = Y - X @ W.T _copy(n_samples * n_tasks, &Y[0, 0], 1, &R[0, 0], 1) - for ii in range(n_features): - for jj in range(n_tasks): - if W[jj, ii] != 0: - _axpy(n_samples, -W[jj, ii], &X[0, ii], 1, - &R[0, jj], 1) + for j in range(n_features): + for t in range(n_tasks): + if W[t, j] != 0: + _axpy(n_samples, -W[t, j], &X[0, j], 1, + &R[0, t], 1) # tol = tol * linalg.norm(Y, ord='fro') ** 2 tol = tol * _nrm2(n_samples * n_tasks, &Y[0, 0], 1) ** 2 @@ -1006,62 +1006,62 @@ def enet_coordinate_descent_multi_task( d_w_max = 0.0 for f_iter in range(n_features): # Loop over coordinates if random: - ii = rand_int(n_features, rand_r_state) + j = rand_int(n_features, rand_r_state) else: - ii = f_iter + j = f_iter - if norm2_cols_X[ii] == 0.0: + if norm2_cols_X[j] == 0.0: continue - # w_ii = W[:, ii] # Store previous value - _copy(n_tasks, &W[0, ii], 1, &w_ii[0], 1) + # w_j = W[:, j] # Store previous value + _copy(n_tasks, &W[0, j], 1, &w_j[0], 1) - # tmp = X[:, ii] @ (R + w_ii * X[:,ii][:, None]) - # first part: X[:, ii] @ R + # tmp = X[:, j] @ (R + w_j * X[:,j][:, None]) + # first part: X[:, j] @ R # Using BLAS Level 2: # _gemv(RowMajor, Trans, n_samples, n_tasks, 1.0, &R[0, 0], - # n_tasks, &X[0, ii], 1, 0.0, &tmp[0], 1) - # second part: (X[:, ii] @ X[:,ii]) * w_ii = norm2_cols * w_ii + # n_tasks, &X[0, j], 1, 0.0, &tmp[0], 1) + # second part: (X[:, j] @ X[:,j]) * w_j = norm2_cols * w_j # Using BLAS Level 1: - # _axpy(n_tasks, norm2_cols[ii], &w_ii[0], 1, &tmp[0], 1) + # _axpy(n_tasks, norm2_cols[j], &w_j[0], 1, &tmp[0], 1) # Using BLAS Level 1 (faster for small vectors like here): - for jj in range(n_tasks): - tmp[jj] = _dot(n_samples, &X[0, ii], 1, &R[0, jj], 1) + for t in range(n_tasks): + tmp[t] = _dot(n_samples, &X[0, j], 1, &R[0, t], 1) # As we have the loop already, we use it to replace the second BLAS # Level 1, i.e., _axpy, too. - tmp[jj] += w_ii[jj] * norm2_cols_X[ii] + tmp[t] += w_j[t] * norm2_cols_X[j] # nn = sqrt(np.sum(tmp ** 2)) nn = _nrm2(n_tasks, &tmp[0], 1) - # W[:, ii] = tmp * fmax(1. - l1_reg / nn, 0) / (norm2_cols_X[ii] + l2_reg) - _copy(n_tasks, &tmp[0], 1, &W[0, ii], 1) - _scal(n_tasks, fmax(1. - l1_reg / nn, 0) / (norm2_cols_X[ii] + l2_reg), - &W[0, ii], 1) + # W[:, j] = tmp * fmax(1. - l1_reg / nn, 0) / (norm2_cols_X[j] + l2_reg) + _copy(n_tasks, &tmp[0], 1, &W[0, j], 1) + _scal(n_tasks, fmax(1. - l1_reg / nn, 0) / (norm2_cols_X[j] + l2_reg), + &W[0, j], 1) # Update residual # Using numpy: - # R -= (W[:, ii] - w_ii) * X[:, ii][:, None] + # R -= (W[:, j] - w_j) * X[:, j][:, None] # Using BLAS Level 1 and 2: - # _axpy(n_tasks, -1.0, &W[0, ii], 1, &w_ii[0], 1) + # _axpy(n_tasks, -1.0, &W[0, j], 1, &w_j[0], 1) # _ger(RowMajor, n_samples, n_tasks, 1.0, - # &X[0, ii], 1, &w_ii, 1, + # &X[0, j], 1, &w_j, 1, # &R[0, 0], n_tasks) # Using BLAS Level 1 (faster for small vectors like here): - for jj in range(n_tasks): - if W[jj, ii] != w_ii[jj]: - _axpy(n_samples, w_ii[jj] - W[jj, ii], &X[0, ii], 1, - &R[0, jj], 1) + for t in range(n_tasks): + if W[t, j] != w_j[t]: + _axpy(n_samples, w_j[t] - W[t, j], &X[0, j], 1, + &R[0, t], 1) # update the maximum absolute coefficient update - d_w_ii = diff_abs_max(n_tasks, &W[0, ii], &w_ii[0]) + d_w_j = diff_abs_max(n_tasks, &W[0, j], &w_j[0]) - if d_w_ii > d_w_max: - d_w_max = d_w_ii + if d_w_j > d_w_max: + d_w_max = d_w_j - W_ii_abs_max = abs_max(n_tasks, &W[0, ii]) - if W_ii_abs_max > w_max: - w_max = W_ii_abs_max + W_j_abs_max = abs_max(n_tasks, &W[0, j]) + if W_j_abs_max > w_max: + w_max = W_j_abs_max if w_max == 0.0 or d_w_max / w_max <= d_w_tol or n_iter == max_iter - 1: # the biggest coordinate update of this iteration was smaller than From 241ae31acc8f7c76c7b5265213a71e597d8ec7ce Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Mon, 25 Aug 2025 20:57:07 +0200 Subject: [PATCH 09/10] ENH add gap safe screening rules to enet_coordinate_descent_multi_task --- sklearn/linear_model/_cd_fast.pyx | 132 +++++++++++++++----- sklearn/linear_model/_coordinate_descent.py | 2 +- 2 files changed, 103 insertions(+), 31 deletions(-) diff --git a/sklearn/linear_model/_cd_fast.pyx b/sklearn/linear_model/_cd_fast.pyx index 706509fd0e755..4705de5341b04 100644 --- a/sklearn/linear_model/_cd_fast.pyx +++ b/sklearn/linear_model/_cd_fast.pyx @@ -258,7 +258,7 @@ def enet_coordinate_descent( cdef floating dual_norm_XtA cdef unsigned int n_active = n_features cdef uint32_t[::1] active_set - # TODO: use binset insteaf of array of bools + # TODO: use binset instead of array of bools cdef uint8_t[::1] excluded_set cdef unsigned int j cdef unsigned int n_iter = 0 @@ -840,18 +840,31 @@ cdef (floating, floating) gap_enet_multi_task( int n_samples, int n_features, int n_tasks, - const floating[::1, :] W, # n_tasks, n_features - floating l1_reg, # L1 penalty - floating l2_reg, # L2 penalty - const floating[::1, :] X, # n_samples, n_features - const floating[::1, :] Y, # n_samples, n_tasks - const floating[::1, :] R, # current residuals = y - X @ W.T - floating[:, ::1] XtA, # XtA = X.T @ R - l2_reg * W.T is calculated inplace + const floating[::1, :] W, # in + floating l1_reg, + floating l2_reg, + const floating[::1, :] X, # in + const floating[::1, :] Y, # in + const floating[::1, :] R, # in + floating[:, ::1] XtA, # out + floating[::1] XtA_row_norms, # out ) noexcept nogil: - """Compute dual gap for use in enet_coordinate_descent.""" + """Compute dual gap for use in enet_coordinate_descent_multi_task. + + Parameters + ---------- + W : memoryview of shape (n_tasks, n_features) + X : memoryview of shape (n_samples, n_features) + Y : memoryview of shape (n_samples, n_tasks) + R : memoryview of shape (n_samples, n_tasks) + Current residuals = Y - X @ W.T + XtA : memoryview of shape (n_features, n_tasks) + Inplace calculated as XtA = X.T @ R - l2_reg * W.T + XtA_row_norms : memoryview of shape n_features + Inplace calculated as np.sqrt(np.sum(XtA ** 2, axis=1)) + """ cdef floating gap = 0.0 cdef floating dual_norm_XtA - cdef floating XtA_axis1norm cdef floating R_norm2 cdef floating w_norm2 = 0.0 cdef floating l21_norm @@ -868,9 +881,9 @@ cdef (floating, floating) gap_enet_multi_task( dual_norm_XtA = 0.0 for j in range(n_features): # np.sqrt(np.sum(XtA ** 2, axis=1)) - XtA_axis1norm = _nrm2(n_tasks, &XtA[j, 0], 1) - if XtA_axis1norm > dual_norm_XtA: - dual_norm_XtA = XtA_axis1norm + XtA_row_norms[j] = _nrm2(n_tasks, &XtA[j, 0], 1) + if XtA_row_norms[j] > dual_norm_XtA: + dual_norm_XtA = XtA_row_norms[j] # R_norm2 = linalg.norm(R, ord="fro") ** 2 R_norm2 = _dot(n_samples * n_tasks, &R[0, 0], 1, &R[0, 0], 1) @@ -901,7 +914,7 @@ cdef (floating, floating) gap_enet_multi_task( def enet_coordinate_descent_multi_task( - const floating[::1, :] W, + floating[::1, :] W, floating l1_reg, floating l2_reg, const floating[::1, :] X, @@ -909,7 +922,8 @@ def enet_coordinate_descent_multi_task( unsigned int max_iter, floating tol, object rng, - bint random=0 + bint random=0, + bint do_screening=1, ): """Cython version of the coordinate descent algorithm for Elastic-Net multi-task regression @@ -952,15 +966,15 @@ def enet_coordinate_descent_multi_task( "ij,ij->j", X, X, dtype=dtype, order="C" ) - # to store XtA - cdef floating[:, ::1] XtA = np.zeros((n_features, n_tasks), dtype=dtype) - cdef floating dual_norm_XtA - # initial value of the residuals - cdef floating[::1, :] R = np.zeros((n_samples, n_tasks), dtype=dtype, order='F') + cdef floating[::1, :] R = np.empty((n_samples, n_tasks), dtype=dtype, order='F') + cdef floating[:, ::1] XtA = np.empty((n_features, n_tasks), dtype=dtype) + cdef floating[::1] XtA_row_norms = np.empty(n_features, dtype=dtype) - cdef floating[::1] tmp = np.zeros(n_tasks, dtype=dtype) - cdef floating[::1] w_j = np.zeros(n_tasks, dtype=dtype) + cdef floating d_j + cdef floating Xj_theta + cdef floating[::1] tmp = np.empty(n_tasks, dtype=dtype) + cdef floating[::1] w_j = np.empty(n_tasks, dtype=dtype) cdef floating d_w_max cdef floating w_max cdef floating d_w_j @@ -968,6 +982,11 @@ def enet_coordinate_descent_multi_task( cdef floating W_j_abs_max cdef floating gap = tol + 1.0 cdef floating d_w_tol = tol + cdef floating dual_norm_XtA + cdef unsigned int n_active = n_features + cdef uint32_t[::1] active_set + # TODO: use binset instead of array of bools + cdef uint8_t[::1] excluded_set cdef unsigned int j cdef unsigned int t cdef unsigned int n_iter = 0 @@ -981,35 +1000,66 @@ def enet_coordinate_descent_multi_task( " results and is discouraged." ) + if do_screening: + active_set = np.empty(n_features, dtype=np.uint32) # map [:n_active] -> j + excluded_set = np.empty(n_features, dtype=np.uint8) + with nogil: # R = Y - X @ W.T _copy(n_samples * n_tasks, &Y[0, 0], 1, &R[0, 0], 1) for j in range(n_features): for t in range(n_tasks): if W[t, j] != 0: - _axpy(n_samples, -W[t, j], &X[0, j], 1, - &R[0, t], 1) + _axpy(n_samples, -W[t, j], &X[0, j], 1, &R[0, t], 1) # tol = tol * linalg.norm(Y, ord='fro') ** 2 tol = tol * _nrm2(n_samples * n_tasks, &Y[0, 0], 1) ** 2 # Check convergence before entering the main loop. gap, dual_norm_XtA = gap_enet_multi_task( - n_samples, n_features, n_tasks, W, l1_reg, l2_reg, X, Y, R, XtA + n_samples, n_features, n_tasks, W, l1_reg, l2_reg, X, Y, R, XtA, XtA_row_norms ) if gap <= tol: with gil: return np.asarray(W), gap, tol, 0 + # Gap Safe Screening Rules for multi-task Lasso, see + # https://arxiv.org/abs/1703.07285 Eq 2.2. (also arxiv:1506.03736) + if do_screening: + n_active = 0 + for j in range(n_features): + if norm2_cols_X[j] == 0: + for t in range(n_tasks): + W[t, j] = 0 + excluded_set[j] = 1 + continue + # Xj_theta = ||X[:,j] @ dual_theta||_2 + Xj_theta = XtA_row_norms[j] / fmax(l1_reg, dual_norm_XtA) + d_j = (1 - Xj_theta) / sqrt(norm2_cols_X[j] + l2_reg) + if d_j <= sqrt(2 * gap) / l1_reg: + # include feature j + active_set[n_active] = j + excluded_set[j] = 0 + n_active += 1 + else: + # R += W[:, 1] * X[:, 1][:, None] + for t in range(n_tasks): + _axpy(n_samples, W[t, j], &X[0, j], 1, &R[0, t], 1) + W[t, j] = 0 + excluded_set[j] = 1 + for n_iter in range(max_iter): w_max = 0.0 d_w_max = 0.0 - for f_iter in range(n_features): # Loop over coordinates + for f_iter in range(n_active): # Loop over coordinates if random: - j = rand_int(n_features, rand_r_state) + j = rand_int(n_active, rand_r_state) else: j = f_iter + if do_screening: + j = active_set[j] + if norm2_cols_X[j] == 0.0: continue @@ -1050,8 +1100,7 @@ def enet_coordinate_descent_multi_task( # Using BLAS Level 1 (faster for small vectors like here): for t in range(n_tasks): if W[t, j] != w_j[t]: - _axpy(n_samples, w_j[t] - W[t, j], &X[0, j], 1, - &R[0, t], 1) + _axpy(n_samples, w_j[t] - W[t, j], &X[0, j], 1, &R[0, t], 1) # update the maximum absolute coefficient update d_w_j = diff_abs_max(n_tasks, &W[0, j], &w_j[0]) @@ -1068,11 +1117,34 @@ def enet_coordinate_descent_multi_task( # the tolerance: check the duality gap as ultimate stopping # criterion gap, dual_norm_XtA = gap_enet_multi_task( - n_samples, n_features, n_tasks, W, l1_reg, l2_reg, X, Y, R, XtA + n_samples, n_features, n_tasks, W, l1_reg, l2_reg, X, Y, R, XtA, XtA_row_norms ) if gap <= tol: # return if we reached desired tolerance break + + # Gap Safe Screening Rules for multi-task Lasso, see + # https://arxiv.org/abs/1703.07285 Eq 2.2. (also arxiv:1506.03736) + if do_screening: + n_active = 0 + for j in range(n_features): + if norm2_cols_X[j] == 0: + continue + # Xj_theta = ||X[:,j] @ dual_theta||_2 + Xj_theta = XtA_row_norms[j] / fmax(l1_reg, dual_norm_XtA) + d_j = (1 - Xj_theta) / sqrt(norm2_cols_X[j] + l2_reg) + if d_j <= sqrt(2 * gap) / l1_reg: + # include feature j + active_set[n_active] = j + excluded_set[j] = 0 + n_active += 1 + else: + # R += W[:, 1] * X[:, 1][:, None] + for t in range(n_tasks): + _axpy(n_samples, W[t, j], &X[0, j], 1, &R[0, t], 1) + W[t, j] = 0 + excluded_set[j] = 1 + else: # for/else, runs if for doesn't end with a `break` with gil: diff --git a/sklearn/linear_model/_coordinate_descent.py b/sklearn/linear_model/_coordinate_descent.py index abf1f13de8c23..925803fedcdf7 100644 --- a/sklearn/linear_model/_coordinate_descent.py +++ b/sklearn/linear_model/_coordinate_descent.py @@ -690,7 +690,7 @@ def enet_path( ) elif multi_output: model = cd_fast.enet_coordinate_descent_multi_task( - coef_, l1_reg, l2_reg, X, y, max_iter, tol, rng, random + coef_, l1_reg, l2_reg, X, y, max_iter, tol, rng, random, do_screening ) elif isinstance(precompute, np.ndarray): # We expect precompute to be already Fortran ordered when bypassing From 7d9e23ada7746a373a196b1a2e0b00b2df88fdfa Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Mon, 25 Aug 2025 22:18:11 +0200 Subject: [PATCH 10/10] CLN try to fix doctests --- sklearn/linear_model/_coordinate_descent.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/linear_model/_coordinate_descent.py b/sklearn/linear_model/_coordinate_descent.py index 925803fedcdf7..21c1350a387c2 100644 --- a/sklearn/linear_model/_coordinate_descent.py +++ b/sklearn/linear_model/_coordinate_descent.py @@ -3095,10 +3095,10 @@ class MultiTaskElasticNetCV(RegressorMixin, LinearModelCV): ... [[0, 0], [1, 1], [2, 2]]) MultiTaskElasticNetCV(cv=3) >>> print(clf.coef_) - [[0.52875032 0.46958558] - [0.52875032 0.46958558]] + [[0.51841231 0.479658] + [0.51841231 0.479658]] >>> print(clf.intercept_) - [0.00166409 0.00166409] + [0.001929... 0.001929...] """ _parameter_constraints: dict = { @@ -3349,7 +3349,7 @@ class MultiTaskLassoCV(RegressorMixin, LinearModelCV): >>> r2_score(y, reg.predict(X)) 0.9994 >>> reg.alpha_ - np.float64(0.5713) + np.float64(0.4321...) >>> reg.predict(X[:1,]) array([[153.7971, 94.9015]]) """