|
14 | 14 | from sklearn import linear_model, datasets
|
15 | 15 | from sklearn.linear_model._least_angle import _lars_path_residues
|
16 | 16 | from sklearn.linear_model import LassoLarsIC, lars_path
|
17 |
| -from sklearn.linear_model import Lars, LassoLars |
| 17 | +from sklearn.linear_model import Lars, LassoLars, LarsCV, LassoLarsCV |
18 | 18 |
|
19 | 19 | # TODO: use another dataset that has multiple drops
|
20 | 20 | diabetes = datasets.load_diabetes()
|
@@ -777,3 +777,54 @@ def test_copy_X_with_auto_gram():
|
777 | 777 | linear_model.lars_path(X, y, Gram='auto', copy_X=True, method='lasso')
|
778 | 778 | # X did not change
|
779 | 779 | assert_allclose(X, X_before)
|
| 780 | + |
| 781 | + |
| 782 | +@pytest.mark.parametrize("LARS, has_coef_path, args", |
| 783 | + ((Lars, True, {}), |
| 784 | + (LassoLars, True, {}), |
| 785 | + (LassoLarsIC, False, {}), |
| 786 | + (LarsCV, True, {}), |
| 787 | + # max_iter=5 is for avoiding ConvergenceWarning |
| 788 | + (LassoLarsCV, True, {"max_iter": 5}))) |
| 789 | +@pytest.mark.parametrize("dtype", (np.float32, np.float64)) |
| 790 | +def test_lars_dtype_match(LARS, has_coef_path, args, dtype): |
| 791 | + # The test ensures that the fit method preserves input dtype |
| 792 | + rng = np.random.RandomState(0) |
| 793 | + X = rng.rand(6, 6).astype(dtype) |
| 794 | + y = rng.rand(6).astype(dtype) |
| 795 | + |
| 796 | + model = LARS(**args) |
| 797 | + model.fit(X, y) |
| 798 | + assert model.coef_.dtype == dtype |
| 799 | + if has_coef_path: |
| 800 | + assert model.coef_path_.dtype == dtype |
| 801 | + assert model.intercept_.dtype == dtype |
| 802 | + |
| 803 | + |
| 804 | +@pytest.mark.parametrize("LARS, has_coef_path, args", |
| 805 | + ((Lars, True, {}), |
| 806 | + (LassoLars, True, {}), |
| 807 | + (LassoLarsIC, False, {}), |
| 808 | + (LarsCV, True, {}), |
| 809 | + # max_iter=5 is for avoiding ConvergenceWarning |
| 810 | + (LassoLarsCV, True, {"max_iter": 5}))) |
| 811 | +def test_lars_numeric_consistency(LARS, has_coef_path, args): |
| 812 | + # The test ensures numerical consistency between trained coefficients |
| 813 | + # of float32 and float64. |
| 814 | + rtol = 1e-5 |
| 815 | + atol = 1e-5 |
| 816 | + |
| 817 | + rng = np.random.RandomState(0) |
| 818 | + X_64 = rng.rand(6, 6) |
| 819 | + y_64 = rng.rand(6) |
| 820 | + |
| 821 | + model_64 = LARS(**args).fit(X_64, y_64) |
| 822 | + model_32 = LARS(**args).fit(X_64.astype(np.float32), |
| 823 | + y_64.astype(np.float32)) |
| 824 | + |
| 825 | + assert_allclose(model_64.coef_, model_32.coef_, rtol=rtol, atol=atol) |
| 826 | + if has_coef_path: |
| 827 | + assert_allclose(model_64.coef_path_, model_32.coef_path_, |
| 828 | + rtol=rtol, atol=atol) |
| 829 | + assert_allclose(model_64.intercept_, model_32.intercept_, |
| 830 | + rtol=rtol, atol=atol) |
0 commit comments