Skip to content

Add sample_weight to the calculation of alphas in enet_path and LinearModelCV #23045

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from

Conversation

s-banach
Copy link
Contributor

@s-banach s-banach commented Apr 4, 2022

Reference Issues/PRs

Fixes #22914.

What does this implement/fix? Explain your changes.

Modifies _alpha_grid function in linear_model._coordinate_descent to accept a sample_weight argument.

The function _alpha_grid is called in two places, enet_path and LinearModelCV.
The new sample_weight argument is not used by enet_path, but it is used by LinearModelCV.

Any other comments?

Since my previous PR on this issue, _preprocess_data has been rewritten.

s-banach added 2 commits April 3, 2022 22:11
It seems like this single call to _preprocess_data suffices in all cases.
This tiny example was given in #22914.
The test merely asserts that alpha_max is large enough to force the coefficient to 0.
Copy link
Member

@lorentzenchr lorentzenchr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A first round of review comments.

s-banach added 3 commits April 4, 2022 09:06
As per reviewer's suggestions:
(1) Clarify eps=1.
(2) Parameterize `fit_intercept`.
(1) Give the name `n_samples` to the quantity `X.shape[0]`.
(2) Clarify that `y_offset` and `X_scale` are not used, since these are already applied to the data by `_preprocess_data`.
@lorentzenchr lorentzenchr added this to the 1.2 milestone Apr 21, 2022
@lorentzenchr
Copy link
Member

@TomDLT May I kindly ping you as your help would be much appreciated.

Copy link
Member

@TomDLT TomDLT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, although I did not check the math.

Main remark: The new test function tests that the computed alpha_max is larger or equal to the true alpha_max. To test that they are actually equal, we could test that alpha_max * 0.99 does not return all-zero coefficients.

We could also add a test that the computation still works without sample weights.

@s-banach
Copy link
Contributor Author

s-banach commented Jul 2, 2022

Main remark: The new test function tests that the computed alpha_max is larger or equal to the true alpha_max. To test that they are actually equal, we could test that alpha_max * 0.99 does not return all-zero coefficients.

I have attempted to update the test according to your recommendation.
It now checks that the max abs coefficient is greater than 1e-3 when alpha=0.99*alpha_max.

We could also add a test that the computation still works without sample weights.

My feeling is that test_enet_cv_sample_weight_consistency basically guarantees this already.
Let me know your thoughts.

@s-banach
Copy link
Contributor Author

s-banach commented Jul 2, 2022

The main thing I'm confused about, is why it's even possible for _alpha_grid to be called before X and y are appropriately scaled by sample_weight. It seems that enet_path or LinearModelCV or something should be refactored such that the call to _preprocess_data can be removed from _alpha_grid.

@TomDLT
Copy link
Member

TomDLT commented Jul 5, 2022

My feeling is that test_enet_cv_sample_weight_consistency basically guarantees this already.

I don't think it guarantees that the alpha_max computation is correct. To do so, we could add a @pytest.mark.parametrize("sample_weight", [[10, 1, 10, 1], None]) to the new test.

The main thing I'm confused about, is why it's even possible for _alpha_grid to be called before X and y are appropriately scaled by sample_weight.

It seems weird indeed. It seems that _pre_fit is called either before _alpha_grid in enet_path, or after _alpha_grid in LinearModelCV.fit (in _path_residuals). We should clarify the situation.

@s-banach
Copy link
Contributor Author

s-banach commented Jul 6, 2022

Per your suggestion, I parameterized the new test by sample_weight.

I know my opinion on this matter isn't very valuable, but I'll share anyway.
As you say, _alpha_grid can currently be called from two contexts: from a path method such as enet_path, or from a LinearModelCV.
Currently, LinearModelCV works by finding the best alpha using CV, then feeding that alpha into the non-CV version of the estimator. Instead, I think LinearModelCV should begin by using path to fit the full dataset. This will allow the user to see the full coef_path after fitting the model, and it may even improve the total runtime due to warm starting.
If this change is made, then _alpha_grid will only ever be called from within a path method. Thus _alpha_grid will not be asked to deal with sample_weight at all.

@TomDLT
Copy link
Member

TomDLT commented Jul 6, 2022

I agree it makes more sense to compute the alpha grid within the path function. We would need to have _path_residuals return the computed alphas, but this is not a problem because the function is private.

We might still need to have _alpha_grid deal with sample_weights in the sparse case.

@jeremiedbb
Copy link
Member

We won't have time to review this one before the 1.2 release. Moving it to 1.3

@jeremiedbb jeremiedbb closed this Nov 24, 2022
@jeremiedbb jeremiedbb reopened this Nov 24, 2022
@jeremiedbb jeremiedbb modified the milestones: 1.2, 1.3 Nov 24, 2022
@jeremiedbb
Copy link
Member

(didn't mean to close :/ )

@jeremiedbb jeremiedbb modified the milestones: 1.3, 1.4 Jul 6, 2023
@glemaitre glemaitre removed this from the 1.4 milestone Dec 7, 2023
@s-banach s-banach closed this by deleting the head repository Jan 23, 2024
@snath-xoc snath-xoc mentioned this pull request Jun 19, 2024
1 task
@snath-xoc snath-xoc mentioned this pull request Jul 9, 2024
1 task
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Calculation of alphas in ElasticNetCV doesn't use sample_weight
5 participants