Skip to content

Commit a5f59cc

Browse files
eqypytorchmergebot
authored andcommitted
[cuDNN][64-bit indexing] update conv depthwise 64bit indexing dispatch condition to match native kernel (#156140)
The native kernel doesn't support batch splitting so the previous check wasn't aggressive enough in dispatching to cuDNN #155225 Pull Request resolved: #156140 Approved by: https://github.com/ngimel
1 parent 94f8679 commit a5f59cc

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

aten/src/ATen/native/Convolution.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <ATen/Config.h>
44
#include <ATen/Parallel.h>
55
#include <ATen/TensorOperators.h>
6+
#include <ATen/native/CanUse32BitIndexMath.h>
67
#include <ATen/native/ConvolutionMM3d.h>
78
#include <ATen/native/ConvUtils.h>
89
#include <ATen/native/Pool.h>
@@ -463,7 +464,7 @@ struct ConvParams {
463464
return true;
464465
}
465466
// native kernel doesn't support 64-bit non-splittable case
466-
if (cudnn_enabled && needs_64bit_indexing_no_split(input, weight)) {
467+
if (cudnn_enabled && !(canUse32BitIndexMath(input) && canUse32BitIndexMath(weight))) {
467468
static long cudnn_version = detail::getCUDAHooks().compiledWithCuDNN() ? detail::getCUDAHooks().versionCuDNN() : -1;
468469
if (!(cudnn_version >= 90300 && at::native::cudnnv8_enabled_check_debug())) {
469470
TORCH_WARN_ONCE("cuDNN cannot be used for large non-batch-splittable convolutions"

test/nn/test_convolution.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4057,11 +4057,22 @@ def test_conv3d_64bit_indexing(self, device):
40574057
@largeTensorTest("20GB")
40584058
@largeTensorTest("80GB", "cpu")
40594059
def test_depthwise_conv_64bit_indexing(self, device):
4060-
x = torch.randn(1, 2, 32800, 32800)
4061-
c = nn.Conv2d(2, 2, kernel_size=3, stride=1, padding=1, groups=2)
4060+
x = torch.randn(1, 2, 32800, 32800, dtype=torch.bfloat16).to(
4061+
memory_format=torch.channels_last
4062+
)
4063+
c = nn.Conv2d(
4064+
2, 2, kernel_size=3, stride=1, padding=1, groups=2, dtype=torch.bfloat16
4065+
).to(memory_format=torch.channels_last)
40624066
yref = c(x)
40634067
y = c.to(device=device)(x.to(device=device))
4064-
self.assertEqual(yref, y)
4068+
self.assertEqual(yref, y, atol=1e-3, rtol=1e-4)
4069+
del y, yref
4070+
4071+
# try a batch-splittable case
4072+
x = x.reshape(100, 2, 3280, 3280).contiguous(memory_format=torch.channels_last)
4073+
yref = c(x)
4074+
y = c.to(device=device)(x.to(device=device))
4075+
self.assertEqual(yref, y, atol=1e-3, rtol=1e-4)
40654076

40664077

40674078
instantiate_device_type_tests(TestConvolutionNNDeviceType, globals(), allow_mps=True)

0 commit comments

Comments
 (0)