Skip to content

[simplefsdp auto-bucketing] ir node runtime estimation #157572

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 18 commits into
base: gh/ruisizhang123/1/base
Choose a base branch
from

Conversation

ruisizhang123
Copy link
Contributor

@ruisizhang123 ruisizhang123 commented Jul 3, 2025

Computation Estimation:

  • BaseSchedulerNode/FusedSchedulerNode: Get the generated triton code and use do_bench mode to benchmark the runtime.
  • ExternKernelSchedulerNode: Realize faketensor into realtensor, and log the cuda event time.

Communication Estimation:

  • NCCL estimation: estimate communication using NCCL's ncclGroupSimulateEnd API from this PR.
  • Profiling-based estimation: Realize faketensor into realtensor, and log the cuda event time for performing the communication.

Note: By default, I use Profiling-based estimation to estimate communication. The reason is it gives better estimation performance in multi-node settings.

At the end, it aggregates the estimated runtime from all ranks, and take the median runtime as the final estimated results.

Stack from ghstack (oldest at bottom):

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

[ghstack-poisoned]
ruisizhang123 added a commit that referenced this pull request Jul 3, 2025
ghstack-source-id: 31163dc
Pull-Request: #157572
Copy link

pytorch-bot bot commented Jul 3, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/157572

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 2 Unrelated Failures

As of commit 4a0c2e8 with merge base 05c19d1 (image):

NEW FAILURE - The following job has failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Contributor

github-actions bot commented Jul 3, 2025

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@ruisizhang123 ruisizhang123 changed the title add estimation code [simplefsdp auto-bucketing] ir node runtime estimation Jul 3, 2025
@ruisizhang123 ruisizhang123 marked this pull request as draft July 3, 2025 17:26
[ghstack-poisoned]
ruisizhang123 added a commit that referenced this pull request Jul 7, 2025
ghstack-source-id: 9453057
Pull-Request: #157572
[ghstack-poisoned]
ruisizhang123 added a commit that referenced this pull request Jul 7, 2025
ghstack-source-id: d7b6a66
Pull-Request: #157572
process_group = c10d.distributed_c10d._get_default_group()

tensor_input = create_real_tensor(inputs.data.get_size(), inputs.data.get_dtype(), inputs.data.get_device())
tensor_output = create_real_tensor(kernel.layout.size, kernel.layout.dtype, kernel.layout.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one downside to allocating real tensors at compile time to do the comms estimation is that if we are already close to peak memory at the time that we do the estimation, we risk compile causing an OOM. IIRC we've seen similar problems before with inductor, since it allocates tensors to do autotuning.

Mostly just calling this out as a potential downside. Although maybe this is still the best option, given that your work showed this is noticeably more accurate than other forms of comms estimation? Runtime PGO is another option but it is a much bigger lift to set up. Or just brainstorming: maybe we can also avoid these allocations being problematic by either offloading other tensors to CPU, or e.g. doing the runtime estimation with small tensor shapes, and having a way to accurately estimate the runtime for larger shapes.

Copy link
Contributor Author

@ruisizhang123 ruisizhang123 Jul 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tensors are deleted immediately in L126 with del tensor_input, tensor_output. Do you think there might be potential memory overhead even after they are deleted? 👀

end_evt.record()
end_evt.synchronize()

torch.cuda.synchronize()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: doing cuda sync / event sync should not be necessary to get timings from the event. I think i'd suggest just removing them.

Copy link
Contributor Author

@ruisizhang123 ruisizhang123 Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will need end_evt.synchronize() to ensure the end_event has finished before calculating the runtime in start_evt.elapsed_time(end_evt). Otherwise, it will throw error. I think torch.cuda.synchronize() is not necessary here tho.

[ghstack-poisoned]
ruisizhang123 added a commit that referenced this pull request Jul 8, 2025
ghstack-source-id: 371d299
Pull-Request: #157572
[ghstack-poisoned]
ruisizhang123 added a commit that referenced this pull request Jul 8, 2025
ghstack-source-id: 2030072
Pull-Request: #157572
[ghstack-poisoned]
ruisizhang123 added a commit that referenced this pull request Jul 8, 2025
ghstack-source-id: b0bd331
Pull-Request: #157572
[ghstack-poisoned]
ruisizhang123 added a commit that referenced this pull request Jul 8, 2025
ghstack-source-id: 627560e
Pull-Request: #157572
[ghstack-poisoned]
ruisizhang123 added a commit that referenced this pull request Jul 8, 2025
ghstack-source-id: 72ce66b
Pull-Request: #157572
ruisizhang123 added a commit that referenced this pull request Jul 8, 2025
ghstack-source-id: ad26df5
Pull-Request: #157572
[ghstack-poisoned]
ruisizhang123 added a commit that referenced this pull request Jul 8, 2025
ghstack-source-id: d1d7074
Pull-Request: #157572
[ghstack-poisoned]
ruisizhang123 added a commit that referenced this pull request Jul 8, 2025
ghstack-source-id: 22a16e7
Pull-Request: #157572
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Git-Hub-Chris pushed a commit to Git-Hub-Chris/PyTorch that referenced this pull request Jul 15, 2025
ghstack-source-id: 40c2647
Pull-Request: pytorch/pytorch#157572
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Computation Estimation:

-  BaseSchedulerNode/FusedSchedulerNode: Get the generated triton code and use [do_bench](https://github.com/pytorch/pytorch/blob/85111cd165f108ffabb4a90083d59d7a867ebd9f/torch/_inductor/codegen/triton.py#L4234) mode to benchmark the runtime.
- ExternKernelSchedulerNode: Realize faketensor into realtensor, and log the cuda event time.

Communication Estimation:

- NCCL estimation: estimate communication using NCCL's [ncclGroupSimulateEnd API](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/group.html?highlight=ncclsiminfo_t) from this [PR](#149343).
- Profiling-based estimation: Realize faketensor into realtensor, and log the cuda event time for performing the communication.

Note: By default, I use Profiling-based estimation to estimate communication. The reason is it gives better estimation performance in multi-node settings.

At the end, it aggregates the estimated runtime from all ranks, and take the median runtime as the final estimated results.  




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
ruisizhang123 added a commit that referenced this pull request Aug 11, 2025
…cket helper function"

This pr is based on Diff D67292294 from yf225.

Major changes are:

- Change the function structure to be compatible with auto-bucketing
- Group bucketed nodes & dependencies with GroupedSchedulerNodes for easier reordering.

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):

* #160282
* #158609
* #158321
* #158098
*  __->__ #158097
* #157572



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
ruisizhang123 added a commit that referenced this pull request Aug 11, 2025
…tion"

This pr is based on Diff D67292294 from yf225.

Major changes are:

- Change the function structure to be compatible with auto-bucketing
- Group bucketed nodes & dependencies with GroupedSchedulerNodes for easier reordering.

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):

* #160282
* #158609
* #158321
* #158098
*  __->__ #158097
* #157572



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
ruisizhang123 added a commit that referenced this pull request Aug 11, 2025
… with greedy algorithm"

## Greedy Algorithm Design

### Runtime estimation

estimation is done in  [__->__ #157572]. 

#### Communication estimation:

- Optimization: While realize communication data into realtensor gives us good estimation of a node, it could incur significant overhead in auto-bucketing.
- Calibration-based estimation: We sample 20 samples evenly in a range of [min(fwd_ag_tensor_list), 0.3* sum(fwd_ag_tensor_list)], and benchmark its communication time. This also applies to backward reduce scatter tensors.
- When a new IR comes, it will search the saved communication dictionary for the closest tensor size, and use the closest tensor runtime as its predicted runtime. 

#### Cache estimated runtime

Here, we add a cache to save pre-estimated results. Leverage a comm cache dictionary & comp cache dictionary to save pre-estimated inductor ir. If it meets a new IR node with the same key, it will skip the estimation for the new IR node.

- For comm, the key is communication type, input tensor size & output tensor size.
- For comp, the key is (i) generated triton code for BaseSchedulerNode/FusedSchedulerNode; (ii) Extern Kernel args

### Greedy Algorithm Implementation

Core idea: the greedy algorithm decide if the node will be bucketed with the previous one based on several criteria below. The reordering will by itself reorder with the previous computation.

bucketing is done in  [__->__ #158097]
reordering is done in  [__->__ #158098]

#### FWD Pass:

- (i) the bucketed AG communication could be overlapped by the previous computation;
- (ii) the bucketed AG copy-in/out memory doesn’t exceed peak memory; 
- (iii) bucketed AG communication size doesn’t exceed 0.3* sum(fwd_ag_tensor_list), such that the estimated AG communication time is always in the calibration bound.

#### BWD Pass:

- (i) the bucketed AG + RS communication could be overlapped by the previous computation; 
- (ii)  the bucketed AG+RS copy-in/out memory doesn’t exceed peak memory; 
- (iii) RS always have future compute to overlap it, such that its final exposed communication is small; 
- (iv) bucketed AG/RS communication size doesn’t exceed 0.3* sum(fwd_ag_tensor_list)/0.3* sum(bwd_rs_tensor_list), such that  the estimated AG/RS communication time is always in the calibration bound.





cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
ruisizhang123 added a commit that referenced this pull request Aug 11, 2025
…orithm"

## Greedy Algorithm Design

### Runtime estimation

estimation is done in  [__->__ #157572]. 

#### Communication estimation:

- Optimization: While realize communication data into realtensor gives us good estimation of a node, it could incur significant overhead in auto-bucketing.
- Calibration-based estimation: We sample 20 samples evenly in a range of [min(fwd_ag_tensor_list), 0.3* sum(fwd_ag_tensor_list)], and benchmark its communication time. This also applies to backward reduce scatter tensors.
- When a new IR comes, it will search the saved communication dictionary for the closest tensor size, and use the closest tensor runtime as its predicted runtime. 

#### Cache estimated runtime

Here, we add a cache to save pre-estimated results. Leverage a comm cache dictionary & comp cache dictionary to save pre-estimated inductor ir. If it meets a new IR node with the same key, it will skip the estimation for the new IR node.

- For comm, the key is communication type, input tensor size & output tensor size.
- For comp, the key is (i) generated triton code for BaseSchedulerNode/FusedSchedulerNode; (ii) Extern Kernel args

### Greedy Algorithm Implementation

Core idea: the greedy algorithm decide if the node will be bucketed with the previous one based on several criteria below. The reordering will by itself reorder with the previous computation.

bucketing is done in  [__->__ #158097]
reordering is done in  [__->__ #158098]

#### FWD Pass:

- (i) the bucketed AG communication could be overlapped by the previous computation;
- (ii) the bucketed AG copy-in/out memory doesn’t exceed peak memory; 
- (iii) bucketed AG communication size doesn’t exceed 0.3* sum(fwd_ag_tensor_list), such that the estimated AG communication time is always in the calibration bound.

#### BWD Pass:

- (i) the bucketed AG + RS communication could be overlapped by the previous computation; 
- (ii)  the bucketed AG+RS copy-in/out memory doesn’t exceed peak memory; 
- (iii) RS always have future compute to overlap it, such that its final exposed communication is small; 
- (iv) bucketed AG/RS communication size doesn’t exceed 0.3* sum(fwd_ag_tensor_list)/0.3* sum(bwd_rs_tensor_list), such that  the estimated AG/RS communication time is always in the calibration bound.





cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants