@@ -1292,6 +1292,185 @@ def fn(x, y, z):
1292
1292
1293
1293
self .assertEqual (ref , res )
1294
1294
1295
+ @torch ._inductor .config .patch (emulate_precision_casts = True )
1296
+ def test_dont_inplace_disjoint_accesses (self ):
1297
+ # TODO - would not need mms if we could annotate donated buffer..
1298
+ def forward ( # noqa: F821, F722
1299
+ arg0_1 : "bf16[2048, 2048][2048, 1]cuda:0" , # noqa: F821, F722
1300
+ arg1_1 : "bf16[8, 4096, 2048][8388608, 2048, 1]cuda:0" , # noqa: F821, F722
1301
+ arg2_1 : "bf16[2048, 2048][2048, 1]cuda:0" , # noqa: F821, F722
1302
+ arg3_1 : "bf16[2048, 2048][2048, 1]cuda:0" , # noqa: F821, F722
1303
+ arg4_1 : "bf16[2048][1]cuda:0" , # noqa: F821, F722
1304
+ arg5_1 : "bf16[2048][1]cuda:0" , # noqa: F821, F722
1305
+ arg6_1 : "f32[4096, 128][128, 1]cuda:0" , # noqa: F821, F722
1306
+ arg7_1 : "f32[4096, 128][128, 1]cuda:0" , # noqa: F821, F722
1307
+ ):
1308
+ permute = torch .ops .aten .permute .default (arg0_1 , [1 , 0 ])
1309
+ arg0_1 = None
1310
+ view = torch .ops .aten .view .default (arg1_1 , [32768 , 2048 ])
1311
+ mm = torch .ops .aten .mm .default (view , permute )
1312
+ view = permute = None
1313
+ view_1 = torch .ops .aten .view .default (mm , [8 , 4096 , 2048 ])
1314
+ mm = None
1315
+ permute_1 = torch .ops .aten .permute .default (arg2_1 , [1 , 0 ])
1316
+ arg2_1 = None
1317
+ view_2 = torch .ops .aten .view .default (arg1_1 , [32768 , 2048 ])
1318
+ mm_1 = torch .ops .aten .mm .default (view_2 , permute_1 )
1319
+ view_2 = permute_1 = None
1320
+ view_3 = torch .ops .aten .view .default (mm_1 , [8 , 4096 , 2048 ])
1321
+ mm_1 = None
1322
+ permute_2 = torch .ops .aten .permute .default (arg3_1 , [1 , 0 ])
1323
+ arg3_1 = None
1324
+ view_4 = torch .ops .aten .view .default (arg1_1 , [32768 , 2048 ])
1325
+ arg1_1 = None
1326
+ mm_2 = torch .ops .aten .mm .default (view_4 , permute_2 )
1327
+ view_4 = permute_2 = None
1328
+ view_5 = torch .ops .aten .view .default (mm_2 , [8 , 4096 , 2048 ])
1329
+ mm_2 = None
1330
+ convert_element_type_6 = torch .ops .prims .convert_element_type .default (
1331
+ view_1 , torch .float32
1332
+ )
1333
+ view_1 = None
1334
+ pow_1 = torch .ops .aten .pow .Tensor_Scalar (convert_element_type_6 , 2 )
1335
+ mean = torch .ops .aten .mean .dim (pow_1 , [- 1 ], True )
1336
+ pow_1 = None
1337
+ add = torch .ops .aten .add .Tensor (mean , 1e-06 )
1338
+ mean = None
1339
+ rsqrt = torch .ops .aten .rsqrt .default (add )
1340
+ add = None
1341
+ mul = torch .ops .aten .mul .Tensor (convert_element_type_6 , rsqrt )
1342
+ convert_element_type_6 = rsqrt = None
1343
+ convert_element_type_7 = torch .ops .prims .convert_element_type .default (
1344
+ arg4_1 , torch .float32
1345
+ )
1346
+ arg4_1 = None
1347
+ mul_1 = torch .ops .aten .mul .Tensor (convert_element_type_7 , mul )
1348
+ convert_element_type_7 = mul = None
1349
+ convert_element_type_8 = torch .ops .prims .convert_element_type .default (
1350
+ mul_1 , torch .bfloat16
1351
+ )
1352
+ mul_1 = None
1353
+ convert_element_type_9 = torch .ops .prims .convert_element_type .default (
1354
+ view_3 , torch .float32
1355
+ )
1356
+ view_3 = None
1357
+ pow_2 = torch .ops .aten .pow .Tensor_Scalar (convert_element_type_9 , 2 )
1358
+ mean_1 = torch .ops .aten .mean .dim (pow_2 , [- 1 ], True )
1359
+ pow_2 = None
1360
+ add_1 = torch .ops .aten .add .Tensor (mean_1 , 1e-06 )
1361
+ mean_1 = None
1362
+ rsqrt_1 = torch .ops .aten .rsqrt .default (add_1 )
1363
+ add_1 = None
1364
+ mul_2 = torch .ops .aten .mul .Tensor (convert_element_type_9 , rsqrt_1 )
1365
+ convert_element_type_9 = rsqrt_1 = None
1366
+ convert_element_type_10 = torch .ops .prims .convert_element_type .default (
1367
+ arg5_1 , torch .float32
1368
+ )
1369
+ arg5_1 = None
1370
+ mul_3 = torch .ops .aten .mul .Tensor (convert_element_type_10 , mul_2 )
1371
+ convert_element_type_10 = mul_2 = None
1372
+ convert_element_type_11 = torch .ops .prims .convert_element_type .default (
1373
+ mul_3 , torch .bfloat16
1374
+ )
1375
+ mul_3 = None
1376
+ view_6 = torch .ops .aten .view .default (
1377
+ convert_element_type_8 , [8 , 4096 , - 1 , 128 ]
1378
+ )
1379
+ convert_element_type_8 = None
1380
+ view_7 = torch .ops .aten .view .default (
1381
+ convert_element_type_11 , [8 , 4096 , - 1 , 128 ]
1382
+ )
1383
+ convert_element_type_11 = None
1384
+ view_8 = torch .ops .aten .view .default (view_5 , [8 , 4096 , - 1 , 128 ])
1385
+ view_5 = None
1386
+ convert_element_type_12 = torch .ops .prims .convert_element_type .default (
1387
+ view_6 , torch .float32
1388
+ )
1389
+ view_6 = None
1390
+ convert_element_type_13 = torch .ops .prims .convert_element_type .default (
1391
+ view_7 , torch .float32
1392
+ )
1393
+ view_7 = None
1394
+ unsqueeze = torch .ops .aten .unsqueeze .default (arg6_1 , 0 )
1395
+ unsqueeze_1 = torch .ops .aten .unsqueeze .default (unsqueeze , 2 )
1396
+ unsqueeze = None
1397
+ unsqueeze_2 = torch .ops .aten .unsqueeze .default (arg7_1 , 0 )
1398
+ unsqueeze_3 = torch .ops .aten .unsqueeze .default (unsqueeze_2 , 2 )
1399
+ unsqueeze_2 = None
1400
+ mul_4 = torch .ops .aten .mul .Tensor (convert_element_type_12 , unsqueeze_3 )
1401
+ unsqueeze_3 = None
1402
+ view_9 = torch .ops .aten .view .default (
1403
+ convert_element_type_12 , [8 , 4096 , 16 , 2 , 64 ]
1404
+ )
1405
+ convert_element_type_12 = None
1406
+ unbind = torch .ops .aten .unbind .int (view_9 , - 2 )
1407
+ view_9 = None
1408
+ getitem = unbind [0 ]
1409
+ getitem_1 = unbind [1 ]
1410
+ unbind = None
1411
+ neg = torch .ops .aten .neg .default (getitem_1 )
1412
+ getitem_1 = None
1413
+ cat = torch .ops .aten .cat .default ([neg , getitem ], - 1 )
1414
+ neg = getitem = None
1415
+ mul_5 = torch .ops .aten .mul .Tensor (cat , unsqueeze_1 )
1416
+ cat = unsqueeze_1 = None
1417
+ add_2 = torch .ops .aten .add .Tensor (mul_4 , mul_5 )
1418
+ mul_4 = mul_5 = None
1419
+ unsqueeze_4 = torch .ops .aten .unsqueeze .default (arg6_1 , 0 )
1420
+ arg6_1 = None
1421
+ unsqueeze_5 = torch .ops .aten .unsqueeze .default (unsqueeze_4 , 2 )
1422
+ unsqueeze_4 = None
1423
+ unsqueeze_6 = torch .ops .aten .unsqueeze .default (arg7_1 , 0 )
1424
+ arg7_1 = None
1425
+ unsqueeze_7 = torch .ops .aten .unsqueeze .default (unsqueeze_6 , 2 )
1426
+ unsqueeze_6 = None
1427
+ mul_6 = torch .ops .aten .mul .Tensor (convert_element_type_13 , unsqueeze_7 )
1428
+ unsqueeze_7 = None
1429
+ view_10 = torch .ops .aten .view .default (
1430
+ convert_element_type_13 , [8 , 4096 , 16 , 2 , 64 ]
1431
+ )
1432
+ convert_element_type_13 = None
1433
+ unbind_1 = torch .ops .aten .unbind .int (view_10 , - 2 )
1434
+ view_10 = None
1435
+ getitem_2 = unbind_1 [0 ]
1436
+ getitem_3 = unbind_1 [1 ]
1437
+ unbind_1 = None
1438
+ neg_1 = torch .ops .aten .neg .default (getitem_3 )
1439
+ getitem_3 = None
1440
+ cat_1 = torch .ops .aten .cat .default ([neg_1 , getitem_2 ], - 1 )
1441
+ neg_1 = getitem_2 = None
1442
+ mul_7 = torch .ops .aten .mul .Tensor (cat_1 , unsqueeze_5 )
1443
+ cat_1 = unsqueeze_5 = None
1444
+ add_3 = torch .ops .aten .add .Tensor (mul_6 , mul_7 )
1445
+ mul_6 = mul_7 = None
1446
+ convert_element_type_14 = torch .ops .prims .convert_element_type .default (
1447
+ add_2 , torch .bfloat16
1448
+ )
1449
+ add_2 = None
1450
+ convert_element_type_15 = torch .ops .prims .convert_element_type .default (
1451
+ add_3 , torch .bfloat16
1452
+ )
1453
+ add_3 = None
1454
+ permute_3 = torch .ops .aten .permute .default (
1455
+ convert_element_type_14 , [0 , 2 , 1 , 3 ]
1456
+ )
1457
+ convert_element_type_14 = None
1458
+ permute_4 = torch .ops .aten .permute .default (
1459
+ convert_element_type_15 , [0 , 2 , 1 , 3 ]
1460
+ )
1461
+ convert_element_type_15 = None
1462
+ permute_5 = torch .ops .aten .permute .default (view_8 , [0 , 2 , 1 , 3 ])
1463
+ view_8 = None
1464
+ return (permute_3 , permute_4 , permute_5 )
1465
+
1466
+ from torch ._dynamo .debug_utils import aot_graph_input_parser
1467
+
1468
+ kwargs = aot_graph_input_parser (forward )
1469
+ out , code = run_and_get_code (torch .compile (forward ), ** kwargs )
1470
+ # ignore tiny values.. prior to this fix absolute error was ~28
1471
+ self .assertEqual (forward (** kwargs ), out , atol = 0.01 , rtol = 2 )
1472
+ FileCheck ().check_not ("in_out" ).run (code [0 ])
1473
+
1295
1474
# https://github.com/pytorch/pytorch/issues/104937
1296
1475
def test_linear_with_zero_infeature_size (self ):
1297
1476
m = nn .Linear (in_features = 0 , out_features = 0 , bias = True ).to ("cuda" )
0 commit comments