-
Notifications
You must be signed in to change notification settings - Fork 24.9k
[simplefsdp auto-bucketing] add ir node bucket helper function #158097
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ruisizhang123
wants to merge
6
commits into
gh/ruisizhang123/4/base
Choose a base branch
from
gh/ruisizhang123/4/head
base: gh/ruisizhang123/4/base
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,070
−4
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158097
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 1 Unrelated FailureAs of commit 1961123 with merge base 05c19d1 ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This was referenced Jul 11, 2025
Git-Hub-Chris
pushed a commit
to Git-Hub-Chris/PyTorch
that referenced
this pull request
Jul 15, 2025
ghstack-source-id: 8ee81dd Pull-Request: pytorch/pytorch#158097
…tion" This pr is based on Diff D67292294 from yf225. Major changes are: - Change the function structure to be compatible with auto-bucketing - Group bucketed nodes & dependencies with GroupedSchedulerNodes for easier reordering. Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #160282 * #158609 * #158321 * #158098 * __->__ #158097 * #157572 cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
ruisizhang123
added a commit
that referenced
this pull request
Aug 11, 2025
…cket helper function" This pr is based on Diff D67292294 from yf225. Major changes are: - Change the function structure to be compatible with auto-bucketing - Group bucketed nodes & dependencies with GroupedSchedulerNodes for easier reordering. Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #160282 * #158609 * #158321 * #158098 * __->__ #158097 * #157572 cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
ruisizhang123
added a commit
that referenced
this pull request
Aug 11, 2025
… with greedy algorithm" ## Greedy Algorithm Design ### Runtime estimation estimation is done in [__->__ #157572]. #### Communication estimation: - Optimization: While realize communication data into realtensor gives us good estimation of a node, it could incur significant overhead in auto-bucketing. - Calibration-based estimation: We sample 20 samples evenly in a range of [min(fwd_ag_tensor_list), 0.3* sum(fwd_ag_tensor_list)], and benchmark its communication time. This also applies to backward reduce scatter tensors. - When a new IR comes, it will search the saved communication dictionary for the closest tensor size, and use the closest tensor runtime as its predicted runtime. #### Cache estimated runtime Here, we add a cache to save pre-estimated results. Leverage a comm cache dictionary & comp cache dictionary to save pre-estimated inductor ir. If it meets a new IR node with the same key, it will skip the estimation for the new IR node. - For comm, the key is communication type, input tensor size & output tensor size. - For comp, the key is (i) generated triton code for BaseSchedulerNode/FusedSchedulerNode; (ii) Extern Kernel args ### Greedy Algorithm Implementation Core idea: the greedy algorithm decide if the node will be bucketed with the previous one based on several criteria below. The reordering will by itself reorder with the previous computation. bucketing is done in [__->__ #158097] reordering is done in [__->__ #158098] #### FWD Pass: - (i) the bucketed AG communication could be overlapped by the previous computation; - (ii) the bucketed AG copy-in/out memory doesn’t exceed peak memory; - (iii) bucketed AG communication size doesn’t exceed 0.3* sum(fwd_ag_tensor_list), such that the estimated AG communication time is always in the calibration bound. #### BWD Pass: - (i) the bucketed AG + RS communication could be overlapped by the previous computation; - (ii) the bucketed AG+RS copy-in/out memory doesn’t exceed peak memory; - (iii) RS always have future compute to overlap it, such that its final exposed communication is small; - (iv) bucketed AG/RS communication size doesn’t exceed 0.3* sum(fwd_ag_tensor_list)/0.3* sum(bwd_rs_tensor_list), such that the estimated AG/RS communication time is always in the calibration bound. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
ruisizhang123
added a commit
that referenced
this pull request
Aug 11, 2025
…orithm" ## Greedy Algorithm Design ### Runtime estimation estimation is done in [__->__ #157572]. #### Communication estimation: - Optimization: While realize communication data into realtensor gives us good estimation of a node, it could incur significant overhead in auto-bucketing. - Calibration-based estimation: We sample 20 samples evenly in a range of [min(fwd_ag_tensor_list), 0.3* sum(fwd_ag_tensor_list)], and benchmark its communication time. This also applies to backward reduce scatter tensors. - When a new IR comes, it will search the saved communication dictionary for the closest tensor size, and use the closest tensor runtime as its predicted runtime. #### Cache estimated runtime Here, we add a cache to save pre-estimated results. Leverage a comm cache dictionary & comp cache dictionary to save pre-estimated inductor ir. If it meets a new IR node with the same key, it will skip the estimation for the new IR node. - For comm, the key is communication type, input tensor size & output tensor size. - For comp, the key is (i) generated triton code for BaseSchedulerNode/FusedSchedulerNode; (ii) Extern Kernel args ### Greedy Algorithm Implementation Core idea: the greedy algorithm decide if the node will be bucketed with the previous one based on several criteria below. The reordering will by itself reorder with the previous computation. bucketing is done in [__->__ #158097] reordering is done in [__->__ #158098] #### FWD Pass: - (i) the bucketed AG communication could be overlapped by the previous computation; - (ii) the bucketed AG copy-in/out memory doesn’t exceed peak memory; - (iii) bucketed AG communication size doesn’t exceed 0.3* sum(fwd_ag_tensor_list), such that the estimated AG communication time is always in the calibration bound. #### BWD Pass: - (i) the bucketed AG + RS communication could be overlapped by the previous computation; - (ii) the bucketed AG+RS copy-in/out memory doesn’t exceed peak memory; - (iii) RS always have future compute to overlap it, such that its final exposed communication is small; - (iv) bucketed AG/RS communication size doesn’t exceed 0.3* sum(fwd_ag_tensor_list)/0.3* sum(bwd_rs_tensor_list), such that the estimated AG/RS communication time is always in the calibration bound. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
ciflow/inductor
module: inductor
oncall: distributed
Add this issue/PR to distributed oncall triage queue
release notes: distributed (fsdp)
release notes category
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This pr is based on Diff D67292294 from @yf225.
Major changes are:
Stack from ghstack (oldest at bottom):
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov