Skip to content

[dynamo, nested graph breaks] support nested graph breaks x context managers #159678

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: gh/williamwen42/270/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 78 additions & 42 deletions test/dynamo/test_nested_graph_breaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,21 +68,6 @@ def make_nested_cls(cls):
make_nested_cls(test)
del test

global_val = 0


class CustomizedCtxManager:
def __init__(self, val):
self.val = val

def __enter__(self):
global global_val
global_val += self.val

def __exit__(self, exc_type, exc_value, traceback):
global global_val
global_val -= self.val


# for use in test_side_effects_globals
global1, global2, global3, global4 = (torch.zeros(3),) * 4
Expand Down Expand Up @@ -222,40 +207,91 @@ def f3(x3):
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)

@unittest.expectedFailure
def test_ctx_manager(self):
global global_val
global_val = 0
def test_supported_ctx_manager(self):
global check, check_disabled, f1, f2, f3

@torch._dynamo.disable
def f1():
return global_val
def check_disabled(value):
assert torch.is_grad_enabled() == value

def f2(x2):
with CustomizedCtxManager(8):
x2 = x2 + (1 << 4)
x2 = x2 + f1() # 15
x2 = x2 + (1 << 5)
x2 = x2 << 2
x2 = x2 + global_val # 3
with CustomizedCtxManager(4):
x2 = x2 << 4
x2 = x2 + f1() # 7
x2 = x2 + (1 << 3)
return x2
def check(value):
assert torch.is_grad_enabled() == value

def f3(x3):
with CustomizedCtxManager(2):
return f2(x3)
def f1(x):
with torch.no_grad():
x = x + 1
check(False)
check_disabled(False)
check(False)
return x + 2

def f2(x):
with torch.enable_grad():
x = x + 4
check(True)
check_disabled(True)
check(True)
return f1(x) + 8

def f4(x4):
with CustomizedCtxManager(1):
return f3(x4)
def f3(x):
with torch.no_grad():
x = x + 16
check(False)
check_disabled(False)
check(False)
return f2(x) + 32

cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(f4)
x = torch.zeros(3, dtype=torch.long)
res = f4(x)
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
x = torch.zeros(3)
res = f3(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 4)

def test_inactive_ctx_manager(self):
global check, f1, f2, f3

def check(value):
assert torch.is_grad_enabled() == value

def f1(x, ctx1):
x = x + 1
ctx2 = torch.no_grad()
# torch.no_grad() is a stack value at the time of graph break
ctx3 = (torch.no_grad(), torch._dynamo.graph_break())[0]
x = x + 64
torch._dynamo.graph_break()
with ctx1:
check(False)
with ctx2:
check(False)
with ctx3:
check(False)
return x + 2

def f2(x, ctx1):
x = x + 4
ctx2 = torch.no_grad()
x = f1(x, torch.no_grad())
with ctx1:
check(False)
with ctx2:
check(False)
return x + 8

def f3(x):
x = x + 16
ctx = torch.no_grad()
x = f2(x, torch.no_grad())
with ctx:
check(False)
return x + 32

cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
x = torch.zeros(3)
res = f3(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 3)
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3345,7 +3345,7 @@ def setup_or_before_with(self, inst):
self.push(exit)

if target:
if isinstance(self, InstructionTranslator):
if isinstance(self, InstructionTranslator) or config.nested_graph_breaks:
self.block_stack.append(
BlockStackEntry(inst, target, len(self.stack), ctx)
)
Expand Down
Loading