-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
array API support for mean_gamma_deviance #29239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
array API support for mean_gamma_deviance #29239
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR. I launched the CUDA tests here:
EDIT: they pass.
Assuming they pass, LGTM. Just a suggestion to simplify the tests below:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
I resolved the conflict and enabled auto-merge
Reference Issues/PRs
towards #26024
What does this implement/fix? Explain your changes.
add array API support for
mean_gamma_deviance
Any other comments?
mean_gamma_deviance
is a special case ofmean_tweedie_deviance
where power=2 and both y_pred and y_true must be strictly positive. For this reason I have added the test casecheck_array_api_regression_metric_gamma
(because the y_true incheck_array_api_regression_metric
contains 0 and I didn't want to change a test that is so widely used). I am not sure if this is the best way to approach this, so if there are any suggestions on how to do this better I would love to know. Thanks!!