Skip to content

[while_loop] support input mutation with auto_functionalize #159010

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: gh/ydwu4/296/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 220 additions & 5 deletions test/functorch/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8587,11 +8587,21 @@ def _new_fn():
mod_or_fn.to(device)
return mod_or_fn

with patch.object(
torch._dynamo.variables.higher_order_ops.CondHigherOrderVariable,
"supports_input_mutation",
True,
):
with contextlib.ExitStack() as ctx_stack:
ctx_stack.enter_context(
patch.object(
torch._dynamo.variables.higher_order_ops.CondHigherOrderVariable,
"supports_input_mutation",
True,
),
)
ctx_stack.enter_context(
patch.object(
torch._dynamo.variables.higher_order_ops.WhileLoopHigherOrderVariable,
"supports_input_mutation",
True,
),
)
# Only suuport input mutation in inference
cloned_args = [_clone(args) for _ in range(3)]
with torch.no_grad():
Expand Down Expand Up @@ -8810,6 +8820,211 @@ def forward(self, arg0_1: "f32[4, 3]", arg1_1: "f32[3, 4]"):
""", # noqa: B950
)

@requires_cuda
@unittest.skipIf(not SM70OrLater, "triton")
@parametrize("device", ["cuda", "cpu"])
@parametrize("dynamic", [True, False])
def test_while_loop_auto_functionalize_input_mutation(self, device, dynamic):
class M(torch.nn.Module):
def forward(self, x, y):
def cond_fn(x):
return x.sum() > 0

def body_fn(x):
x.add_(-1)
return (x.clone(),)

x = x.clone()
ret = while_loop(cond_fn, body_fn, (x,))
return y + ret[0]

x, y = (
torch.randn(3, 4),
torch.randn(3, 4),
)
fw_gm = self.check(M, (x, y), device, dynamic)
if not TEST_WITH_CROSSREF and not dynamic and device == "cuda":
self.assertExpectedInline(
normalize_gm(fw_gm.print_readable(print_output=False)),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[3, 4]", arg1_1: "f32[3, 4]"):
clone: "f32[3, 4]" = torch.ops.aten.clone.default(arg0_1); arg0_1 = None

auto_functionalized_subgraph_0 = self.auto_functionalized_subgraph_0
auto_functionalized_subgraph_1 = self.auto_functionalized_subgraph_1
_tree_spec_constant0 = self._tree_spec_constant0
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.higher_order.while_loop, cond_fn = auto_functionalized_subgraph_0, body_fn = auto_functionalized_subgraph_1, _carried_input0_base_index = 0, _all_bases = [clone], _op_schema = _tree_spec_constant0); auto_functionalized_subgraph_0 = auto_functionalized_subgraph_1 = clone = _tree_spec_constant0 = None
getitem: "f32[3, 4]" = auto_functionalized_v2[0]; auto_functionalized_v2 = None

add: "f32[3, 4]" = torch.ops.aten.add.Tensor(arg1_1, getitem); arg1_1 = getitem = None
return (add,)

class auto_functionalized_subgraph_0(torch.nn.Module):
def forward(self, arg0_1: "f32[3, 4]"):
sum_1: "f32[]" = torch.ops.aten.sum.default(arg0_1); arg0_1 = None
gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None
return gt

class auto_functionalized_subgraph_1(torch.nn.Module):
def forward(self, arg0_1: "f32[3, 4]"):
add: "f32[3, 4]" = torch.ops.aten.add.Tensor(arg0_1, -1)
clone: "f32[3, 4]" = torch.ops.aten.clone.default(add)
copy_: "f32[3, 4]" = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = add = copy_ = None
return (clone,)
""", # noqa: B950
)

@requires_cuda
@unittest.skipIf(not SM70OrLater, "triton")
@parametrize("device", ["cuda", "cpu"])
@parametrize("dynamic", [True, False])
def test_while_loop_auto_functionalize_buffer_mutation(self, device, dynamic):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer(
"buf", torch.ones(8, requires_grad=False, device=device)
)

def forward(self, p, x):
def cond_fn(x):
return x.sum() < 0

def body_fn(x):
x.add_(-1)
self.buf.add_(-1)
return (x + self.buf.sum(),)

x = x.clone()
out = while_loop(cond_fn, body_fn, (x,))
return x + self.buf + out[0]

p, x = torch.tensor(True), torch.randn(1, requires_grad=True)
fw_gm = self.check(M, (p, x), device, dynamic)
if not TEST_WITH_CROSSREF and not dynamic and device == "cuda":
self.assertExpectedInline(
normalize_gm(fw_gm.print_readable(print_output=False)),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[1]", arg1_1: "f32[8]"):
clone: "f32[1]" = torch.ops.aten.clone.default(arg0_1); arg0_1 = None

auto_functionalized_subgraph_0 = self.auto_functionalized_subgraph_0
auto_functionalized_subgraph_1 = self.auto_functionalized_subgraph_1
_tree_spec_constant0 = self._tree_spec_constant0
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.higher_order.while_loop, cond_fn = auto_functionalized_subgraph_0, body_fn = auto_functionalized_subgraph_1, _carried_input0_base_index = 0, _additional_input0_base_index = 1, _all_bases = [clone, arg1_1], _op_schema = _tree_spec_constant0); auto_functionalized_subgraph_0 = auto_functionalized_subgraph_1 = clone = _tree_spec_constant0 = None
getitem: "f32[1]" = auto_functionalized_v2[0]
getitem_1: "f32[1]" = auto_functionalized_v2[1]
getitem_2: "f32[8]" = auto_functionalized_v2[2]; auto_functionalized_v2 = None

add: "f32[8]" = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = None
add_1: "f32[8]" = torch.ops.aten.add.Tensor(add, getitem); add = getitem = None

copy_: "f32[8]" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy_ = None
return (add_1,)

class auto_functionalized_subgraph_0(torch.nn.Module):
def forward(self, arg0_1: "f32[1]", arg1_1: "f32[8]"):
sum_1: "f32[]" = torch.ops.aten.sum.default(arg0_1); arg0_1 = None
lt: "b8[]" = torch.ops.aten.lt.Scalar(sum_1, 0); sum_1 = None
return lt

class auto_functionalized_subgraph_1(torch.nn.Module):
def forward(self, arg0_1: "f32[1]", arg1_1: "f32[8]"):
add: "f32[1]" = torch.ops.aten.add.Tensor(arg0_1, -1)
add_1: "f32[8]" = torch.ops.aten.add.Tensor(arg1_1, -1)
sum_1: "f32[]" = torch.ops.aten.sum.default(add_1)
add_2: "f32[1]" = torch.ops.aten.add.Tensor(add, sum_1); sum_1 = None
copy_: "f32[1]" = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = add = copy_ = None
copy__1: "f32[8]" = torch.ops.aten.copy_.default(arg1_1, add_1); arg1_1 = add_1 = copy__1 = None
return (add_2,)
""", # noqa: B950
)

@requires_cuda
@unittest.skipIf(not SM70OrLater, "triton")
@torch._dynamo.config.patch(capture_scalar_outputs=True)
@torch._dynamo.config.patch(prefer_deferred_runtime_asserts_over_guards=True)
@parametrize("device", ["cuda", "cpu"])
@parametrize("dynamic", [True, False])
def test_while_loop_auto_functionalize_inplace_mutate_out_buffer_as_carry(
self, device, dynamic
):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer(
"buf", torch.ones(1, requires_grad=False, device=device)
)

def forward(self, p, x):
def cond_fn(it, x, out_buf):
return it < x.size(0)

def body_fn(it, x, out_buf):
out = x.sin()
idx = it.item()
torch._check_is_size(idx, max=x.size(0) - 1)
out_buf[idx].add_(out[idx])
return (it + 1, x + 1, out_buf.clone())

it = torch.tensor(0, dtype=torch.int64)
out_buf = x.clone()
x = x.clone()
out = while_loop(cond_fn, body_fn, (it, x, out_buf))
return x + self.buf + out[0]

p, x = torch.tensor(True), torch.randn(3, 4)
fw_gm = self.check(M, (p, x), device, dynamic)
if not TEST_WITH_CROSSREF and not dynamic and device == "cuda":
self.assertExpectedInline(
normalize_gm(fw_gm.print_readable(print_output=False)),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[3, 4]", arg1_1: "f32[1]"):
_tensor_constant0: "i64[]" = self._tensor_constant0
lift_fresh_copy: "i64[]" = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None

clone: "f32[3, 4]" = torch.ops.aten.clone.default(arg0_1)

clone_1: "f32[3, 4]" = torch.ops.aten.clone.default(arg0_1); arg0_1 = None

auto_functionalized_subgraph_0 = self.auto_functionalized_subgraph_0
auto_functionalized_subgraph_1 = self.auto_functionalized_subgraph_1
_tree_spec_constant0 = self._tree_spec_constant0
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.higher_order.while_loop, cond_fn = auto_functionalized_subgraph_0, body_fn = auto_functionalized_subgraph_1, carried_input0 = lift_fresh_copy, carried_input1 = clone_1, _carried_input2_base_index = 0, _all_bases = [clone], _op_schema = _tree_spec_constant0); auto_functionalized_subgraph_0 = auto_functionalized_subgraph_1 = lift_fresh_copy = clone = _tree_spec_constant0 = None
getitem: "i64[]" = auto_functionalized_v2[0]; auto_functionalized_v2 = None

add: "f32[3, 4]" = torch.ops.aten.add.Tensor(clone_1, arg1_1); clone_1 = arg1_1 = None
add_1: "f32[3, 4]" = torch.ops.aten.add.Tensor(add, getitem); add = getitem = None
return (add_1,)

class auto_functionalized_subgraph_0(torch.nn.Module):
def forward(self, arg0_1: "i64[]", arg1_1: "f32[3, 4]", arg2_1: "f32[3, 4]"):
lt: "b8[]" = torch.ops.aten.lt.Scalar(arg0_1, 3); arg0_1 = None
return lt

class auto_functionalized_subgraph_1(torch.nn.Module):
def forward(self, arg0_1: "i64[]", arg1_1: "f32[3, 4]", arg2_1: "f32[3, 4]"):
sin: "f32[3, 4]" = torch.ops.aten.sin.default(arg1_1)
_local_scalar_dense: "Sym(u3)" = torch.ops.aten._local_scalar_dense.default(arg0_1)
ge_1: "Sym(u3 >= 0)" = _local_scalar_dense >= 0
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u3 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_default = None
le_1: "Sym(u3 <= 2)" = _local_scalar_dense <= 2
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u3 <= 2 on node 'le_1'"); le_1 = _assert_scalar_default_1 = None
select: "f32[4]" = torch.ops.aten.select.int(arg2_1, 0, _local_scalar_dense)
select_1: "f32[4]" = torch.ops.aten.select.int(sin, 0, _local_scalar_dense); sin = None
add: "f32[4]" = torch.ops.aten.add.Tensor(select, select_1); select = select_1 = None
select_scatter: "f32[3, 4]" = torch.ops.aten.select_scatter.default(arg2_1, add, 0, _local_scalar_dense); add = _local_scalar_dense = None
add_1: "i64[]" = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None
add_2: "f32[3, 4]" = torch.ops.aten.add.Tensor(arg1_1, 1); arg1_1 = None
clone: "f32[3, 4]" = torch.ops.aten.clone.default(select_scatter)
copy_: "f32[3, 4]" = torch.ops.aten.copy_.default(arg2_1, select_scatter); arg2_1 = select_scatter = copy_ = None
return (add_1, add_2, clone)
""", # noqa: B950
)


_hop_schema_test_schema_types = [
"bool",
Expand Down
4 changes: 2 additions & 2 deletions torch/_dynamo/variables/higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1358,8 +1358,8 @@ def unspecialize_carried_inputs(tx, carry) -> VariableTracker:
source_target=self.value,
set_subgraph_inputs="flatten_manual",
should_flatten_outputs=True,
supports_input_mutation=False,
supports_aliasing=False,
supports_input_mutation=self.supports_input_mutation,
supports_aliasing=self.supports_aliasing,
)
validate_subgraph_output_types(body_r)

Expand Down
2 changes: 1 addition & 1 deletion torch/_higher_order_ops/auto_functionalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ def __call__(self, *args, **kwargs):
# Inlining has the benefit of allowing easiser fusion inside subgraph.
# Though the epilogue graph contains copy_, it is OK because inductor can handle it
# and this is also how we have been supporting top-level graph input mutation.
return tuple(torch.func.functionalize(self.orig_callable)(*args, **kwargs))
return torch.func.functionalize(self.orig_callable)(*args, **kwargs)

def __hash__(self):
return id(self.orig_callable)
Expand Down
30 changes: 26 additions & 4 deletions torch/_higher_order_ops/while_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,11 @@ def _find_example_value(n, real_inp):

for idx, arg in enumerate(additional_inputs):
additional_idx = len(carried_inputs) + idx
assert additional_idx not in mutated_inputs, (
"Lifted additional_inputs cannot be in-place mutated."
schema_gen.add_arg(
f"additional_input{idx}",
arg,
is_mutated=additional_idx in mutated_inputs,
)
schema_gen.add_arg(f"additional_input{idx}", arg, is_mutated=False)

for out in body_outputs:
schema_gen.add_output(out)
Expand Down Expand Up @@ -498,7 +499,28 @@ def while_loop_fake_tensor_mode(

@while_loop_op.py_functionalize_impl
def while_loop_func(ctx, cond_fn, body_fn, carried_inputs, additional_inputs):
from torch._higher_order_ops.utils import _check_alias_and_mutation
from torch._higher_order_ops.auto_functionalize import (
can_auto_functionalize,
do_auto_functionalize_v2,
)
from torch._higher_order_ops.utils import _check_alias_and_mutation, HopInstance

hop_instance = HopInstance.create(
while_loop_op, cond_fn, body_fn, carried_inputs, additional_inputs
)
# For now, we only support auto-functionalization for while_loop when using python
# functionalization mode
if can_auto_functionalize(hop_instance) and hasattr(ctx, "mode"):
return do_auto_functionalize_v2(
ctx.mode,
hop_instance,
tuple(
pytree.tree_flatten(
(cond_fn, body_fn, carried_inputs, additional_inputs)
)[0]
),
{},
)

unwrapped_carried_inputs = ctx.unwrap_tensors(carried_inputs)
unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs)
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -7024,7 +7024,7 @@ def _map_output(out: Any):
return out
elif isinstance(out, ir.StorageBox):
return TensorBox(out)
elif isinstance(out, ir.MultiOutput):
elif isinstance(out, (ir.MultiOutput, ir.ReinterpretView)):
return TensorBox.create(out)
else:
raise RuntimeError(f"NYI unsupported output type: {type(out)}")
Expand Down
Loading