-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
ENH Use Array API in r2_score
#27904
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Co-authored-by: Tim Head <betatim@gmail.com>
Co-authored-by: Tim Head <betatim@gmail.com>
Some Array API compatible libraries do not have a device called 'cpu'. Instead we try and detect the lib+device combination that does not support float64.
…test_affinity_propagation` (scikit-learn#27095) Signed-off-by: Julien Jerphanion <git@jjerphan.xyz> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
…nto ENH/r2_score_array_api
@betatim @fcharras @adrinjalali I think this is ready for review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments. Haven't looked at the tests yet but the rest looks nice.
a = xp.astype(a, output_dtype) | ||
|
||
if weights is None: | ||
return (xp.mean if normalize else xp.sum)(a, axis=axis) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
😮
Kinda cool that we can do this in Python, but also strong stuff :D
Co-authored-by: Tim Head <betatim@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice. Other than the nits, this looks quite good to me now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @ogrisel
Thanks for the reviews @adrinjalali and @betatim. This PR has been much simplified as a result of those reviews. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
What a mission this PR was! Thanks everyone who helped, I think it was worth the effort and wait :D |
Thanks everyone for continuing this PR, I now caught up with latest diff and I'm also a happy bunny. I want to mention 2 differences I think I've spotted between the state where I had left the branch and what has been merged:
I'm happy that I got to learn the existence of We did make very conservative choices on this PR initially and in the end that was a source of several iterations, I'll try to be better at thinking at what is actually needed in scikit-learn and the cost in complexity. As it has been pointed out already, I suspect some of the tools that we introduced then dropped during the PR (like Last word, we had initiated documenting the policy when dealing with array api dispatch with no float64 support at #28034 , now I'm a bit unsure if #27904 (comment) had everyone aligned or if it moved to something a bit different in this PR, I'll try to sum up again and update it. |
Reference Issues/PRs
The PR builds on preliminary explorations done by @elindgren in #27102
It tackles one of the items outlined in #26024.
Any other comments?
This PR proposes to fallbacks to cpu+numpy at the very beginning of the r2_score function whenever the array namespace and the device can't handle float64 precision, because explicit castings to float64 are unavoidable and are used in a lot of steps.
It also proposes improved ways to detect device support for dtypes, and uses it to act accordingly in r2_score and _average, but also updates weighted_sum function.