Skip to content

Commit ebd8f3a

Browse files
committed
Work around MPSGraph issue in backward pass of nn.ReplicationPad1d/2d.
Fixes #135447. When the 3rd from last dimension is 2^16 or greater, MPSGraph returns 0 for padgradient. To work around this, we break the problematic dimension into chunks with chunk size being no greater than 2^16 - 1. Test case for nn.ReplicationPad1d: ``` shape = [65739, 2, 4] x_cpu = torch.randn(shape, device='cpu', requires_grad=True) x_mps = x_cpu.clone().detach().to('mps').requires_grad_(True) model = torch.nn.ReplicationPad1d((1, 1)) out_cpu = model(x_cpu) out_mps = model(x_mps) # backward g_cpu = torch.randn_like(out_cpu) g_mps = g_cpu.clone().detach().to('mps').requires_grad_(False) out_cpu.backward(g_cpu) out_mps.backward(g_mps) print(f"{((x_cpu.grad - x_mps.grad.cpu()).abs() > 1e-5).sum() = }") # Expected Output: # ((x_cpu.grad - x_mps.grad.cpu()).abs() > 1e-5).sum() = tensor(0) ``` Test case for nn.ReplicationPad2d, ``` shape = [2, 65739, 2, 4] x_cpu = torch.randn(shape, device='cpu', requires_grad=True) x_mps = x_cpu.clone().detach().to('mps').requires_grad_(True) model = torch.nn.ReplicationPad2d((1, 1, 1, 1)) out_cpu = model(x_cpu) out_mps = model(x_mps) # backward g_cpu = torch.randn_like(out_cpu) g_mps = g_cpu.clone().detach().to('mps').requires_grad_(False) out_cpu.backward(g_cpu) out_mps.backward(g_mps) print(f"{((x_cpu.grad - x_mps.grad.cpu()).abs() > 1e-5).sum() = }") # Expected Output: # ((x_cpu.grad - x_mps.grad.cpu()).abs() > 1e-5).sum() = tensor(0) ``` These tests produce expected output with this workaround.
1 parent e7ed50f commit ebd8f3a

File tree

2 files changed

+108
-65
lines changed

2 files changed

+108
-65
lines changed

aten/src/ATen/native/mps/operations/Pad.mm

Lines changed: 104 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <ATen/ops/replication_pad2d_native.h>
2020
#include <ATen/ops/replication_pad3d_backward_native.h>
2121
#include <ATen/ops/replication_pad3d_native.h>
22+
#include <ATen/ops/slice.h>
2223
#endif
2324

2425
namespace at::native {
@@ -243,75 +244,113 @@
243244
dataType = MPSDataTypeInt8;
244245
}
245246

246-
@autoreleasepool {
247-
string key = op_name + getTensorsStringKey({input, grad_output, output}) + ":[" + getArrayRefString(padding) +
248-
"]:" + std::to_string(constantValue);
249-
250-
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
251-
newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, dataType, getMPSShape(input));
252-
const bool needsSlice = startMask != dims_mask || endMask != dims_mask;
253-
254-
if (!is_backward_pass) {
255-
MPSGraphTensor* padTensor = [mpsGraph padTensor:newCachedGraph->inputTensor_
256-
withPaddingMode:mode
257-
leftPadding:leftPadding
258-
rightPadding:rightPadding
259-
constantValue:constantValue
260-
name:nil];
261-
// workaround for the right padding bug in Monterey
262-
if (needsSlice) {
263-
newCachedGraph->gradInputTensor_ =
264-
[mpsGraph sliceTensor:padTensor
265-
starts:[NSArray arrayWithObjects:startsVec.data() count:ndims]
266-
ends:[NSArray arrayWithObjects:endsVec.data() count:ndims]
267-
strides:[NSArray arrayWithObjects:stridesVec.data() count:ndims]
268-
startMask:startMask
269-
endMask:endMask
270-
squeezeMask:0
271-
name:nil];
272-
} else {
273-
newCachedGraph->gradInputTensor_ = padTensor;
274-
}
275-
} else {
276-
newCachedGraph->gradOutputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, dataType, getMPSShape(grad_output));
277-
MPSGraphTensor* padGradTensor =
278-
[mpsGraph padGradientWithIncomingGradientTensor:newCachedGraph->gradOutputTensor_
279-
sourceTensor:newCachedGraph->inputTensor_
280-
paddingMode:mode
281-
leftPadding:leftPadding
282-
rightPadding:rightPadding
283-
name:nil];
284-
// workaround for negative padding issue with padGradientWithIncomingGradientTensor()
285-
if (needsSlice) {
286-
newCachedGraph->gradInputTensor_ =
287-
[mpsGraph sliceGradientTensor:padGradTensor
288-
fwdInShapeTensor:[mpsGraph shapeOfTensor:newCachedGraph->inputTensor_ name:nil]
289-
starts:[NSArray arrayWithObjects:startsVec.data() count:ndims]
290-
ends:[NSArray arrayWithObjects:endsVec.data() count:ndims]
291-
strides:[NSArray arrayWithObjects:stridesVec.data() count:ndims]
292-
startMask:startMask
293-
endMask:endMask
294-
squeezeMask:0
295-
name:nil];
247+
// For tensor with rank equal 3 or 4 and padding mode replicate1d/2d, when the 3rd from the
248+
// last dimension is 2**16 or greater, MPSGraph returns incorrect gradient. To work around this,
249+
// we break the tensor into chuncks where the problematic dimention is no greater than 2**16-1.
250+
// This is reported in https://github.com/pytorch/pytorch/issues/135447.
251+
// Internal radar for MPSGraph: rdar://149853787.
252+
const int64_t max_sub_batch_size = 65535;
253+
int64_t sliced_dim = -1;
254+
int64_t sub_batch_start = 0;
255+
int64_t remaining_batch_size = 0;
256+
if ((ndims == 3 || ndims == 4) && mode == MPSGraphPaddingModeClampToEdge && pad_front == 0 && pad_back == 0) {
257+
int64_t batch_size = input_.size(-3);
258+
if (batch_size > max_sub_batch_size) {
259+
sliced_dim = ndims - 3;
260+
remaining_batch_size = batch_size;
261+
}
262+
}
263+
do {
264+
Tensor sub_batch_input = input;
265+
Tensor sub_batch_grad_output = grad_output;
266+
Tensor sub_batch_output = output;
267+
268+
if (sliced_dim >= 0) {
269+
int64_t sub_batch_size =
270+
is_backward_pass ? std::min<int64_t>(remaining_batch_size, max_sub_batch_size) : remaining_batch_size;
271+
sub_batch_input = at::slice(input, sliced_dim, sub_batch_start, sub_batch_start + sub_batch_size);
272+
sub_batch_output = at::slice(output, sliced_dim, sub_batch_start, sub_batch_start + sub_batch_size);
273+
if (is_backward_pass) {
274+
sub_batch_grad_output = at::slice(grad_output, sliced_dim, sub_batch_start, sub_batch_start + sub_batch_size);
275+
}
276+
remaining_batch_size -= sub_batch_size;
277+
sub_batch_start += sub_batch_size;
278+
}
279+
@autoreleasepool {
280+
string key = op_name + getTensorsStringKey({sub_batch_input, sub_batch_grad_output, sub_batch_output}) + ":[" +
281+
getArrayRefString(padding) + "]:" + std::to_string(constantValue) + std::to_string(sub_batch_start);
282+
283+
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
284+
newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, dataType, getMPSShape(sub_batch_input));
285+
const bool needsSlice = startMask != dims_mask || endMask != dims_mask;
286+
287+
if (!is_backward_pass) {
288+
MPSGraphTensor* padTensor = [mpsGraph padTensor:newCachedGraph->inputTensor_
289+
withPaddingMode:mode
290+
leftPadding:leftPadding
291+
rightPadding:rightPadding
292+
constantValue:constantValue
293+
name:nil];
294+
// workaround for the right padding bug in Monterey
295+
if (needsSlice) {
296+
newCachedGraph->gradInputTensor_ =
297+
[mpsGraph sliceTensor:padTensor
298+
starts:[NSArray arrayWithObjects:startsVec.data() count:ndims]
299+
ends:[NSArray arrayWithObjects:endsVec.data() count:ndims]
300+
strides:[NSArray arrayWithObjects:stridesVec.data() count:ndims]
301+
startMask:startMask
302+
endMask:endMask
303+
squeezeMask:0
304+
name:nil];
305+
} else {
306+
newCachedGraph->gradInputTensor_ = padTensor;
307+
}
296308
} else {
297-
newCachedGraph->gradInputTensor_ = padGradTensor;
309+
newCachedGraph->gradOutputTensor_ =
310+
mpsGraphRankedPlaceHolder(mpsGraph, dataType, getMPSShape(sub_batch_grad_output));
311+
MPSGraphTensor* padGradTensor =
312+
[mpsGraph padGradientWithIncomingGradientTensor:newCachedGraph->gradOutputTensor_
313+
sourceTensor:newCachedGraph->inputTensor_
314+
paddingMode:mode
315+
leftPadding:leftPadding
316+
rightPadding:rightPadding
317+
name:nil];
318+
// workaround for negative padding issue with padGradientWithIncomingGradientTensor()
319+
if (needsSlice) {
320+
newCachedGraph->gradInputTensor_ =
321+
[mpsGraph sliceGradientTensor:padGradTensor
322+
fwdInShapeTensor:[mpsGraph shapeOfTensor:newCachedGraph->inputTensor_ name:nil]
323+
starts:[NSArray arrayWithObjects:startsVec.data() count:ndims]
324+
ends:[NSArray arrayWithObjects:endsVec.data() count:ndims]
325+
strides:[NSArray arrayWithObjects:stridesVec.data() count:ndims]
326+
startMask:startMask
327+
endMask:endMask
328+
squeezeMask:0
329+
name:nil];
330+
} else {
331+
newCachedGraph->gradInputTensor_ = padGradTensor;
332+
}
298333
}
334+
});
335+
Placeholder inputPlaceholder =
336+
Placeholder(cachedGraph->inputTensor_, sub_batch_input, getMPSShape(sub_batch_input), true, dataType);
337+
Placeholder outputPlaceholder =
338+
Placeholder(cachedGraph->gradInputTensor_, sub_batch_output, getMPSShape(sub_batch_output), true, dataType);
339+
Placeholder gradOutputPlaceholder = !is_backward_pass ? Placeholder()
340+
: Placeholder(cachedGraph->gradOutputTensor_,
341+
sub_batch_grad_output,
342+
getMPSShape(sub_batch_grad_output),
343+
true,
344+
dataType);
345+
346+
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
347+
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
348+
if (is_backward_pass) {
349+
feeds[gradOutputPlaceholder.getMPSGraphTensor()] = gradOutputPlaceholder.getMPSGraphTensorData();
299350
}
300-
});
301-
302-
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input, nullptr, true, dataType);
303-
Placeholder outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, output, nullptr, true, dataType);
304-
Placeholder gradOutputPlaceholder = !is_backward_pass
305-
? Placeholder()
306-
: Placeholder(cachedGraph->gradOutputTensor_, grad_output, nullptr, true, dataType);
307-
308-
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
309-
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
310-
if (is_backward_pass) {
311-
feeds[gradOutputPlaceholder.getMPSGraphTensor()] = gradOutputPlaceholder.getMPSGraphTensorData();
351+
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outputPlaceholder);
312352
}
313-
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outputPlaceholder);
314-
}
353+
} while (remaining_batch_size > 0);
315354
return output;
316355
}
317356
} // namespace mps

test/test_mps.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9556,6 +9556,8 @@ def helper(shape, padding, op, value=0):
95569556
helper((2, 4, 4), (1, 3), nn.ReflectionPad1d)
95579557
# Replication 1D
95589558
helper((2, 1, 6), 3, nn.ReplicationPad1d)
9559+
# Replication 1D with large batch size (>= 2**16)
9560+
helper((65539, 1, 6), 3, nn.ReplicationPad1d)
95599561
# Constant Pad 1D
95609562
helper((2, 3, 4), 2, nn.ConstantPad1d)
95619563
# Constant Pad 1D with single dimension input
@@ -9569,6 +9571,8 @@ def helper(shape, padding, op, value=0):
95699571
helper((2, 1, 6, 8), 2, nn.ReplicationPad2d)
95709572
# verify if a change in shape of padding would cause problems with graph caching
95719573
helper((2, 1, 6, 8), (2, 4, 3, 5), nn.ReplicationPad2d)
9574+
# verify 2d padding with the 2nd dimension >= 2**16
9575+
helper((2, 65539, 6, 8), (2, 4, 3, 5), nn.ReplicationPad2d)
95729576
# Constant Pad 2D
95739577
helper((2, 1, 6, 8), (2, 4, 3, 5), nn.ConstantPad2d)
95749578
# input size < pad size

0 commit comments

Comments
 (0)