Skip to content

Commit 444e238

Browse files
v0i0pytorchmergebot
authored andcommitted
[inductor] move all cpu scalars using pinned memory for graph partition (#155360) (#158983)
Pull Request resolved: #158983 Approved by: https://github.com/eellison ghstack dependencies: #158758
1 parent 6085bf7 commit 444e238

File tree

2 files changed

+64
-13
lines changed

2 files changed

+64
-13
lines changed

test/inductor/test_cudagraph_trees.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2849,6 +2849,28 @@ def foo(x):
28492849

28502850
self.assertEqual(x, torch.tensor(1, device="cpu"))
28512851

2852+
@torch._inductor.config.patch("graph_partition", True)
2853+
def test_graph_partition_cpu_scalar_multiple(self):
2854+
def f(x, y, z):
2855+
return x + y, x + z
2856+
2857+
compiled_f = torch.compile(f, mode="reduce-overhead")
2858+
2859+
inputs = (
2860+
torch.ones((), device="cpu"),
2861+
torch.ones((), device="cpu"),
2862+
torch.ones(2, 2, device="cuda"),
2863+
)
2864+
for i in range(3):
2865+
if i == 0:
2866+
_, code = run_and_get_code(compiled_f, *inputs)
2867+
FileCheck().check_regex(r".copy_.*True").run(code[0])
2868+
FileCheck().check_count(".copy_", 1, exactly=True).run(code[0])
2869+
else:
2870+
compiled_f(*inputs)
2871+
self.assertEqual(compiled_f(*inputs), f(*inputs))
2872+
self.assertEqual(self.get_manager().new_graph_id().id, 1)
2873+
28522874
@torch._inductor.config.patch("graph_partition", True)
28532875
@torch._inductor.config.patch("triton.cudagraphs", False)
28542876
def test_graph_partition_reduce_overhead_mode_effectiveness(self):

torch/_inductor/fx_passes/post_grad.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1760,17 +1760,44 @@ def __call__(self, graph: fx.Graph) -> None:
17601760
movable_constructors = self.find_movable_constructors(graph, constructors)
17611761

17621762
target_device = next(iter(target_devices))
1763-
for node in movable_constructors:
1764-
if node in cpu_placeholders:
1765-
with graph.inserting_after(node):
1766-
gpu_node = graph.call_function(
1767-
torch.ops.prims.device_put.default, (node, target_device)
1763+
movable_cpu_placeholders = movable_constructors & cpu_placeholders
1764+
if movable_cpu_placeholders:
1765+
node = next(iter(reversed(movable_cpu_placeholders)))
1766+
last_node = node
1767+
unsqueezed_nodes = []
1768+
for elem in movable_cpu_placeholders:
1769+
with graph.inserting_after(last_node):
1770+
unsqueezed_nodes.append(
1771+
graph.call_function(torch.ops.aten.unsqueeze.default, (elem, 0))
17681772
)
1769-
node.replace_all_uses_with(
1770-
gpu_node,
1771-
lambda x: x != gpu_node
1772-
and x.target != torch.ops.aten.copy_.default,
1773+
last_node = unsqueezed_nodes[-1]
1774+
with graph.inserting_after(last_node):
1775+
cpu_concat = graph.call_function(
1776+
torch.ops.aten.cat.default, (unsqueezed_nodes,)
1777+
)
1778+
last_node = cpu_concat
1779+
with graph.inserting_after(last_node):
1780+
gpu_concat = graph.call_function(
1781+
torch.ops.prims.device_put.default,
1782+
(cpu_concat, target_device, True),
17731783
)
1784+
last_node = gpu_concat
1785+
with graph.inserting_after(last_node):
1786+
gpu_split = graph.call_function(
1787+
torch.ops.aten.unbind.int, (gpu_concat,)
1788+
)
1789+
last_node = gpu_split
1790+
for idx, node in enumerate(movable_cpu_placeholders):
1791+
with graph.inserting_after(last_node):
1792+
gpu_node = graph.call_function(operator.getitem, (gpu_split, idx))
1793+
node.replace_all_uses_with(
1794+
gpu_node,
1795+
lambda x: x
1796+
not in [cpu_concat, gpu_concat, gpu_split, gpu_node]
1797+
+ unsqueezed_nodes
1798+
and x.target != torch.ops.aten.copy_.default,
1799+
)
1800+
last_node = gpu_node
17741801

17751802
# noop elimination if there are other device_put for gpu_node to
17761803
# target device. Alternatively, we could just move the other device_put
@@ -1784,10 +1811,12 @@ def __call__(self, graph: fx.Graph) -> None:
17841811
for noop in noop_device_puts:
17851812
noop.replace_all_uses_with(gpu_node)
17861813
graph.erase_node(noop)
1787-
else:
1788-
kwargs = node.kwargs.copy()
1789-
kwargs["device"] = target_device
1790-
node.kwargs = kwargs
1814+
1815+
movable_constructors -= movable_cpu_placeholders
1816+
for node in movable_constructors:
1817+
kwargs = node.kwargs.copy()
1818+
kwargs["device"] = target_device
1819+
node.kwargs = kwargs
17911820

17921821
def find_movable_constructors(
17931822
self, graph: fx.Graph, constructors: list[fx.Node]

0 commit comments

Comments
 (0)