-
-
Notifications
You must be signed in to change notification settings - Fork 26.2k
ENH Array API support for confusion_matrix #30562
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH Array API support for confusion_matrix #30562
Conversation
result = confusion_matrix(y_true, y_pred) | ||
xp_result, _ = get_namespace(result) | ||
assert _is_numpy_namespace(xp_result) | ||
|
||
# Since the computation always happens with NumPy / SciPy on the CPU, this | ||
# function is expected to return an array allocated on the CPU even when it does | ||
# not match the input array's device. | ||
assert result.device == "cpu" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have adjusted this test to your suggestions from this comment, @ogrisel. But here, the test is narrower because we made the return value of confusion_matrix
to always be a numpy array on cpu.
Regarding the question how to document the return type of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect the CI failure (about the Test Library) is a false positive, as I couldn't reproduce the same error after running the tests multiple times on my local machine. Instead, I encountered a different set of errors related to CUDA. I believe the issue might be linked to how we convert sample_weight
into a NumPy array.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the updates @StefanieSenger
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your reviews, @OmarManzoor and @lesteve!
I have applied your suggestions and commented regarding the documentation.
Now, this PR looks very straightforward. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM beside the point below.
I am not sure how to best document which scikit-learn functions and classes intentionally rely on internal array API namespace conversions (https://github.com/scikit-learn/scikit-learn/pull/30562/files#r2207423623). I think it's interesting to do it but I agree we can discuss that in a follow-up issue or PR.
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks @StefanieSenger
There are some tests failing which may need to be checked |
Hmm probably some weird interaction with #31701? |
In d1b3439 we fixed the latest failures in the most straightforward (but maybe a bit hacky) way. We also added a test that would have failed in #31701 and would have surfaced the issue. The tension comes from:
A maybe more clean approach would be to pass an additional argument |
@lesteve Thanks for the fixes. I think the current changes look fine. However the approach that you have mentioned about adding in |
I am also fine the current workaround implemented in this PR as it is well explained by the inline comment. +1 for exploring the refactoring suggested by @lesteve in a follow-up PR. |
Let's merge this then! |
Co-authored-by: Virgil Chan <virchan.math@gmail.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Loïc Estève <loic.esteve@ymail.com> Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
Reference Issues/PRs
towards #26024
closes #30440 (supercedes)
This PR is an alternative, discussed in #30440. It accepts array inputs from all namespaces, converts the input arrays to numpy arrays right away to do the calculations in numpy (which is necessary for the coo_matrix at least) and returns the confusion_matrix in the same namespace as the input and on a cpu device.
That's what we had discussed. For more details see the discussions on both PRs.