From c6d7bcfa64c01630d45bda8156035cf519a34073 Mon Sep 17 00:00:00 2001 From: Jaques Grobler Date: Tue, 14 May 2013 01:59:40 +0200 Subject: [PATCH 1/3] make new classes for lasso_path/enet_path and deprecate old --- doc/whats_new.rst | 6 ++ .../plot_lasso_coordinate_descent_path.py | 47 +++++------ sklearn/linear_model/coordinate_descent.py | 83 +++++++++++++++---- .../tests/test_coordinate_descent.py | 37 ++++++++- .../linear_model/tests/test_least_angle.py | 12 ++- 5 files changed, 143 insertions(+), 42 deletions(-) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 1c7625e2f3686..c2a1fa04c0ca5 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -133,6 +133,12 @@ Changelog API changes summary ------------------- + - :class:`linear_model.lasso_path` and + :class:`linear_model.enet_path` can return its results in the same + format as that of :class:`linear_model.lars_path`. This is done by + setting the `return_models` parameter to `False`. By + `Jaques Grobler`_ + - :class:`grid_search.IterGrid` was renamed to :class:`grid_search.ParameterGrid`. diff --git a/examples/linear_model/plot_lasso_coordinate_descent_path.py b/examples/linear_model/plot_lasso_coordinate_descent_path.py index 0f66a0417d75a..5f944e0ca7ae5 100644 --- a/examples/linear_model/plot_lasso_coordinate_descent_path.py +++ b/examples/linear_model/plot_lasso_coordinate_descent_path.py @@ -23,41 +23,40 @@ X = diabetes.data y = diabetes.target -X /= X.std(0) # Standardize data (easier to set the l1_ratio parameter) +X /= X.std(axis=0) # Standardize data (easier to set the l1_ratio parameter) -############################################################################### # Compute paths eps = 5e-3 # the smaller it is the longer is the path print("Computing regularization path using the lasso...") -models = lasso_path(X, y, eps=eps) -alphas_lasso = np.array([model.alpha for model in models]) -coefs_lasso = np.array([model.coef_ for model in models]) +# The return_models parameter sets that lasso_path will return +# the alphas and the coefficients as output, instead of a list +# of models as it does by default. Returning the list of models +# is deprecated and will eventually be removed in 0.15 +alphas_lasso, coefs_lasso = lasso_path(X, y, eps, return_models=False) print("Computing regularization path using the positive lasso...") -models = lasso_path(X, y, eps=eps, positive=True) -alphas_positive_lasso = np.array([model.alpha for model in models]) -coefs_positive_lasso = np.array([model.coef_ for model in models]) - +alphas_positive_lasso, coefs_positive_lasso = lasso_path(X, y, eps, + positive=True, + return_models=False) print("Computing regularization path using the elastic net...") -models = enet_path(X, y, eps=eps, l1_ratio=0.8) -alphas_enet = np.array([model.alpha for model in models]) -coefs_enet = np.array([model.coef_ for model in models]) +alphas_enet, coefs_enet = enet_path(X, y, eps=eps, l1_ratio=0.8, + return_models=False) print("Computing regularization path using the positve elastic net...") -models = enet_path(X, y, eps=eps, l1_ratio=0.8, positive=True) -alphas_positive_enet = np.array([model.alpha for model in models]) -coefs_positive_enet = np.array([model.coef_ for model in models]) +alphas_positive_enet, coefs_positive_enet = enet_path(X, y, eps=eps, + l1_ratio=0.8, + positive=True, + return_models=False) -############################################################################### # Display results pl.figure(1) ax = pl.gca() ax.set_color_cycle(2 * ['b', 'r', 'g', 'c', 'k']) -l1 = pl.plot(-np.log10(alphas_lasso), coefs_lasso) -l2 = pl.plot(-np.log10(alphas_enet), coefs_enet, linestyle='--') +l1 = pl.plot(-np.log10(alphas_lasso), coefs_lasso.T) +l2 = pl.plot(-np.log10(alphas_enet), coefs_enet.T, linestyle='--') pl.xlabel('-Log(alpha)') pl.ylabel('coefficients') @@ -69,9 +68,9 @@ pl.figure(2) ax = pl.gca() ax.set_color_cycle(2 * ['b', 'r', 'g', 'c', 'k']) -l1 = pl.plot(-np.log10(alphas_lasso), coefs_lasso) -l2 = pl.plot(-np.log10(alphas_positive_lasso), coefs_positive_lasso, - linestyle='--') +l1 = pl.plot(-np.log10(alphas_lasso), coefs_lasso.T) +l2 = pl.plot(-np.log10(alphas_positive_lasso), coefs_positive_lasso.T, + linestyle='--') pl.xlabel('-Log(alpha)') pl.ylabel('coefficients') @@ -83,9 +82,9 @@ pl.figure(3) ax = pl.gca() ax.set_color_cycle(2 * ['b', 'r', 'g', 'c', 'k']) -l1 = pl.plot(-np.log10(alphas_enet), coefs_enet) -l2 = pl.plot(-np.log10(alphas_positive_enet), coefs_positive_enet, - linestyle='--') +l1 = pl.plot(-np.log10(alphas_enet), coefs_enet.T) +l2 = pl.plot(-np.log10(alphas_positive_enet), coefs_positive_enet.T, + linestyle='--') pl.xlabel('-Log(alpha)') pl.ylabel('coefficients') diff --git a/sklearn/linear_model/coordinate_descent.py b/sklearn/linear_model/coordinate_descent.py index 16cdaff02fbee..21b0aed736aa6 100644 --- a/sklearn/linear_model/coordinate_descent.py +++ b/sklearn/linear_model/coordinate_descent.py @@ -520,7 +520,7 @@ def _alpha_grid(X, y, Xy=None, l1_ratio=1.0, fit_intercept=True, def lasso_path(X, y, eps=1e-3, n_alphas=100, alphas=None, precompute='auto', Xy=None, fit_intercept=True, - normalize=False, copy_X=True, verbose=False, + normalize=False, copy_X=True, verbose=False, return_models=True, **params): """Compute Lasso path with coordinate descent @@ -569,12 +569,29 @@ def lasso_path(X, y, eps=1e-3, n_alphas=100, alphas=None, verbose : bool or integer Amount of verbosity + return_models : boolean, optional, default True + If ``True``, the function will return list of models. Setting it + to ``False`` will change the function output returning the values + of the alphas and the coefficients along the path. Returning the + model list will be removed in version 0.15. + params : kwargs keyword arguments passed to the Lasso objects Returns ------- models : a list of models along the regularization path + (Is returned if ``return_models`` is set ``True`` (default). + + alphas : array, shape: [n_alphas + 1] + The alphas along the path where models are computed. + (Is returned, along with ``coefs``, when ``return_models`` is set + to ``False``) + + coefs : shape (n_features, n_alphas + 1) + Coefficients along the path. + (Is returned, along with ``alphas``, when ``return_models`` is set + to ``False``). Notes ----- @@ -589,6 +606,11 @@ def lasso_path(X, y, eps=1e-3, n_alphas=100, alphas=None, interpolation can be used to retrieve model coefficients between the values output by lars_path + Deprecation Notice: Setting ``return_models`` to ``False`` will make + the Lasso Path return an output in the style used by :func:`lars_path`. + This will be become the norm as of version 0.15. Leaving ``return_models`` + set to `True` will let the function return a list of models as before. + Examples --------- @@ -597,9 +619,9 @@ def lasso_path(X, y, eps=1e-3, n_alphas=100, alphas=None, >>> X = np.array([[1, 2, 3.1], [2.3, 5.4, 4.3]]).T >>> y = np.array([1, 2, 3.1]) >>> # Use lasso_path to compute a coefficient path - >>> coef_path = [e.coef_ for e in lasso_path(X, y, alphas=[5., 1., .5], \ -fit_intercept=False)] - >>> print(np.array(coef_path).T) + >>> coef_path = [e.coef_ for e in lasso_path(X, y, alphas=[5., 1., .5], + ... fit_intercept=False)] + >>> print np.array(coef_path).T [[ 0. 0. 0.46874778] [ 0.2159048 0.4425765 0.23689075]] @@ -609,8 +631,8 @@ def lasso_path(X, y, eps=1e-3, n_alphas=100, alphas=None, >>> alphas, active, coef_path_lars = lars_path(X, y, method='lasso') >>> from scipy import interpolate >>> coef_path_continuous = interpolate.interp1d(alphas[::-1], - ... coef_path_lars[:, ::-1]) - >>> print(coef_path_continuous([5., 1., .5])) + ... coef_path_lars[:, ::-1]) + >>> print coef_path_continuous([5., 1., .5]) [[ 0. 0. 0.46915237] [ 0.2159048 0.4425765 0.23668876]] @@ -627,13 +649,14 @@ def lasso_path(X, y, eps=1e-3, n_alphas=100, alphas=None, return enet_path(X, y, l1_ratio=1., eps=eps, n_alphas=n_alphas, alphas=alphas, precompute=precompute, Xy=Xy, fit_intercept=fit_intercept, normalize=normalize, - copy_X=copy_X, verbose=verbose, **params) + copy_X=copy_X, verbose=verbose, + return_models=return_models, **params) def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None, precompute='auto', Xy=None, fit_intercept=True, normalize=False, copy_X=True, verbose=False, rho=None, - **params): + return_models=True, **params): """Compute Elastic-Net path with coordinate descent The Elastic Net optimization function is:: @@ -687,24 +710,53 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None, verbose : bool or integer Amount of verbosity + return_models : boolean, optional, default True + If ``True``, the function will return list of models. Setting it + to ``False`` will change the function output returning the values + of the alphas and the coefficients along the path. Returning the + model list will be removed in version 0.15. + params : kwargs keyword arguments passed to the Lasso objects Returns ------- models : a list of models along the regularization path + (Is returned if ``return_models`` is set ``True`` (default). + + alphas : array, shape: [n_alphas + 1] + The alphas along the path where models are computed. + (Is returned, along with ``coefs``, when ``return_models`` is set + to ``False``) + + coefs : shape (n_features, n_alphas + 1) + Coefficients along the path. + (Is returned, along with ``alphas``, when ``return_models`` is set + to ``False``). Notes ----- See examples/linear_model/plot_lasso_coordinate_descent_path.py for an example. + Deprecation Notice: Setting ``return_models`` to ``False`` will make + the Lasso Path return an output in the style used by :func:`lars_path`. + This will be become the norm as of version 0.15. Leaving ``return_models`` + set to `True` will let the function return a list of models as before. + See also -------- ElasticNet ElasticNetCV """ - + if return_models: + warnings.warn("Use enet_path(return_models=False), as it returns the" + " coefficients and alphas instead of just a list of" + " models as previously `lasso_path`/`enet_path` did." + " `return_models` will eventually be removed in 0.15," + " after which, returning alphas and coefs" + " will become the norm.", + DeprecationWarning, stacklevel=2) if rho is not None: l1_ratio = rho warnings.warn("rho was renamed to l1_ratio and will be removed " @@ -720,7 +772,6 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None, # at each fit... n_samples, n_features = X.shape - if Xy is None: Xy = safe_sparse_dot(X.T, y, dense_output=True) @@ -745,7 +796,7 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None, if precompute == 'auto': precompute = (n_samples > n_features) - if precompute: + if precompute or (precompute == 'auto'): if sparse.isspmatrix(X): warnings.warn("precompute is ignored for sparse data") precompute = False @@ -754,6 +805,7 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None, coef_ = None # init coef_ models = [] + coefs = [] n_alphas = len(alphas) for i, alpha in enumerate(alphas): @@ -762,7 +814,6 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None, fit_intercept=fit_intercept if sparse.isspmatrix(X) else False, precompute=precompute) model.set_params(**params) - model.copy_X = False model.fit(X, y, coef_init=coef_, Xy=Xy) if fit_intercept and not sparse.isspmatrix(X): model.fit_intercept = True @@ -774,10 +825,14 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None, print('Path: %03i out of %03i' % (i, n_alphas)) else: sys.stderr.write('.') - coef_ = model.coef_.copy() + coefs.append(model.coef_) + coef_ = coefs[-1].copy() models.append(model) - return models + if return_models: + return models + else: + return alphas, np.asarray(coefs).T def _path_residuals(X, y, train, test, path, path_params, l1_ratio=1, X_order=None, dtype=None): diff --git a/sklearn/linear_model/tests/test_coordinate_descent.py b/sklearn/linear_model/tests/test_coordinate_descent.py index 5232f343493c2..da911d29b1424 100644 --- a/sklearn/linear_model/tests/test_coordinate_descent.py +++ b/sklearn/linear_model/tests/test_coordinate_descent.py @@ -16,8 +16,9 @@ from sklearn.utils.testing import assert_greater from sklearn.linear_model.coordinate_descent import Lasso, \ - LassoCV, ElasticNet, ElasticNetCV, MultiTaskLasso, MultiTaskElasticNet -from sklearn.linear_model import LassoLarsCV + LassoCV, ElasticNet, ElasticNetCV, MultiTaskLasso, MultiTaskElasticNet, \ + lasso_path +from sklearn.linear_model import LassoLarsCV, lars_path def check_warnings(): @@ -149,7 +150,7 @@ def build_dataset(n_samples=50, n_features=200, n_informative_features=10, return X, y, X_test, y_test -def test_lasso_path(): +def test_lasso_cv(): X, y, X_test, y_test = build_dataset() max_iter = 150 clf = LassoCV(n_alphas=10, eps=1e-3, max_iter=max_iter).fit(X, y) @@ -176,6 +177,36 @@ def test_lasso_path(): assert_greater(clf.score(X_test, y_test), 0.99) +def test_lasso_path_return_models_vs_new_return_gives_same_coefficients(): + # Test that lasso_path with lars_path style output gives the + # same result + + # Some toy data + X = np.array([[1, 2, 3.1], [2.3, 5.4, 4.3]]).T + y = np.array([1, 2, 3.1]) + alphas = [5., 1., .5] + # Compute the lasso_path + coef_path = [e.coef_ for e in lasso_path(X, y, alphas=alphas, + fit_intercept=False)] + + # Use lars_path and lasso_path(new output) with 1D linear interpolation + # to compute the the same path + alphas_lars, _, coef_path_lars = lars_path(X, y, method='lasso') + coef_path_cont_lars = interpolate.interp1d(alphas_lars[::-1], + coef_path_lars[:, ::-1]) + alphas_lasso2, coef_path_lasso2 = lasso_path(X, y, alphas=alphas, + fit_intercept=False, + return_models=False) + coef_path_cont_lasso = interpolate.interp1d(alphas_lasso2[::-1], + coef_path_lasso2[:, ::-1]) + + np.testing.assert_array_almost_equal(coef_path_cont_lasso(alphas), + np.asarray(coef_path).T, decimal=1) + np.testing.assert_array_almost_equal(coef_path_cont_lasso(alphas), + coef_path_cont_lars(alphas), + decimal=1) + + def test_enet_path(): # We use a large number of samples and of informative features so that # the l1_ratio selected is more toward ridge than lasso diff --git a/sklearn/linear_model/tests/test_least_angle.py b/sklearn/linear_model/tests/test_least_angle.py index 8ec60760dfd0a..12c48c2bbc1f4 100644 --- a/sklearn/linear_model/tests/test_least_angle.py +++ b/sklearn/linear_model/tests/test_least_angle.py @@ -221,7 +221,7 @@ def test_rank_deficient_design(): def test_lasso_lars_vs_lasso_cd(verbose=False): """ Test that LassoLars and Lasso using coordinate descent give the - same results + same results. """ X = 3 * diabetes.data @@ -301,6 +301,9 @@ def test_lasso_lars_vs_lasso_cd_ill_conditioned(): # Test lasso lars on a very ill-conditioned design, and check that # it does not blow up, and stays somewhat close to a solution given # by the coordinate descent solver + # Also test that lasso_path (using lars_path output style) gives + # the same result as lars_path and previous lasso output style + # under these conditions. rng = np.random.RandomState(42) # Generate data @@ -321,15 +324,22 @@ def test_lasso_lars_vs_lasso_cd_ill_conditioned(): warnings.simplefilter("always", UserWarning) lars_alphas, _, lars_coef = linear_model.lars_path(X, y, method='lasso') + assert_true(len(warning_list) > 0) assert_true(('Dropping a regressor' in warning_list[0].message.args[0]) or ('Early stopping' in warning_list[0].message.args[0])) + _, lasso_coef2 = linear_model.lasso_path(X, y, + alphas=lars_alphas, tol=1e-6, + return_models=False) + lasso_coef = np.zeros((w.shape[0], len(lars_alphas))) for i, model in enumerate(linear_model.lasso_path(X, y, alphas=lars_alphas, tol=1e-6)): lasso_coef[:, i] = model.coef_ np.testing.assert_array_almost_equal(lars_coef, lasso_coef, decimal=1) + np.testing.assert_array_almost_equal(lars_coef, lasso_coef2, decimal=1) + np.testing.assert_array_almost_equal(lasso_coef, lasso_coef2, decimal=1) def test_lars_drop_for_good(): From 029c61525bf84ff31b5d6f0eb101d58de6ab7cd3 Mon Sep 17 00:00:00 2001 From: Alexandre Gramfort Date: Mon, 22 Jul 2013 10:35:35 +0200 Subject: [PATCH 2/3] ENH : massive refactoring of CV models in coordinate descent. Now the algo core is in path functions --- .../plot_lasso_coordinate_descent_path.py | 12 +- sklearn/linear_model/base.py | 7 +- sklearn/linear_model/coordinate_descent.py | 1279 +++++++++-------- .../tests/test_coordinate_descent.py | 7 +- .../linear_model/tests/test_least_angle.py | 10 +- .../tests/test_sparse_coordinate_descent.py | 8 +- 6 files changed, 665 insertions(+), 658 deletions(-) diff --git a/examples/linear_model/plot_lasso_coordinate_descent_path.py b/examples/linear_model/plot_lasso_coordinate_descent_path.py index 5f944e0ca7ae5..48eee240fc5ee 100644 --- a/examples/linear_model/plot_lasso_coordinate_descent_path.py +++ b/examples/linear_model/plot_lasso_coordinate_descent_path.py @@ -34,18 +34,18 @@ # the alphas and the coefficients as output, instead of a list # of models as it does by default. Returning the list of models # is deprecated and will eventually be removed in 0.15 -alphas_lasso, coefs_lasso = lasso_path(X, y, eps, return_models=False) +alphas_lasso, coefs_lasso, _ = lasso_path(X, y, eps, return_models=False) print("Computing regularization path using the positive lasso...") -alphas_positive_lasso, coefs_positive_lasso = lasso_path(X, y, eps, - positive=True, - return_models=False) +alphas_positive_lasso, coefs_positive_lasso, _ = lasso_path(X, y, eps, + positive=True, + return_models=False) print("Computing regularization path using the elastic net...") -alphas_enet, coefs_enet = enet_path(X, y, eps=eps, l1_ratio=0.8, +alphas_enet, coefs_enet, _ = enet_path(X, y, eps=eps, l1_ratio=0.8, return_models=False) print("Computing regularization path using the positve elastic net...") -alphas_positive_enet, coefs_positive_enet = enet_path(X, y, eps=eps, +alphas_positive_enet, coefs_positive_enet, _ = enet_path(X, y, eps=eps, l1_ratio=0.8, positive=True, return_models=False) diff --git a/sklearn/linear_model/base.py b/sklearn/linear_model/base.py index 1f1c69f0a4ce8..d0cce4365fbff 100644 --- a/sklearn/linear_model/base.py +++ b/sklearn/linear_model/base.py @@ -45,8 +45,10 @@ def sparse_center_data(X, y, fit_intercept, normalize=False): axis 0. Be aware that X will not be centered since it would break the sparsity, but will be normalized if asked so. """ - X_data = np.array(X.data, np.float64) + X = X.astype(np.float64) + if fit_intercept: + X_data = X.data # copy if 'normalize' is True or X is not a csc matrix X = sp.csc_matrix(X, copy=normalize) X_mean, X_std = csc_mean_variance_axis0(X) @@ -65,8 +67,7 @@ def sparse_center_data(X, y, fit_intercept, normalize=False): X_std = np.ones(X.shape[1]) y_mean = 0. if y.ndim == 1 else np.zeros(y.shape[1], dtype=X.dtype) - X_data = np.array(X.data, np.float64) - return X_data, y, X_mean, y_mean, X_std + return X, y, X_mean, y_mean, X_std def center_data(X, y, fit_intercept, normalize=False, copy=True, diff --git a/sklearn/linear_model/coordinate_descent.py b/sklearn/linear_model/coordinate_descent.py index 21b0aed736aa6..7914dac66ca6d 100644 --- a/sklearn/linear_model/coordinate_descent.py +++ b/sklearn/linear_model/coordinate_descent.py @@ -28,438 +28,38 @@ ############################################################################### -# ElasticNet model - - -class ElasticNet(LinearModel, RegressorMixin): - """Linear Model trained with L1 and L2 prior as regularizer - - Minimizes the objective function:: - - 1 / (2 * n_samples) * ||y - Xw||^2_2 + - + alpha * l1_ratio * ||w||_1 - + 0.5 * alpha * (1 - l1_ratio) * ||w||^2_2 - - If you are interested in controlling the L1 and L2 penalty - separately, keep in mind that this is equivalent to:: - - a * L1 + b * L2 - - where:: - - alpha = a + b and l1_ratio = a / (a + b) - - The parameter l1_ratio corresponds to alpha in the glmnet R package while - alpha corresponds to the lambda parameter in glmnet. Specifically, l1_ratio - = 1 is the lasso penalty. Currently, l1_ratio <= 0.01 is not reliable, - unless you supply your own sequence of alpha. - - Parameters - ---------- - alpha : float - Constant that multiplies the penalty terms. Defaults to 1.0 - See the notes for the exact mathematical meaning of this - parameter. - ``alpha = 0`` is equivalent to an ordinary least square, solved - by the :class:`LinearRegression` object. For numerical - reasons, using ``alpha = 0`` with the Lasso object is not advised - and you should prefer the LinearRegression object. - - l1_ratio : float - The ElasticNet mixing parameter, with ``0 <= l1_ratio <= 1``. For - ``l1_ratio = 0`` the penalty is an L2 penalty. ``For l1_ratio = 1`` it - is an L1 penalty. For ``0 < l1_ratio < 1``, the penalty is a - combination of L1 and L2. - - fit_intercept: bool - Whether the intercept should be estimated or not. If ``False``, the - data is assumed to be already centered. - - normalize : boolean, optional, default False - If ``True``, the regressors X will be normalized before regression. - - precompute : True | False | 'auto' | array-like - Whether to use a precomputed Gram matrix to speed up - calculations. If set to ``'auto'`` let us decide. The Gram - matrix can also be passed as argument. For sparse input - this option is always ``True`` to preserve sparsity. - - max_iter: int, optional - The maximum number of iterations - - copy_X : boolean, optional, default False - If ``True``, X will be copied; else, it may be overwritten. - - tol: float, optional - The tolerance for the optimization: if the updates are - smaller than ``tol``, the optimization code checks the - dual gap for optimality and continues until it is smaller - than ``tol``. - - warm_start : bool, optional - When set to ``True``, reuse the solution of the previous call to fit as - initialization, otherwise, just erase the previous solution. - - positive: bool, optional - When set to ``True``, forces the coefficients to be positive. - - Attributes - ---------- - ``coef_`` : array, shape = (n_features,) | (n_targets, n_features) - parameter vector (w in the cost function formula) - - ``sparse_coef_`` : scipy.sparse matrix, shape = (n_features, 1) | \ - (n_targets, n_features) - ``sparse_coef_`` is a readonly property derived from ``coef_`` - - ``intercept_`` : float | array, shape = (n_targets,) - independent term in decision function. - - ``dual_gap_`` : float | array, shape = (n_targets,) - the current fit is guaranteed to be epsilon-suboptimal with - epsilon := ``dual_gap_`` - - ``eps_`` : float | array, shape = (n_targets,) - ``eps_`` is used to check if the fit converged to the requested - ``tol`` - - Notes - ----- - To avoid unnecessary memory duplication the X argument of the fit method - should be directly passed as a Fortran-contiguous numpy array. - """ - def __init__(self, alpha=1.0, l1_ratio=0.5, fit_intercept=True, - normalize=False, precompute='auto', max_iter=1000, - copy_X=True, tol=1e-4, warm_start=False, positive=False, - rho=None): - self.alpha = alpha - self.l1_ratio = l1_ratio - if rho is not None: - self.l1_ratio = rho - warnings.warn("rho was renamed to l1_ratio and will be removed " - "in 0.15", DeprecationWarning) - self.coef_ = None - self.fit_intercept = fit_intercept - self.normalize = normalize - self.precompute = precompute - self.max_iter = max_iter - self.copy_X = copy_X - self.tol = tol - self.warm_start = warm_start - self.positive = positive - self.intercept_ = 0.0 - - def fit(self, X, y, Xy=None, coef_init=None): - """Fit model with coordinate descent - - Parameters - ----------- - X: ndarray or scipy.sparse matrix, (n_samples, n_features) - Data - y: ndarray, shape = (n_samples,) or (n_samples, n_targets) - Target - Xy : array-like, optional - Xy = np.dot(X.T, y) that can be precomputed. It is useful - only when the Gram matrix is precomputed. - coef_init: ndarray of shape n_features or (n_targets, n_features) - The initial coefficients to warm-start the optimization - - Notes - ----- - - Coordinate descent is an algorithm that considers each column of - data at a time hence it will automatically convert the X input - as a Fortran-contiguous numpy array if necessary. - - To avoid memory re-allocation it is advised to allocate the - initial data in memory directly using that format. - """ - if self.alpha == 0: - warnings.warn("With alpha=0, this algorithm does not converge " - "well. You are advised to use the LinearRegression " - "estimator", stacklevel=2) - X = atleast2d_or_csc(X, dtype=np.float64, order='F', - copy=self.copy_X and self.fit_intercept) - # From now on X can be touched inplace - y = np.asarray(y, dtype=np.float64) - # now all computation with X can be done inplace - fit = self._sparse_fit if sparse.isspmatrix(X) else self._dense_fit - fit(X, y, Xy, coef_init) - return self - - def _dense_fit(self, X, y, Xy=None, coef_init=None): +# Paths functions +def _pre_fit(X, y, Xy, precompute, normalize, fit_intercept, copy): + n_samples, n_features = X.shape + if sparse.isspmatrix(X): + precompute = False + X, y, X_mean, y_mean, X_std = sparse_center_data( + X, y, fit_intercept, normalize) + else: # copy was done in fit if necessary X, y, X_mean, y_mean, X_std = center_data( - X, y, self.fit_intercept, self.normalize, copy=False) - - if y.ndim == 1: - y = y[:, np.newaxis] - if Xy is not None and Xy.ndim == 1: - Xy = Xy[:, np.newaxis] - - n_samples, n_features = X.shape - n_targets = y.shape[1] - - precompute = self.precompute - if hasattr(precompute, '__array__') \ - and not np.allclose(X_mean, np.zeros(n_features)) \ - and not np.allclose(X_std, np.ones(n_features)): - # recompute Gram - precompute = 'auto' - Xy = None - - coef_ = self._init_coef(coef_init, n_features, n_targets) - dual_gap_ = np.empty(n_targets) - eps_ = np.empty(n_targets) - - l1_reg = self.alpha * self.l1_ratio * n_samples - l2_reg = self.alpha * (1.0 - self.l1_ratio) * n_samples - - # precompute if n_samples > n_features - if precompute == 'auto': - precompute = (n_samples > n_features) - - if hasattr(precompute, '__array__'): - Gram = precompute - elif precompute: - Gram = np.dot(X.T, X) - else: - Gram = None - - for k in xrange(n_targets): - if Gram is None: - coef_[k, :], dual_gap_[k], eps_[k] = \ - cd_fast.enet_coordinate_descent( - coef_[k, :], l1_reg, l2_reg, X, y[:, k], self.max_iter, - self.tol, self.positive) - else: - Gram = Gram.copy() - if Xy is None: - this_Xy = np.dot(X.T, y[:, k]) - else: - this_Xy = Xy[:, k] - coef_[k, :], dual_gap_[k], eps_[k] = \ - cd_fast.enet_coordinate_descent_gram( - coef_[k, :], l1_reg, l2_reg, Gram, this_Xy, y[:, k], - self.max_iter, self.tol, self.positive) - - if dual_gap_[k] > eps_[k]: - warnings.warn('Objective did not converge for ' + - 'target %d, you might want' % k + - ' to increase the number of iterations') - - self.coef_, self.dual_gap_, self.eps_ = (np.squeeze(a) for a in - (coef_, dual_gap_, eps_)) - self._set_intercept(X_mean, y_mean, X_std) - - # return self for chaining fit and predict calls - return self - - def _sparse_fit(self, X, y, Xy=None, coef_init=None): - - if X.shape[0] != y.shape[0]: - raise ValueError("X and y have incompatible shapes.\n" + - "Note: Sparse matrices cannot be indexed w/" + - "boolean masks (use `indices=True` in CV).") - - # NOTE: we are explicitly not centering the data the naive way to - # avoid breaking the sparsity of X - X_data, y, X_mean, y_mean, X_std = sparse_center_data( - X, y, self.fit_intercept, self.normalize) - - if y.ndim == 1: - y = y[:, np.newaxis] - - n_samples, n_features = X.shape[0], X.shape[1] - n_targets = y.shape[1] - - coef_ = self._init_coef(coef_init, n_features, n_targets) - dual_gap_ = np.empty(n_targets) - eps_ = np.empty(n_targets) - - l1_reg = self.alpha * self.l1_ratio * n_samples - l2_reg = self.alpha * (1.0 - self.l1_ratio) * n_samples - - for k in xrange(n_targets): - coef_[k, :], dual_gap_[k], eps_[k] = \ - cd_fast.sparse_enet_coordinate_descent( - coef_[k, :], l1_reg, l2_reg, X_data, X.indices, - X.indptr, y[:, k], X_mean / X_std, - self.max_iter, self.tol, self.positive) - - if dual_gap_[k] > eps_[k]: - warnings.warn('Objective did not converge for ' + - 'target %d, you might want' % k + - ' to increase the number of iterations') - - self.coef_, self.dual_gap_, self.eps_ = (np.squeeze(a) for a in - (coef_, dual_gap_, eps_)) - self._set_intercept(X_mean, y_mean, X_std) - - # return self for chaining fit and predict calls - return self - - def _init_coef(self, coef_init, n_features, n_targets): - if coef_init is None: - if not self.warm_start or self.coef_ is None: - coef_ = np.zeros((n_targets, n_features), dtype=np.float64) - else: - coef_ = self.coef_ - else: - coef_ = coef_init - - if coef_.ndim == 1: - coef_ = coef_[np.newaxis, :] - if coef_.shape != (n_targets, n_features): - raise ValueError("X and coef_init have incompatible " - "shapes (%s != %s)." - % (coef_.shape, (n_targets, n_features))) - - return coef_ - - @property - def sparse_coef_(self): - """ sparse representation of the fitted coef """ - return sparse.csr_matrix(self.coef_) - - def decision_function(self, X): - """Decision function of the linear model - - Parameters - ---------- - X : numpy array or scipy.sparse matrix of shape (n_samples, n_features) - - Returns - ------- - T : array, shape = (n_samples,) - The predicted decision function - """ - if sparse.isspmatrix(X): - return np.ravel(safe_sparse_dot(self.coef_, X.T, dense_output=True) - + self.intercept_) - else: - return super(ElasticNet, self).decision_function(X) - - -############################################################################### -# Lasso model - -class Lasso(ElasticNet): - """Linear Model trained with L1 prior as regularizer (aka the Lasso) - - The optimization objective for Lasso is:: - - (1 / (2 * n_samples)) * ||y - Xw||^2_2 + alpha * ||w||_1 - - Technically the Lasso model is optimizing the same objective function as - the Elastic Net with ``l1_ratio=1.0`` (no L2 penalty). - - Parameters - ---------- - alpha : float, optional - Constant that multiplies the L1 term. Defaults to 1.0. - ``alpha = 0`` is equivalent to an ordinary least square, solved - by the :class:`LinearRegression` object. For numerical - reasons, using ``alpha = 0`` is with the Lasso object is not advised - and you should prefer the LinearRegression object. - - fit_intercept : boolean - whether to calculate the intercept for this model. If set - to false, no intercept will be used in calculations - (e.g. data is expected to be already centered). - - normalize : boolean, optional, default False - If ``True``, the regressors X will be normalized before regression. - - copy_X : boolean, optional, default True - If ``True``, X will be copied; else, it may be overwritten. - - precompute : True | False | 'auto' | array-like - Whether to use a precomputed Gram matrix to speed up - calculations. If set to ``'auto'`` let us decide. The Gram - matrix can also be passed as argument. For sparse input - this option is always ``True`` to preserve sparsity. - - max_iter: int, optional - The maximum number of iterations - - tol : float, optional - The tolerance for the optimization: if the updates are - smaller than ``tol``, the optimization code checks the - dual gap for optimality and continues until it is smaller - than ``tol``. - - warm_start : bool, optional - When set to True, reuse the solution of the previous call to fit as - initialization, otherwise, just erase the previous solution. - - positive : bool, optional - When set to ``True``, forces the coefficients to be positive. - - - Attributes - ---------- - ``coef_`` : array, shape = (n_features,) | (n_targets, n_features) - parameter vector (w in the cost function formula) - - ``sparse_coef_`` : scipy.sparse matrix, shape = (n_features, 1) | \ - (n_targets, n_features) - ``sparse_coef_`` is a readonly property derived from ``coef_`` - - ``intercept_`` : float | array, shape = (n_targets,) - independent term in decision function. - - ``dual_gap_`` : float | array, shape = (n_targets,) - the current fit is guaranteed to be epsilon-suboptimal with - epsilon := ``dual_gap_`` - - ``eps_`` : float | array, shape = (n_targets,) - ``eps_`` is used to check if the fit converged to the requested - ``tol`` + X, y, fit_intercept, normalize, copy=copy) - Examples - -------- - >>> from sklearn import linear_model - >>> clf = linear_model.Lasso(alpha=0.1) - >>> clf.fit([[0,0], [1, 1], [2, 2]], [0, 1, 2]) - Lasso(alpha=0.1, copy_X=True, fit_intercept=True, max_iter=1000, - normalize=False, positive=False, precompute='auto', tol=0.0001, - warm_start=False) - >>> print(clf.coef_) - [ 0.85 0. ] - >>> print(clf.intercept_) - 0.15 - - See also - -------- - lars_path - lasso_path - LassoLars - LassoCV - LassoLarsCV - sklearn.decomposition.sparse_encode - - Notes - ----- - The algorithm used to fit the model is coordinate descent. + if hasattr(precompute, '__array__') \ + and not np.allclose(X_mean, np.zeros(n_features)) \ + and not np.allclose(X_std, np.ones(n_features)): + # recompute Gram + precompute = 'auto' + Xy = None - To avoid unnecessary memory duplication the X argument of the fit method - should be directly passed as a Fortran-contiguous numpy array. - """ + # precompute if n_samples > n_features + if precompute == 'auto': + precompute = (n_samples > n_features) - def __init__(self, alpha=1.0, fit_intercept=True, normalize=False, - precompute='auto', copy_X=True, max_iter=1000, - tol=1e-4, warm_start=False, positive=False): - super(Lasso, self).__init__( - alpha=alpha, l1_ratio=1.0, fit_intercept=fit_intercept, - normalize=normalize, precompute=precompute, copy_X=copy_X, - max_iter=max_iter, tol=tol, warm_start=warm_start, - positive=positive) + if precompute is True: + precompute = np.dot(X.T, X) + Xy = np.dot(X.T, y) + else: + Xy = None + return X, y, X_mean, y_mean, X_std, precompute, Xy -############################################################################### -# Classes to store linear models along a regularization path def _alpha_grid(X, y, Xy=None, l1_ratio=1.0, fit_intercept=True, eps=1e-3, n_alphas=100, normalize=False, copy_X=True): @@ -519,14 +119,158 @@ def _alpha_grid(X, y, Xy=None, l1_ratio=1.0, fit_intercept=True, def lasso_path(X, y, eps=1e-3, n_alphas=100, alphas=None, - precompute='auto', Xy=None, fit_intercept=True, - normalize=False, copy_X=True, verbose=False, return_models=True, + precompute='auto', Xy=None, fit_intercept=None, + normalize=None, copy_X=True, verbose=False, return_models=True, **params): """Compute Lasso path with coordinate descent The optimization objective for Lasso is:: - (1 / (2 * n_samples)) * ||y - Xw||^2_2 + alpha * ||w||_1 + (1 / (2 * n_samples)) * ||y - Xw||^2_2 + alpha * ||w||_1 + + Parameters + ---------- + X : {array-like, sparse matrix}, shape (n_samples, n_features) + Training data. Pass directly as Fortran-contiguous data to avoid + unnecessary memory duplication + + y : ndarray, shape = (n_samples,) + Target values + + eps : float, optional + Length of the path. ``eps=1e-3`` means that + ``alpha_min / alpha_max = 1e-3`` + + n_alphas : int, optional + Number of alphas along the regularization path + + alphas : ndarray, optional + List of alphas where to compute the models. + If ``None`` alphas are set automatically + + precompute : True | False | 'auto' | array-like + Whether to use a precomputed Gram matrix to speed up + calculations. If set to ``'auto'`` let us decide. The Gram + matrix can also be passed as argument. + + Xy : array-like, optional + Xy = np.dot(X.T, y) that can be precomputed. It is useful + only when the Gram matrix is precomputed. + + fit_intercept : bool + Fit or not an intercept. + WARNING : will be deprecated in 0.15 + + normalize : boolean, optional, default False + If ``True``, the regressors X will be normalized before regression. + WARNING : will be deprecated in 0.15 + + copy_X : boolean, optional, default True + If ``True``, X will be copied; else, it may be overwritten. + + verbose : bool or integer + Amount of verbosity + + return_models : boolean, optional, default True + If ``True``, the function will return list of models. Setting it + to ``False`` will change the function output returning the values + of the alphas and the coefficients along the path. Returning the + model list will be removed in version 0.15. + + params : kwargs + keyword arguments passed to the Lasso objects + + Returns + ------- + models : a list of models along the regularization path + (Is returned if ``return_models`` is set ``True`` (default). + + alphas : array, shape: [n_alphas + 1] + The alphas along the path where models are computed. + (Is returned, along with ``coefs``, when ``return_models`` is set + to ``False``) + + coefs : shape (n_features, n_alphas + 1) + Coefficients along the path. + (Is returned, along with ``alphas``, when ``return_models`` is set + to ``False``). + + dual_gaps : shape (n_alphas + 1) + The dual gaps and the end of the optimization for each alpha. + (Is returned, along with ``alphas``, when ``return_models`` is set + to ``False``). + + Notes + ----- + See examples/linear_model/plot_lasso_coordinate_descent_path.py + for an example. + + To avoid unnecessary memory duplication the X argument of the fit method + should be directly passed as a Fortran-contiguous numpy array. + + Note that in certain cases, the Lars solver may be significantly + faster to implement this functionality. In particular, linear + interpolation can be used to retrieve model coefficents between the + values output by lars_path + + Deprecation Notice: Setting ``return_models`` to ``False`` will make + the Lasso Path return an output in the style used by :func:`lars_path`. + This will be become the norm as of version 0.15. Leaving ``return_models`` + set to `True` will let the function return a list of models as before. + + Examples + --------- + + Comparing lasso_path and lars_path with interpolation: + + >>> X = np.array([[1, 2, 3.1], [2.3, 5.4, 4.3]]).T + >>> y = np.array([1, 2, 3.1]) + >>> # Use lasso_path to compute a coefficient path + >>> _, coef_path, _ = lasso_path(X, y, alphas=[5., 1., .5], + ... return_models=False, fit_intercept=False) + >>> print coef_path + [[ 0. 0. 0.46874778] + [ 0.2159048 0.4425765 0.23689075]] + + >>> # Now use lars_path and 1D linear interpolation to compute the + >>> # same path + >>> from sklearn.linear_model import lars_path + >>> alphas, active, coef_path_lars = lars_path(X, y, method='lasso') + >>> from scipy import interpolate + >>> coef_path_continuous = interpolate.interp1d(alphas[::-1], + ... coef_path_lars[:, ::-1]) + >>> print coef_path_continuous([5., 1., .5]) + [[ 0. 0. 0.46915237] + [ 0.2159048 0.4425765 0.23668876]] + + + See also + -------- + lars_path + Lasso + LassoLars + LassoCV + LassoLarsCV + sklearn.decomposition.sparse_encode + """ + return enet_path(X, y, l1_ratio=1., eps=eps, n_alphas=n_alphas, + alphas=alphas, precompute=precompute, Xy=Xy, + fit_intercept=fit_intercept, normalize=normalize, + copy_X=copy_X, verbose=verbose, + return_models=return_models, **params) + + +def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None, + precompute='auto', Xy=None, fit_intercept=True, + normalize=False, copy_X=True, verbose=False, rho=None, + return_models=True, **params): + """Compute Elastic-Net path with coordinate descent + + The Elastic Net optimization function is:: + + 1 / (2 * n_samples) * ||y - Xw||^2_2 + + + alpha * l1_ratio * ||w||_1 + + 0.5 * alpha * (1 - l1_ratio) * ||w||^2_2 Parameters ---------- @@ -537,7 +281,11 @@ def lasso_path(X, y, eps=1e-3, n_alphas=100, alphas=None, y : ndarray, shape = (n_samples,) Target values - eps : float, optional + l1_ratio : float, optional + float between 0 and 1 passed to ElasticNet (scaling between + l1 and l2 penalties). ``l1_ratio=1`` corresponds to the Lasso + + eps : float Length of the path. ``eps=1e-3`` means that ``alpha_min / alpha_max = 1e-3`` @@ -546,7 +294,7 @@ def lasso_path(X, y, eps=1e-3, n_alphas=100, alphas=None, alphas : ndarray, optional List of alphas where to compute the models. - If ``None`` alphas are set automatically + If None alphas are set automatically precompute : True | False | 'auto' | array-like Whether to use a precomputed Gram matrix to speed up @@ -558,10 +306,12 @@ def lasso_path(X, y, eps=1e-3, n_alphas=100, alphas=None, only when the Gram matrix is precomputed. fit_intercept : bool - Fit or not an intercept + Fit or not an intercept. + WARNING : will be deprecated in 0.15 normalize : boolean, optional, default False If ``True``, the regressors X will be normalized before regression. + WARNING : will be deprecated in 0.15 copy_X : boolean, optional, default True If ``True``, X will be copied; else, it may be overwritten. @@ -593,113 +343,387 @@ def lasso_path(X, y, eps=1e-3, n_alphas=100, alphas=None, (Is returned, along with ``alphas``, when ``return_models`` is set to ``False``). + dual_gaps : shape (n_alphas + 1) + The dual gaps and the end of the optimization for each alpha. + (Is returned, along with ``alphas``, when ``return_models`` is set + to ``False``). + Notes ----- - See examples/linear_model/plot_lasso_coordinate_descent_path.py - for an example. - - To avoid unnecessary memory duplication the X argument of the fit method - should be directly passed as a Fortran-contiguous numpy array. - - Note that in certain cases, the Lars solver may be significantly - faster to implement this functionality. In particular, linear - interpolation can be used to retrieve model coefficients between the - values output by lars_path + See examples/plot_lasso_coordinate_descent_path.py for an example. Deprecation Notice: Setting ``return_models`` to ``False`` will make the Lasso Path return an output in the style used by :func:`lars_path`. This will be become the norm as of version 0.15. Leaving ``return_models`` set to `True` will let the function return a list of models as before. - Examples - --------- + See also + -------- + ElasticNet + ElasticNetCV + """ + if return_models: + warnings.warn("Use enet_path(return_models=False), as it returns the" + " coefficients and alphas instead of just a list of" + " models as previously `lasso_path`/`enet_path` did." + " `return_models` will eventually be removed in 0.15," + " after which, returning alphas and coefs" + " will become the norm.", + DeprecationWarning, stacklevel=2) - Comparing lasso_path and lars_path with interpolation: + if normalize is not None: + warnings.warn("normalize param will be removed in 0.15." + " Intercept fitting and feature normalization will be" + " done in estimators.", + DeprecationWarning, stacklevel=2) + else: + normalize = False - >>> X = np.array([[1, 2, 3.1], [2.3, 5.4, 4.3]]).T - >>> y = np.array([1, 2, 3.1]) - >>> # Use lasso_path to compute a coefficient path - >>> coef_path = [e.coef_ for e in lasso_path(X, y, alphas=[5., 1., .5], - ... fit_intercept=False)] - >>> print np.array(coef_path).T - [[ 0. 0. 0.46874778] - [ 0.2159048 0.4425765 0.23689075]] + if fit_intercept is not None: + warnings.warn("fit_intercept param will be removed in 0.15." + " Intercept fitting and feature normalization will be" + " done in estimators.", + DeprecationWarning, stacklevel=2) + else: + fit_intercept = True + + if rho is not None: + l1_ratio = rho + warnings.warn("rho was renamed to l1_ratio and will be removed " + "in 0.15", DeprecationWarning) + + X = atleast2d_or_csc(X, dtype=np.float64, order='F', + copy=copy_X and fit_intercept) + + n_samples, n_features = X.shape + + if sparse.isspmatrix(X): + if 'X_mean' in params: + # As sparse matrices are not actually centered we need this + # to be passed to the CD solver. + X_sparse_scaling = params['X_mean'] / params['X_std'] + else: + X_sparse_scaling = np.ones(n_features) + + X, y, X_mean, y_mean, X_std, precompute, Xy = \ + _pre_fit(X, y, Xy, precompute, normalize, fit_intercept, copy=False) + + n_samples = X.shape[0] + if alphas is None: + # No need to normalize of fit_intercept: it has been done + # above + alphas = _alpha_grid(X, y, Xy=Xy, l1_ratio=l1_ratio, + fit_intercept=False, eps=eps, n_alphas=n_alphas, + normalize=False, copy_X=False) + else: + alphas = np.sort(alphas)[::-1] # make sure alphas are properly ordered + + n_alphas = len(alphas) + + coef_ = np.zeros(n_features, dtype=np.float64) + models = [] + coefs = np.empty((n_features, n_alphas), dtype=np.float64) + dual_gaps = np.empty(n_alphas) + + tol = params.get('tol', 1e-4) + positive = params.get('positive', False) + max_iter = params.get('max_iter', 1000) + + for i, alpha in enumerate(alphas): + l1_reg = alpha * l1_ratio * n_samples + l2_reg = alpha * (1.0 - l1_ratio) * n_samples + + if sparse.isspmatrix(X): + coef_, dual_gap_, eps_ = cd_fast.sparse_enet_coordinate_descent( + coef_, l1_reg, l2_reg, X.data, X.indices, + X.indptr, y, X_sparse_scaling, + max_iter, tol, positive) + else: + coef_, dual_gap_, eps_ = cd_fast.enet_coordinate_descent( + coef_, l1_reg, l2_reg, X, y, max_iter, tol, positive) + + if dual_gap_ > eps_: + warnings.warn('Objective did not converge.' + + ' You might want' + + ' to increase the number of iterations') + + coefs[:, i] = coef_ + dual_gaps[i] = dual_gap_ + + if return_models: + model = ElasticNet( + alpha=alpha, l1_ratio=l1_ratio, + fit_intercept=fit_intercept if sparse.isspmatrix(X) else False, + precompute=precompute) + model.coef_ = coefs[:, i] + model.dual_gap_ = dual_gaps[-1] + if fit_intercept and not sparse.isspmatrix(X): + model.fit_intercept = True + model._set_intercept(X_mean, y_mean, X_std) + models.append(model) + + if verbose: + if verbose > 2: + print(model) + elif verbose > 1: + print('Path: %03i out of %03i' % (i, n_alphas)) + else: + sys.stderr.write('.') + + if return_models: + return models + else: + return alphas, coefs, dual_gaps + + +############################################################################### +# ElasticNet model + + +class ElasticNet(LinearModel, RegressorMixin): + """Linear Model trained with L1 and L2 prior as regularizer + + Minimizes the objective function:: + + 1 / (2 * n_samples) * ||y - Xw||^2_2 + + + alpha * l1_ratio * ||w||_1 + + 0.5 * alpha * (1 - l1_ratio) * ||w||^2_2 + + If you are interested in controlling the L1 and L2 penalty + separately, keep in mind that this is equivalent to:: + + a * L1 + b * L2 + + where:: + + alpha = a + b and l1_ratio = a / (a + b) + + The parameter l1_ratio corresponds to alpha in the glmnet R package while + alpha corresponds to the lambda parameter in glmnet. Specifically, l1_ratio + = 1 is the lasso penalty. Currently, l1_ratio <= 0.01 is not reliable, + unless you supply your own sequence of alpha. + + Parameters + ---------- + alpha : float + Constant that multiplies the penalty terms. Defaults to 1.0 + See the notes for the exact mathematical meaning of this + parameter. + ``alpha = 0`` is equivalent to an ordinary least square, solved + by the :class:`LinearRegression` object. For numerical + reasons, using ``alpha = 0`` with the Lasso object is not advised + and you should prefer the LinearRegression object. + + l1_ratio : float + The ElasticNet mixing parameter, with ``0 <= l1_ratio <= 1``. For + ``l1_ratio = 0`` the penalty is an L2 penalty. ``For l1_ratio = 1`` it + is an L1 penalty. For ``0 < l1_ratio < 1``, the penalty is a + combination of L1 and L2. + + fit_intercept: bool + Whether the intercept should be estimated or not. If ``False``, the + data is assumed to be already centered. + + normalize : boolean, optional, default False + If ``True``, the regressors X will be normalized before regression. + + precompute : True | False | 'auto' | array-like + Whether to use a precomputed Gram matrix to speed up + calculations. If set to ``'auto'`` let us decide. The Gram + matrix can also be passed as argument. For sparse input + this option is always ``True`` to preserve sparsity. + + max_iter: int, optional + The maximum number of iterations + + copy_X : boolean, optional, default False + If ``True``, X will be copied; else, it may be overwritten. + + tol: float, optional + The tolerance for the optimization: if the updates are + smaller than ``tol``, the optimization code checks the + dual gap for optimality and continues until it is smaller + than ``tol``. + + warm_start : bool, optional + When set to ``True``, reuse the solution of the previous call to fit as + initialization, otherwise, just erase the previous solution. + + positive: bool, optional + When set to ``True``, forces the coefficients to be positive. + + Attributes + ---------- + ``coef_`` : array, shape = (n_features,) | (n_targets, n_features) + parameter vector (w in the cost function formula) + + ``sparse_coef_`` : scipy.sparse matrix, shape = (n_features, 1) | \ + (n_targets, n_features) + ``sparse_coef_`` is a readonly property derived from ``coef_`` + + ``intercept_`` : float | array, shape = (n_targets,) + independent term in decision function. + + Notes + ----- + To avoid unnecessary memory duplication the X argument of the fit method + should be directly passed as a Fortran-contiguous numpy array. + """ + path = staticmethod(enet_path) + + def __init__(self, alpha=1.0, l1_ratio=0.5, fit_intercept=True, + normalize=False, precompute='auto', max_iter=1000, + copy_X=True, tol=1e-4, warm_start=False, positive=False, + rho=None): + self.alpha = alpha + self.l1_ratio = l1_ratio + if rho is not None: + self.l1_ratio = rho + warnings.warn("rho was renamed to l1_ratio and will be removed " + "in 0.15", DeprecationWarning) + self.coef_ = None + self.fit_intercept = fit_intercept + self.normalize = normalize + self.precompute = precompute + self.max_iter = max_iter + self.copy_X = copy_X + self.tol = tol + self.warm_start = warm_start + self.positive = positive + self.intercept_ = 0.0 + + def fit(self, X, y, Xy=None, coef_init=None): + """Fit model with coordinate descent + + Parameters + ----------- + X: ndarray or scipy.sparse matrix, (n_samples, n_features) + Data + y: ndarray, shape = (n_samples,) or (n_samples, n_targets) + Target + Xy : array-like, optional + Xy = np.dot(X.T, y) that can be precomputed. It is useful + only when the Gram matrix is precomputed. + WARNING : ignored and will be deprecated in 0.15 + coef_init: ndarray of shape n_features or (n_targets, n_features) + The initial coeffients to warm-start the optimization + WARNING : ignored and will be deprecated in 0.15 + + Notes + ----- + + Coordinate descent is an algorithm that considers each column of + data at a time hence it will automatically convert the X input + as a Fortran-contiguous numpy array if necessary. + + To avoid memory re-allocation it is advised to allocate the + initial data in memory directly using that format. + """ + if Xy is not None: + warnings.warn("Xy param is now ignored and will be removed in " + "0.15. See enet_path function.", + DeprecationWarning, stacklevel=2) + + if coef_init is not None: + warnings.warn("coef_init is now ignored and will be removed in " + "0.15. See enet_path function.", + DeprecationWarning, stacklevel=2) + + if self.alpha == 0: + warnings.warn("With alpha=0, this algorithm does not converge " + "well. You are advised to use the LinearRegression " + "estimator", stacklevel=2) + X = atleast2d_or_csc(X, dtype=np.float64, order='F', + copy=self.copy_X and self.fit_intercept) + # From now on X can be touched inplace + y = np.asarray(y, dtype=np.float64) + + X, y, X_mean, y_mean, X_std, precompute, Xy = \ + _pre_fit(X, y, Xy, self.precompute, self.normalize, + self.fit_intercept, copy=True) + + if y.ndim == 1: + y = y[:, np.newaxis] + if Xy is not None and Xy.ndim == 1: + Xy = Xy[:, np.newaxis] - >>> # Now use lars_path and 1D linear interpolation to compute the - >>> # same path - >>> from sklearn.linear_model import lars_path - >>> alphas, active, coef_path_lars = lars_path(X, y, method='lasso') - >>> from scipy import interpolate - >>> coef_path_continuous = interpolate.interp1d(alphas[::-1], - ... coef_path_lars[:, ::-1]) - >>> print coef_path_continuous([5., 1., .5]) - [[ 0. 0. 0.46915237] - [ 0.2159048 0.4425765 0.23668876]] + n_samples, n_features = X.shape + n_targets = y.shape[1] + coef_ = np.zeros((n_targets, n_features), dtype=np.float64) + dual_gaps_ = np.zeros(n_targets, dtype=np.float64) - See also - -------- - lars_path - Lasso - LassoLars - LassoCV - LassoLarsCV - sklearn.decomposition.sparse_encode - """ - return enet_path(X, y, l1_ratio=1., eps=eps, n_alphas=n_alphas, - alphas=alphas, precompute=precompute, Xy=Xy, - fit_intercept=fit_intercept, normalize=normalize, - copy_X=copy_X, verbose=verbose, - return_models=return_models, **params) + for k in xrange(n_targets): + if Xy is not None: + this_Xy = Xy[:, k] + else: + this_Xy = None + _, this_coef, this_dual_gap = self.path(X, y[:, k], + l1_ratio=self.l1_ratio, eps=None, + n_alphas=None, alphas=[self.alpha], + precompute=precompute, Xy=this_Xy, + fit_intercept=False, normalize=False, copy_X=True, + verbose=False, tol=self.tol, positive=self.positive, + return_models=False, X_mean=X_mean, X_std=X_std) + coef_[k] = this_coef[:, 0] + dual_gaps_[k] = this_dual_gap[0] + + self.coef_, self.dual_gap_ = map(np.squeeze, [coef_, dual_gaps_]) + self._set_intercept(X_mean, y_mean, X_std) + # return self for chaining fit and predict calls + return self -def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None, - precompute='auto', Xy=None, fit_intercept=True, - normalize=False, copy_X=True, verbose=False, rho=None, - return_models=True, **params): - """Compute Elastic-Net path with coordinate descent + @property + def sparse_coef_(self): + """ sparse representation of the fitted coef """ + return sparse.csr_matrix(self.coef_) - The Elastic Net optimization function is:: + def decision_function(self, X): + """Decision function of the linear model - 1 / (2 * n_samples) * ||y - Xw||^2_2 + - + alpha * l1_ratio * ||w||_1 - + 0.5 * alpha * (1 - l1_ratio) * ||w||^2_2 + Parameters + ---------- + X : numpy array or scipy.sparse matrix of shape (n_samples, n_features) - Parameters - ---------- - X : {array-like, sparse matrix}, shape (n_samples, n_features) - Training data. Pass directly as Fortran-contiguous data to avoid - unnecessary memory duplication + Returns + ------- + T : array, shape = (n_samples,) + The predicted decision function + """ + if sparse.isspmatrix(X): + return np.ravel(safe_sparse_dot(self.coef_, X.T, dense_output=True) + + self.intercept_) + else: + return super(ElasticNet, self).decision_function(X) - y : ndarray, shape = (n_samples,) - Target values - l1_ratio : float, optional - float between 0 and 1 passed to ElasticNet (scaling between - l1 and l2 penalties). ``l1_ratio=1`` corresponds to the Lasso +############################################################################### +# Lasso model - eps : float - Length of the path. ``eps=1e-3`` means that - ``alpha_min / alpha_max = 1e-3`` +class Lasso(ElasticNet): + """Linear Model trained with L1 prior as regularizer (aka the Lasso) - n_alphas : int, optional - Number of alphas along the regularization path + The optimization objective for Lasso is:: - alphas : ndarray, optional - List of alphas where to compute the models. - If None alphas are set automatically + (1 / (2 * n_samples)) * ||y - Xw||^2_2 + alpha * ||w||_1 - precompute : True | False | 'auto' | array-like - Whether to use a precomputed Gram matrix to speed up - calculations. If set to ``'auto'`` let us decide. The Gram - matrix can also be passed as argument. + Technically the Lasso model is optimizing the same objective function as + the Elastic Net with ``l1_ratio=1.0`` (no L2 penalty). - Xy : array-like, optional - Xy = np.dot(X.T, y) that can be precomputed. It is useful - only when the Gram matrix is precomputed. + Parameters + ---------- + alpha : float, optional + Constant that multiplies the L1 term. Defaults to 1.0. + ``alpha = 0`` is equivalent to an ordinary least square, solved + by the :class:`LinearRegression` object. For numerical + reasons, using ``alpha = 0`` is with the Lasso object is not advised + and you should prefer the LinearRegression object. - fit_intercept : bool - Fit or not an intercept + fit_intercept : boolean + whether to calculate the intercept for this model. If set + to false, no intercept will be used in calculations + (e.g. data is expected to be already centered). normalize : boolean, optional, default False If ``True``, the regressors X will be normalized before regression. @@ -707,136 +731,87 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None, copy_X : boolean, optional, default True If ``True``, X will be copied; else, it may be overwritten. - verbose : bool or integer - Amount of verbosity + precompute : True | False | 'auto' | array-like + Whether to use a precomputed Gram matrix to speed up + calculations. If set to ``'auto'`` let us decide. The Gram + matrix can also be passed as argument. For sparse input + this option is always ``True`` to preserve sparsity. - return_models : boolean, optional, default True - If ``True``, the function will return list of models. Setting it - to ``False`` will change the function output returning the values - of the alphas and the coefficients along the path. Returning the - model list will be removed in version 0.15. + max_iter: int, optional + The maximum number of iterations - params : kwargs - keyword arguments passed to the Lasso objects + tol : float, optional + The tolerance for the optimization: if the updates are + smaller than ``tol``, the optimization code checks the + dual gap for optimality and continues until it is smaller + than ``tol``. - Returns - ------- - models : a list of models along the regularization path - (Is returned if ``return_models`` is set ``True`` (default). + warm_start : bool, optional + When set to True, reuse the solution of the previous call to fit as + initialization, otherwise, just erase the previous solution. - alphas : array, shape: [n_alphas + 1] - The alphas along the path where models are computed. - (Is returned, along with ``coefs``, when ``return_models`` is set - to ``False``) + positive : bool, optional + When set to ``True``, forces the coefficients to be positive. - coefs : shape (n_features, n_alphas + 1) - Coefficients along the path. - (Is returned, along with ``alphas``, when ``return_models`` is set - to ``False``). + Attributes + ---------- + ``coef_`` : array, shape = (n_features,) | (n_targets, n_features) + parameter vector (w in the cost function formula) - Notes - ----- - See examples/linear_model/plot_lasso_coordinate_descent_path.py - for an example. + ``sparse_coef_`` : scipy.sparse matrix, shape = (n_features, 1) | \ + (n_targets, n_features) + ``sparse_coef_`` is a readonly property derived from ``coef_`` - Deprecation Notice: Setting ``return_models`` to ``False`` will make - the Lasso Path return an output in the style used by :func:`lars_path`. - This will be become the norm as of version 0.15. Leaving ``return_models`` - set to `True` will let the function return a list of models as before. + ``intercept_`` : float | array, shape = (n_targets,) + independent term in decision function. - See also + Examples -------- - ElasticNet - ElasticNetCV - """ - if return_models: - warnings.warn("Use enet_path(return_models=False), as it returns the" - " coefficients and alphas instead of just a list of" - " models as previously `lasso_path`/`enet_path` did." - " `return_models` will eventually be removed in 0.15," - " after which, returning alphas and coefs" - " will become the norm.", - DeprecationWarning, stacklevel=2) - if rho is not None: - l1_ratio = rho - warnings.warn("rho was renamed to l1_ratio and will be removed " - "in 0.15", DeprecationWarning) - - X = atleast2d_or_csc(X, dtype=np.float64, order='F', - copy=copy_X and fit_intercept) - # From now on X can be touched inplace - if not sparse.isspmatrix(X): - X, y, X_mean, y_mean, X_std = center_data(X, y, fit_intercept, - normalize, copy=False) - # XXX : in the sparse case the data will be centered - # at each fit... - - n_samples, n_features = X.shape - if Xy is None: - Xy = safe_sparse_dot(X.T, y, dense_output=True) - - n_samples = X.shape[0] - if alphas is None: - # No need to normalize of fit_intercept: it has been done - # above - alphas = _alpha_grid(X, y, Xy=Xy, l1_ratio=l1_ratio, - fit_intercept=False, eps=1e-3, n_alphas=100, - normalize=False, copy_X=False) - else: - alphas = np.sort(alphas)[::-1] # make sure alphas are properly ordered + >>> from sklearn import linear_model + >>> clf = linear_model.Lasso(alpha=0.1) + >>> clf.fit([[0,0], [1, 1], [2, 2]], [0, 1, 2]) + Lasso(alpha=0.1, copy_X=True, fit_intercept=True, max_iter=1000, + normalize=False, positive=False, precompute='auto', tol=0.0001, + warm_start=False) + >>> print(clf.coef_) + [ 0.85 0. ] + >>> print(clf.intercept_) + 0.15 - if (hasattr(precompute, '__array__') - and not np.allclose(X_mean, np.zeros(n_features)) - and not np.allclose(X_std, np.ones(n_features))): - # recompute Gram - precompute = 'auto' - Xy = None + See also + -------- + lars_path + lasso_path + LassoLars + LassoCV + LassoLarsCV + sklearn.decomposition.sparse_encode - # precompute if n_samples > n_features - if precompute == 'auto': - precompute = (n_samples > n_features) + Notes + ----- + The algorithm used to fit the model is coordinate descent. - if precompute or (precompute == 'auto'): - if sparse.isspmatrix(X): - warnings.warn("precompute is ignored for sparse data") - precompute = False - else: - precompute = np.dot(X.T, X) + To avoid unnecessary memory duplication the X argument of the fit method + should be directly passed as a Fortran-contiguous numpy array. + """ + path = staticmethod(enet_path) - coef_ = None # init coef_ - models = [] - coefs = [] + def __init__(self, alpha=1.0, fit_intercept=True, normalize=False, + precompute='auto', copy_X=True, max_iter=1000, + tol=1e-4, warm_start=False, positive=False): + super(Lasso, self).__init__( + alpha=alpha, l1_ratio=1.0, fit_intercept=fit_intercept, + normalize=normalize, precompute=precompute, copy_X=copy_X, + max_iter=max_iter, tol=tol, warm_start=warm_start, + positive=positive) - n_alphas = len(alphas) - for i, alpha in enumerate(alphas): - model = ElasticNet( - alpha=alpha, l1_ratio=l1_ratio, - fit_intercept=fit_intercept if sparse.isspmatrix(X) else False, - precompute=precompute) - model.set_params(**params) - model.fit(X, y, coef_init=coef_, Xy=Xy) - if fit_intercept and not sparse.isspmatrix(X): - model.fit_intercept = True - model._set_intercept(X_mean, y_mean, X_std) - if verbose: - if verbose > 2: - print(model) - elif verbose > 1: - print('Path: %03i out of %03i' % (i, n_alphas)) - else: - sys.stderr.write('.') - coefs.append(model.coef_) - coef_ = coefs[-1].copy() - models.append(model) - if return_models: - return models - else: - return alphas, np.asarray(coefs).T +############################################################################### +# Functions for CV with paths functions def _path_residuals(X, y, train, test, path, path_params, l1_ratio=1, X_order=None, dtype=None): - """ Returns the MSE for the models computed by 'path' + """Returns the MSE for the models computed by 'path' Parameters ---------- @@ -873,19 +848,47 @@ def _path_residuals(X, y, train, test, path, path_params, l1_ratio=1, The dtype of the arrays expected by the path function to avoid memory copies """ - this_mses = list() + X_train = X[train] + y_train = y[train] + X_test = X[test] + y_test = y[test] + fit_intercept = path_params['fit_intercept'] + normalize = path_params['normalize'] + precompute = path_params['precompute'] + Xy = None + + X_train, y_train, X_mean, y_mean, X_std, precompute, Xy = \ + _pre_fit(X_train, y_train, Xy, precompute, normalize, fit_intercept, + copy=False) + + # del path_params['precompute'] + path_params = path_params.copy() + path_params['return_models'] = False + path_params['fit_intercept'] = False + path_params['normalize'] = False + path_params['Xy'] = Xy + path_params['X_mean'] = X_mean + path_params['X_std'] = X_std + path_params['precompute'] = precompute + path_params['copy_X'] = False + if 'l1_ratio' in path_params: path_params['l1_ratio'] = l1_ratio - X_train = X[train] + # Do the ordering and type casting here, as if it is done in the path, # X is copied and a reference is kept here X_train = atleast2d_or_csc(X_train, dtype=dtype, order=X_order) - models_train = path(X_train, y[train], **path_params) + alphas, coefs, _ = path(X_train, y[train], **path_params) del X_train - this_mses = np.empty(len(models_train)) - for i_model, model in enumerate(models_train): - y_ = model.predict(X[test]) - this_mses[i_model] = ((y_ - y[test]) ** 2).mean() + + if normalize: + nonzeros = np.flatnonzero(X_std) + coefs[nonzeros] /= X_std[nonzeros][:, np.newaxis] + + intercepts = y_mean - np.dot(X_mean, coefs) + residues = safe_sparse_dot(X_test, coefs) - y_test[:, np.newaxis] + residues += intercepts[np.newaxis, :] + this_mses = (residues ** 2).mean(axis=0) return this_mses, l1_ratio @@ -1005,6 +1008,7 @@ def fit(self, X, y): mse_alphas = [m[0] for m in mse_alphas] mse_alphas = np.array(mse_alphas) + mse = np.mean(mse_alphas, axis=0) i_best_alpha = np.argmin(mse) this_best_mse = mse[i_best_alpha] @@ -1032,7 +1036,6 @@ def fit(self, X, y): self.coef_ = model.coef_ self.intercept_ = model.intercept_ self.dual_gap_ = model.dual_gap_ - self.eps_ = model.eps_ return self @property @@ -1078,9 +1081,9 @@ class LassoCV(LinearModelCV, RegressorMixin): dual gap for optimality and continues until it is smaller than ``tol``. - cv : integer or cross-validation generator, optional + cv : integer or crossvalidation generator, optional If an integer is passed, it is the number of fold (default 3). - Specific cross-validation objects can be passed, see the + Specific crossvalidation objects can be passed, see the :mod:`sklearn.cross_validation` module for the list of possible objects. @@ -1106,7 +1109,7 @@ class LassoCV(LinearModelCV, RegressorMixin): Notes ----- - See examples/linear_model/plot_lasso_model_selection.py + See examples/linear_model/lasso_path_with_crossvalidation.py for an example. To avoid unnecessary memory duplication the X argument of the fit method @@ -1177,9 +1180,9 @@ class ElasticNetCV(LinearModelCV, RegressorMixin): dual gap for optimality and continues until it is smaller than ``tol``. - cv : integer or cross-validation generator, optional + cv : integer or crossvalidation generator, optional If an integer is passed, it is the number of fold (default 3). - Specific cross-validation objects can be passed, see the + Specific crossvalidation objects can be passed, see the :mod:`sklearn.cross_validation` module for the list of possible objects. @@ -1212,7 +1215,7 @@ class ElasticNetCV(LinearModelCV, RegressorMixin): Notes ----- - See examples/linear_model/plot_lasso_model_selection.py + See examples/linear_model/lasso_path_with_crossvalidation.py for an example. To avoid unnecessary memory duplication the X argument of the fit method @@ -1387,7 +1390,7 @@ def fit(self, X, y, Xy=None, coef_init=None): y: ndarray, shape = (n_samples, n_tasks) Target coef_init: ndarray of shape n_features - The initial coefficients to warm-start the optimization + The initial coeffients to warm-start the optimization Notes ----- diff --git a/sklearn/linear_model/tests/test_coordinate_descent.py b/sklearn/linear_model/tests/test_coordinate_descent.py index da911d29b1424..17f7165364307 100644 --- a/sklearn/linear_model/tests/test_coordinate_descent.py +++ b/sklearn/linear_model/tests/test_coordinate_descent.py @@ -194,7 +194,7 @@ def test_lasso_path_return_models_vs_new_return_gives_same_coefficients(): alphas_lars, _, coef_path_lars = lars_path(X, y, method='lasso') coef_path_cont_lars = interpolate.interp1d(alphas_lars[::-1], coef_path_lars[:, ::-1]) - alphas_lasso2, coef_path_lasso2 = lasso_path(X, y, alphas=alphas, + alphas_lasso2, coef_path_lasso2, _ = lasso_path(X, y, alphas=alphas, fit_intercept=False, return_models=False) coef_path_cont_lasso = interpolate.interp1d(alphas_lasso2[::-1], @@ -331,15 +331,14 @@ def test_enet_multitarget(): n_informative_features=10, n_targets=n_targets) estimator = ElasticNet(alpha=0.01, fit_intercept=True) estimator.fit(X, y) - coef, intercept, dual_gap, eps = (estimator.coef_, estimator.intercept_, - estimator.dual_gap_, estimator.eps_) + coef, intercept, dual_gap = (estimator.coef_, estimator.intercept_, + estimator.dual_gap_) for k in range(n_targets): estimator.fit(X, y[:, k]) assert_array_almost_equal(coef[k, :], estimator.coef_) assert_array_almost_equal(intercept[k], estimator.intercept_) assert_array_almost_equal(dual_gap[k], estimator.dual_gap_) - assert_array_almost_equal(eps[k], estimator.eps_) if __name__ == '__main__': diff --git a/sklearn/linear_model/tests/test_least_angle.py b/sklearn/linear_model/tests/test_least_angle.py index 12c48c2bbc1f4..8bc428638e45d 100644 --- a/sklearn/linear_model/tests/test_least_angle.py +++ b/sklearn/linear_model/tests/test_least_angle.py @@ -329,14 +329,16 @@ def test_lasso_lars_vs_lasso_cd_ill_conditioned(): assert_true(('Dropping a regressor' in warning_list[0].message.args[0]) or ('Early stopping' in warning_list[0].message.args[0])) - _, lasso_coef2 = linear_model.lasso_path(X, y, - alphas=lars_alphas, tol=1e-6, - return_models=False) + _, lasso_coef2, _ = linear_model.lasso_path(X, y, + alphas=lars_alphas, tol=1e-6, + return_models=False, + fit_intercept=False) lasso_coef = np.zeros((w.shape[0], len(lars_alphas))) for i, model in enumerate(linear_model.lasso_path(X, y, alphas=lars_alphas, - tol=1e-6)): + tol=1e-6, fit_intercept=False)): lasso_coef[:, i] = model.coef_ + np.testing.assert_array_almost_equal(lars_coef, lasso_coef, decimal=1) np.testing.assert_array_almost_equal(lars_coef, lasso_coef2, decimal=1) np.testing.assert_array_almost_equal(lasso_coef, lasso_coef2, decimal=1) diff --git a/sklearn/linear_model/tests/test_sparse_coordinate_descent.py b/sklearn/linear_model/tests/test_sparse_coordinate_descent.py index a54091f0352dc..1b6d0847b21ef 100644 --- a/sklearn/linear_model/tests/test_sparse_coordinate_descent.py +++ b/sklearn/linear_model/tests/test_sparse_coordinate_descent.py @@ -235,15 +235,14 @@ def test_enet_multitarget(): estimator = ElasticNet(alpha=0.01, fit_intercept=True, precompute=None) # XXX: There is a bug when precompute is not None! estimator.fit(X, y) - coef, intercept, dual_gap, eps = (estimator.coef_, estimator.intercept_, - estimator.dual_gap_, estimator.eps_) + coef, intercept, dual_gap = (estimator.coef_, estimator.intercept_, + estimator.dual_gap_) for k in range(n_targets): estimator.fit(X, y[:, k]) assert_array_almost_equal(coef[k, :], estimator.coef_) assert_array_almost_equal(intercept[k], estimator.intercept_) assert_array_almost_equal(dual_gap[k], estimator.dual_gap_) - assert_array_almost_equal(eps[k], estimator.eps_) def test_path_parameters(): @@ -256,3 +255,6 @@ def test_path_parameters(): assert_almost_equal(0.5, clf.l1_ratio) assert_equal(n_alphas, clf.n_alphas) assert_equal(n_alphas, len(clf.alphas_)) + sparse_mse_path = clf.mse_path_ + clf.fit(X.toarray(), y) # compare with dense data + assert_almost_equal(clf.mse_path_, sparse_mse_path) From 4d1cf3b56589c1eb254fc862f217306300e4d687 Mon Sep 17 00:00:00 2001 From: Alexandre Gramfort Date: Tue, 23 Jul 2013 14:42:26 +0200 Subject: [PATCH 3/3] update what's new --- doc/whats_new.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index c2a1fa04c0ca5..98bbaccb1bf18 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -137,7 +137,7 @@ API changes summary :class:`linear_model.enet_path` can return its results in the same format as that of :class:`linear_model.lars_path`. This is done by setting the `return_models` parameter to `False`. By - `Jaques Grobler`_ + `Jaques Grobler`_ and `Alexandre Gramfort`_ - :class:`grid_search.IterGrid` was renamed to :class:`grid_search.ParameterGrid`.