diff --git a/test/dynamo/test_nested_graph_breaks.py b/test/dynamo/test_nested_graph_breaks.py index 5f593d01defc9..9c8a31e080305 100644 --- a/test/dynamo/test_nested_graph_breaks.py +++ b/test/dynamo/test_nested_graph_breaks.py @@ -68,21 +68,6 @@ def make_nested_cls(cls): make_nested_cls(test) del test -global_val = 0 - - -class CustomizedCtxManager: - def __init__(self, val): - self.val = val - - def __enter__(self): - global global_val - global_val += self.val - - def __exit__(self, exc_type, exc_value, traceback): - global global_val - global_val -= self.val - # for use in test_side_effects_globals global1, global2, global3, global4 = (torch.zeros(3),) * 4 @@ -222,40 +207,91 @@ def f3(x3): self.assertEqual(ref, res) self.assertEqual(cnts.frame_count, 2) - @unittest.expectedFailure - def test_ctx_manager(self): - global global_val - global_val = 0 + def test_supported_ctx_manager(self): + global check, check_disabled, f1, f2, f3 @torch._dynamo.disable - def f1(): - return global_val + def check_disabled(value): + assert torch.is_grad_enabled() == value - def f2(x2): - with CustomizedCtxManager(8): - x2 = x2 + (1 << 4) - x2 = x2 + f1() # 15 - x2 = x2 + (1 << 5) - x2 = x2 << 2 - x2 = x2 + global_val # 3 - with CustomizedCtxManager(4): - x2 = x2 << 4 - x2 = x2 + f1() # 7 - x2 = x2 + (1 << 3) - return x2 + def check(value): + assert torch.is_grad_enabled() == value - def f3(x3): - with CustomizedCtxManager(2): - return f2(x3) + def f1(x): + with torch.no_grad(): + x = x + 1 + check(False) + check_disabled(False) + check(False) + return x + 2 + + def f2(x): + with torch.enable_grad(): + x = x + 4 + check(True) + check_disabled(True) + check(True) + return f1(x) + 8 - def f4(x4): - with CustomizedCtxManager(1): - return f3(x4) + def f3(x): + with torch.no_grad(): + x = x + 16 + check(False) + check_disabled(False) + check(False) + return f2(x) + 32 cnts = torch._dynamo.testing.CompileCounter() - opt_fn = torch._dynamo.optimize(backend=cnts)(f4) - x = torch.zeros(3, dtype=torch.long) - res = f4(x) + opt_fn = torch._dynamo.optimize(backend=cnts)(f3) + x = torch.zeros(3) + res = f3(x) + ref = opt_fn(x) + self.assertEqual(ref, res) + self.assertEqual(cnts.frame_count, 4) + + def test_inactive_ctx_manager(self): + global check, f1, f2, f3 + + def check(value): + assert torch.is_grad_enabled() == value + + def f1(x, ctx1): + x = x + 1 + ctx2 = torch.no_grad() + # torch.no_grad() is a stack value at the time of graph break + ctx3 = (torch.no_grad(), torch._dynamo.graph_break())[0] + x = x + 64 + torch._dynamo.graph_break() + with ctx1: + check(False) + with ctx2: + check(False) + with ctx3: + check(False) + return x + 2 + + def f2(x, ctx1): + x = x + 4 + ctx2 = torch.no_grad() + x = f1(x, torch.no_grad()) + with ctx1: + check(False) + with ctx2: + check(False) + return x + 8 + + def f3(x): + x = x + 16 + ctx = torch.no_grad() + x = f2(x, torch.no_grad()) + with ctx: + check(False) + return x + 32 + + cnts = torch._dynamo.testing.CompileCounter() + opt_fn = torch._dynamo.optimize(backend=cnts)(f3) + x = torch.zeros(3) + res = f3(x) ref = opt_fn(x) self.assertEqual(ref, res) self.assertEqual(cnts.frame_count, 3) diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 4a2127c92dfed..d3648ebe819d7 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -3345,7 +3345,7 @@ def setup_or_before_with(self, inst): self.push(exit) if target: - if isinstance(self, InstructionTranslator): + if isinstance(self, InstructionTranslator) or config.nested_graph_breaks: self.block_stack.append( BlockStackEntry(inst, target, len(self.stack), ctx) )