@@ -49,6 +49,18 @@ def geninp():
49
49
return input_dict
50
50
51
51
52
+ def get_padded_stride (shape , alignment_bytes , pad_output , itemsize ):
53
+ align = alignment_bytes // itemsize
54
+ new_strides = [0 for _ in range (len (shape ))]
55
+ new_strides [len (shape ) - 1 ] = 1
56
+ for i in range (len (shape ) - 1 , 0 , - 1 ):
57
+ stride = shape [i ] * new_strides [i ]
58
+ if pad_output and stride % align != 0 :
59
+ stride = (stride + align - 1 ) // align * align
60
+ new_strides [i - 1 ] = stride
61
+ return tuple (new_strides )
62
+
63
+
52
64
class LinearAndSoftmax (nn .Module ):
53
65
"""
54
66
It's very common that a transformer model will do a matmul and then
@@ -745,20 +757,11 @@ def get_input(size: tuple[int], alignment_bytes: int) -> torch.Tensor:
745
757
input_tensors = [get_input (shape , alignment_bytes ) for _ in range (num_inputs )]
746
758
747
759
config_patches = {
748
- "compile_threads" : 1 ,
749
760
"comprehensive_padding" : pad_output ,
750
761
"cpu_backend" : "triton" ,
751
- "disable_padding_cpu" : False ,
752
- "implicit_fallbacks" : False ,
753
- "inplace_buffers" : False ,
754
762
"padding_alignment_bytes" : alignment_bytes ,
755
- "pad_channels_last" : True ,
756
763
"pad_outputs" : True ,
757
764
"padding_stride_threshold" : 0 ,
758
- "triton.prefer_nd_tiling" : True ,
759
- "triton.use_block_ptr" : True ,
760
- "triton.codegen_upcast_to_fp32" : False ,
761
- "unroll_reductions_threshold" : 1 ,
762
765
}
763
766
with config .patch (config_patches ):
764
767
compiled = torch .compile (torch .cat )
@@ -767,7 +770,89 @@ def get_input(size: tuple[int], alignment_bytes: int) -> torch.Tensor:
767
770
output_shape = (shape [0 ] * num_inputs , shape [1 ])
768
771
output_stride = input_tensors [0 ].stride ()
769
772
output_line = f"buf12 = empty_strided_{ GPU_TYPE } ({ output_shape } , { output_stride } , torch.float32)"
770
- self .assertTrue (any (output_line in line for line in code ))
773
+ self .assertTrue (output_line in code [0 ])
774
+
775
+ @parametrize (
776
+ "shape,alignment_bytes,pad_output" ,
777
+ [
778
+ ((512 , 1 ), 32 , False ),
779
+ ((512 , 1 ), 32 , True ),
780
+ ((32 , 30 ), 64 , False ),
781
+ ((32 , 30 ), 64 , True ),
782
+ ((512 , 100 , 1 ), 32 , False ),
783
+ ((512 , 100 , 1 ), 32 , True ),
784
+ ((32 , 50 , 30 ), 64 , False ),
785
+ ((32 , 50 , 30 ), 64 , True ),
786
+ ],
787
+ )
788
+ def test_outer_dynamic_shape_padding (self , shape , alignment_bytes , pad_output ):
789
+ """
790
+ When only the outter most dim is dynamic shape, the output can still be padded up
791
+ based on padding configuration.
792
+ """
793
+ num_inputs = 2
794
+ input_tensors = [
795
+ torch .randn (shape , dtype = torch .float32 ) for _ in range (num_inputs )
796
+ ]
797
+
798
+ config_patches = {
799
+ "comprehensive_padding" : pad_output ,
800
+ "cpu_backend" : "triton" ,
801
+ "padding_alignment_bytes" : alignment_bytes ,
802
+ "pad_outputs" : True ,
803
+ "padding_stride_threshold" : 0 ,
804
+ }
805
+ with config .patch (config_patches ):
806
+ torch ._dynamo .mark_dynamic (input_tensors [0 ], 0 )
807
+ torch ._dynamo .mark_dynamic (input_tensors [1 ], 0 )
808
+ compiled = torch .compile (torch .add )
809
+ result , _ = run_and_get_code (compiled , * input_tensors )
810
+
811
+ expected_stride = get_padded_stride (
812
+ result .shape , alignment_bytes , pad_output , result .dtype .itemsize
813
+ )
814
+ self .assertEqual (result .stride (), expected_stride )
815
+
816
+ @parametrize (
817
+ "shape,alignment_bytes,pad_output" ,
818
+ [
819
+ ((500 , 10 , 1 ), 32 , False ),
820
+ ((500 , 20 , 1 ), 32 , True ),
821
+ ((30 , 10 , 20 ), 64 , True ),
822
+ ((30 , 10 , 20 ), 64 , False ),
823
+ ],
824
+ )
825
+ def test_perm_outer_dynamic_shape_padding (self , shape , alignment_bytes , pad_output ):
826
+ """
827
+ When only the outter most dim is dynamic shape, the output can still be padded up
828
+ based on padding configuration. Test when this occurs after a permute op.
829
+ """
830
+
831
+ def permute_contig (x ):
832
+ return torch .transpose (x , 0 , 2 ).contiguous ()
833
+
834
+ num_inputs = 1
835
+ input_tensors = [
836
+ torch .randn (shape , dtype = torch .float32 ) for _ in range (num_inputs )
837
+ ]
838
+
839
+ config_patches = {
840
+ "comprehensive_padding" : pad_output ,
841
+ "cpu_backend" : "triton" ,
842
+ "padding_alignment_bytes" : alignment_bytes ,
843
+ "pad_outputs" : True ,
844
+ "padding_stride_threshold" : 0 ,
845
+ "triton.use_block_ptr" : True ,
846
+ }
847
+ with config .patch (config_patches ):
848
+ torch ._dynamo .mark_dynamic (input_tensors [0 ], 2 )
849
+ compiled = torch .compile (permute_contig )
850
+ result , _ = run_and_get_code (compiled , * input_tensors )
851
+
852
+ expected_stride = get_padded_stride (
853
+ result .shape , alignment_bytes , pad_output , result .dtype .itemsize
854
+ )
855
+ self .assertEqual (result .stride (), expected_stride )
771
856
772
857
773
858
if __name__ == "__main__" :
0 commit comments