Skip to content

ENH: Make roc_curve array API compatible #30878

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Jun 18, 2025

Conversation

lithomas1
Copy link
Contributor

@lithomas1 lithomas1 commented Feb 22, 2025

Reference Issues/PRs

xref #26024

What does this implement/fix? Explain your changes.

Makes roc_curve array API compatible.

Any other comments?

Copy link

github-actions bot commented Feb 22, 2025

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: dbdfbd5. Link to the linter CI: here

@lithomas1 lithomas1 marked this pull request as ready for review March 2, 2025 03:48
Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @lithomas1

I added some initial comments.

@lithomas1
Copy link
Contributor Author

Thanks for the review, and sorry for the slow reply.

I addressed the device issues (a MPS run on my Intel MBP uncovered some more issues that I fixed).

@lithomas1 lithomas1 requested a review from OmarManzoor March 18, 2025 00:33
Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates @lithomas1.
Mostly this looks good. However let's consider waiting for the array-api-extra PR

lithomas1 and others added 3 commits March 23, 2025 21:32
Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks @lithomas1

CC: @ogrisel @betatim for a second review

Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
Comment on lines 1132 to 1137
arr_np = np.asarray(
[[1, 2e-9, 3e-9] * int(1e6)],
)
arr_xp = xp.asarray(
arr_np, dtype=getattr(xp, dtype) if dtype is not None else dtype
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not do:

Suggested change
arr_np = np.asarray(
[[1, 2e-9, 3e-9] * int(1e6)],
)
arr_xp = xp.asarray(
arr_np, dtype=getattr(xp, dtype) if dtype is not None else dtype
)
arr_np = np.asarray(
[[1, 2e-9, 3e-9] * int(1e6)], dypte=dtype
)
arr_xp = xp.asarray(arr_np)

?

Comment on lines 1139 to 1142
assert_allclose(
_convert_to_numpy(stable_cumsum(arr_xp, axis=axis), xp),
np.cumsum(arr_np, axis=axis),
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Brain wave just now: maybe instead of creating arr_np and using _convert_to_numpy we could use xpx.isclose? One reason to have array api extra is to slowly start using what it has to offer instead of having our own way. Though I think we have to call .all() on the result of xpx.isclose because it returns an array, not just one bool? @lucascolley

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could, but then you would have to implement all of the extra features of assert_allclose on top.

We are looking to expose the following as xpx.testing.assert_close when it has everything we need and the public API is agreed upon: https://github.com/data-apis/array-api-extra/blob/main/src/array_api_extra/_lib/_testing.py#L213-L276. It converts to NumPy and uses np.testing at the minute, but feasibly it could use things from other backends instead down the line.

It would be really useful if someone could bump xpx to v0.8.0 in sklearn and try using those private functions, to see what is missing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok, I thought that xpx.isclose is that replacement for assert_allclose :-/

Copy link
Contributor

@lucascolley lucascolley Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, xpx.isclose covers https://numpy.org/doc/2.1/reference/generated/numpy.isclose.html, which is used outside of testing (of course it can be used in tests, but the assertions are more feature-complete)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be really useful if someone could bump xpx to v0.8.0 in sklearn and try using those private functions, to see what is missing.

x-ref data-apis/array-api-extra#17, please chime in there if anyone tries this out!

@lithomas1
Copy link
Contributor Author

Ugh looks like something weird going on with precision for cupy.

Is it OK to just skip the test for cupy (this isn't ideal, but we have some confirmation the check works on GPU for torch already)?
Maybe there's also potentially a test we can steal from cuml?

@OmarManzoor
Copy link
Contributor

Ugh looks like something weird going on with precision for cupy.

Is it OK to just skip the test for cupy (this isn't ideal, but we have some confirmation the check works on GPU for torch already)? Maybe there's also potentially a test we can steal from cuml?

We weren't testing for this warning before or were we?

@lithomas1
Copy link
Contributor Author

Just for numpy, since stable_cumsum wasn't array API compatible before.

@OmarManzoor
Copy link
Contributor

OmarManzoor commented Jun 5, 2025

Just for numpy, since stable_cumsum wasn't array API compatible before.

From my experimentation this warning isn't too common on cuda and it seemed hard to reproduce. Do you think we should maybe test it only for specific conditions and maybe not for cupy or torch?

Also I think some other opinions might be valuable. I'll also tag @ogrisel for his input.

@OmarManzoor
Copy link
Contributor

@lithomas1 Let's do this for stable_cumsum. We can keep the original test that we had with just numpy that would also test the warning. For the array api test, let's just test for equivalence of the function with numpy outputs.

@ogrisel
Copy link
Member

ogrisel commented Jun 12, 2025

Also I think some other opinions might be valuable. I'll also tag @ogrisel for his input.

Thanks for the ping. I replied in the original thread above: #30878 (comment)

For the particular case of this PR, I think we should revert the changes under sklearn/utils and just call xp.cumsum directly in the code that computes the ROC curve.

@OmarManzoor
Copy link
Contributor

For the particular case of this PR, I think we should revert the changes under sklearn/utils and just call xp.cumsum directly in the code that computes the ROC curve.

That sounds good. Thanks.

@ogrisel
Copy link
Member

ogrisel commented Jun 13, 2025

As shown by @lesteve in #31533 (comment), casting to xp.float64 before computing the cumsum can be required to get roc_auc_score to return the correct value on some dataset. This PR should add a non-regression test based on Loïc's comment if it's not already present in our test suite.

@OmarManzoor
Copy link
Contributor

@ogrisel I made the updates that were discussed. Though I added a non regression test since it needs a large number of samples it takes a bit of time to run. Could you kindly have a look at the changes?

@OmarManzoor
Copy link
Contributor

The failed test is due to the fact that the size of the array allocation is too large because of the large number of samples

@OmarManzoor
Copy link
Contributor

@lesteve Do you have any suggestions on how we can keep this test?

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the cost to maintain a proper non-regression test, I think an inline comment will do.

OmarManzoor and others added 4 commits June 18, 2025 18:05
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
@OmarManzoor
Copy link
Contributor

I think since the suggestions have been resolved with respect to this PR, let's merge this. Thank you for the work @lithomas1

@OmarManzoor OmarManzoor merged commit dab0842 into scikit-learn:main Jun 18, 2025
36 checks passed
@jeremiedbb jeremiedbb mentioned this pull request Jul 15, 2025
13 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants