From 3e71c4bb73de0fa798861ca0728dc0f62d0c9251 Mon Sep 17 00:00:00 2001 From: doaa-altarawy Date: Tue, 24 Jul 2018 15:20:43 -0400 Subject: [PATCH 1/4] Implement penalty.factor of glmnet R package to have different penalty for each variable (#11566) --- sklearn/covariance/graph_lasso_.py | 2 +- sklearn/linear_model/cd_fast.pyx | 28 ++++++++++----- sklearn/linear_model/coordinate_descent.py | 41 ++++++++++++++-------- 3 files changed, 48 insertions(+), 23 deletions(-) diff --git a/sklearn/covariance/graph_lasso_.py b/sklearn/covariance/graph_lasso_.py index 0837acf4a3641..b4797a3e679be 100644 --- a/sklearn/covariance/graph_lasso_.py +++ b/sklearn/covariance/graph_lasso_.py @@ -223,7 +223,7 @@ def graphical_lasso(emp_cov, alpha, cov_init=None, mode='cd', tol=1e-4, coefs = -(precision_[indices != idx, idx] / (precision_[idx, idx] + 1000 * eps)) coefs, _, _, _ = cd_fast.enet_coordinate_descent_gram( - coefs, alpha, 0, sub_covariance, + coefs, alpha, 0, np.empty(0), sub_covariance, row, row, max_iter, enet_tol, check_random_state(None), False) else: diff --git a/sklearn/linear_model/cd_fast.pyx b/sklearn/linear_model/cd_fast.pyx index a51d1bdbdbc96..fd46a124466c4 100644 --- a/sklearn/linear_model/cd_fast.pyx +++ b/sklearn/linear_model/cd_fast.pyx @@ -143,6 +143,7 @@ cdef extern from "cblas.h": @cython.cdivision(True) def enet_coordinate_descent(np.ndarray[floating, ndim=1] w, floating alpha, floating beta, + np.ndarray[floating, ndim=1] l1_weights, np.ndarray[floating, ndim=2, mode='fortran'] X, np.ndarray[floating, ndim=1, mode='c'] y, int max_iter, floating tol, @@ -201,6 +202,8 @@ def enet_coordinate_descent(np.ndarray[floating, ndim=1] w, cdef unsigned int f_iter cdef UINT32_t rand_r_state_seed = rng.randint(0, RAND_R_MAX) cdef UINT32_t* rand_r_state = &rand_r_state_seed + cdef floating w_alpha = alpha + cdef unsigned int use_l1_weight = l1_weights.shape[0] > 0 cdef floating *X_data = X.data cdef floating *y_data = y.data @@ -224,6 +227,9 @@ def enet_coordinate_descent(np.ndarray[floating, ndim=1] w, w_max = 0.0 d_w_max = 0.0 for f_iter in range(n_features): # Loop over coordinates + if use_l1_weight: + w_alpha = l1_weights[f_iter] + if random: ii = rand_int(n_features, rand_r_state) else: @@ -245,7 +251,7 @@ def enet_coordinate_descent(np.ndarray[floating, ndim=1] w, if positive and tmp < 0: w[ii] = 0.0 else: - w[ii] = (fsign(tmp) * fmax(fabs(tmp) - alpha, 0) + w[ii] = (fsign(tmp) * fmax(fabs(tmp) - w_alpha, 0) / (norm_cols_X[ii] + beta)) if w[ii] != 0.0: @@ -284,8 +290,8 @@ def enet_coordinate_descent(np.ndarray[floating, ndim=1] w, # w_norm2 = np.dot(w, w) w_norm2 = dot(n_features, w_data, 1, w_data, 1) - if (dual_norm_XtA > alpha): - const = alpha / dual_norm_XtA + if (dual_norm_XtA > w_alpha): + const = w_alpha / dual_norm_XtA A_norm2 = R_norm2 * (const ** 2) gap = 0.5 * (R_norm2 + A_norm2) else: @@ -295,7 +301,7 @@ def enet_coordinate_descent(np.ndarray[floating, ndim=1] w, l1_norm = asum(n_features, w_data, 1) # np.dot(R.T, y) - gap += (alpha * l1_norm + gap += (w_alpha * l1_norm - const * dot(n_samples, R_data, 1, y_data, n_tasks) + 0.5 * beta * (1 + const ** 2) * (w_norm2)) @@ -529,6 +535,7 @@ def sparse_enet_coordinate_descent(floating [:] w, @cython.wraparound(False) @cython.cdivision(True) def enet_coordinate_descent_gram(floating[:] w, floating alpha, floating beta, + np.ndarray[floating, ndim=1] l1_weights, np.ndarray[floating, ndim=2, mode='c'] Q, np.ndarray[floating, ndim=1, mode='c'] q, np.ndarray[floating, ndim=1] y, @@ -581,6 +588,8 @@ def enet_coordinate_descent_gram(floating[:] w, floating alpha, floating beta, cdef unsigned int f_iter cdef UINT32_t rand_r_state_seed = rng.randint(0, RAND_R_MAX) cdef UINT32_t* rand_r_state = &rand_r_state_seed + cdef floating w_alpha = alpha + cdef unsigned int use_l1_weight = l1_weights.shape[0] > 0 cdef floating y_norm2 = np.dot(y, y) cdef floating* w_ptr = &w[0] @@ -599,6 +608,9 @@ def enet_coordinate_descent_gram(floating[:] w, floating alpha, floating beta, w_max = 0.0 d_w_max = 0.0 for f_iter in range(n_features): # Loop over coordinates + if use_l1_weight: + w_alpha = l1_weights[f_iter] + if random: ii = rand_int(n_features, rand_r_state) else: @@ -619,7 +631,7 @@ def enet_coordinate_descent_gram(floating[:] w, floating alpha, floating beta, if positive and tmp < 0: w[ii] = 0.0 else: - w[ii] = fsign(tmp) * fmax(fabs(tmp) - alpha, 0) \ + w[ii] = fsign(tmp) * fmax(fabs(tmp) - w_alpha, 0) \ / (Q[ii, ii] + beta) if w[ii] != 0.0: @@ -659,8 +671,8 @@ def enet_coordinate_descent_gram(floating[:] w, floating alpha, floating beta, # w_norm2 = np.dot(w, w) w_norm2 = dot(n_features, &w[0], 1, &w[0], 1) - if (dual_norm_XtA > alpha): - const = alpha / dual_norm_XtA + if (dual_norm_XtA > w_alpha): + const = w_alpha / dual_norm_XtA A_norm2 = R_norm2 * (const ** 2) gap = 0.5 * (R_norm2 + A_norm2) else: @@ -668,7 +680,7 @@ def enet_coordinate_descent_gram(floating[:] w, floating alpha, floating beta, gap = R_norm2 # The call to dasum is equivalent to the L1 norm of w - gap += (alpha * asum(n_features, &w[0], 1) - + gap += (w_alpha * asum(n_features, &w[0], 1) - const * y_norm2 + const * q_dot_w + 0.5 * beta * (1 + const ** 2) * w_norm2) diff --git a/sklearn/linear_model/coordinate_descent.py b/sklearn/linear_model/coordinate_descent.py index 13e3a3e09ddf9..a3bb4287dccb9 100644 --- a/sklearn/linear_model/coordinate_descent.py +++ b/sklearn/linear_model/coordinate_descent.py @@ -140,7 +140,7 @@ def lasso_path(X, y, eps=1e-3, n_alphas=100, alphas=None, Where:: - ||W||_21 = \\sum_i \\sqrt{\\sum_j w_{ij}^2} + ||W||_21 = \sum_i \sqrt{\sum_j w_{ij}^2} i.e. the sum of norm of each row. @@ -268,7 +268,7 @@ def lasso_path(X, y, eps=1e-3, n_alphas=100, alphas=None, def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None, precompute='auto', Xy=None, copy_X=True, coef_init=None, verbose=False, return_n_iter=False, positive=False, - check_input=True, **params): + check_input=True, l1_weights=None, **params): """Compute elastic net path with coordinate descent The elastic net optimization function varies for mono and multi-outputs. @@ -276,13 +276,13 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None, For mono-output tasks it is:: 1 / (2 * n_samples) * ||y - Xw||^2_2 - + alpha * l1_ratio * ||w||_1 + + alpha * l1_ratio * l1_weights * ||w||_1 + 0.5 * alpha * (1 - l1_ratio) * ||w||^2_2 For multi-output tasks it is:: (1 / (2 * n_samples)) * ||Y - XW||^Fro_2 - + alpha * l1_ratio * ||W||_21 + + alpha * l1_ratio * l1_weights * ||W||_21 + 0.5 * alpha * (1 - l1_ratio) * ||W||_Fro^2 Where:: @@ -347,6 +347,12 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None, Skip input validation checks, including the Gram matrix when provided assuming there are handled by the caller when check_input=False. + l1_weights : array, shape (n_features, ), optional + Apply separate weight to penalties of each coefficient in the L1 term. + If not provided, no weighting is used (the default). + For example, if the weight of a feature is Zero, it means it's not penalized + at all, and that feature will always be there in the model. + **params : kwargs keyword arguments passed to the coordinate descent solver. @@ -454,6 +460,10 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None, for i, alpha in enumerate(alphas): l1_reg = alpha * l1_ratio * n_samples l2_reg = alpha * (1.0 - l1_ratio) * n_samples + if l1_weights is not None: + l1_weights_ = np.asfortranarray(l1_weights * l1_reg, dtype=X.dtype) + else: + l1_weights_ = np.asfortranarray([], dtype=X.dtype) if not multi_output and sparse.isspmatrix(X): model = cd_fast.sparse_enet_coordinate_descent( coef_, l1_reg, l2_reg, X.data, X.indices, @@ -469,11 +479,11 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None, precompute = check_array(precompute, dtype=X.dtype.type, order='C') model = cd_fast.enet_coordinate_descent_gram( - coef_, l1_reg, l2_reg, precompute, Xy, y, max_iter, + coef_, l1_reg, l2_reg, l1_weights_, precompute, Xy, y, max_iter, tol, rng, random, positive) elif precompute is False: model = cd_fast.enet_coordinate_descent( - coef_, l1_reg, l2_reg, X, y, max_iter, tol, rng, random, + coef_, l1_reg, l2_reg, l1_weights_, X, y, max_iter, tol, rng, random, positive) else: raise ValueError("Precompute should be one of True, False, " @@ -624,8 +634,9 @@ class ElasticNet(LinearModel, RegressorMixin): >>> regr = ElasticNet(random_state=0) >>> regr.fit(X, y) ElasticNet(alpha=1.0, copy_X=True, fit_intercept=True, l1_ratio=0.5, - max_iter=1000, normalize=False, positive=False, precompute=False, - random_state=0, selection='cyclic', tol=0.0001, warm_start=False) + l1_weights=None, max_iter=1000, normalize=False, positive=False, + precompute=False, random_state=0, selection='cyclic', tol=0.0001, + warm_start=False) >>> print(regr.coef_) # doctest: +ELLIPSIS [18.83816048 64.55968825] >>> print(regr.intercept_) # doctest: +ELLIPSIS @@ -652,7 +663,7 @@ class ElasticNet(LinearModel, RegressorMixin): def __init__(self, alpha=1.0, l1_ratio=0.5, fit_intercept=True, normalize=False, precompute=False, max_iter=1000, copy_X=True, tol=1e-4, warm_start=False, positive=False, - random_state=None, selection='cyclic'): + random_state=None, selection='cyclic', l1_weights=None): self.alpha = alpha self.l1_ratio = l1_ratio self.fit_intercept = fit_intercept @@ -665,6 +676,7 @@ def __init__(self, alpha=1.0, l1_ratio=0.5, fit_intercept=True, self.positive = positive self.random_state = random_state self.selection = selection + self.l1_weights = l1_weights def fit(self, X, y, check_input=True): """Fit model with coordinate descent. @@ -754,7 +766,7 @@ def fit(self, X, y, check_input=True): verbose=False, tol=self.tol, positive=self.positive, X_offset=X_offset, X_scale=X_scale, return_n_iter=True, coef_init=coef_[k], max_iter=self.max_iter, - random_state=self.random_state, + random_state=self.random_state, l1_weights=self.l1_weights, selection=self.selection, check_input=False) coef_[k] = this_coef[:, 0] @@ -1534,9 +1546,9 @@ class ElasticNetCV(LinearModelCV, RegressorMixin): >>> regr = ElasticNetCV(cv=5, random_state=0) >>> regr.fit(X, y) ElasticNetCV(alphas=None, copy_X=True, cv=5, eps=0.001, fit_intercept=True, - l1_ratio=0.5, max_iter=1000, n_alphas=100, n_jobs=1, - normalize=False, positive=False, precompute='auto', random_state=0, - selection='cyclic', tol=0.0001, verbose=0) + l1_ratio=0.5, l1_weights=None, max_iter=1000, n_alphas=100, + n_jobs=1, normalize=False, positive=False, precompute='auto', + random_state=0, selection='cyclic', tol=0.0001, verbose=0) >>> print(regr.alpha_) # doctest: +ELLIPSIS 0.1994727942696716 >>> print(regr.intercept_) # doctest: +ELLIPSIS @@ -1583,7 +1595,7 @@ def __init__(self, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None, fit_intercept=True, normalize=False, precompute='auto', max_iter=1000, tol=1e-4, cv='warn', copy_X=True, verbose=0, n_jobs=1, positive=False, random_state=None, - selection='cyclic'): + selection='cyclic', l1_weights=None): self.l1_ratio = l1_ratio self.eps = eps self.n_alphas = n_alphas @@ -1600,6 +1612,7 @@ def __init__(self, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None, self.positive = positive self.random_state = random_state self.selection = selection + self.l1_weights = l1_weights ############################################################################### From dbeffd5edc5270076725d77872fd6d6287b9d70f Mon Sep 17 00:00:00 2001 From: doaa-altarawy Date: Tue, 24 Jul 2018 15:23:13 -0400 Subject: [PATCH 2/4] Add tests for l1_weights in enet (#11566) --- .../tests/test_coordinate_descent.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/sklearn/linear_model/tests/test_coordinate_descent.py b/sklearn/linear_model/tests/test_coordinate_descent.py index 834d685f5b23d..9921ea41c4063 100644 --- a/sklearn/linear_model/tests/test_coordinate_descent.py +++ b/sklearn/linear_model/tests/test_coordinate_descent.py @@ -291,6 +291,43 @@ def test_enet_path(): assert_almost_equal(clf1.alpha_, clf2.alpha_) +@pytest.mark.filterwarnings('ignore: You should specify a value') # 0.22 +def test_enet_selective_penalty(): + n_features = 200 + n_informative_features = 20 + # A dataset with small number of samples and large number of features + # with the last n_informative_features as the informative ones + X, y, X_test, y_test = build_dataset(n_samples=50, n_features=n_features, + n_informative_features=n_informative_features) + + # Default weight is 1 for all features, keep l1 penalty + l1_weights = np.ones(n_features) + + # Add some prior knowledge, when we know some features are important + # So, we will relax l1 penalty on the last n_informative_features. + # Use any small number, or zero if you are 100% sure of the prior + # knowledge + l1_weights[:n_informative_features-1] *= 0.001 + + # Run enet with prior knowledge (l1_weights) + clf_with_prior = ElasticNetCV(alphas=[0.01, 0.05, 0.1, 0.5, 1, 1.5], + eps=2e-3, cv=3, l1_weights=l1_weights) + + ignore_warnings(clf_with_prior.fit)(X, y) + + # This is a model without using any prior knowledge + clf_base = ElasticNetCV(alphas=[0.01, 0.05, 0.1, 0.5, 1, 1.5], + eps=2e-3, cv=3) + + ignore_warnings(clf_base.fit)(X, y) + + # Accuracy of the model with prior knowledge should be higher + # than the model without prior knowledge for a hard data set + # (much less samples than features) + assert_greater(clf_with_prior.score(X_test, y_test), + clf_base.score(X_test, y_test)) + + @pytest.mark.filterwarnings('ignore: You should specify a value') # 0.22 def test_path_parameters(): X, y, _, _ = build_dataset() From 9d4c12ed3e144a56f48134d6a31bdfa4f4deefc8 Mon Sep 17 00:00:00 2001 From: doaa-altarawy Date: Tue, 24 Jul 2018 16:46:39 -0400 Subject: [PATCH 3/4] Fix flake8 and docstring errors --- sklearn/linear_model/coordinate_descent.py | 24 ++++++++++++++----- .../tests/test_coordinate_descent.py | 5 ++-- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/sklearn/linear_model/coordinate_descent.py b/sklearn/linear_model/coordinate_descent.py index a3bb4287dccb9..6e11f3807c26d 100644 --- a/sklearn/linear_model/coordinate_descent.py +++ b/sklearn/linear_model/coordinate_descent.py @@ -350,8 +350,8 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None, l1_weights : array, shape (n_features, ), optional Apply separate weight to penalties of each coefficient in the L1 term. If not provided, no weighting is used (the default). - For example, if the weight of a feature is Zero, it means it's not penalized - at all, and that feature will always be there in the model. + For example, if the weight of a feature is Zero, it means it's not + penalized at all, and that feature will always be there in the model. **params : kwargs keyword arguments passed to the coordinate descent solver. @@ -479,12 +479,12 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None, precompute = check_array(precompute, dtype=X.dtype.type, order='C') model = cd_fast.enet_coordinate_descent_gram( - coef_, l1_reg, l2_reg, l1_weights_, precompute, Xy, y, max_iter, - tol, rng, random, positive) + coef_, l1_reg, l2_reg, l1_weights_, precompute, Xy, y, + max_iter, tol, rng, random, positive) elif precompute is False: model = cd_fast.enet_coordinate_descent( - coef_, l1_reg, l2_reg, l1_weights_, X, y, max_iter, tol, rng, random, - positive) + coef_, l1_reg, l2_reg, l1_weights_, X, y, max_iter, tol, rng, + random, positive) else: raise ValueError("Precompute should be one of True, False, " "'auto' or array-like. Got %r" % precompute) @@ -609,6 +609,12 @@ class ElasticNet(LinearModel, RegressorMixin): (setting to 'random') often leads to significantly faster convergence especially when tol is higher than 1e-4. + l1_weights : array, shape (n_features, ), optional + Apply separate weight to penalties of each coefficient in the L1 term. + If not provided, no weighting is used (the default). + For example, if the weight of a feature is Zero, it means it's not + penalized at all, and that feature will always be there in the model. + Attributes ---------- coef_ : array, shape (n_features,) | (n_targets, n_features) @@ -1511,6 +1517,12 @@ class ElasticNetCV(LinearModelCV, RegressorMixin): (setting to 'random') often leads to significantly faster convergence especially when tol is higher than 1e-4. + l1_weights : array, shape (n_features, ), optional + Apply separate weight to penalties of each coefficient in the L1 term. + If not provided, no weighting is used (the default). + For example, if the weight of a feature is Zero, it means it's not + penalized at all, and that feature will always be there in the model. + Attributes ---------- alpha_ : float diff --git a/sklearn/linear_model/tests/test_coordinate_descent.py b/sklearn/linear_model/tests/test_coordinate_descent.py index 9921ea41c4063..f7f7d5a1b1fb5 100644 --- a/sklearn/linear_model/tests/test_coordinate_descent.py +++ b/sklearn/linear_model/tests/test_coordinate_descent.py @@ -297,8 +297,9 @@ def test_enet_selective_penalty(): n_informative_features = 20 # A dataset with small number of samples and large number of features # with the last n_informative_features as the informative ones - X, y, X_test, y_test = build_dataset(n_samples=50, n_features=n_features, - n_informative_features=n_informative_features) + X, y, X_test, y_test = build_dataset( + n_samples=50, n_features=n_features, + n_informative_features=n_informative_features) # Default weight is 1 for all features, keep l1 penalty l1_weights = np.ones(n_features) From 67764c8f9b38d15b9c8b7cea87a91c71f072d889 Mon Sep 17 00:00:00 2001 From: doaa-altarawy Date: Tue, 24 Jul 2018 21:07:43 -0400 Subject: [PATCH 4/4] Fix flake8, long line --- sklearn/linear_model/coordinate_descent.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sklearn/linear_model/coordinate_descent.py b/sklearn/linear_model/coordinate_descent.py index 6e11f3807c26d..9d36e32826420 100644 --- a/sklearn/linear_model/coordinate_descent.py +++ b/sklearn/linear_model/coordinate_descent.py @@ -772,7 +772,8 @@ def fit(self, X, y, check_input=True): verbose=False, tol=self.tol, positive=self.positive, X_offset=X_offset, X_scale=X_scale, return_n_iter=True, coef_init=coef_[k], max_iter=self.max_iter, - random_state=self.random_state, l1_weights=self.l1_weights, + random_state=self.random_state, + l1_weights=self.l1_weights, selection=self.selection, check_input=False) coef_[k] = this_coef[:, 0]