Skip to content

Commit 2cf2772

Browse files
committed
Enable output padding when only outermost dim is dynamic (#159404)
Summary: Pull Request resolved: #159404 When the shape of the output tensor has a dynamic outer most dim, the stride can still be padded to conform to configured alignment if specified by padding config. Test Plan: CI Rollback Plan: Reviewed By: blaine-rister, eellison Differential Revision: D79146886
1 parent a53d14d commit 2cf2772

File tree

2 files changed

+102
-19
lines changed

2 files changed

+102
-19
lines changed

test/inductor/test_padding.py

Lines changed: 95 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,18 @@ def geninp():
4949
return input_dict
5050

5151

52+
def get_padded_stride(shape, alignment_bytes, pad_output, itemsize):
53+
align = alignment_bytes // itemsize
54+
new_strides = [0 for _ in range(len(shape))]
55+
new_strides[len(shape) - 1] = 1
56+
for i in range(len(shape) - 1, 0, -1):
57+
stride = shape[i] * new_strides[i]
58+
if pad_output and stride % align != 0:
59+
stride = (stride + align - 1) // align * align
60+
new_strides[i - 1] = stride
61+
return tuple(new_strides)
62+
63+
5264
class LinearAndSoftmax(nn.Module):
5365
"""
5466
It's very common that a transformer model will do a matmul and then
@@ -745,20 +757,11 @@ def get_input(size: tuple[int], alignment_bytes: int) -> torch.Tensor:
745757
input_tensors = [get_input(shape, alignment_bytes) for _ in range(num_inputs)]
746758

747759
config_patches = {
748-
"compile_threads": 1,
749760
"comprehensive_padding": pad_output,
750761
"cpu_backend": "triton",
751-
"disable_padding_cpu": False,
752-
"implicit_fallbacks": False,
753-
"inplace_buffers": False,
754762
"padding_alignment_bytes": alignment_bytes,
755-
"pad_channels_last": True,
756763
"pad_outputs": True,
757764
"padding_stride_threshold": 0,
758-
"triton.prefer_nd_tiling": True,
759-
"triton.use_block_ptr": True,
760-
"triton.codegen_upcast_to_fp32": False,
761-
"unroll_reductions_threshold": 1,
762765
}
763766
with config.patch(config_patches):
764767
compiled = torch.compile(torch.cat)
@@ -767,7 +770,89 @@ def get_input(size: tuple[int], alignment_bytes: int) -> torch.Tensor:
767770
output_shape = (shape[0] * num_inputs, shape[1])
768771
output_stride = input_tensors[0].stride()
769772
output_line = f"buf12 = empty_strided_{GPU_TYPE}({output_shape}, {output_stride}, torch.float32)"
770-
self.assertTrue(any(output_line in line for line in code))
773+
self.assertTrue(output_line in code[0])
774+
775+
@parametrize(
776+
"shape,alignment_bytes,pad_output",
777+
[
778+
((512, 1), 32, False),
779+
((512, 1), 32, True),
780+
((32, 30), 64, False),
781+
((32, 30), 64, True),
782+
((512, 100, 1), 32, False),
783+
((512, 100, 1), 32, True),
784+
((32, 50, 30), 64, False),
785+
((32, 50, 30), 64, True),
786+
],
787+
)
788+
def test_outer_dynamic_shape_padding(self, shape, alignment_bytes, pad_output):
789+
"""
790+
When only the outter most dim is dynamic shape, the output can still be padded up
791+
based on padding configuration.
792+
"""
793+
num_inputs = 2
794+
input_tensors = [
795+
torch.randn(shape, dtype=torch.float32) for _ in range(num_inputs)
796+
]
797+
798+
config_patches = {
799+
"comprehensive_padding": pad_output,
800+
"cpu_backend": "triton",
801+
"padding_alignment_bytes": alignment_bytes,
802+
"pad_outputs": True,
803+
"padding_stride_threshold": 0,
804+
}
805+
with config.patch(config_patches):
806+
torch._dynamo.mark_dynamic(input_tensors[0], 0)
807+
torch._dynamo.mark_dynamic(input_tensors[1], 0)
808+
compiled = torch.compile(torch.add)
809+
result, _ = run_and_get_code(compiled, *input_tensors)
810+
811+
expected_stride = get_padded_stride(
812+
result.shape, alignment_bytes, pad_output, result.dtype.itemsize
813+
)
814+
self.assertEqual(result.stride(), expected_stride)
815+
816+
@parametrize(
817+
"shape,alignment_bytes,pad_output",
818+
[
819+
((500, 10, 1), 32, False),
820+
((500, 20, 1), 32, True),
821+
((30, 10, 20), 64, True),
822+
((30, 10, 20), 64, False),
823+
],
824+
)
825+
def test_perm_outer_dynamic_shape_padding(self, shape, alignment_bytes, pad_output):
826+
"""
827+
When only the outter most dim is dynamic shape, the output can still be padded up
828+
based on padding configuration. Test when this occurs after a permute op.
829+
"""
830+
831+
def permute_contig(x):
832+
return torch.transpose(x, 0, 2).contiguous()
833+
834+
num_inputs = 1
835+
input_tensors = [
836+
torch.randn(shape, dtype=torch.float32) for _ in range(num_inputs)
837+
]
838+
839+
config_patches = {
840+
"comprehensive_padding": pad_output,
841+
"cpu_backend": "triton",
842+
"padding_alignment_bytes": alignment_bytes,
843+
"pad_outputs": True,
844+
"padding_stride_threshold": 0,
845+
"triton.use_block_ptr": True,
846+
}
847+
with config.patch(config_patches):
848+
torch._dynamo.mark_dynamic(input_tensors[0], 2)
849+
compiled = torch.compile(permute_contig)
850+
result, _ = run_and_get_code(compiled, *input_tensors)
851+
852+
expected_stride = get_padded_stride(
853+
result.shape, alignment_bytes, pad_output, result.dtype.itemsize
854+
)
855+
self.assertEqual(result.stride(), expected_stride)
771856

772857

773858
if __name__ == "__main__":

torch/_inductor/ir.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3707,10 +3707,8 @@ def _pad_strides(
37073707
# do for dynamic shape.
37083708
#
37093709
# Skip padding the strides for dynamic shape for now.
3710-
if not all(
3711-
isinstance(s, (int, sympy.Integer))
3712-
for s in itertools.chain(in_strides, size)
3713-
):
3710+
# If outermost dim is dynamic, stride still can be fully static
3711+
if not all(isinstance(s, (int, sympy.Integer)) for s in in_strides):
37143712
return in_strides
37153713

37163714
stride_order = get_stride_order(in_strides)
@@ -3725,11 +3723,11 @@ def _pad_strides(
37253723
for rank, idx in enumerate(fill_order[1:], start=1):
37263724
prev_idx = fill_order[rank - 1]
37273725
stride = new_strides[prev_idx] * size[prev_idx]
3728-
3729-
if stride > config.padding_stride_threshold and stride % align != 0:
3730-
stride = ceildiv(stride, align) * align
3731-
padded = True
3732-
new_strides[idx] = stride
3726+
if isinstance(stride, (int, sympy.Integer)):
3727+
if stride > config.padding_stride_threshold and stride % align != 0:
3728+
stride = ceildiv(stride, align) * align
3729+
padded = True
3730+
new_strides[idx] = stride
37333731

37343732
if not padded:
37353733
# Consider a tensor with shape [256, 1, 5, 5]

0 commit comments

Comments
 (0)