Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/modules/linear_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ for another implementation::
>>> clf = linear_model.Lasso(alpha = 0.1)
>>> clf.fit([[0, 0], [1, 1]], [0, 1])
Lasso(alpha=0.1, copy_X=True, fit_intercept=True, max_iter=1000,
normalize=False, positive=False, precompute='auto', random_state=None,
normalize=False, positive=False, precompute=False, random_state=None,
selection='cyclic', tol=0.0001, warm_start=False)
>>> clf.predict([[1, 1]])
array([ 0.8])
Expand Down
2 changes: 1 addition & 1 deletion doc/tutorial/statistical_inference/supervised_learning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ application of Occam's razor: *prefer simpler models*.
>>> regr.alpha = best_alpha
>>> regr.fit(diabetes_X_train, diabetes_y_train)
Lasso(alpha=0.025118864315095794, copy_X=True, fit_intercept=True,
max_iter=1000, normalize=False, positive=False, precompute='auto',
max_iter=1000, normalize=False, positive=False, precompute=False,
random_state=None, selection='cyclic', tol=0.0001, warm_start=False)
>>> print(regr.coef_)
[ 0. -212.43764548 517.19478111 313.77959962 -160.8303982 -0.
Expand Down
7 changes: 7 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,13 @@ API changes summary
but previous versions accidentally returned only the positive
probability. Fixed by Will Lamond and `Lars Buitinck`_.

- Change default value of precompute in :class:`ElasticNet` and :class:`Lasso`
to False. Setting precompute to "auto" was found to be slower when
n_samples > n_features since the computation of the Gram matrix is
computationally expensive and outweighs the benefit of fitting the Gram
for just one alpha.
``precompute="auto"`` is now deprecated and will be removed in 0.18
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why 0.18?

I would raise the warning in 0.16 and remove in 0.17

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I thought the warnings were for two versions ;)

By `Manoj Kumar`_.

.. _changes_0_15_2:

Expand Down
27 changes: 19 additions & 8 deletions sklearn/linear_model/coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .base import LinearModel, _pre_fit
from ..base import RegressorMixin
from .base import center_data, sparse_center_data
from ..utils import check_array
from ..utils import check_array, check_X_y
from ..utils.validation import check_random_state
from ..cross_validation import _check_cv as check_cv
from ..externals.joblib import Parallel, delayed
Expand Down Expand Up @@ -604,6 +604,8 @@ class ElasticNet(LinearModel, RegressorMixin):
calculations. If set to ``'auto'`` let us decide. The Gram
matrix can also be passed as argument. For sparse input
this option is always ``True`` to preserve sparsity.
WARNING : The ``'auto'`` option is deprecated and will
be removed in 0.18.

max_iter : int, optional
The maximum number of iterations
Expand Down Expand Up @@ -665,7 +667,7 @@ class ElasticNet(LinearModel, RegressorMixin):
path = staticmethod(enet_path)

def __init__(self, alpha=1.0, l1_ratio=0.5, fit_intercept=True,
normalize=False, precompute='auto', max_iter=1000,
normalize=False, precompute=False, max_iter=1000,
copy_X=True, tol=1e-4, warm_start=False, positive=False,
random_state=None, selection='cyclic'):
self.alpha = alpha
Expand Down Expand Up @@ -708,10 +710,16 @@ def fit(self, X, y):
warnings.warn("With alpha=0, this algorithm does not converge "
"well. You are advised to use the LinearRegression "
"estimator", stacklevel=2)
X = check_array(X, 'csc', dtype=np.float64, order='F', copy=self.copy_X
and self.fit_intercept)
# From now on X can be touched inplace
y = np.asarray(y, dtype=np.float64)

if self.precompute == 'auto':
warnings.warn("Setting precompute to 'auto', was found to be "
"slower even when n_samples > n_features. Hence "
"it will be removed in 0.18.",
DeprecationWarning, stacklevel=2)

X, y = check_X_y(X, y, accept_sparse='csc', dtype=np.float64,
order='F', copy=self.copy_X and self.fit_intercept,
multi_output=True)

X, y, X_mean, y_mean, X_std, precompute, Xy = \
_pre_fit(X, y, None, self.precompute, self.normalize,
Expand Down Expand Up @@ -830,6 +838,8 @@ class Lasso(ElasticNet):
calculations. If set to ``'auto'`` let us decide. The Gram
matrix can also be passed as argument. For sparse input
this option is always ``True`` to preserve sparsity.
WARNING : The ``'auto'`` option is deprecated and will
be removed in 0.18.

max_iter : int, optional
The maximum number of iterations
Expand Down Expand Up @@ -880,7 +890,7 @@ class Lasso(ElasticNet):
>>> clf = linear_model.Lasso(alpha=0.1)
>>> clf.fit([[0,0], [1, 1], [2, 2]], [0, 1, 2])
Lasso(alpha=0.1, copy_X=True, fit_intercept=True, max_iter=1000,
normalize=False, positive=False, precompute='auto', random_state=None,
normalize=False, positive=False, precompute=False, random_state=None,
selection='cyclic', tol=0.0001, warm_start=False)
>>> print(clf.coef_)
[ 0.85 0. ]
Expand All @@ -906,7 +916,7 @@ class Lasso(ElasticNet):
path = staticmethod(enet_path)

def __init__(self, alpha=1.0, fit_intercept=True, normalize=False,
precompute='auto', copy_X=True, max_iter=1000,
precompute=False, copy_X=True, max_iter=1000,
tol=1e-4, warm_start=False, positive=False,
random_state=None, selection='cyclic'):
super(Lasso, self).__init__(
Expand Down Expand Up @@ -1207,6 +1217,7 @@ def fit(self, X, y):
model.alpha = best_alpha
model.l1_ratio = best_l1_ratio
model.copy_X = copy_X
model.precompute = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not model.precompute = self.precompute?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the last fit.

  1. There might be a case where self.precompute="auto", and we have deprecated it.
  2. If self.precompute=True and we set model.precompute=True, I think we might be going against our principle of believing that computing the Gram, is useless for doing a single fit.

This is basically to make the last fit as fast as possible.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok got it.

model.fit(X, y)
if not hasattr(self, 'l1_ratio'):
del self.l1_ratio_
Expand Down
12 changes: 12 additions & 0 deletions sklearn/linear_model/tests/test_coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,18 @@ def test_random_descent():
assert_raises(ValueError, clf_random.fit, X, y)


def test_deprection_precompute_enet():
"""
Test that setting precompute="auto" gives a Deprecation Warning.
"""

X, y, _, _ = build_dataset(n_samples=20, n_features=10)
clf = ElasticNet(precompute="auto")
assert_warns(DeprecationWarning, clf.fit, X, y)
clf = Lasso(precompute="auto")
assert_warns(DeprecationWarning, clf.fit, X, y)


if __name__ == '__main__':
import nose
nose.runmodule()