@@ -297,13 +297,13 @@ static PoolSizes process_pool_sizes(const Tensor& input,
297
297
pooling_dims,
298
298
" ints" );
299
299
300
- TORCH_CHECK (stride.empty () || stride.size () == 1 || stride.size () == 3 ,
300
+ TORCH_CHECK (stride.empty () || stride.size () == 1 || stride.size () == pooling_dims ,
301
301
op_name,
302
302
" : stride must either be omitted, a single int, or a tuple of " ,
303
303
pooling_dims,
304
304
" ints" );
305
305
306
- TORCH_CHECK (padding.size () == 1 || padding.size () == 3 ,
306
+ TORCH_CHECK (padding.size () == 1 || padding.size () == pooling_dims ,
307
307
op_name,
308
308
" : padding must either be a single int, or a tuple of " ,
309
309
pooling_dims,
@@ -333,6 +333,22 @@ static PoolSizes process_pool_sizes(const Tensor& input,
333
333
" : pad should be at most half of effective kernel size" );
334
334
}
335
335
336
+ if (pooling_dims == 2 ) {
337
+ const auto memory_format = input.suggest_memory_format ();
338
+ bool valid_dims = input.size (1 ) != 0 && input.size (2 ) != 0 ;
339
+ if (memory_format == at::MemoryFormat::ChannelsLast) {
340
+ // Expect tensor in NHWC format and allow 0-dim only for N.
341
+ TORCH_CHECK ((dims == 4 && valid_dims && input.size (3 ) != 0 ),
342
+ " Expected 4D (batch mode) tensor expected for input with channels_last layout"
343
+ " with optional 0 dim batch size for input, but got: " ,
344
+ input.sizes ());
345
+ } else {
346
+ TORCH_CHECK ((dims == 3 && input.size (0 ) != 0 && valid_dims) || (dims == 4 && valid_dims && input.size (3 ) != 0 ),
347
+ " Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got:" ,
348
+ input.sizes ());
349
+ }
350
+ }
351
+
336
352
for (const auto dim : c10::irange (static_cast <int >(leading_dims == 2 ), dims)) {
337
353
TORCH_CHECK (input.size (dim) > 0 , op_name, " : Expected input's non-batch dimensions to have positive length" );
338
354
}
@@ -786,31 +802,54 @@ static void avg_pool_backward_out_mps_template(const Tensor& grad_input,
786
802
787
803
} // namespace mps
788
804
805
+ // TODO: The MPS graph impl can sometimes give significantly better performance
806
+ // than the Metal impl for cases where the stride is 1 in all dimensions. There
807
+ // may be a code path in the graph kernel that specifically optimizes for that
808
+ // case. We should look into implementing a specialized case in Metal so we can
809
+ // avoid using the graph impl.
810
+ static bool use_graph_for_max_pool2d (IntArrayRef kernel_size, IntArrayRef stride_) {
811
+ IntArrayRef stride = stride_.empty () ? kernel_size : stride_;
812
+ return (stride[0 ] == 1 ) && (stride.size () == 1 || stride[1 ] == 1 );
813
+ }
814
+
789
815
Tensor mps_max_pool2d (const Tensor& input,
790
816
IntArrayRef kernel_size,
791
817
IntArrayRef stride,
792
818
IntArrayRef padding,
793
819
IntArrayRef dilation,
794
820
bool ceil_mode) {
795
821
Tensor output = at::empty ({0 }, input.options (), MemoryFormat::Contiguous);
796
- mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn (cachedGraph, desc) {
797
- MPSGraph* mpsGraph = cachedGraph.graph ();
798
- return [mpsGraph maxPooling2DWithSourceTensor: cachedGraph.inputTensor descriptor: desc name: nil ];
799
- };
800
- mps::pool2d_template (input,
801
- output,
802
- std::nullopt,
803
- std::nullopt,
804
- kernel_size,
805
- stride,
806
- padding,
807
- dilation,
808
- ceil_mode,
809
- false ,
810
- std::nullopt,
811
- pooling_op_block,
812
- " max_pool2d" );
813
-
822
+ bool use_graph = use_graph_for_max_pool2d (kernel_size, stride);
823
+ if (use_graph) {
824
+ mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn (cachedGraph, desc) {
825
+ MPSGraph* mpsGraph = cachedGraph.graph ();
826
+ return [mpsGraph maxPooling2DWithSourceTensor: cachedGraph.inputTensor descriptor: desc name: nil ];
827
+ };
828
+ mps::pool2d_template (input,
829
+ output,
830
+ std::nullopt,
831
+ std::nullopt,
832
+ kernel_size,
833
+ stride,
834
+ padding,
835
+ dilation,
836
+ ceil_mode,
837
+ false ,
838
+ std::nullopt,
839
+ pooling_op_block,
840
+ " max_pool2d" );
841
+ } else {
842
+ mps::max_pool_with_indices_out_mps_template (output,
843
+ std::nullopt,
844
+ input,
845
+ kernel_size,
846
+ stride,
847
+ padding,
848
+ dilation,
849
+ ceil_mode,
850
+ /* pooling_dims=*/ 2 ,
851
+ " max_pool2d" );
852
+ }
814
853
return output;
815
854
}
816
855
@@ -855,32 +894,45 @@ Tensor mps_max_pool2d_backward(const Tensor& grad_output,
855
894
bool ceil_mode,
856
895
const Tensor& output,
857
896
const Tensor& indices) {
858
- auto indices_memory_format = indices.suggest_memory_format ();
859
-
860
- mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn (cachedGraph, desc) {
861
- MPSGraph* mpsGraph = cachedGraph.graph ();
862
- NSArray <MPSGraphTensor*>* poolOutputs = [mpsGraph maxPooling2DReturnIndicesWithSourceTensor: cachedGraph.inputTensor
863
- descriptor: desc
864
- name: nil ];
865
- cachedGraph.indicesTensor = mps::castMPSTensor (mpsGraph, poolOutputs[1 ], ScalarType::Long);
866
- return poolOutputs[0 ];
867
- };
868
- mps::pool2d_template (input,
869
- output,
870
- indices,
871
- std::nullopt,
872
- kernel_size,
873
- stride,
874
- padding,
875
- dilation,
876
- ceil_mode,
877
- false ,
878
- std::nullopt,
879
- pooling_op_block,
880
- " max_pool2d_indices" );
897
+ bool use_graph = use_graph_for_max_pool2d (kernel_size, stride);
898
+ if (use_graph) {
899
+ auto indices_memory_format = indices.suggest_memory_format ();
900
+
901
+ mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn (cachedGraph, desc) {
902
+ MPSGraph* mpsGraph = cachedGraph.graph ();
903
+ NSArray <MPSGraphTensor*>* poolOutputs =
904
+ [mpsGraph maxPooling2DReturnIndicesWithSourceTensor: cachedGraph.inputTensor descriptor: desc name: nil ];
905
+ cachedGraph.indicesTensor = mps::castMPSTensor (mpsGraph, poolOutputs[1 ], ScalarType::Long);
906
+ return poolOutputs[0 ];
907
+ };
908
+ mps::pool2d_template (input,
909
+ output,
910
+ indices,
911
+ std::nullopt,
912
+ kernel_size,
913
+ stride,
914
+ padding,
915
+ dilation,
916
+ ceil_mode,
917
+ false ,
918
+ std::nullopt,
919
+ pooling_op_block,
920
+ " max_pool2d_indices" );
921
+ if (indices_memory_format == MemoryFormat::ChannelsLast) {
922
+ const_cast <Tensor&>(indices) = indices.to (MemoryFormat::ChannelsLast);
923
+ }
881
924
882
- if (indices_memory_format == MemoryFormat::ChannelsLast) {
883
- const_cast <Tensor&>(indices) = indices.to (MemoryFormat::ChannelsLast);
925
+ } else {
926
+ mps::max_pool_with_indices_out_mps_template (output,
927
+ indices,
928
+ input,
929
+ kernel_size,
930
+ stride,
931
+ padding,
932
+ dilation,
933
+ ceil_mode,
934
+ /* pooling_dims=*/ 2 ,
935
+ " max_pool2d" );
884
936
}
885
937
}
886
938
0 commit comments