diff --git a/aten/src/ATen/native/MaxPooling.h b/aten/src/ATen/native/MaxPooling.h index 50d1205ba3ce..59d87442c508 100644 --- a/aten/src/ATen/native/MaxPooling.h +++ b/aten/src/ATen/native/MaxPooling.h @@ -40,6 +40,8 @@ inline void check_max_pool1d( stride = kernel_size; } + int effective_kernel_size = (kernel_size[0] - 1) * dilation[0] + 1; + TORCH_CHECK( kernel_size[0] > 0, "max_pool1d() kernel_size must be greater than zero, but got ", @@ -49,11 +51,11 @@ inline void check_max_pool1d( TORCH_CHECK( padding[0] >= 0, "max_pool1d() padding must be non-negative, but got ", padding[0]); TORCH_CHECK( - padding[0] <= kernel_size[0] / 2, - "max_pool1d() padding should be at most half of kernel size, but got padding=", + padding[0] <= effective_kernel_size / 2, + "max_pool1d() padding should be at most half of effective kernel size, but got padding=", padding[0], - " and kernel_size=", - kernel_size[0]); + " and effective_kernel_size=", + effective_kernel_size); TORCH_CHECK( dilation[0] > 0, "max_pool1d() dilation must be greater than zero, but got ", dilation[0]); diff --git a/aten/src/ATen/native/Pool.h b/aten/src/ATen/native/Pool.h index 51d19102ad93..e7d2ed3d34e9 100644 --- a/aten/src/ATen/native/Pool.h +++ b/aten/src/ATen/native/Pool.h @@ -153,9 +153,20 @@ pool2d_shape_check( input.sizes()); } - TORCH_CHECK(kW/2 >= padW && kH/2 >= padH, - "pad should be smaller than or equal to half of kernel size, but got ", - "padW = ", padW, ", padH = ", padH, ", kW = ", kW, ", kH = ", kH); + int effectiveKW = (kW - 1) * dilationW + 1; + int effectiveKH = (kH - 1) * dilationH + 1; + + TORCH_CHECK( + effectiveKW / 2 >= padW && effectiveKH / 2 >= padH, + "pad should be smaller than or equal to half of effective kernel size, but got ", + "padW = ", + padW, + ", padH = ", + padH, + ", effectiveKW = ", + effectiveKW, + ", effectiveKH = ", + effectiveKH); TORCH_CHECK(outputWidth >= 1 && outputHeight >= 1, "Given input size: (", @@ -276,9 +287,25 @@ pool3d_shape_check( "kernel size ", "(kT: ", kT, " kH: ", kH, " kW: ", kW, ")"); } - TORCH_CHECK(kT/2 >= pT && kW/2 >= pW && kH/2 >= pH, - "pad should be smaller than or equal to half of kernel size, but got " - "kT: ", kT, " kW: ", kW, " kH: ", kH, " padT: ", pT, " padW: ", pW, " padH: ", pH); + int effectiveKT = (kT - 1) * dilationT + 1; + int effectiveKW = (kW - 1) * dilationW + 1; + int effectiveKH = (kH - 1) * dilationH + 1; + + TORCH_CHECK( + effectiveKT / 2 >= pT && effectiveKW / 2 >= pW && effectiveKH / 2 >= pH, + "pad should be smaller than or equal to half of effective kernel size, but got " + "padT: ", + pT, + " padW: ", + pW, + " padH: ", + pH, + " effectiveKT: ", + effectiveKT, + " effectiveKW: ", + effectiveKW, + " effectiveKH: ", + effectiveKH); TORCH_CHECK(otime >= 1 && owidth >= 1 && oheight >= 1, "Given input size: (", diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index c590e4d0dfbe..62c1b9931dfa 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -4451,10 +4451,13 @@ def pool2d_shape_check( lambda: f"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got: {input.size()}", ) + effectiveKW = (kW - 1) * dilationW + 1 + effectiveKH = (kH - 1) * dilationH + 1 + torch._check( - kW // 2 >= padW and kH // 2 >= padH, - lambda: "pad should be smaller than or equal to half of kernel size, but got " - f"padW = {padW}, padH = {padH}, kW = {kW}, kH = {kH}", + effectiveKW // 2 >= padW and effectiveKH // 2 >= padH, + lambda: "pad should be smaller than or equal to half of effective kernel size, but got " + f"padW = {padW}, padH = {padH}, effectiveKW = {effectiveKW}, effectiveKH = {effectiveKH}", ) torch._check( @@ -4539,11 +4542,16 @@ def pool3d_shape_check( ), ) + effectiveKT = (kT - 1) * dilationT + 1 + effectiveKW = (kW - 1) * dilationW + 1 + effectiveKH = (kH - 1) * dilationH + 1 + torch._check( - kT / 2 >= pT and kW / 2 >= pW and kH / 2 >= pH, + effectiveKT / 2 >= pT and effectiveKW / 2 >= pW and effectiveKH / 2 >= pH, lambda: ( - f"pad should be smaller than or equal to half of kernel size, but got " - f"kT: {kT} kW: {kW} kH: {kH} padT: {pT} padW: {pW} padH: {pH}" + f"pad should be smaller than or equal to half of effective kernel size, but got " + f"padT: {pT}, padW: {pW}, padH: {pH}, " + f"effectiveKT: {effectiveKT}, effectiveKW: {effectiveKW}, effectiveKH: {effectiveKH}" ), )