Skip to content

Enable output padding when only outermost dim is dynamic #159404

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 95 additions & 10 deletions test/inductor/test_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,18 @@ def geninp():
return input_dict


def get_padded_stride(shape, alignment_bytes, pad_output, itemsize):
align = alignment_bytes // itemsize
new_strides = [0 for _ in range(len(shape))]
new_strides[len(shape) - 1] = 1
for i in range(len(shape) - 1, 0, -1):
stride = shape[i] * new_strides[i]
if pad_output and stride % align != 0:
stride = (stride + align - 1) // align * align
new_strides[i - 1] = stride
return tuple(new_strides)


class LinearAndSoftmax(nn.Module):
"""
It's very common that a transformer model will do a matmul and then
Expand Down Expand Up @@ -745,20 +757,11 @@ def get_input(size: tuple[int], alignment_bytes: int) -> torch.Tensor:
input_tensors = [get_input(shape, alignment_bytes) for _ in range(num_inputs)]

config_patches = {
"compile_threads": 1,
"comprehensive_padding": pad_output,
"cpu_backend": "triton",
"disable_padding_cpu": False,
"implicit_fallbacks": False,
"inplace_buffers": False,
"padding_alignment_bytes": alignment_bytes,
"pad_channels_last": True,
"pad_outputs": True,
"padding_stride_threshold": 0,
"triton.prefer_nd_tiling": True,
"triton.use_block_ptr": True,
"triton.codegen_upcast_to_fp32": False,
"unroll_reductions_threshold": 1,
}
with config.patch(config_patches):
compiled = torch.compile(torch.cat)
Expand All @@ -767,7 +770,89 @@ def get_input(size: tuple[int], alignment_bytes: int) -> torch.Tensor:
output_shape = (shape[0] * num_inputs, shape[1])
output_stride = input_tensors[0].stride()
output_line = f"buf12 = empty_strided_{GPU_TYPE}({output_shape}, {output_stride}, torch.float32)"
self.assertTrue(any(output_line in line for line in code))
self.assertTrue(output_line in code[0])

@parametrize(
"shape,alignment_bytes,pad_output",
[
((512, 1), 32, False),
((512, 1), 32, True),
((32, 30), 64, False),
((32, 30), 64, True),
((512, 100, 1), 32, False),
((512, 100, 1), 32, True),
((32, 50, 30), 64, False),
((32, 50, 30), 64, True),
],
)
def test_outer_dynamic_shape_padding(self, shape, alignment_bytes, pad_output):
"""
When only the outermost dim is dynamic shape, the output can still be padded up
based on padding configuration.
"""
num_inputs = 2
input_tensors = [
torch.randn(shape, dtype=torch.float32) for _ in range(num_inputs)
]

config_patches = {
"comprehensive_padding": pad_output,
"cpu_backend": "triton",
"padding_alignment_bytes": alignment_bytes,
"pad_outputs": True,
"padding_stride_threshold": 0,
}
with config.patch(config_patches):
torch._dynamo.mark_dynamic(input_tensors[0], 0)
torch._dynamo.mark_dynamic(input_tensors[1], 0)
compiled = torch.compile(torch.add)
result, _ = run_and_get_code(compiled, *input_tensors)

expected_stride = get_padded_stride(
result.shape, alignment_bytes, pad_output, result.dtype.itemsize
)
self.assertEqual(result.stride(), expected_stride)

@parametrize(
"shape,alignment_bytes,pad_output",
[
((500, 10, 1), 32, False),
((500, 20, 1), 32, True),
((30, 10, 20), 64, True),
((30, 10, 20), 64, False),
],
)
def test_perm_outer_dynamic_shape_padding(self, shape, alignment_bytes, pad_output):
"""
When only the outermost dim is dynamic shape, the output can still be padded up
based on padding configuration. Test when this occurs after a permute op.
"""

def permute_contig(x):
return torch.transpose(x, 0, 2).contiguous()

num_inputs = 1
input_tensors = [
torch.randn(shape, dtype=torch.float32) for _ in range(num_inputs)
]

config_patches = {
"comprehensive_padding": pad_output,
"cpu_backend": "triton",
"padding_alignment_bytes": alignment_bytes,
"pad_outputs": True,
"padding_stride_threshold": 0,
"triton.use_block_ptr": True,
}
with config.patch(config_patches):
torch._dynamo.mark_dynamic(input_tensors[0], 2)
compiled = torch.compile(permute_contig)
result, _ = run_and_get_code(compiled, *input_tensors)

expected_stride = get_padded_stride(
result.shape, alignment_bytes, pad_output, result.dtype.itemsize
)
self.assertEqual(result.stride(), expected_stride)


if __name__ == "__main__":
Expand Down
16 changes: 7 additions & 9 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3733,10 +3733,8 @@ def _pad_strides(
# do for dynamic shape.
#
# Skip padding the strides for dynamic shape for now.
if not all(
isinstance(s, (int, sympy.Integer))
for s in itertools.chain(in_strides, size)
):
# If outermost dim is dynamic, stride still can be fully static
if not all(isinstance(s, (int, sympy.Integer)) for s in in_strides):
return in_strides

stride_order = get_stride_order(in_strides)
Expand All @@ -3751,11 +3749,11 @@ def _pad_strides(
for rank, idx in enumerate(fill_order[1:], start=1):
prev_idx = fill_order[rank - 1]
stride = new_strides[prev_idx] * size[prev_idx]

if stride > config.padding_stride_threshold and stride % align != 0:
stride = ceildiv(stride, align) * align
padded = True
new_strides[idx] = stride
if isinstance(stride, (int, sympy.Integer)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should definitively be an int right, since we checked for sympy.Integer above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We allow both int or sympy.int above. I am not sure if it's possible that the stride values are all sympy.int or partially int/sympy.int, prob doesn't hurt to leave it though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be an invariant based on our checks above. so it makes the code less clear

if stride > config.padding_stride_threshold and stride % align != 0:
stride = ceildiv(stride, align) * align
padded = True
new_strides[idx] = stride

if not padded:
# Consider a tensor with shape [256, 1, 5, 5]
Expand Down
Loading