Skip to content

[MPS] Add support for two more isin variants #154010

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from

Conversation

malfet
Copy link
Contributor

@malfet malfet commented May 21, 2025

Stack from ghstack (oldest at bottom):

isin_Tensor_Scalar_out is just a redispatch to eq/neq
isin_Scalar_Tensor_out redispatches back to generic isin op, but needs a small tweak to handle float scalars
Make sure that out is resized to an expected value in isin_Tensor_Tensor_out_mps

Add unittests to validate that, but skip them on MacOS-13, where MPS op just returns garbage

Before this change both of those failed

>>> import torch
>>> t = torch.tensor([0, 1, 2], device='mps')
>>> torch.isin(t, 1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
NotImplementedError: The operator 'aten::isin.Tensor_Scalar_out' is not currently implemented for the MPS device. If you want this op to be considered for addition please comment on https://github.com/pytorch/pytorch/issues/141287 and mention use-case, that resulted in missing op as well as commit hash 3b875c25ea6d8802a0c53af9eb961ddf2f058188. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
>>> torch.isin(1, t)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
NotImplementedError: The operator 'aten::isin.Scalar_Tensor_out' is not currently implemented for the MPS device. If you want this op to be considered for addition please comment on https://github.com/pytorch/pytorch/issues/141287 and mention use-case, that resulted in missing op as well as commit hash 3b875c25ea6d8802a0c53af9eb961ddf2f058188. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

[ghstack-poisoned]
@malfet malfet requested a review from kulinseth as a code owner May 21, 2025 03:48
Copy link

pytorch-bot bot commented May 21, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/154010

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 72 Pending

As of commit 64c8b36 with merge base 2e56ce0 (image):
💚 Looks good so far! There are no failures yet. 💚

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the ciflow/mps Run MPS tests (subset of trunk) label May 21, 2025
malfet added a commit that referenced this pull request May 21, 2025
`isin_Tensor_Scalar_out` is just a redispatch to eq/neq
`isin_Scalar_Tensor_out` redispatches back to generic `isin` op

Add unittests to validate that
Before this change both of those failed
```python
>>> import torch
>>> t = torch.tensor([0, 1, 2], device='mps')
>>> torch.isin(t, 1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
NotImplementedError: The operator 'aten::isin.Tensor_Scalar_out' is not currently implemented for the MPS device. If you want this op to be considered for addition please comment on #141287 and mention use-case, that resulted in missing op as well as commit hash 3b875c2. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
>>> torch.isin(1, t)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
NotImplementedError: The operator 'aten::isin.Scalar_Tensor_out' is not currently implemented for the MPS device. If you want this op to be considered for addition please comment on #141287 and mention use-case, that resulted in missing op as well as commit hash 3b875c2. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
```

ghstack-source-id: 9c8eb66
Pull Request resolved: #154010
@malfet malfet requested review from dcci, albanD and manuelcandales May 21, 2025 03:48
Copy link
Contributor

Attention! native_functions.yaml was changed

If you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info.


Caused by:

@malfet malfet added topic: improvements topic category release notes: mps Release notes category labels May 21, 2025
@albanD albanD removed their request for review May 21, 2025 17:28
[ghstack-poisoned]
malfet added a commit that referenced this pull request May 21, 2025
`isin_Tensor_Scalar_out` is just a redispatch to eq/neq
`isin_Scalar_Tensor_out` redispatches back to generic `isin` op

Add unittests to validate that
Before this change both of those failed
```python
>>> import torch
>>> t = torch.tensor([0, 1, 2], device='mps')
>>> torch.isin(t, 1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
NotImplementedError: The operator 'aten::isin.Tensor_Scalar_out' is not currently implemented for the MPS device. If you want this op to be considered for addition please comment on #141287 and mention use-case, that resulted in missing op as well as commit hash 3b875c2. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
>>> torch.isin(1, t)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
NotImplementedError: The operator 'aten::isin.Scalar_Tensor_out' is not currently implemented for the MPS device. If you want this op to be considered for addition please comment on #141287 and mention use-case, that resulted in missing op as well as commit hash 3b875c2. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
```

ghstack-source-id: 123db4e
Pull Request resolved: #154010
[ghstack-poisoned]
malfet added a commit that referenced this pull request May 22, 2025
`isin_Tensor_Scalar_out` is just a redispatch to eq/neq
`isin_Scalar_Tensor_out` redispatches back to generic `isin` op

Add unittests to validate that
Before this change both of those failed
```python
>>> import torch
>>> t = torch.tensor([0, 1, 2], device='mps')
>>> torch.isin(t, 1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
NotImplementedError: The operator 'aten::isin.Tensor_Scalar_out' is not currently implemented for the MPS device. If you want this op to be considered for addition please comment on #141287 and mention use-case, that resulted in missing op as well as commit hash 3b875c2. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
>>> torch.isin(1, t)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
NotImplementedError: The operator 'aten::isin.Scalar_Tensor_out' is not currently implemented for the MPS device. If you want this op to be considered for addition please comment on #141287 and mention use-case, that resulted in missing op as well as commit hash 3b875c2. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
```

ghstack-source-id: 99c45ae
Pull Request resolved: #154010
[ghstack-poisoned]
malfet added a commit that referenced this pull request May 22, 2025
`isin_Tensor_Scalar_out` is just a redispatch to eq/neq
`isin_Scalar_Tensor_out` redispatches back to generic `isin` op

Add unittests to validate that
Before this change both of those failed
```python
>>> import torch
>>> t = torch.tensor([0, 1, 2], device='mps')
>>> torch.isin(t, 1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
NotImplementedError: The operator 'aten::isin.Tensor_Scalar_out' is not currently implemented for the MPS device. If you want this op to be considered for addition please comment on #141287 and mention use-case, that resulted in missing op as well as commit hash 3b875c2. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
>>> torch.isin(1, t)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
NotImplementedError: The operator 'aten::isin.Scalar_Tensor_out' is not currently implemented for the MPS device. If you want this op to be considered for addition please comment on #141287 and mention use-case, that resulted in missing op as well as commit hash 3b875c2. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
```

ghstack-source-id: 96faf9a
Pull Request resolved: #154010
[ghstack-poisoned]
malfet added a commit that referenced this pull request May 22, 2025
`isin_Tensor_Scalar_out` is just a redispatch to eq/neq
`isin_Scalar_Tensor_out` redispatches back to generic `isin` op

Add unittests to validate that
Before this change both of those failed
```python
>>> import torch
>>> t = torch.tensor([0, 1, 2], device='mps')
>>> torch.isin(t, 1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
NotImplementedError: The operator 'aten::isin.Tensor_Scalar_out' is not currently implemented for the MPS device. If you want this op to be considered for addition please comment on #141287 and mention use-case, that resulted in missing op as well as commit hash 3b875c2. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
>>> torch.isin(1, t)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
NotImplementedError: The operator 'aten::isin.Scalar_Tensor_out' is not currently implemented for the MPS device. If you want this op to be considered for addition please comment on #141287 and mention use-case, that resulted in missing op as well as commit hash 3b875c2. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
```

ghstack-source-id: e60527d
Pull Request resolved: #154010
@malfet
Copy link
Contributor Author

malfet commented May 22, 2025

@pytorchbot merge -f "Roses are red, violets are blue, I want to land this PR now"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions bot deleted the gh/malfet/350/head branch June 23, 2025 02:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants