Skip to content

[ROCm] Use opportunistic fastatomics based on hueristics #159430

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

jerrymannil
Copy link
Contributor

@jerrymannil jerrymannil commented Jul 29, 2025

  • Opportunistic fast atomics works better with small sizes, since there is more chance of lanes doing atomics on the same address

Co-author: @amd-hhashemi

Reproducer:

import time
import torch

x = torch.randn((1_632_960, 128), device='cuda', dtype=torch.float)
ind = torch.randint(0, x.size(0), size=(5_079_670,), device='cuda')
src = torch.randn((5_079_670, 128), device='cuda', dtype=torch.float)

for _ in range(20):
    x.index_add_(0, ind, src)

start_time = time.time()
for i in range(100):
    x.index_add_(0, ind, src)
torch.cuda.synchronize()
end_time = time.time()
mean_time = (end_time - start_time)/100
print(f"Avg time for index_add_: {mean_time * 1e6:.2f} us")

Perf numbers:

Before:
Avg time for index_add_: 25652.16 us

After:
Avg time for index_add_: 2675.15 us

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd

Copy link

pytorch-bot bot commented Jul 29, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159430

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Cancelled Job, 1 Unrelated Failure

As of commit 44e97dd with merge base 1ebcba4 (image):

CANCELLED JOB - The following job was cancelled. Please retry:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added module: rocm AMD GPU support for Pytorch release notes: cuda release notes category labels Jul 29, 2025
@jerrymannil jerrymannil marked this pull request as draft July 29, 2025 23:29
@pruthvistony pruthvistony added topic: not user facing topic category ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/rocm Trigger "default" config CI on ROCm ciflow/inductor-rocm Trigger "inductor" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/periodic-rocm-mi300 Trigger "distributed" config CI on ROCm MI300 and removed release notes: cuda release notes category labels Jul 30, 2025
Copy link

pytorch-bot bot commented Jul 30, 2025

To add the ciflow label ciflow/rocm please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

Copy link

pytorch-bot bot commented Jul 30, 2025

To add the ciflow label ciflow/periodic please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

Copy link

pytorch-bot bot commented Jul 30, 2025

To add the ciflow label ciflow/inductor-rocm please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

Copy link

pytorch-bot bot commented Jul 30, 2025

To add the ciflow label ciflow/rocm-mi300 please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

Copy link

pytorch-bot bot commented Jul 30, 2025

To add the ciflow label ciflow/periodic-rocm-mi300 please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@pytorch-bot pytorch-bot bot removed ciflow/rocm Trigger "default" config CI on ROCm ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/inductor-rocm Trigger "inductor" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/periodic-rocm-mi300 Trigger "distributed" config CI on ROCm MI300 labels Jul 30, 2025
@pruthvistony pruthvistony added ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR rocm This tag is for PRs from ROCm team ciflow/rocm Trigger "default" config CI on ROCm ciflow/inductor-rocm Trigger "inductor" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/periodic-rocm-mi300 Trigger "distributed" config CI on ROCm MI300 labels Jul 30, 2025
Copy link

pytorch-bot bot commented Jul 30, 2025

To add the ciflow label ciflow/inductor-rocm please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

Copy link

pytorch-bot bot commented Jul 30, 2025

To add the ciflow label ciflow/rocm-mi300 please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@pytorch-bot pytorch-bot bot removed ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/rocm Trigger "default" config CI on ROCm ciflow/inductor-rocm Trigger "inductor" config CI on ROCm ciflow/periodic-rocm-mi300 Trigger "distributed" config CI on ROCm MI300 ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 labels Jul 30, 2025
@pytorch-bot pytorch-bot bot added the ciflow/rocm Trigger "default" config CI on ROCm label Jul 31, 2025
@jithunnair-amd
Copy link
Collaborator

@jerrymannil Can you please mention if there's a unit test or workload you used to qualify and quantify the improvement?

@pruthvistony
Copy link
Collaborator

pruthvistony commented Jul 31, 2025

Tickets comments shows the results - https://ontrack-internal.amd.com/browse/SWDEV-546136?focusedId=19870475&page=com.atlassian.jira.plugin.system.issuetabpanels%3Acomment-tabpanel#comment-19870475

@jerrymannil
Please update the PR description with all the perf results and testing done.

@pruthvistony pruthvistony marked this pull request as ready for review July 31, 2025 18:54
@jerrymannil
Copy link
Contributor Author

@pruthvistony @jithunnair-amd
Updated PR description with reproducer and numbers

pruthvistony pushed a commit to ROCm/pytorch that referenced this pull request Jul 31, 2025
#2438)

* Merge of pytorch#159430
* Opportunistic fast atomics works better will small sizes, since there
is more chance of lanes doing atomics on the same address

Reproducer:
```
import time
import torch

x = torch.randn((1_632_960, 128), device='cuda', dtype=torch.float)
ind = torch.randint(0, x.size(0), size=(5_079_670,), device='cuda')
src = torch.randn((5_079_670, 128), device='cuda', dtype=torch.float)

for _ in range(20):
    x.index_add_(0, ind, src)

start_time = time.time()
for i in range(100):
    x.index_add_(0, ind, src)
torch.cuda.synchronize()
end_time = time.time()
mean_time = (end_time - start_time)/100
print(f"Avg time for index_add_: {mean_time * 1e6:.2f} us")
```

Perf numbers:
```
Before:
Avg time for index_add_: 25652.16 us

After:
Avg time for index_add_: 2675.15 us
```

Co-author: @amd-hhashemi
@pruthvistony pruthvistony requested review from malfet and atalman July 31, 2025 21:35
@jerrymannil
Copy link
Contributor Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

* Opportunistic fast atomics works better will small sizes, since there is more chance of lanes doing atomics on the same address
@pytorchmergebot
Copy link
Collaborator

Successfully rebased patch-1 onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout patch-1 && git pull --rebase)

@pytorch-bot pytorch-bot bot removed the ciflow/rocm Trigger "default" config CI on ROCm label Aug 1, 2025
@jerrymannil
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 1, 2025
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Approvers from one of the following sets are needed:

  • superuser (pytorch/metamates)
  • Core Reviewers (mruberry, lezcano, Skylion007, ngimel, peterbell10, ...)
  • Core Maintainers (soumith, gchanan, ezyang, dzhulgakov, malfet, ...)
Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request module: rocm AMD GPU support for Pytorch open source rocm This tag is for PRs from ROCm team topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants