|
19 | 19 | #include <ATen/ops/replication_pad2d_native.h>
|
20 | 20 | #include <ATen/ops/replication_pad3d_backward_native.h>
|
21 | 21 | #include <ATen/ops/replication_pad3d_native.h>
|
| 22 | +#include <ATen/ops/slice.h> |
22 | 23 | #endif
|
23 | 24 |
|
24 | 25 | namespace at::native {
|
|
243 | 244 | dataType = MPSDataTypeInt8;
|
244 | 245 | }
|
245 | 246 |
|
246 |
| - @autoreleasepool { |
247 |
| - string key = op_name + getTensorsStringKey({input, grad_output, output}) + ":[" + getArrayRefString(padding) + |
248 |
| - "]:" + std::to_string(constantValue); |
249 |
| - |
250 |
| - auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |
251 |
| - newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, dataType, getMPSShape(input)); |
252 |
| - const bool needsSlice = startMask != dims_mask || endMask != dims_mask; |
253 |
| - |
254 |
| - if (!is_backward_pass) { |
255 |
| - MPSGraphTensor* padTensor = [mpsGraph padTensor:newCachedGraph->inputTensor_ |
256 |
| - withPaddingMode:mode |
257 |
| - leftPadding:leftPadding |
258 |
| - rightPadding:rightPadding |
259 |
| - constantValue:constantValue |
260 |
| - name:nil]; |
261 |
| - // workaround for the right padding bug in Monterey |
262 |
| - if (needsSlice) { |
263 |
| - newCachedGraph->gradInputTensor_ = |
264 |
| - [mpsGraph sliceTensor:padTensor |
265 |
| - starts:[NSArray arrayWithObjects:startsVec.data() count:ndims] |
266 |
| - ends:[NSArray arrayWithObjects:endsVec.data() count:ndims] |
267 |
| - strides:[NSArray arrayWithObjects:stridesVec.data() count:ndims] |
268 |
| - startMask:startMask |
269 |
| - endMask:endMask |
270 |
| - squeezeMask:0 |
271 |
| - name:nil]; |
272 |
| - } else { |
273 |
| - newCachedGraph->gradInputTensor_ = padTensor; |
274 |
| - } |
275 |
| - } else { |
276 |
| - newCachedGraph->gradOutputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, dataType, getMPSShape(grad_output)); |
277 |
| - MPSGraphTensor* padGradTensor = |
278 |
| - [mpsGraph padGradientWithIncomingGradientTensor:newCachedGraph->gradOutputTensor_ |
279 |
| - sourceTensor:newCachedGraph->inputTensor_ |
280 |
| - paddingMode:mode |
281 |
| - leftPadding:leftPadding |
282 |
| - rightPadding:rightPadding |
283 |
| - name:nil]; |
284 |
| - // workaround for negative padding issue with padGradientWithIncomingGradientTensor() |
285 |
| - if (needsSlice) { |
286 |
| - newCachedGraph->gradInputTensor_ = |
287 |
| - [mpsGraph sliceGradientTensor:padGradTensor |
288 |
| - fwdInShapeTensor:[mpsGraph shapeOfTensor:newCachedGraph->inputTensor_ name:nil] |
289 |
| - starts:[NSArray arrayWithObjects:startsVec.data() count:ndims] |
290 |
| - ends:[NSArray arrayWithObjects:endsVec.data() count:ndims] |
291 |
| - strides:[NSArray arrayWithObjects:stridesVec.data() count:ndims] |
292 |
| - startMask:startMask |
293 |
| - endMask:endMask |
294 |
| - squeezeMask:0 |
295 |
| - name:nil]; |
| 247 | + // For tensor with rank equal 3 or 4 and padding mode replicate1d/2d, when the 3rd from the |
| 248 | + // last dimension is 2**16 or greater, MPSGraph returns incorrect gradient. To work around this, |
| 249 | + // we break the tensor into chuncks where the problematic dimention is no greater than 2**16-1. |
| 250 | + // This is reported in https://github.com/pytorch/pytorch/issues/135447. |
| 251 | + // Internal radar for MPSGraph: rdar://149853787. |
| 252 | + const int64_t max_sub_batch_size = 65535; |
| 253 | + int64_t sliced_dim = -1; |
| 254 | + int64_t sub_batch_start = 0; |
| 255 | + int64_t remaining_batch_size = 0; |
| 256 | + if ((ndims == 3 || ndims == 4) && mode == MPSGraphPaddingModeClampToEdge && pad_front == 0 && pad_back == 0) { |
| 257 | + int64_t batch_size = input_.size(-3); |
| 258 | + if (batch_size > max_sub_batch_size) { |
| 259 | + sliced_dim = ndims - 3; |
| 260 | + remaining_batch_size = batch_size; |
| 261 | + } |
| 262 | + } |
| 263 | + do { |
| 264 | + Tensor sub_batch_input = input; |
| 265 | + Tensor sub_batch_grad_output = grad_output; |
| 266 | + Tensor sub_batch_output = output; |
| 267 | + |
| 268 | + if (sliced_dim >= 0) { |
| 269 | + int64_t sub_batch_size = |
| 270 | + is_backward_pass ? std::min<int64_t>(remaining_batch_size, max_sub_batch_size) : remaining_batch_size; |
| 271 | + sub_batch_input = at::slice(input, sliced_dim, sub_batch_start, sub_batch_start + sub_batch_size); |
| 272 | + sub_batch_output = at::slice(output, sliced_dim, sub_batch_start, sub_batch_start + sub_batch_size); |
| 273 | + if (is_backward_pass) { |
| 274 | + sub_batch_grad_output = at::slice(grad_output, sliced_dim, sub_batch_start, sub_batch_start + sub_batch_size); |
| 275 | + } |
| 276 | + remaining_batch_size -= sub_batch_size; |
| 277 | + sub_batch_start += sub_batch_size; |
| 278 | + } |
| 279 | + @autoreleasepool { |
| 280 | + string key = op_name + getTensorsStringKey({sub_batch_input, sub_batch_grad_output, sub_batch_output}) + ":[" + |
| 281 | + getArrayRefString(padding) + "]:" + std::to_string(constantValue) + std::to_string(sub_batch_start); |
| 282 | + |
| 283 | + auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |
| 284 | + newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, dataType, getMPSShape(sub_batch_input)); |
| 285 | + const bool needsSlice = startMask != dims_mask || endMask != dims_mask; |
| 286 | + |
| 287 | + if (!is_backward_pass) { |
| 288 | + MPSGraphTensor* padTensor = [mpsGraph padTensor:newCachedGraph->inputTensor_ |
| 289 | + withPaddingMode:mode |
| 290 | + leftPadding:leftPadding |
| 291 | + rightPadding:rightPadding |
| 292 | + constantValue:constantValue |
| 293 | + name:nil]; |
| 294 | + // workaround for the right padding bug in Monterey |
| 295 | + if (needsSlice) { |
| 296 | + newCachedGraph->gradInputTensor_ = |
| 297 | + [mpsGraph sliceTensor:padTensor |
| 298 | + starts:[NSArray arrayWithObjects:startsVec.data() count:ndims] |
| 299 | + ends:[NSArray arrayWithObjects:endsVec.data() count:ndims] |
| 300 | + strides:[NSArray arrayWithObjects:stridesVec.data() count:ndims] |
| 301 | + startMask:startMask |
| 302 | + endMask:endMask |
| 303 | + squeezeMask:0 |
| 304 | + name:nil]; |
| 305 | + } else { |
| 306 | + newCachedGraph->gradInputTensor_ = padTensor; |
| 307 | + } |
296 | 308 | } else {
|
297 |
| - newCachedGraph->gradInputTensor_ = padGradTensor; |
| 309 | + newCachedGraph->gradOutputTensor_ = |
| 310 | + mpsGraphRankedPlaceHolder(mpsGraph, dataType, getMPSShape(sub_batch_grad_output)); |
| 311 | + MPSGraphTensor* padGradTensor = |
| 312 | + [mpsGraph padGradientWithIncomingGradientTensor:newCachedGraph->gradOutputTensor_ |
| 313 | + sourceTensor:newCachedGraph->inputTensor_ |
| 314 | + paddingMode:mode |
| 315 | + leftPadding:leftPadding |
| 316 | + rightPadding:rightPadding |
| 317 | + name:nil]; |
| 318 | + // workaround for negative padding issue with padGradientWithIncomingGradientTensor() |
| 319 | + if (needsSlice) { |
| 320 | + newCachedGraph->gradInputTensor_ = |
| 321 | + [mpsGraph sliceGradientTensor:padGradTensor |
| 322 | + fwdInShapeTensor:[mpsGraph shapeOfTensor:newCachedGraph->inputTensor_ name:nil] |
| 323 | + starts:[NSArray arrayWithObjects:startsVec.data() count:ndims] |
| 324 | + ends:[NSArray arrayWithObjects:endsVec.data() count:ndims] |
| 325 | + strides:[NSArray arrayWithObjects:stridesVec.data() count:ndims] |
| 326 | + startMask:startMask |
| 327 | + endMask:endMask |
| 328 | + squeezeMask:0 |
| 329 | + name:nil]; |
| 330 | + } else { |
| 331 | + newCachedGraph->gradInputTensor_ = padGradTensor; |
| 332 | + } |
298 | 333 | }
|
| 334 | + }); |
| 335 | + Placeholder inputPlaceholder = |
| 336 | + Placeholder(cachedGraph->inputTensor_, sub_batch_input, getMPSShape(sub_batch_input), true, dataType); |
| 337 | + Placeholder outputPlaceholder = |
| 338 | + Placeholder(cachedGraph->gradInputTensor_, sub_batch_output, getMPSShape(sub_batch_output), true, dataType); |
| 339 | + Placeholder gradOutputPlaceholder = !is_backward_pass ? Placeholder() |
| 340 | + : Placeholder(cachedGraph->gradOutputTensor_, |
| 341 | + sub_batch_grad_output, |
| 342 | + getMPSShape(sub_batch_grad_output), |
| 343 | + true, |
| 344 | + dataType); |
| 345 | + |
| 346 | + NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease]; |
| 347 | + feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData(); |
| 348 | + if (is_backward_pass) { |
| 349 | + feeds[gradOutputPlaceholder.getMPSGraphTensor()] = gradOutputPlaceholder.getMPSGraphTensorData(); |
299 | 350 | }
|
300 |
| - }); |
301 |
| - |
302 |
| - Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input, nullptr, true, dataType); |
303 |
| - Placeholder outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, output, nullptr, true, dataType); |
304 |
| - Placeholder gradOutputPlaceholder = !is_backward_pass |
305 |
| - ? Placeholder() |
306 |
| - : Placeholder(cachedGraph->gradOutputTensor_, grad_output, nullptr, true, dataType); |
307 |
| - |
308 |
| - NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease]; |
309 |
| - feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData(); |
310 |
| - if (is_backward_pass) { |
311 |
| - feeds[gradOutputPlaceholder.getMPSGraphTensor()] = gradOutputPlaceholder.getMPSGraphTensorData(); |
| 351 | + runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outputPlaceholder); |
312 | 352 | }
|
313 |
| - runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outputPlaceholder); |
314 |
| - } |
| 353 | + } while (remaining_batch_size > 0); |
315 | 354 | return output;
|
316 | 355 | }
|
317 | 356 | } // namespace mps
|
|
0 commit comments