Skip to content

Commit 5644567

Browse files
tugsbayasgalanfacebook-github-bot
authored andcommitted
Add support for param mutation under inference mode (#159661)
Summary: In HF model rwkv, we have parameter mutation under inference mode which should be safe. This PR does multiple things to make sure it works: 1. We execute global autograd mutation while tracing so that we can actually trace through parameter inplace mutation 2. Add support for parameter mutation under inference mode in AOTAutograd 3. Add support for parameter mutation under inference mode in export. Test Plan: test Rollback Plan: Reviewed By: ydwu4 Differential Revision: D79460136
1 parent bfc873d commit 5644567

File tree

17 files changed

+305
-31
lines changed

17 files changed

+305
-31
lines changed

test/export/test_export.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,52 @@ def forward(self, *args):
326326
dynamic_shapes=dynamic_shapes,
327327
)
328328

329+
def test_no_grad_param_inplace(self):
330+
class Foo(torch.nn.Module):
331+
def __init__(self):
332+
super().__init__()
333+
self.parameter = torch.nn.Parameter(torch.ones(4, 4))
334+
335+
def forward(self, x):
336+
with torch.no_grad():
337+
self.parameter.div_(2)
338+
return x + self.parameter
339+
340+
foo_ep = Foo()
341+
foo_eager = Foo()
342+
ep = export(foo_ep, (torch.rand(4, 4),)).run_decompositions()
343+
val = ep.graph_signature.parameters_to_mutate
344+
self.assertExpectedInline(
345+
str(ep.graph).strip(),
346+
"""\
347+
graph():
348+
%p_parameter : [num_users=1] = placeholder[target=p_parameter]
349+
%x : [num_users=1] = placeholder[target=x]
350+
%div : [num_users=2] = call_function[target=torch.ops.aten.div.Tensor](args = (%p_parameter, 2), kwargs = {})
351+
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %div), kwargs = {})
352+
return (div, add)""",
353+
)
354+
355+
self.assertTrue("div" in val.keys())
356+
self.assertTrue("parameter" in val.values())
357+
358+
test_inp = torch.rand(4, 4)
359+
360+
res = foo_eager(test_inp)
361+
362+
# TODO We almost need to make the param mutation happen outside
363+
# of the graph. Or wrap the param mutation in a no_grad HOP. Simply
364+
# overriding gm.__call__ doesn't seem to work due to:
365+
# 1. graph module does something weird to __call__ so it is not easy to override
366+
# 2. We inspect module.forward to bind fake args when retracing
367+
with self.assertRaisesRegex(RuntimeError, "leaf"):
368+
res_export = ep.module()(torch.rand(4, 4))
369+
370+
with torch.no_grad():
371+
res_export = ep.module()(test_inp)
372+
373+
self.assertTrue(torch.allclose(res, res_export))
374+
329375
def test_export_slice_unbacked_dim1(self):
330376
class MySlice(torch.nn.Module):
331377
def forward(self, x, seq_len):
@@ -4000,6 +4046,17 @@ def forward(self, x):
40004046
inp = torch.randn(3, 3)
40014047
self.assertTrue(torch.allclose(ep.module()(inp)[0], inp + 1))
40024048

4049+
def test_set_grad_as_side_effect(self):
4050+
class Foo(torch.nn.Module):
4051+
def forward(self, x):
4052+
torch._C._set_grad_enabled(False)
4053+
return x.sum()
4054+
4055+
before = torch.is_grad_enabled()
4056+
ep = torch.export.export(Foo(), (torch.randn(4, 4),))
4057+
after = torch.is_grad_enabled()
4058+
self.assertEqual(before, after)
4059+
40034060
def test_derived_dim_out_of_order_simplified(self):
40044061
_dimz = torch.export.Dim("_dimz", min=6, max=8)
40054062
dimy = _dimz - 1

test/export/test_serialize.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,25 @@ def forward(self, x):
280280
actual_out = loaded_ep.module()(*inp)
281281
self.assertEqual(exp_out, actual_out)
282282

283+
def test_serialize_param_mutation(self):
284+
class Foo(torch.nn.Module):
285+
def __init__(self):
286+
super().__init__()
287+
self.parameter = torch.nn.Parameter(torch.ones(4, 4))
288+
289+
def forward(self, x):
290+
with torch.no_grad():
291+
self.parameter.div_(2)
292+
return x + self.parameter
293+
294+
foo = Foo()
295+
ep = torch.export.export(foo, (torch.rand(4, 4),)).run_decompositions()
296+
buffer = io.BytesIO()
297+
save(ep, buffer)
298+
loaded_ep = load(buffer)
299+
val = loaded_ep.graph_signature.parameters_to_mutate
300+
self.assertEqual({"div": "parameter"}, val)
301+
283302
def test_serialize_constant_outputs(self):
284303
class MyModule(torch.nn.Module):
285304
def __init__(self) -> None:

test/functorch/test_aotdispatch.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5364,11 +5364,15 @@ def forward(self, x):
53645364

53655365
mod = M()
53665366
inp = torch.randn(2, requires_grad=True)
5367-
with self.assertRaisesRegex(
5368-
RuntimeError,
5369-
"Found a graph input that requires gradients, and received a mutation",
5370-
):
5371-
aot_export_module(mod, [inp], trace_joint=False)
5367+
gm, _ = aot_export_module(mod, [inp], trace_joint=False)
5368+
self.assertExpectedInline(
5369+
str(gm.graph).strip(),
5370+
"""\
5371+
graph():
5372+
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
5373+
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, 4), kwargs = {})
5374+
return (add, add)""",
5375+
)
53725376

53735377
def test_aot_export_input_mutation_on_parameter_banned(self):
53745378
def fn(p, x):
@@ -5379,11 +5383,26 @@ def fn(p, x):
53795383
inp = torch.randn(2)
53805384
with self.assertRaisesRegex(
53815385
RuntimeError,
5382-
"Found a graph input that requires gradients, and received a mutation",
5386+
"aot_export_joint_simple does not support input mutations. ViewAndMutationMeta",
53835387
):
53845388
aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False)
5389+
with self.assertRaisesRegex(
5390+
RuntimeError,
5391+
"Found a graph input that requires gradients, and received a mutation",
5392+
):
53855393
aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True)
5386-
aot_export_module(mod, [inp], trace_joint=False)
5394+
5395+
gm, _ = aot_export_module(mod, [inp], trace_joint=False)
5396+
self.assertExpectedInline(
5397+
str(gm.graph).strip(),
5398+
"""\
5399+
graph():
5400+
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
5401+
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
5402+
%mul : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg0_1, 2), kwargs = {})
5403+
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %arg1_1), kwargs = {})
5404+
return (mul, add)""",
5405+
)
53875406

53885407
def test_aot_export_synthetic_bases_banned(self):
53895408
def fn(p, x, y):

torch/_export/serde/export_schema.thrift

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// @generated by update_schema.py
2-
// checksum<<e7f100132ac684ccc67fce91b241821062f1dfe496fdff4b9929aba4ac938b4f>>
2+
// checksum<<00d94226d15b290b97bd49f9ff12bbfe04b7252c75d2d1bae66d1756fd9b8517>>
33

44
namespace py3 torch._export
55
namespace cpp2 torch._export.schema
@@ -254,6 +254,11 @@ struct BufferMutationSpec {
254254
20: string buffer_name;
255255
}
256256

257+
struct ParameterMutationSpec {
258+
10: TensorArgument arg;
259+
20: string parameter_name;
260+
}
261+
257262
struct GradientToParameterSpec {
258263
10: TensorArgument arg;
259264
20: string parameter_name;
@@ -281,6 +286,7 @@ union OutputSpec {
281286
50: GradientToUserInputSpec gradient_to_user_input;
282287
60: UserInputMutationSpec user_input_mutation;
283288
70: OutputTokenSpec token;
289+
80: ParameterMutationSpec parameter_mutation;
284290
}
285291

286292
struct GraphSignature {

torch/_export/serde/schema.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,12 @@ class BufferMutationSpec:
327327
buffer_name: Annotated[str, 20]
328328

329329

330+
@dataclass
331+
class ParameterMutationSpec:
332+
arg: Annotated[TensorArgument, 10]
333+
parameter_name: Annotated[str, 20]
334+
335+
330336
@dataclass
331337
class GradientToParameterSpec:
332338
arg: Annotated[TensorArgument, 10]
@@ -359,6 +365,7 @@ class OutputSpec(_Union):
359365
gradient_to_user_input: Annotated[GradientToUserInputSpec, 50]
360366
user_input_mutation: Annotated[UserInputMutationSpec, 60]
361367
token: Annotated[OutputTokenSpec, 70]
368+
parameter_mutation: Annotated[ParameterMutationSpec, 80]
362369

363370

364371
@dataclass

torch/_export/serde/schema.yaml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# @generated by update_schema.py
2-
# checksum<<afe0cc0f99e72d00aa05f1a94da938ecb619aabc5d131d3ade489b57799f1e5a>>
2+
# checksum<<face83b52f81c45eeaeccc97cee19e146b3f7416ed91e015b4510ada7549a72f>>
33
AOTInductorModelPickleData:
44
kind: struct
55
fields:
@@ -383,11 +383,20 @@ OutputSpec:
383383
type: UserInputMutationSpec
384384
token:
385385
type: OutputTokenSpec
386+
parameter_mutation:
387+
type: ParameterMutationSpec
386388
OutputTokenSpec:
387389
kind: struct
388390
fields:
389391
arg:
390392
type: TokenArgument
393+
ParameterMutationSpec:
394+
kind: struct
395+
fields:
396+
arg:
397+
type: TensorArgument
398+
parameter_name:
399+
type: str
391400
RangeConstraint:
392401
kind: struct
393402
fields:

torch/_export/serde/serialize.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
OptionalTensorArgument,
7070
OutputSpec,
7171
OutputTokenSpec,
72+
ParameterMutationSpec,
7273
RangeConstraint,
7374
ScalarType,
7475
SCHEMA_VERSION,
@@ -1241,6 +1242,15 @@ def serialize_output_spec(self, spec: ep.OutputSpec) -> OutputSpec:
12411242
buffer_name=spec.target,
12421243
)
12431244
)
1245+
elif spec.kind == ep.OutputKind.PARAMETER_MUTATION:
1246+
assert spec.target is not None
1247+
assert isinstance(spec.arg, ep.TensorArgument)
1248+
return OutputSpec.create(
1249+
parameter_mutation=ParameterMutationSpec(
1250+
arg=TensorArgument(name=spec.arg.name),
1251+
parameter_name=spec.target,
1252+
)
1253+
)
12441254
elif spec.kind == ep.OutputKind.GRADIENT_TO_PARAMETER:
12451255
assert spec.target is not None
12461256
assert isinstance(spec.arg, ep.TensorArgument)
@@ -2199,6 +2209,12 @@ def deserialize_output_spec(self, o: OutputSpec) -> ep.OutputSpec:
21992209
arg=ep.TensorArgument(name=o.buffer_mutation.arg.name),
22002210
target=o.buffer_mutation.buffer_name,
22012211
)
2212+
elif o.type == "parameter_mutation":
2213+
return ep.OutputSpec(
2214+
kind=ep.OutputKind.PARAMETER_MUTATION,
2215+
arg=ep.TensorArgument(name=o.parameter_mutation.arg.name),
2216+
target=o.parameter_mutation.parameter_name,
2217+
)
22022218
elif o.type == "gradient_to_parameter":
22032219
return ep.OutputSpec(
22042220
kind=ep.OutputKind.GRADIENT_TO_PARAMETER,
@@ -3377,17 +3393,19 @@ def rank_output(out) -> tuple[int, Optional[str], int]:
33773393
idx, (_arg, spec) = out
33783394
assert isinstance(spec, OutputSpec)
33793395
if spec.type == "user_output":
3380-
return 3, None, idx
3396+
return 4, None, idx
33813397
elif spec.type == "loss_output":
3382-
return 3, None, idx
3398+
return 4, None, idx
3399+
elif spec.type == "parameter_mutation":
3400+
return 1, spec.parameter_mutation.parameter_name, idx
33833401
elif spec.type == "buffer_mutation":
3384-
return 1, spec.buffer_mutation.buffer_name, idx
3402+
return 2, spec.buffer_mutation.buffer_name, idx
33853403
elif spec.type == "gradient_to_parameter":
3386-
return 4, spec.gradient_to_parameter.parameter_name, idx
3404+
return 5, spec.gradient_to_parameter.parameter_name, idx
33873405
elif spec.type == "gradient_to_user_input":
3388-
return 5, None, idx
3406+
return 6, None, idx
33893407
elif spec.type == "user_input_mutation":
3390-
return 2, None, idx
3408+
return 3, None, idx
33913409
elif spec.type == "token":
33923410
return 0, None, idx
33933411
else:
@@ -3500,6 +3518,9 @@ def replace_output(out):
35003518
elif spec.type == "buffer_mutation":
35013519
t = spec.buffer_mutation.arg
35023520
t.name = replace_table[t.name]
3521+
elif spec.type == "parameter_mutation":
3522+
t = spec.parameter_mutation.arg
3523+
t.name = replace_table[t.name]
35033524
elif spec.type == "gradient_to_parameter":
35043525
t = spec.gradient_to_parameter.arg
35053526
t.name = replace_table[t.name]

torch/_export/verifier.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,12 @@ def _verify_exported_program_signature(exported_program) -> None:
463463
)
464464

465465
num_tokens = len(gs.output_tokens)
466-
end = len(gs.buffers_to_mutate) + len(gs.user_inputs_to_mutate) + num_tokens
466+
end = (
467+
len(gs.buffers_to_mutate)
468+
+ len(gs.parameters_to_mutate)
469+
+ len(gs.user_inputs_to_mutate)
470+
+ num_tokens
471+
)
467472
mutate_nodes: list[str] = output_nodes[num_tokens:end]
468473
user_output_nodes = output_nodes[end : end + len(gs.user_outputs)]
469474

@@ -475,6 +480,13 @@ def _verify_exported_program_signature(exported_program) -> None:
475480
f"Dict of buffers that are mutated, in order: {gs.buffers_to_mutate} \n"
476481
f"Buffer nodes available: {gs.buffers} \n"
477482
)
483+
elif mutation_node in gs.parameters_to_mutate:
484+
if gs.parameters_to_mutate[mutation_node] not in gs.parameters:
485+
raise SpecViolationError(
486+
f"Parameter output {mutation_node} does not point to a parameter that exists. \n"
487+
f"Dict of parameters that are mutated, in order: {gs.parameters_to_mutate} \n"
488+
f"Parameter nodes available: {gs.parameters} \n"
489+
)
478490
elif mutation_node in gs.user_inputs_to_mutate:
479491
if gs.user_inputs_to_mutate[mutation_node] not in gs.user_inputs:
480492
raise SpecViolationError(

torch/_functorch/_aot_autograd/input_output_analysis.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,7 @@ def create_graph_signature(
460460
named_buffers=buffer_names,
461461
num_user_inputs=num_user_args,
462462
num_user_outputs=num_user_fw_outs,
463+
trace_joint=trace_joint,
463464
loss_index=loss_index,
464465
backward_signature=backward_signature,
465466
)

torch/_functorch/_aot_autograd/schemas.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,7 @@ class GraphSignature:
829829
# "graph outputs that correspond to updated buffers"
830830
# to the FQN names of those mutated buffers.
831831
buffers_to_mutate: dict[GraphOutputName, FQN]
832+
parameters_to_mutate: dict[GraphOutputName, FQN]
832833
user_inputs_to_mutate: dict[GraphOutputName, GraphInputName]
833834

834835
in_spec: pytree.TreeSpec
@@ -852,6 +853,7 @@ def from_tracing_metadata(
852853
named_buffers: list[str],
853854
num_user_inputs: int,
854855
num_user_outputs: int,
856+
trace_joint: bool,
855857
loss_index: Optional[int],
856858
backward_signature: Optional[BackwardSignature],
857859
) -> GraphSignature:
@@ -897,8 +899,9 @@ def from_tracing_metadata(
897899
mutations = []
898900
for idx, input_info in enumerate(view_mutation_metadata.input_info):
899901
if input_info.mutates_data:
900-
# Only buffers can be mutated, not parameters
901-
assert idx >= len(parameters)
902+
if trace_joint:
903+
# Only buffers can be mutated, not parameters
904+
assert idx >= len(parameters)
902905
mutations.append(names[idx + num_tokens])
903906

904907
assert len(mutations) == view_mutation_metadata.num_mutated_inp_runtime_indices
@@ -911,12 +914,16 @@ def from_tracing_metadata(
911914

912915
user_inputs_to_mutate = {}
913916
buffers_to_mutate = {}
917+
parameters_to_mutate = {}
914918
for output_name, mutation_name in outputs_to_mutations.items():
915919
if mutation_name in user_inputs:
916920
user_inputs_to_mutate[output_name] = mutation_name
917921
else:
918-
assert mutation_name in buffers
919-
buffers_to_mutate[output_name] = mutation_name
922+
assert mutation_name in buffers or mutation_name in parameters
923+
if mutation_name in buffers:
924+
buffers_to_mutate[output_name] = mutation_name
925+
else:
926+
parameters_to_mutate[output_name] = mutation_name
920927

921928
start, stop = stop, stop + num_user_outputs
922929
user_outputs = graph_outputs[start:stop]
@@ -937,6 +944,7 @@ def from_tracing_metadata(
937944
inputs_to_parameters=inputs_to_parameters, # type: ignore[arg-type]
938945
user_inputs_to_mutate=user_inputs_to_mutate,
939946
buffers_to_mutate=buffers_to_mutate, # type: ignore[arg-type]
947+
parameters_to_mutate=parameters_to_mutate, # type: ignore[arg-type]
940948
in_spec=in_spec,
941949
out_spec=out_spec,
942950
backward_signature=backward_signature,
@@ -983,6 +991,9 @@ class AOTConfig:
983991
ignore_shape_env: bool = False
984992
precompile_backend_id: Optional[str] = None
985993
force_non_lazy_backward_lowering: bool = False
994+
# This config makes sure to check certain things like
995+
# mutating input with req_grad in export joint tracing.
996+
export_trace_joint: bool = False
986997

987998
def __post_init__(self):
988999
if self.pre_dispatch:

0 commit comments

Comments
 (0)