Skip to content

Commit c7ff78d

Browse files
pytorchboteellison
andauthored
Fix inplacing with multiple, fused uses (#150892)
Fix inplacing with multiple, fused uses (#150845) We had `can_inplace` defined on a single use. When that buffer has multiple uses inside a fused node, we need to check if the other accesses have the same index. Otherwise we may read memory that has already been written to from inplacing. Pull Request resolved: #150845 Approved by: https://github.com/zou3519, https://github.com/exclamaforte, https://github.com/atalman, https://github.com/jansel (cherry picked from commit 27ded35) Co-authored-by: eellison <elias.ellison@gmail.com>
1 parent 894909a commit c7ff78d

File tree

2 files changed

+212
-0
lines changed

2 files changed

+212
-0
lines changed

test/inductor/test_cuda_repro.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1292,6 +1292,185 @@ def fn(x, y, z):
12921292

12931293
self.assertEqual(ref, res)
12941294

1295+
@torch._inductor.config.patch(emulate_precision_casts=True)
1296+
def test_dont_inplace_disjoint_accesses(self):
1297+
# TODO - would not need mms if we could annotate donated buffer..
1298+
def forward( # noqa: F821, F722
1299+
arg0_1: "bf16[2048, 2048][2048, 1]cuda:0", # noqa: F821, F722
1300+
arg1_1: "bf16[8, 4096, 2048][8388608, 2048, 1]cuda:0", # noqa: F821, F722
1301+
arg2_1: "bf16[2048, 2048][2048, 1]cuda:0", # noqa: F821, F722
1302+
arg3_1: "bf16[2048, 2048][2048, 1]cuda:0", # noqa: F821, F722
1303+
arg4_1: "bf16[2048][1]cuda:0", # noqa: F821, F722
1304+
arg5_1: "bf16[2048][1]cuda:0", # noqa: F821, F722
1305+
arg6_1: "f32[4096, 128][128, 1]cuda:0", # noqa: F821, F722
1306+
arg7_1: "f32[4096, 128][128, 1]cuda:0", # noqa: F821, F722
1307+
):
1308+
permute = torch.ops.aten.permute.default(arg0_1, [1, 0])
1309+
arg0_1 = None
1310+
view = torch.ops.aten.view.default(arg1_1, [32768, 2048])
1311+
mm = torch.ops.aten.mm.default(view, permute)
1312+
view = permute = None
1313+
view_1 = torch.ops.aten.view.default(mm, [8, 4096, 2048])
1314+
mm = None
1315+
permute_1 = torch.ops.aten.permute.default(arg2_1, [1, 0])
1316+
arg2_1 = None
1317+
view_2 = torch.ops.aten.view.default(arg1_1, [32768, 2048])
1318+
mm_1 = torch.ops.aten.mm.default(view_2, permute_1)
1319+
view_2 = permute_1 = None
1320+
view_3 = torch.ops.aten.view.default(mm_1, [8, 4096, 2048])
1321+
mm_1 = None
1322+
permute_2 = torch.ops.aten.permute.default(arg3_1, [1, 0])
1323+
arg3_1 = None
1324+
view_4 = torch.ops.aten.view.default(arg1_1, [32768, 2048])
1325+
arg1_1 = None
1326+
mm_2 = torch.ops.aten.mm.default(view_4, permute_2)
1327+
view_4 = permute_2 = None
1328+
view_5 = torch.ops.aten.view.default(mm_2, [8, 4096, 2048])
1329+
mm_2 = None
1330+
convert_element_type_6 = torch.ops.prims.convert_element_type.default(
1331+
view_1, torch.float32
1332+
)
1333+
view_1 = None
1334+
pow_1 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_6, 2)
1335+
mean = torch.ops.aten.mean.dim(pow_1, [-1], True)
1336+
pow_1 = None
1337+
add = torch.ops.aten.add.Tensor(mean, 1e-06)
1338+
mean = None
1339+
rsqrt = torch.ops.aten.rsqrt.default(add)
1340+
add = None
1341+
mul = torch.ops.aten.mul.Tensor(convert_element_type_6, rsqrt)
1342+
convert_element_type_6 = rsqrt = None
1343+
convert_element_type_7 = torch.ops.prims.convert_element_type.default(
1344+
arg4_1, torch.float32
1345+
)
1346+
arg4_1 = None
1347+
mul_1 = torch.ops.aten.mul.Tensor(convert_element_type_7, mul)
1348+
convert_element_type_7 = mul = None
1349+
convert_element_type_8 = torch.ops.prims.convert_element_type.default(
1350+
mul_1, torch.bfloat16
1351+
)
1352+
mul_1 = None
1353+
convert_element_type_9 = torch.ops.prims.convert_element_type.default(
1354+
view_3, torch.float32
1355+
)
1356+
view_3 = None
1357+
pow_2 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_9, 2)
1358+
mean_1 = torch.ops.aten.mean.dim(pow_2, [-1], True)
1359+
pow_2 = None
1360+
add_1 = torch.ops.aten.add.Tensor(mean_1, 1e-06)
1361+
mean_1 = None
1362+
rsqrt_1 = torch.ops.aten.rsqrt.default(add_1)
1363+
add_1 = None
1364+
mul_2 = torch.ops.aten.mul.Tensor(convert_element_type_9, rsqrt_1)
1365+
convert_element_type_9 = rsqrt_1 = None
1366+
convert_element_type_10 = torch.ops.prims.convert_element_type.default(
1367+
arg5_1, torch.float32
1368+
)
1369+
arg5_1 = None
1370+
mul_3 = torch.ops.aten.mul.Tensor(convert_element_type_10, mul_2)
1371+
convert_element_type_10 = mul_2 = None
1372+
convert_element_type_11 = torch.ops.prims.convert_element_type.default(
1373+
mul_3, torch.bfloat16
1374+
)
1375+
mul_3 = None
1376+
view_6 = torch.ops.aten.view.default(
1377+
convert_element_type_8, [8, 4096, -1, 128]
1378+
)
1379+
convert_element_type_8 = None
1380+
view_7 = torch.ops.aten.view.default(
1381+
convert_element_type_11, [8, 4096, -1, 128]
1382+
)
1383+
convert_element_type_11 = None
1384+
view_8 = torch.ops.aten.view.default(view_5, [8, 4096, -1, 128])
1385+
view_5 = None
1386+
convert_element_type_12 = torch.ops.prims.convert_element_type.default(
1387+
view_6, torch.float32
1388+
)
1389+
view_6 = None
1390+
convert_element_type_13 = torch.ops.prims.convert_element_type.default(
1391+
view_7, torch.float32
1392+
)
1393+
view_7 = None
1394+
unsqueeze = torch.ops.aten.unsqueeze.default(arg6_1, 0)
1395+
unsqueeze_1 = torch.ops.aten.unsqueeze.default(unsqueeze, 2)
1396+
unsqueeze = None
1397+
unsqueeze_2 = torch.ops.aten.unsqueeze.default(arg7_1, 0)
1398+
unsqueeze_3 = torch.ops.aten.unsqueeze.default(unsqueeze_2, 2)
1399+
unsqueeze_2 = None
1400+
mul_4 = torch.ops.aten.mul.Tensor(convert_element_type_12, unsqueeze_3)
1401+
unsqueeze_3 = None
1402+
view_9 = torch.ops.aten.view.default(
1403+
convert_element_type_12, [8, 4096, 16, 2, 64]
1404+
)
1405+
convert_element_type_12 = None
1406+
unbind = torch.ops.aten.unbind.int(view_9, -2)
1407+
view_9 = None
1408+
getitem = unbind[0]
1409+
getitem_1 = unbind[1]
1410+
unbind = None
1411+
neg = torch.ops.aten.neg.default(getitem_1)
1412+
getitem_1 = None
1413+
cat = torch.ops.aten.cat.default([neg, getitem], -1)
1414+
neg = getitem = None
1415+
mul_5 = torch.ops.aten.mul.Tensor(cat, unsqueeze_1)
1416+
cat = unsqueeze_1 = None
1417+
add_2 = torch.ops.aten.add.Tensor(mul_4, mul_5)
1418+
mul_4 = mul_5 = None
1419+
unsqueeze_4 = torch.ops.aten.unsqueeze.default(arg6_1, 0)
1420+
arg6_1 = None
1421+
unsqueeze_5 = torch.ops.aten.unsqueeze.default(unsqueeze_4, 2)
1422+
unsqueeze_4 = None
1423+
unsqueeze_6 = torch.ops.aten.unsqueeze.default(arg7_1, 0)
1424+
arg7_1 = None
1425+
unsqueeze_7 = torch.ops.aten.unsqueeze.default(unsqueeze_6, 2)
1426+
unsqueeze_6 = None
1427+
mul_6 = torch.ops.aten.mul.Tensor(convert_element_type_13, unsqueeze_7)
1428+
unsqueeze_7 = None
1429+
view_10 = torch.ops.aten.view.default(
1430+
convert_element_type_13, [8, 4096, 16, 2, 64]
1431+
)
1432+
convert_element_type_13 = None
1433+
unbind_1 = torch.ops.aten.unbind.int(view_10, -2)
1434+
view_10 = None
1435+
getitem_2 = unbind_1[0]
1436+
getitem_3 = unbind_1[1]
1437+
unbind_1 = None
1438+
neg_1 = torch.ops.aten.neg.default(getitem_3)
1439+
getitem_3 = None
1440+
cat_1 = torch.ops.aten.cat.default([neg_1, getitem_2], -1)
1441+
neg_1 = getitem_2 = None
1442+
mul_7 = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_5)
1443+
cat_1 = unsqueeze_5 = None
1444+
add_3 = torch.ops.aten.add.Tensor(mul_6, mul_7)
1445+
mul_6 = mul_7 = None
1446+
convert_element_type_14 = torch.ops.prims.convert_element_type.default(
1447+
add_2, torch.bfloat16
1448+
)
1449+
add_2 = None
1450+
convert_element_type_15 = torch.ops.prims.convert_element_type.default(
1451+
add_3, torch.bfloat16
1452+
)
1453+
add_3 = None
1454+
permute_3 = torch.ops.aten.permute.default(
1455+
convert_element_type_14, [0, 2, 1, 3]
1456+
)
1457+
convert_element_type_14 = None
1458+
permute_4 = torch.ops.aten.permute.default(
1459+
convert_element_type_15, [0, 2, 1, 3]
1460+
)
1461+
convert_element_type_15 = None
1462+
permute_5 = torch.ops.aten.permute.default(view_8, [0, 2, 1, 3])
1463+
view_8 = None
1464+
return (permute_3, permute_4, permute_5)
1465+
1466+
from torch._dynamo.debug_utils import aot_graph_input_parser
1467+
1468+
kwargs = aot_graph_input_parser(forward)
1469+
out, code = run_and_get_code(torch.compile(forward), **kwargs)
1470+
# ignore tiny values.. prior to this fix absolute error was ~28
1471+
self.assertEqual(forward(**kwargs), out, atol=0.01, rtol=2)
1472+
FileCheck().check_not("in_out").run(code[0])
1473+
12951474
# https://github.com/pytorch/pytorch/issues/104937
12961475
def test_linear_with_zero_infeature_size(self):
12971476
m = nn.Linear(in_features=0, out_features=0, bias=True).to("cuda")

torch/_inductor/scheduler.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,38 @@ def decide_inplace_update(self) -> None:
462462
| self.scheduler.completed_operations
463463
)
464464

465+
def single_index_in_fused_node(buf_to_be_inplaced: SchedulerBuffer) -> bool:
466+
# Inside of NodeUser, we track that the read and write are equivalent
467+
# before deciding if the use can be inplace.
468+
# But if that use is fused into a larger kernel, we need to check equivalence
469+
# of other accesses in fused scheduler node as well.
470+
fused_node = buf_to_be_inplaced.scheduler.get_fused_node(self)
471+
buf_name = buf_to_be_inplaced.get_name()
472+
# Dedup read/writes with equivalent indices
473+
# TODO - would be nice if we could just cache accesses on ReadWrites,
474+
# and inforce variant that this class & members are functional..
475+
deps: OrderedSet[Dep] = OrderedSet()
476+
for user in buf_to_be_inplaced.users:
477+
user_node = user.node
478+
if not isinstance(user_node, BaseSchedulerNode):
479+
continue
480+
481+
if (
482+
buf_to_be_inplaced.scheduler.get_fused_node(user_node)
483+
is not fused_node
484+
):
485+
continue
486+
487+
deps |= (
488+
o
489+
for o in user_node.read_writes.reads_and_writes()
490+
if o.name == buf_name
491+
)
492+
if len(deps) > 1:
493+
return False
494+
495+
return True
496+
465497
for buf in self.get_outputs():
466498
buf_node = buf.node
467499
assert buf_node is not None
@@ -513,6 +545,7 @@ def decide_inplace_update(self) -> None:
513545
and len(input_buf.node.get_inputs_that_alias_output()) > 0
514546
)
515547
and can_match_buffer_size(input_buf.node, buf.node)
548+
and single_index_in_fused_node(input_buf)
516549
):
517550
# if there isn't a triton kernel, then we don't need to call triton-specific things.
518551
# but TODO this might be a convenient place to signal to the Collective kernels to inplace

0 commit comments

Comments
 (0)