Skip to content

Commit d64f0c6

Browse files
xwu-498pytorchmergebot
authored andcommitted
Work around MPSGraph issue in backward pass of nn.ReplicationPad1d/2d.
Fixes #135447. When the 3rd from last dimension is 2^16 or greater, MPSGraph returns 0 for padgradient. To work around this, we break the problematic dimension into chunks with chunk size being no greater than 2^16 - 1. Test case for nn.ReplicationPad1d: ``` shape = [65739, 2, 4] x_cpu = torch.randn(shape, device='cpu', requires_grad=True) x_mps = x_cpu.clone().detach().to('mps').requires_grad_(True) model = torch.nn.ReplicationPad1d((1, 1)) out_cpu = model(x_cpu) out_mps = model(x_mps) # backward g_cpu = torch.randn_like(out_cpu) g_mps = g_cpu.clone().detach().to('mps').requires_grad_(False) out_cpu.backward(g_cpu) out_mps.backward(g_mps) print(f"{((x_cpu.grad - x_mps.grad.cpu()).abs() > 1e-5).sum() = }") # Expected Output: # ((x_cpu.grad - x_mps.grad.cpu()).abs() > 1e-5).sum() = tensor(0) ``` Test case for nn.ReplicationPad2d, ``` shape = [2, 65739, 2, 4] x_cpu = torch.randn(shape, device='cpu', requires_grad=True) x_mps = x_cpu.clone().detach().to('mps').requires_grad_(True) model = torch.nn.ReplicationPad2d((1, 1, 1, 1)) out_cpu = model(x_cpu) out_mps = model(x_mps) # backward g_cpu = torch.randn_like(out_cpu) g_mps = g_cpu.clone().detach().to('mps').requires_grad_(False) out_cpu.backward(g_cpu) out_mps.backward(g_mps) print(f"{((x_cpu.grad - x_mps.grad.cpu()).abs() > 1e-5).sum() = }") # Expected Output: # ((x_cpu.grad - x_mps.grad.cpu()).abs() > 1e-5).sum() = tensor(0) ``` These tests produce expected output with this workaround.
1 parent ef1d45b commit d64f0c6

File tree

2 files changed

+107
-68
lines changed

2 files changed

+107
-68
lines changed

aten/src/ATen/native/mps/operations/Pad.mm

Lines changed: 104 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <ATen/ops/replication_pad2d_native.h>
2020
#include <ATen/ops/replication_pad3d_backward_native.h>
2121
#include <ATen/ops/replication_pad3d_native.h>
22+
#include <ATen/ops/slice.h>
2223
#endif
2324

2425
namespace at::native {
@@ -243,75 +244,113 @@
243244
dataType = MPSDataTypeInt8;
244245
}
245246

246-
@autoreleasepool {
247-
std::string key = op_name + getTensorsStringKey({input, grad_output, output}) + ":[" + getArrayRefString(padding) +
248-
"]:" + std::to_string(constantValue);
249-
250-
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
251-
newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, dataType, getMPSShape(input));
252-
const bool needsSlice = startMask != dims_mask || endMask != dims_mask;
253-
254-
if (!is_backward_pass) {
255-
MPSGraphTensor* padTensor = [mpsGraph padTensor:newCachedGraph->inputTensor_
256-
withPaddingMode:mode
257-
leftPadding:leftPadding
258-
rightPadding:rightPadding
259-
constantValue:constantValue
260-
name:nil];
261-
// workaround for the right padding bug in Monterey
262-
if (needsSlice) {
263-
newCachedGraph->gradInputTensor_ =
264-
[mpsGraph sliceTensor:padTensor
265-
starts:[NSArray arrayWithObjects:startsVec.data() count:ndims]
266-
ends:[NSArray arrayWithObjects:endsVec.data() count:ndims]
267-
strides:[NSArray arrayWithObjects:stridesVec.data() count:ndims]
268-
startMask:startMask
269-
endMask:endMask
270-
squeezeMask:0
271-
name:nil];
272-
} else {
273-
newCachedGraph->gradInputTensor_ = padTensor;
274-
}
275-
} else {
276-
newCachedGraph->gradOutputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, dataType, getMPSShape(grad_output));
277-
MPSGraphTensor* padGradTensor =
278-
[mpsGraph padGradientWithIncomingGradientTensor:newCachedGraph->gradOutputTensor_
279-
sourceTensor:newCachedGraph->inputTensor_
280-
paddingMode:mode
281-
leftPadding:leftPadding
282-
rightPadding:rightPadding
283-
name:nil];
284-
// workaround for negative padding issue with padGradientWithIncomingGradientTensor()
285-
if (needsSlice) {
286-
newCachedGraph->gradInputTensor_ =
287-
[mpsGraph sliceGradientTensor:padGradTensor
288-
fwdInShapeTensor:[mpsGraph shapeOfTensor:newCachedGraph->inputTensor_ name:nil]
289-
starts:[NSArray arrayWithObjects:startsVec.data() count:ndims]
290-
ends:[NSArray arrayWithObjects:endsVec.data() count:ndims]
291-
strides:[NSArray arrayWithObjects:stridesVec.data() count:ndims]
292-
startMask:startMask
293-
endMask:endMask
294-
squeezeMask:0
295-
name:nil];
247+
// For tensor with rank equal 3 or 4 and padding mode replicate1d/2d, when the 3rd from the
248+
// last dimension is 2**16 or greater, MPSGraph returns incorrect gradient. To work around this,
249+
// we break the tensor into chuncks where the problematic dimention is no greater than 2**16-1.
250+
// This is reported in https://github.com/pytorch/pytorch/issues/135447.
251+
// Internal radar for MPSGraph: rdar://149853787.
252+
constexpr auto max_sub_batch_size = 65535;
253+
int64_t sliced_dim = -1;
254+
int64_t sub_batch_start = 0;
255+
int64_t remaining_batch_size = 0;
256+
if ((ndims == 3 || ndims == 4) && mode == MPSGraphPaddingModeClampToEdge && pad_front == 0 && pad_back == 0) {
257+
int64_t batch_size = input_.size(-3);
258+
if (batch_size > max_sub_batch_size) {
259+
sliced_dim = ndims - 3;
260+
remaining_batch_size = batch_size;
261+
}
262+
}
263+
do {
264+
Tensor sub_batch_input = input;
265+
Tensor sub_batch_grad_output = grad_output;
266+
Tensor sub_batch_output = output;
267+
268+
if (sliced_dim >= 0) {
269+
int64_t sub_batch_size =
270+
is_backward_pass ? std::min<int64_t>(remaining_batch_size, max_sub_batch_size) : remaining_batch_size;
271+
sub_batch_input = at::slice(input, sliced_dim, sub_batch_start, sub_batch_start + sub_batch_size);
272+
sub_batch_output = at::slice(output, sliced_dim, sub_batch_start, sub_batch_start + sub_batch_size);
273+
if (is_backward_pass) {
274+
sub_batch_grad_output = at::slice(grad_output, sliced_dim, sub_batch_start, sub_batch_start + sub_batch_size);
275+
}
276+
remaining_batch_size -= sub_batch_size;
277+
sub_batch_start += sub_batch_size;
278+
}
279+
@autoreleasepool {
280+
std::string key = op_name + getTensorsStringKey({sub_batch_input, sub_batch_grad_output, sub_batch_output}) +
281+
":[" + getArrayRefString(padding) + "]:" + std::to_string(constantValue) + std::to_string(sub_batch_start);
282+
283+
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
284+
newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, dataType, getMPSShape(sub_batch_input));
285+
const bool needsSlice = startMask != dims_mask || endMask != dims_mask;
286+
287+
if (!is_backward_pass) {
288+
MPSGraphTensor* padTensor = [mpsGraph padTensor:newCachedGraph->inputTensor_
289+
withPaddingMode:mode
290+
leftPadding:leftPadding
291+
rightPadding:rightPadding
292+
constantValue:constantValue
293+
name:nil];
294+
// workaround for the right padding bug in Monterey
295+
if (needsSlice) {
296+
newCachedGraph->gradInputTensor_ =
297+
[mpsGraph sliceTensor:padTensor
298+
starts:[NSArray arrayWithObjects:startsVec.data() count:ndims]
299+
ends:[NSArray arrayWithObjects:endsVec.data() count:ndims]
300+
strides:[NSArray arrayWithObjects:stridesVec.data() count:ndims]
301+
startMask:startMask
302+
endMask:endMask
303+
squeezeMask:0
304+
name:nil];
305+
} else {
306+
newCachedGraph->gradInputTensor_ = padTensor;
307+
}
296308
} else {
297-
newCachedGraph->gradInputTensor_ = padGradTensor;
309+
newCachedGraph->gradOutputTensor_ =
310+
mpsGraphRankedPlaceHolder(mpsGraph, dataType, getMPSShape(sub_batch_grad_output));
311+
MPSGraphTensor* padGradTensor =
312+
[mpsGraph padGradientWithIncomingGradientTensor:newCachedGraph->gradOutputTensor_
313+
sourceTensor:newCachedGraph->inputTensor_
314+
paddingMode:mode
315+
leftPadding:leftPadding
316+
rightPadding:rightPadding
317+
name:nil];
318+
// workaround for negative padding issue with padGradientWithIncomingGradientTensor()
319+
if (needsSlice) {
320+
newCachedGraph->gradInputTensor_ =
321+
[mpsGraph sliceGradientTensor:padGradTensor
322+
fwdInShapeTensor:[mpsGraph shapeOfTensor:newCachedGraph->inputTensor_ name:nil]
323+
starts:[NSArray arrayWithObjects:startsVec.data() count:ndims]
324+
ends:[NSArray arrayWithObjects:endsVec.data() count:ndims]
325+
strides:[NSArray arrayWithObjects:stridesVec.data() count:ndims]
326+
startMask:startMask
327+
endMask:endMask
328+
squeezeMask:0
329+
name:nil];
330+
} else {
331+
newCachedGraph->gradInputTensor_ = padGradTensor;
332+
}
298333
}
334+
});
335+
Placeholder inputPlaceholder =
336+
Placeholder(cachedGraph->inputTensor_, sub_batch_input, getMPSShape(sub_batch_input), true, dataType);
337+
Placeholder outputPlaceholder =
338+
Placeholder(cachedGraph->gradInputTensor_, sub_batch_output, getMPSShape(sub_batch_output), true, dataType);
339+
Placeholder gradOutputPlaceholder = !is_backward_pass ? Placeholder()
340+
: Placeholder(cachedGraph->gradOutputTensor_,
341+
sub_batch_grad_output,
342+
getMPSShape(sub_batch_grad_output),
343+
true,
344+
dataType);
345+
346+
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
347+
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
348+
if (is_backward_pass) {
349+
feeds[gradOutputPlaceholder.getMPSGraphTensor()] = gradOutputPlaceholder.getMPSGraphTensorData();
299350
}
300-
});
301-
302-
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input, nullptr, true, dataType);
303-
Placeholder outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, output, nullptr, true, dataType);
304-
Placeholder gradOutputPlaceholder = !is_backward_pass
305-
? Placeholder()
306-
: Placeholder(cachedGraph->gradOutputTensor_, grad_output, nullptr, true, dataType);
307-
308-
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
309-
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
310-
if (is_backward_pass) {
311-
feeds[gradOutputPlaceholder.getMPSGraphTensor()] = gradOutputPlaceholder.getMPSGraphTensorData();
351+
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outputPlaceholder);
312352
}
313-
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outputPlaceholder);
314-
}
353+
} while (remaining_batch_size > 0);
315354
return output;
316355
}
317356
} // namespace mps

test/test_nn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from torch.testing._internal.common_device_type import dtypesIfMPS, instantiate_device_type_tests, dtypes, \
4444
dtypesIfCUDA, precisionOverride, onlyCUDA, onlyCPU, \
4545
skipCUDAIfRocm, skipCUDAIf, skipCUDAIfNotRocm, \
46-
onlyNativeDeviceTypes, deviceCountAtLeast, largeTensorTest, expectedFailureMeta, expectedFailureMPS, \
46+
onlyNativeDeviceTypes, deviceCountAtLeast, largeTensorTest, expectedFailureMeta, expectedFailureMPS, expectedFailureMPSPre15, \
4747
skipMeta, get_all_device_types
4848

4949
from hypothesis import given
@@ -8781,7 +8781,7 @@ def test_ReplicationPad_empty(self, device, dtype):
87818781
with self.assertRaisesRegex(RuntimeError, 'padding size is expected to be 6'):
87828782
torch._C._nn.replication_pad3d(torch.randn([2]), padding=[])
87838783

8784-
@expectedFailureMPS # Correctness issue https://github.com/pytorch/pytorch/issues/135447
8784+
@expectedFailureMPSPre15 # Correctness issue https://github.com/pytorch/pytorch/issues/135447
87858785
def test_ReplicationPad1d_large(self, device):
87868786
shapes = ([2, 65736, 4], [65736, 2, 4])
87878787
pl, pr = 3, 4
@@ -8806,7 +8806,7 @@ def test_ReplicationPad1d_large(self, device):
88068806
self.assertEqual(x.grad[:, :, 0], g[:, :, : pl + 1].sum(-1))
88078807
self.assertEqual(x.grad[:, :, -1], g[:, :, -pr - 1:].sum(-1))
88088808

8809-
@expectedFailureMPS # Correctness issue https://github.com/pytorch/pytorch/issues/135447
8809+
@expectedFailureMPSPre15 # Correctness issue https://github.com/pytorch/pytorch/issues/135447
88108810
def test_ReplicationPad2d_large(self, device):
88118811
shapes = ([2, 65736, 4, 4], [65736, 2, 4, 4])
88128812
pl, pr, pt, pb = 3, 4, 5, 6

0 commit comments

Comments
 (0)