Skip to content

Make Nd tensors hit fused addmm pass #106911

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from

Conversation

eellison
Copy link
Contributor

@eellison eellison commented Aug 9, 2023

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 9, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/106911

Note: Links to docs will display an error until the docs builds have been completed.

✅ 1 Unrelated Failure

As of commit 4adfdee with merge base ed07821 (image):

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

mod_c = torch.compile(mod)
out, code = run_and_get_code(mod_c, v, other)
self.assertEqual(out, mod(v, other), rtol=1e-2, atol=1e-2)
# TODO - assert fusions work code
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @jgong5, would you mind taking a look at this ?

@eellison eellison requested review from albanD and ezyang August 9, 2023 22:30
Replace #106433 since I had a bad cla commit.

Speeds up eager convnext bfloat16 inference by 35%., and eager timm bfloat16 inference average by `.5%`

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov anijain2305

[ghstack-poisoned]
@@ -152,7 +152,6 @@ def __init__(self):
)
self.bn1 = torch.nn.BatchNorm2d(num_features=16)
self.relu1 = torch.nn.ReLU()
self.fc1 = torch.nn.Linear(in_features=1638400, out_features=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tell me more?

#include <c10/util/MaybeOwned.h>
#include <ATen/TensorSubclassLikeUtils.h>
#include <iostream>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dead now right

return result.view_symint({input_sizes[0], input_sizes[1], result.sym_size(1)});
// can't use -1 in reshape because it errors when a dimension is 0
c10::SymInt flattened_dim = 1;
for (size_t i = 0, ndim = input_sizes.size(); i < ndim - 1; ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ndim - 1 here hazardous because you are using unsigned size_t type, you will overflow if ndim == 0. Just use int64_t, the unsigned type here really is not worth it.

auto inp_reshape = input.reshape_symint({flattened_dim, input_sizes.at(input_sizes.size() -1)});
const auto result = at::addmm(bias, inp_reshape, weight.t());
auto new_size = input_sizes.slice(0, input_sizes.size() - 1);
std::vector<SymInt> sizes_vec(new_size.begin(), new_size.end());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using SymDimVector to avoid the heap allocation

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okey dokey

Replace #106433 since I had a bad cla commit.

Speeds up eager convnext bfloat16 inference by 35%., and eager timm bfloat16 inference average by `.5%`

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov anijain2305

[ghstack-poisoned]
@albanD albanD removed their request for review August 11, 2023 20:12
Replace #106433 since I had a bad cla commit.

Speeds up eager convnext bfloat16 inference by 35%., and eager timm bfloat16 inference average by `.5%`

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov anijain2305

[ghstack-poisoned]
@eellison eellison added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 16, 2023
@eellison
Copy link
Contributor Author

@pytorchbot merge -f "unrelated failure"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Aug 16, 2023
I get a 2% inference speedup in HF with this PR. I checked to see if there any models where unfusing was slower than the cublas gelu fusion, and I did not see any, which was surprising to me. Sorry for the cublas-activation api churn 😬

Kicking off another run in cublas 12, it's possible that the results have changed since.

Pull Request resolved: #106912
Approved by: https://github.com/jansel
ghstack dependencies: #106911
summerdo pushed a commit to summerdo/pytorch that referenced this pull request Aug 17, 2023
Replace pytorch#106433 since I had a bad cla commit.

Speeds up eager convnext bfloat16 inference by 35%., and eager timm bfloat16 inference average by `.5%`

Pull Request resolved: pytorch#106911
Approved by: https://github.com/ezyang
summerdo pushed a commit to summerdo/pytorch that referenced this pull request Aug 17, 2023
I get a 2% inference speedup in HF with this PR. I checked to see if there any models where unfusing was slower than the cublas gelu fusion, and I did not see any, which was surprising to me. Sorry for the cublas-activation api churn 😬

Kicking off another run in cublas 12, it's possible that the results have changed since.

Pull Request resolved: pytorch#106912
Approved by: https://github.com/jansel
ghstack dependencies: pytorch#106911
pytorchmergebot pushed a commit that referenced this pull request Aug 17, 2023
This can lead to a large speedup when max autotune is set, e.g. resnet 2.1x -> 2.5x, particularly in combination with freezing.

Pull Request resolved: #107004
Approved by: https://github.com/jansel, https://github.com/shunting314, https://github.com/int3
ghstack dependencies: #106911, #106912
@facebook-github-bot facebook-github-bot deleted the gh/eellison/512/head branch August 20, 2023 14:16
eellison added a commit to eellison/pytorch that referenced this pull request Mar 8, 2024
ghstack-source-id: 603bdd0
Pull Request resolved: pytorch#106911
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants