@@ -510,6 +510,46 @@ def test_uniform_targets():
510
510
assert_array_equal (model .alphas_ , [np .finfo (float ).resolution ] * 3 )
511
511
512
512
513
+ @pytest .mark .filterwarnings ("error::sklearn.exceptions.ConvergenceWarning" )
514
+ def test_multi_task_lasso_vs_skglm ():
515
+ """Test that MultiTaskLasso gives same results as the one from skglm.
516
+
517
+ To reproduce numbers, just use
518
+ from skglm import MultiTaskLasso
519
+ """
520
+ # Numbers are with skglm version 0.5.
521
+ n_samples , n_features , n_tasks = 5 , 4 , 3
522
+ X = np .vander (np .arange (n_samples ), n_features )
523
+ Y = np .arange (n_samples * n_tasks ).reshape (n_samples , n_tasks )
524
+
525
+ def obj (W , X , y , alpha ):
526
+ intercept = W [:, - 1 ]
527
+ W = W [:, :- 1 ]
528
+ l21_norm = np .sqrt (np .sum (W ** 2 , axis = 0 )).sum ()
529
+ return (
530
+ np .linalg .norm (Y - X @ W .T - intercept , ord = "fro" ) ** 2 / (2 * n_samples )
531
+ + alpha * l21_norm
532
+ )
533
+
534
+ alpha = 0.1
535
+ # TODO: The high number of iterations are required for convergence and show room
536
+ # for improvement of the CD algorithm.
537
+ m = MultiTaskLasso (alpha = alpha , tol = 1e-10 , max_iter = 5000 ).fit (X , Y )
538
+ assert_allclose (
539
+ obj (np .c_ [m .coef_ , m .intercept_ ], X , Y , alpha = alpha ),
540
+ 0.4965993692547902 ,
541
+ rtol = 1e-10 ,
542
+ )
543
+ assert_allclose (
544
+ m .intercept_ , [0.219942959407 , 1.219942959407 , 2.219942959407 ], rtol = 1e-7
545
+ )
546
+ assert_allclose (
547
+ m .coef_ ,
548
+ np .tile ([- 0.032075014794 , 0.25430904614 , 2.44785152982 , 0 ], (n_tasks , 1 )),
549
+ rtol = 1e-6 ,
550
+ )
551
+
552
+
513
553
def test_multi_task_lasso_and_enet ():
514
554
X , y , X_test , y_test = build_dataset ()
515
555
Y = np .c_ [y , y ]
0 commit comments