Skip to content

Commit 3883ba7

Browse files
authored
TST add test_multi_task_lasso_vs_skglm (#31957)
1 parent 2883187 commit 3883ba7

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

sklearn/linear_model/_cd_fast.pyx

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -786,6 +786,12 @@ def enet_coordinate_descent_multi_task(
786786
787787
0.5 * norm(Y - X W.T, 2)^2 + l1_reg ||W.T||_21 + 0.5 * l2_reg norm(W.T, 2)^2
788788
789+
The algorithm follows
790+
Noah Simon, Jerome Friedman, Trevor Hastie. 2013.
791+
A Blockwise Descent Algorithm for Group-penalized Multiresponse and Multinomial
792+
Regression
793+
https://doi.org/10.48550/arXiv.1311.6529
794+
789795
Returns
790796
-------
791797
W : ndarray of shape (n_tasks, n_features)

sklearn/linear_model/tests/test_coordinate_descent.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,46 @@ def test_uniform_targets():
510510
assert_array_equal(model.alphas_, [np.finfo(float).resolution] * 3)
511511

512512

513+
@pytest.mark.filterwarnings("error::sklearn.exceptions.ConvergenceWarning")
514+
def test_multi_task_lasso_vs_skglm():
515+
"""Test that MultiTaskLasso gives same results as the one from skglm.
516+
517+
To reproduce numbers, just use
518+
from skglm import MultiTaskLasso
519+
"""
520+
# Numbers are with skglm version 0.5.
521+
n_samples, n_features, n_tasks = 5, 4, 3
522+
X = np.vander(np.arange(n_samples), n_features)
523+
Y = np.arange(n_samples * n_tasks).reshape(n_samples, n_tasks)
524+
525+
def obj(W, X, y, alpha):
526+
intercept = W[:, -1]
527+
W = W[:, :-1]
528+
l21_norm = np.sqrt(np.sum(W**2, axis=0)).sum()
529+
return (
530+
np.linalg.norm(Y - X @ W.T - intercept, ord="fro") ** 2 / (2 * n_samples)
531+
+ alpha * l21_norm
532+
)
533+
534+
alpha = 0.1
535+
# TODO: The high number of iterations are required for convergence and show room
536+
# for improvement of the CD algorithm.
537+
m = MultiTaskLasso(alpha=alpha, tol=1e-10, max_iter=5000).fit(X, Y)
538+
assert_allclose(
539+
obj(np.c_[m.coef_, m.intercept_], X, Y, alpha=alpha),
540+
0.4965993692547902,
541+
rtol=1e-10,
542+
)
543+
assert_allclose(
544+
m.intercept_, [0.219942959407, 1.219942959407, 2.219942959407], rtol=1e-7
545+
)
546+
assert_allclose(
547+
m.coef_,
548+
np.tile([-0.032075014794, 0.25430904614, 2.44785152982, 0], (n_tasks, 1)),
549+
rtol=1e-6,
550+
)
551+
552+
513553
def test_multi_task_lasso_and_enet():
514554
X, y, X_test, y_test = build_dataset()
515555
Y = np.c_[y, y]

0 commit comments

Comments
 (0)