-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
ENH: Make roc_curve array API compatible #30878
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR @lithomas1
I added some initial comments.
Thanks for the review, and sorry for the slow reply. I addressed the device issues (a MPS run on my Intel MBP uncovered some more issues that I fixed). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates @lithomas1.
Mostly this looks good. However let's consider waiting for the array-api-extra PR
Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
@@ -552,7 +552,7 @@ def isclose( | |||
xp=xp, | |||
) | |||
if equal_nan: | |||
out = xp.where(xp.isnan(a) & xp.isnan(b), xp.asarray(True), out) | |||
out = xp.where(xp.isnan(a) & xp.isnan(b), True, out) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I found a bug in array-api-extra related to multi-device support.
I'm planning on upstreaming something there to fix this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR merged, please ping me when you need a release
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks @lithomas1
Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
if dtype is not None and dtype not in [np.float32, np.float64]: | ||
dtype = np.float64 | ||
if dtype is not None and dtype not in [xp.float32, xp.float64]: | ||
dtype = _max_precision_float_dtype(xp, device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you update the description of the dtype
parameter in the docstring? I think we should change the statement that np.flaot64
will be the dtype of sample_weight
if nothing else is specified. Maybe "The output will be the highest precision floating point dtype" or some such?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reworded this slightly to "Otherwise, the output has the highest precision floating point dtype supported by the array namespace/device of the input arrays."
Let me know if that sounds OK to you.
or xp.all(classes == xp.asarray([-1, 1], device=device)) | ||
or xp.all(classes == xp.asarray([0], device=device)) | ||
or xp.all(classes == xp.asarray([-1], device=device)) | ||
or xp.all(classes == xp.asarray([1], device=device)) | ||
): | ||
classes_repr = ", ".join([repr(c) for c in classes.tolist()]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know why .tolist()
works here? I don't think it is part of the array API standard and classes
is a array API array, which is why I am surprised that this works. Are we just getting lucky or does it make sense that this works?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd convert classes
to a numpy array and then use it here. It should be short, so cheap to convert with _convert_to_numpy
from sklearn.utils._array_api
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the catch, my guess is that we are only hitting this error condition with numpy arrays.
(I don't think the array API checks test any error cases, and this block raises a ValueError).
Reference Issues/PRs
xref #26024
What does this implement/fix? Explain your changes.
Makes roc_curve array API compatible.
Any other comments?