Skip to content

Support the ExitStack context manager (or a simplified version) #109309

@carmocca

Description

@carmocca

🐛 Describe the bug

I'm trying to compile a piece of code with fullgraph=True that uses multiple context managers.

I saw that #98725 should support this sort of behaviour (cc: @yanboliang as author)

Lightning-AI/pytorch-lightning#18557 includes some real pieces of code that show this pattern

Error logs

No response

Minified repro

from contextlib import ExitStack

class A:
    def __enter__(self): pass
    def __exit__(self, exc_type, exc_val, exc_tb): pass

class B:
    def __enter__(self): pass
    def __exit__(self, exc_type, exc_val, exc_tb): pass


def init_context():
    stack = ExitStack()
    stack.enter_context(A())
    stack.enter_context(B())
    return stack


def fn():
    with init_context():
        return 1 + 2

import torch
fn = torch.compile(fn, fullgraph=True)
out = fn()

However, this fails with

Traceback (most recent call last):
  File "/home/carmocca/git/lightning/kk2.py", line 25, in <module>
    out = fn()
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 333, in _fn
    return fn(*args, **kwargs)
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 493, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 133, in _fn
    return fn(*args, **kwargs)
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 389, in _convert_frame_assert
    return _compile(
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 564, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 486, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 453, in transform
    tracer.run()
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2074, in run
    super().run()
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1115, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 562, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 261, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 598, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2179, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2286, in inline_call_
    tracer.run()
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1115, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 562, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py", line 594, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs).add_options(self)
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/variables/base.py", line 329, in call_method
    raise unimplemented(f"call_method {self} {name} {args} {kwargs}")
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 172, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: call_method GenericContextWrappingVariable() enter_context [GenericContextWrappingVariable()] {}

from user code:
   File "/home/carmocca/git/lightning/kk2.py", line 20, in fn
    with init_context():
  File "/home/carmocca/git/lightning/kk2.py", line 14, in init_context
    stack.enter_context(A())

I also tried my own simplified implementation:

class ExitStack:
    def __init__(self, context_managers) -> None:
        self._context_managers = context_managers

    def __enter__(self):
        for ctx_manager in self._context_managers:
            ctx_manager.__enter__()

    def __exit__(self, exc_type, exc_value, traceback):
        for ctx_manager in reversed(self._context_managers):
            ctx_manager.__exit__(exc_type, exc_value, traceback)

class A:
    def __enter__(self): pass
    def __exit__(self, exc_type, exc_val, exc_tb): pass

class B:
    def __enter__(self): pass
    def __exit__(self, exc_type, exc_val, exc_tb): pass


def init_context():
    stack = ExitStack([A(), B()])
    return stack


def fn():
    with init_context():
        return 1 + 2

import torch
fn = torch.compile(fn, fullgraph=True)
out = fn()

Which fails with

Traceback (most recent call last):
  File "/home/carmocca/git/lightning/kk2.py", line 33, in <module>
    out = fn()
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 333, in _fn
    return fn(*args, **kwargs)
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 493, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 133, in _fn
    return fn(*args, **kwargs)
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 389, in _convert_frame_assert
    return _compile(
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 564, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 486, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 453, in transform
    tracer.run()
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2074, in run
    super().run()
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1003, in SETUP_WITH
    self.setup_or_before_with(inst)
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1755, in setup_or_before_with
    unimplemented(f"{inst.opname} {ctx}")
  File "/home/carmocca/git/venv/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 172, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: SETUP_WITH UserDefinedObjectVariable(ExitStack)

from user code:
   File "/home/carmocca/git/lightning/kk2.py", line 28, in fn
    with init_context():

The test suite includes tests for similar classes: https://github.com/yanboliang/pytorch/blob/main/test/dynamo/test_ctx_manager.py so I would expect that supporting this basic version is feasible.

Versions

2.1.0.rc0

cc @ezyang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames @anijain2305 @msaroufim @wconstab @bdhirsh @Xia-Weiwen @aakhundov

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions