Skip to content

Commit c8753d4

Browse files
authored
ENH Preserving dtype for numpy.float32 in Least Angle Regression (#20155)
1 parent 1c36b49 commit c8753d4

File tree

3 files changed

+72
-6
lines changed

3 files changed

+72
-6
lines changed

doc/whats_new/v1.0.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,10 @@ Changelog
344344
is now faster. This is especially noticeable on large sparse input.
345345
:pr:`19734` by :user:`Fred Robinson <frrad>`.
346346

347+
- |Enhancement| `fit` method preserves dtype for numpy.float32 in
348+
:class:`Lars`, :class:`LassoLars`, :class:`LassoLars`, :class:`LarsCV` and
349+
:class:`LassoLarsCV`. :pr:`20155` by :user:`Takeshi Oura <takoika>`.
350+
347351
:mod:`sklearn.manifold`
348352
.......................
349353

sklearn/linear_model/_least_angle.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -476,12 +476,23 @@ def _lars_path_solver(
476476

477477
max_features = min(max_iter, n_features)
478478

479+
dtypes = set(a.dtype for a in (X, y, Xy, Gram) if a is not None)
480+
if len(dtypes) == 1:
481+
# use the precision level of input data if it is consistent
482+
return_dtype = next(iter(dtypes))
483+
else:
484+
# fallback to double precision otherwise
485+
return_dtype = np.float64
486+
479487
if return_path:
480-
coefs = np.zeros((max_features + 1, n_features))
481-
alphas = np.zeros(max_features + 1)
488+
coefs = np.zeros((max_features + 1, n_features), dtype=return_dtype)
489+
alphas = np.zeros(max_features + 1, dtype=return_dtype)
482490
else:
483-
coef, prev_coef = np.zeros(n_features), np.zeros(n_features)
484-
alpha, prev_alpha = np.array([0.]), np.array([0.]) # better ideas?
491+
coef, prev_coef = (np.zeros(n_features, dtype=return_dtype),
492+
np.zeros(n_features, dtype=return_dtype))
493+
alpha, prev_alpha = (np.array([0.], dtype=return_dtype),
494+
np.array([0.], dtype=return_dtype))
495+
# above better ideas?
485496

486497
n_iter, n_active = 0, 0
487498
active, indices = list(), np.arange(n_features)
@@ -948,7 +959,7 @@ def _fit(self, X, y, max_iter, alpha, fit_path, Xy=None):
948959

949960
self.alphas_ = []
950961
self.n_iter_ = []
951-
self.coef_ = np.empty((n_targets, n_features))
962+
self.coef_ = np.empty((n_targets, n_features), dtype=X.dtype)
952963

953964
if fit_path:
954965
self.active_ = []

sklearn/linear_model/tests/test_least_angle.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from sklearn import linear_model, datasets
1515
from sklearn.linear_model._least_angle import _lars_path_residues
1616
from sklearn.linear_model import LassoLarsIC, lars_path
17-
from sklearn.linear_model import Lars, LassoLars
17+
from sklearn.linear_model import Lars, LassoLars, LarsCV, LassoLarsCV
1818

1919
# TODO: use another dataset that has multiple drops
2020
diabetes = datasets.load_diabetes()
@@ -777,3 +777,54 @@ def test_copy_X_with_auto_gram():
777777
linear_model.lars_path(X, y, Gram='auto', copy_X=True, method='lasso')
778778
# X did not change
779779
assert_allclose(X, X_before)
780+
781+
782+
@pytest.mark.parametrize("LARS, has_coef_path, args",
783+
((Lars, True, {}),
784+
(LassoLars, True, {}),
785+
(LassoLarsIC, False, {}),
786+
(LarsCV, True, {}),
787+
# max_iter=5 is for avoiding ConvergenceWarning
788+
(LassoLarsCV, True, {"max_iter": 5})))
789+
@pytest.mark.parametrize("dtype", (np.float32, np.float64))
790+
def test_lars_dtype_match(LARS, has_coef_path, args, dtype):
791+
# The test ensures that the fit method preserves input dtype
792+
rng = np.random.RandomState(0)
793+
X = rng.rand(6, 6).astype(dtype)
794+
y = rng.rand(6).astype(dtype)
795+
796+
model = LARS(**args)
797+
model.fit(X, y)
798+
assert model.coef_.dtype == dtype
799+
if has_coef_path:
800+
assert model.coef_path_.dtype == dtype
801+
assert model.intercept_.dtype == dtype
802+
803+
804+
@pytest.mark.parametrize("LARS, has_coef_path, args",
805+
((Lars, True, {}),
806+
(LassoLars, True, {}),
807+
(LassoLarsIC, False, {}),
808+
(LarsCV, True, {}),
809+
# max_iter=5 is for avoiding ConvergenceWarning
810+
(LassoLarsCV, True, {"max_iter": 5})))
811+
def test_lars_numeric_consistency(LARS, has_coef_path, args):
812+
# The test ensures numerical consistency between trained coefficients
813+
# of float32 and float64.
814+
rtol = 1e-5
815+
atol = 1e-5
816+
817+
rng = np.random.RandomState(0)
818+
X_64 = rng.rand(6, 6)
819+
y_64 = rng.rand(6)
820+
821+
model_64 = LARS(**args).fit(X_64, y_64)
822+
model_32 = LARS(**args).fit(X_64.astype(np.float32),
823+
y_64.astype(np.float32))
824+
825+
assert_allclose(model_64.coef_, model_32.coef_, rtol=rtol, atol=atol)
826+
if has_coef_path:
827+
assert_allclose(model_64.coef_path_, model_32.coef_path_,
828+
rtol=rtol, atol=atol)
829+
assert_allclose(model_64.intercept_, model_32.intercept_,
830+
rtol=rtol, atol=atol)

0 commit comments

Comments
 (0)