@@ -260,22 +260,45 @@ def enet_coordinate_descent(floating[::1] w,
260
260
return w, gap, tol, n_iter + 1
261
261
262
262
263
- def sparse_enet_coordinate_descent (floating [::1] w ,
264
- floating alpha , floating beta ,
265
- np.ndarray[floating , ndim = 1 , mode = ' c' ] X_data,
266
- np.ndarray[int , ndim = 1 , mode = ' c' ] X_indices,
267
- np.ndarray[int , ndim = 1 , mode = ' c' ] X_indptr,
268
- np.ndarray[floating , ndim = 1 ] y,
269
- floating[:] X_mean , int max_iter ,
270
- floating tol , object rng , bint random = 0 ,
271
- bint positive = 0 ):
263
+ def sparse_enet_coordinate_descent (
264
+ floating [::1] w ,
265
+ floating alpha ,
266
+ floating beta ,
267
+ np.ndarray[floating , ndim = 1 , mode = ' c' ] X_data,
268
+ np.ndarray[int , ndim = 1 , mode = ' c' ] X_indices,
269
+ np.ndarray[int , ndim = 1 , mode = ' c' ] X_indptr,
270
+ floating[::1] y ,
271
+ floating[::1] sample_weight ,
272
+ floating[::1] X_mean ,
273
+ int max_iter ,
274
+ floating tol ,
275
+ object rng ,
276
+ bint random = 0 ,
277
+ bint positive = 0 ,
278
+ ):
272
279
""" Cython version of the coordinate descent algorithm for Elastic-Net
273
280
274
281
We minimize:
275
282
276
- (1/2) * norm(y - X w, 2)^2 + alpha norm(w, 1) + (beta/2) * norm(w, 2)^2
283
+ 1/2 * norm(y - Z w, 2)^2 + alpha * norm(w, 1) + (beta/2) * norm(w, 2)^2
284
+
285
+ where Z = X - X_mean.
286
+ With sample weights sw, this becomes
277
287
288
+ 1/2 * sum(sw * (y - Z w)^2, axis=0) + alpha * norm(w, 1)
289
+ + (beta/2) * norm(w, 2)^2
290
+
291
+ and X_mean is the weighted average of X (per column).
278
292
"""
293
+ # Notes for sample_weight:
294
+ # For dense X, one centers X and y and then rescales them by sqrt(sample_weight).
295
+ # Here, for sparse X, we get the sample_weight averaged center X_mean. We take care
296
+ # that every calculation results as if we had rescaled y and X (and therefore also
297
+ # X_mean) by sqrt(sample_weight) without actually calculating the square root.
298
+ # We work with:
299
+ # yw = sample_weight
300
+ # R = sample_weight * residual
301
+ # norm_cols_X = np.sum(sample_weight * (X - X_mean)**2, axis=0)
279
302
280
303
# get the data information into easy vars
281
304
cdef unsigned int n_samples = y.shape[0 ]
@@ -289,18 +312,17 @@ def sparse_enet_coordinate_descent(floating [::1] w,
289
312
cdef unsigned int endptr
290
313
291
314
# initial value of the residuals
292
- cdef floating[:] R = y.copy( )
293
-
294
- cdef floating[:] X_T_R
295
- cdef floating[:] XtA
315
+ # R = y - Zw, weighted version R = sample_weight * (y - Zw )
316
+ cdef floating[:: 1 ] R
317
+ cdef floating[:: 1 ] XtA
318
+ cdef floating[:: 1 ] yw
296
319
297
320
if floating is float :
298
321
dtype = np.float32
299
322
else :
300
323
dtype = np.float64
301
324
302
325
norm_cols_X = np.zeros(n_features, dtype = dtype)
303
- X_T_R = np.zeros(n_features, dtype = dtype)
304
326
XtA = np.zeros(n_features, dtype = dtype)
305
327
306
328
cdef floating tmp
@@ -324,6 +346,14 @@ def sparse_enet_coordinate_descent(floating [::1] w,
324
346
cdef UINT32_t rand_r_state_seed = rng.randint(0 , RAND_R_MAX)
325
347
cdef UINT32_t* rand_r_state = & rand_r_state_seed
326
348
cdef bint center = False
349
+ cdef bint no_sample_weights = sample_weight is None
350
+
351
+ if no_sample_weights:
352
+ yw = y
353
+ R = y.copy()
354
+ else :
355
+ yw = np.multiply(sample_weight, y)
356
+ R = yw.copy()
327
357
328
358
with nogil:
329
359
# center = (X_mean != 0).any()
@@ -338,19 +368,32 @@ def sparse_enet_coordinate_descent(floating [::1] w,
338
368
normalize_sum = 0.0
339
369
w_ii = w[ii]
340
370
341
- for jj in range (startptr, endptr):
342
- normalize_sum += (X_data[jj] - X_mean_ii) ** 2
343
- R[X_indices[jj]] -= X_data[jj] * w_ii
344
- norm_cols_X[ii] = normalize_sum + \
345
- (n_samples - endptr + startptr) * X_mean_ii ** 2
346
-
347
- if center:
348
- for jj in range (n_samples):
349
- R[jj] += X_mean_ii * w_ii
371
+ if no_sample_weights:
372
+ for jj in range (startptr, endptr):
373
+ normalize_sum += (X_data[jj] - X_mean_ii) ** 2
374
+ R[X_indices[jj]] -= X_data[jj] * w_ii
375
+ norm_cols_X[ii] = normalize_sum + \
376
+ (n_samples - endptr + startptr) * X_mean_ii ** 2
377
+ if center:
378
+ for jj in range (n_samples):
379
+ R[jj] += X_mean_ii * w_ii
380
+ else :
381
+ for jj in range (startptr, endptr):
382
+ tmp = sample_weight[X_indices[jj]]
383
+ # second term will be subtracted by loop over range(n_samples)
384
+ normalize_sum += (tmp * (X_data[jj] - X_mean_ii) ** 2
385
+ - tmp * X_mean_ii ** 2 )
386
+ R[X_indices[jj]] -= tmp * X_data[jj] * w_ii
387
+ if center:
388
+ for jj in range (n_samples):
389
+ normalize_sum += sample_weight[jj] * X_mean_ii ** 2
390
+ R[jj] += sample_weight[jj] * X_mean_ii * w_ii
391
+ norm_cols_X[ii] = normalize_sum
350
392
startptr = endptr
351
393
352
394
# tol *= np.dot(y, y)
353
- tol *= _dot(n_samples, & y[0 ], 1 , & y[0 ], 1 )
395
+ # with sample weights: tol *= y @ (sw * y)
396
+ tol *= _dot(n_samples, & y[0 ], 1 , & yw[0 ], 1 )
354
397
355
398
for n_iter in range (max_iter):
356
399
@@ -373,11 +416,19 @@ def sparse_enet_coordinate_descent(floating [::1] w,
373
416
374
417
if w_ii != 0.0 :
375
418
# R += w_ii * X[:,ii]
376
- for jj in range (startptr, endptr):
377
- R[X_indices[jj]] += X_data[jj] * w_ii
378
- if center:
379
- for jj in range (n_samples):
380
- R[jj] -= X_mean_ii * w_ii
419
+ if no_sample_weights:
420
+ for jj in range (startptr, endptr):
421
+ R[X_indices[jj]] += X_data[jj] * w_ii
422
+ if center:
423
+ for jj in range (n_samples):
424
+ R[jj] -= X_mean_ii * w_ii
425
+ else :
426
+ for jj in range (startptr, endptr):
427
+ tmp = sample_weight[X_indices[jj]]
428
+ R[X_indices[jj]] += tmp * X_data[jj] * w_ii
429
+ if center:
430
+ for jj in range (n_samples):
431
+ R[jj] -= sample_weight[jj] * X_mean_ii * w_ii
381
432
382
433
# tmp = (X[:,ii] * R).sum()
383
434
tmp = 0.0
@@ -398,20 +449,25 @@ def sparse_enet_coordinate_descent(floating [::1] w,
398
449
399
450
if w[ii] != 0.0 :
400
451
# R -= w[ii] * X[:,ii] # Update residual
401
- for jj in range (startptr, endptr):
402
- R[X_indices[jj]] -= X_data[jj] * w[ii]
403
-
404
- if center:
405
- for jj in range (n_samples):
406
- R[jj] += X_mean_ii * w[ii]
452
+ if no_sample_weights:
453
+ for jj in range (startptr, endptr):
454
+ R[X_indices[jj]] -= X_data[jj] * w[ii]
455
+ if center:
456
+ for jj in range (n_samples):
457
+ R[jj] += X_mean_ii * w[ii]
458
+ else :
459
+ for jj in range (startptr, endptr):
460
+ tmp = sample_weight[X_indices[jj]]
461
+ R[X_indices[jj]] -= tmp * X_data[jj] * w[ii]
462
+ if center:
463
+ for jj in range (n_samples):
464
+ R[jj] += sample_weight[jj] * X_mean_ii * w[ii]
407
465
408
466
# update the maximum absolute coefficient update
409
467
d_w_ii = fabs(w[ii] - w_ii)
410
- if d_w_ii > d_w_max:
411
- d_w_max = d_w_ii
468
+ d_w_max = fmax(d_w_max, d_w_ii)
412
469
413
- if fabs(w[ii]) > w_max:
414
- w_max = fabs(w[ii])
470
+ w_max = fmax(w_max, fabs(w[ii]))
415
471
416
472
if w_max == 0.0 or d_w_max / w_max < d_w_tol or n_iter == max_iter - 1 :
417
473
# the biggest coordinate update of this iteration was smaller than
@@ -424,22 +480,30 @@ def sparse_enet_coordinate_descent(floating [::1] w,
424
480
for jj in range (n_samples):
425
481
R_sum += R[jj]
426
482
483
+ # XtA = X.T @ R - beta * w
427
484
for ii in range (n_features):
428
- X_T_R [ii] = 0.0
485
+ XtA [ii] = 0.0
429
486
for jj in range (X_indptr[ii], X_indptr[ii + 1 ]):
430
- X_T_R [ii] += X_data[jj] * R[X_indices[jj]]
487
+ XtA [ii] += X_data[jj] * R[X_indices[jj]]
431
488
432
489
if center:
433
- X_T_R [ii] -= X_mean[ii] * R_sum
434
- XtA[ii] = X_T_R[ii] - beta * w[ii]
490
+ XtA [ii] -= X_mean[ii] * R_sum
491
+ XtA[ii] -= beta * w[ii]
435
492
436
493
if positive:
437
494
dual_norm_XtA = max (n_features, & XtA[0 ])
438
495
else :
439
496
dual_norm_XtA = abs_max(n_features, & XtA[0 ])
440
497
441
498
# R_norm2 = np.dot(R, R)
442
- R_norm2 = _dot(n_samples, & R[0 ], 1 , & R[0 ], 1 )
499
+ if no_sample_weights:
500
+ R_norm2 = _dot(n_samples, & R[0 ], 1 , & R[0 ], 1 )
501
+ else :
502
+ R_norm2 = 0.0
503
+ for jj in range (n_samples):
504
+ # R is already multiplied by sample_weight
505
+ if sample_weight[jj] != 0 :
506
+ R_norm2 += (R[jj] ** 2 ) / sample_weight[jj]
443
507
444
508
# w_norm2 = np.dot(w, w)
445
509
w_norm2 = _dot(n_features, & w[0 ], 1 , & w[0 ], 1 )
0 commit comments