Skip to content

Commit 43a00d7

Browse files
yanboliangpytorchmergebot
authored andcommitted
[Trace Python Dispatcher] Support FuncTorchInterpreter (#144444)
Pull Request resolved: #144444 Approved by: https://github.com/williamwen42, https://github.com/zou3519 ghstack dependencies: #144439
1 parent 5d02575 commit 43a00d7

File tree

4 files changed

+91
-2
lines changed

4 files changed

+91
-2
lines changed

test/dynamo/test_python_dispatcher.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,34 @@ def fn(x, dks):
102102
# Re-compile since the dispatch key set is different.
103103
self.assertEqual(counter.frame_count, 2)
104104

105+
def test_functorch_interpreter(self):
106+
counter = CompileCounter()
107+
108+
def square_and_add(x, y):
109+
interpreter = (
110+
torch._functorch.pyfunctorch.retrieve_current_functorch_interpreter()
111+
)
112+
level = interpreter.level()
113+
if interpreter.key() == torch._C._functorch.TransformType.Vmap:
114+
return (x**2 + y) * level
115+
else:
116+
return x**2 * level
117+
118+
@torch.compile(backend=counter, fullgraph=True)
119+
def fn(x, y):
120+
return torch.vmap(square_and_add)(x, y)
121+
122+
x = torch.tensor([1, 2, 3, 4])
123+
y = torch.tensor([10, 20, 30, 40])
124+
self.assertEqual(fn(x, y), torch.tensor([11, 24, 39, 56]))
125+
self.assertEqual(counter.frame_count, 1)
126+
127+
x = torch.tensor([1, 2, 3, 1])
128+
y = torch.tensor([10, 20, 30, 10])
129+
self.assertEqual(fn(x, y), torch.tensor([11, 24, 39, 11]))
130+
# No recompile
131+
self.assertEqual(counter.frame_count, 1)
132+
105133

106134
if __name__ == "__main__":
107135
from torch._dynamo.test_case import run_tests

torch/_dynamo/trace_rules.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@
305305
"torch._C._functorch.peek_interpreter_stack": TorchInGraphFunctionVariable,
306306
"torch._C._functorch.unwrap_if_dead": TorchInGraphFunctionVariable,
307307
# everything else
308+
"torch._functorch.pyfunctorch.coerce_cinterpreter": TorchInGraphFunctionVariable,
308309
"torch._higher_order_ops.triton_kernel_wrap.do_prune_configs": UserFunctionVariable,
309310
"torch._higher_order_ops.foreach_map.foreach_map": UserFunctionVariable,
310311
"torch._constrain_as_size": UserFunctionVariable,

torch/_dynamo/variables/builder.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@
217217
)
218218
from .torch import (
219219
DispatchKeySetVariable,
220+
FuncTorchInterpreterVariable,
220221
TorchCtxManagerClassVariable,
221222
TorchInGraphFunctionVariable,
222223
)
@@ -668,7 +669,9 @@ def build_key_value(i, k, v):
668669
items = [SourcelessBuilder.create(self.tx, v) for v in value]
669670
self.install_guards(GuardBuilder.ID_MATCH)
670671
return FrozensetVariable(items, source=self.source)
671-
elif isinstance(value, (enum.Enum, torch.DispatchKey)):
672+
elif isinstance(
673+
value, (enum.Enum, torch.DispatchKey, torch._C._functorch.TransformType)
674+
):
672675
self.install_guards(GuardBuilder.ID_MATCH)
673676
return EnumVariable(value=value, source=self.source)
674677
elif DebuggingVariable.is_reorderable_logging_function(value):
@@ -857,6 +860,9 @@ def build_key_value(i, k, v):
857860
elif isinstance(value, (torch._C._SDPAParams)):
858861
self.install_guards(GuardBuilder.TYPE_MATCH)
859862
return SDPAParamsVariable.create(self.tx, value, self.source)
863+
elif isinstance(value, torch._functorch.pyfunctorch.FuncTorchInterpreter):
864+
self.install_guards(GuardBuilder.ID_MATCH)
865+
return FuncTorchInterpreterVariable(value)
860866
elif isinstance(value, torch.Event):
861867
self.install_guards(GuardBuilder.ID_MATCH)
862868
torch._dynamo.utils.store_user_object_weakref(value)
@@ -2960,7 +2966,9 @@ def create(tx: "InstructionTranslator", value) -> VariableTracker:
29602966
return trace_rules.lookup_callable(value)(value)
29612967
elif is_function_or_wrapper(value):
29622968
return trace_rules.lookup(value)(value)
2963-
elif isinstance(value, (enum.Enum, torch.DispatchKey)):
2969+
elif isinstance(
2970+
value, (enum.Enum, torch.DispatchKey, torch._C._functorch.TransformType)
2971+
):
29642972
return EnumVariable(value)
29652973
elif isinstance(value, (type, abc.ABCMeta)):
29662974
return UserDefinedClassVariable(value)
@@ -3025,6 +3033,11 @@ def make_type_handlers():
30253033
handlers[torch.DispatchKeySet] = lambda tx, value: DispatchKeySetVariable(
30263034
value, mutation_type=ValueMutationNew()
30273035
)
3036+
handlers[
3037+
torch._functorch.pyfunctorch.FuncTorchInterpreter
3038+
] = lambda tx, value: FuncTorchInterpreterVariable(
3039+
value, mutation_type=ValueMutationNew()
3040+
)
30283041

30293042
handlers[
30303043
torch.distributions.constraints._Real

torch/_dynamo/variables/torch.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,25 @@ def handle_unsafe_set_version_counter(
826826
_unsafe_set_version_counter
827827
).call_function(tx, [*args], kwargs)
828828

829+
@register(torch._C._functorch.peek_interpreter_stack)
830+
def handle_functorch_peek_interpreter_stack(
831+
self, tx: "InstructionTranslator", *args, **kwargs
832+
):
833+
# Wrap C++ interpreter (torch._C._functorch.CInterpreter) as UserDefinedObjectVariable,
834+
# but Python interpreter (torch._functorch.pyfunctorch.FuncTorchInterpreter) as FuncTorchInterpreterVariable.
835+
return UserDefinedObjectVariable(
836+
torch._C._functorch.peek_interpreter_stack()
837+
)
838+
839+
@register(torch._functorch.pyfunctorch.coerce_cinterpreter)
840+
def handle_functorch_pyfunctorch_coerce_cinterpreter(
841+
self, tx: "InstructionTranslator", *args, **kwargs
842+
):
843+
cinterpreter = args[0].value
844+
return FuncTorchInterpreterVariable(
845+
torch._functorch.pyfunctorch.coerce_cinterpreter(cinterpreter)
846+
)
847+
829848
@register(torch.tensor)
830849
def handle_torch_tensor(self, tx: "InstructionTranslator", *args, **kwargs):
831850
def check_any_unspec(x):
@@ -1260,3 +1279,31 @@ def call_method(
12601279
elif name == "highestPriorityTypeId":
12611280
return variables.EnumVariable(self.value.highestPriorityTypeId())
12621281
return super().call_method(tx, name, args, kwargs)
1282+
1283+
1284+
class FuncTorchInterpreterVariable(BaseTorchVariable):
1285+
"""represents torch._functorch.pyfunctorch.FuncTorchInterpreter"""
1286+
1287+
@classmethod
1288+
def create_with_source(cls, value, source):
1289+
install_guard(source.make_guard(GuardBuilder.ID_MATCH))
1290+
return cls(value, source=source)
1291+
1292+
def call_method(
1293+
self,
1294+
tx,
1295+
name,
1296+
args: "List[VariableTracker]",
1297+
kwargs: "Dict[str, VariableTracker]",
1298+
) -> "VariableTracker":
1299+
if name == "key":
1300+
return variables.EnumVariable(self.value.key())
1301+
elif name == "process":
1302+
return tx.inline_user_function_return(
1303+
variables.UserFunctionVariable(self.value.process.__func__),
1304+
[self] + args,
1305+
kwargs,
1306+
)
1307+
elif name in ["level", "batch_size", "randomness"]:
1308+
return variables.ConstantVariable.create(getattr(self.value, name)())
1309+
return super().call_method(tx, name, args, kwargs)

0 commit comments

Comments
 (0)