-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
Fix array api integration for additive_chi2_kernel with torch mps #29256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix array api integration for additive_chi2_kernel with torch mps #29256
Conversation
Oh, I overlooked it has to be on the same device. Maybe common tests should also assert it? It should spot the error at least in gpu workflows |
I think this might be specifically an issue with mps because it seems to work fine with cuda. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I confirm that this fix is needed for the mps
device 😕
@@ -30,6 +30,7 @@ | |||
from ..utils._array_api import ( | |||
_find_matching_floating_dtype, | |||
_is_numpy_namespace, | |||
device, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we could also use get_namespace_and_device
and substitute it to xp, _ = get_namespace(X, Y)
below
FYI I have updated to the latest |
Nice that they did this. It was verbose to have to specify the device even for scalar arrays. I am -0 for merging this PR. Since array API is very new, one can expect people who want to try it to use the latest stable versions of libraries and I would rather keep our code based not too verbose if not necessary. What other people think? We could make that requirement explicit in the |
I think it makes sense if it is fixed. No need for the change. @ogrisel should we close this PR? |
Let's wait a bit to let others have time to express their opinion. |
Let's close, we can reopen if others disagree. |
Note: we would have to check with other libraries. For instance I think jax support in |
Reference Issues/PRs
Follow up of #29144
What does this implement/fix? Explain your changes.
Any other comments?