Skip to content

Commit f363733

Browse files
committed
[while_loop] support input mutation with auto_functionalize
ghstack-source-id: 8640806 Pull Request resolved: #159010
1 parent fba4583 commit f363733

File tree

5 files changed

+250
-13
lines changed

5 files changed

+250
-13
lines changed

test/functorch/test_control_flow.py

Lines changed: 220 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8586,11 +8586,21 @@ def _new_fn():
85868586
mod_or_fn.to(device)
85878587
return mod_or_fn
85888588

8589-
with patch.object(
8590-
torch._dynamo.variables.higher_order_ops.CondHigherOrderVariable,
8591-
"supports_input_mutation",
8592-
True,
8593-
):
8589+
with contextlib.ExitStack() as ctx_stack:
8590+
ctx_stack.enter_context(
8591+
patch.object(
8592+
torch._dynamo.variables.higher_order_ops.CondHigherOrderVariable,
8593+
"supports_input_mutation",
8594+
True,
8595+
),
8596+
)
8597+
ctx_stack.enter_context(
8598+
patch.object(
8599+
torch._dynamo.variables.higher_order_ops.WhileLoopHigherOrderVariable,
8600+
"supports_input_mutation",
8601+
True,
8602+
),
8603+
)
85948604
# Only suuport input mutation in inference
85958605
cloned_args = [_clone(args) for _ in range(3)]
85968606
with torch.no_grad():
@@ -8809,6 +8819,211 @@ def forward(self, arg0_1: "f32[4, 3]", arg1_1: "f32[3, 4]"):
88098819
""", # noqa: B950
88108820
)
88118821

8822+
@requires_cuda
8823+
@unittest.skipIf(not SM70OrLater, "triton")
8824+
@parametrize("device", ["cuda", "cpu"])
8825+
@parametrize("dynamic", [True, False])
8826+
def test_while_loop_auto_functionalize_input_mutation(self, device, dynamic):
8827+
class M(torch.nn.Module):
8828+
def forward(self, x, y):
8829+
def cond_fn(x):
8830+
return x.sum() > 0
8831+
8832+
def body_fn(x):
8833+
x.add_(-1)
8834+
return (x.clone(),)
8835+
8836+
x = x.clone()
8837+
ret = while_loop(cond_fn, body_fn, (x,))
8838+
return y + ret[0]
8839+
8840+
x, y = (
8841+
torch.randn(3, 4),
8842+
torch.randn(3, 4),
8843+
)
8844+
fw_gm = self.check(M, (x, y), device, dynamic)
8845+
if not TEST_WITH_CROSSREF and not dynamic and device == "cuda":
8846+
self.assertExpectedInline(
8847+
normalize_gm(fw_gm.print_readable(print_output=False)),
8848+
"""\
8849+
class <lambda>(torch.nn.Module):
8850+
def forward(self, arg0_1: "f32[3, 4]", arg1_1: "f32[3, 4]"):
8851+
clone: "f32[3, 4]" = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
8852+
8853+
auto_functionalized_subgraph_0 = self.auto_functionalized_subgraph_0
8854+
auto_functionalized_subgraph_1 = self.auto_functionalized_subgraph_1
8855+
_tree_spec_constant0 = self._tree_spec_constant0
8856+
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.higher_order.while_loop, cond_fn = auto_functionalized_subgraph_0, body_fn = auto_functionalized_subgraph_1, _carried_input0_base_index = 0, _all_bases = [clone], _op_schema = _tree_spec_constant0); auto_functionalized_subgraph_0 = auto_functionalized_subgraph_1 = clone = _tree_spec_constant0 = None
8857+
getitem: "f32[3, 4]" = auto_functionalized_v2[0]; auto_functionalized_v2 = None
8858+
8859+
add: "f32[3, 4]" = torch.ops.aten.add.Tensor(arg1_1, getitem); arg1_1 = getitem = None
8860+
return (add,)
8861+
8862+
class auto_functionalized_subgraph_0(torch.nn.Module):
8863+
def forward(self, arg0_1: "f32[3, 4]"):
8864+
sum_1: "f32[]" = torch.ops.aten.sum.default(arg0_1); arg0_1 = None
8865+
gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None
8866+
return gt
8867+
8868+
class auto_functionalized_subgraph_1(torch.nn.Module):
8869+
def forward(self, arg0_1: "f32[3, 4]"):
8870+
add: "f32[3, 4]" = torch.ops.aten.add.Tensor(arg0_1, -1)
8871+
clone: "f32[3, 4]" = torch.ops.aten.clone.default(add)
8872+
copy_: "f32[3, 4]" = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = add = copy_ = None
8873+
return (clone,)
8874+
""", # noqa: B950
8875+
)
8876+
8877+
@requires_cuda
8878+
@unittest.skipIf(not SM70OrLater, "triton")
8879+
@parametrize("device", ["cuda", "cpu"])
8880+
@parametrize("dynamic", [True, False])
8881+
def test_while_loop_auto_functionalize_buffer_mutation(self, device, dynamic):
8882+
class M(torch.nn.Module):
8883+
def __init__(self):
8884+
super().__init__()
8885+
self.register_buffer(
8886+
"buf", torch.ones(8, requires_grad=False, device=device)
8887+
)
8888+
8889+
def forward(self, p, x):
8890+
def cond_fn(x):
8891+
return x.sum() < 0
8892+
8893+
def body_fn(x):
8894+
x.add_(-1)
8895+
self.buf.add_(-1)
8896+
return (x + self.buf.sum(),)
8897+
8898+
x = x.clone()
8899+
out = while_loop(cond_fn, body_fn, (x,))
8900+
return x + self.buf + out[0]
8901+
8902+
p, x = torch.tensor(True), torch.randn(1, requires_grad=True)
8903+
fw_gm = self.check(M, (p, x), device, dynamic)
8904+
if not TEST_WITH_CROSSREF and not dynamic and device == "cuda":
8905+
self.assertExpectedInline(
8906+
normalize_gm(fw_gm.print_readable(print_output=False)),
8907+
"""\
8908+
class <lambda>(torch.nn.Module):
8909+
def forward(self, arg0_1: "f32[1]", arg1_1: "f32[8]"):
8910+
clone: "f32[1]" = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
8911+
8912+
auto_functionalized_subgraph_0 = self.auto_functionalized_subgraph_0
8913+
auto_functionalized_subgraph_1 = self.auto_functionalized_subgraph_1
8914+
_tree_spec_constant0 = self._tree_spec_constant0
8915+
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.higher_order.while_loop, cond_fn = auto_functionalized_subgraph_0, body_fn = auto_functionalized_subgraph_1, _carried_input0_base_index = 0, _additional_input0_base_index = 1, _all_bases = [clone, arg1_1], _op_schema = _tree_spec_constant0); auto_functionalized_subgraph_0 = auto_functionalized_subgraph_1 = clone = _tree_spec_constant0 = None
8916+
getitem: "f32[1]" = auto_functionalized_v2[0]
8917+
getitem_1: "f32[1]" = auto_functionalized_v2[1]
8918+
getitem_2: "f32[8]" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
8919+
8920+
add: "f32[8]" = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = None
8921+
add_1: "f32[8]" = torch.ops.aten.add.Tensor(add, getitem); add = getitem = None
8922+
8923+
copy_: "f32[8]" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy_ = None
8924+
return (add_1,)
8925+
8926+
class auto_functionalized_subgraph_0(torch.nn.Module):
8927+
def forward(self, arg0_1: "f32[1]", arg1_1: "f32[8]"):
8928+
sum_1: "f32[]" = torch.ops.aten.sum.default(arg0_1); arg0_1 = None
8929+
lt: "b8[]" = torch.ops.aten.lt.Scalar(sum_1, 0); sum_1 = None
8930+
return lt
8931+
8932+
class auto_functionalized_subgraph_1(torch.nn.Module):
8933+
def forward(self, arg0_1: "f32[1]", arg1_1: "f32[8]"):
8934+
add: "f32[1]" = torch.ops.aten.add.Tensor(arg0_1, -1)
8935+
add_1: "f32[8]" = torch.ops.aten.add.Tensor(arg1_1, -1)
8936+
sum_1: "f32[]" = torch.ops.aten.sum.default(add_1)
8937+
add_2: "f32[1]" = torch.ops.aten.add.Tensor(add, sum_1); sum_1 = None
8938+
copy_: "f32[1]" = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = add = copy_ = None
8939+
copy__1: "f32[8]" = torch.ops.aten.copy_.default(arg1_1, add_1); arg1_1 = add_1 = copy__1 = None
8940+
return (add_2,)
8941+
""", # noqa: B950
8942+
)
8943+
8944+
@requires_cuda
8945+
@unittest.skipIf(not SM70OrLater, "triton")
8946+
@torch._dynamo.config.patch(capture_scalar_outputs=True)
8947+
@torch._dynamo.config.patch(prefer_deferred_runtime_asserts_over_guards=True)
8948+
@parametrize("device", ["cuda", "cpu"])
8949+
@parametrize("dynamic", [True, False])
8950+
def test_while_loop_auto_functionalize_inplace_mutate_out_buffer_as_carry(
8951+
self, device, dynamic
8952+
):
8953+
class M(torch.nn.Module):
8954+
def __init__(self):
8955+
super().__init__()
8956+
self.register_buffer(
8957+
"buf", torch.ones(1, requires_grad=False, device=device)
8958+
)
8959+
8960+
def forward(self, p, x):
8961+
def cond_fn(it, x, out_buf):
8962+
return it < x.size(0)
8963+
8964+
def body_fn(it, x, out_buf):
8965+
out = x.sin()
8966+
idx = it.item()
8967+
torch._check_is_size(idx, max=x.size(0) - 1)
8968+
out_buf[idx].add_(out[idx])
8969+
return (it + 1, x + 1, out_buf.clone())
8970+
8971+
it = torch.tensor(0, dtype=torch.int64)
8972+
out_buf = x.clone()
8973+
x = x.clone()
8974+
out = while_loop(cond_fn, body_fn, (it, x, out_buf))
8975+
return x + self.buf + out[0]
8976+
8977+
p, x = torch.tensor(True), torch.randn(3, 4)
8978+
fw_gm = self.check(M, (p, x), device, dynamic)
8979+
if not TEST_WITH_CROSSREF and not dynamic and device == "cuda":
8980+
self.assertExpectedInline(
8981+
normalize_gm(fw_gm.print_readable(print_output=False)),
8982+
"""\
8983+
class <lambda>(torch.nn.Module):
8984+
def forward(self, arg0_1: "f32[3, 4]", arg1_1: "f32[1]"):
8985+
_tensor_constant0: "i64[]" = self._tensor_constant0
8986+
lift_fresh_copy: "i64[]" = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None
8987+
8988+
clone: "f32[3, 4]" = torch.ops.aten.clone.default(arg0_1)
8989+
8990+
clone_1: "f32[3, 4]" = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
8991+
8992+
auto_functionalized_subgraph_0 = self.auto_functionalized_subgraph_0
8993+
auto_functionalized_subgraph_1 = self.auto_functionalized_subgraph_1
8994+
_tree_spec_constant0 = self._tree_spec_constant0
8995+
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.higher_order.while_loop, cond_fn = auto_functionalized_subgraph_0, body_fn = auto_functionalized_subgraph_1, carried_input0 = lift_fresh_copy, carried_input1 = clone_1, _carried_input2_base_index = 0, _all_bases = [clone], _op_schema = _tree_spec_constant0); auto_functionalized_subgraph_0 = auto_functionalized_subgraph_1 = lift_fresh_copy = clone = _tree_spec_constant0 = None
8996+
getitem: "i64[]" = auto_functionalized_v2[0]; auto_functionalized_v2 = None
8997+
8998+
add: "f32[3, 4]" = torch.ops.aten.add.Tensor(clone_1, arg1_1); clone_1 = arg1_1 = None
8999+
add_1: "f32[3, 4]" = torch.ops.aten.add.Tensor(add, getitem); add = getitem = None
9000+
return (add_1,)
9001+
9002+
class auto_functionalized_subgraph_0(torch.nn.Module):
9003+
def forward(self, arg0_1: "i64[]", arg1_1: "f32[3, 4]", arg2_1: "f32[3, 4]"):
9004+
lt: "b8[]" = torch.ops.aten.lt.Scalar(arg0_1, 3); arg0_1 = None
9005+
return lt
9006+
9007+
class auto_functionalized_subgraph_1(torch.nn.Module):
9008+
def forward(self, arg0_1: "i64[]", arg1_1: "f32[3, 4]", arg2_1: "f32[3, 4]"):
9009+
sin: "f32[3, 4]" = torch.ops.aten.sin.default(arg1_1)
9010+
_local_scalar_dense: "Sym(u3)" = torch.ops.aten._local_scalar_dense.default(arg0_1)
9011+
ge_1: "Sym(u3 >= 0)" = _local_scalar_dense >= 0
9012+
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u3 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_default = None
9013+
le_1: "Sym(u3 <= 2)" = _local_scalar_dense <= 2
9014+
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u3 <= 2 on node 'le_1'"); le_1 = _assert_scalar_default_1 = None
9015+
select: "f32[4]" = torch.ops.aten.select.int(arg2_1, 0, _local_scalar_dense)
9016+
select_1: "f32[4]" = torch.ops.aten.select.int(sin, 0, _local_scalar_dense); sin = None
9017+
add: "f32[4]" = torch.ops.aten.add.Tensor(select, select_1); select = select_1 = None
9018+
select_scatter: "f32[3, 4]" = torch.ops.aten.select_scatter.default(arg2_1, add, 0, _local_scalar_dense); add = _local_scalar_dense = None
9019+
add_1: "i64[]" = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None
9020+
add_2: "f32[3, 4]" = torch.ops.aten.add.Tensor(arg1_1, 1); arg1_1 = None
9021+
clone: "f32[3, 4]" = torch.ops.aten.clone.default(select_scatter)
9022+
copy_: "f32[3, 4]" = torch.ops.aten.copy_.default(arg2_1, select_scatter); arg2_1 = select_scatter = copy_ = None
9023+
return (add_1, add_2, clone)
9024+
""", # noqa: B950
9025+
)
9026+
88129027

88139028
_hop_schema_test_schema_types = [
88149029
"bool",

torch/_dynamo/variables/higher_order_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1344,8 +1344,8 @@ def unspecialize_carried_inputs(tx, carry) -> VariableTracker:
13441344
source_target=self.value,
13451345
set_subgraph_inputs="flatten_manual",
13461346
should_flatten_outputs=True,
1347-
supports_input_mutation=False,
1348-
supports_aliasing=False,
1347+
supports_input_mutation=self.supports_input_mutation,
1348+
supports_aliasing=self.supports_aliasing,
13491349
)
13501350
validate_subgraph_output_types(body_r)
13511351

torch/_higher_order_ops/auto_functionalize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,7 @@ def __call__(self, *args, **kwargs):
588588
# Inlining has the benefit of allowing easiser fusion inside subgraph.
589589
# Though the epilogue graph contains copy_, it is OK because inductor can handle it
590590
# and this is also how we have been supporting top-level graph input mutation.
591-
return tuple(torch.func.functionalize(self.orig_callable)(*args, **kwargs))
591+
return torch.func.functionalize(self.orig_callable)(*args, **kwargs)
592592

593593
def __hash__(self):
594594
return id(self.orig_callable)

torch/_higher_order_ops/while_loop.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,11 @@ def _find_example_value(n, real_inp):
112112

113113
for idx, arg in enumerate(additional_inputs):
114114
additional_idx = len(carried_inputs) + idx
115-
assert additional_idx not in mutated_inputs, (
116-
"Lifted additional_inputs cannot be in-place mutated."
115+
schema_gen.add_arg(
116+
f"additional_input{idx}",
117+
arg,
118+
is_mutated=additional_idx in mutated_inputs,
117119
)
118-
schema_gen.add_arg(f"additional_input{idx}", arg, is_mutated=False)
119120

120121
for out in body_outputs:
121122
schema_gen.add_output(out)
@@ -498,7 +499,28 @@ def while_loop_fake_tensor_mode(
498499

499500
@while_loop_op.py_functionalize_impl
500501
def while_loop_func(ctx, cond_fn, body_fn, carried_inputs, additional_inputs):
501-
from torch._higher_order_ops.utils import _check_alias_and_mutation
502+
from torch._higher_order_ops.auto_functionalize import (
503+
can_auto_functionalize,
504+
do_auto_functionalize_v2,
505+
)
506+
from torch._higher_order_ops.utils import _check_alias_and_mutation, HopInstance
507+
508+
hop_instance = HopInstance.create(
509+
while_loop_op, cond_fn, body_fn, carried_inputs, additional_inputs
510+
)
511+
# For now, we only support auto-functionalization for while_loop when using python
512+
# functionalization mode
513+
if can_auto_functionalize(hop_instance) and hasattr(ctx, "mode"):
514+
return do_auto_functionalize_v2(
515+
ctx.mode,
516+
hop_instance,
517+
tuple(
518+
pytree.tree_flatten(
519+
(cond_fn, body_fn, carried_inputs, additional_inputs)
520+
)[0]
521+
),
522+
{},
523+
)
502524

503525
unwrapped_carried_inputs = ctx.unwrap_tensors(carried_inputs)
504526
unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs)

torch/_inductor/lowering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6964,7 +6964,7 @@ def _map_output(out: Any):
69646964
return out
69656965
elif isinstance(out, ir.StorageBox):
69666966
return TensorBox(out)
6967-
elif isinstance(out, ir.MultiOutput):
6967+
elif isinstance(out, (ir.MultiOutput, ir.ReinterpretView)):
69686968
return TensorBox.create(out)
69696969
else:
69706970
raise RuntimeError(f"NYI unsupported output type: {type(out)}")

0 commit comments

Comments
 (0)