diff --git a/test/inductor/test_padding.py b/test/inductor/test_padding.py index 15c1abdf32db..41944a916923 100644 --- a/test/inductor/test_padding.py +++ b/test/inductor/test_padding.py @@ -49,6 +49,18 @@ def geninp(): return input_dict +def get_padded_stride(shape, alignment_bytes, pad_output, itemsize): + align = alignment_bytes // itemsize + new_strides = [0 for _ in range(len(shape))] + new_strides[len(shape) - 1] = 1 + for i in range(len(shape) - 1, 0, -1): + stride = shape[i] * new_strides[i] + if pad_output and stride % align != 0: + stride = (stride + align - 1) // align * align + new_strides[i - 1] = stride + return tuple(new_strides) + + class LinearAndSoftmax(nn.Module): """ It's very common that a transformer model will do a matmul and then @@ -745,20 +757,11 @@ def get_input(size: tuple[int], alignment_bytes: int) -> torch.Tensor: input_tensors = [get_input(shape, alignment_bytes) for _ in range(num_inputs)] config_patches = { - "compile_threads": 1, "comprehensive_padding": pad_output, "cpu_backend": "triton", - "disable_padding_cpu": False, - "implicit_fallbacks": False, - "inplace_buffers": False, "padding_alignment_bytes": alignment_bytes, - "pad_channels_last": True, "pad_outputs": True, "padding_stride_threshold": 0, - "triton.prefer_nd_tiling": True, - "triton.use_block_ptr": True, - "triton.codegen_upcast_to_fp32": False, - "unroll_reductions_threshold": 1, } with config.patch(config_patches): compiled = torch.compile(torch.cat) @@ -767,7 +770,89 @@ def get_input(size: tuple[int], alignment_bytes: int) -> torch.Tensor: output_shape = (shape[0] * num_inputs, shape[1]) output_stride = input_tensors[0].stride() output_line = f"buf12 = empty_strided_{GPU_TYPE}({output_shape}, {output_stride}, torch.float32)" - self.assertTrue(any(output_line in line for line in code)) + self.assertTrue(output_line in code[0]) + + @parametrize( + "shape,alignment_bytes,pad_output", + [ + ((512, 1), 32, False), + ((512, 1), 32, True), + ((32, 30), 64, False), + ((32, 30), 64, True), + ((512, 100, 1), 32, False), + ((512, 100, 1), 32, True), + ((32, 50, 30), 64, False), + ((32, 50, 30), 64, True), + ], + ) + def test_outer_dynamic_shape_padding(self, shape, alignment_bytes, pad_output): + """ + When only the outermost dim is dynamic shape, the output can still be padded up + based on padding configuration. + """ + num_inputs = 2 + input_tensors = [ + torch.randn(shape, dtype=torch.float32) for _ in range(num_inputs) + ] + + config_patches = { + "comprehensive_padding": pad_output, + "cpu_backend": "triton", + "padding_alignment_bytes": alignment_bytes, + "pad_outputs": True, + "padding_stride_threshold": 0, + } + with config.patch(config_patches): + torch._dynamo.mark_dynamic(input_tensors[0], 0) + torch._dynamo.mark_dynamic(input_tensors[1], 0) + compiled = torch.compile(torch.add) + result, _ = run_and_get_code(compiled, *input_tensors) + + expected_stride = get_padded_stride( + result.shape, alignment_bytes, pad_output, result.dtype.itemsize + ) + self.assertEqual(result.stride(), expected_stride) + + @parametrize( + "shape,alignment_bytes,pad_output", + [ + ((500, 10, 1), 32, False), + ((500, 20, 1), 32, True), + ((30, 10, 20), 64, True), + ((30, 10, 20), 64, False), + ], + ) + def test_perm_outer_dynamic_shape_padding(self, shape, alignment_bytes, pad_output): + """ + When only the outermost dim is dynamic shape, the output can still be padded up + based on padding configuration. Test when this occurs after a permute op. + """ + + def permute_contig(x): + return torch.transpose(x, 0, 2).contiguous() + + num_inputs = 1 + input_tensors = [ + torch.randn(shape, dtype=torch.float32) for _ in range(num_inputs) + ] + + config_patches = { + "comprehensive_padding": pad_output, + "cpu_backend": "triton", + "padding_alignment_bytes": alignment_bytes, + "pad_outputs": True, + "padding_stride_threshold": 0, + "triton.use_block_ptr": True, + } + with config.patch(config_patches): + torch._dynamo.mark_dynamic(input_tensors[0], 2) + compiled = torch.compile(permute_contig) + result, _ = run_and_get_code(compiled, *input_tensors) + + expected_stride = get_padded_stride( + result.shape, alignment_bytes, pad_output, result.dtype.itemsize + ) + self.assertEqual(result.stride(), expected_stride) if __name__ == "__main__": diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index a668cd41ebf1..39c344a18466 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3733,10 +3733,8 @@ def _pad_strides( # do for dynamic shape. # # Skip padding the strides for dynamic shape for now. - if not all( - isinstance(s, (int, sympy.Integer)) - for s in itertools.chain(in_strides, size) - ): + # If outermost dim is dynamic, stride still can be fully static + if not all(isinstance(s, (int, sympy.Integer)) for s in in_strides): return in_strides stride_order = get_stride_order(in_strides) @@ -3751,11 +3749,11 @@ def _pad_strides( for rank, idx in enumerate(fill_order[1:], start=1): prev_idx = fill_order[rank - 1] stride = new_strides[prev_idx] * size[prev_idx] - - if stride > config.padding_stride_threshold and stride % align != 0: - stride = ceildiv(stride, align) * align - padded = True - new_strides[idx] = stride + if isinstance(stride, (int, sympy.Integer)): + if stride > config.padding_stride_threshold and stride % align != 0: + stride = ceildiv(stride, align) * align + padded = True + new_strides[idx] = stride if not padded: # Consider a tensor with shape [256, 1, 5, 5]