Skip to content

ENH: Make roc_curve array API compatible #30878

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Jun 18, 2025

Conversation

lithomas1
Copy link
Contributor

@lithomas1 lithomas1 commented Feb 22, 2025

Reference Issues/PRs

xref #26024

What does this implement/fix? Explain your changes.

Makes roc_curve array API compatible.

Any other comments?

Copy link

github-actions bot commented Feb 22, 2025

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: dbdfbd5. Link to the linter CI: here

@lithomas1 lithomas1 marked this pull request as ready for review March 2, 2025 03:48
Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @lithomas1

I added some initial comments.

@lithomas1
Copy link
Contributor Author

Thanks for the review, and sorry for the slow reply.

I addressed the device issues (a MPS run on my Intel MBP uncovered some more issues that I fixed).

@lithomas1 lithomas1 requested a review from OmarManzoor March 18, 2025 00:33
Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates @lithomas1.
Mostly this looks good. However let's consider waiting for the array-api-extra PR

lithomas1 and others added 3 commits March 23, 2025 21:32
Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks @lithomas1

CC: @ogrisel @betatim for a second review

Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
Comment on lines 1259 to 1267
if not xp.all(
xpx.isclose(
xp.take(out, xp.asarray([last_elem_idx], device=device), axis=axis),
expected,
rtol=rtol,
atol=atol,
equal_nan=True,
)
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lithomas1 seems that this part is not working (CI failure shows no warning emitted for cupy)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am at a lose as to what is causing this failure. The docs do talk about automatic type promotion:

CuPy automatically promotes dtypes of cupy.ndarray s in a function with two or more operands

but the failure occurs even when dtype is float64.

Maybe others can shed some light?

Copy link
Contributor

@OmarManzoor OmarManzoor Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure whether we want to test it or not, because this warning is not triggered generally. However I did push a modification to check if that works.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It still failed in one case. @betatim @lesteve What do you think about this? This warning is not triggered easily.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I would be in favor of not calling scikit-learn's stable_cumsum at all in functions with array API support, and instead directly call xp.cumsum.

Once all the code base has been migrated to remove all the calls to stable_cumsum we can just deprecate this function. The extra call to np/xp.sum is adding some unnecessary overhead for no value for the end user: if the cumsum result happens to be unstable, the users cannot do anything about it, so I don't really see the point of this warning.

Copy link
Member

@ogrisel ogrisel Jun 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opened #31533 to recommend to consistently stop using stable_cumsum in our code base. Feel free to express your opinion there.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to merge this PR should we remove the warning code altogether or should we keep the code but stop testing for it?

Copy link
Member

@ogrisel ogrisel Jun 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this PR, I would rather stable_cumsum unchanged (revert to what is was in main) and just not use it in the ROC curve code but instead just use xp.cumsum directly instead of calling stable_cumsum.

Then later, if people agree with the proposal in #31533, we can remove all the other calls to stable_cumsum in other parts of the scikit-learn code base and deprecate stable_cumsum officially (because it is part of our public API and we cannot just delete it).

Comment on lines 1132 to 1137
arr_np = np.asarray(
[[1, 2e-9, 3e-9] * int(1e6)],
)
arr_xp = xp.asarray(
arr_np, dtype=getattr(xp, dtype) if dtype is not None else dtype
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not do:

Suggested change
arr_np = np.asarray(
[[1, 2e-9, 3e-9] * int(1e6)],
)
arr_xp = xp.asarray(
arr_np, dtype=getattr(xp, dtype) if dtype is not None else dtype
)
arr_np = np.asarray(
[[1, 2e-9, 3e-9] * int(1e6)], dypte=dtype
)
arr_xp = xp.asarray(arr_np)

?

Comment on lines 1139 to 1142
assert_allclose(
_convert_to_numpy(stable_cumsum(arr_xp, axis=axis), xp),
np.cumsum(arr_np, axis=axis),
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Brain wave just now: maybe instead of creating arr_np and using _convert_to_numpy we could use xpx.isclose? One reason to have array api extra is to slowly start using what it has to offer instead of having our own way. Though I think we have to call .all() on the result of xpx.isclose because it returns an array, not just one bool? @lucascolley

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could, but then you would have to implement all of the extra features of assert_allclose on top.

We are looking to expose the following as xpx.testing.assert_close when it has everything we need and the public API is agreed upon: https://github.com/data-apis/array-api-extra/blob/main/src/array_api_extra/_lib/_testing.py#L213-L276. It converts to NumPy and uses np.testing at the minute, but feasibly it could use things from other backends instead down the line.

It would be really useful if someone could bump xpx to v0.8.0 in sklearn and try using those private functions, to see what is missing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok, I thought that xpx.isclose is that replacement for assert_allclose :-/

Copy link
Contributor

@lucascolley lucascolley Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, xpx.isclose covers https://numpy.org/doc/2.1/reference/generated/numpy.isclose.html, which is used outside of testing (of course it can be used in tests, but the assertions are more feature-complete)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be really useful if someone could bump xpx to v0.8.0 in sklearn and try using those private functions, to see what is missing.

x-ref data-apis/array-api-extra#17, please chime in there if anyone tries this out!

@lithomas1
Copy link
Contributor Author

Ugh looks like something weird going on with precision for cupy.

Is it OK to just skip the test for cupy (this isn't ideal, but we have some confirmation the check works on GPU for torch already)?
Maybe there's also potentially a test we can steal from cuml?

@OmarManzoor
Copy link
Contributor

Ugh looks like something weird going on with precision for cupy.

Is it OK to just skip the test for cupy (this isn't ideal, but we have some confirmation the check works on GPU for torch already)? Maybe there's also potentially a test we can steal from cuml?

We weren't testing for this warning before or were we?

@lithomas1
Copy link
Contributor Author

Just for numpy, since stable_cumsum wasn't array API compatible before.

@OmarManzoor
Copy link
Contributor

OmarManzoor commented Jun 5, 2025

Just for numpy, since stable_cumsum wasn't array API compatible before.

From my experimentation this warning isn't too common on cuda and it seemed hard to reproduce. Do you think we should maybe test it only for specific conditions and maybe not for cupy or torch?

Also I think some other opinions might be valuable. I'll also tag @ogrisel for his input.

@OmarManzoor
Copy link
Contributor

@lithomas1 Let's do this for stable_cumsum. We can keep the original test that we had with just numpy that would also test the warning. For the array api test, let's just test for equivalence of the function with numpy outputs.

@ogrisel
Copy link
Member

ogrisel commented Jun 12, 2025

Also I think some other opinions might be valuable. I'll also tag @ogrisel for his input.

Thanks for the ping. I replied in the original thread above: #30878 (comment)

For the particular case of this PR, I think we should revert the changes under sklearn/utils and just call xp.cumsum directly in the code that computes the ROC curve.

@OmarManzoor
Copy link
Contributor

For the particular case of this PR, I think we should revert the changes under sklearn/utils and just call xp.cumsum directly in the code that computes the ROC curve.

That sounds good. Thanks.

@ogrisel
Copy link
Member

ogrisel commented Jun 13, 2025

As shown by @lesteve in #31533 (comment), casting to xp.float64 before computing the cumsum can be required to get roc_auc_score to return the correct value on some dataset. This PR should add a non-regression test based on Loïc's comment if it's not already present in our test suite.

@OmarManzoor
Copy link
Contributor

@ogrisel I made the updates that were discussed. Though I added a non regression test since it needs a large number of samples it takes a bit of time to run. Could you kindly have a look at the changes?

@OmarManzoor
Copy link
Contributor

The failed test is due to the fact that the size of the array allocation is too large because of the large number of samples

@OmarManzoor
Copy link
Contributor

@lesteve Do you have any suggestions on how we can keep this test?

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the cost to maintain a proper non-regression test, I think an inline comment will do.

OmarManzoor and others added 4 commits June 18, 2025 18:05
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
@OmarManzoor
Copy link
Contributor

I think since the suggestions have been resolved with respect to this PR, let's merge this. Thank you for the work @lithomas1

@OmarManzoor OmarManzoor merged commit dab0842 into scikit-learn:main Jun 18, 2025
36 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants