-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Description
🐛 Describe the bug
I'm trying to compile a piece of code with fullgraph=True
that uses multiple context managers.
I saw that #98725 should support this sort of behaviour (cc: @yanboliang as author)
Lightning-AI/pytorch-lightning#18557 includes some real pieces of code that show this pattern
Error logs
No response
Minified repro
from contextlib import ExitStack
class A:
def __enter__(self): pass
def __exit__(self, exc_type, exc_val, exc_tb): pass
class B:
def __enter__(self): pass
def __exit__(self, exc_type, exc_val, exc_tb): pass
def init_context():
stack = ExitStack()
stack.enter_context(A())
stack.enter_context(B())
return stack
def fn():
with init_context():
return 1 + 2
import torch
fn = torch.compile(fn, fullgraph=True)
out = fn()
However, this fails with
Traceback (most recent call last):
File "/home/carmocca/git/lightning/kk2.py", line 25, in <module>
out = fn()
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 333, in _fn
return fn(*args, **kwargs)
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 493, in catch_errors
return callback(frame, cache_entry, hooks, frame_state)
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 133, in _fn
return fn(*args, **kwargs)
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 389, in _convert_frame_assert
return _compile(
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 564, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
r = func(*args, **kwargs)
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 486, in compile_inner
out_code = transform_code_object(code, transform)
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
transformations(instructions, code_options)
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 453, in transform
tracer.run()
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2074, in run
super().run()
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
and self.step()
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
getattr(self, inst.opname)(inst)
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
return inner_fn(self, inst)
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1115, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 562, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 261, in call_function
return super().call_function(tx, args, kwargs)
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
return tx.inline_user_function_return(
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 598, in inline_user_function_return
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2179, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2286, in inline_call_
tracer.run()
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
and self.step()
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
getattr(self, inst.opname)(inst)
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
return inner_fn(self, inst)
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1115, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 562, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py", line 594, in call_function
return self.obj.call_method(tx, self.name, args, kwargs).add_options(self)
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/variables/base.py", line 329, in call_method
raise unimplemented(f"call_method {self} {name} {args} {kwargs}")
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 172, in unimplemented
raise Unsupported(msg)
torch._dynamo.exc.Unsupported: call_method GenericContextWrappingVariable() enter_context [GenericContextWrappingVariable()] {}
from user code:
File "/home/carmocca/git/lightning/kk2.py", line 20, in fn
with init_context():
File "/home/carmocca/git/lightning/kk2.py", line 14, in init_context
stack.enter_context(A())
I also tried my own simplified implementation:
class ExitStack:
def __init__(self, context_managers) -> None:
self._context_managers = context_managers
def __enter__(self):
for ctx_manager in self._context_managers:
ctx_manager.__enter__()
def __exit__(self, exc_type, exc_value, traceback):
for ctx_manager in reversed(self._context_managers):
ctx_manager.__exit__(exc_type, exc_value, traceback)
class A:
def __enter__(self): pass
def __exit__(self, exc_type, exc_val, exc_tb): pass
class B:
def __enter__(self): pass
def __exit__(self, exc_type, exc_val, exc_tb): pass
def init_context():
stack = ExitStack([A(), B()])
return stack
def fn():
with init_context():
return 1 + 2
import torch
fn = torch.compile(fn, fullgraph=True)
out = fn()
Which fails with
Traceback (most recent call last):
File "/home/carmocca/git/lightning/kk2.py", line 33, in <module>
out = fn()
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 333, in _fn
return fn(*args, **kwargs)
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 493, in catch_errors
return callback(frame, cache_entry, hooks, frame_state)
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 133, in _fn
return fn(*args, **kwargs)
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 389, in _convert_frame_assert
return _compile(
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 564, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
r = func(*args, **kwargs)
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 486, in compile_inner
out_code = transform_code_object(code, transform)
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
transformations(instructions, code_options)
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 453, in transform
tracer.run()
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2074, in run
super().run()
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
and self.step()
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
getattr(self, inst.opname)(inst)
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1003, in SETUP_WITH
self.setup_or_before_with(inst)
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1755, in setup_or_before_with
unimplemented(f"{inst.opname} {ctx}")
File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 172, in unimplemented
raise Unsupported(msg)
torch._dynamo.exc.Unsupported: SETUP_WITH UserDefinedObjectVariable(ExitStack)
from user code:
File "/home/carmocca/git/lightning/kk2.py", line 28, in fn
with init_context():
The test suite includes tests for similar classes: https://github.com/yanboliang/pytorch/blob/main/test/dynamo/test_ctx_manager.py so I would expect that supporting this basic version is feasible.
Versions
2.1.0.rc0
cc @ezyang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames @anijain2305 @msaroufim @wconstab @bdhirsh @Xia-Weiwen @aakhundov