-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Enable output padding when only outermost dim is dynamic #159404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159404
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 12 PendingAs of commit 7095895 with merge base ee9f8ba ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This pull request was exported from Phabricator. Differential Revision: D79146886 |
This PR needs a
|
9c66df6
to
7d7ba40
Compare
Summary: When the shape of the output tensor has a dynamic outer most dim, the stride can still be padded to conform to configured alignment if specified by padding config. Test Plan: CI Rollback Plan: Differential Revision: D79146886
This pull request was exported from Phabricator. Differential Revision: D79146886 |
7d7ba40
to
950ded0
Compare
This pull request was exported from Phabricator. Differential Revision: D79146886 |
950ded0
to
d39e10a
Compare
d39e10a
to
1b8e223
Compare
This pull request was exported from Phabricator. Differential Revision: D79146886 |
2 similar comments
This pull request was exported from Phabricator. Differential Revision: D79146886 |
This pull request was exported from Phabricator. Differential Revision: D79146886 |
) Summary: Pull Request resolved: pytorch#159404 When the shape of the output tensor has a dynamic outer most dim, the stride can still be padded to conform to configured alignment if specified by padding config. Test Plan: CI Rollback Plan: Differential Revision: D79146886
1b8e223
to
539c474
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice generalization but one comment.
Also, there's nothing preventing us from padding dynamic strides. If you look at the lowerings.py sdpa we pad to the inputs to dynamic aligned strides.
The only thing to think abt is the heuristic in the lowering to not padding stride threshold.
torch/_inductor/ir.py
Outdated
if not all( | ||
isinstance(s, (int, sympy.Integer)) | ||
for s in itertools.chain(in_strides, size) | ||
for s in itertools.chain(in_strides, size[1:]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The stride that is not dynamic even when the the corresponding dimension is dynamic will be the least dense dimension, not necessarily the last dimension.
It's only the last dimension when tensor is contiguous.
Would it be sufficient to just check strides here 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a good way to determine which dim is the outermost dim in the shape? Perhaps...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about checking if isinstance(s, (int, sympy.Integer))
to see if the strides are static, then picking the least one? A good test case would be add -> transpose fusion.
To generalize this to symbolic strides, I have a hunch we could sort them by V.graph.sizevars.statically_known_leq
. For example, if min(s0,s1)>=1
then 1 <= s0 <= s0 * s1
. But if we want a strict ordering, we may need to assume min(s0,s1)>1
, or substitute all the symbols with 2. (This makes a pretty strong assumption about what expressions strides can come from.) If this is false, then the strides will end up being equal anyways, so it may not matter which one we pick.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we even need to check the shapes here ? If we check the strides, wouldn't that be encompass the above check ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup I think you are right we just need to check strides. I can take a look at spda in follow up. We probably should generalize the padding to all dimensions having dynamic shapes if possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@eellison @nandesuka I took another look at the existing code. I'm wondering if this comment is up to date. It seems like it calls get_stride_order
to determine the least stride, and this uses size hints under the hood when it sees dynamic shapes. So static shapes may be too strict a requirement, although we might still require backend symints.
I'm wondering if the existing logic can generalize to dynamic shapes/strides, if we relax checks like
if stride > config.padding_stride_threshold and stride % align != 0
to something like this?
if V.graph.sizevars.statically_known_geq(stride, config.padding_stride_threshold)
Similarly dynamic strides could be computed with sympy:
stride = ceildiv(stride, align) * align
becomes
stride = CeilDiv(stride, align) * align
That being said, the dynamic stride formulae could be pretty complex. So maybe it's more practical to only pad when
isinstance(stride, (int, sympy.Integer))
Alternatively, if we want to fully support dynamic shaped padding, there's a trick which could simplify the formulae: the least stride is always 1
, and the second least stride is CeilDiv(stride, align) * align
. Then all the other strides have stride[k] = shape[k - 1] * stride[k - 1]
. So it seems like we really only need to have a complex formula for the 2nd least stride. Does that seem right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this pr - would it make sense to pad only static strides, then do the dynamic shapes in a follow ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense to me. The static inner stride case comes up a lot in practice with things like dynamic batch size.
539c474
to
6cc4985
Compare
This pull request was exported from Phabricator. Differential Revision: D79146886 |
) Summary: Pull Request resolved: pytorch#159404 When the shape of the output tensor has a dynamic outer most dim, the stride can still be padded to conform to configured alignment if specified by padding config. Test Plan: CI Rollback Plan: Reviewed By: blaine-rister, eellison Differential Revision: D79146886
6eac7e2
to
45724e9
Compare
This pull request was exported from Phabricator. Differential Revision: D79146886 |
45724e9
to
c4e8773
Compare
c4e8773
to
da5ef22
Compare
This pull request was exported from Phabricator. Differential Revision: D79146886 |
da5ef22
to
63a2cc6
Compare
This pull request was exported from Phabricator. Differential Revision: D79146886 |
) Summary: Pull Request resolved: pytorch#159404 When the shape of the output tensor has a dynamic outer most dim, the stride can still be padded to conform to configured alignment if specified by padding config. Test Plan: CI Rollback Plan: Reviewed By: blaine-rister, eellison Differential Revision: D79146886
63a2cc6
to
ff38ac9
Compare
This pull request was exported from Phabricator. Differential Revision: D79146886 |
ff38ac9
to
d28c92b
Compare
d28c92b
to
40729d5
Compare
Summary: When the shape of the output tensor has a dynamic outer most dim, the stride can still be padded to conform to configured alignment if specified by padding config. Test Plan: CI Rollback Plan: Reviewed By: blaine-rister, eellison Differential Revision: D79146886
This pull request was exported from Phabricator. Differential Revision: D79146886 |
) Summary: Pull Request resolved: pytorch#159404 When the shape of the output tensor has a dynamic outer most dim, the stride can still be padded to conform to configured alignment if specified by padding config. Test Plan: CI Rollback Plan: Reviewed By: blaine-rister, eellison Differential Revision: D79146886
40729d5
to
e09ee2d
Compare
This pull request was exported from Phabricator. Differential Revision: D79146886 |
e09ee2d
to
280b130
Compare
This pull request was exported from Phabricator. Differential Revision: D79146886 |
) Summary: Pull Request resolved: pytorch#159404 When the shape of the output tensor has a dynamic outer most dim, the stride can still be padded to conform to configured alignment if specified by padding config. Test Plan: CI Rollback Plan: Reviewed By: blaine-rister, eellison Differential Revision: D79146886
280b130
to
7095895
Compare
This pull request was exported from Phabricator. Differential Revision: D79146886 |
Summary: When the shape of the output tensor has a dynamic outer most dim, the stride can still be padded to conform to configured alignment if required.
Test Plan:
CI
Rollback Plan:
Differential Revision: D79146886
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben