Skip to content

Commit bf0ece8

Browse files
authored
ENH add sample_weight to sparse coordinade descent (#22808)
1 parent 4408dfb commit bf0ece8

File tree

8 files changed

+217
-139
lines changed

8 files changed

+217
-139
lines changed

doc/whats_new/v1.1.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -623,14 +623,18 @@ Changelog
623623
:class:`linear_model.ARDRegression` now preserve float32 dtype. :pr:`9087` by
624624
:user:`Arthur Imbert <Henley13>` and :pr:`22525` by :user:`Meekail Zain <micky774>`.
625625

626+
- |Feature| :class:`ElasticNet`, :class:`ElasticNetCV`, :class:`Lasso` and
627+
:class:`LassoCV` support `sample_weight` for sparse input `X`.
628+
:pr:`22808` by :user:`Christian Lorentzen <lorentzenchr>`.
629+
626630
- |Fix| The `coef_` and `intercept_` attributes of :class:`LinearRegression` are now
627631
correctly computed in the presence of sample weights when the input is sparse.
628632
:pr:`22891` by :user:`Jérémie du Boisberranger <jeremiedbb>`.
629633

630634
- |Fix| The `coef_` and `intercept_` attributes of :class:`Ridge` with
631635
`solver="sparse_cg"` and `solver="lbfgs"` are now correctly computed in the presence
632636
of sample weights when the input is sparse.
633-
:pr:`22899` by :user:`Jérémie du Boisberranger <jeremiedbb>`.
637+
:pr:`22899` by :user:`Jérémie du Boisberranger <jeremiedbb>`.
634638

635639
:mod:`sklearn.manifold`
636640
.......................

sklearn/linear_model/_base.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,6 @@ def _preprocess_data(
217217
normalize=False,
218218
copy=True,
219219
sample_weight=None,
220-
return_mean=False,
221220
check_input=True,
222221
):
223222
"""Center and scale data.
@@ -231,7 +230,7 @@ def _preprocess_data(
231230
232231
X_scale is the L2 norm of X - X_offset. If sample_weight is not None,
233232
then the weighted mean of X and y is zero, and not the mean itself. If
234-
return_mean=True, the mean, eventually weighted, is returned, independently
233+
fit_intercept=True, the mean, eventually weighted, is returned, independently
235234
of whether X was centered (option used for optimization with sparse data in
236235
coordinate_descend).
237236
@@ -271,8 +270,6 @@ def _preprocess_data(
271270
if fit_intercept:
272271
if sp.issparse(X):
273272
X_offset, X_var = mean_variance_axis(X, axis=0, weights=sample_weight)
274-
if not return_mean:
275-
X_offset[:] = X.dtype.type(0)
276273
else:
277274
if normalize:
278275
X_offset, X_var, _ = _incremental_mean_and_var(
@@ -328,7 +325,18 @@ def _preprocess_data(
328325
def _rescale_data(X, y, sample_weight):
329326
"""Rescale data sample-wise by square root of sample_weight.
330327
331-
For many linear models, this enables easy support for sample_weight.
328+
For many linear models, this enables easy support for sample_weight because
329+
330+
(y - X w)' S (y - X w)
331+
332+
with S = diag(sample_weight) becomes
333+
334+
||y_rescaled - X_rescaled w||_2^2
335+
336+
when setting
337+
338+
y_rescaled = sqrt(S) y
339+
X_rescaled = sqrt(S) X
332340
333341
Returns
334342
-------
@@ -687,7 +695,6 @@ def fit(self, X, y, sample_weight=None):
687695
normalize=_normalize,
688696
copy=self.copy_X,
689697
sample_weight=sample_weight,
690-
return_mean=True,
691698
)
692699

693700
# Sample weight can be implemented via a simple rescaling.
@@ -824,8 +831,8 @@ def _pre_fit(
824831
fit_intercept=fit_intercept,
825832
normalize=normalize,
826833
copy=False,
827-
return_mean=True,
828834
check_input=check_input,
835+
sample_weight=sample_weight,
829836
)
830837
else:
831838
# copy was done in fit if necessary
@@ -838,8 +845,11 @@ def _pre_fit(
838845
check_input=check_input,
839846
sample_weight=sample_weight,
840847
)
841-
if sample_weight is not None:
842-
X, y, _ = _rescale_data(X, y, sample_weight=sample_weight)
848+
# Rescale only in dense case. Sparse cd solver directly deals with
849+
# sample_weight.
850+
if sample_weight is not None:
851+
# This triggers copies anyway.
852+
X, y, _ = _rescale_data(X, y, sample_weight=sample_weight)
843853

844854
# FIXME: 'normalize' to be removed in 1.2
845855
if hasattr(precompute, "__array__"):

sklearn/linear_model/_cd_fast.pyx

Lines changed: 109 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -260,22 +260,45 @@ def enet_coordinate_descent(floating[::1] w,
260260
return w, gap, tol, n_iter + 1
261261

262262

263-
def sparse_enet_coordinate_descent(floating [::1] w,
264-
floating alpha, floating beta,
265-
np.ndarray[floating, ndim=1, mode='c'] X_data,
266-
np.ndarray[int, ndim=1, mode='c'] X_indices,
267-
np.ndarray[int, ndim=1, mode='c'] X_indptr,
268-
np.ndarray[floating, ndim=1] y,
269-
floating[:] X_mean, int max_iter,
270-
floating tol, object rng, bint random=0,
271-
bint positive=0):
263+
def sparse_enet_coordinate_descent(
264+
floating [::1] w,
265+
floating alpha,
266+
floating beta,
267+
np.ndarray[floating, ndim=1, mode='c'] X_data,
268+
np.ndarray[int, ndim=1, mode='c'] X_indices,
269+
np.ndarray[int, ndim=1, mode='c'] X_indptr,
270+
floating[::1] y,
271+
floating[::1] sample_weight,
272+
floating[::1] X_mean,
273+
int max_iter,
274+
floating tol,
275+
object rng,
276+
bint random=0,
277+
bint positive=0,
278+
):
272279
"""Cython version of the coordinate descent algorithm for Elastic-Net
273280
274281
We minimize:
275282
276-
(1/2) * norm(y - X w, 2)^2 + alpha norm(w, 1) + (beta/2) * norm(w, 2)^2
283+
1/2 * norm(y - Z w, 2)^2 + alpha * norm(w, 1) + (beta/2) * norm(w, 2)^2
284+
285+
where Z = X - X_mean.
286+
With sample weights sw, this becomes
277287
288+
1/2 * sum(sw * (y - Z w)^2, axis=0) + alpha * norm(w, 1)
289+
+ (beta/2) * norm(w, 2)^2
290+
291+
and X_mean is the weighted average of X (per column).
278292
"""
293+
# Notes for sample_weight:
294+
# For dense X, one centers X and y and then rescales them by sqrt(sample_weight).
295+
# Here, for sparse X, we get the sample_weight averaged center X_mean. We take care
296+
# that every calculation results as if we had rescaled y and X (and therefore also
297+
# X_mean) by sqrt(sample_weight) without actually calculating the square root.
298+
# We work with:
299+
# yw = sample_weight
300+
# R = sample_weight * residual
301+
# norm_cols_X = np.sum(sample_weight * (X - X_mean)**2, axis=0)
279302

280303
# get the data information into easy vars
281304
cdef unsigned int n_samples = y.shape[0]
@@ -289,18 +312,17 @@ def sparse_enet_coordinate_descent(floating [::1] w,
289312
cdef unsigned int endptr
290313

291314
# initial value of the residuals
292-
cdef floating[:] R = y.copy()
293-
294-
cdef floating[:] X_T_R
295-
cdef floating[:] XtA
315+
# R = y - Zw, weighted version R = sample_weight * (y - Zw)
316+
cdef floating[::1] R
317+
cdef floating[::1] XtA
318+
cdef floating[::1] yw
296319

297320
if floating is float:
298321
dtype = np.float32
299322
else:
300323
dtype = np.float64
301324

302325
norm_cols_X = np.zeros(n_features, dtype=dtype)
303-
X_T_R = np.zeros(n_features, dtype=dtype)
304326
XtA = np.zeros(n_features, dtype=dtype)
305327

306328
cdef floating tmp
@@ -324,6 +346,14 @@ def sparse_enet_coordinate_descent(floating [::1] w,
324346
cdef UINT32_t rand_r_state_seed = rng.randint(0, RAND_R_MAX)
325347
cdef UINT32_t* rand_r_state = &rand_r_state_seed
326348
cdef bint center = False
349+
cdef bint no_sample_weights = sample_weight is None
350+
351+
if no_sample_weights:
352+
yw = y
353+
R = y.copy()
354+
else:
355+
yw = np.multiply(sample_weight, y)
356+
R = yw.copy()
327357

328358
with nogil:
329359
# center = (X_mean != 0).any()
@@ -338,19 +368,32 @@ def sparse_enet_coordinate_descent(floating [::1] w,
338368
normalize_sum = 0.0
339369
w_ii = w[ii]
340370

341-
for jj in range(startptr, endptr):
342-
normalize_sum += (X_data[jj] - X_mean_ii) ** 2
343-
R[X_indices[jj]] -= X_data[jj] * w_ii
344-
norm_cols_X[ii] = normalize_sum + \
345-
(n_samples - endptr + startptr) * X_mean_ii ** 2
346-
347-
if center:
348-
for jj in range(n_samples):
349-
R[jj] += X_mean_ii * w_ii
371+
if no_sample_weights:
372+
for jj in range(startptr, endptr):
373+
normalize_sum += (X_data[jj] - X_mean_ii) ** 2
374+
R[X_indices[jj]] -= X_data[jj] * w_ii
375+
norm_cols_X[ii] = normalize_sum + \
376+
(n_samples - endptr + startptr) * X_mean_ii ** 2
377+
if center:
378+
for jj in range(n_samples):
379+
R[jj] += X_mean_ii * w_ii
380+
else:
381+
for jj in range(startptr, endptr):
382+
tmp = sample_weight[X_indices[jj]]
383+
# second term will be subtracted by loop over range(n_samples)
384+
normalize_sum += (tmp * (X_data[jj] - X_mean_ii) ** 2
385+
- tmp * X_mean_ii ** 2)
386+
R[X_indices[jj]] -= tmp * X_data[jj] * w_ii
387+
if center:
388+
for jj in range(n_samples):
389+
normalize_sum += sample_weight[jj] * X_mean_ii ** 2
390+
R[jj] += sample_weight[jj] * X_mean_ii * w_ii
391+
norm_cols_X[ii] = normalize_sum
350392
startptr = endptr
351393

352394
# tol *= np.dot(y, y)
353-
tol *= _dot(n_samples, &y[0], 1, &y[0], 1)
395+
# with sample weights: tol *= y @ (sw * y)
396+
tol *= _dot(n_samples, &y[0], 1, &yw[0], 1)
354397

355398
for n_iter in range(max_iter):
356399

@@ -373,11 +416,19 @@ def sparse_enet_coordinate_descent(floating [::1] w,
373416

374417
if w_ii != 0.0:
375418
# R += w_ii * X[:,ii]
376-
for jj in range(startptr, endptr):
377-
R[X_indices[jj]] += X_data[jj] * w_ii
378-
if center:
379-
for jj in range(n_samples):
380-
R[jj] -= X_mean_ii * w_ii
419+
if no_sample_weights:
420+
for jj in range(startptr, endptr):
421+
R[X_indices[jj]] += X_data[jj] * w_ii
422+
if center:
423+
for jj in range(n_samples):
424+
R[jj] -= X_mean_ii * w_ii
425+
else:
426+
for jj in range(startptr, endptr):
427+
tmp = sample_weight[X_indices[jj]]
428+
R[X_indices[jj]] += tmp * X_data[jj] * w_ii
429+
if center:
430+
for jj in range(n_samples):
431+
R[jj] -= sample_weight[jj] * X_mean_ii * w_ii
381432

382433
# tmp = (X[:,ii] * R).sum()
383434
tmp = 0.0
@@ -398,20 +449,25 @@ def sparse_enet_coordinate_descent(floating [::1] w,
398449

399450
if w[ii] != 0.0:
400451
# R -= w[ii] * X[:,ii] # Update residual
401-
for jj in range(startptr, endptr):
402-
R[X_indices[jj]] -= X_data[jj] * w[ii]
403-
404-
if center:
405-
for jj in range(n_samples):
406-
R[jj] += X_mean_ii * w[ii]
452+
if no_sample_weights:
453+
for jj in range(startptr, endptr):
454+
R[X_indices[jj]] -= X_data[jj] * w[ii]
455+
if center:
456+
for jj in range(n_samples):
457+
R[jj] += X_mean_ii * w[ii]
458+
else:
459+
for jj in range(startptr, endptr):
460+
tmp = sample_weight[X_indices[jj]]
461+
R[X_indices[jj]] -= tmp * X_data[jj] * w[ii]
462+
if center:
463+
for jj in range(n_samples):
464+
R[jj] += sample_weight[jj] * X_mean_ii * w[ii]
407465

408466
# update the maximum absolute coefficient update
409467
d_w_ii = fabs(w[ii] - w_ii)
410-
if d_w_ii > d_w_max:
411-
d_w_max = d_w_ii
468+
d_w_max = fmax(d_w_max, d_w_ii)
412469

413-
if fabs(w[ii]) > w_max:
414-
w_max = fabs(w[ii])
470+
w_max = fmax(w_max, fabs(w[ii]))
415471

416472
if w_max == 0.0 or d_w_max / w_max < d_w_tol or n_iter == max_iter - 1:
417473
# the biggest coordinate update of this iteration was smaller than
@@ -424,22 +480,30 @@ def sparse_enet_coordinate_descent(floating [::1] w,
424480
for jj in range(n_samples):
425481
R_sum += R[jj]
426482

483+
# XtA = X.T @ R - beta * w
427484
for ii in range(n_features):
428-
X_T_R[ii] = 0.0
485+
XtA[ii] = 0.0
429486
for jj in range(X_indptr[ii], X_indptr[ii + 1]):
430-
X_T_R[ii] += X_data[jj] * R[X_indices[jj]]
487+
XtA[ii] += X_data[jj] * R[X_indices[jj]]
431488

432489
if center:
433-
X_T_R[ii] -= X_mean[ii] * R_sum
434-
XtA[ii] = X_T_R[ii] - beta * w[ii]
490+
XtA[ii] -= X_mean[ii] * R_sum
491+
XtA[ii] -= beta * w[ii]
435492

436493
if positive:
437494
dual_norm_XtA = max(n_features, &XtA[0])
438495
else:
439496
dual_norm_XtA = abs_max(n_features, &XtA[0])
440497

441498
# R_norm2 = np.dot(R, R)
442-
R_norm2 = _dot(n_samples, &R[0], 1, &R[0], 1)
499+
if no_sample_weights:
500+
R_norm2 = _dot(n_samples, &R[0], 1, &R[0], 1)
501+
else:
502+
R_norm2 = 0.0
503+
for jj in range(n_samples):
504+
# R is already multiplied by sample_weight
505+
if sample_weight[jj] != 0:
506+
R_norm2 += (R[jj] ** 2) / sample_weight[jj]
443507

444508
# w_norm2 = np.dot(w, w)
445509
w_norm2 = _dot(n_features, &w[0], 1, &w[0], 1)

0 commit comments

Comments
 (0)