Skip to content

Fix array api integration for additive_chi2_kernel with torch mps #29256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

OmarManzoor
Copy link
Contributor

Reference Issues/PRs

Follow up of #29144

What does this implement/fix? Explain your changes.

  • Fixes a test that was failing for the additive_chi2_kernel array api with torch mps

Any other comments?

Copy link

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 43d860c. Link to the linter CI: here

@Tialo
Copy link
Contributor

Tialo commented Jun 14, 2024

Oh, I overlooked it has to be on the same device. Maybe common tests should also assert it? It should spot the error at least in gpu workflows

@OmarManzoor
Copy link
Contributor Author

Oh, I overlooked it has to be on the same device. Maybe common tests should also assert it? It should spot the error at least in gpu workflows

I think this might be specifically an issue with mps because it seems to work fine with cuda.

Copy link
Contributor

@EdAbati EdAbati left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I confirm that this fix is needed for the mps device 😕

@@ -30,6 +30,7 @@
from ..utils._array_api import (
_find_matching_floating_dtype,
_is_numpy_namespace,
device,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could also use get_namespace_and_device and substitute it to xp, _ = get_namespace(X, Y) below

@EdAbati
Copy link
Contributor

EdAbati commented Jun 17, 2024

FYI I have updated to the latest torch (2.3.1) and it looks like specifying the device here is not required anymore when using mps. 👀

@ogrisel
Copy link
Member

ogrisel commented Jun 18, 2024

FYI I have updated to the latest torch (2.3.1) and it looks like specifying the device here is not required anymore when using mps. 👀

Nice that they did this. It was verbose to have to specify the device even for scalar arrays.

I am -0 for merging this PR. Since array API is very new, one can expect people who want to try it to use the latest stable versions of libraries and I would rather keep our code based not too verbose if not necessary.

What other people think? We could make that requirement explicit in the _array_api.rst file.

@OmarManzoor
Copy link
Contributor Author

FYI I have updated to the latest torch (2.3.1) and it looks like specifying the device here is not required anymore when using mps. 👀

Nice that they did this. It was verbose to have to specify the device even for scalar arrays.

I am -0 for merging this PR. Since array API is very new, one can expect people who want to try it to use the latest stable versions of libraries and I would rather keep our code based not too verbose if not necessary.

What other people think? We could make that requirement explicit in the _array_api.rst file.

I think it makes sense if it is fixed. No need for the change. @ogrisel should we close this PR?

@ogrisel
Copy link
Member

ogrisel commented Jun 18, 2024

Let's wait a bit to let others have time to express their opinion.

@ogrisel
Copy link
Member

ogrisel commented Jun 20, 2024

Let's close, we can reopen if others disagree.

@ogrisel ogrisel closed this Jun 20, 2024
@ogrisel
Copy link
Member

ogrisel commented Jun 20, 2024

Note: we would have to check with other libraries. For instance I think jax support in array-api-compat is improving so we might include it in our test suite at some point and we might end-up needing this change.

@OmarManzoor OmarManzoor deleted the torch_mps_fix_for_additive_chi2_kernel branch June 20, 2024 09:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants