diff --git a/aten/src/ATen/native/mps/operations/Pad.mm b/aten/src/ATen/native/mps/operations/Pad.mm index 0c2c25946bb4b..5c1cd773b8c69 100644 --- a/aten/src/ATen/native/mps/operations/Pad.mm +++ b/aten/src/ATen/native/mps/operations/Pad.mm @@ -19,6 +19,7 @@ #include #include #include +#include #endif namespace at::native { @@ -243,75 +244,113 @@ dataType = MPSDataTypeInt8; } - @autoreleasepool { - std::string key = op_name + getTensorsStringKey({input, grad_output, output}) + ":[" + getArrayRefString(padding) + - "]:" + std::to_string(constantValue); - - auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { - newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, dataType, getMPSShape(input)); - const bool needsSlice = startMask != dims_mask || endMask != dims_mask; - - if (!is_backward_pass) { - MPSGraphTensor* padTensor = [mpsGraph padTensor:newCachedGraph->inputTensor_ - withPaddingMode:mode - leftPadding:leftPadding - rightPadding:rightPadding - constantValue:constantValue - name:nil]; - // workaround for the right padding bug in Monterey - if (needsSlice) { - newCachedGraph->gradInputTensor_ = - [mpsGraph sliceTensor:padTensor - starts:[NSArray arrayWithObjects:startsVec.data() count:ndims] - ends:[NSArray arrayWithObjects:endsVec.data() count:ndims] - strides:[NSArray arrayWithObjects:stridesVec.data() count:ndims] - startMask:startMask - endMask:endMask - squeezeMask:0 - name:nil]; - } else { - newCachedGraph->gradInputTensor_ = padTensor; - } - } else { - newCachedGraph->gradOutputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, dataType, getMPSShape(grad_output)); - MPSGraphTensor* padGradTensor = - [mpsGraph padGradientWithIncomingGradientTensor:newCachedGraph->gradOutputTensor_ - sourceTensor:newCachedGraph->inputTensor_ - paddingMode:mode - leftPadding:leftPadding - rightPadding:rightPadding - name:nil]; - // workaround for negative padding issue with padGradientWithIncomingGradientTensor() - if (needsSlice) { - newCachedGraph->gradInputTensor_ = - [mpsGraph sliceGradientTensor:padGradTensor - fwdInShapeTensor:[mpsGraph shapeOfTensor:newCachedGraph->inputTensor_ name:nil] - starts:[NSArray arrayWithObjects:startsVec.data() count:ndims] - ends:[NSArray arrayWithObjects:endsVec.data() count:ndims] - strides:[NSArray arrayWithObjects:stridesVec.data() count:ndims] - startMask:startMask - endMask:endMask - squeezeMask:0 - name:nil]; + // For tensor with rank equal 3 or 4 and padding mode replicate1d/2d, when the 3rd from the + // last dimension is 2**16 or greater, MPSGraph returns incorrect gradient. To work around this, + // we break the tensor into chuncks where the problematic dimention is no greater than 2**16-1. + // This is reported in https://github.com/pytorch/pytorch/issues/135447. + // Internal radar for MPSGraph: rdar://149853787. + constexpr auto max_sub_batch_size = 65535; + int64_t sliced_dim = -1; + int64_t sub_batch_start = 0; + int64_t remaining_batch_size = 0; + if ((ndims == 3 || ndims == 4) && mode == MPSGraphPaddingModeClampToEdge && pad_front == 0 && pad_back == 0) { + int64_t batch_size = input_.size(-3); + if (batch_size > max_sub_batch_size) { + sliced_dim = ndims - 3; + remaining_batch_size = batch_size; + } + } + do { + Tensor sub_batch_input = input; + Tensor sub_batch_grad_output = grad_output; + Tensor sub_batch_output = output; + + if (sliced_dim >= 0) { + int64_t sub_batch_size = + is_backward_pass ? std::min(remaining_batch_size, max_sub_batch_size) : remaining_batch_size; + sub_batch_input = at::slice(input, sliced_dim, sub_batch_start, sub_batch_start + sub_batch_size); + sub_batch_output = at::slice(output, sliced_dim, sub_batch_start, sub_batch_start + sub_batch_size); + if (is_backward_pass) { + sub_batch_grad_output = at::slice(grad_output, sliced_dim, sub_batch_start, sub_batch_start + sub_batch_size); + } + remaining_batch_size -= sub_batch_size; + sub_batch_start += sub_batch_size; + } + @autoreleasepool { + std::string key = op_name + getTensorsStringKey({sub_batch_input, sub_batch_grad_output, sub_batch_output}) + + ":[" + getArrayRefString(padding) + "]:" + std::to_string(constantValue) + std::to_string(sub_batch_start); + + auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { + newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, dataType, getMPSShape(sub_batch_input)); + const bool needsSlice = startMask != dims_mask || endMask != dims_mask; + + if (!is_backward_pass) { + MPSGraphTensor* padTensor = [mpsGraph padTensor:newCachedGraph->inputTensor_ + withPaddingMode:mode + leftPadding:leftPadding + rightPadding:rightPadding + constantValue:constantValue + name:nil]; + // workaround for the right padding bug in Monterey + if (needsSlice) { + newCachedGraph->gradInputTensor_ = + [mpsGraph sliceTensor:padTensor + starts:[NSArray arrayWithObjects:startsVec.data() count:ndims] + ends:[NSArray arrayWithObjects:endsVec.data() count:ndims] + strides:[NSArray arrayWithObjects:stridesVec.data() count:ndims] + startMask:startMask + endMask:endMask + squeezeMask:0 + name:nil]; + } else { + newCachedGraph->gradInputTensor_ = padTensor; + } } else { - newCachedGraph->gradInputTensor_ = padGradTensor; + newCachedGraph->gradOutputTensor_ = + mpsGraphRankedPlaceHolder(mpsGraph, dataType, getMPSShape(sub_batch_grad_output)); + MPSGraphTensor* padGradTensor = + [mpsGraph padGradientWithIncomingGradientTensor:newCachedGraph->gradOutputTensor_ + sourceTensor:newCachedGraph->inputTensor_ + paddingMode:mode + leftPadding:leftPadding + rightPadding:rightPadding + name:nil]; + // workaround for negative padding issue with padGradientWithIncomingGradientTensor() + if (needsSlice) { + newCachedGraph->gradInputTensor_ = + [mpsGraph sliceGradientTensor:padGradTensor + fwdInShapeTensor:[mpsGraph shapeOfTensor:newCachedGraph->inputTensor_ name:nil] + starts:[NSArray arrayWithObjects:startsVec.data() count:ndims] + ends:[NSArray arrayWithObjects:endsVec.data() count:ndims] + strides:[NSArray arrayWithObjects:stridesVec.data() count:ndims] + startMask:startMask + endMask:endMask + squeezeMask:0 + name:nil]; + } else { + newCachedGraph->gradInputTensor_ = padGradTensor; + } } + }); + Placeholder inputPlaceholder = + Placeholder(cachedGraph->inputTensor_, sub_batch_input, getMPSShape(sub_batch_input), true, dataType); + Placeholder outputPlaceholder = + Placeholder(cachedGraph->gradInputTensor_, sub_batch_output, getMPSShape(sub_batch_output), true, dataType); + Placeholder gradOutputPlaceholder = !is_backward_pass ? Placeholder() + : Placeholder(cachedGraph->gradOutputTensor_, + sub_batch_grad_output, + getMPSShape(sub_batch_grad_output), + true, + dataType); + + NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease]; + feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData(); + if (is_backward_pass) { + feeds[gradOutputPlaceholder.getMPSGraphTensor()] = gradOutputPlaceholder.getMPSGraphTensorData(); } - }); - - Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input, nullptr, true, dataType); - Placeholder outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, output, nullptr, true, dataType); - Placeholder gradOutputPlaceholder = !is_backward_pass - ? Placeholder() - : Placeholder(cachedGraph->gradOutputTensor_, grad_output, nullptr, true, dataType); - - NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease]; - feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData(); - if (is_backward_pass) { - feeds[gradOutputPlaceholder.getMPSGraphTensor()] = gradOutputPlaceholder.getMPSGraphTensorData(); + runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outputPlaceholder); } - runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outputPlaceholder); - } + } while (remaining_batch_size > 0); return output; } } // namespace mps diff --git a/test/test_nn.py b/test/test_nn.py index 353e0d6abd804..7c2f31a38d71c 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -43,7 +43,7 @@ from torch.testing._internal.common_device_type import dtypesIfMPS, instantiate_device_type_tests, dtypes, \ dtypesIfCUDA, precisionOverride, onlyCUDA, onlyCPU, \ skipCUDAIfRocm, skipCUDAIf, skipCUDAIfNotRocm, \ - onlyNativeDeviceTypes, deviceCountAtLeast, largeTensorTest, expectedFailureMeta, expectedFailureMPS, \ + onlyNativeDeviceTypes, deviceCountAtLeast, largeTensorTest, expectedFailureMeta, expectedFailureMPS, expectedFailureMPSPre15, \ skipMeta, get_all_device_types from hypothesis import given @@ -8781,7 +8781,7 @@ def test_ReplicationPad_empty(self, device, dtype): with self.assertRaisesRegex(RuntimeError, 'padding size is expected to be 6'): torch._C._nn.replication_pad3d(torch.randn([2]), padding=[]) - @expectedFailureMPS # Correctness issue https://github.com/pytorch/pytorch/issues/135447 + @expectedFailureMPSPre15 # Correctness issue https://github.com/pytorch/pytorch/issues/135447 def test_ReplicationPad1d_large(self, device): shapes = ([2, 65736, 4], [65736, 2, 4]) pl, pr = 3, 4 @@ -8806,7 +8806,7 @@ def test_ReplicationPad1d_large(self, device): self.assertEqual(x.grad[:, :, 0], g[:, :, : pl + 1].sum(-1)) self.assertEqual(x.grad[:, :, -1], g[:, :, -pr - 1:].sum(-1)) - @expectedFailureMPS # Correctness issue https://github.com/pytorch/pytorch/issues/135447 + @expectedFailureMPSPre15 # Correctness issue https://github.com/pytorch/pytorch/issues/135447 def test_ReplicationPad2d_large(self, device): shapes = ([2, 65736, 4, 4], [65736, 2, 4, 4]) pl, pr, pt, pb = 3, 4, 5, 6