diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index 16dbee61ca0e..1326fb227c06 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -8587,11 +8587,21 @@ def _new_fn(): mod_or_fn.to(device) return mod_or_fn - with patch.object( - torch._dynamo.variables.higher_order_ops.CondHigherOrderVariable, - "supports_input_mutation", - True, - ): + with contextlib.ExitStack() as ctx_stack: + ctx_stack.enter_context( + patch.object( + torch._dynamo.variables.higher_order_ops.CondHigherOrderVariable, + "supports_input_mutation", + True, + ), + ) + ctx_stack.enter_context( + patch.object( + torch._dynamo.variables.higher_order_ops.WhileLoopHigherOrderVariable, + "supports_input_mutation", + True, + ), + ) # Only suuport input mutation in inference cloned_args = [_clone(args) for _ in range(3)] with torch.no_grad(): @@ -8810,6 +8820,211 @@ def forward(self, arg0_1: "f32[4, 3]", arg1_1: "f32[3, 4]"): """, # noqa: B950 ) + @requires_cuda + @unittest.skipIf(not SM70OrLater, "triton") + @parametrize("device", ["cuda", "cpu"]) + @parametrize("dynamic", [True, False]) + def test_while_loop_auto_functionalize_input_mutation(self, device, dynamic): + class M(torch.nn.Module): + def forward(self, x, y): + def cond_fn(x): + return x.sum() > 0 + + def body_fn(x): + x.add_(-1) + return (x.clone(),) + + x = x.clone() + ret = while_loop(cond_fn, body_fn, (x,)) + return y + ret[0] + + x, y = ( + torch.randn(3, 4), + torch.randn(3, 4), + ) + fw_gm = self.check(M, (x, y), device, dynamic) + if not TEST_WITH_CROSSREF and not dynamic and device == "cuda": + self.assertExpectedInline( + normalize_gm(fw_gm.print_readable(print_output=False)), + """\ +class (torch.nn.Module): + def forward(self, arg0_1: "f32[3, 4]", arg1_1: "f32[3, 4]"): + clone: "f32[3, 4]" = torch.ops.aten.clone.default(arg0_1); arg0_1 = None + + auto_functionalized_subgraph_0 = self.auto_functionalized_subgraph_0 + auto_functionalized_subgraph_1 = self.auto_functionalized_subgraph_1 + _tree_spec_constant0 = self._tree_spec_constant0 + auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.higher_order.while_loop, cond_fn = auto_functionalized_subgraph_0, body_fn = auto_functionalized_subgraph_1, _carried_input0_base_index = 0, _all_bases = [clone], _op_schema = _tree_spec_constant0); auto_functionalized_subgraph_0 = auto_functionalized_subgraph_1 = clone = _tree_spec_constant0 = None + getitem: "f32[3, 4]" = auto_functionalized_v2[0]; auto_functionalized_v2 = None + + add: "f32[3, 4]" = torch.ops.aten.add.Tensor(arg1_1, getitem); arg1_1 = getitem = None + return (add,) + + class auto_functionalized_subgraph_0(torch.nn.Module): + def forward(self, arg0_1: "f32[3, 4]"): + sum_1: "f32[]" = torch.ops.aten.sum.default(arg0_1); arg0_1 = None + gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None + return gt + + class auto_functionalized_subgraph_1(torch.nn.Module): + def forward(self, arg0_1: "f32[3, 4]"): + add: "f32[3, 4]" = torch.ops.aten.add.Tensor(arg0_1, -1) + clone: "f32[3, 4]" = torch.ops.aten.clone.default(add) + copy_: "f32[3, 4]" = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = add = copy_ = None + return (clone,) +""", # noqa: B950 + ) + + @requires_cuda + @unittest.skipIf(not SM70OrLater, "triton") + @parametrize("device", ["cuda", "cpu"]) + @parametrize("dynamic", [True, False]) + def test_while_loop_auto_functionalize_buffer_mutation(self, device, dynamic): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer( + "buf", torch.ones(8, requires_grad=False, device=device) + ) + + def forward(self, p, x): + def cond_fn(x): + return x.sum() < 0 + + def body_fn(x): + x.add_(-1) + self.buf.add_(-1) + return (x + self.buf.sum(),) + + x = x.clone() + out = while_loop(cond_fn, body_fn, (x,)) + return x + self.buf + out[0] + + p, x = torch.tensor(True), torch.randn(1, requires_grad=True) + fw_gm = self.check(M, (p, x), device, dynamic) + if not TEST_WITH_CROSSREF and not dynamic and device == "cuda": + self.assertExpectedInline( + normalize_gm(fw_gm.print_readable(print_output=False)), + """\ +class (torch.nn.Module): + def forward(self, arg0_1: "f32[1]", arg1_1: "f32[8]"): + clone: "f32[1]" = torch.ops.aten.clone.default(arg0_1); arg0_1 = None + + auto_functionalized_subgraph_0 = self.auto_functionalized_subgraph_0 + auto_functionalized_subgraph_1 = self.auto_functionalized_subgraph_1 + _tree_spec_constant0 = self._tree_spec_constant0 + auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.higher_order.while_loop, cond_fn = auto_functionalized_subgraph_0, body_fn = auto_functionalized_subgraph_1, _carried_input0_base_index = 0, _additional_input0_base_index = 1, _all_bases = [clone, arg1_1], _op_schema = _tree_spec_constant0); auto_functionalized_subgraph_0 = auto_functionalized_subgraph_1 = clone = _tree_spec_constant0 = None + getitem: "f32[1]" = auto_functionalized_v2[0] + getitem_1: "f32[1]" = auto_functionalized_v2[1] + getitem_2: "f32[8]" = auto_functionalized_v2[2]; auto_functionalized_v2 = None + + add: "f32[8]" = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = None + add_1: "f32[8]" = torch.ops.aten.add.Tensor(add, getitem); add = getitem = None + + copy_: "f32[8]" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy_ = None + return (add_1,) + + class auto_functionalized_subgraph_0(torch.nn.Module): + def forward(self, arg0_1: "f32[1]", arg1_1: "f32[8]"): + sum_1: "f32[]" = torch.ops.aten.sum.default(arg0_1); arg0_1 = None + lt: "b8[]" = torch.ops.aten.lt.Scalar(sum_1, 0); sum_1 = None + return lt + + class auto_functionalized_subgraph_1(torch.nn.Module): + def forward(self, arg0_1: "f32[1]", arg1_1: "f32[8]"): + add: "f32[1]" = torch.ops.aten.add.Tensor(arg0_1, -1) + add_1: "f32[8]" = torch.ops.aten.add.Tensor(arg1_1, -1) + sum_1: "f32[]" = torch.ops.aten.sum.default(add_1) + add_2: "f32[1]" = torch.ops.aten.add.Tensor(add, sum_1); sum_1 = None + copy_: "f32[1]" = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = add = copy_ = None + copy__1: "f32[8]" = torch.ops.aten.copy_.default(arg1_1, add_1); arg1_1 = add_1 = copy__1 = None + return (add_2,) +""", # noqa: B950 + ) + + @requires_cuda + @unittest.skipIf(not SM70OrLater, "triton") + @torch._dynamo.config.patch(capture_scalar_outputs=True) + @torch._dynamo.config.patch(prefer_deferred_runtime_asserts_over_guards=True) + @parametrize("device", ["cuda", "cpu"]) + @parametrize("dynamic", [True, False]) + def test_while_loop_auto_functionalize_inplace_mutate_out_buffer_as_carry( + self, device, dynamic + ): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer( + "buf", torch.ones(1, requires_grad=False, device=device) + ) + + def forward(self, p, x): + def cond_fn(it, x, out_buf): + return it < x.size(0) + + def body_fn(it, x, out_buf): + out = x.sin() + idx = it.item() + torch._check_is_size(idx, max=x.size(0) - 1) + out_buf[idx].add_(out[idx]) + return (it + 1, x + 1, out_buf.clone()) + + it = torch.tensor(0, dtype=torch.int64) + out_buf = x.clone() + x = x.clone() + out = while_loop(cond_fn, body_fn, (it, x, out_buf)) + return x + self.buf + out[0] + + p, x = torch.tensor(True), torch.randn(3, 4) + fw_gm = self.check(M, (p, x), device, dynamic) + if not TEST_WITH_CROSSREF and not dynamic and device == "cuda": + self.assertExpectedInline( + normalize_gm(fw_gm.print_readable(print_output=False)), + """\ +class (torch.nn.Module): + def forward(self, arg0_1: "f32[3, 4]", arg1_1: "f32[1]"): + _tensor_constant0: "i64[]" = self._tensor_constant0 + lift_fresh_copy: "i64[]" = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None + + clone: "f32[3, 4]" = torch.ops.aten.clone.default(arg0_1) + + clone_1: "f32[3, 4]" = torch.ops.aten.clone.default(arg0_1); arg0_1 = None + + auto_functionalized_subgraph_0 = self.auto_functionalized_subgraph_0 + auto_functionalized_subgraph_1 = self.auto_functionalized_subgraph_1 + _tree_spec_constant0 = self._tree_spec_constant0 + auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.higher_order.while_loop, cond_fn = auto_functionalized_subgraph_0, body_fn = auto_functionalized_subgraph_1, carried_input0 = lift_fresh_copy, carried_input1 = clone_1, _carried_input2_base_index = 0, _all_bases = [clone], _op_schema = _tree_spec_constant0); auto_functionalized_subgraph_0 = auto_functionalized_subgraph_1 = lift_fresh_copy = clone = _tree_spec_constant0 = None + getitem: "i64[]" = auto_functionalized_v2[0]; auto_functionalized_v2 = None + + add: "f32[3, 4]" = torch.ops.aten.add.Tensor(clone_1, arg1_1); clone_1 = arg1_1 = None + add_1: "f32[3, 4]" = torch.ops.aten.add.Tensor(add, getitem); add = getitem = None + return (add_1,) + + class auto_functionalized_subgraph_0(torch.nn.Module): + def forward(self, arg0_1: "i64[]", arg1_1: "f32[3, 4]", arg2_1: "f32[3, 4]"): + lt: "b8[]" = torch.ops.aten.lt.Scalar(arg0_1, 3); arg0_1 = None + return lt + + class auto_functionalized_subgraph_1(torch.nn.Module): + def forward(self, arg0_1: "i64[]", arg1_1: "f32[3, 4]", arg2_1: "f32[3, 4]"): + sin: "f32[3, 4]" = torch.ops.aten.sin.default(arg1_1) + _local_scalar_dense: "Sym(u3)" = torch.ops.aten._local_scalar_dense.default(arg0_1) + ge_1: "Sym(u3 >= 0)" = _local_scalar_dense >= 0 + _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u3 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_default = None + le_1: "Sym(u3 <= 2)" = _local_scalar_dense <= 2 + _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u3 <= 2 on node 'le_1'"); le_1 = _assert_scalar_default_1 = None + select: "f32[4]" = torch.ops.aten.select.int(arg2_1, 0, _local_scalar_dense) + select_1: "f32[4]" = torch.ops.aten.select.int(sin, 0, _local_scalar_dense); sin = None + add: "f32[4]" = torch.ops.aten.add.Tensor(select, select_1); select = select_1 = None + select_scatter: "f32[3, 4]" = torch.ops.aten.select_scatter.default(arg2_1, add, 0, _local_scalar_dense); add = _local_scalar_dense = None + add_1: "i64[]" = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None + add_2: "f32[3, 4]" = torch.ops.aten.add.Tensor(arg1_1, 1); arg1_1 = None + clone: "f32[3, 4]" = torch.ops.aten.clone.default(select_scatter) + copy_: "f32[3, 4]" = torch.ops.aten.copy_.default(arg2_1, select_scatter); arg2_1 = select_scatter = copy_ = None + return (add_1, add_2, clone) +""", # noqa: B950 + ) + _hop_schema_test_schema_types = [ "bool", diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index d3334424c5f4..58633d04dd18 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -1369,8 +1369,8 @@ def unspecialize_carried_inputs(tx, carry) -> VariableTracker: source_target=self.value, set_subgraph_inputs="flatten_manual", should_flatten_outputs=True, - supports_input_mutation=False, - supports_aliasing=False, + supports_input_mutation=self.supports_input_mutation, + supports_aliasing=self.supports_aliasing, ) validate_subgraph_output_types(body_r) diff --git a/torch/_higher_order_ops/auto_functionalize.py b/torch/_higher_order_ops/auto_functionalize.py index d5aa0d09c8b1..28138d8b29f3 100644 --- a/torch/_higher_order_ops/auto_functionalize.py +++ b/torch/_higher_order_ops/auto_functionalize.py @@ -588,7 +588,7 @@ def __call__(self, *args, **kwargs): # Inlining has the benefit of allowing easiser fusion inside subgraph. # Though the epilogue graph contains copy_, it is OK because inductor can handle it # and this is also how we have been supporting top-level graph input mutation. - return tuple(torch.func.functionalize(self.orig_callable)(*args, **kwargs)) + return torch.func.functionalize(self.orig_callable)(*args, **kwargs) def __hash__(self): return id(self.orig_callable) diff --git a/torch/_higher_order_ops/while_loop.py b/torch/_higher_order_ops/while_loop.py index 044828ef1a84..9c11ed93625a 100644 --- a/torch/_higher_order_ops/while_loop.py +++ b/torch/_higher_order_ops/while_loop.py @@ -112,10 +112,11 @@ def _find_example_value(n, real_inp): for idx, arg in enumerate(additional_inputs): additional_idx = len(carried_inputs) + idx - assert additional_idx not in mutated_inputs, ( - "Lifted additional_inputs cannot be in-place mutated." + schema_gen.add_arg( + f"additional_input{idx}", + arg, + is_mutated=additional_idx in mutated_inputs, ) - schema_gen.add_arg(f"additional_input{idx}", arg, is_mutated=False) for out in body_outputs: schema_gen.add_output(out) @@ -498,7 +499,28 @@ def while_loop_fake_tensor_mode( @while_loop_op.py_functionalize_impl def while_loop_func(ctx, cond_fn, body_fn, carried_inputs, additional_inputs): - from torch._higher_order_ops.utils import _check_alias_and_mutation + from torch._higher_order_ops.auto_functionalize import ( + can_auto_functionalize, + do_auto_functionalize_v2, + ) + from torch._higher_order_ops.utils import _check_alias_and_mutation, HopInstance + + hop_instance = HopInstance.create( + while_loop_op, cond_fn, body_fn, carried_inputs, additional_inputs + ) + # For now, we only support auto-functionalization for while_loop when using python + # functionalization mode + if can_auto_functionalize(hop_instance) and hasattr(ctx, "mode"): + return do_auto_functionalize_v2( + ctx.mode, + hop_instance, + tuple( + pytree.tree_flatten( + (cond_fn, body_fn, carried_inputs, additional_inputs) + )[0] + ), + {}, + ) unwrapped_carried_inputs = ctx.unwrap_tensors(carried_inputs) unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 74a562365b69..85495043ba43 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -7024,7 +7024,7 @@ def _map_output(out: Any): return out elif isinstance(out, ir.StorageBox): return TensorBox(out) - elif isinstance(out, ir.MultiOutput): + elif isinstance(out, (ir.MultiOutput, ir.ReinterpretView)): return TensorBox.create(out) else: raise RuntimeError(f"NYI unsupported output type: {type(out)}")