From db48618f5f50a54743408ada22a6bbb64204bf48 Mon Sep 17 00:00:00 2001 From: Narek Malkhasyan Date: Mon, 26 May 2025 23:37:47 +0400 Subject: [PATCH 1/4] Updated max_pool2d and max_pool3d padding validation to account for dilation --- aten/src/ATen/native/Pool.h | 39 ++++++++++++++++++++++++++++++------ torch/_meta_registrations.py | 19 ++++++++++++------ 2 files changed, 46 insertions(+), 12 deletions(-) diff --git a/aten/src/ATen/native/Pool.h b/aten/src/ATen/native/Pool.h index 51d19102ad93..4250ac7b1ff5 100644 --- a/aten/src/ATen/native/Pool.h +++ b/aten/src/ATen/native/Pool.h @@ -153,9 +153,20 @@ pool2d_shape_check( input.sizes()); } - TORCH_CHECK(kW/2 >= padW && kH/2 >= padH, - "pad should be smaller than or equal to half of kernel size, but got ", - "padW = ", padW, ", padH = ", padH, ", kW = ", kW, ", kH = ", kH); + int effectiveKW = (kW - 1) * dilationW + 1; + int effectiveKH = (kH - 1) * dilationH + 1; + + TORCH_CHECK( + effectiveKW / 2 >= padW && effectiveKH / 2 >= padH, + "pad should be smaller than or equal to half of effective kernel size, but got ", + "padW = ", + padW, + ", padH = ", + padH, + ", effectiveKW = ", + effectiveKW, + ", effectiveKH = ", + effectiveKH); TORCH_CHECK(outputWidth >= 1 && outputHeight >= 1, "Given input size: (", @@ -276,9 +287,25 @@ pool3d_shape_check( "kernel size ", "(kT: ", kT, " kH: ", kH, " kW: ", kW, ")"); } - TORCH_CHECK(kT/2 >= pT && kW/2 >= pW && kH/2 >= pH, - "pad should be smaller than or equal to half of kernel size, but got " - "kT: ", kT, " kW: ", kW, " kH: ", kH, " padT: ", pT, " padW: ", pW, " padH: ", pH); + int effectiveKT = (kT - 1) * dilationT + 1; + int effectiveKW = (kW - 1) * dilationW + 1; + int effectiveKH = (kH - 1) * dilationH + 1; + + TORCH_CHECK( + kT / 2 >= pT && kW / 2 >= pW && kH / 2 >= pH, + "pad should be smaller than or equal to half of effective kernel size, but got " + "padT: ", + pT, + " padW: ", + pW, + " padH: ", + pH, + " effectiveKT: ", + effectiveKT, + " effectiveKW: ", + effectiveKW, + " effectiveKH: ", + effectiveKH); TORCH_CHECK(otime >= 1 && owidth >= 1 && oheight >= 1, "Given input size: (", diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index c590e4d0dfbe..45a93c15354e 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -4451,10 +4451,13 @@ def pool2d_shape_check( lambda: f"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got: {input.size()}", ) + effectiveKW = (kW - 1) * dilationW + 1 + effectiveKH = (kH - 1) * dilationH + 1 + torch._check( - kW // 2 >= padW and kH // 2 >= padH, - lambda: "pad should be smaller than or equal to half of kernel size, but got " - f"padW = {padW}, padH = {padH}, kW = {kW}, kH = {kH}", + effectiveKW // 2 >= padW and effectiveKH // 2 >= padH, + lambda: "pad should be smaller than or equal to half of effective kernel size, but got " + f"padW = {padW}, padH = {padH}, effectiveKW = {effectiveKW}, effectiveKH = {effectiveKH}", ) torch._check( @@ -4539,11 +4542,15 @@ def pool3d_shape_check( ), ) + effectiveKT = (kT - 1) * dilationT + 1 + effectiveKW = (kW - 1) * dilationW + 1 + effectiveKH = (kH - 1) * dilationH + 1 + torch._check( - kT / 2 >= pT and kW / 2 >= pW and kH / 2 >= pH, + effectiveKT / 2 >= pT and effectiveKW / 2 >= pW and effectiveKH / 2 >= pH, lambda: ( - f"pad should be smaller than or equal to half of kernel size, but got " - f"kT: {kT} kW: {kW} kH: {kH} padT: {pT} padW: {pW} padH: {pH}" + f"pad should be smaller than or equal to half of effective kernel size, but got " + f"padT: {pT}, padW: {pW}, padH: {pH}, effectiveKT: {effectiveKT}, effectiveKW: {effectiveKW}, effectiveKH: {effectiveKH}" ), ) From ff5c35ff31a3a171a196ba91c48fff87dce11ad6 Mon Sep 17 00:00:00 2001 From: Narek Malkhasyan Date: Tue, 27 May 2025 09:36:12 +0400 Subject: [PATCH 2/4] Minor fix --- aten/src/ATen/native/Pool.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/native/Pool.h b/aten/src/ATen/native/Pool.h index 4250ac7b1ff5..e7d2ed3d34e9 100644 --- a/aten/src/ATen/native/Pool.h +++ b/aten/src/ATen/native/Pool.h @@ -292,7 +292,7 @@ pool3d_shape_check( int effectiveKH = (kH - 1) * dilationH + 1; TORCH_CHECK( - kT / 2 >= pT && kW / 2 >= pW && kH / 2 >= pH, + effectiveKT / 2 >= pT && effectiveKW / 2 >= pW && effectiveKH / 2 >= pH, "pad should be smaller than or equal to half of effective kernel size, but got " "padT: ", pT, From 6787f6a2af533d2a771220bb6fcf82751dc5bcb5 Mon Sep 17 00:00:00 2001 From: Narek Malkhasyan Date: Wed, 28 May 2025 18:26:11 +0400 Subject: [PATCH 3/4] Fixed linter error --- torch/_meta_registrations.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 45a93c15354e..62c1b9931dfa 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -4550,7 +4550,8 @@ def pool3d_shape_check( effectiveKT / 2 >= pT and effectiveKW / 2 >= pW and effectiveKH / 2 >= pH, lambda: ( f"pad should be smaller than or equal to half of effective kernel size, but got " - f"padT: {pT}, padW: {pW}, padH: {pH}, effectiveKT: {effectiveKT}, effectiveKW: {effectiveKW}, effectiveKH: {effectiveKH}" + f"padT: {pT}, padW: {pW}, padH: {pH}, " + f"effectiveKT: {effectiveKT}, effectiveKW: {effectiveKW}, effectiveKH: {effectiveKH}" ), ) From caca0c6427aa54aa254fff8652f608572d1615db Mon Sep 17 00:00:00 2001 From: Narek Malkhasyan Date: Wed, 28 May 2025 19:12:27 +0400 Subject: [PATCH 4/4] Updated max_pool1d padding validation --- aten/src/ATen/native/MaxPooling.h | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/MaxPooling.h b/aten/src/ATen/native/MaxPooling.h index 50d1205ba3ce..59d87442c508 100644 --- a/aten/src/ATen/native/MaxPooling.h +++ b/aten/src/ATen/native/MaxPooling.h @@ -40,6 +40,8 @@ inline void check_max_pool1d( stride = kernel_size; } + int effective_kernel_size = (kernel_size[0] - 1) * dilation[0] + 1; + TORCH_CHECK( kernel_size[0] > 0, "max_pool1d() kernel_size must be greater than zero, but got ", @@ -49,11 +51,11 @@ inline void check_max_pool1d( TORCH_CHECK( padding[0] >= 0, "max_pool1d() padding must be non-negative, but got ", padding[0]); TORCH_CHECK( - padding[0] <= kernel_size[0] / 2, - "max_pool1d() padding should be at most half of kernel size, but got padding=", + padding[0] <= effective_kernel_size / 2, + "max_pool1d() padding should be at most half of effective kernel size, but got padding=", padding[0], - " and kernel_size=", - kernel_size[0]); + " and effective_kernel_size=", + effective_kernel_size); TORCH_CHECK( dilation[0] > 0, "max_pool1d() dilation must be greater than zero, but got ", dilation[0]);