Skip to content

Calculation of alphas in ElasticNetCV doesn't use sample_weight #22914

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
s-banach opened this issue Mar 21, 2022 · 4 comments · Fixed by #29442
Closed

Calculation of alphas in ElasticNetCV doesn't use sample_weight #22914

s-banach opened this issue Mar 21, 2022 · 4 comments · Fixed by #29442

Comments

@s-banach
Copy link
Contributor

s-banach commented Mar 21, 2022

Describe the bug

In ElasticNetCV, the first and largest value of alpha, call it alpha_max, should be just large enough to force all of the coefficients to become zero. The existing code works correctly when sample_weight is not specified. However, the computation of alpha_max does not take into account sample_weight.

Steps/Code to Reproduce

import numpy as np
from sklearn.linear_model import ElasticNet, ElasticNetCV

X = np.array([[3, 1], [2, 5], [5, 3], [1, 4]])
beta = np.array([1, 1])
y = X @ beta
w = np.array([10, 1, 10, 1])

# Fit ElasticNetCV just to get the .alphas_ attribute
enetCV = ElasticNetCV(cv=2)
enetCV.fit(X, y, sample_weight=w)

# The coefficient of ElasticNet fitted at alpha_max should be [0.  0.].
alpha_max = enetCV.alphas_[0]
enet = ElasticNet(alpha=alpha_max)
enet.fit(X, y, sample_weight=w)
print(enet.coef_)  # [0.1970807  0.19708023]

Expected Results

If the correct value of alpha_max is computed, then enet.coef_ should be right at the cusp of zero, such that any smaller value of alpha makes it nonzero:

def get_alpha_max(X, y, w, l1_ratio=0.5):
    wn = w / w.sum()
    Xn = X - np.dot(wn, X)
    yn = (y - np.dot(wn, y)) * wn
    return np.max(np.abs(yn @ Xn)) / l1_ratio


enet = ElasticNet(alpha=get_alpha_max(X, y, w))
enet.fit(X, y, sample_weight=w)
print(enet.coef_)  # [6.70427878e-17 6.70427878e-17]

Actual Results

enet.coef_ is [0.1970807 0.19708023].

Versions

System:
    python: 3.9.7 (default, Sep 16 2021, 13:09:58)  [GCC 7.5.0]
executable: /home/jhopfens/.conda/envs/jhop39/bin/python
   machine: Linux-3.10.0-1160.53.1.el7.x86_64-x86_64-with-glibc2.17

Python dependencies:
          pip: 21.2.4
   setuptools: 58.0.4
      sklearn: 1.0.2
        numpy: 1.21.5
        scipy: 1.7.2
       Cython: 0.29.24
       pandas: 1.3.5
   matplotlib: 3.5.1
       joblib: 1.1.0
threadpoolctl: 3.0.0

Built with OpenMP: True
@lorentzenchr
Copy link
Member

lorentzenchr commented Mar 21, 2022

I agree, the max alpha should give zero coefficients.
Do you want to give it a try and open a PR?

The change policy might be tricky as it corresponds to changing a model (once fixed, ElasticNetCV.fit will give different results than before) and there is no meaningful deprecation path, IMO.
I consider it a bug fix such that no deprecation cycle is needed.

@s-banach
Copy link
Contributor Author

s-banach commented Mar 21, 2022

I could simply modify the function sklearn.linear_model._coordinate_descent._alpha_grid, but I'm not sure this would be optimal. I would simply be repeating the calculation of multiplying by sqrt(sample_weight) which is already being done when we fit ElasticNet.
Personally, here is how I would approach the problem:

(1) I would always make the ElasticNet model call enet_path. Surprisingly, this can be faster than the existing approach, even if we only want to solve for a single value of alpha. Consider the following timings from my own machine.

First we generate the data and the array of alphas:

from sklearn.linear_model import ElasticNet, enet_path
from sklearn.datasets import make_regression

X, y = make_regression(1_000_000, n_features=50, noise=0.1, random_state=0)
alphas, _, _ = enet_path(X, y)

Now we time the alternative approaches:

%%timeit
enet_path(X, y)

Output: 1.41 s ± 348 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%%timeit
enet = ElasticNet(alpha = alphas[-1])
enet.fit(X, y)

Output: 1.58 s ± 281 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

(2) Thus, ElasticNet would accept an n_alphas parameter.
Now ElasticNetCV(n_alphas=n_alphas) will always begin by fitting ElasticNet(n_alphas=n_alphas). Then, to get the alphas to use for the different CV folds, we simply use the .alphas_ attribute of our ElasticNet.

Therefore, the normalization of X, y with respect to sample_weight will only happen once inside of ElasticNet, rather than once inside of ElasticNet and once inside of _alpha_grid.

@jeremiedbb
Copy link
Member

I consider it a bug fix such that no deprecation cycle is needed.

I agree with that

@jeremiedbb
Copy link
Member

I could simply modify the function sklearn.linear_model._coordinate_descent._alpha_grid, but I'm not sure this would be optimal. I would simply be repeating the calculation of multiplying by sqrt(sample_weight) which is already being done when we fit ElasticNet. Personally, here is how I would approach the problem:

I think it's better to first fix the issue, if there's the possibility of a simple fix as it seems here, with a good non-regression test like you showed in your issue. Then we can think about making some small refactorings in a separate PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment