Skip to content

ENH add sample_weight and fit_intercept to Cython enet_coordinate_descent #31375

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 129 additions & 16 deletions sklearn/linear_model/_cd_fast.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,11 @@ def enet_coordinate_descent(
floating beta,
const floating[::1, :] X,
const floating[::1] y,
unsigned int max_iter,
floating tol,
object rng,
const floating[::1] sample_weight=None,
const floating[::1] X_mean=None,
unsigned int max_iter=1000,
floating tol=1e-4,
object rng=np.random.RandomState(0),
bint random=0,
bint positive=0
):
Expand All @@ -102,6 +104,11 @@ def enet_coordinate_descent(

(1/2) * norm(y - X w, 2)^2 + alpha norm(w, 1) + (beta/2) norm(w, 2)^2

With sample weights, y and X are scaled by sqrt(sample_weight) - but the
implementation avoids the square root and extra memory allocations.

For X_mean, see sparse_enet_coordinate_descent.

Returns
-------
w : ndarray of shape (n_features,)
Expand All @@ -124,11 +131,13 @@ def enet_coordinate_descent(
cdef unsigned int n_features = X.shape[1]

# compute norms of the columns of X
cdef floating[::1] norm_cols_X = np.square(X).sum(axis=0)
cdef floating[::1] norm_cols_X

# initial value of the residuals
cdef floating[::1] R = np.empty(n_samples, dtype=dtype)
cdef floating[::1] XtA = np.empty(n_features, dtype=dtype)
# residuals with sample_weight
cdef floating[::1] R_sw

cdef floating tmp
cdef floating w_ii
Expand All @@ -138,6 +147,7 @@ def enet_coordinate_descent(
cdef floating gap = tol + 1.0
cdef floating d_w_tol = tol
cdef floating dual_norm_XtA
cdef floating R_sum = 0.0 # always takes sample_weights into account
cdef floating R_norm2
cdef floating w_norm2
cdef floating l1_norm
Expand All @@ -146,21 +156,75 @@ def enet_coordinate_descent(
cdef unsigned int ii
cdef unsigned int n_iter = 0
cdef unsigned int f_iter
cdef unsigned int jj
cdef uint32_t rand_r_state_seed = rng.randint(0, RAND_R_MAX)
cdef uint32_t* rand_r_state = &rand_r_state_seed
cdef bint center = False
cdef bint no_sample_weights = sample_weight is None

if alpha == 0 and beta == 0:
warnings.warn("Coordinate descent with no regularization may lead to "
"unexpected results and is discouraged.")

if X_mean is None:
if no_sample_weights:
norm_cols_X = np.square(X).sum(axis=0)
else:
norm_cols_X = (sample_weight[:, None] * np.square(X)).sum(axis=0)
else:
# Computation delayed, inside nogil block.
norm_cols_X = np.zeros(n_features, dtype=dtype)

if not no_sample_weights:
R_sw = np.empty_like(R)

with nogil:
# center = (X_mean != 0).any()
if X_mean is not None:
for ii in range(n_features):
if X_mean[ii]:
center = True
break

if center:
if no_sample_weights:
for ii in range(n_features):
for jj in range(n_samples):
norm_cols_X[ii] += (X[jj, ii] - X_mean[ii]) ** 2
else:
for ii in range(n_features):
for jj in range(n_samples):
norm_cols_X[ii] += (
sample_weight[jj] * (X[jj, ii] - X_mean[ii]) ** 2
)

# R = y - np.dot(X, w)
_copy(n_samples, &y[0], 1, &R[0], 1)
_gemv(ColMajor, NoTrans, n_samples, n_features, -1.0, &X[0, 0],
n_samples, &w[0], 1, 1.0, &R[0], 1)
if center:
# R += np.dot(X_mean, w)
# R_sum = np.sum(R) or np.sum(sample_weight * R)
tmp = _dot(n_features, &X_mean[0], 1, &w[0], 1)
R_sum = 0.0
if no_sample_weights:
for jj in range(n_samples):
R[jj] += tmp
R_sum += R[jj]
else:
for jj in range(n_samples):
R[jj] += tmp
R_sum += sample_weight[jj] * R[jj]
# Note: It turns out that R_sum does not need any update from here on.

# tol *= np.dot(y, y)
tol *= _dot(n_samples, &y[0], 1, &y[0], 1)
if no_sample_weights:
tol *= _dot(n_samples, &y[0], 1, &y[0], 1)
else:
tmp = 0
for jj in range(n_samples):
tmp += sample_weight[jj] * y[jj]**2
tol *= tmp

for n_iter in range(max_iter):
w_max = 0.0
Expand All @@ -179,9 +243,27 @@ def enet_coordinate_descent(
if w_ii != 0.0:
# R += w_ii * X[:,ii]
_axpy(n_samples, w_ii, &X[0, ii], 1, &R[0], 1)
if center and X_mean[ii] != 0.0:
# R -= w_ii * X_mean[ii]
# Note: No need to update R_sum because the update terms cancel
# each other: -w_ii np.sum(X[:,ii] - X_mean[ii]) = 0.
tmp = w_ii * X_mean[ii]
if no_sample_weights:
for jj in range(n_samples):
R[jj] -= tmp
else:
for jj in range(n_samples):
R[jj] -= tmp

# tmp = (X[:,ii]*R).sum()
tmp = _dot(n_samples, &X[0, ii], 1, &R[0], 1)
if no_sample_weights:
tmp = _dot(n_samples, &X[0, ii], 1, &R[0], 1)
else:
tmp = 0.0
for jj in range(n_samples):
tmp += X[jj, ii] * sample_weight[jj] * R[jj]
if center and X_mean[ii] != 0.0:
tmp -= R_sum * X_mean[ii]

if positive and tmp < 0:
w[ii] = 0.0
Expand All @@ -192,6 +274,16 @@ def enet_coordinate_descent(
if w[ii] != 0.0:
# R -= w[ii] * X[:,ii] # Update residual
_axpy(n_samples, -w[ii], &X[0, ii], 1, &R[0], 1)
if center and X_mean[ii] != 0.0:
# R += w[ii] * X_mean[ii]
# Note: No need to update R_sum, see note above.
tmp = X_mean[ii] * w[ii]
if no_sample_weights:
for jj in range(n_samples):
R[jj] += tmp
else:
for jj in range(n_samples):
R[jj] += tmp

# update the maximum absolute coefficient update
d_w_ii = fabs(w[ii] - w_ii)
Expand All @@ -210,18 +302,32 @@ def enet_coordinate_descent(

# XtA = np.dot(X.T, R) - beta * w
_copy(n_features, &w[0], 1, &XtA[0], 1)
_gemv(ColMajor, Trans,
n_samples, n_features, 1.0, &X[0, 0], n_samples,
&R[0], 1,
-beta, &XtA[0], 1)
if no_sample_weights:
_gemv(
ColMajor, Trans, n_samples, n_features, 1.0, &X[0, 0],
n_samples, &R[0], 1, -beta, &XtA[0], 1
)
else:
for jj in range(n_samples):
R_sw[jj] = sample_weight[jj] * R[jj]
_gemv(
ColMajor, Trans, n_samples, n_features, 1.0, &X[0, 0],
n_samples, &R_sw[0], 1, -beta, &XtA[0], 1
)
if center:
# XtA -= X_mean * R_sum
_axpy(n_features, -R_sum, &X_mean[0], 1, &XtA[0], 1)

if positive:
dual_norm_XtA = max(n_features, &XtA[0])
else:
dual_norm_XtA = abs_max(n_features, &XtA[0])

# R_norm2 = np.dot(R, R)
R_norm2 = _dot(n_samples, &R[0], 1, &R[0], 1)
if no_sample_weights:
R_norm2 = _dot(n_samples, &R[0], 1, &R[0], 1)
else:
R_norm2 = _dot(n_samples, &R_sw[0], 1, &R[0], 1)

# w_norm2 = np.dot(w, w)
w_norm2 = _dot(n_features, &w[0], 1, &w[0], 1)
Expand All @@ -236,10 +342,12 @@ def enet_coordinate_descent(

l1_norm = _asum(n_features, &w[0], 1)

# np.dot(R.T, y)
gap += (alpha * l1_norm
- const * _dot(n_samples, &R[0], 1, &y[0], 1)
+ 0.5 * beta * (1 + const ** 2) * (w_norm2))
gap += alpha * l1_norm + 0.5 * beta * (1 + const ** 2) * (w_norm2)
# gap -= const * np.dot(R.T, y)
if no_sample_weights:
gap -= const * _dot(n_samples, &R[0], 1, &y[0], 1)
else:
gap -= const * _dot(n_samples, &R_sw[0], 1, &y[0], 1)

if gap < tol:
# return if we reached desired tolerance
Expand Down Expand Up @@ -294,7 +402,9 @@ def sparse_enet_coordinate_descent(
1/2 * sum(sw * (y - Z w)^2, axis=0) + alpha * norm(w, 1)
+ (beta/2) * norm(w, 2)^2

and X_mean is the weighted average of X (per column).
and X_mean is the weighted average of X (per column). If y_mean is not just zero,
the passed y must already be centered, i.e.,
`y - np.average(y, weights=sample_weight)`.

Returns
-------
Expand Down Expand Up @@ -366,6 +476,7 @@ def sparse_enet_coordinate_descent(
cdef bint no_sample_weights = sample_weight is None
cdef int kk

# R = y
if no_sample_weights:
yw = y
R = y.copy()
Expand All @@ -380,6 +491,8 @@ def sparse_enet_coordinate_descent(
center = True
break

# R -= np.dot(X, w)
# norm_cols_X = np.square(X).sum(axis=0)
for ii in range(n_features):
X_mean_ii = X_mean[ii]
endptr = X_indptr[ii + 1]
Expand Down
13 changes: 12 additions & 1 deletion sklearn/linear_model/_coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,18 @@ def enet_path(
)
elif precompute is False:
model = cd_fast.enet_coordinate_descent(
coef_, l1_reg, l2_reg, X, y, max_iter, tol, rng, random, positive
coef_,
l1_reg,
l2_reg,
X,
y,
None,
None,
max_iter,
tol,
rng,
random,
positive,
)
else:
raise ValueError(
Expand Down
52 changes: 52 additions & 0 deletions sklearn/linear_model/tests/test_coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@
lars_path,
lasso_path,
)

# mypy error: Module 'sklearn.linear_model' has no attribute '_cd_fast'
from sklearn.linear_model._cd_fast import (
enet_coordinate_descent, # type: ignore[attr-defined]
)
from sklearn.linear_model._coordinate_descent import _set_order
from sklearn.model_selection import (
BaseCrossValidator,
Expand Down Expand Up @@ -1214,6 +1219,53 @@ def test_multi_task_lasso_cv_dtype():
assert_array_almost_equal(est.coef_, [[1, 0, 0]] * 2, decimal=3)


@pytest.mark.parametrize("with_sw", [False, True])
@pytest.mark.parametrize("fit_intercept", [False, True])
def test_enet_coordinate_descent_with_sample_weight(
with_sw, fit_intercept, global_random_seed
):
"""Test that enet_coordinate_descent with sample_weights."""
rng = np.random.RandomState(global_random_seed)
n_samples, n_features = 10, 5
X = np.asfortranarray(rng.rand(n_samples, n_features))
y = rng.rand(n_samples)
sw = 0.5 + rng.rand(n_samples) if with_sw else None
reg = ElasticNet(alpha=1e-2, fit_intercept=fit_intercept, random_state=42, tol=1e-6)
reg.fit(X, y, sample_weight=sw)
# The alpha should be small enough s.t. some coefficients are non-zero.
assert np.sum(np.abs(reg.coef_)) > 0

y_centered = y - np.average(y, weights=sw)
X_mean = np.average(X, weights=sw, axis=0)
coef_ = np.zeros_like(reg.coef_)
l1_reg = reg.alpha * reg.l1_ratio
l2_reg = reg.alpha * (1 - reg.l1_ratio)
l1_reg *= np.sum(sw) if with_sw else n_samples
l2_reg *= np.sum(sw) if with_sw else n_samples
rng = np.random.RandomState(42)
random = reg.selection == "random"
coef_, dual_gap_, eps_, n_iter_ = enet_coordinate_descent(
coef_,
l1_reg,
l2_reg,
X,
y_centered if fit_intercept else y,
sample_weight=sw,
X_mean=X_mean if fit_intercept else None,
max_iter=reg.max_iter,
tol=reg.tol,
rng=rng,
random=random,
positive=reg.positive,
)
if with_sw:
assert dual_gap_ / np.sum(sw) == pytest.approx(reg.dual_gap_)
else:
assert dual_gap_ / n_samples == pytest.approx(reg.dual_gap_)
assert_allclose(coef_, reg.coef_, rtol=1e-4, atol=1e-5)
assert n_iter_ == reg.n_iter_


@pytest.mark.parametrize("fit_intercept", [True, False])
@pytest.mark.parametrize("alpha", [0.01])
@pytest.mark.parametrize("precompute", [False, True])
Expand Down