@@ -1760,17 +1760,44 @@ def __call__(self, graph: fx.Graph) -> None:
1760
1760
movable_constructors = self .find_movable_constructors (graph , constructors )
1761
1761
1762
1762
target_device = next (iter (target_devices ))
1763
- for node in movable_constructors :
1764
- if node in cpu_placeholders :
1765
- with graph .inserting_after (node ):
1766
- gpu_node = graph .call_function (
1767
- torch .ops .prims .device_put .default , (node , target_device )
1763
+ movable_cpu_placeholders = movable_constructors & cpu_placeholders
1764
+ if movable_cpu_placeholders :
1765
+ node = next (iter (reversed (movable_cpu_placeholders )))
1766
+ last_node = node
1767
+ unsqueezed_nodes = []
1768
+ for elem in movable_cpu_placeholders :
1769
+ with graph .inserting_after (last_node ):
1770
+ unsqueezed_nodes .append (
1771
+ graph .call_function (torch .ops .aten .unsqueeze .default , (elem , 0 ))
1768
1772
)
1769
- node .replace_all_uses_with (
1770
- gpu_node ,
1771
- lambda x : x != gpu_node
1772
- and x .target != torch .ops .aten .copy_ .default ,
1773
+ last_node = unsqueezed_nodes [- 1 ]
1774
+ with graph .inserting_after (last_node ):
1775
+ cpu_concat = graph .call_function (
1776
+ torch .ops .aten .cat .default , (unsqueezed_nodes ,)
1777
+ )
1778
+ last_node = cpu_concat
1779
+ with graph .inserting_after (last_node ):
1780
+ gpu_concat = graph .call_function (
1781
+ torch .ops .prims .device_put .default ,
1782
+ (cpu_concat , target_device , True ),
1773
1783
)
1784
+ last_node = gpu_concat
1785
+ with graph .inserting_after (last_node ):
1786
+ gpu_split = graph .call_function (
1787
+ torch .ops .aten .unbind .int , (gpu_concat ,)
1788
+ )
1789
+ last_node = gpu_split
1790
+ for idx , node in enumerate (movable_cpu_placeholders ):
1791
+ with graph .inserting_after (last_node ):
1792
+ gpu_node = graph .call_function (operator .getitem , (gpu_split , idx ))
1793
+ node .replace_all_uses_with (
1794
+ gpu_node ,
1795
+ lambda x : x
1796
+ not in [cpu_concat , gpu_concat , gpu_split , gpu_node ]
1797
+ + unsqueezed_nodes
1798
+ and x .target != torch .ops .aten .copy_ .default ,
1799
+ )
1800
+ last_node = gpu_node
1774
1801
1775
1802
# noop elimination if there are other device_put for gpu_node to
1776
1803
# target device. Alternatively, we could just move the other device_put
@@ -1784,10 +1811,12 @@ def __call__(self, graph: fx.Graph) -> None:
1784
1811
for noop in noop_device_puts :
1785
1812
noop .replace_all_uses_with (gpu_node )
1786
1813
graph .erase_node (noop )
1787
- else :
1788
- kwargs = node .kwargs .copy ()
1789
- kwargs ["device" ] = target_device
1790
- node .kwargs = kwargs
1814
+
1815
+ movable_constructors -= movable_cpu_placeholders
1816
+ for node in movable_constructors :
1817
+ kwargs = node .kwargs .copy ()
1818
+ kwargs ["device" ] = target_device
1819
+ node .kwargs = kwargs
1791
1820
1792
1821
def find_movable_constructors (
1793
1822
self , graph : fx .Graph , constructors : list [fx .Node ]
0 commit comments