Skip to content

Commit 6af1d84

Browse files
committed
FIX improve 'precompute' handling in Lars
1 parent e5f71e7 commit 6af1d84

File tree

4 files changed

+81
-39
lines changed

4 files changed

+81
-39
lines changed

sklearn/linear_model/least_angle.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -174,16 +174,19 @@ def lars_path(X, y, Xy=None, Gram=None, max_iter=500,
174174
swap, nrm2 = linalg.get_blas_funcs(('swap', 'nrm2'), (X,))
175175
solve_cholesky, = get_lapack_funcs(('potrs',), (X,))
176176

177-
if Gram is None:
177+
if Gram is None or Gram is False:
178+
Gram = None
178179
if copy_X:
179180
# force copy. setting the array to be fortran-ordered
180181
# speeds up the calculation of the (partial) Gram matrix
181182
# and allows to easily swap columns
182183
X = X.copy('F')
183-
elif isinstance(Gram, string_types) and Gram == 'auto':
184-
Gram = None
185-
if X.shape[0] > X.shape[1]:
184+
185+
elif isinstance(Gram, string_types) and Gram == 'auto' or Gram is True:
186+
if Gram is True or X.shape[0] > X.shape[1]:
186187
Gram = np.dot(X.T, X)
188+
else:
189+
Gram = None
187190
elif copy_Gram:
188191
Gram = Gram.copy()
189192

@@ -593,16 +596,14 @@ def __init__(self, fit_intercept=True, verbose=False, normalize=True,
593596
self.copy_X = copy_X
594597
self.fit_path = fit_path
595598

596-
def _get_gram(self):
597-
# precompute if n_samples > n_features
598-
precompute = self.precompute
599-
if hasattr(precompute, '__array__'):
600-
Gram = precompute
601-
elif precompute == 'auto':
602-
Gram = 'auto'
603-
else:
604-
Gram = None
605-
return Gram
599+
def _get_gram(self, precompute, X, y):
600+
if (not hasattr(precompute, '__array__')) and (
601+
(precompute is True) or
602+
(precompute == 'auto' and X.shape[0] > X.shape[1]) or
603+
(precompute == 'auto' and y.shape[1] > 1)):
604+
precompute = np.dot(X.T, X)
605+
606+
return precompute
606607

607608
def fit(self, X, y, Xy=None):
608609
"""Fit the model using X, y as training data.
@@ -645,14 +646,7 @@ def fit(self, X, y, Xy=None):
645646
else:
646647
max_iter = self.max_iter
647648

648-
precompute = self.precompute
649-
if not hasattr(precompute, '__array__') and (
650-
precompute is True or
651-
(precompute == 'auto' and X.shape[0] > X.shape[1]) or
652-
(precompute == 'auto' and y.shape[1] > 1)):
653-
Gram = np.dot(X.T, X)
654-
else:
655-
Gram = self._get_gram()
649+
Gram = self._get_gram(self.precompute, X, y)
656650

657651
self.alphas_ = []
658652
self.n_iter_ = []
@@ -972,10 +966,10 @@ class LarsCV(Lars):
972966
copy_X : boolean, optional, default True
973967
If ``True``, X will be copied; else, it may be overwritten.
974968
975-
precompute : True | False | 'auto' | array-like
969+
precompute : True | False | 'auto'
976970
Whether to use a precomputed Gram matrix to speed up
977-
calculations. If set to ``'auto'`` let us decide. The Gram
978-
matrix can also be passed as argument.
971+
calculations. If set to ``'auto'`` let us decide. The Gram matrix
972+
cannot be passed as argument since we will use only subsets of X.
979973
980974
max_iter: integer, optional
981975
Maximum number of iterations to perform.
@@ -1081,7 +1075,13 @@ def fit(self, X, y):
10811075
# init cross-validation generator
10821076
cv = check_cv(self.cv, classifier=False)
10831077

1084-
Gram = 'auto' if self.precompute else None
1078+
# As we use cross-validation, the Gram matrix is not precomputed here
1079+
Gram = self.precompute
1080+
if hasattr(Gram, '__array__'):
1081+
warnings.warn("Parameter 'precompute' cannot be an array in "
1082+
"%s. Automatically switch to 'auto' instead."
1083+
% self.__class__.__name__)
1084+
Gram = 'auto'
10851085

10861086
cv_paths = Parallel(n_jobs=self.n_jobs, verbose=self.verbose)(
10871087
delayed(_lars_path_residues)(
@@ -1171,10 +1171,10 @@ class LassoLarsCV(LarsCV):
11711171
normalize : boolean, optional, default False
11721172
If True, the regressors X will be normalized before regression.
11731173
1174-
precompute : True | False | 'auto' | array-like
1174+
precompute : True | False | 'auto'
11751175
Whether to use a precomputed Gram matrix to speed up
1176-
calculations. If set to ``'auto'`` let us decide. The Gram
1177-
matrix can also be passed as argument.
1176+
calculations. If set to ``'auto'`` let us decide. The Gram matrix
1177+
cannot be passed as argument since we will use only subsets of X.
11781178
11791179
max_iter : integer, optional
11801180
Maximum number of iterations to perform.
@@ -1404,7 +1404,7 @@ def fit(self, X, y, copy_X=True):
14041404
X, y, self.fit_intercept, self.normalize, self.copy_X)
14051405
max_iter = self.max_iter
14061406

1407-
Gram = self._get_gram()
1407+
Gram = self.precompute
14081408

14091409
alphas_, active_, coef_path_, self.n_iter_ = lars_path(
14101410
X, y, Gram=Gram, copy_X=copy_X, copy_Gram=True, alpha_min=0.0,

sklearn/linear_model/randomized_l1.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def _randomized_lasso(X, y, weights, mask, alpha=1., verbose=False,
165165
alpha = np.atleast_1d(np.asarray(alpha, dtype=np.float64))
166166

167167
X = (1 - weights) * X
168+
168169
with warnings.catch_warnings():
169170
warnings.simplefilter('ignore', ConvergenceWarning)
170171
alphas_, _, coef_ = lars_path(X, y,
@@ -226,10 +227,11 @@ class RandomizedLasso(BaseRandomizedLinearModel):
226227
normalize : boolean, optional, default True
227228
If True, the regressors X will be normalized before regression.
228229
229-
precompute : True | False | 'auto'
230-
Whether to use a precomputed Gram matrix to speed up
231-
calculations. If set to 'auto' let us decide. The Gram
232-
matrix can also be passed as argument.
230+
precompute : True | False | 'auto' | array-like
231+
Whether to use a precomputed Gram matrix to speed up calculations.
232+
If set to 'auto' let us decide.
233+
The Gram matrix can also be passed as argument, but it will be used
234+
only for the selection of parameter alpha, if alpha is 'aic' or 'bic'.
233235
234236
max_iter : integer, optional
235237
Maximum number of iterations to perform in the Lars algorithm.
@@ -328,7 +330,6 @@ def __init__(self, alpha='aic', scaling=.5, sample_fraction=.75,
328330
self.memory = memory
329331

330332
def _make_estimator_and_params(self, X, y):
331-
assert self.precompute in (True, False, None, 'auto')
332333
alpha = self.alpha
333334
if alpha in ('aic', 'bic'):
334335
model = LassoLarsIC(precompute=self.precompute,
@@ -337,9 +338,16 @@ def _make_estimator_and_params(self, X, y):
337338
eps=self.eps)
338339
model.fit(X, y)
339340
self.alpha_ = alpha = model.alpha_
341+
342+
precompute = self.precompute
343+
# A precomputed Gram array is useless, since _randomized_lasso
344+
# change X a each iteration
345+
if hasattr(precompute, '__array__'):
346+
precompute = 'auto'
347+
assert precompute in (True, False, None, 'auto')
340348
return _randomized_lasso, dict(alpha=alpha, max_iter=self.max_iter,
341349
eps=self.eps,
342-
precompute=self.precompute)
350+
precompute=precompute)
343351

344352

345353
###############################################################################

sklearn/linear_model/tests/test_least_angle.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,20 @@ def test_no_path_all_precomputed():
169169
assert_true(alpha_ == alphas_[-1])
170170

171171

172+
def test_lars_precompute():
173+
# Check for different values of precompute
174+
X, y = diabetes.data, diabetes.target
175+
G = np.dot(X.T, X)
176+
for classifier in [linear_model.Lars, linear_model.LarsCV,
177+
linear_model.LassoLarsIC]:
178+
clf = classifier(precompute=G)
179+
output_1 = ignore_warnings(clf.fit)(X, y).coef_
180+
for precompute in [True, False, 'auto', None]:
181+
clf = classifier(precompute=precompute)
182+
output_2 = clf.fit(X, y).coef_
183+
assert_array_almost_equal(output_1, output_2, decimal=8)
184+
185+
172186
def test_singular_matrix():
173187
# Test when input is a singular matrix
174188
X1 = np.array([[1, 1.], [1., 1.]])

sklearn/linear_model/tests/test_randomized_l1.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,17 +42,18 @@ def test_randomized_lasso():
4242
# Check randomized lasso
4343
scaling = 0.3
4444
selection_threshold = 0.5
45+
n_resampling = 20
4546

4647
# or with 1 alpha
4748
clf = RandomizedLasso(verbose=False, alpha=1, random_state=42,
48-
scaling=scaling,
49+
scaling=scaling, n_resampling=n_resampling,
4950
selection_threshold=selection_threshold)
5051
feature_scores = clf.fit(X, y).scores_
5152
assert_array_equal(np.argsort(F)[-3:], np.argsort(feature_scores)[-3:])
5253

5354
# or with many alphas
5455
clf = RandomizedLasso(verbose=False, alpha=[1, 0.8], random_state=42,
55-
scaling=scaling,
56+
scaling=scaling, n_resampling=n_resampling,
5657
selection_threshold=selection_threshold)
5758
feature_scores = clf.fit(X, y).scores_
5859
assert_equal(clf.all_scores_.shape, (X.shape[1], 2))
@@ -64,7 +65,7 @@ def test_randomized_lasso():
6465
assert_equal(X_full.shape, X.shape)
6566

6667
clf = RandomizedLasso(verbose=False, alpha='aic', random_state=42,
67-
scaling=scaling)
68+
scaling=scaling, n_resampling=n_resampling)
6869
feature_scores = clf.fit(X, y).scores_
6970
assert_array_equal(feature_scores, X.shape[1] * [1.])
7071

@@ -75,6 +76,25 @@ def test_randomized_lasso():
7576
assert_raises(ValueError, clf.fit, X, y)
7677

7778

79+
def test_randomized_lasso_precompute():
80+
# Check randomized lasso for different values of precompute
81+
n_resampling = 20
82+
alpha = 1
83+
random_state = 42
84+
85+
G = np.dot(X.T, X)
86+
87+
clf = RandomizedLasso(alpha=alpha, random_state=random_state,
88+
precompute=G, n_resampling=n_resampling)
89+
feature_scores_1 = clf.fit(X, y).scores_
90+
91+
for precompute in [True, False, None, 'auto']:
92+
clf = RandomizedLasso(alpha=alpha, random_state=random_state,
93+
precompute=precompute, n_resampling=n_resampling)
94+
feature_scores_2 = clf.fit(X, y).scores_
95+
assert_array_equal(feature_scores_1, feature_scores_2)
96+
97+
7898
def test_randomized_logistic():
7999
# Check randomized sparse logistic regression
80100
iris = load_iris()

0 commit comments

Comments
 (0)