-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
ENH migrate GLMs / TweedieRegressor to linear loss #22548
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH migrate GLMs / TweedieRegressor to linear loss #22548
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some hints for reviewers.
class HalfTweedieLossIdentity(BaseLoss): | ||
"""Half Tweedie deviance loss with identity link, for regression. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This new loss class is needed for TweedieRegressor(link="identity")
.
See Cython implementation in file _loss.pyx.tp
.
sklearn/linear_model/_glm/glm.py
Outdated
family : {'normal', 'poisson', 'gamma', 'inverse-gaussian'} \ | ||
or an ExponentialDispersionModel instance, default='normal' | ||
The distributional assumption of the GLM, i.e. which distribution from | ||
the EDM, specifies the loss function to be minimized. | ||
base_loss_class : subclass of BaseLoss, default=HalfSquaredError | ||
A `base_loss_class` contains a specific loss function as well as the link | ||
function. The loss to be minimized specifies the distributional assumption of | ||
the GLM, i.e. the distribution from the EDM. Here are some examples: | ||
|
||
======================= ======== ========================== | ||
base_loss_class Link Target Domain | ||
======================= ======== ========================== | ||
HalfSquaredError identity y any real number | ||
HalfPoissonLoss log 0 <= y | ||
HalfGammaLoss log 0 < y | ||
HalfInverseGaussianLoss log 0 < y | ||
HalfTweedieLoss log dependend on tweedie power | ||
======================= ======== ========================== | ||
|
||
link : {'auto', 'identity', 'log'} or an instance of class BaseLink, \ | ||
default='auto' | ||
The link function of the GLM, i.e. mapping from linear predictor | ||
`X @ coeff + intercept` to prediction `y_pred`. Option 'auto' sets | ||
the link depending on the chosen family as follows: | ||
`X @ coeff + intercept` to prediction `y_pred`. For instance, with a log link, | ||
we have `y_pred = exp(X @ coeff + intercept)`. | ||
|
||
- 'identity' for Normal distribution | ||
- 'log' for Poisson, Gamma and Inverse Gaussian distributions | ||
base_loss_params : dictionary, default={} | ||
Arguments to be passed to base_loss_class, e.g. {"power": 1.5} with | ||
`base_loss_class=HalfTweedieLoss`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
family
and link
were private attributes. They are replaced by base_loss_class
and base_loss_params
. The new losses have the link functions baked into them.
sklearn/linear_model/_glm/glm.py
Outdated
self._linear_loss = LinearModelLoss( | ||
base_loss=self._get_base_loss_instance(), | ||
fit_intercept=self.fit_intercept, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the reason for this PR, to use LinearModelLoss
!
@pytest.mark.parametrize( | ||
"name, link_class", [("identity", IdentityLink), ("log", LogLink)] | ||
) | ||
def test_tweedie_link_argument(name, link_class): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replaces former test_glm_link_argument
.
(3, LogLink), # inverse-gaussian | ||
], | ||
) | ||
def test_tweedie_link_auto(power, expected_link_class): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replaces former test_glm_link_auto
target_type=numbers.Real, | ||
) | ||
|
||
message = f"Mean Tweedie deviance error with power={p} can only be used on " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From here on, it is really copy&paste from the now(edit: in 2 releases to be) deleted glm_distribution.py
.
This might interest @agramfort @rth @TomDLT. |
What changed your mind? |
I changed my mind, too. Let's do proper deprecation and hopefully move forward with this. This PR is the last missing peace for possible new 2nd order solvers for all/most GLMs! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you very much for the PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I opened a PR to your fork to showcase a way to avoid _base_loss
:
lorentzenchr#3
CLN Idea removing _base_loss
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I doubled checked the math for HalfTweedieLossIdentity
and it looks good to me. I have a few small comments on the code.
There is an issue with d2_tweedie_score
in terms of memory.
Small refactor for _mean_tweedie_deviance
…tzenchr/scikit-learn into migrate_glm_to_linear_loss
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ran the benchmarks from #22548 (review) and see the same performance improvements.
LGTM
@@ -770,6 +771,52 @@ def constant_to_optimal_zero(self, y_true, sample_weight=None): | |||
return term | |||
|
|||
|
|||
class HalfTweedieLossIdentity(BaseLoss): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My curiosity: when does it make sense to use identity link with power != 0 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Years ago, I thought it a good idea. Meanwhile, I don't think it's is useful. Therefore, I opened #19086 without much response.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My curiosity: when does it make sense to use identity link with power != 0 ?
I have a problem where I know the expectation y' follows the linear model y' = w x. My measurements, y, have poisson errors.
(The specific problem involves analysis of radiation measurements. The expectation is linear with the amount of source; the measurements are poisson distributed).
Using a log link function is just not the right description of my problem. Yes, the whole thing breaks down when evaluating negative values of w, but it seems much better to offer a constraint to avoid ever evaluating negative values of w rather than exclude the situations where you have an actual linear relationship with poisson measurements.
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Reference Issues/PRs
This is a follow-up of #21808 and #20567.
It also fixes #22124 (~partial fix of #21406).
What does this implement/fix? Explain your changes.
This PR plugs in the new
LinearModuleLoss
in the privateGeneralizedLinearRegressor
, thereby removingsklearn._loss.glm_distribution.py
andsklearn.linear_model._glm/link.py
.The tweedie deviance code is copy&pasted into the metric
mean_tweedie_deviance
.Any other comments?
It should be a user API backward compatible change (
PoissonRegressor
,GammaRegressor
andTweedieRegressor
,mean_tweedie_deviance
).