Skip to content

Fix elasticnect cv sample weight #29442

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
8d4b501
Update _alpha_grid to take sample_weight
s-banach Apr 4, 2022
2f494db
Add a simple test for alpha_max with sample_weight
s-banach Apr 4, 2022
fa2c821
Update test
s-banach Apr 4, 2022
75e6584
Clarify _alpha_grid.
s-banach Apr 4, 2022
8b6cfc0
Clarify notation
s-banach Apr 5, 2022
2ba4c57
Use Xy if it is provided.
s-banach Jul 2, 2022
5d1f5e7
Update test, check alpha_max is not too large
s-banach Jul 2, 2022
dce169c
Fix test that alpha_max is not too large.
s-banach Jul 2, 2022
380c21f
Test alpha_max without sample_weight.
s-banach Jul 6, 2022
c187cf7
fix elasticnetcv sample weighting adapted from previous commit by s-b…
snath-xoc Jun 19, 2024
40d8b30
Update _preprocess_data inputs in _coordinate_descent.py
snath-xoc Jun 20, 2024
85062c0
added tests for repeated vs weighted on cyclic ElasticNetCV and modif…
snath-xoc Jun 20, 2024
c649d36
Merge branch 'main' into fix_elasticnet_cv_sample_weight
ogrisel Jun 27, 2024
41fcb5f
added to changelog and changed seeding in tests
snath-xoc Jun 27, 2024
335137d
[all random seeds] test_enet_cv_sample_weight
snath-xoc Jun 27, 2024
36cc847
Merge branch 'main' into fix_elasticnet_cv_sample_weight
ogrisel Jun 28, 2024
fec4f74
Revert unrelated changes
ogrisel Jun 28, 2024
c41a8ee
merged test into test_enet_cv_sample_weight_correctness
snath-xoc Jun 28, 2024
ac9f090
changed sample weight to be explicitly set as integers in sklearn/lin…
snath-xoc Jun 29, 2024
bf62b35
Merge branch 'main' into fix_elasticnet_cv_sample_weight
snath-xoc Jul 5, 2024
2a42de2
add sample weights into default score calculation under _log_reg_scor…
Jul 5, 2024
ff02ffa
[all random seeds] test_enet_cv_sample_weight
snath-xoc Jul 5, 2024
63ad46b
[all random seeds] test_enet_cv_sample_weight
snath-xoc Jul 5, 2024
2dc37c9
modified test_coordinate_descent for X, y, group names
snath-xoc Jul 9, 2024
7cdbf52
restoring some files
snath-xoc Jul 9, 2024
3a19575
add global_random_seed
snath-xoc Jul 9, 2024
5876f3d
Merge branch 'main' into fix_elasticnect_cv_sample_weight
ogrisel Aug 7, 2024
f774c6a
updated comments and changelog
snath-xoc Aug 9, 2024
69646d9
[all random seeds]
snath-xoc Aug 9, 2024
d71315c
[all random seeds]
snath-xoc Aug 21, 2024
9d70495
trigger regular CI
snath-xoc Aug 21, 2024
a7dfa0f
Merge branch 'main' into fix_elasticnect_cv_sample_weight
snath-xoc Aug 21, 2024
d2c6c4d
[all random seeds]
snath-xoc Aug 26, 2024
1f7a407
Apply suggestions from code review
snath-xoc Sep 6, 2024
9aa6ff0
Fix linting problem
ogrisel Sep 6, 2024
cf31e14
updated changelog
snath-xoc Sep 6, 2024
e6c89fe
moved fix in changelog to after API
snath-xoc Sep 6, 2024
2b99c9b
Restore non-degenerate alpha check in test
ogrisel Sep 9, 2024
080c846
Remove alpha boundary check
ogrisel Sep 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/whats_new/v1.6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,11 @@ Changelog
has no effect. `copy_X` will be removed in 1.8.
:pr:`29105` by :user:`Adam Li <adam2392>`.

- |Fix| :class:`linear_model.LassoCV` and :class:`linear_model.ElasticNetCV` now
take sample weights into accounts to define the search grid for the internally tuned
`alpha` hyper-parameter. :pr:`29442` by :user:`John Hopfensperger <s-banach> and
:user:`Shruti Nath <snath-xoc>`.

:mod:`sklearn.manifold`
.......................

Expand Down
68 changes: 34 additions & 34 deletions sklearn/linear_model/_coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def _alpha_grid(
eps=1e-3,
n_alphas=100,
copy_X=True,
sample_weight=None,
):
"""Compute the grid of alpha values for elastic net parameter search

Expand Down Expand Up @@ -132,6 +133,8 @@ def _alpha_grid(

copy_X : bool, default=True
If ``True``, X will be copied; else, it may be overwritten.

sample_weight : ndarray of shape (n_samples,), default=None
"""
if l1_ratio == 0:
raise ValueError(
Expand All @@ -140,43 +143,39 @@ def _alpha_grid(
"your estimator with the appropriate `alphas=` "
"argument."
)
n_samples = len(y)

sparse_center = False
if Xy is None:
X_sparse = sparse.issparse(X)
sparse_center = X_sparse and fit_intercept
X = check_array(
X, accept_sparse="csc", copy=(copy_X and fit_intercept and not X_sparse)
if Xy is not None:
Xyw = Xy
else:
X, y, X_offset, _, _ = _preprocess_data(
X,
y,
fit_intercept=fit_intercept,
copy=copy_X,
sample_weight=sample_weight,
check_input=False,
)
if not X_sparse:
# X can be touched inplace thanks to the above line
X, y, _, _, _ = _preprocess_data(
X, y, fit_intercept=fit_intercept, copy=False
)
Xy = safe_sparse_dot(X.T, y, dense_output=True)

if sparse_center:
# Workaround to find alpha_max for sparse matrices.
# since we should not destroy the sparsity of such matrices.
_, _, X_offset, _, X_scale = _preprocess_data(
X, y, fit_intercept=fit_intercept
)
mean_dot = X_offset * np.sum(y)

if Xy.ndim == 1:
Xy = Xy[:, np.newaxis]

if sparse_center:
if fit_intercept:
Xy -= mean_dot[:, np.newaxis]
if sample_weight is not None:
if y.ndim > 1:
yw = y * sample_weight.reshape(-1, 1)
else:
yw = y * sample_weight
else:
yw = y
if sparse.issparse(X):
Xyw = safe_sparse_dot(X.T, yw, dense_output=True) - np.sum(yw) * X_offset
else:
Xyw = np.dot(X.T, yw)

alpha_max = np.sqrt(np.sum(Xy**2, axis=1)).max() / (n_samples * l1_ratio)
if Xyw.ndim == 1:
Xyw = Xyw[:, np.newaxis]
if sample_weight is not None:
n_samples = sample_weight.sum()
else:
n_samples = X.shape[0]
alpha_max = np.sqrt(np.sum(Xyw**2, axis=1)).max() / (n_samples * l1_ratio)

if alpha_max <= np.finfo(float).resolution:
alphas = np.empty(n_alphas)
alphas.fill(np.finfo(float).resolution)
return alphas
if alpha_max <= np.finfo(np.float64).resolution:
return np.full(n_alphas, np.finfo(np.float64).resolution)

return np.geomspace(alpha_max, alpha_max * eps, num=n_alphas)

Expand Down Expand Up @@ -1702,6 +1701,7 @@ def fit(self, X, y, sample_weight=None, **params):
eps=self.eps,
n_alphas=self.n_alphas,
copy_X=self.copy_X,
sample_weight=sample_weight,
)
for l1_ratio in l1_ratios
]
Expand Down
124 changes: 84 additions & 40 deletions sklearn/linear_model/tests/test_coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
assert_almost_equal,
assert_array_almost_equal,
assert_array_equal,
assert_array_less,
ignore_warnings,
)
from sklearn.utils.fixes import COO_CONTAINERS, CSC_CONTAINERS, CSR_CONTAINERS
Expand Down Expand Up @@ -1304,55 +1305,78 @@ def test_enet_sample_weight_consistency(

@pytest.mark.parametrize("fit_intercept", [True, False])
@pytest.mark.parametrize("sparse_container", [None] + CSC_CONTAINERS)
def test_enet_cv_sample_weight_correctness(fit_intercept, sparse_container):
"""Test that ElasticNetCV with sample weights gives correct results."""
rng = np.random.RandomState(42)
n_splits, n_samples, n_features = 3, 10, 5
X = rng.rand(n_splits * n_samples, n_features)
def test_enet_cv_sample_weight_correctness(
fit_intercept, sparse_container, global_random_seed
):
"""Test that ElasticNetCV with sample weights gives correct results.

We fit the same model twice, once with weighted training data, once with repeated
data points in the training data and check that both models converge to the
same solution.

Since this model uses an internal cross-validation scheme to tune the alpha
regularization parameter, we make sure that the repetitions only occur within
a specific CV group. Data points belonging to other CV groups stay
unit-weighted / "unrepeated".
"""
rng = np.random.RandomState(global_random_seed)
n_splits, n_samples_per_cv, n_features = 3, 10, 5
X_with_weights = rng.rand(n_splits * n_samples_per_cv, n_features)
beta = rng.rand(n_features)
beta[0:2] = 0
y = X @ beta + rng.rand(n_splits * n_samples)
sw = np.ones_like(y)
y_with_weights = X_with_weights @ beta + rng.rand(n_splits * n_samples_per_cv)

if sparse_container is not None:
X = sparse_container(X)
X_with_weights = sparse_container(X_with_weights)
params = dict(tol=1e-6)

# Set alphas, otherwise the two cv models might use different ones.
if fit_intercept:
alphas = np.linspace(0.001, 0.01, num=91)
else:
alphas = np.linspace(0.01, 0.1, num=91)

# We weight the first fold 2 times more.
sw[:n_samples] = 2
groups_sw = np.r_[
np.full(n_samples, 0), np.full(n_samples, 1), np.full(n_samples, 2)
]
splits_sw = list(LeaveOneGroupOut().split(X, groups=groups_sw))
reg_sw = ElasticNetCV(
alphas=alphas, cv=splits_sw, fit_intercept=fit_intercept, **params
# Assign random integer weights only to the first cross-validation group.
# The samples in the other cross-validation groups are left with unit
# weights.

sw = np.ones_like(y_with_weights)
sw[:n_samples_per_cv] = rng.randint(0, 5, size=n_samples_per_cv)
groups_with_weights = np.concatenate(
[
np.full(n_samples_per_cv, 0),
np.full(n_samples_per_cv, 1),
np.full(n_samples_per_cv, 2),
]
)
splits_with_weights = list(
LeaveOneGroupOut().split(X_with_weights, groups=groups_with_weights)
)
reg_with_weights = ElasticNetCV(
cv=splits_with_weights, fit_intercept=fit_intercept, **params
)
reg_sw.fit(X, y, sample_weight=sw)

# We repeat the first fold 2 times and provide splits ourselves
reg_with_weights.fit(X_with_weights, y_with_weights, sample_weight=sw)

if sparse_container is not None:
X = X.toarray()
X = np.r_[X[:n_samples], X]
X_with_weights = X_with_weights.toarray()
X_with_repetitions = np.repeat(X_with_weights, sw.astype(int), axis=0)
if sparse_container is not None:
X = sparse_container(X)
y = np.r_[y[:n_samples], y]
groups = np.r_[
np.full(2 * n_samples, 0), np.full(n_samples, 1), np.full(n_samples, 2)
]
splits = list(LeaveOneGroupOut().split(X, groups=groups))
reg = ElasticNetCV(alphas=alphas, cv=splits, fit_intercept=fit_intercept, **params)
reg.fit(X, y)
X_with_repetitions = sparse_container(X_with_repetitions)

y_with_repetitions = np.repeat(y_with_weights, sw.astype(int), axis=0)
groups_with_repetitions = np.repeat(groups_with_weights, sw.astype(int), axis=0)

splits_with_repetitions = list(
LeaveOneGroupOut().split(X_with_repetitions, groups=groups_with_repetitions)
)
reg_with_repetitions = ElasticNetCV(
cv=splits_with_repetitions, fit_intercept=fit_intercept, **params
)
reg_with_repetitions.fit(X_with_repetitions, y_with_repetitions)

# ensure that we chose meaningful alphas, i.e. not boundaries
assert alphas[0] < reg.alpha_ < alphas[-1]
assert reg_sw.alpha_ == reg.alpha_
assert_allclose(reg_sw.coef_, reg.coef_)
assert reg_sw.intercept_ == pytest.approx(reg.intercept_)
# Check that the alpha selection process is the same:
assert_allclose(reg_with_weights.mse_path_, reg_with_repetitions.mse_path_)
assert_allclose(reg_with_weights.alphas_, reg_with_repetitions.alphas_)
assert reg_with_weights.alpha_ == pytest.approx(reg_with_repetitions.alpha_)

# Check that the final model coefficients are the same:
assert_allclose(reg_with_weights.coef_, reg_with_repetitions.coef_, atol=1e-10)
assert reg_with_weights.intercept_ == pytest.approx(reg_with_repetitions.intercept_)


@pytest.mark.parametrize("sample_weight", [False, True])
Expand Down Expand Up @@ -1444,9 +1468,29 @@ def test_enet_cv_sample_weight_consistency(
assert_allclose(reg.intercept_, intercept)


@pytest.mark.parametrize("X_is_sparse", [False, True])
@pytest.mark.parametrize("fit_intercept", [False, True])
@pytest.mark.parametrize("sample_weight", [np.array([10, 1, 10, 1]), None])
def test_enet_alpha_max_sample_weight(X_is_sparse, fit_intercept, sample_weight):
X = np.array([[3.0, 1.0], [2.0, 5.0], [5.0, 3.0], [1.0, 4.0]])
beta = np.array([1, 1])
y = X @ beta
if X_is_sparse:
X = sparse.csc_matrix(X)
# Test alpha_max makes coefs zero.
reg = ElasticNetCV(n_alphas=1, cv=2, eps=1, fit_intercept=fit_intercept)
reg.fit(X, y, sample_weight=sample_weight)
assert_allclose(reg.coef_, 0, atol=1e-5)
alpha_max = reg.alpha_
# Test smaller alpha makes coefs nonzero.
reg = ElasticNet(alpha=0.99 * alpha_max, fit_intercept=fit_intercept)
reg.fit(X, y, sample_weight=sample_weight)
assert_array_less(1e-3, np.max(np.abs(reg.coef_)))


@pytest.mark.parametrize("estimator", [ElasticNetCV, LassoCV])
def test_linear_models_cv_fit_with_loky(estimator):
# LinearModelsCV.fit performs inplace operations on fancy-indexed memmapped
# LinearModelsCV.fit performs operations on fancy-indexed memmapped
# data when using the loky backend, causing an error due to unexpected
# behavior of fancy indexing of read-only memmaps (cf. numpy#14132).

Expand Down