Skip to content

ENH: Make roc_curve array API compatible #30878

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from

Conversation

lithomas1
Copy link
Contributor

@lithomas1 lithomas1 commented Feb 22, 2025

Reference Issues/PRs

xref #26024

What does this implement/fix? Explain your changes.

Makes roc_curve array API compatible.

Any other comments?

Copy link

github-actions bot commented Feb 22, 2025

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 714b3b0. Link to the linter CI: here

@lithomas1 lithomas1 marked this pull request as ready for review March 2, 2025 03:48
Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @lithomas1

I added some initial comments.

@lithomas1
Copy link
Contributor Author

Thanks for the review, and sorry for the slow reply.

I addressed the device issues (a MPS run on my Intel MBP uncovered some more issues that I fixed).

@lithomas1 lithomas1 requested a review from OmarManzoor March 18, 2025 00:33
Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates @lithomas1.
Mostly this looks good. However let's consider waiting for the array-api-extra PR

lithomas1 and others added 3 commits March 23, 2025 21:32
@@ -552,7 +552,7 @@ def isclose(
xp=xp,
)
if equal_nan:
out = xp.where(xp.isnan(a) & xp.isnan(b), xp.asarray(True), out)
out = xp.where(xp.isnan(a) & xp.isnan(b), True, out)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I found a bug in array-api-extra related to multi-device support.

I'm planning on upstreaming something there to fix this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR merged, please ping me when you need a release

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks @lithomas1

CC: @ogrisel @betatim for a second review

Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
if dtype is not None and dtype not in [np.float32, np.float64]:
dtype = np.float64
if dtype is not None and dtype not in [xp.float32, xp.float64]:
dtype = _max_precision_float_dtype(xp, device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the description of the dtype parameter in the docstring? I think we should change the statement that np.flaot64 will be the dtype of sample_weight if nothing else is specified. Maybe "The output will be the highest precision floating point dtype" or some such?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reworded this slightly to "Otherwise, the output has the highest precision floating point dtype supported by the array namespace/device of the input arrays."

Let me know if that sounds OK to you.

or xp.all(classes == xp.asarray([-1, 1], device=device))
or xp.all(classes == xp.asarray([0], device=device))
or xp.all(classes == xp.asarray([-1], device=device))
or xp.all(classes == xp.asarray([1], device=device))
):
classes_repr = ", ".join([repr(c) for c in classes.tolist()])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why .tolist() works here? I don't think it is part of the array API standard and classes is a array API array, which is why I am surprised that this works. Are we just getting lucky or does it make sense that this works?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd convert classes to a numpy array and then use it here. It should be short, so cheap to convert with _convert_to_numpy from sklearn.utils._array_api

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch, my guess is that we are only hitting this error condition with numpy arrays.
(I don't think the array API checks test any error cases, and this block raises a ValueError).

@lithomas1 lithomas1 requested a review from betatim April 11, 2025 14:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants