-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Enable output padding when only outermost dim is dynamic #159404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3733,10 +3733,8 @@ def _pad_strides( | |
# do for dynamic shape. | ||
# | ||
# Skip padding the strides for dynamic shape for now. | ||
if not all( | ||
isinstance(s, (int, sympy.Integer)) | ||
for s in itertools.chain(in_strides, size) | ||
): | ||
# If outermost dim is dynamic, stride still can be fully static | ||
if not all(isinstance(s, (int, sympy.Integer)) for s in in_strides): | ||
return in_strides | ||
|
||
stride_order = get_stride_order(in_strides) | ||
|
@@ -3751,11 +3749,11 @@ def _pad_strides( | |
for rank, idx in enumerate(fill_order[1:], start=1): | ||
prev_idx = fill_order[rank - 1] | ||
stride = new_strides[prev_idx] * size[prev_idx] | ||
|
||
if stride > config.padding_stride_threshold and stride % align != 0: | ||
stride = ceildiv(stride, align) * align | ||
padded = True | ||
new_strides[idx] = stride | ||
if isinstance(stride, (int, sympy.Integer)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should definitively be an int right, since we checked for sympy.Integer above. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We allow both int or sympy.int above. I am not sure if it's possible that the stride values are all sympy.int or partially int/sympy.int, prob doesn't hurt to leave it though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should be an invariant based on our checks above. so it makes the code less clear |
||
if stride > config.padding_stride_threshold and stride % align != 0: | ||
stride = ceildiv(stride, align) * align | ||
padded = True | ||
new_strides[idx] = stride | ||
|
||
if not padded: | ||
# Consider a tensor with shape [256, 1, 5, 5] | ||
|
Uh oh!
There was an error while loading. Please reload this page.