-
Notifications
You must be signed in to change notification settings - Fork 24.9k
[FakeTensor] Supplement the relevant logic for converting conv1d to conv2d in meta_conv #160408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…o convert2d in conv of meta kernel
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/160408
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 28a839d with merge base fea7e9d ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@eellison Please review my PR and leave any comments if you are convenient. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for pr - I would have expected us to need to update the fake impl. Which has access to the device information.
@@ -2450,6 +2450,19 @@ def pick_memory_format(): | |||
elif input_tensor.is_contiguous(memory_format=torch.preserve_format): | |||
return torch.preserve_format | |||
|
|||
# Expand 1d -> 2d. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure we don't need to update
pytorch/torch/_subclasses/fake_impls.py
Lines 1024 to 1026 in b7db866
elif k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu: | |
mem_fmt = None | |
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, not entirely sure yet. I tested the bug fix locally and it worked fine, but I'll double-check and confirm after further investigation.
Fixes #159462
summary
the issue is caused by the wrong stride of conv1d's result generated by meta_conv:
pytorch/torch/_meta_registrations.py
Lines 2453 to 2471 in 4d5b3f2
and the wrong stride will be used to codegen size assert in inductor:
pytorch/torch/_inductor/ir.py
Lines 6152 to 6163 in 4d5b3f2
reason
So why the computed stride is wrong in the meta_conv function? because the corresponding backend will convert conv1d to conv2d and change the input tensor' size and memory_format(channel last). but the meta_conv do not do this transformation, so a mismatch happend.
pytorch/aten/src/ATen/native/Convolution.cpp
Lines 1502 to 1510 in 4d5b3f2
just add corresponding logic in meta_conv.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @eellison