Skip to content

[MRG] Add Penalty factors for each coefficient in enet ( see #11566) #11671

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sklearn/covariance/graph_lasso_.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def graphical_lasso(emp_cov, alpha, cov_init=None, mode='cd', tol=1e-4,
coefs = -(precision_[indices != idx, idx]
/ (precision_[idx, idx] + 1000 * eps))
coefs, _, _, _ = cd_fast.enet_coordinate_descent_gram(
coefs, alpha, 0, sub_covariance,
coefs, alpha, 0, np.empty(0), sub_covariance,
row, row, max_iter, enet_tol,
check_random_state(None), False)
else:
Expand Down
28 changes: 20 additions & 8 deletions sklearn/linear_model/cd_fast.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ cdef extern from "cblas.h":
@cython.cdivision(True)
def enet_coordinate_descent(np.ndarray[floating, ndim=1] w,
floating alpha, floating beta,
np.ndarray[floating, ndim=1] l1_weights,
np.ndarray[floating, ndim=2, mode='fortran'] X,
np.ndarray[floating, ndim=1, mode='c'] y,
int max_iter, floating tol,
Expand Down Expand Up @@ -201,6 +202,8 @@ def enet_coordinate_descent(np.ndarray[floating, ndim=1] w,
cdef unsigned int f_iter
cdef UINT32_t rand_r_state_seed = rng.randint(0, RAND_R_MAX)
cdef UINT32_t* rand_r_state = &rand_r_state_seed
cdef floating w_alpha = alpha
cdef unsigned int use_l1_weight = l1_weights.shape[0] > 0

cdef floating *X_data = <floating*> X.data
cdef floating *y_data = <floating*> y.data
Expand All @@ -224,6 +227,9 @@ def enet_coordinate_descent(np.ndarray[floating, ndim=1] w,
w_max = 0.0
d_w_max = 0.0
for f_iter in range(n_features): # Loop over coordinates
if use_l1_weight:
w_alpha = l1_weights[f_iter]

if random:
ii = rand_int(n_features, rand_r_state)
else:
Expand All @@ -245,7 +251,7 @@ def enet_coordinate_descent(np.ndarray[floating, ndim=1] w,
if positive and tmp < 0:
w[ii] = 0.0
else:
w[ii] = (fsign(tmp) * fmax(fabs(tmp) - alpha, 0)
w[ii] = (fsign(tmp) * fmax(fabs(tmp) - w_alpha, 0)
/ (norm_cols_X[ii] + beta))

if w[ii] != 0.0:
Expand Down Expand Up @@ -284,8 +290,8 @@ def enet_coordinate_descent(np.ndarray[floating, ndim=1] w,
# w_norm2 = np.dot(w, w)
w_norm2 = dot(n_features, w_data, 1, w_data, 1)

if (dual_norm_XtA > alpha):
const = alpha / dual_norm_XtA
if (dual_norm_XtA > w_alpha):
const = w_alpha / dual_norm_XtA
A_norm2 = R_norm2 * (const ** 2)
gap = 0.5 * (R_norm2 + A_norm2)
else:
Expand All @@ -295,7 +301,7 @@ def enet_coordinate_descent(np.ndarray[floating, ndim=1] w,
l1_norm = asum(n_features, w_data, 1)

# np.dot(R.T, y)
gap += (alpha * l1_norm
gap += (w_alpha * l1_norm
- const * dot(n_samples, R_data, 1, y_data, n_tasks)
+ 0.5 * beta * (1 + const ** 2) * (w_norm2))

Expand Down Expand Up @@ -529,6 +535,7 @@ def sparse_enet_coordinate_descent(floating [:] w,
@cython.wraparound(False)
@cython.cdivision(True)
def enet_coordinate_descent_gram(floating[:] w, floating alpha, floating beta,
np.ndarray[floating, ndim=1] l1_weights,
np.ndarray[floating, ndim=2, mode='c'] Q,
np.ndarray[floating, ndim=1, mode='c'] q,
np.ndarray[floating, ndim=1] y,
Expand Down Expand Up @@ -581,6 +588,8 @@ def enet_coordinate_descent_gram(floating[:] w, floating alpha, floating beta,
cdef unsigned int f_iter
cdef UINT32_t rand_r_state_seed = rng.randint(0, RAND_R_MAX)
cdef UINT32_t* rand_r_state = &rand_r_state_seed
cdef floating w_alpha = alpha
cdef unsigned int use_l1_weight = l1_weights.shape[0] > 0

cdef floating y_norm2 = np.dot(y, y)
cdef floating* w_ptr = <floating*>&w[0]
Expand All @@ -599,6 +608,9 @@ def enet_coordinate_descent_gram(floating[:] w, floating alpha, floating beta,
w_max = 0.0
d_w_max = 0.0
for f_iter in range(n_features): # Loop over coordinates
if use_l1_weight:
w_alpha = l1_weights[f_iter]

if random:
ii = rand_int(n_features, rand_r_state)
else:
Expand All @@ -619,7 +631,7 @@ def enet_coordinate_descent_gram(floating[:] w, floating alpha, floating beta,
if positive and tmp < 0:
w[ii] = 0.0
else:
w[ii] = fsign(tmp) * fmax(fabs(tmp) - alpha, 0) \
w[ii] = fsign(tmp) * fmax(fabs(tmp) - w_alpha, 0) \
/ (Q[ii, ii] + beta)

if w[ii] != 0.0:
Expand Down Expand Up @@ -659,16 +671,16 @@ def enet_coordinate_descent_gram(floating[:] w, floating alpha, floating beta,
# w_norm2 = np.dot(w, w)
w_norm2 = dot(n_features, &w[0], 1, &w[0], 1)

if (dual_norm_XtA > alpha):
const = alpha / dual_norm_XtA
if (dual_norm_XtA > w_alpha):
const = w_alpha / dual_norm_XtA
A_norm2 = R_norm2 * (const ** 2)
gap = 0.5 * (R_norm2 + A_norm2)
else:
const = 1.0
gap = R_norm2

# The call to dasum is equivalent to the L1 norm of w
gap += (alpha * asum(n_features, &w[0], 1) -
gap += (w_alpha * asum(n_features, &w[0], 1) -
const * y_norm2 + const * q_dot_w +
0.5 * beta * (1 + const ** 2) * w_norm2)

Expand Down
56 changes: 41 additions & 15 deletions sklearn/linear_model/coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def lasso_path(X, y, eps=1e-3, n_alphas=100, alphas=None,

Where::

||W||_21 = \\sum_i \\sqrt{\\sum_j w_{ij}^2}
||W||_21 = \sum_i \sqrt{\sum_j w_{ij}^2}

i.e. the sum of norm of each row.

Expand Down Expand Up @@ -268,21 +268,21 @@ def lasso_path(X, y, eps=1e-3, n_alphas=100, alphas=None,
def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
precompute='auto', Xy=None, copy_X=True, coef_init=None,
verbose=False, return_n_iter=False, positive=False,
check_input=True, **params):
check_input=True, l1_weights=None, **params):
"""Compute elastic net path with coordinate descent

The elastic net optimization function varies for mono and multi-outputs.

For mono-output tasks it is::

1 / (2 * n_samples) * ||y - Xw||^2_2
+ alpha * l1_ratio * ||w||_1
+ alpha * l1_ratio * l1_weights * ||w||_1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the l1_weights should be inside the L1 norm: ||l1_weights*w||__1

+ 0.5 * alpha * (1 - l1_ratio) * ||w||^2_2

For multi-output tasks it is::

(1 / (2 * n_samples)) * ||Y - XW||^Fro_2
+ alpha * l1_ratio * ||W||_21
+ alpha * l1_ratio * l1_weights * ||W||_21
+ 0.5 * alpha * (1 - l1_ratio) * ||W||_Fro^2

Where::
Expand Down Expand Up @@ -347,6 +347,12 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
Skip input validation checks, including the Gram matrix when provided
assuming there are handled by the caller when check_input=False.

l1_weights : array, shape (n_features, ), optional
Apply separate weight to penalties of each coefficient in the L1 term.
If not provided, no weighting is used (the default).
For example, if the weight of a feature is Zero, it means it's not
penalized at all, and that feature will always be there in the model.

**params : kwargs
keyword arguments passed to the coordinate descent solver.

Expand Down Expand Up @@ -454,6 +460,10 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
for i, alpha in enumerate(alphas):
l1_reg = alpha * l1_ratio * n_samples
l2_reg = alpha * (1.0 - l1_ratio) * n_samples
if l1_weights is not None:
l1_weights_ = np.asfortranarray(l1_weights * l1_reg, dtype=X.dtype)
else:
l1_weights_ = np.asfortranarray([], dtype=X.dtype)
if not multi_output and sparse.isspmatrix(X):
model = cd_fast.sparse_enet_coordinate_descent(
coef_, l1_reg, l2_reg, X.data, X.indices,
Expand All @@ -469,12 +479,12 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
precompute = check_array(precompute, dtype=X.dtype.type,
order='C')
model = cd_fast.enet_coordinate_descent_gram(
coef_, l1_reg, l2_reg, precompute, Xy, y, max_iter,
tol, rng, random, positive)
coef_, l1_reg, l2_reg, l1_weights_, precompute, Xy, y,
max_iter, tol, rng, random, positive)
elif precompute is False:
model = cd_fast.enet_coordinate_descent(
coef_, l1_reg, l2_reg, X, y, max_iter, tol, rng, random,
positive)
coef_, l1_reg, l2_reg, l1_weights_, X, y, max_iter, tol, rng,
random, positive)
else:
raise ValueError("Precompute should be one of True, False, "
"'auto' or array-like. Got %r" % precompute)
Expand Down Expand Up @@ -599,6 +609,12 @@ class ElasticNet(LinearModel, RegressorMixin):
(setting to 'random') often leads to significantly faster convergence
especially when tol is higher than 1e-4.

l1_weights : array, shape (n_features, ), optional
Apply separate weight to penalties of each coefficient in the L1 term.
If not provided, no weighting is used (the default).
For example, if the weight of a feature is Zero, it means it's not
penalized at all, and that feature will always be there in the model.

Attributes
----------
coef_ : array, shape (n_features,) | (n_targets, n_features)
Expand All @@ -624,8 +640,9 @@ class ElasticNet(LinearModel, RegressorMixin):
>>> regr = ElasticNet(random_state=0)
>>> regr.fit(X, y)
ElasticNet(alpha=1.0, copy_X=True, fit_intercept=True, l1_ratio=0.5,
max_iter=1000, normalize=False, positive=False, precompute=False,
random_state=0, selection='cyclic', tol=0.0001, warm_start=False)
l1_weights=None, max_iter=1000, normalize=False, positive=False,
precompute=False, random_state=0, selection='cyclic', tol=0.0001,
warm_start=False)
>>> print(regr.coef_) # doctest: +ELLIPSIS
[18.83816048 64.55968825]
>>> print(regr.intercept_) # doctest: +ELLIPSIS
Expand All @@ -652,7 +669,7 @@ class ElasticNet(LinearModel, RegressorMixin):
def __init__(self, alpha=1.0, l1_ratio=0.5, fit_intercept=True,
normalize=False, precompute=False, max_iter=1000,
copy_X=True, tol=1e-4, warm_start=False, positive=False,
random_state=None, selection='cyclic'):
random_state=None, selection='cyclic', l1_weights=None):
self.alpha = alpha
self.l1_ratio = l1_ratio
self.fit_intercept = fit_intercept
Expand All @@ -665,6 +682,7 @@ def __init__(self, alpha=1.0, l1_ratio=0.5, fit_intercept=True,
self.positive = positive
self.random_state = random_state
self.selection = selection
self.l1_weights = l1_weights

def fit(self, X, y, check_input=True):
"""Fit model with coordinate descent.
Expand Down Expand Up @@ -755,6 +773,7 @@ def fit(self, X, y, check_input=True):
X_offset=X_offset, X_scale=X_scale, return_n_iter=True,
coef_init=coef_[k], max_iter=self.max_iter,
random_state=self.random_state,
l1_weights=self.l1_weights,
selection=self.selection,
check_input=False)
coef_[k] = this_coef[:, 0]
Expand Down Expand Up @@ -1499,6 +1518,12 @@ class ElasticNetCV(LinearModelCV, RegressorMixin):
(setting to 'random') often leads to significantly faster convergence
especially when tol is higher than 1e-4.

l1_weights : array, shape (n_features, ), optional
Apply separate weight to penalties of each coefficient in the L1 term.
If not provided, no weighting is used (the default).
For example, if the weight of a feature is Zero, it means it's not
penalized at all, and that feature will always be there in the model.

Attributes
----------
alpha_ : float
Expand Down Expand Up @@ -1534,9 +1559,9 @@ class ElasticNetCV(LinearModelCV, RegressorMixin):
>>> regr = ElasticNetCV(cv=5, random_state=0)
>>> regr.fit(X, y)
ElasticNetCV(alphas=None, copy_X=True, cv=5, eps=0.001, fit_intercept=True,
l1_ratio=0.5, max_iter=1000, n_alphas=100, n_jobs=1,
normalize=False, positive=False, precompute='auto', random_state=0,
selection='cyclic', tol=0.0001, verbose=0)
l1_ratio=0.5, l1_weights=None, max_iter=1000, n_alphas=100,
n_jobs=1, normalize=False, positive=False, precompute='auto',
random_state=0, selection='cyclic', tol=0.0001, verbose=0)
>>> print(regr.alpha_) # doctest: +ELLIPSIS
0.1994727942696716
>>> print(regr.intercept_) # doctest: +ELLIPSIS
Expand Down Expand Up @@ -1583,7 +1608,7 @@ def __init__(self, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
fit_intercept=True, normalize=False, precompute='auto',
max_iter=1000, tol=1e-4, cv='warn', copy_X=True,
verbose=0, n_jobs=1, positive=False, random_state=None,
selection='cyclic'):
selection='cyclic', l1_weights=None):
self.l1_ratio = l1_ratio
self.eps = eps
self.n_alphas = n_alphas
Expand All @@ -1600,6 +1625,7 @@ def __init__(self, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
self.positive = positive
self.random_state = random_state
self.selection = selection
self.l1_weights = l1_weights


###############################################################################
Expand Down
38 changes: 38 additions & 0 deletions sklearn/linear_model/tests/test_coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,44 @@ def test_enet_path():
assert_almost_equal(clf1.alpha_, clf2.alpha_)


@pytest.mark.filterwarnings('ignore: You should specify a value') # 0.22
def test_enet_selective_penalty():
n_features = 200
n_informative_features = 20
# A dataset with small number of samples and large number of features
# with the last n_informative_features as the informative ones
X, y, X_test, y_test = build_dataset(
n_samples=50, n_features=n_features,
n_informative_features=n_informative_features)

# Default weight is 1 for all features, keep l1 penalty
l1_weights = np.ones(n_features)

# Add some prior knowledge, when we know some features are important
# So, we will relax l1 penalty on the last n_informative_features.
# Use any small number, or zero if you are 100% sure of the prior
# knowledge
l1_weights[:n_informative_features-1] *= 0.001

# Run enet with prior knowledge (l1_weights)
clf_with_prior = ElasticNetCV(alphas=[0.01, 0.05, 0.1, 0.5, 1, 1.5],
eps=2e-3, cv=3, l1_weights=l1_weights)

ignore_warnings(clf_with_prior.fit)(X, y)

# This is a model without using any prior knowledge
clf_base = ElasticNetCV(alphas=[0.01, 0.05, 0.1, 0.5, 1, 1.5],
eps=2e-3, cv=3)

ignore_warnings(clf_base.fit)(X, y)

# Accuracy of the model with prior knowledge should be higher
# than the model without prior knowledge for a hard data set
# (much less samples than features)
assert_greater(clf_with_prior.score(X_test, y_test),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use plain assert

clf_base.score(X_test, y_test))


@pytest.mark.filterwarnings('ignore: You should specify a value') # 0.22
def test_path_parameters():
X, y, _, _ = build_dataset()
Expand Down