Skip to content

Conversation

StefanieSenger
Copy link
Member

@StefanieSenger StefanieSenger commented Dec 30, 2024

Reference Issues/PRs

towards #26024
closes #30440 (supercedes)

This PR is an alternative, discussed in #30440. It accepts array inputs from all namespaces, converts the input arrays to numpy arrays right away to do the calculations in numpy (which is necessary for the coo_matrix at least) and returns the confusion_matrix in the same namespace as the input and on a cpu device.

That's what we had discussed. For more details see the discussions on both PRs.

Copy link

github-actions bot commented Dec 30, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 5b103e7. Link to the linter CI: here

Comment on lines 3120 to 3127
result = confusion_matrix(y_true, y_pred)
xp_result, _ = get_namespace(result)
assert _is_numpy_namespace(xp_result)

# Since the computation always happens with NumPy / SciPy on the CPU, this
# function is expected to return an array allocated on the CPU even when it does
# not match the input array's device.
assert result.device == "cpu"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have adjusted this test to your suggestions from this comment, @ogrisel. But here, the test is narrower because we made the return value of confusion_matrix to always be a numpy array on cpu.

@StefanieSenger
Copy link
Member Author

Regarding the question how to document the return type of confusion_matrix() as a numpy array, I think that keeping
C : ndarray of shape (n_classes, n_classes) in the docstring should be enough, assumed that all the other functions and methods where we have added array api support document the return value type correctly, which is currently not the case.

Copy link
Member

@virchan virchan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect the CI failure (about the Test Library) is a false positive, as I couldn't reproduce the same error after running the tests multiple times on my local machine. Instead, I encountered a different set of errors related to CUDA. I believe the issue might be linked to how we convert sample_weight into a NumPy array.

@StefanieSenger StefanieSenger changed the title ENH Array API support for confusion_matrix converting to numpy array ENH Array API support for confusion_matrix Jul 15, 2025
Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the updates @StefanieSenger

Copy link
Member Author

@StefanieSenger StefanieSenger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your reviews, @OmarManzoor and @lesteve!
I have applied your suggestions and commented regarding the documentation.

Now, this PR looks very straightforward. :)

Copy link
Member

@lesteve lesteve left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM beside the point below.

I am not sure how to best document which scikit-learn functions and classes intentionally rely on internal array API namespace conversions (https://github.com/scikit-learn/scikit-learn/pull/30562/files#r2207423623). I think it's interesting to do it but I agree we can discuss that in a follow-up issue or PR.

Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks @StefanieSenger

@OmarManzoor
Copy link
Contributor

There are some tests failing which may need to be checked

@lesteve
Copy link
Member

lesteve commented Aug 1, 2025

Hmm probably some weird interaction with #31701?

@lesteve
Copy link
Member

lesteve commented Aug 1, 2025

In d1b3439 we fixed the latest failures in the most straightforward (but maybe a bit hacky) way. We also added a test that would have failed in #31701 and would have surfaced the issue.

The tension comes from:

A maybe more clean approach would be to pass an additional argument ensure_min_samples in _check_targets that gets passed to _check_sample_weight and eventually to ensure that ensure_min_samples=ensure_min_samples is passed to the check_array(sample_weight) call. This seems a bit too much since it seems like confusion_matrix is a bit special in being able to handle empty inputs?

@OmarManzoor
Copy link
Contributor

@lesteve Thanks for the fixes. I think the current changes look fine. However the approach that you have mentioned about adding in ensure_min_samples also sounds good. Wouldn't that approach be better considering that such a parameter is actually available in the check_array function?

@ogrisel
Copy link
Member

ogrisel commented Aug 4, 2025

I am also fine the current workaround implemented in this PR as it is well explained by the inline comment.

+1 for exploring the refactoring suggested by @lesteve in a follow-up PR.

@OmarManzoor
Copy link
Contributor

OmarManzoor commented Aug 4, 2025

Let's merge this then!

@OmarManzoor OmarManzoor enabled auto-merge (squash) August 4, 2025 08:22
@OmarManzoor OmarManzoor merged commit 1ff785e into scikit-learn:main Aug 4, 2025
34 checks passed
@StefanieSenger StefanieSenger deleted the array_api_confusion_matrix_numpy branch August 4, 2025 11:15
lucyleeow pushed a commit to lucyleeow/scikit-learn that referenced this pull request Aug 22, 2025
Co-authored-by: Virgil Chan <virchan.math@gmail.com>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Loïc Estève <loic.esteve@ymail.com>
Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants