@@ -8586,11 +8586,21 @@ def _new_fn():
8586
8586
mod_or_fn .to (device )
8587
8587
return mod_or_fn
8588
8588
8589
- with patch .object (
8590
- torch ._dynamo .variables .higher_order_ops .CondHigherOrderVariable ,
8591
- "supports_input_mutation" ,
8592
- True ,
8593
- ):
8589
+ with contextlib .ExitStack () as ctx_stack :
8590
+ ctx_stack .enter_context (
8591
+ patch .object (
8592
+ torch ._dynamo .variables .higher_order_ops .CondHigherOrderVariable ,
8593
+ "supports_input_mutation" ,
8594
+ True ,
8595
+ ),
8596
+ )
8597
+ ctx_stack .enter_context (
8598
+ patch .object (
8599
+ torch ._dynamo .variables .higher_order_ops .WhileLoopHigherOrderVariable ,
8600
+ "supports_input_mutation" ,
8601
+ True ,
8602
+ ),
8603
+ )
8594
8604
# Only suuport input mutation in inference
8595
8605
cloned_args = [_clone (args ) for _ in range (3 )]
8596
8606
with torch .no_grad ():
@@ -8809,6 +8819,211 @@ def forward(self, arg0_1: "f32[4, 3]", arg1_1: "f32[3, 4]"):
8809
8819
""" , # noqa: B950
8810
8820
)
8811
8821
8822
+ @requires_cuda
8823
+ @unittest .skipIf (not SM70OrLater , "triton" )
8824
+ @parametrize ("device" , ["cuda" , "cpu" ])
8825
+ @parametrize ("dynamic" , [True , False ])
8826
+ def test_while_loop_auto_functionalize_input_mutation (self , device , dynamic ):
8827
+ class M (torch .nn .Module ):
8828
+ def forward (self , x , y ):
8829
+ def cond_fn (x ):
8830
+ return x .sum () > 0
8831
+
8832
+ def body_fn (x ):
8833
+ x .add_ (- 1 )
8834
+ return (x .clone (),)
8835
+
8836
+ x = x .clone ()
8837
+ ret = while_loop (cond_fn , body_fn , (x ,))
8838
+ return y + ret [0 ]
8839
+
8840
+ x , y = (
8841
+ torch .randn (3 , 4 ),
8842
+ torch .randn (3 , 4 ),
8843
+ )
8844
+ fw_gm = self .check (M , (x , y ), device , dynamic )
8845
+ if not TEST_WITH_CROSSREF and not dynamic and device == "cuda" :
8846
+ self .assertExpectedInline (
8847
+ normalize_gm (fw_gm .print_readable (print_output = False )),
8848
+ """\
8849
+ class <lambda>(torch.nn.Module):
8850
+ def forward(self, arg0_1: "f32[3, 4]", arg1_1: "f32[3, 4]"):
8851
+ clone: "f32[3, 4]" = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
8852
+
8853
+ auto_functionalized_subgraph_0 = self.auto_functionalized_subgraph_0
8854
+ auto_functionalized_subgraph_1 = self.auto_functionalized_subgraph_1
8855
+ _tree_spec_constant0 = self._tree_spec_constant0
8856
+ auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.higher_order.while_loop, cond_fn = auto_functionalized_subgraph_0, body_fn = auto_functionalized_subgraph_1, _carried_input0_base_index = 0, _all_bases = [clone], _op_schema = _tree_spec_constant0); auto_functionalized_subgraph_0 = auto_functionalized_subgraph_1 = clone = _tree_spec_constant0 = None
8857
+ getitem: "f32[3, 4]" = auto_functionalized_v2[0]; auto_functionalized_v2 = None
8858
+
8859
+ add: "f32[3, 4]" = torch.ops.aten.add.Tensor(arg1_1, getitem); arg1_1 = getitem = None
8860
+ return (add,)
8861
+
8862
+ class auto_functionalized_subgraph_0(torch.nn.Module):
8863
+ def forward(self, arg0_1: "f32[3, 4]"):
8864
+ sum_1: "f32[]" = torch.ops.aten.sum.default(arg0_1); arg0_1 = None
8865
+ gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None
8866
+ return gt
8867
+
8868
+ class auto_functionalized_subgraph_1(torch.nn.Module):
8869
+ def forward(self, arg0_1: "f32[3, 4]"):
8870
+ add: "f32[3, 4]" = torch.ops.aten.add.Tensor(arg0_1, -1)
8871
+ clone: "f32[3, 4]" = torch.ops.aten.clone.default(add)
8872
+ copy_: "f32[3, 4]" = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = add = copy_ = None
8873
+ return (clone,)
8874
+ """ , # noqa: B950
8875
+ )
8876
+
8877
+ @requires_cuda
8878
+ @unittest .skipIf (not SM70OrLater , "triton" )
8879
+ @parametrize ("device" , ["cuda" , "cpu" ])
8880
+ @parametrize ("dynamic" , [True , False ])
8881
+ def test_while_loop_auto_functionalize_buffer_mutation (self , device , dynamic ):
8882
+ class M (torch .nn .Module ):
8883
+ def __init__ (self ):
8884
+ super ().__init__ ()
8885
+ self .register_buffer (
8886
+ "buf" , torch .ones (8 , requires_grad = False , device = device )
8887
+ )
8888
+
8889
+ def forward (self , p , x ):
8890
+ def cond_fn (x ):
8891
+ return x .sum () < 0
8892
+
8893
+ def body_fn (x ):
8894
+ x .add_ (- 1 )
8895
+ self .buf .add_ (- 1 )
8896
+ return (x + self .buf .sum (),)
8897
+
8898
+ x = x .clone ()
8899
+ out = while_loop (cond_fn , body_fn , (x ,))
8900
+ return x + self .buf + out [0 ]
8901
+
8902
+ p , x = torch .tensor (True ), torch .randn (1 , requires_grad = True )
8903
+ fw_gm = self .check (M , (p , x ), device , dynamic )
8904
+ if not TEST_WITH_CROSSREF and not dynamic and device == "cuda" :
8905
+ self .assertExpectedInline (
8906
+ normalize_gm (fw_gm .print_readable (print_output = False )),
8907
+ """\
8908
+ class <lambda>(torch.nn.Module):
8909
+ def forward(self, arg0_1: "f32[1]", arg1_1: "f32[8]"):
8910
+ clone: "f32[1]" = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
8911
+
8912
+ auto_functionalized_subgraph_0 = self.auto_functionalized_subgraph_0
8913
+ auto_functionalized_subgraph_1 = self.auto_functionalized_subgraph_1
8914
+ _tree_spec_constant0 = self._tree_spec_constant0
8915
+ auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.higher_order.while_loop, cond_fn = auto_functionalized_subgraph_0, body_fn = auto_functionalized_subgraph_1, _carried_input0_base_index = 0, _additional_input0_base_index = 1, _all_bases = [clone, arg1_1], _op_schema = _tree_spec_constant0); auto_functionalized_subgraph_0 = auto_functionalized_subgraph_1 = clone = _tree_spec_constant0 = None
8916
+ getitem: "f32[1]" = auto_functionalized_v2[0]
8917
+ getitem_1: "f32[1]" = auto_functionalized_v2[1]
8918
+ getitem_2: "f32[8]" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
8919
+
8920
+ add: "f32[8]" = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = None
8921
+ add_1: "f32[8]" = torch.ops.aten.add.Tensor(add, getitem); add = getitem = None
8922
+
8923
+ copy_: "f32[8]" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy_ = None
8924
+ return (add_1,)
8925
+
8926
+ class auto_functionalized_subgraph_0(torch.nn.Module):
8927
+ def forward(self, arg0_1: "f32[1]", arg1_1: "f32[8]"):
8928
+ sum_1: "f32[]" = torch.ops.aten.sum.default(arg0_1); arg0_1 = None
8929
+ lt: "b8[]" = torch.ops.aten.lt.Scalar(sum_1, 0); sum_1 = None
8930
+ return lt
8931
+
8932
+ class auto_functionalized_subgraph_1(torch.nn.Module):
8933
+ def forward(self, arg0_1: "f32[1]", arg1_1: "f32[8]"):
8934
+ add: "f32[1]" = torch.ops.aten.add.Tensor(arg0_1, -1)
8935
+ add_1: "f32[8]" = torch.ops.aten.add.Tensor(arg1_1, -1)
8936
+ sum_1: "f32[]" = torch.ops.aten.sum.default(add_1)
8937
+ add_2: "f32[1]" = torch.ops.aten.add.Tensor(add, sum_1); sum_1 = None
8938
+ copy_: "f32[1]" = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = add = copy_ = None
8939
+ copy__1: "f32[8]" = torch.ops.aten.copy_.default(arg1_1, add_1); arg1_1 = add_1 = copy__1 = None
8940
+ return (add_2,)
8941
+ """ , # noqa: B950
8942
+ )
8943
+
8944
+ @requires_cuda
8945
+ @unittest .skipIf (not SM70OrLater , "triton" )
8946
+ @torch ._dynamo .config .patch (capture_scalar_outputs = True )
8947
+ @torch ._dynamo .config .patch (prefer_deferred_runtime_asserts_over_guards = True )
8948
+ @parametrize ("device" , ["cuda" , "cpu" ])
8949
+ @parametrize ("dynamic" , [True , False ])
8950
+ def test_while_loop_auto_functionalize_inplace_mutate_out_buffer_as_carry (
8951
+ self , device , dynamic
8952
+ ):
8953
+ class M (torch .nn .Module ):
8954
+ def __init__ (self ):
8955
+ super ().__init__ ()
8956
+ self .register_buffer (
8957
+ "buf" , torch .ones (1 , requires_grad = False , device = device )
8958
+ )
8959
+
8960
+ def forward (self , p , x ):
8961
+ def cond_fn (it , x , out_buf ):
8962
+ return it < x .size (0 )
8963
+
8964
+ def body_fn (it , x , out_buf ):
8965
+ out = x .sin ()
8966
+ idx = it .item ()
8967
+ torch ._check_is_size (idx , max = x .size (0 ) - 1 )
8968
+ out_buf [idx ].add_ (out [idx ])
8969
+ return (it + 1 , x + 1 , out_buf .clone ())
8970
+
8971
+ it = torch .tensor (0 , dtype = torch .int64 )
8972
+ out_buf = x .clone ()
8973
+ x = x .clone ()
8974
+ out = while_loop (cond_fn , body_fn , (it , x , out_buf ))
8975
+ return x + self .buf + out [0 ]
8976
+
8977
+ p , x = torch .tensor (True ), torch .randn (3 , 4 )
8978
+ fw_gm = self .check (M , (p , x ), device , dynamic )
8979
+ if not TEST_WITH_CROSSREF and not dynamic and device == "cuda" :
8980
+ self .assertExpectedInline (
8981
+ normalize_gm (fw_gm .print_readable (print_output = False )),
8982
+ """\
8983
+ class <lambda>(torch.nn.Module):
8984
+ def forward(self, arg0_1: "f32[3, 4]", arg1_1: "f32[1]"):
8985
+ _tensor_constant0: "i64[]" = self._tensor_constant0
8986
+ lift_fresh_copy: "i64[]" = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None
8987
+
8988
+ clone: "f32[3, 4]" = torch.ops.aten.clone.default(arg0_1)
8989
+
8990
+ clone_1: "f32[3, 4]" = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
8991
+
8992
+ auto_functionalized_subgraph_0 = self.auto_functionalized_subgraph_0
8993
+ auto_functionalized_subgraph_1 = self.auto_functionalized_subgraph_1
8994
+ _tree_spec_constant0 = self._tree_spec_constant0
8995
+ auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.higher_order.while_loop, cond_fn = auto_functionalized_subgraph_0, body_fn = auto_functionalized_subgraph_1, carried_input0 = lift_fresh_copy, carried_input1 = clone_1, _carried_input2_base_index = 0, _all_bases = [clone], _op_schema = _tree_spec_constant0); auto_functionalized_subgraph_0 = auto_functionalized_subgraph_1 = lift_fresh_copy = clone = _tree_spec_constant0 = None
8996
+ getitem: "i64[]" = auto_functionalized_v2[0]; auto_functionalized_v2 = None
8997
+
8998
+ add: "f32[3, 4]" = torch.ops.aten.add.Tensor(clone_1, arg1_1); clone_1 = arg1_1 = None
8999
+ add_1: "f32[3, 4]" = torch.ops.aten.add.Tensor(add, getitem); add = getitem = None
9000
+ return (add_1,)
9001
+
9002
+ class auto_functionalized_subgraph_0(torch.nn.Module):
9003
+ def forward(self, arg0_1: "i64[]", arg1_1: "f32[3, 4]", arg2_1: "f32[3, 4]"):
9004
+ lt: "b8[]" = torch.ops.aten.lt.Scalar(arg0_1, 3); arg0_1 = None
9005
+ return lt
9006
+
9007
+ class auto_functionalized_subgraph_1(torch.nn.Module):
9008
+ def forward(self, arg0_1: "i64[]", arg1_1: "f32[3, 4]", arg2_1: "f32[3, 4]"):
9009
+ sin: "f32[3, 4]" = torch.ops.aten.sin.default(arg1_1)
9010
+ _local_scalar_dense: "Sym(u3)" = torch.ops.aten._local_scalar_dense.default(arg0_1)
9011
+ ge_1: "Sym(u3 >= 0)" = _local_scalar_dense >= 0
9012
+ _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u3 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_default = None
9013
+ le_1: "Sym(u3 <= 2)" = _local_scalar_dense <= 2
9014
+ _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u3 <= 2 on node 'le_1'"); le_1 = _assert_scalar_default_1 = None
9015
+ select: "f32[4]" = torch.ops.aten.select.int(arg2_1, 0, _local_scalar_dense)
9016
+ select_1: "f32[4]" = torch.ops.aten.select.int(sin, 0, _local_scalar_dense); sin = None
9017
+ add: "f32[4]" = torch.ops.aten.add.Tensor(select, select_1); select = select_1 = None
9018
+ select_scatter: "f32[3, 4]" = torch.ops.aten.select_scatter.default(arg2_1, add, 0, _local_scalar_dense); add = _local_scalar_dense = None
9019
+ add_1: "i64[]" = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None
9020
+ add_2: "f32[3, 4]" = torch.ops.aten.add.Tensor(arg1_1, 1); arg1_1 = None
9021
+ clone: "f32[3, 4]" = torch.ops.aten.clone.default(select_scatter)
9022
+ copy_: "f32[3, 4]" = torch.ops.aten.copy_.default(arg2_1, select_scatter); arg2_1 = select_scatter = copy_ = None
9023
+ return (add_1, add_2, clone)
9024
+ """ , # noqa: B950
9025
+ )
9026
+
8812
9027
8813
9028
_hop_schema_test_schema_types = [
8814
9029
"bool" ,
0 commit comments